学习OpenCV——SVM 手写数字检测

转自http://blog.csdn.net/firefight/article/details/6452188

是MNIST手写数字图片库:http://code.google.com/p/supplement-of-the-mnist-database-of-handwritten-digits/downloads/list

其他方法:http://blog.csdn.net/onezeros/article/details/5672192

使用OPENCV训练手写数字识别分类器

1,下载训练数据和测试数据文件,这里用的是MNIST手写数字图片库,其中训练数据库中为60000个,测试数据库中为10000个
2,创建训练数据和测试数据文件读取函数,注意字节顺序为大端
3,确定字符特征方式为最简单的8×8网格内的字符点数


4,创建SVM,训练并读取,结果如下
 1000个训练样本,测试数据正确率80.21%(并没有体现SVM小样本高准确率的特性啊)
  10000个训练样本,测试数据正确率95.45%
  60000个训练样本,测试数据正确率97.67%

5,编写手写输入的GUI程序,并进行验证,效果还可以接受。

以下为主要代码,以供参考

(类似的也实现了随机树分类器,比较发现在相同的样本数情况下,SVM准确率略高)

[cpp] view plaincopyprint?

    1. #include "stdafx.h"
    2. #include <fstream>
    3. #include "opencv2/opencv.hpp"
    4. #include <vector>
    5. using namespace std;
    6. using namespace cv;
    7. #define SHOW_PROCESS 0
    8. #define ON_STUDY 0
    9. class NumTrainData
    10. {
    11. public:
    12. NumTrainData()
    13. {
    14. memset(data, 0, sizeof(data));
    15. result = -1;
    16. }
    17. public:
    18. float data[64];
    19. int result;
    20. };
    21. vector<NumTrainData> buffer;
    22. int featureLen = 64;
    23. void swapBuffer(char* buf)
    24. {
    25. char temp;
    26. temp = *(buf);
    27. *buf = *(buf+3);
    28. *(buf+3) = temp;
    29. temp = *(buf+1);
    30. *(buf+1) = *(buf+2);
    31. *(buf+2) = temp;
    32. }
    33. void GetROI(Mat& src, Mat& dst)
    34. {
    35. int left, right, top, bottom;
    36. left = src.cols;
    37. right = 0;
    38. top = src.rows;
    39. bottom = 0;
    40. //Get valid area
    41. for(int i=0; i<src.rows; i++)
    42. {
    43. for(int j=0; j<src.cols; j++)
    44. {
    45. if(src.at<uchar>(i, j) > 0)
    46. {
    47. if(j<left) left = j;
    48. if(j>right) right = j;
    49. if(i<top) top = i;
    50. if(i>bottom) bottom = i;
    51. }
    52. }
    53. }
    54. //Point center;
    55. //center.x = (left + right) / 2;
    56. //center.y = (top + bottom) / 2;
    57. int width = right - left;
    58. int height = bottom - top;
    59. int len = (width < height) ? height : width;
    60. //Create a squre
    61. dst = Mat::zeros(len, len, CV_8UC1);
    62. //Copy valid data to squre center
    63. Rect dstRect((len - width)/2, (len - height)/2, width, height);
    64. Rect srcRect(left, top, width, height);
    65. Mat dstROI = dst(dstRect);
    66. Mat srcROI = src(srcRect);
    67. srcROI.copyTo(dstROI);
    68. }
    69. int ReadTrainData(int maxCount)
    70. {
    71. //Open image and label file
    72. const char fileName[] = "../res/train-images.idx3-ubyte";
    73. const char labelFileName[] = "../res/train-labels.idx1-ubyte";
    74. ifstream lab_ifs(labelFileName, ios_base::binary);
    75. ifstream ifs(fileName, ios_base::binary);
    76. if( ifs.fail() == true )
    77. return -1;
    78. if( lab_ifs.fail() == true )
    79. return -1;
    80. //Read train data number and image rows / cols
    81. char magicNum[4], ccount[4], crows[4], ccols[4];
    82. ifs.read(magicNum, sizeof(magicNum));
    83. ifs.read(ccount, sizeof(ccount));
    84. ifs.read(crows, sizeof(crows));
    85. ifs.read(ccols, sizeof(ccols));
    86. int count, rows, cols;
    87. swapBuffer(ccount);
    88. swapBuffer(crows);
    89. swapBuffer(ccols);
    90. memcpy(&count, ccount, sizeof(count));
    91. memcpy(&rows, crows, sizeof(rows));
    92. memcpy(&cols, ccols, sizeof(cols));
    93. //Just skip label header
    94. lab_ifs.read(magicNum, sizeof(magicNum));
    95. lab_ifs.read(ccount, sizeof(ccount));
    96. //Create source and show image matrix
    97. Mat src = Mat::zeros(rows, cols, CV_8UC1);
    98. Mat temp = Mat::zeros(8, 8, CV_8UC1);
    99. Mat img, dst;
    100. char label = 0;
    101. Scalar templateColor(255, 0, 255 );
    102. NumTrainData rtd;
    103. //int loop = 1000;
    104. int total = 0;
    105. while(!ifs.eof())
    106. {
    107. if(total >= count)
    108. break;
    109. total++;
    110. cout << total << endl;
    111. //Read label
    112. lab_ifs.read(&label, 1);
    113. label = label + ‘0‘;
    114. //Read source data
    115. ifs.read((char*)src.data, rows * cols);
    116. GetROI(src, dst);
    117. #if(SHOW_PROCESS)
    118. //Too small to watch
    119. img = Mat::zeros(dst.rows*10, dst.cols*10, CV_8UC1);
    120. resize(dst, img, img.size());
    121. stringstream ss;
    122. ss << "Number " << label;
    123. string text = ss.str();
    124. putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
    125. //imshow("img", img);
    126. #endif
    127. rtd.result = label;
    128. resize(dst, temp, temp.size());
    129. //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);
    130. for(int i = 0; i<8; i++)
    131. {
    132. for(int j = 0; j<8; j++)
    133. {
    134. rtd.data[ i*8 + j] = temp.at<uchar>(i, j);
    135. }
    136. }
    137. buffer.push_back(rtd);
    138. //if(waitKey(0)==27) //ESC to quit
    139. //  break;
    140. maxCount--;
    141. if(maxCount == 0)
    142. break;
    143. }
    144. ifs.close();
    145. lab_ifs.close();
    146. return 0;
    147. }
    148. void newRtStudy(vector<NumTrainData>& trainData)
    149. {
    150. int testCount = trainData.size();
    151. Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);
    152. Mat res = Mat::zeros(testCount, 1, CV_32SC1);
    153. for (int i= 0; i< testCount; i++)
    154. {
    155. NumTrainData td = trainData.at(i);
    156. memcpy(data.data + i*featureLen*sizeof(float), td.data, featureLen*sizeof(float));
    157. res.at<unsigned int>(i, 0) = td.result;
    158. }
    159. /////////////START RT TRAINNING//////////////////
    160. CvRTrees forest;
    161. CvMat* var_importance = 0;
    162. forest.train( data, CV_ROW_SAMPLE, res, Mat(), Mat(), Mat(), Mat(),
    163. CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER));
    164. forest.save( "new_rtrees.xml" );
    165. }
    166. int newRtPredict()
    167. {
    168. CvRTrees forest;
    169. forest.load( "new_rtrees.xml" );
    170. const char fileName[] = "../res/t10k-images.idx3-ubyte";
    171. const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";
    172. ifstream lab_ifs(labelFileName, ios_base::binary);
    173. ifstream ifs(fileName, ios_base::binary);
    174. if( ifs.fail() == true )
    175. return -1;
    176. if( lab_ifs.fail() == true )
    177. return -1;
    178. char magicNum[4], ccount[4], crows[4], ccols[4];
    179. ifs.read(magicNum, sizeof(magicNum));
    180. ifs.read(ccount, sizeof(ccount));
    181. ifs.read(crows, sizeof(crows));
    182. ifs.read(ccols, sizeof(ccols));
    183. int count, rows, cols;
    184. swapBuffer(ccount);
    185. swapBuffer(crows);
    186. swapBuffer(ccols);
    187. memcpy(&count, ccount, sizeof(count));
    188. memcpy(&rows, crows, sizeof(rows));
    189. memcpy(&cols, ccols, sizeof(cols));
    190. Mat src = Mat::zeros(rows, cols, CV_8UC1);
    191. Mat temp = Mat::zeros(8, 8, CV_8UC1);
    192. Mat m = Mat::zeros(1, featureLen, CV_32FC1);
    193. Mat img, dst;
    194. //Just skip label header
    195. lab_ifs.read(magicNum, sizeof(magicNum));
    196. lab_ifs.read(ccount, sizeof(ccount));
    197. char label = 0;
    198. Scalar templateColor(255, 0, 0);
    199. NumTrainData rtd;
    200. int right = 0, error = 0, total = 0;
    201. int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;
    202. while(ifs.good())
    203. {
    204. //Read label
    205. lab_ifs.read(&label, 1);
    206. label = label + ‘0‘;
    207. //Read data
    208. ifs.read((char*)src.data, rows * cols);
    209. GetROI(src, dst);
    210. //Too small to watch
    211. img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);
    212. resize(dst, img, img.size());
    213. rtd.result = label;
    214. resize(dst, temp, temp.size());
    215. //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);
    216. for(int i = 0; i<8; i++)
    217. {
    218. for(int j = 0; j<8; j++)
    219. {
    220. m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);
    221. }
    222. }
    223. if(total >= count)
    224. break;
    225. char ret = (char)forest.predict(m);
    226. if(ret == label)
    227. {
    228. right++;
    229. if(total <= 5000)
    230. right_1++;
    231. else
    232. right_2++;
    233. }
    234. else
    235. {
    236. error++;
    237. if(total <= 5000)
    238. error_1++;
    239. else
    240. error_2++;
    241. }
    242. total++;
    243. #if(SHOW_PROCESS)
    244. stringstream ss;
    245. ss << "Number " << label << ", predict " << ret;
    246. string text = ss.str();
    247. putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
    248. imshow("img", img);
    249. if(waitKey(0)==27) //ESC to quit
    250. break;
    251. #endif
    252. }
    253. ifs.close();
    254. lab_ifs.close();
    255. stringstream ss;
    256. ss << "Total " << total << ", right " << right <<", error " << error;
    257. string text = ss.str();
    258. putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
    259. imshow("img", img);
    260. waitKey(0);
    261. return 0;
    262. }
    263. void newSvmStudy(vector<NumTrainData>& trainData)
    264. {
    265. int testCount = trainData.size();
    266. Mat m = Mat::zeros(1, featureLen, CV_32FC1);
    267. Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);
    268. Mat res = Mat::zeros(testCount, 1, CV_32SC1);
    269. for (int i= 0; i< testCount; i++)
    270. {
    271. NumTrainData td = trainData.at(i);
    272. memcpy(m.data, td.data, featureLen*sizeof(float));
    273. normalize(m, m);
    274. memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float));
    275. res.at<unsigned int>(i, 0) = td.result;
    276. }
    277. /////////////START SVM TRAINNING//////////////////
    278. CvSVM svm = CvSVM();
    279. CvSVMParams param;
    280. CvTermCriteria criteria;
    281. criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON);
    282. param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria);
    283. svm.train(data, res, Mat(), Mat(), param);
    284. svm.save( "SVM_DATA.xml" );
    285. }
    286. int newSvmPredict()
    287. {
    288. CvSVM svm = CvSVM();
    289. svm.load( "SVM_DATA.xml" );
    290. const char fileName[] = "../res/t10k-images.idx3-ubyte";
    291. const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";
    292. ifstream lab_ifs(labelFileName, ios_base::binary);
    293. ifstream ifs(fileName, ios_base::binary);
    294. if( ifs.fail() == true )
    295. return -1;
    296. if( lab_ifs.fail() == true )
    297. return -1;
    298. char magicNum[4], ccount[4], crows[4], ccols[4];
    299. ifs.read(magicNum, sizeof(magicNum));
    300. ifs.read(ccount, sizeof(ccount));
    301. ifs.read(crows, sizeof(crows));
    302. ifs.read(ccols, sizeof(ccols));
    303. int count, rows, cols;
    304. swapBuffer(ccount);
    305. swapBuffer(crows);
    306. swapBuffer(ccols);
    307. memcpy(&count, ccount, sizeof(count));
    308. memcpy(&rows, crows, sizeof(rows));
    309. memcpy(&cols, ccols, sizeof(cols));
    310. Mat src = Mat::zeros(rows, cols, CV_8UC1);
    311. Mat temp = Mat::zeros(8, 8, CV_8UC1);
    312. Mat m = Mat::zeros(1, featureLen, CV_32FC1);
    313. Mat img, dst;
    314. //Just skip label header
    315. lab_ifs.read(magicNum, sizeof(magicNum));
    316. lab_ifs.read(ccount, sizeof(ccount));
    317. char label = 0;
    318. Scalar templateColor(255, 0, 0);
    319. NumTrainData rtd;
    320. int right = 0, error = 0, total = 0;
    321. int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;
    322. while(ifs.good())
    323. {
    324. //Read label
    325. lab_ifs.read(&label, 1);
    326. label = label + ‘0‘;
    327. //Read data
    328. ifs.read((char*)src.data, rows * cols);
    329. GetROI(src, dst);
    330. //Too small to watch
    331. img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);
    332. resize(dst, img, img.size());
    333. rtd.result = label;
    334. resize(dst, temp, temp.size());
    335. //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);
    336. for(int i = 0; i<8; i++)
    337. {
    338. for(int j = 0; j<8; j++)
    339. {
    340. m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);
    341. }
    342. }
    343. if(total >= count)
    344. break;
    345. normalize(m, m);
    346. char ret = (char)svm.predict(m);
    347. if(ret == label)
    348. {
    349. right++;
    350. if(total <= 5000)
    351. right_1++;
    352. else
    353. right_2++;
    354. }
    355. else
    356. {
    357. error++;
    358. if(total <= 5000)
    359. error_1++;
    360. else
    361. error_2++;
    362. }
    363. total++;
    364. #if(SHOW_PROCESS)
    365. stringstream ss;
    366. ss << "Number " << label << ", predict " << ret;
    367. string text = ss.str();
    368. putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
    369. imshow("img", img);
    370. if(waitKey(0)==27) //ESC to quit
    371. break;
    372. #endif
    373. }
    374. ifs.close();
    375. lab_ifs.close();
    376. stringstream ss;
    377. ss << "Total " << total << ", right " << right <<", error " << error;
    378. string text = ss.str();
    379. putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
    380. imshow("img", img);
    381. waitKey(0);
    382. return 0;
    383. }
    384. int main( int argc, char *argv[] )
    385. {
    386. #if(ON_STUDY)
    387. int maxCount = 60000;
    388. ReadTrainData(maxCount);
    389. //newRtStudy(buffer);
    390. newSvmStudy(buffer);
    391. #else
    392. //newRtPredict();
    393. newSvmPredict();
    394. #endif
    395. return 0;
    396. }
    397. //from: http://blog.csdn.net/yangtrees/article/details/7458466
