矩阵乘积计算(Strassen)

问题描述

​ 已知A,B两个矩阵计算其乘积C?

矩阵乘积数学公式:

​ 假设存在两个矩阵A为m×n矩阵,B为k×l矩阵,若需要计算AB则必须n=k,若需要计算BA必须l=m否则无法进行计算,先假定n=k即B为n×l矩阵则AB的结果为一个m×l的矩阵并且该矩阵每个点的元素的值表示为CijCij则:

方法一:直接计算

​ 直接利用多重for循环求出相关矩阵对应的点的值即可

//矩阵的数据结构,随机矩阵,非特殊矩阵
struct array
{int **data;                 //数据域int row;int col;
};/***  初始化矩阵元素,用随机数填充*  只为研究算法因此为进行相关的内存检查*  flag用来标记是否生成空矩阵,即元素全部为0的矩阵*/
void init_array(struct array *ptr,const int row,const int col,int flag)
{int i = 0,j = 0;ptr->data = (int **)malloc(sizeof(int)*row);                                    //??内存分配for (i = 0; i < row; i++){*(ptr->data + i) = (int*)malloc(sizeof(int)*col);                           //??内存分配}ptr->col = col;ptr->row = row;srand(time(NULL));for (i = 0; i < row; i++){for (j = 0; j < col; j++){if (flag){ptr->data[i][j] = rand() % ARRAY_PRCE;}else{ptr->data[i][j] = 0;}}}
}/***  打印矩阵元素*/
void print_array(const struct array *ptr, const char *msg)
{int i, j;printf("%s\n", msg);for (i = 0; i < ptr->row; i++){for (j = 0; j < ptr->col; j++){printf("%4d", ptr->data[i][j]);}printf("\n");}
}
/***  销毁内存*/
void delete_array(struct array *ptr)
{int i = 0;for (i = 0; i < ptr->row; i++){free(*(ptr->data + i));*(ptr->data + i) = NULL;}free(ptr->data);
}/**************************************************************************************************************************************************************************************//**
*   矩阵乘法求解
*   问题描述:已知两个可以进行相乘的矩阵,求的乘积后的结果
*//**
*   方法一:暴力直接求解
*   利用矩阵乘法规则直接进行求解罗列出每个点的值求的最终的矩阵
*/
struct array mult_array(const struct array *ptr1, const struct array *ptr2)
{int i = 0;int j = 0;int k = 0;struct array ptr;if (ptr1->col != ptr2->row)                         //检查是否符合可以进行乘积的要求{return;}init_array(&ptr, ptr1->row, ptr2->col, 0);for (i = 0; i < ptr.row;i ++){for (j = 0; j < ptr.col; j++){for (k = 0; k < ptr1->col; k++){ptr.data[i][j] += ptr1->data[i][k] * ptr2->data[k][j];}}}return ptr;
}

执行效果

时间复杂度为O(n3n^3)

方法二:分治算法

​ 将矩阵分解为一个个小矩阵进行计算然后将计算结果合并得到相关的结果。源于矩阵服从分配率和结合律,并不支持交换律。

​ 三个矩阵本身就可以写成下面的格式

​ 那么相关的计算可以写成

​ 同理A11等一些子矩阵也可以写成相关的子矩阵,就这样将矩阵不断分解为小矩阵进行计算,最后归并为一个矩阵。

​ 时间复杂度为O(n3n^3)

