Conquer-Divide的经典例子之Strassen算法解决大型矩阵的相乘

通过汉诺塔问题理解递归的精髓中我讲解了怎么把一个复杂的问题一步步recursively划分了成简单显而易见的小问题。其实这个解决问题的思路就是算法中常用的divide
and conquer, 这篇日志通过解决矩阵的乘法,来了解另外一个基本divide and
conque思想的strassen算法。

矩阵A乘以B等于X, 则Xij = 
注意左乘右乘的区别,AB
与BA是不同的。
如果r = 1, 直接就是两个数的相乘。
如果r = 2, 例如
X = 
[ 1,
2; 
  3, 4];
Y = 
[ 2, 3;
 4, 5];
R =
XY的计算十分简单,但是如果r很大,耗时是O(r^3)。为了简化,可以把X,
Y各自划分成2X2的矩阵,每一个元素其实是有n/2行的矩阵
(注:这里仅讲解行数等于列数的情况。)

X = 
[A,
B;
C, D];

Y = 
[E, F;
G, H]

所以XY =[
AE+BG,
AF+BH;
CE+DG, CF+DH]

Strassen引入seven magic product 分别是P1, P2, P3 ,P4,
P5, P6, P7
P1 = A(F-H)
P2 = (A+B)H
P3 = (C+D)E
P4 = D(G-E)
P5 =
(A+D)(E+H)
P6 = (B-D)(G+H)
P7 = (A-C)(E+F)

这样XY

[P5+P4-P2+P6, P1+P2;
P3+P4,
P1+P5-P3-P7]

然后通过递归的策略计算矩阵的相乘,递归的出口是n = 1.

关键点就是这些,附上代码吧。