时间: 2024-10-10 21:55:23

学习OpenCV——SVM 手写数字检测的相关文章

SVM 手写数字识别

初次是根据“支持向量机通俗导论(理解SVM的三层境界)”对SVM有了简单的了解.总的来说其主要的思想可以概括为以下两点(也是别人的总结) 1.SVM是对二分类问题在线性可分的情况下提出的,当样本线性不可分时,它通过非线性的映射算法,将在低维空间线性不可分的样本映射到高维的特征空间使其线性可分,从而使得对非线性可分样本进行线性分类. 2.SVM是建立在统计学习理论的 VC理论和结构风险最小化原理基础上的,在保证样本分类精度的前提下,建立最优的分割超平面,使得学习器有较好的全局性和推广性. 第一个能

mnist手写数字检测

# -*- coding: utf-8 -*- """ Created on Tue Apr 23 06:16:04 2019 @author: 92958 """ import numpy as np import tensorflow as tf #下载并载入mnist(55000*28*28图片) #from tensorflow.examples.tutorials.mnist import input_data #创造变量mnist,用

简单HOG+SVM mnist手写数字分类

使用工具 :VS2013 + OpenCV 3.1 数据集:minst 训练数据:60000张 测试数据:10000张 输出模型:HOG_SVM_DATA.xml 数据准备 train-images-idx3-ubyte.gz:  training set images (9912422 bytes) train-labels-idx1-ubyte.gz:  training set labels (28881 bytes) t10k-images-idx3-ubyte.gz:   test s