/***  方法二:利用分治思想进行求解*  存在问题无法解决不同类型的矩阵的问题,要求矩阵的行列必须为2的n次方,若不符合要求可以使用*  补0来构造相关的矩阵*/
Matrix* Matrix::merge_calc(const Matrix& x)
{if (x.row == 1)             //当前的矩阵为单个的元素{Matrix *ptr = new Matrix(x.row, x.col);ptr->clear((this->getElem(0, 0))*(x.getElem(0, 0)));return ptr;}//将第一个矩阵分解为四个子矩阵Matrix A11(0,0,row/2,col/2,*this);Matrix A12(row / 2, 0, row, col / 2, *this);Matrix A21(0, col / 2, row / 2, col, *this);Matrix A22(row / 2, col / 2, row, col, *this);//将第二个矩阵分解为四个子矩阵Matrix B11(0, 0, row / 2, col / 2, x);Matrix B12(row / 2, 0, row, col / 2, x);Matrix B21(0, col / 2, row / 2, col, x);Matrix B22(row / 2, col / 2, row, col, x);Matrix *C11 = Matrix::add(A11.merge_calc(B11), A12.merge_calc(B21));Matrix *C12 = Matrix::add(A11.merge_calc(B12), A12.merge_calc(B22));Matrix *C21 = Matrix::add(A21.merge_calc(B11), A22.merge_calc(B21));Matrix *C22 = Matrix::add(A21.merge_calc(B12), A22.merge_calc(B22));//将C11,C12,C21,C22合并为一个完整的矩阵Matrix* ptr = Matrix::merge(C11, C12, C21, C22);return ptr;
}

方法三:Strassen算法

​ Strassen算法同样是使用分治的思想解决问题,只不过,不同的是当矩阵的阶很大时就会采取一个递推式进行计算相关递推式为:

                            S1 = B12 - B22S2 = A11 + A12S3 = A21 + A22S4 = B21 - B11S5 = A11 + A22S6 = B11 + B22S7 = A12 - A22S8 = B21 + B22S9 = A11 - A21S10 = B11 + B12 
                            P1 = A11 * S1P2 = S2 * B22P3 = S3 * B11P4 = A22 * S4P5 = S5 * S6P6 = S7 * S8P7 = S9 * S10
                            C11 = P5 + P4 - P2 + P6C12 = P1 + P2C21 = P3 + P4C22 = P5 + P1 - P3 - P7

​ 其中A11,A12,A21,A22和B11,B12,B21,B22分别为两个乘数A和B矩阵的四个子矩阵。C11,C12,C21,C22为最终的结果C矩阵的四个子矩阵。该递推式是被数学家证明过的。

​ 该算法的效率为O(n(log27)n^(log_27)),但是相对来说额外空间的使用也是很多的。

Matrix* Matrix::strassen_calc(const Matrix& x)
{if (x.row < 2){return this->force_calc(x);}//将第一个矩阵分解为四个子矩阵Matrix A11(0, 0, row / 2, col / 2, *this);Matrix A12(row / 2, 0, row, col / 2, *this);Matrix A21(0, col / 2, row / 2, col, *this);Matrix A22(row / 2, col / 2, row, col, *this);//将第二个矩阵分解为四个子矩阵Matrix B11(0, 0, row / 2, col / 2, x);Matrix B12(row / 2, 0, row, col / 2, x);Matrix B21(0, col / 2, row / 2, col, x);Matrix B22(row / 2, col / 2, row, col, x);Matrix* S1 = B12 - B22;Matrix* S2 = A11 + A12;Matrix* S3 = A21 + A22;Matrix* S4 = B21 - B11;Matrix* S5 = A11 + A22;Matrix* S6 = B11 + B22;Matrix* S7 = A12 - A22;Matrix* S8 = B21 + B22;Matrix* S9 = A11 - A21;Matrix* S10 = B11 + B12;Matrix* P1 = B12 - B22;Matrix* P2 = B12 - B22;Matrix* P3 = B12 - B22;Matrix* P4 = B12 - B22;Matrix* P5 = B12 - B22;Matrix* P6 = B12 - B22;Matrix* P7 = B12 - B22;P1 = A11.strassen_calc(*S1);P2 = S2->strassen_calc(B22);P3 = S3->strassen_calc(B11);P4 = A22.strassen_calc(*S4);P5 = S5->strassen_calc(*S6);P6 = S7->strassen_calc(*S8);P7 = S9->strassen_calc(*S10);Matrix *C11 = Matrix::sub(Matrix::add(P5, P4), Matrix::sub(P2, P6));Matrix *C12 = Matrix::add(P1, P2);Matrix *C21 = Matrix::add(P3, P4);Matrix *C22 = Matrix::sub(Matrix::add(P5, P1), Matrix::add(P3, P7));return Matrix::merge(C11,C12,C21,C22);
}