[java] view
plain
copy

    1. //multiply matrix multiplication

    2. import java.util.Scanner;

    3. public class Strassen{

    4. public Strassen(){}
    5. /** split a parent matrix into child matrics8*/

    6. public static void split(int[][] P, int[][] C, int iB, int jB){

    7. for(int i1=0, i2 = iB; i1<C.length; i1++, i2++)

    8. for(int j1=0, j2=jB; j1<C.length; j1++, j2++)

    9. C[i1][j1] = P[i2][j2];

    10. }
    11. /**join child matric into parent matrix*/

    12. public static void join(int[][] C, int[][] P, int iB, int jB){

    13. for(int i1=0, i2 = iB; i1<C.length; i1++, i2++)

    14. for(int j1=0, j2=jB; j1<C.length; j1++, j2++)

    15. P[i2][j2]=C[i1][j1];

    16. }
    17. /**add two matrics into one*/

    18. public static int[][] add(int[][] A, int[][] B){

    19. //A and B has the same dimension

    20. int n = A.length;

    21. int[][] C = new int[n][n];

    22. for (int i=0; i<n; i++)

    23. for(int j=0; j<n; j++)

    24. C[i][j] = A[i][j] + B[i][j];
    25. return C;

    26. }
    27. //subtract one matric by another

    28. public static int[][] sub(int[][] A, int[][] B){

    29. //A and B has the same dimension

    30. int n = A.length;

    31. int[][] C = new int[n][n];

    32. for (int i=0; i<n; i++)

    33. for(int j=0; j<n; j++)

    34. C[i][j] = A[i][j] - B[i][j];

    35. return C;

    36. }
    37. //Multiply matrix

    38. public static int[][] multiply(int[][] A, int[][] B){

    39. int n = A.length;

    40. int[][] R = new int[n][n];
    41. /**exit*/

    42. if(n==1)

    43. R[0][0] = A[0][0]+B[0][0];
    44. else{

    45. //divide A into 4 submatrix

    46. int[][] A11 = new int[n/2][n/2];

    47. int[][] A12 = new int[n/2][n/2];

    48. int[][] A21 = new int[n/2][n/2];

    49. int[][] A22 = new int[n/2][n/2];
    50. split(A, A11, 0, 0);

    51. split(A, A12, 0, n/2);

    52. split(A, A21, n/2, 0);

    53. split(A, A22, n/2, n/2);
    54. //divide B into 4 submatric

    55. int[][] B11 = new int[n/2][n/2];

    56. int[][] B12 = new int[n/2][n/2];

    57. int[][] B21 = new int[n/2][n/2];

    58. int[][] B22 = new int[n/2][n/2];
    59. split(B, B11, 0, 0);

    60. split(B, B12, 0, n/2);

    61. split(B, B21, n/2, 0);

    62. split(B, B22, n/2, n/2);
    63. //seven magic products

    64. int[][] P1 = multiply(A11, sub(B12, B22));

    65. int[][] P2 = multiply(add(A11,A12), B22);

    66. int[][] P3 = multiply(add(A21, A22), B11);

    67. int[][] P4 = multiply(A22, sub(B21, B11));

    68. int[][] P5 = multiply(add(A11, A22), add(B11, B22));

    69. int[][] P6 = multiply(sub(A12, A22), add(B21, B22));

    70. int[][] P7 = multiply(sub(A11, A21), add(B11, B12));
    71. //new 4 submatrix

    72. int[][] R11 = add(add(P5, sub(P4, P2)), P6);

    73. int[][] R12 = add(P1, P2);

    74. int[][] R21 = add(P3, P4);

    75. int[][] R22 = sub(sub(add(P1, P5), P3), P7);
    76. //joint together

    77. join(R11, R, 0, 0);

    78. join(R12, R, 0, n/2);

    79. join(R21, R, n/2, 0);

    80. join(R22, R, n/2, n/2);
    81. }

    82. return R;

    83. }
    84. //main

    85. public static void main(String[] args){
    86. Scanner scan = new Scanner(System.in);

    87. System.out.println("Strassen Multiplication Algorithm Test\n");

    88. Strassen s = new Strassen();
    89. System.out.println("Fetch the matric A and B...");

    90. int N = scan.nextInt();

    91. int[][] A = new int[N][N];

    92. int[][] B = new int[N][N];
    93. for (int i = 0; i < N; i++)

    94. for (int j = 0; j < N; j++)

    95. A[i][j] = scan.nextInt();
    96. for (int i = 0; i < N; i++)

    97. for (int j = 0; j < N; j++)

    98. B[i][j] = scan.nextInt();
    99. System.out.println("Fetch Completed!");
    100. int[][] C = s.multiply(A, B);
    101. System.out.println("\nmatrices A = ");

    102. for (int i = 0; i < N; i++){

    103. for (int j = 0; j < N; j++)

    104. System.out.print(A[i][j] +" ");

    105. System.out.println();

    106. }
    107. System.out.println("\nmatrices B =");

    108. for (int i = 0; i < N; i++) {

    109. for (int j = 0; j < N; j++)

    110. System.out.print(B[i][j] +" ");

    111. System.out.println();

    112. }
    113. System.out.println("\nProduct of matrices A and  B  = ");

    114. for (int i = 0; i < N; i++)

    115. {

    116. for (int j = 0; j < N; j++)

    117. System.out.print(C[i][j] +" ");

    118. System.out.println();

    119. }

    120. }

    121. }

Conquer-Divide的经典例子之Strassen算法解决大型矩阵的相乘

时间: 2024-10-12 15:41:30

Conquer-Divide的经典例子之Strassen算法解决大型矩阵的相乘的相关文章

Strassen算法及其python实现

题目描述 请编程实现矩阵乘法,并考虑当矩阵规模较大时的优化方法. 思路分析 根据wikipedia上的介绍:两个矩阵的乘法仅当第一个矩阵B的列数和另一个矩阵A的行数相等时才能定义.如A是m×n矩阵和B是n×p矩阵,它们的乘积AB是一个m×p矩阵,它的一个元素其中 1 ≤ i ≤ m, 1 ≤ j ≤ p. 值得一提的是,矩阵乘法满足结合律和分配率,但并不满足交换律,如下图所示的这个例子,两个矩阵交换相乘后,结果变了: 下面咱们来具体解决这个矩阵相乘的问题. 解法一.暴力解法 其实,通过前面的分析

矩阵乘法的Strassen算法详解

