最近做一个机器学习的课题,主体是matlab写的,其中有部分训练的核心算法是用c++写的,因为有太多的循环和数值计算用c++比较快。这也是我第一次用c++写matlab的模块,感觉走了很多弯路,下面给大家分享一点经验。
matlab中的c++编程称为mex编程:matlab executive matlab 可执行文件,至于其中的具体机制我不是很清楚,有的大神会比较清楚编译期间产生的各种文件。
1) mex编程中指针和索引:
matlab中默认的数据类型是double,用class()函数可以看到变量的数据类型:
matlab代码如下:
mex mex.cpp ‘-g‘; a = [1.1,2.1,3;4,5,6;7,8,9] mex(a)
命令mex 用来编译mex文件,上面代码中 mex mex.cpp ‘-g’ 编译了mex.cpp这个c++文件,编译完成之后会生成一个“mex.mexw64”的文件,后缀名说明这是在win64下编译完成的mex文件,后面的‘-g‘是一个附加参数,在这里不用理解。编译后的mex文件可以当matlab函数使用。
在matlab代码文件同目录下的c文件mex.cpp代码如下:
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) { double *input; input = mxGetPr(prhs[0]); printf("第一个值%f\n",*input); printf("第二个值%f\n",*(input+1)); }
mexFunction有四个参数,nlhs( number left hand s):左边参数个数,也就是matlab函数输出值得个数,mxArray *plhs[]是一个指针数组,数组中的每一个元素都是一个指针,指向输出的矩阵;nrhs 是右边参数个数,也就是输入参数的个数,mxArray *prhs[]数组中的每个指针指向输入矩阵。mxGetPr()函数返回一个double*型的指针,指向矩阵的第一个元素,在matlab代码中调用:mex(a),那也就是 prhs[0]是输入矩阵a的地址,而 input = mxGetPr(prhs[0]) ,input指向了a第一个元素1 。
那矩阵第一排第二列的值a(1,2)的地址是多少呢?是(input+1)吗?在这里我们运行上面的matlab代码,得到的结果如下:
可以看出,输出的*(input+1)是4,也就是说,c++中的matlab矩阵是按列进行索引的。这里是一个需要注意的地方,因为很多地方要对matlab输入的矩阵进行遍历得到矩阵的元素值,如果索引出错,那就完全错了。其实这里的内在原因,是因为在matlab中矩阵是按列进行索引的,而c++中指针式按行往后加的。
有很多函数可以方便我们对矩阵进行索引,uint32 mxGetM(mxArray *)输入一个矩阵的指针,返回该矩阵的行数,uint32 mxGetN(mxArray *)返回列数,对行数和列数适当的计算,可以方便的访问矩阵元素,例如,访问a(i,j): *(input+N*(j-1)+(i-1)) ,N为矩阵行数,这里需要-1的原因是,matlab的行数列数从1开始计数,而c的数组则从0开始索引。
2)mex编程中的数据类型与指针移位的重要关系,mxGetPr() 与 mxGetData():
前面说过,matlab里的默认数据类型是double,那么,如果把mex函数的输入矩阵的数据类型转换一下,会出现什么结果呢?
matlab 代码:
1 clc 2 mex mex.cpp ‘-g‘; 3 a = [1.1,2.1,3;4,5,6;7,8,9]; 4 a=single(a) 5 mex(a)
c++代码:
1 #include "mex.h" 2 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) 3 { 4 double *input; 5 input = mxGetPr(prhs[0]); 6 printf("第一个值%f\n",*input); 7 printf("第二个值%f\n",*(input+1)); 8 }
c++代码并没有变,matlab代码也仅仅进行了一个数据类型转换,我们看看输出结果:
可以看到这里输出的已经不是我们期望的数值了。在我调试mex代码的时候这个问题苦恼了我很久,因为mex不方便调试,很多时候输出的结果不是想要的,而且我的输入矩阵都是上万维的,很难调试。这里输入矩阵a变成了single单精度类型,前面我们说过,mxGetPr()返回double类型的指针,当我们用double类型指针访问一个单精度(在c++)中我们称之为浮点型float的数据的时候,当然会发成内存越界,用取值符号*去取值的时候超过了数据的内存块,因此发生错误,如果我们修改c++代码:
1 #include "mex.h" 2 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) 3 { 4 float* input; 5 input = (float*)mxGetPr(prhs[0]); 6 printf("第一个值%f\n",*input); 7 printf("第二个值%f\n",*(input+1)); 8 }
将input类型设置成float,并将mxGetPr()的返回类型强制转换为float*就可以了。在这里还有一个函数mxGetData()也可以返回输入矩阵的头地址,只不过mxGetData()返回的是char*类型的指针,而mxGetPr()返回的是double*类型的指针,可以根据自己的需要选取函数,或者转换指针类型。如果指针类型不对,极有可能造成内存访问错误,导致matlab死掉。
3) nlhs 与 nrhs的作用
mexFunction函数中,两个指针参数分别指向输入输出的矩阵,而nrhs和nlhs分别记录输入输出矩阵的个数,在一般的操作中,我们仅仅对输入矩阵进行取值,运算,对输出矩阵进行赋值,nrhs和nlhs不是很常用,但是也是极其重要的。例如,在上面的代码中,如果我在matlab代码中这样调用mex:mex(),不输入任何参数,matlab就会马上死掉。因为在mex文件的cpp代码中,你用指针访问了输入矩阵的值,而在参数中你没有给mex输入任何参数,使得矩阵指针为野指针,导致内存错误。如果编码中出现这种参数不对的情况,将导致matlab频繁死掉,我的工作中数据特别多,准备数据需要几十分钟,这样让我非常痛苦。解决的方法就是利用nlhs和nrhs这两个参数。在mexFunction中判断nlhs的值来判断输入参数的个数,用nrhs判断输入参数的个数。如果输入参数少于某个值或者不满足你的要求可以让mexFunction直接return,避免后续的程序导致内存错误。