使用MLP解决OCR问题(OpenCV)(下)

分类模型:

分类模型涉及的一个比较关键的问题就是输出的10维向量是如何与具体的类别挂钩的。实际上:10维向量的每一位都代表一类,在对于训练集的表达中,如果输入数据是0,则10维向量的第一位赋值为1,其余均为0。即0对应[1,0,0,0,0,0,0,0,0,0]。MLP模型训练完成后,就需要对用户输入的数据所属类别进行判定。这时得到的输出数据基本不可能是正好的所属类为1,其他位置为0。那具体的分类方法就是判断这10位中哪一位最大,则这个输入就属于哪一类。

test_sample = test_set.row(tsample);

		//分类器的输出

		nnetwork.predict(test_sample, classificationResult);
		//输出向量中最大的值即为样本所属的类

		// 以下的工作就是找到最大的数是哪个
		int maxIndex = 0;
		float value=0.0f;
		float maxValue=classificationResult.at<float>(0,0);
		for(int index=1;index<CLASSES;index++)
		{
			value = classificationResult.at<float>(0,index);
			if(value>maxValue)
			{
				maxValue = value;
				maxIndex=index;

			}
		}

		printf("Testing Sample %i -> class result (digit %d)\n", tsample, maxIndex);

测试集:

测试集是用来测试训练好的模型是否有良好的泛化性,即是否能识别训练集以外的数据。所以这里就要求训练集与测试集最好不要有相同的图片。如果测试结果不满意,则需要增加训练集重新训练或者调整MLP的参数。

文章的最后将整个CPP文件分享给大家

#include <opencv2/opencv.hpp>
#include <string.h>
#include <fstream>
#include <stdio.h>
using namespace std;
using namespace cv;

#define ATTRIBUTES 135  //每一个样本的像素总数.9X15
#define CLASSES 10
#define TRAINING_SAMPLES 460
#define TEST_SAMPLES 200

//将int型转为string型
string convertInt(int number)
{
	stringstream ss;
	ss << number;
	return ss.str();
}
//将图像矩阵转为一个向量
void convertToPixelValueArray(Mat &img,int pixelarray[])
{
	int i =0;
	for(int x=0;x<15;x++)
	{
		for(int y=0;y<9;y++)
		{
			pixelarray[i]=(img.at<uchar>(x,y)==255)?1:0;
			i++;

		}

	}
}
//读取样本集,并将样本集按照一个样本一行的形式写入一个文件
void readFile(string datasetPath,int samplesPerClass,string outputfile )
{
	fstream file(outputfile.c_str(),ios::out);
	for(int sample = 1; sample<=samplesPerClass;sample++)
	{
		for(int digit=0;digit<10;digit++)
		{   //构建图像路径
			string imagePath = datasetPath+convertInt(digit)+"\\"+convertInt(sample)+".bmp";

			Mat img = imread(imagePath,0);
			Mat output;

			int pixelValueArray[135];

			//图像矩阵转为向量
			convertToPixelValueArray(img,pixelValueArray);
			//将这个向量写入文件
			for(int d=0;d<135;d++){
				file<<pixelValueArray[d]<<",";
			}
			//将所属类别写入文件(行尾)
			file<<digit<<"\n";

		}
	}
	file.close();
}
//从样本集生成的文件中读取数据
void read_dataset(char *filename, Mat &data, Mat &classes,  int total_samples)
{

	int label;
	float pixelvalue;
	FILE* inputfile = fopen( filename, "r" );

	for(int row = 0; row < total_samples; row++)
	{

		for(int col = 0; col <=ATTRIBUTES; col++)
		{

			if (col < ATTRIBUTES){

				fscanf(inputfile, "%f,", &pixelvalue);
				data.at<float>(row,col) = pixelvalue;

			}
			else if (col == ATTRIBUTES){
				//输出向量的结构是应属类别的位置赋值为1,其余赋值为0
				fscanf(inputfile, "%i", &label);
				classes.at<float>(row,label) = 1.0;

			}
		}
	}

	fclose(inputfile);

}