题目描述 请编程实现矩阵乘法,并考虑当矩阵规模较大时的优化方法. 思路分析 根据wikipedia上的介绍:两个矩阵的乘法仅当第一个矩阵B的列数和另一个矩阵A的行数相等时才能定义.如A是m×n矩阵和B是n×p矩阵,它们的乘积AB是一个m×p矩阵,它的一个元素其中 1 ≤ i ≤ m, 1 ≤ j ≤ p. 值得一提的是,矩阵乘法满足结合律和分配率,但并不满足交换律,如下图所示的这个例子,两个矩阵交换相乘后,结果变了: 下面咱们来具体解决这个矩阵相乘的问题. 解法一.暴力解法 其实,通过前面的分析

小猪的数据结构辅助教程——2.5 经典例子:约瑟夫问题的解决

小猪的数据结构辅助教程--2.5 经典例子:约瑟夫问题的解决 标签(空格分隔): 数据结构 约瑟夫问题的解析 关于问题的故事背景就不提了,我们直接说这个问题的内容吧: 一堆人,围成一个圈,然后规定一个数N,然后依次报数,当报数到N,这个人自杀,其他人鼓掌!啪啪啪, 接着又从1开始报数,报到N又自杀-以此类推,直到死剩最后一个人,那么游戏结束! 这就是问题,而我们用计算机模拟的话,用户输入:N(参与人数),M(第几个人死),结果返回最后一个人! 类似的问题有跳海问题,猴子选王等,下面我们就以N =

经典的十个机器学习算法

1.C4.5 机器学习中,决策树是一个预测模型:他代表的是对象属性与对象值之间的一种映射关系.树中每个节点表示某个对象,而每个分叉路径则代表的某个可能的 属性值,而每个叶结点则对应从根节点到该叶节点所经历的路径所表示的对象的值.决策树仅有单一输出,若欲有复数输出,可以建立独立的决策树以处理不同输 出. 从数据产生决策树的机器学习技术叫做决策树学习, 通俗说就是决策树. 决策树学习也是数据挖掘中一个普通的方法.在这里,每个决策树都表述了一种树型结构,他由他的分支来对该类型的对象依靠属性进行分类.每

递归的几个经典例子

注意:构造方法不可递归,否则是无限创建对象; 递归的几个经典例子: 1.HannoiTower 1 import java.util.Scanner; 2 public class HanoiTower{ 3 //level代表盘子个数;三个char类型代表柱子 4 public static void moveDish(int level, char from, char inter, char to){ 5 if(level == 1){ 6 System.out.println("从&qu

矩阵乘法的Strassen算法及时间复杂度

[问题]普通方法计算矩阵相乘,时间复杂度为O(n^3),请设计优化算法. [Strassen算法] [时间复杂度]

信号量基础和两个经典例子

信号量基础和两个经典例子 信号量(semaphore) 用于进程中传递信号的一个整数值. 三个操作: 1.一个信号量可以初始化为非负值 2.semWait操作可以使信号量减1,若信号量的值为负,则执行semWait的进程被阻塞.否则进程继续执行. 3.semSignal操作使信号量加1.若信号量的值小于等于0,则被semWait操作阻塞的进程讲被接触阻塞. ps: semWait对应P原语,semSignal对应V原语. 信号量以及PV原语的C语言定义如下 struct semaphore {

python经典例子

http://wangwei007.blog.51cto.com/68019/1106735  检查Linux系统日志error和mysql错误日志的脚本 http://wangwei007.blog.51cto.com/68019/1102836  pickle http://wangwei007.blog.51cto.com/68019/1045577  python用zipfile模块打包文件或是目录.解压zip文件实例 http://blog.163.com/kefan_1987/blo

storm经典例子的wordcount的实现

storm有个经典的例子wordcount,其实这几乎可以说是大数据的经典例子了,mapreduce也会有这个例子.但是storm给的例子包里的WordCountTopology用到了python的调用,直接用eclipse跑起来的话会报错,这里做了个小改动. 1.WordCountTopology.java package storm.starter; import backtype.storm.Config; import backtype.storm.LocalCluster; impor