执行效果:

完整的代码

//Matrix.h
#pragma once
#ifndef _MATRIX_H_
#define _MATRIX_H_#include <iostream>
#include <vector>
using std::vector;
#define VISE 5
#define GATE 16                 //用来限定使用哪种算法进行计算#include <cstdlib>
#include <ctime>typedef int type;
class Matrix
{
private:int row;                            //行int col;                            //列vector<vector<type>> data;          //数据
public:Matrix(int row, int col) :data(row),row(row),col(col)                   //矩阵数据生成利用随机数进行生成{for (int i = 0; i < row; i++){data[i].resize(col);}srand(time(0));for (int i = 0; i < row; i++){for (int j = 0; j < col; j++){data[i][j] = rand() % VISE;}}}Matrix(int row1, int col1, int row2, int col2, const Matrix& x) :row(row2 - row1), col(col2 - col1),data(row){for (int i = 0; i < row; i++){data[i].resize(col);}for (int i = 0; i < row; i++){for (int j = 0; j < col; j++){data[i][j] = x.getElem(col1 + i, row1 + j);}}}Matrix(const Matrix& x){*this = x;}//相关算数运算操作Matrix* operator+(const Matrix&);Matrix* operator-(const Matrix&);Matrix* operator*(const Matrix&);static Matrix* add(const Matrix*, const Matrix*);                               //+static Matrix* sub(const Matrix*, const Matrix*);                               //-static Matrix* merge(const Matrix*, const Matrix*,const Matrix*,const Matrix*); //将四个子矩阵合并为一个矩阵//获取矩阵的相关元素vector<type> operator[](const int);             //取得rowtype getElem(const int,const int) const;        //获取相关节点的数据void setElem(const int, const int, type);       //设置节点的数据   //计算乘法的算法Matrix* force_calc(const Matrix&);              //直接暴力求解Matrix* merge_calc(const Matrix&);              //分治求解Matrix* strassen_calc(const Matrix&);           //Strassen算法void show();                                    //打印矩阵bool isSimilar(const Matrix& x);                //行列相同即为同类型矩阵void clear(type);                               //设置矩阵中所有的元素为同一个指定的值                                ~Matrix();
};#endif
//_MATRIX_H_
//Matrix.cpp
#include "Matrix.h"Matrix* Matrix::operator*(const Matrix& x)
{if (x.row != this->col){return nullptr;}if (x.row < VISE && x.col < VISE && row < VISE && col < VISE){return this->merge_calc(x);}return this->strassen_calc(x);
}/***  方法一:暴力直接求解问题*  时间复杂度为O(n^3)*/
Matrix* Matrix::force_calc(const Matrix& x)
{if (x.row != this->col)                                             //行列不同无法进行乘法,可以进行补零将相关矩阵填充为可使用的矩阵{                                                                   //这里不进行相关的编写return nullptr;}Matrix *ptr = new Matrix(row, x.col);ptr->clear(0);for (int i = 0; i < row; i++){for (int j = 0; j < x.col; j++){for (int k = 0; k < col; k++){ptr->setElem(i, j, ptr->getElem(i, j) + getElem(i, k) * x.getElem(k, j));}}}return ptr;
}void Matrix::clear(type cur = 0)
{for (int i = 0; i < row; i++){for (int j = 0; j < col; j++){data[i][j] = cur;}}
}/***  方法二:利用分治思想进行求解*  存在问题无法解决不同类型的矩阵的问题,要求矩阵的行列必须为2的n次方,若不符合要求可以使用*  补0来构造相关的矩阵*/
Matrix* Matrix::merge_calc(const Matrix& x)
{if (x.row == 1)             //当前的矩阵为单个的元素{Matrix *ptr = new Matrix(x.row, x.col);ptr->clear((this->getElem(0, 0))*(x.getElem(0, 0)));return ptr;}//将第一个矩阵分解为四个子矩阵Matrix A11(0,0,row/2,col/2,*this);Matrix A12(row / 2, 0, row, col / 2, *this);Matrix A21(0, col / 2, row / 2, col, *this);Matrix A22(row / 2, col / 2, row, col, *this);//将第二个矩阵分解为四个子矩阵Matrix B11(0, 0, row / 2, col / 2, x);Matrix B12(row / 2, 0, row, col / 2, x);Matrix B21(0, col / 2, row / 2, col, x);Matrix B22(row / 2, col / 2, row, col, x);Matrix *C11 = Matrix::add(A11.merge_calc(B11), A12.merge_calc(B21));Matrix *C12 = Matrix::add(A11.merge_calc(B12), A12.merge_calc(B22));Matrix *C21 = Matrix::add(A21.merge_calc(B11), A22.merge_calc(B21));Matrix *C22 = Matrix::add(A21.merge_calc(B12), A22.merge_calc(B22));//将C11,C12,C21,C22合并为一个完整的矩阵Matrix* ptr = Matrix::merge(C11, C12, C21, C22);return ptr;
}Matrix* Matrix::strassen_calc(const Matrix& x)
{if (x.row < 2){return this->force_calc(x);}//将第一个矩阵分解为四个子矩阵Matrix A11(0, 0, row / 2, col / 2, *this);Matrix A12(row / 2, 0, row, col / 2, *this);Matrix A21(0, col / 2, row / 2, col, *this);Matrix A22(row / 2, col / 2, row, col, *this);//将第二个矩阵分解为四个子矩阵Matrix B11(0, 0, row / 2, col / 2, x);Matrix B12(row / 2, 0, row, col / 2, x);Matrix B21(0, col / 2, row / 2, col, x);Matrix B22(row / 2, col / 2, row, col, x);Matrix* S1 = B12 - B22;Matrix* S2 = A11 + A12;Matrix* S3 = A21 + A22;Matrix* S4 = B21 - B11;Matrix* S5 = A11 + A22;Matrix* S6 = B11 + B22;Matrix* S7 = A12 - A22;Matrix* S8 = B21 + B22;Matrix* S9 = A11 - A21;Matrix* S10 = B11 + B12;Matrix* P1 = B12 - B22;Matrix* P2 = B12 - B22;Matrix* P3 = B12 - B22;Matrix* P4 = B12 - B22;Matrix* P5 = B12 - B22;Matrix* P6 = B12 - B22;Matrix* P7 = B12 - B22;P1 = A11.strassen_calc(*S1);P2 = S2->strassen_calc(B22);P3 = S3->strassen_calc(B11);P4 = A22.strassen_calc(*S4);P5 = S5->strassen_calc(*S6);P6 = S7->strassen_calc(*S8);P7 = S9->strassen_calc(*S10);Matrix *C11 = Matrix::sub(Matrix::add(P5, P4), Matrix::sub(P2, P6));Matrix *C12 = Matrix::add(P1, P2);Matrix *C21 = Matrix::add(P3, P4);Matrix *C22 = Matrix::sub(Matrix::add(P5, P1), Matrix::add(P3, P7));return Matrix::merge(C11,C12,C21,C22);
}/***  将四个子矩阵合并为一个完整的矩阵*  也可以使用分治思想进行解决,以后可能会添加相关的功能*/
Matrix* Matrix::merge(const Matrix* p1, const Matrix* p2,const Matrix* p3, const Matrix* p4)
{//不符合可以进行合并的条件if (!(p1->row == p2->row && p2->col == p4->col && p4->row == p3->row && p1->col == p3->col)){return nullptr;}Matrix* ptr = new Matrix(p1->row + p3->row, p2->col + p1->col);ptr->clear(0);//重新装值for (int i = 0; i < p1->row; i++){for (int j = 0; j < p1->col; j++){ptr->setElem(i, j, p1->getElem(i, j));}}for (int i = 0; i < p2->row; i++){for (int j = 0; j < p2->col; j++){ptr->setElem(i, j + p1->col, p2->getElem(i, j));}}for (int i = 0; i < p3->row; i++){for (int j = 0; j < p3->col; j++){ptr->setElem(i + p1->row, j, p3->getElem(i, j));}}for (int i = 0; i < p4->row; i++){for (int j = 0; j < p4->col; j++){ptr->setElem(p1->row + i, p1->col + j, p4->getElem(i, j));}}return ptr;
}Matrix* Matrix::sub(const Matrix* p1, const Matrix* p2)
{if (!(p1->col == p2->col && p1->row == p2->row)){return nullptr;}Matrix *ptr = new Matrix(p1->row, p1->col);for (int i = 0; i < p1->row; i++){for (int j = 0; j < p1->col; j++){ptr->setElem(i, j, (p1->getElem(i, j) - p2->getElem(i, j)));}}return ptr;
}Matrix* Matrix::add(const Matrix* p1, const Matrix* p2)
{if (!(p1->col == p2->col && p1->row == p2->row)){return nullptr;}Matrix *ptr = new Matrix(p1->row, p1->col);for (int i = 0; i < p1->row; i++){for (int j = 0; j < p1->col; j++){ptr->setElem(i, j, (p1->getElem(i, j) + p2->getElem(i, j)));}}return ptr;
}Matrix* Matrix::operator+(const Matrix& x)
{if (!isSimilar(x)){return nullptr;}Matrix *ptr = new Matrix(x.row, x.col);                             //内存需要释放for (int i = 0; i < row; i++){for (int j = 0; j < col; j++){ptr->setElem(i, j, this->getElem(i, j) + x.getElem(i, j));}}return ptr;
}Matrix* Matrix::operator-(const Matrix& x)
{if (!isSimilar(x)){return nullptr;}Matrix *ptr = new Matrix(x.row, x.col);                             //内存需要释放for (int i = 0; i < row; i++){for (int j = 0; j < col; j++){ptr->setElem(i, j, this->getElem(i, j) - x.getElem(i, j));}}return ptr;
}vector<type> Matrix::operator[](const int row)
{return data[row];
}type Matrix::getElem(int row, int col)const
{return this->data[row][col];
}void Matrix::setElem(int row, int col, type cur)
{this->data[row][col] = cur;
}void Matrix::show()
{for (int i = 0; i < row; i++){if (i == 0){std::cout << "┏";}else if (i == row - 1){std::cout << "┗";}else{std::cout << "┃";}for (int j = 0; j < col; j++){std::cout.width(4);std::cout << data[i][j];}if (i == 0){std::cout << "   ┓";}else if (i == row - 1){std::cout << "   ┛";}else{std::cout << "   ┃";}std::cout << std::endl;}
}bool Matrix::isSimilar(const Matrix& x)
{return x.row == this->row && this->col == x.col;
}Matrix::~Matrix()
{this->row = 0;this->col = 0;
}

