资讯详情

Strassen矩阵乘法(C++)

思路

两个矩阵A,B相乘时.有三种方法

三个for循环, 时间复杂度为O(n^3).因为Cij=∑(k=1->n)Aik*Bkj,需要循环, 且C中有n^2个元素, 因此,时间复杂度为O(n^3)

首先将A,B,C分为相等大小的方块矩阵.

所以C11=A11*B11 A12*B21, C12=A11*B12 A12*B22,

C21=A21*B11 A22*B21, C22=A21*B12 A22*B22

用T(n)表示n*n矩阵乘法, 所以有T(n)=8T(n/2) Θ(n^2). 其中, 8T(n/2)表示8次子矩阵乘法, 子矩阵的规模为n/2 * n/2. θ(n^2)表示4次矩阵加法的时间复杂性和合并C矩阵的时间复杂性.最后结果是Θ(n^3)与暴力计算时间的复杂性相同.

,可以优化时间复杂度O(n^log7).

现在重新定义7个新矩阵

M1=(A11 A22)*(B11 B22)

M2=(A21 A22)*B11

M3=A11*(B12-B22)

M4=A22*(B21-B11)

M5=(A11 A12)*B22

M6=(A21-A11)*(B11 B12)

M7=(A12-A22)*(B21 B22)

结果矩阵C可以组合上述矩阵,如下

C11=M1 M4-M5 M7

C12=M3 M5

C21=M2 M4

C22=M1-M2 M3 M6

此时共有7次乘法,18次加减法. 写递推公式T(n)=7T(n/2) Θ(n^2). 最终结果是O(n^log7)=O(n^2.807).