int main( int argc, char** argv )
{

	readFile("E:\\workdir\\NN\\character_train\\",46,"E:\\workdir\\NN\\trainingset.txt");
	readFile("E:\\workdir\\NN\\character_test\\",20,"E:\\workdir\\NN\\testset.txt");

	//训练样本集构成的矩阵
	Mat training_set(TRAINING_SAMPLES,ATTRIBUTES,CV_32F);
	//训练样本集的标签(输出向量)构成的矩阵
	Mat training_set_classifications(TRAINING_SAMPLES, CLASSES, CV_32F,Scalar(-1));
	//测试样本集构成的矩阵
	Mat test_set(TEST_SAMPLES,ATTRIBUTES,CV_32F);
	//测试样本集的标签(输出向量)构成的矩阵
	Mat test_set_classifications(TEST_SAMPLES,CLASSES,CV_32F,Scalar(-1));

	//
	Mat classificationResult(1, CLASSES, CV_32F);

	read_dataset("E:\\workdir\\NN\\trainingset.txt", training_set, training_set_classifications, TRAINING_SAMPLES);
	read_dataset("E:\\workdir\\NN\\testset.txt", test_set, test_set_classifications, TEST_SAMPLES);

	// 定义MLP的结构
	// 神经网络总共有三层
	// - 135输入节点
	// - 16 隐藏节点
	// - 10 输出节点.

	cv::Mat layers(3,1,CV_32S);
	layers.at<int>(0,0) = ATTRIBUTES;//input layer
	layers.at<int>(1,0)=16;//hidden layer
	layers.at<int>(2,0) =CLASSES;//output layer

	//创建神经网络
	//for more details check http://docs.opencv.org/modules/ml/doc/neural_networks.html
	CvANN_MLP nnetwork(layers, CvANN_MLP::SIGMOID_SYM,2.0/3.0,1);

	CvANN_MLP_TrainParams params(                                  

		// 终止训练在 1000 次迭代之后
		// 或者神经网络的权值某次迭代
		// 之后发生了很小的改变
		cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 1000, 0.000001),
		// 使用BP算法训练
		CvANN_MLP_TrainParams::BACKPROP,
		// BP算法的系数
		// recommended values taken from http://docs.opencv.org/modules/ml/doc/neural_networks.html#cvann-mlp-trainparams
		0.1,
		0.1);

	// 训练神经网络

	printf( "\nUsing training dataset\n");
	int iterations = nnetwork.train(training_set, training_set_classifications,cv::Mat(),cv::Mat(),params);
	printf( "Training iterations: %i\n\n", iterations);

	// 保存模型到一个XML文件
	CvFileStorage* storage = cvOpenFileStorage( "E:\\workdir\\NN\\param.xml", 0, CV_STORAGE_WRITE );
	nnetwork.write(storage,"DigitOCR");
	cvReleaseFileStorage(&storage);

	// 对生成的模型进行测试.
	cv::Mat test_sample;

	int correct_class = 0;

	int wrong_class = 0;

	//分类矩阵记录某个样本分到某类的次数.
	int classification_matrix[CLASSES][CLASSES]={{}};

	for (int tsample = 0; tsample < TEST_SAMPLES; tsample++)
	{
		test_sample = test_set.row(tsample);

		//分类器的输出

		nnetwork.predict(test_sample, classificationResult);
		//输出向量中最大的值即为样本所属的类

		// 以下的工作就是找到最大的数是哪个
		int maxIndex = 0;
		float value=0.0f;
		float maxValue=classificationResult.at<float>(0,0);
		for(int index=1;index<CLASSES;index++)
		{
			value = classificationResult.at<float>(0,index);
			if(value>maxValue)
			{
				maxValue = value;
				maxIndex=index;

			}
		}

		printf("Testing Sample %i -> class result (digit %d)\n", tsample, maxIndex);

		//现在比较神经网络的预测结果与真实结果. 如果分类正确
		//test_set_classifications[tsample][ maxIndex] 应该是 1.
		//如果分类错误, 记录下来.
		if (test_set_classifications.at<float>(tsample, maxIndex)!=1.0f)
		{

			wrong_class++;

			//标记分类矩阵
			for(int class_index=0;class_index<CLASSES;class_index++)
			{
				if(test_set_classifications.at<float>(tsample, class_index)==1.0f)
				{

					classification_matrix[class_index][maxIndex]++;// A class_index sample was wrongly classified as maxindex.
					break;
				}
			}

		}
		else
		{
			correct_class++;
			classification_matrix[maxIndex][maxIndex]++;
		}
	}

	//输出测试结果
	printf( "\nResults on the testing dataset\n"
		"\tCorrect classification: %d (%g%%)\n"
		"\tWrong classifications: %d (%g%%)\n",
		correct_class, (double) correct_class*100/TEST_SAMPLES,
		wrong_class, (double) wrong_class*100/TEST_SAMPLES);
	cout<<"   ";
	for (int i = 0; i < CLASSES; i++)
	{
		cout<< i<<"\t";
	}
	cout<<"\n";
	for(int row=0;row<CLASSES;row++)
	{
		cout<<row<<"  ";
		for(int col=0;col<CLASSES;col++)
		{
			cout<<classification_matrix[row][col]<<"\t";
		}
		cout<<"\n";
	}

	return 0;

}

参考文献:   http://www.nithinrajs.in/ocr-using-artificial-neural-network-opencv-part-1/

时间: 2024-12-29 16:51:23

使用MLP解决OCR问题(OpenCV)(下)的相关文章

XE6移动开发环境搭建之IOS篇(5):解决Windows和虚拟机下Mac OSX的共享问题(有图有真相)