矩阵乘积计算(Strassen)相关推荐

  1. 离散数学·(不调用第三方库)普通矩阵乘积/关系矩阵乘积,理论+python代码实现

    矩阵乘法如何计算? 普通矩阵乘法:第一个矩阵的列数等于第二个矩阵的行数. 矩阵关系运算前提: (1)第一个矩阵的列数等于第二个矩阵的行数. (2)两个矩阵的元素均是0或1. 这里以关系矩阵乘法为例: ...

  2. C语言数组使用、数组相关的宏定义剖析,及矩阵乘积、杨辉三角实例

         数组一直是编程语言学习中需要重点掌握的部分,它是最基本的数据结构类型,是构成字符串及其他许多重要构造结构的基础.相对许多其他高级语言来说,C语言对数组本身提供的支持并不太多,尤其是不支持动态 ...

  3. Lua计算kronecker 积、Khatri-Rao积、Hadamard积、普通矩阵乘积

    Lua计算kronecker 积.Khatri-Rao积.Hadamard积.普通矩阵乘积 function Kron(A,B,mark)local C ={}if mark==0thenrowC=r ...

  4. Python矩阵计算类:计算矩阵加和、矩阵乘积、矩阵转置、矩阵行列式值、伴随矩阵和逆矩阵

    最近在Python程序设计中遇到一道设计矩阵计算类的题目,原题目要求计算矩阵加和和矩阵乘积,而我出于设计和挑战自己的目的,为自己增加难度,因此设计出矩阵计算类,不仅可以求出矩阵加和和矩阵乘积,还能计算 ...

  5. 矩阵相乘的strassen算法_矩阵乘法的Strassen算法+动态规划算法(矩阵链相乘和硬币问题)...

    矩阵乘法的Strassen 这个算法就是在矩阵乘法中采用分治法,能够有效的提高算法的效率. 先来看看咱们在高等代数中学的普通矩阵的乘法 两个矩阵相乘 上边这种普通求解方法的复杂度为: O(n3) 也称 ...

  6. 矩阵相乘的strassen算法_4-2.矩阵乘法的Strassen算法详解

    题目描述 请编程实现矩阵乘法,并考虑当矩阵规模较大时的优化方法. 思路分析 根据wikipedia上的介绍:两个矩阵的乘法仅当第一个矩阵B的列数和另一个矩阵A的行数相等时才能定义.如A是m×n矩阵和B ...

  7. 矩阵乘法计算速度再次突破极限,我炼丹能更快了吗?| 哈佛、MIT

    梦晨 发自 凹非寺 量子位 报道 | 公众号 QbitAI n阶矩阵乘法最优解的时间复杂度再次被突破,达到了. 按定义直接算的话,时间复杂度是O(n³). 光这么说可能不太直观,从图上可以看出,n足够 ...

  8. 将矩阵转为一行_矩阵与矩阵乘积简介

    作者|Hadrien Jean 编译|VK 来源|Towards Data Science 原文链接:https://towardsdatascience.com/introduction-to-ma ...

  9. ZZULIOJ 1127: 矩阵乘积

    矩阵乘积 题目描述 计算两个矩阵A和B的乘积. 输入 第一行三个正整数m.p和n,0<=m,n,p<=10,表示矩阵A是m行p列,矩阵B是p行n列: 接下来的m行是矩阵A的内容,每行p个整 ...