#include <bits/stdc  .h>  using namespace std;  // 暴力求解矩阵相乘 void MUL(int** MatrixA,int** MatrixB,int** MatrixResult,int Msize){     for(int i=0;i<Msize;i  ){         for(int j=0;j<Msize;j  ){             MatrixResult[i][j]=0;             for(int k=0;k<Msize;k  ){                 MatrixResult[i][j] =MatrixA[i][k]*MatrixB[k][j];             }         }     } }  // 矩阵相加运算 void ADD(int** MatrixA,int** MatrixB,int** MatrixResult,int Msize){     for(int i=0;i<Msize;i  ){         for(int j=0;j<Msize;j  ){             MatrixResult[i][j]=MatrixA[i][j] MatrixB[i][j];         }     } }  // 矩阵相减运算 void SUB(int** MatrixA,int** MatrixB,int** MatrixResult,int Msize){     for(int i=0;i<Msize;i  ){         for(int j=0;j<Msize;j  ){             MatrixResult[i][j]=MatrixA[i][j]-MatrixB[i][j];         }     } }  // Strassen算法 void Strassen(int N,int** MatrixA,int** MatrixB,int** MatrixC){     int halfSize=N/2;     if(N<=2){         MUL(MatrixA,MatrixB,MatrixC,N);     }     else{         // 创建二维数组指针         int** A11;         int** A12;         int** A21;         int** A22;          int** B11;         int** B12;         int** B21;         int** B22;          int** C11;         int** C12;         int** C21;         int** C22;          int** M1;         int** M2;         int** M3;         int** M4;         int** M5;         int** M6;         int** M7;         int** AResult;         int** BResult;         // 初始化         A11=new int*[halfSize];         A12=new int*[halfSize];         A21=new int*[halfSize];         A22=new int*[halfSize];          B11=new int*[halfSize];         B12=new int*[halfSize];         B21=new int*[halfSize];         B22=new int*[halfSize];          C11=new int*[halfSize];         C12=new int*[halfSize];         C21=new int*[halfSize];         C22=new int*[halfSize];          M1=new int*[halfSize];         M2=new int*[halfSize];         M3=new int*[halfSize];         M4=new int*[halfSize];         M5=new int*[halfSize];         M6=new int*[halfSize];         M7=new int*[halfSize];         AResult=new int*[halfSize];         BResult=new int*[halfSize];          for(int i=0;i<halfSize;i  ){             A11[i]=new int[halfSize];             A12[i]=new int[halfSize];             A21[i]=new int[halfSize];             A22[i]=new int[halfSize];              B11[i]=new int[halfSize];             B12[i]=new int[halfSize];             B21[i]=new int[halfSize];             B22[i]=new int[halfSize];              C11[i]=new int[halfSize];             C12[i]=new int[halfSize];             C21[i]=new int[halfSize];             C22[i]=new int[halfSize];              M1[i]=new int[halfSize];             M2[i]=new int[halfSize];             M3[i]=new int[halfSize];             M4[i]=new int[halfSize];             M5[i]=new int[halfSize];             M6[i]=new int[halfSize];             M7[i]=new int[halfSize];              AResult[i]=new int[halfSize];             BResult[i]=new int[halfSize];         }          // 把MatrixA和MatrixB分块         for(int i=0;i<N/2;i  ){             for(int j=0;j<N/2;j  ){                 A11[i][j]=MatrixA[i][j];                 A12[i][j]=MatrixA[i][j N/2];                 A21[i][j]=MatrixA[i N/2][j];                 A22[i][j]=MatrixA[i N/2][j N/2];                  B11[i][j]=MatrixB[i][j];                 B12[i][j]=MatrixB[i][j N/2];                 B21[i][j]=MatrixB[i N/2][j];                 B22[i][j]=MatrixB[i N/2][j N/2];             }         }          // M1=(A11 A22)*(B11 B22)         ADD(A11,A22,AResult,halfSize);         ADD(B11,B22,BResult,halfSize);         Strassen(halfSize,AResult,BResult,M1);          // M2=(A21 A22)*B11         ADD(A21,A22,AResult,halfSize);         Strassen(halfSiz,AResult,B11,M2);

        // M3=A11*(B12-B22)
        SUB(B12,B22,BResult,halfSize);
        Strassen(halfSize,A11,BResult,M3);

        // M4=A22*(B21-B11)
        SUB(B21,B11,BResult,halfSize);
        Strassen(halfSize,A22,BResult,M4);

        // M5=(A11+A12)B22
        ADD( A11, A12, AResult, halfSize);
        Strassen(halfSize, AResult, B22, M5);

        // M6=(A21-A11)*(B11+B12)
        SUB( A21, A11, AResult, halfSize);
        ADD( B11, B12, BResult, halfSize);
        Strassen( halfSize, AResult, BResult, M6);

        // M7=(A12-A22)*(B21+B22)
        SUB(A12, A22, AResult, halfSize);
        ADD(B21, B22, BResult, halfSize);
        Strassen(halfSize, AResult, BResult, M7);

        // C11=M1+M4-M5+M7
        ADD( M1, M4, AResult, halfSize);
        SUB( M7, M5, BResult, halfSize);
        ADD( AResult, BResult, C11, halfSize);

        // C12=M3+M5
        ADD( M3, M5, C12, halfSize);

        // C21=M2+M4
        ADD( M2, M4, C21, halfSize);

        // C22=M1-M2+M3+M6
        ADD( M1, M3, AResult, halfSize);
        SUB( M6, M2, BResult, halfSize);
        ADD( AResult, BResult, C22, halfSize);

        // 把C11,C12,C21,C22矩阵合并成一个大矩阵MatrixC
        for(int i=0;i<N/2;i++){
            for(int j=0;j<N/2;j++){
                MatrixC[i][j]=C11[i][j];
                MatrixC[i][j+N/2]=C12[i][j];
                MatrixC[i+N/2][j]=C21[i][j];
                MatrixC[i+N/2][j+N/2]=C22[i][j];
            }
        }

        // 释放空间
        for (int i = 0; i < halfSize; i++)
        {
            delete[] A11[i];delete[] A12[i];delete[] A21[i];
            delete[] A22[i];

            delete[] B11[i];delete[] B12[i];delete[] B21[i];
            delete[] B22[i];
            delete[] C11[i];delete[] C12[i];delete[] C21[i];
            delete[] C22[i];
            delete[] M1[i];delete[] M2[i];delete[] M3[i];delete[] M4[i];
            delete[] M5[i];delete[] M6[i];delete[] M7[i];
            delete[] AResult[i];delete[] BResult[i] ;
        }
        delete[] A11;delete[] A12;delete[] A21;delete[] A22;
        delete[] B11;delete[] B12;delete[] B21;delete[] B22;
        delete[] C11;delete[] C12;delete[] C21;delete[] C22;
        delete[] M1;delete[] M2;delete[] M3;delete[] M4;delete[] M5;
        delete[] M6;delete[] M7;
        delete[] AResult;
        delete[] BResult;
    }

}

int main()
{
    int MSize;
    cin >> MSize;

    // 定义三个矩阵
    int** MatrixA;
    int** MatrixB;
    int** MatrixC;

    // 初始化三个矩阵
    MatrixA=new int*[MSize];
    MatrixB=new int*[MSize];
    MatrixC=new int*[MSize];
    for(int i=0;i<MSize;i++){
        MatrixA[i]=new int[MSize];
        MatrixB[i]=new int[MSize];
        MatrixC[i]=new int[MSize];
    }

    // 输入相乘的矩阵
    for(int i=0;i<MSize;i++){
        for(int j=0;j<MSize;j++){
            cin >> MatrixA[i][j];
        }
    }
    for(int i=0;i<MSize;i++){
        for(int j=0;j<MSize;j++){
            cin >> MatrixB[i][j];
        }
    }

    Strassen(MSize,MatrixA,MatrixB,MatrixC);

    // 打印输出结果矩阵
    for(int i=0;i<MSize;i++){
        for(int j=0;j<MSize;j++){
            cout << MatrixC[i][j] << " ";
        }
        cout << endl;
    }

    return 0;
}


/* 一组数据
4
1 2 4 7
8 3 6 5
4 7 2 1
6 4 3 1
1 2 4 7
8 3 6 5
4 7 2 1
6 4 3 1
*/

标签: a21传感器微型对射传感器

锐单商城拥有海量元器件数据手册IC替代型号,打造 电子元器件IC百科大全!

锐单商城 - 一站式电子元器件采购平台