XE6移动开发环境搭建之IOS篇(5):解决Windows和虚拟机下Mac OSX的共享问题(有图有真相) 2014-08-20 20:28 网上能找到的关于Delphi XE系列的移动开发环境的相关文章甚少,本文尽量以详细的内容.傻瓜式的表达来告诉你想要的答案. 在安装XE6 PAServer前,我们先解决Windows和虚拟机下Mac的文件共享问题,由于虚拟机和我们安装的XE6是同一台电脑,所以此问题很好解决.网上相传有很多的共享大法,但是在WIN7这种权限管制得过份的系统下显得相对复杂了,

解决php 5.4下dedecms登陆后台空白,标题不能为空错误

这两天有人反应新版的php-fpm的php版本为5.4.7对dedecms5.6兼容性不好. dedecms安装完成后会出现登陆后台空白,发布文章时提示"标题不能为空". 1.解决dedecms登陆后台空白错误因为php5.4的版本废除了session_register,所以需要去掉session_register函数 修改:"include/userlogin.class.php",注释掉session_register,修改后如下//@session_regis

解决Xilinx_ISE在Win8下打开崩溃闪退的方法

解决Xilinx_ISE在Win8下打开崩溃闪退的方法 在64位windows8或者8.1上安装xilinx ise之后,加载 licence或者保存文件的时候,ise应用程序就会崩溃,出现闪退的情况. 修复方法: 第一步: 找到xilinx安装文件下的子文件,我的是安装在D盘. [plain] view plaincopy D:\Xilinx\14.4\ISE_DS\ISE\lib\nt64 在这个文件夹中搜索文件 libPortability 会出来两个文件 [plain] view pla

解决IE和firefox 下flash盖住div的问题(转载)

原文地址:http://www.oschina.net/question/171410_26563 做的企业站 顶部是flash的滚动图片. 右侧是在线客户,但是flash盖住了在线客户的div 网上搜索解决方法如下: <object classid="clsid:D27CDB6E-AE6D-11cf-96B8-444553540000" codebase="http://download.macromedia.com/pub/shockwave/cabs/flash/

如何解决Windows 10系统下设备的声音问题

如何解决Windows 10系统下设备的声音问题? 请阅读下面的说明来解决Windows 10设备上的声音问题. 1. 检查设备管理器 打开开始菜单,键入设备管理器, 从出现的结果中选择并打开它. 在声音.视频和游戏控制器栏目下, 选择并打开你的声卡 . 选择 驱动程序 一栏, 并选择 更新驱动程序. 如果系统没有找到新的驱动,可以尝试在ASUS官网寻找驱动. 如果上述步骤无效,尝试重装声卡驱动: 打开 设备管理器, 右击声卡驱动, 选择 卸载. 重启电脑,系统就会自动尝试重装声卡驱动. 如果无

解决QT5.3.1下触控笔无法工作的问题

刚进新公司,据开发人员说存在QT5.3.1下触控笔无法工作,而在QT5.2.1下能正常工作,研究了一下. Steps: 1. 首先当然是看下问题是否真的存在.测试情况: 环境 结果 备注 QT5.2.1, win 8.1 32bit, mingw 触控笔正常工作 QT5.3.1, win 8.1 32bit, mingw 触控笔无法工作 用手指可以正常触控 2. 目前的情况来看应该是QT发布QT5.3.x时引入的新BUG,决定到QT-PROJECT上的BUGREPORTS搜下是否已经有此BUG,

解决Extjs有IE下z-index属性的问题

在用Extjs时,有时候,在Google浏览器上面没有任何问题,但是相同的页面在IE下面就会有问题,直接报错,点击中断,进行后可以看到如下的信息: Google里面没这个问题,加一句代码就能解决在窗体的构造函数里面加上一行代码  style: 'z-index: -1;', 以后如果出现类似的问题,如果中断,进去后看到如下的提示,并且google中没有问题,那么就加上这么一行代码,具体原因还不清楚,但是这行代码可以解决这个问题 解决Extjs有IE下z-index属性的问题

解决谷歌浏览器在win8下没有注册类的问题

在网上搜索了很多方法,终于找到一种有用的,分享下 新建一个txt,里面存放代码 Windows Registry Editor Version 5.00 [HKEY_LOCAL_MACHINE\SOFTWARE\Classes\ChromeHTML\shell\open\command] @="\"C:\\Program Files (x86)\\Google\\Chrome\\Application\\chrome.exe\" -- \"%1\"&quo

解决svn在win7下安装后右键无菜单项的问题

解决svn在win7下安装后右键无菜单的问题.该版本为1.8.10版本,其中包括安装包跟中文插件包,请先安装TortoiseSVN-1.8.10.26129-win32-svn-1.8.11.1420009704:然后安装LanguagePack_1.8.10.26129-win32-zh_CN.msi 下载无需积分,下载地址:http://download.csdn.net/detail/a358763471/9058629 注意事项:安装之前请确保已经卸载现有版本,并且最好用清理工具清一下注