最新文章

  1. My97DatePicker日历控件日报、每周和每月的选择
  2. is NULL , is NOT NULL 有时索引失效 || in 走索引, not in 索引失效 ||单列索引和复合索引 || 查看索引使用情况
  3. 以非root 用戶安裝並啟動高級單服務器版
  4. 机器学习第6天:数据可视化神器--Matplotlib
  5. 深入理解ThreadLocal
  6. ASP.NET MVC入门(一)---MVC的Hello World
  7. FreeBSD9.1安装Gnome2桌面
  8. Flask笔记-使用flask-sqlacodegen自动生成model
  9. 扒一扒那些奇葩的甲方吧
  10. Atom 备份神器 —— Sync Settings
  11. linux系统shell脚本编程,Linux系统shell脚本编程(一)
  12. html留言板代码_接口测试平台代码实现19.首页优化
  13. Pr常见问题,pr素材脱机后该如何恢复?
  14. C++引用(作为函数参数和返回值)
  15. Bayesian framework 贝叶斯框架 (R)
  16. 电脑wifi 找不到网络怎么办
  17. 向量法计算体积的思路(没有代码了)
  18. 成为会带团队的技术人 跨团队:没有汇报线的人和事就是推不动?
  19. 微信CRM六大模块详解
  20. mysql的UNIX_TIMESTAMP用法

热门文章

  1. Access表数据类型/字段类型
  2. linuxRC的含义
  3. 【无标题】置信规则库研究现状,研究知识图谱,研究大全一览
  4. UE Base64图片格式 的加载显示方式
  5. php 取整 floor,php 取整函数(floor,ceil,round,intval)
  6. 初一计算机课怎么上,初中信息技术七年级上册《初识计算机》公开课PPT课件
  7. Android移动应用开发之TextView实现阴影跑马灯文字效果
  8. 常用音频接口:TDM,PDM,I2S,PCM
  9. 国外网络开放课程遍地开花
  10. 如何生成指定分布的随机数