《神经网络和深度学习》系列文章一:使用神经网络识别手写数字

出处: Michael Nielsen的<Neural Network and Deep Leraning> 本节译者:哈工大SCIR硕士生 徐梓翔 (https://github.com/endyul) 声明:我们将不定期连载该书的中文翻译,如需转载请联系[email protected],未经授权不得转载. “本文转载自[哈工大SCIR]微信公众号,转载已征得同意.” 使用神经网络识别手写数字 感知机 sigmoid神经元 神经网络的结构 用简单的网络结构解决手写数字识别 通过梯度下降法学

使用神经网络来识别手写数字【译(三)- 用Python代码实现

实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样失败手写数字. 我们也难怪Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNIST数据.如果有 github 账号,你可以将这些代码库克隆下来, git clone https://github.com/mnielsen/neural-networks-and-deep-learning.git 或者你可以到这里 下载. Incidentally, 当我先前说到 MNIS

keras入门实战:手写数字识别

近些年由于理论知识的硬件的快速发展,使得深度学习达到了空前的火热.深度学习已经在很多方面都成功得到了应用,尤其是在图像识别和分类领域,机器识别图像的能力甚至超过了人类. 本文用深度学习Python库Keras实现深度学习入门教程mnist手写数字识别.mnist手写数字识别是机器学习和深度学习领域的"hello world",MNIST数据集是手写数字的数据集合,训练集规模为60000,测试集为10000. 本文的内容包括: 如何用Keras加载MNIST数据集 对于MNIST问题如何

手写数字识别【QT+OpenCV】

[说明] 手写数字识别的实现方式很多. 本文尽量将其简化,以让大家能够快速了解怎样实现一个动起来的系统. [截图] [思路] 1.特征提取 将图像划分为5*5大小的区域,然后计算该区域内黑色(或白色)的像素点所占比例. 将需要测试的图像.用来分类的图像都进行特征提取. 2.计算当前的测试图像与用来分类的图像之间的欧氏距离. 3.找出欧式距离最小的值即为与当前测试图像最匹配的图像,即将该图像所代表的数字作为当前测试图像的结果. 4.为了处理上的方便,做了简化处理如下: 4.1仅仅选用10幅用来分类

利用手写数字识别项目详细描述BP深度神经网络的权重学习

本篇文章是针对学习<深度学习入门>(由日本学者斋藤康毅所著陆羽杰所译)中关于神经网络的学习一章来总结归纳一些收获. 本书提出神经网络的学习分四步:1.mini-batch 2.计算梯度 3.更新参数 4.重复前面步骤 1.从识别手写数字项目学习神经网络 所谓“从数据中学习”是指 可以由数据#自动决定权重#.当解决较为简单的问题,使用简单的神经网络时,网络里的权重可以人为的手动设置,去提取输入信息中特定的特征.但是在实际的神经网络中,参数往往是成千上万,甚至可能上亿的权重,这个时候人为手动设置是

学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology database),简单机器视觉数据集,28X28像素手写数字,只有灰度值信息,空白部分为0,笔迹根据颜色深浅取[0, 1], 784维,丢弃二维空间信息,目标分0~9共10类.数据加载,data.read_data_sets, 55000个样本,测试集10000样本,验证集5000样本.样本标注信