最近实现一个算法要用到求逆等矩阵运算,在网上搜到一个别人写的矩阵类,试了一下效果不错,贴在这里,做个保存。
matrix.h文件:
1 #ifndef __MATRIX_H__ 2 #define __MATRIX_H__ 3 4 #pragma once 5 6 #include <iostream> 7 #include <fstream> 8 #include <sstream> 9 #include <vector> 10 #include <string> 11 12 using std::vector; 13 using std::string; 14 using std::cout; 15 using std::cin; 16 using std::istream; 17 using std::ostream; 18 19 // 任意类型矩阵类 20 template <typename Object> 21 class MATRIX 22 { 23 public: 24 explicit MATRIX() : array( 0 ) {} 25 26 MATRIX( int rows, int cols):array( rows ) 27 { 28 for( int i = 0; i < rows; ++i ) 29 { 30 array[i].resize( cols ); 31 } 32 } 33 34 MATRIX( const MATRIX<Object>& m ){ *this = m;} 35 36 void resize( int rows, int cols ); // 改变当前矩阵大小 37 bool push_back( const vector<Object>& v ); // 在矩阵末尾添加一行数据 38 void swap_row( int row1, int row2 ); // 将换两行的数据 39 40 int rows() const{ return array.size(); } 41 int cols() const { return rows() ? (array[0].size()) : 0; } 42 bool empty() const { return rows() == 0; } // 是否为空 43 bool square() const { return (!(empty()) && rows() == cols()); } // 是否为方阵 44 45 46 const vector<Object>& operator[](int row) const { return array[row]; } //[]操作符重载 47 vector<Object>& operator[](int row){ return array[row]; } 48 49 protected: 50 vector< vector<Object> > array; 51 }; 52 53 // 改变当前矩阵大小 54 template <typename Object> 55 void MATRIX<Object>::resize( int rows, int cols ) 56 { 57 int rs = this->rows(); 58 int cs = this->cols(); 59 60 if ( rows == rs && cols == cs ) 61 { 62 return; 63 } 64 else if ( rows == rs && cols != cs ) 65 { 66 for ( int i = 0; i < rows; ++i ) 67 { 68 array[i].resize( cols ); 69 } 70 } 71 else if ( rows != rs && cols == cs ) 72 { 73 array.resize( rows ); 74 for ( int i = rs; i < rows; ++i ) 75 { 76 array[i].resize( cols ); 77 } 78 } 79 else 80 { 81 array.resize( rows ); 82 for ( int i = 0; i < rows; ++i ) 83 { 84 array[i].resize( cols ); 85 } 86 } 87 } 88 89 // 在矩阵末尾添加一行 90 template <typename Object> 91 bool MATRIX<Object>::push_back( const vector<Object>& v ) 92 { 93 if ( rows() == 0 || cols() == (int)v.size() ) 94 { 95 array.push_back( v ); 96 } 97 else 98 { 99 return false; 100 } 101 102 return true; 103 } 104 105 // 将换两行 106 template <typename Object> 107 void MATRIX<Object>::swap_row( int row1, int row2 ) 108 { 109 if ( row1 != row2 && row1 >=0 && 110 row1 < rows() && row2 >= 0 && row2 < rows() ) 111 { 112 vector<Object>& v1 = array[row1]; 113 vector<Object>& v2 = array[row2]; 114 vector<Object> tmp = v1; 115 v1 = v2; 116 v2 = tmp; 117 } 118 } 119 120 // 矩阵转置 121 template <typename Object> 122 const MATRIX<Object> trans( const MATRIX<Object>& m ) 123 { 124 MATRIX<Object> ret; 125 if ( m.empty() ) return ret; 126 127 int row = m.cols(); 128 int col = m.rows(); 129 ret.resize( row, col ); 130 131 for ( int i = 0; i < row; ++i ) 132 { 133 for ( int j = 0; j < col; ++j ) 134 { 135 ret[i][j] = m[j][i]; 136 } 137 } 138 139 return ret; 140 } 141 142 ////////////////////////////////////////////////////////// 143 // double类型矩阵类,用于科学计算 144 // 继承自MATRIX类 145 // 实现常用操作符重载,并实现计算矩阵的行列式、逆以及LU分解 146 class Matrix:public MATRIX<double> 147 { 148 public: 149 Matrix():MATRIX<double>(){} 150 Matrix( int c, int r ):MATRIX<double>(c,r){} 151 Matrix( const Matrix& m){ *this = m; } 152 153 const Matrix& operator+=( const Matrix& m ); 154 const Matrix& operator-=( const Matrix& m ); 155 const Matrix& operator*=( const Matrix& m ); 156 const Matrix& operator/=( const Matrix& m ); 157 }; 158 159 bool operator==( const Matrix& lhs, const Matrix& rhs ); // 重载操作符== 160 bool operator!=( const Matrix& lhs, const Matrix& rhs ); // 重载操作符!= 161 const Matrix operator+( const Matrix& lhs, const Matrix& rhs ); // 重载操作符+ 162 const Matrix operator-( const Matrix& lhs, const Matrix& rhs ); // 重载操作符- 163 const Matrix operator*( const Matrix& lhs, const Matrix& rhs ); // 重载操作符* 164 const Matrix operator/( const Matrix& lhs, const Matrix& rhs ); // 重载操作符/ 165 const double det( const Matrix& m ); // 计算行列式 166 const double det( const Matrix& m, int start, int end ); // 计算子矩阵行列式 167 const Matrix abs( const Matrix& m ); // 计算所有元素的绝对值 168 const double max( const Matrix& m ); // 所有元素的最大值 169 const double max( const Matrix& m, int& row, int& col); // 所有元素中的最大值及其下标 170 const double min( const Matrix& m ); // 所有元素的最小值 171 const double min( const Matrix& m, int& row, int& col); // 所有元素的最小值及其下标 172 const Matrix trans( const Matrix& m ); // 返回转置矩阵 173 const Matrix submatrix(const Matrix& m,int rb,int re,int cb,int ce); // 返回子矩阵 174 const Matrix inverse( const Matrix& m ); // 计算逆矩阵 175 const Matrix LU( const Matrix& m ); // 计算方阵的LU分解 176 const Matrix readMatrix( istream& in = std::cin ); // 从指定输入流读入矩阵 177 const Matrix readMatrix( string file ); // 从文本文件读入矩阵 178 const Matrix loadMatrix( string file ); // 从二进制文件读取矩阵 179 void printMatrix( const Matrix& m, ostream& out = std::cout ); // 从指定输出流打印矩阵 180 void printMatrix( const Matrix& m, string file); // 将矩阵输出到文本文件 181 void saveMatrix( const Matrix& m, string file); // 将矩阵保存为二进制文件 182 183 184 #endif
matrix.cpp文件:
1 #include "Matrix.h" 2 #include <iomanip> //用于设置输出格式 3 4 using std::ifstream; 5 using std::ofstream; 6 using std::istringstream; 7 using std::cerr; 8 using std::endl; 9 10 const Matrix& Matrix::operator+=( const Matrix& m ) 11 { 12 if ( rows() != m.rows() || rows() != m.cols() ) 13 { 14 return *this; 15 } 16 17 int r = rows(); 18 int c = cols(); 19 20 for ( int i = 0; i < r; ++i ) 21 { 22 for ( int j = 0; j < c; ++j ) 23 { 24 array[i][j] += m[i][j]; 25 } 26 } 27 28 return *this; 29 } 30 31 32 const Matrix& Matrix::operator-=( const Matrix& m ) 33 { 34 if ( rows() != m.rows() || cols() != m.cols() ) 35 { 36 return *this; 37 } 38 39 int r = rows(); 40 int c = cols(); 41 42 for ( int i = 0; i < r; ++i ) 43 { 44 for ( int j = 0; j < c; ++j ) 45 { 46 array[i][j] -= m[i][j]; 47 } 48 } 49 50 return *this; 51 } 52 53 const Matrix& Matrix::operator*=( const Matrix& m ) 54 { 55 if ( cols() != m.rows() || !m.square() ) 56 { 57 return *this; 58 } 59 60 Matrix ret( rows(), cols() ); 61 62 int r = rows(); 63 int c = cols(); 64 65 for ( int i = 0; i < r; ++i ) 66 { 67 for ( int j = 0; j < c; ++j ) 68 { 69 double sum = 0.0; 70 for ( int k = 0; k < c; ++k ) 71 { 72 sum += array[i][k] * m[k][j]; 73 } 74 ret[i][j] = sum; 75 } 76 } 77 78 *this = ret; 79 return *this; 80 } 81 82 const Matrix& Matrix::operator/=( const Matrix& m ) 83 { 84 Matrix tmp = inverse( m ); 85 return operator*=( tmp ); 86 } 87 88 89 bool operator==( const Matrix& lhs, const Matrix& rhs ) 90 { 91 if ( lhs.rows() != rhs.rows() || lhs.cols() != rhs.cols() ) 92 { 93 return false; 94 } 95 96 for ( int i = 0; i < lhs.rows(); ++i ) 97 { 98 for ( int j = 0; j < lhs.cols(); ++j ) 99 { 100 if ( rhs[i][j] != rhs[i][j] ) 101 { 102 return false; 103 } 104 } 105 } 106 107 return true; 108 } 109 110 bool operator!=( const Matrix& lhs, const Matrix& rhs ) 111 { 112 return !( lhs == rhs ); 113 } 114 115 const Matrix operator+( const Matrix& lhs, const Matrix& rhs ) 116 { 117 Matrix m; 118 if ( lhs.rows() != rhs.rows() || lhs.cols() != rhs.cols() ) 119 { 120 return m; 121 } 122 123 m = lhs; 124 m += rhs; 125 126 return m; 127 } 128 129 const Matrix operator-( const Matrix& lhs, const Matrix& rhs ) 130 { 131 Matrix m; 132 if ( lhs.rows() != rhs.rows() || lhs.cols() != rhs.cols() ) 133 { 134 return m; 135 } 136 137 m = lhs; 138 m -= rhs; 139 140 return m; 141 } 142 143 const Matrix operator*( const Matrix& lhs, const Matrix& rhs ) 144 { 145 Matrix m; 146 if ( lhs.cols() != rhs.rows() ) 147 { 148 return m; 149 } 150 151 m.resize( lhs.rows(), rhs.cols() ); 152 153 int r = m.rows(); 154 int c = m.cols(); 155 int K = lhs.cols(); 156 157 for ( int i = 0; i < r; ++i ) 158 { 159 for ( int j = 0; j < c; ++j ) 160 { 161 double sum = 0.0; 162 for ( int k = 0; k < K; ++k ) 163 { 164 sum += lhs[i][k] * rhs[k][j]; 165 } 166 m[i][j] = sum; 167 } 168 } 169 170 return m; 171 } 172 173 const Matrix operator/( const Matrix& lhs, const Matrix& rhs ) 174 { 175 Matrix tmp = inverse( rhs ); 176 Matrix m; 177 178 if ( tmp.empty() ) 179 { 180 return m; 181 } 182 183 return m = lhs * tmp; 184 } 185 186 inline static double LxAbs( double d ) 187 { 188 return (d>=0)?(d):(-d); 189 } 190 191 inline 192 static bool isSignRev( const vector<double>& v ) 193 { 194 int p = 0; 195 int sum = 0; 196 int n = (int)v.size(); 197 198 for ( int i = 0; i < n; ++i ) 199 { 200 p = (int)v[i]; 201 if ( p >= 0 ) 202 { 203 sum += p + i; 204 } 205 } 206 207 if ( sum % 2 == 0 ) // 如果是偶数,说明不变号 208 { 209 return false; 210 } 211 return true; 212 } 213 214 // 计算方阵行列式 215 const double det( const Matrix& m ) 216 { 217 double ret = 0.0; 218 219 if ( m.empty() || !m.square() ) return ret; 220 221 Matrix N = LU( m ); 222 223 if ( N.empty() ) return ret; 224 225 ret = 1.0; 226 for ( int i = 0; i < N.cols(); ++ i ) 227 { 228 ret *= N[i][i]; 229 } 230 231 if ( isSignRev( N[N.rows()-1] )) 232 { 233 return -ret; 234 } 235 236 return ret; 237 } 238 239 // 计算矩阵指定子方阵的行列式 240 const double det( const Matrix& m, int start, int end ) 241 { 242 return det( submatrix(m, start, end, start, end) ); 243 } 244 245 246 // 计算矩阵转置 247 const Matrix trans( const Matrix& m ) 248 { 249 Matrix ret; 250 if ( m.empty() ) return ret; 251 252 int r = m.cols(); 253 int c = m.rows(); 254 255 ret.resize(r, c); 256 for ( int i = 0; i < r; ++i ) 257 { 258 for ( int j = 0; j < c; ++j ) 259 { 260 ret[i][j] = m[j][i]; 261 } 262 } 263 264 return ret; 265 } 266 267 // 计算逆矩阵 268 const Matrix inverse( const Matrix& m ) 269 { 270 Matrix ret; 271 272 if ( m.empty() || !m.square() ) 273 { 274 return ret; 275 } 276 277 int n = m.rows(); 278 279 ret.resize( n, n ); 280 Matrix A(m); 281 282 for ( int i = 0; i < n; ++i ) ret[i][i] = 1.0; 283 284 for ( int j = 0; j < n; ++j ) //每一列 285 { 286 int p = j; 287 double maxV = LxAbs(A[j][j]); 288 for ( int i = j+1; i < n; ++i ) // 找到第j列中元素绝对值最大行 289 { 290 if ( maxV < LxAbs(A[i][j]) ) 291 { 292 p = i; 293 maxV = LxAbs(A[i][j]); 294 } 295 } 296 297 if ( maxV < 1e-20 ) 298 { 299 ret.resize(0,0); 300 return ret; 301 } 302 303 if ( j!= p ) 304 { 305 A.swap_row( j, p ); 306 ret.swap_row( j, p ); 307 } 308 309 double d = A[j][j]; 310 for ( int i = j; i < n; ++i ) A[j][i] /= d; 311 for ( int i = 0; i < n; ++i ) ret[j][i] /= d; 312 313 for ( int i = 0; i < n; ++i ) 314 { 315 if ( i != j ) 316 { 317 double q = A[i][j]; 318 for ( int k = j; k < n; ++k ) 319 { 320 A [i][k] -= q * A[j][k]; 321 } 322 for ( int k = 0; k < n; ++k ) 323 { 324 ret[i][k] -= q * ret[j][k]; 325 } 326 } 327 } 328 } 329 330 return ret; 331 } 332 333 // 计算绝对值 334 const Matrix abs( const Matrix& m ) 335 { 336 Matrix ret; 337 338 if( m.empty() ) 339 { 340 return ret; 341 } 342 343 int r = m.rows(); 344 int c = m.cols(); 345 ret.resize( r, c ); 346 347 for ( int i = 0; i < r; ++i ) 348 { 349 for ( int j = 0; j < c; ++j ) 350 { 351 double t = m[i][j]; 352 if ( t < 0 ) ret[i][j] = -t; 353 else ret[i][j] = t; 354 } 355 } 356 357 return ret; 358 } 359 360 // 返回矩阵所有元素的最大值 361 const double max( const Matrix& m ) 362 { 363 if ( m.empty() ) return 0.; 364 365 double ret = m[0][0]; 366 int r = m.rows(); 367 int c = m.cols(); 368 369 for ( int i = 0; i < r; ++i ) 370 { 371 for ( int j = 0; j < c; ++j ) 372 { 373 if ( m[i][j] > ret ) ret = m[i][j]; 374 } 375 } 376 return ret; 377 } 378 379 // 计算矩阵最大值,并返回该元素的引用 380 const double max( const Matrix& m, int& row, int& col ) 381 { 382 if ( m.empty() ) return 0.; 383 384 double ret = m[0][0]; 385 row = 0; 386 col = 0; 387 388 int r = m.rows(); 389 int c = m.cols(); 390 391 for ( int i = 0; i < r; ++i ) 392 { 393 for ( int j = 0; j < c; ++j ) 394 { 395 if ( m[i][j] > ret ) 396 { 397 ret = m[i][j]; 398 row = i; 399 col = j; 400 } 401 } 402 } 403 return ret; 404 } 405 406 // 计算矩阵所有元素最小值 407 const double min( const Matrix& m ) 408 { 409 if ( m.empty() ) return 0.; 410 411 double ret = m[0][0]; 412 int r = m.rows(); 413 int c = m.cols(); 414 415 for ( int i = 0; i < r; ++i ) 416 { 417 for ( int j = 0; j < c; ++j ) 418 { 419 if ( m[i][j] > ret ) ret = m[i][j]; 420 } 421 } 422 423 return ret; 424 } 425 426 // 计算矩阵最小值,并返回该元素的引用 427 const double min( const Matrix& m, int& row, int& col) 428 { 429 if ( m.empty() ) return 0.; 430 431 double ret = m[0][0]; 432 row = 0; 433 col = 0; 434 int r = m.rows(); 435 int c = m.cols(); 436 437 for ( int i = 0; i < r; ++i ) 438 { 439 for ( int j = 0; j < c; ++j ) 440 { 441 if ( m[i][j] > ret ) 442 { 443 ret = m[i][j]; 444 row = i; 445 col = j; 446 } 447 } 448 } 449 450 return ret; 451 } 452 453 // 取矩阵中指定位置的子矩阵 454 const Matrix submatrix(const Matrix& m,int rb,int re,int cb,int ce) 455 { 456 Matrix ret; 457 if ( m.empty() ) return ret; 458 459 if ( rb < 0 || re >= m.rows() || rb > re ) return ret; 460 if ( cb < 0 || ce >= m.cols() || cb > ce ) return ret; 461 462 ret.resize( re-rb+1, ce-cb+1 ); 463 464 for ( int i = rb; i <= re; ++i ) 465 { 466 for ( int j = cb; j <= ce; ++j ) 467 { 468 ret[i-rb][j-cb] = m[i][j]; 469 } 470 } 471 472 return ret; 473 } 474 475 476 inline static 477 int max_idx( const Matrix& m, int k, int n ) 478 { 479 int p = k; 480 for ( int i = k+1; i < n; ++i ) 481 { 482 if ( LxAbs(m[p][k]) < LxAbs(m[i][k]) ) 483 { 484 p = i; 485 } 486 } 487 return p; 488 } 489 490 // 计算方阵 M 的 LU 分解 491 // 其中L为对角线元素全为1的下三角阵,U为对角元素依赖M的上三角阵 492 // 使得 M = LU 493 // 返回矩阵下三角部分存储L(对角元素除外),上三角部分存储U(包括对角线元素) 494 const Matrix LU( const Matrix& m ) 495 { 496 Matrix ret; 497 498 if ( m.empty() || !m.square() ) return ret; 499 500 int n = m.rows(); 501 ret.resize( n+1, n ); 502 503 for ( int i = 0; i < n; ++i ) 504 { 505 ret[n][i] = -1.0; 506 } 507 508 for ( int i = 0; i < n; ++i ) 509 { 510 for ( int j = 0; j < n; ++j ) 511 { 512 ret[i][j] = m[i][j]; 513 } 514 } 515 516 for ( int k = 0; k < n-1; ++k ) 517 { 518 int p = max_idx( ret, k, n ); 519 if ( p != k ) // 进行行交换 520 { 521 ret.swap_row( k, p ); 522 ret[n][k] = (double)p; // 记录将换信息 523 } 524 525 if ( ret[k][k] == 0.0 ) 526 { 527 cout << "ERROR: " << endl; 528 ret.resize(0,0); 529 return ret; 530 } 531 532 for ( int i = k+1; i < n; ++i ) 533 { 534 ret[i][k] /= ret[k][k]; 535 for ( int j = k+1; j < n; ++j ) 536 { 537 ret[i][j] -= ret[i][k] * ret[k][j]; 538 } 539 } 540 } 541 542 return ret; 543 } 544 545 //--------------------------------------------------- 546 // 读取和打印 547 //--------------------------------------------------- 548 // 从输入流读取矩阵 549 const Matrix readMatrix( istream& in ) 550 { 551 Matrix M; 552 string str; 553 double b; 554 vector<double> v; 555 556 while( getline( in, str ) ) 557 { 558 for ( string::size_type i = 0; i < str.size(); ++i ) 559 { 560 if ( str[i] == ‘,‘ || str[i] == ‘;‘) 561 { 562 str[i] = ‘ ‘; 563 } 564 else if ( str[i] != ‘.‘ && (str[i] < ‘0‘ || str[i] > ‘9‘) 565 && str[i] != ‘ ‘ && str[i] != ‘\t‘ && str[i] != ‘-‘) 566 { 567 M.resize(0,0); 568 return M; 569 } 570 } 571 572 istringstream sstream(str); 573 v.resize(0); 574 575 while ( sstream >> b ) 576 { 577 v.push_back(b); 578 } 579 if ( v.size() == 0 ) 580 { 581 continue; 582 } 583 if ( !M.push_back( v ) ) 584 { 585 M.resize( 0, 0 ); 586 return M; 587 } 588 } 589 590 return M; 591 } 592 593 // 从文本文件读入矩阵 594 const Matrix readMatrix( string file ) 595 { 596 ifstream fin( file.c_str() ); 597 Matrix M; 598 599 if ( !fin ) 600 { 601 cerr << "Error: open file " << file << " failed." << endl; 602 return M; 603 } 604 605 M = readMatrix( fin ); 606 fin.close(); 607 608 return M; 609 } 610 611 // 将矩阵输出到指定输出流 612 void printMatrix( const Matrix& m, ostream& out ) 613 { 614 if ( m.empty() ) 615 { 616 return; 617 } 618 619 int r = m.rows(); 620 int c = m.cols(); 621 622 int n = 0; // 数据小数点前最大位数 623 double maxV = max(abs(m)); 624 while( maxV >= 1.0 ) 625 { 626 maxV /= 10; 627 ++n; 628 } 629 if ( n == 0 ) n = 1; // 如果最大数绝对值小于1,这小数点前位数为1,为数字0 630 int pre = 4; // 小数点后数据位数 631 int wid = n + pre + 3; // 控制字符宽度=n+pre+符号位+小数点位 632 633 out<<std::showpoint; 634 out<<std::setiosflags(std::ios::fixed); 635 out<<std::setprecision( pre ); 636 for ( int i = 0; i < r; ++i ) 637 { 638 for ( int j = 0; j < c; ++j ) 639 { 640 out<<std::setw(wid) << m[i][j]; 641 } 642 out << endl; 643 } 644 645 out<<std::setprecision(6); 646 out<<std::noshowpoint; 647 } 648 649 // 将矩阵打印到指定文件 650 void printMatrix( const Matrix& m, string file ) 651 { 652 ofstream fout( file.c_str() ); 653 if ( !fout ) return; 654 655 printMatrix( m, fout ); 656 fout.close(); 657 } 658 659 // 将矩阵数据存为二进制文件 660 void saveMatrix( const Matrix& m, string file ) 661 { 662 if ( m.empty() ) return; 663 664 ofstream fout(file.c_str(), std::ios_base::out|std::ios::binary ); 665 if ( !fout ) return; 666 667 int r = m.rows(); 668 int c = m.cols(); 669 char Flag[12] = "MATRIX_DATA"; 670 fout.write( (char*)&Flag, sizeof(Flag) ); 671 fout.write( (char*)&r, sizeof(r) ); 672 fout.write( (char*)&c, sizeof(c) ); 673 674 for ( int i = 0; i < r; ++i ) 675 { 676 for ( int j = 0; j < c; ++j ) 677 { 678 double t = m[i][j]; 679 fout.write( (char*)&t, sizeof(t) ); 680 } 681 } 682 683 fout.close(); 684 } 685 686 // 从二进制文件load矩阵 687 const Matrix loadMatrix( string file ) 688 { 689 Matrix m; 690 691 ifstream fin( file.c_str(), std::ios_base::in|std::ios::binary ); 692 if ( !fin ) return m; 693 694 char Flag[12]; 695 fin.read((char*)&Flag, sizeof(Flag)); 696 697 string str( Flag ); 698 if ( str != "MATRIX_DATA" ) 699 { 700 return m; 701 } 702 703 int r, c; 704 fin.read( (char*)&r, sizeof(r) ); 705 fin.read( (char*)&c, sizeof(c) ); 706 707 if ( r <=0 || c <=0 ) return m; 708 709 m.resize( r, c ); 710 double t; 711 712 for ( int i = 0; i < r; ++i ) 713 { 714 for ( int j = 0; j < c; ++j ) 715 { 716 fin.read( (char*)&t, sizeof(t) ); 717 m[i][j] = t; 718 } 719 } 720 721 return m; 722 }
时间: 2024-10-03 22:37:33