【OpenCV】opencv3.0中的SVM训练 mnist 手写字体识别

前言:

SVM(支持向量机)一种训练分类器的学习方法

mnist 是一个手写字体图像数据库,训练样本有60000个,测试样本有10000个

LibSVM 一个常用的SVM框架

OpenCV3.0 中的ml包含了很多的ML框架接口,就试试了。

详细的OpenCV文档:http://docs.opencv.org/3.0-beta/doc/tutorials/ml/introduction_to_svm/introduction_to_svm.html

mnist数据下载:http://yann.lecun.com/exdb/mnist/

LibSVM下载:http://www.csie.ntu.edu.tw/~cjlin/libsvm/

========================我是分割线=============================

训练的过程大致如下:

1. 读取mnist训练集数据

2. 训练

3. 读取mnist测试数据,对比预测结果,得到错误率

具体实现:

1. mnist给出的数据文件是二进制文件

四个文件,解压后如下

   

  "train-images.idx3-ubyte" 二进制文件,存储了头文件信息以及60000张28*28图像pixel信息(用于训练)
  "train-labels.idx1-ubyte" 二进制文件,存储了头文件信息以及60000张图像label信息
  "t10k-images.idx3-ubyte"二进制文件,存储了头文件信息以及10000张28*28图像pixel信息(用于测试)
  "t10k-labels.idx1-ubyte"二进制文件,存储了头文件信息以及10000张图像label信息

  因为OpenCV中没有直接导入MINST数据的文件,所以需要自己写函数来读取

  首先要知道,MNIST数据的数据格式

  

   IMAGE FILE包含四个int型的头部数据(magic number,number_of_images, number_of_rows, number_of_columns)

余下的每一个byte表示一个pixel的数据,范围是0-255(可以在读入的时候scale到0~1的区间)

LABEL FILE包含两个int型的头部数据(magic number, number of items)

余下的每一个byte表示一个label数据,范围是0-9

   注意(第一个坑):MNIST是大端存储,然而大部分的Intel处理器都是小端存储,所以对于int、long、float这些多字节的数据类型,就要一个一个byte地翻转过来,才能正确显示。

  

 1 //翻转
 2 int reverseInt(int i) {
 3     unsigned char c1, c2, c3, c4;
 4
 5     c1 = i & 255;
 6     c2 = (i >> 8) & 255;
 7     c3 = (i >> 16) & 255;
 8     c4 = (i >> 24) & 255;
 9
10     return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
11 }

  然后读取MNIST文件,但是它是二进制文件,打开方式

  所以不能用

  ifstream file(fileName);

  而要改成

  ifstream file(fileName, ios::binary);

  注意(第二个坑):如果用第一条指令来打开文件,不会报错,但是数据会出现错误,头部数据仍然正确,但是后面的pixel数据大部分都是0,我刚开始没注意,开始training的时候发现等了很久...真的是很久...(7+ hours)...估计是达到迭代终止的最大次数了,才停下来的

  嗯,stack overflow上也有类似的提问:

  

  注意(第三个坑):

  training时,IMAGE和LABEL的数据分别都放进一个MAT中存储,但是只能是CV32_F或者CV32_S的格式,不然会assertion报错

  OPENCV给出的文档中,例子是这样的:(但是predict的时候又会要求label的格式是unsigned int)所以...可以设置data的Mat格式为CV_32FC1,label的Mat格式为CV_32SC1

  

  顺便地,图像训练数据的转换存储格式(http://stackoverflow.com/questions/14694810/using-opencv-and-svm-with-images?rq=1)

  

  最后,为了验证读取数据的正确性,一个有效的办法就是输出第一个和最后一个数据(可以输出答应第一个/最后一个image以及label)

2. 训练

  (此处我是直接对原图像训练,并没有提取任何的特征)

  也有人建议这里应该对图像做HOG特征提取,再配合label训练(我还没试过...不知道效果如何...)

  

  opencv3.0和2.4的SVM接口有不同,基本可以按照以下的格式来执行:

ml::SVM::Params params;
params.svmType = ml::SVM::C_SVC;
params.kernelType = ml::SVM::POLY;
params.gamma = 3;
Ptr<ml::SVM> svm = ml::SVM::create(params);
Mat trainData; // 每行为一个样本
Mat labels;
svm->train( trainData , ml::ROW_SAMPLE , labels );
// ...

svm->save("....");//文件形式为xml,可以保存在txt或者xml文件中
Ptr<SVM> svm=statModel::load<SVM>("....");

Mat query; // 输入, 1个通道
Mat res;   // 输出
svm->predict(query, res);

但是要注意,如果报错的话最好去看opencv3.0的文档,里面有函数原型和解释,我在实际操作的过程中,也做了一些改动

   1)设置参数

    SVM的参数有很多,但是与C_SVC和RBF有关的就只有gamma和C,所以设置这两个就好,终止条件设置和默认一样,由经验可得(其实是查阅了很多的资料,把gamma设置成0.01,这样训练收敛速度会快很多)

Ptr<SVM> svm = SVM::create();
svm->setType(SVM::C_SVC);
svm->setKernel(SVM::RBF);
svm->setGamma(0.01);
svm->setC(10.0);
svm->setTermCriteria(TermCriteria(CV_TERMCRIT_EPS, 1000,FLT_EPSILON));

  svm_type –指定SVM的类型,下面是可能的取值:

  CvSVM::C_SVC C类支持向量分类机。 n类分组 (n \geq 2),允许用异常值惩罚因子C进行不完全分类。
  CvSVM::NU_SVC \nu类支持向量分类机。n类似然不完全分类的分类器。参数为 \nu 取代C(其值在区间【0,1】中,nu越大,决策边界越平滑)。
  CvSVM::ONE_CLASS 单分类器,所有的训练数据提取自同一个类里,然后SVM建立了一个分界线以分割该类在特征空间中所占区域和其它类在特征空间中所占区域。
  CvSVM::EPS_SVR \epsilon类支持向量回归机。训练集中的特征向量和拟合出来的超平面的距离需要小于p。异常值惩罚因子C被采用。
  CvSVM::NU_SVR \nu类支持向量回归机。 \nu 代替了 p。

  kernel_type –SVM的内核类型,下面是可能的取值:

  CvSVM::LINEAR 线性内核。没有任何向映射至高维空间,线性区分(或回归)在原始特征空间中被完成,这是最快的选择。K(x_i, x_j) = x_i^T x_j.
  CvSVM::POLY 多项式内核: K(x_i, x_j) = (\gamma x_i^T x_j + coef0)^{degree}, \gamma > 0.
  CvSVM::RBF 基于径向的函数,对于大多数情况都是一个较好的选择: K(x_i, x_j) = e^{-\gamma ||x_i - x_j||^2}, \gamma > 0.
  CvSVM::SIGMOID Sigmoid函数内核:K(x_i, x_j) = \tanh(\gamma x_i^T x_j + coef0).

  degree – 内核函数(POLY)的参数degree。

  gamma – 内核函数(POLY/ RBF/ SIGMOID)的参数\gamma。

  coef0 – 内核函数(POLY/ SIGMOID)的参数coef0。

  Cvalue – SVM类型(C_SVC/ EPS_SVR/ NU_SVR)的参数C。

  nu – SVM类型(NU_SVC/ ONE_CLASS/ NU_SVR)的参数 \nu。

  p – SVM类型(EPS_SVR)的参数 \epsilon。

  class_weights – C_SVC中的可选权重,赋给指定的类,乘以C以后变成 class\_weights_i * C。所以这些权重影响不同类别的错误分类惩罚项。权重越大,某一类别的误分类数据的惩罚项就越大。

  term_crit – SVM的迭代训练过程的中止条件,解决部分受约束二次最优问题。您可以指定的公差和/或最大迭代次数。

   2)训练

Mat trainData;
Mat labels;
trainData = read_mnist_image(trainImage);
labels = read_mnist_label(trainLabel);

svm->train(trainData, ROW_SAMPLE, labels);

   3)保存

svm->save("mnist_dataset/mnist_svm.xml");

3. 测试,比对结果

(此处的FLT_EPSILON是一个极小的数,1.0 - FLT_EPSILON != 1.0)

Mat testData;
Mat tLabel;
testData = read_mnist_image(testImage);
tLabel = read_mnist_label(testLabel);

float count = 0;
for (int i = 0; i < testData.rows; i++) {
    Mat sample = testData.row(i);
    float res = svm1->predict(sample);
    res = std::abs(res - tLabel.at<unsigned int>(i, 0)) <= FLT_EPSILON ? 1.f : 0.f;
    count += res;
}
cout << "正确的识别个数 count = " << count << endl;
cout << "错误率为..." << (10000 - count + 0.0) / 10000 * 100.0 << "%....\n";

这里没有使用svm->predict(query, res);

然后就查看了opencv的文档,当传入数据是Mat 而不是cvMat时,可以利用predict的返回值(float)来判断预测是否正确。

运行结果:

1)1000个训练数据/1000个测试数据

  

2)2000个训练数据/2000个测试数据

  

3)5000个训练数据/5000个测试数据

  

4)10000个训练数据/10000个测试数据

  

5)60000个训练数据/10000个测试数据

  

最后,关于运行时间(在程序正确的前提下,训练时长和初始的参数设置有关),给出我最的运行结果(1000张图是11s左右,6000张是1300s ~ 2000s左右)

代码:

 1 #ifndef MNIST_H
 2 #define MNIST_H
 3
 4 #include <iostream>
 5 #include <string>
 6 #include <fstream>
 7 #include <ctime>
 8 #include <opencv2/opencv.hpp>
 9
10 using namespace cv;
11 using namespace std;
12
13 //小端存储转换
14 int reverseInt(int i);
15
16 //读取image数据集信息
17 Mat read_mnist_image(const string fileName);
18
19 //读取label数据集信息
20 Mat read_mnist_label(const string fileName);
21
22 #endif

mnist.h

  1 #include "mnist.h"
  2
  3 //计时器
  4 double cost_time;
  5 clock_t start_time;
  6 clock_t end_time;
  7
  8 //测试item个数
  9 int testNum = 10000;
 10
 11 int reverseInt(int i) {
 12     unsigned char c1, c2, c3, c4;
 13
 14     c1 = i & 255;
 15     c2 = (i >> 8) & 255;
 16     c3 = (i >> 16) & 255;
 17     c4 = (i >> 24) & 255;
 18
 19     return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
 20 }
 21
 22 Mat read_mnist_image(const string fileName) {
 23     int magic_number = 0;
 24     int number_of_images = 0;
 25     int n_rows = 0;
 26     int n_cols = 0;
 27
 28     Mat DataMat;
 29
 30     ifstream file(fileName, ios::binary);
 31     if (file.is_open())
 32     {
 33         cout << "成功打开图像集 ... \n";
 34
 35         file.read((char*)&magic_number, sizeof(magic_number));
 36         file.read((char*)&number_of_images, sizeof(number_of_images));
 37         file.read((char*)&n_rows, sizeof(n_rows));
 38         file.read((char*)&n_cols, sizeof(n_cols));
 39         //cout << magic_number << " " << number_of_images << " " << n_rows << " " << n_cols << endl;
 40
 41         magic_number = reverseInt(magic_number);
 42         number_of_images = reverseInt(number_of_images);
 43         n_rows = reverseInt(n_rows);
 44         n_cols = reverseInt(n_cols);
 45         cout << "MAGIC NUMBER = " << magic_number
 46             << " ;NUMBER OF IMAGES = " << number_of_images
 47             << " ; NUMBER OF ROWS = " << n_rows
 48             << " ; NUMBER OF COLS = " << n_cols << endl;
 49
 50         //-test-
 51         //number_of_images = testNum;
 52         //输出第一张和最后一张图,检测读取数据无误
 53         Mat s = Mat::zeros(n_rows, n_rows * n_cols, CV_32FC1);
 54         Mat e = Mat::zeros(n_rows, n_rows * n_cols, CV_32FC1);
 55
 56         cout << "开始读取Image数据......\n";
 57         start_time = clock();
 58         DataMat = Mat::zeros(number_of_images, n_rows * n_cols, CV_32FC1);
 59         for (int i = 0; i < number_of_images; i++) {
 60             for (int j = 0; j < n_rows * n_cols; j++) {
 61                 unsigned char temp = 0;
 62                 file.read((char*)&temp, sizeof(temp));
 63                 float pixel_value = float((temp + 0.0) / 255.0);
 64                 DataMat.at<float>(i, j) = pixel_value;
 65
 66                 //打印第一张和最后一张图像数据
 67                 if (i == 0) {
 68                     s.at<float>(j / n_cols, j % n_cols) = pixel_value;
 69                 }
 70                 else if (i == number_of_images - 1) {
 71                     e.at<float>(j / n_cols, j % n_cols) = pixel_value;
 72                 }
 73             }
 74         }
 75         end_time = clock();
 76         cost_time = (end_time - start_time) / CLOCKS_PER_SEC;
 77         cout << "读取Image数据完毕......" << cost_time << "s\n";
 78
 79         imshow("first image", s);
 80         imshow("last image", e);
 81         waitKey(0);
 82     }
 83     file.close();
 84     return DataMat;
 85 }
 86
 87 Mat read_mnist_label(const string fileName) {
 88     int magic_number;
 89     int number_of_items;
 90
 91     Mat LabelMat;
 92
 93     ifstream file(fileName, ios::binary);
 94     if (file.is_open())
 95     {
 96         cout << "成功打开Label集 ... \n";
 97
 98         file.read((char*)&magic_number, sizeof(magic_number));
 99         file.read((char*)&number_of_items, sizeof(number_of_items));
100         magic_number = reverseInt(magic_number);
101         number_of_items = reverseInt(number_of_items);
102
103         cout << "MAGIC NUMBER = " << magic_number << "  ; NUMBER OF ITEMS = " << number_of_items << endl;
104
105         //-test-
106         //number_of_items = testNum;
107         //记录第一个label和最后一个label
108         unsigned int s = 0, e = 0;
109
110         cout << "开始读取Label数据......\n";
111         start_time = clock();
112         LabelMat = Mat::zeros(number_of_items, 1, CV_32SC1);
113         for (int i = 0; i < number_of_items; i++) {
114             unsigned char temp = 0;
115             file.read((char*)&temp, sizeof(temp));
116             LabelMat.at<unsigned int>(i, 0) = (unsigned int)temp;
117
118             //打印第一个和最后一个label
119             if (i == 0) s = (unsigned int)temp;
120             else if (i == number_of_items - 1) e = (unsigned int)temp;
121         }
122         end_time = clock();
123         cost_time = (end_time - start_time) / CLOCKS_PER_SEC;
124         cout << "读取Label数据完毕......" << cost_time << "s\n";
125
126         cout << "first label = " << s << endl;
127         cout << "last label = " << e << endl;
128     }
129     file.close();
130     return LabelMat;
131 }

mnist.cpp

  1 /*
  2 svm_type –
  3 指定SVM的类型,下面是可能的取值:
  4 CvSVM::C_SVC C类支持向量分类机。 n类分组  (n \geq 2),允许用异常值惩罚因子C进行不完全分类。
  5 CvSVM::NU_SVC \nu类支持向量分类机。n类似然不完全分类的分类器。参数为 \nu 取代C(其值在区间【0,1】中,nu越大,决策边界越平滑)。
  6 CvSVM::ONE_CLASS 单分类器,所有的训练数据提取自同一个类里,然后SVM建立了一个分界线以分割该类在特征空间中所占区域和其它类在特征空间中所占区域。
  7 CvSVM::EPS_SVR \epsilon类支持向量回归机。训练集中的特征向量和拟合出来的超平面的距离需要小于p。异常值惩罚因子C被采用。
  8 CvSVM::NU_SVR \nu类支持向量回归机。 \nu 代替了 p。
  9
 10 可从 [LibSVM] 获取更多细节。
 11
 12 kernel_type –
 13 SVM的内核类型,下面是可能的取值:
 14 CvSVM::LINEAR 线性内核。没有任何向映射至高维空间,线性区分(或回归)在原始特征空间中被完成,这是最快的选择。K(x_i, x_j) = x_i^T x_j.
 15 CvSVM::POLY 多项式内核: K(x_i, x_j) = (\gamma x_i^T x_j + coef0)^{degree}, \gamma > 0.
 16 CvSVM::RBF 基于径向的函数,对于大多数情况都是一个较好的选择: K(x_i, x_j) = e^{-\gamma ||x_i - x_j||^2}, \gamma > 0.
 17 CvSVM::SIGMOID Sigmoid函数内核:K(x_i, x_j) = \tanh(\gamma x_i^T x_j + coef0).
 18
 19 degree – 内核函数(POLY)的参数degree。
 20
 21 gamma – 内核函数(POLY/ RBF/ SIGMOID)的参数\gamma。
 22
 23 coef0 – 内核函数(POLY/ SIGMOID)的参数coef0。
 24
 25 Cvalue – SVM类型(C_SVC/ EPS_SVR/ NU_SVR)的参数C。
 26
 27 nu – SVM类型(NU_SVC/ ONE_CLASS/ NU_SVR)的参数 \nu。
 28
 29 p – SVM类型(EPS_SVR)的参数 \epsilon。
 30
 31 class_weights – C_SVC中的可选权重,赋给指定的类,乘以C以后变成 class\_weights_i * C。所以这些权重影响不同类别的错误分类惩罚项。权重越大,某一类别的误分类数据的惩罚项就越大。
 32
 33 term_crit – SVM的迭代训练过程的中止条件,解决部分受约束二次最优问题。您可以指定的公差和/或最大迭代次数。
 34
 35 */
 36
 37
 38 #include "mnist.h"
 39
 40 #include <opencv2/core.hpp>
 41 #include <opencv2/imgproc.hpp>
 42 #include "opencv2/imgcodecs.hpp"
 43 #include <opencv2/highgui.hpp>
 44 #include <opencv2/ml.hpp>
 45
 46 #include <string>
 47 #include <iostream>
 48
 49 using namespace std;
 50 using namespace cv;
 51 using namespace cv::ml;
 52
 53 string trainImage = "mnist_dataset/train-images.idx3-ubyte";
 54 string trainLabel = "mnist_dataset/train-labels.idx1-ubyte";
 55 string testImage = "mnist_dataset/t10k-images.idx3-ubyte";
 56 string testLabel = "mnist_dataset/t10k-labels.idx1-ubyte";
 57 //string testImage = "mnist_dataset/train-images.idx3-ubyte";
 58 //string testLabel = "mnist_dataset/train-labels.idx1-ubyte";
 59
 60 //计时器
 61 double cost_time_;
 62 clock_t start_time_;
 63 clock_t end_time_;
 64
 65 int main()
 66 {
 67
 68     //--------------------- 1. Set up training data ---------------------------------------
 69     Mat trainData;
 70     Mat labels;
 71     trainData = read_mnist_image(trainImage);
 72     labels = read_mnist_label(trainLabel);
 73
 74     cout << trainData.rows << " " << trainData.cols << endl;
 75     cout << labels.rows << " " << labels.cols << endl;
 76
 77     //------------------------ 2. Set up the support vector machines parameters --------------------
 78     Ptr<SVM> svm = SVM::create();
 79     svm->setType(SVM::C_SVC);
 80     svm->setKernel(SVM::RBF);
 81     //svm->setDegree(10.0);
 82     svm->setGamma(0.01);
 83     //svm->setCoef0(1.0);
 84     svm->setC(10.0);
 85     //svm->setNu(0.5);
 86     //svm->setP(0.1);
 87     svm->setTermCriteria(TermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON));
 88
 89     //------------------------ 3. Train the svm ----------------------------------------------------
 90     cout << "Starting training process" << endl;
 91     start_time_ = clock();
 92     svm->train(trainData, ROW_SAMPLE, labels);
 93     end_time_ = clock();
 94     cost_time_ = (end_time_ - start_time_) / CLOCKS_PER_SEC;
 95     cout << "Finished training process...cost " << cost_time_ << " seconds..." << endl;
 96
 97     //------------------------ 4. save the svm ----------------------------------------------------
 98     svm->save("mnist_dataset/mnist_svm.xml");
 99     cout << "save as /mnist_dataset/mnist_svm.xml" << endl;
100
101
102     //------------------------ 5. load the svm ----------------------------------------------------
103     cout << "开始导入SVM文件...\n";
104     Ptr<SVM> svm1 = StatModel::load<SVM>("mnist_dataset/mnist_svm.xml");
105     cout << "成功导入SVM文件...\n";
106
107
108     //------------------------ 6. read the test dataset -------------------------------------------
109     cout << "开始导入测试数据...\n";
110     Mat testData;
111     Mat tLabel;
112     testData = read_mnist_image(testImage);
113     tLabel = read_mnist_label(testLabel);
114     cout << "成功导入测试数据!!!\n";
115
116
117     float count = 0;
118     for (int i = 0; i < testData.rows; i++) {
119         Mat sample = testData.row(i);
120         float res = svm1->predict(sample);
121         res = std::abs(res - tLabel.at<unsigned int>(i, 0)) <= FLT_EPSILON ? 1.f : 0.f;
122         count += res;
123     }
124     cout << "正确的识别个数 count = " << count << endl;
125     cout << "错误率为..." << (10000 - count + 0.0) / 10000 * 100.0 << "%....\n";
126
127     system("pause");
128     return 0;
129 }

main.cpp

一些网站(资料):(其实都很容易搜索到的=_=, 但是搬了人家的东西,就还是贴一下...

http://blog.csdn.net/augusdi/article/details/9005352

http://blog.csdn.net/arthur503/article/details/19974057

http://blog.csdn.net/laihonghuan/article/details/49387237

http://docs.opencv.org/3.0-beta/modules/ml/doc/support_vector_machines.html#prediction-with-svm

http://stackoverflow.com/questions/14694810/using-opencv-and-svm-with-images?rq=1

http://docs.opencv.org/2.4/modules/ml/doc/support_vector_machines.html#cvsvm-train

http://blog.csdn.net/u010869312/article/details/44927721

http://blog.csdn.net/heroacool/article/details/50579955

http://docs.opencv.org/3.0-beta/doc/tutorials/ml/introduction_to_svm/introduction_to_svm.html

http://guyvercz.blog.163.com/blog/static/252545292011112974915402/

http://stackoverflow.com/questions/12993941/how-can-i-read-the-mnist-dataset-with-c?lq=1

时间: 2024-10-16 11:14:34

【OpenCV】opencv3.0中的SVM训练 mnist 手写字体识别的相关文章

Tensorflow实践 mnist手写数字识别

minst数据集                                         tensorflow的文档中就自带了mnist手写数字识别的例子,是一个很经典也比较简单的入门tensorflow的例子,非常值得自己动手亲自实践一下.由于我用的不是tensorflow中自带的mnist数据集,而是从kaggle的网站下载下来的,数据集有些不太一样,所以直接按照tensorflow官方文档上的参数训练的话还是踩了一些坑,特此记录. 首先从kaggle网站下载mnist数据集,一份是

tensorflow 基础学习五:MNIST手写数字识别

MNIST数据集介绍: from tensorflow.examples.tutorials.mnist import input_data # 载入MNIST数据集,如果指定地址下没有已经下载好的数据,tensorflow会自动下载数据 mnist=input_data.read_data_sets('.',one_hot=True) # 打印 Training data size:55000. print("Training data size: {}".format(mnist.

Pytorch入门实战一:LeNet神经网络实现 MNIST手写数字识别

记得第一次接触手写数字识别数据集还在学习TensorFlow,各种sess.run(),头都绕晕了.自从接触pytorch以来,一直想写点什么.曾经在2017年5月,Andrej Karpathy发表的一片Twitter,调侃道:l've been using PyTorch a few months now, l've never felt better, l've more energy.My skin is clearer. My eye sight has improved.确实,使用p

用SVM(有核和无核函数)进行MNIST手写字体的分类

1.普通SVM分类MNIST数据集 1 #导入必备的包 2 import numpy as np 3 import struct 4 import matplotlib.pyplot as plt 5 import os 6 ##加载svm模型 7 from sklearn import svm 8 ###用于做数据预处理 9 from sklearn import preprocessing 10 import time 11 12 #加载数据的路径 13 path='./dataset/mn

Caffe 例子使用记录(一)mnist手写数字识别

1. 安装caffe请参考 http://www.cnblogs.com/xuanyuyt/p/5726926.html 2. 下载训练和测试数据.caffe识别leveldb或者lmdb格式的数据,这里提供转换好的LEVELDB格式数据集,解压缩到mnist例子目录下 链接:http://pan.baidu.com/s/1gfjXteV 密码:45j6 3. 打开lenet_solver.prototxt,这里可以自己试着改几个参数看看最终效果 # The train/test net pro

Tensorflow之MNIST手写数字识别:分类问题(2)

整体代码: #数据读取 import tensorflow as tf import matplotlib.pyplot as plt import numpy as np from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/",one_hot=True) #定义待输入数据的占位符 #mnist中每张照片共有28*28=784个像

三种方法实现MNIST 手写数字识别

MNIST数据集下载: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) #one_hot 独热编码,也叫一位有效编码.在任意时候只有一位为1,其他位都是0 1 使用逻辑回归: import tensorflow as tf # 导入数据集 #fr

SVM:利用SVM算法实现手写图片识别(数据集50000张图片)—Jason niu

import mnist_loader # Third-party libraries from sklearn import svm def svm_baseline(): training_data, validation_data, test_data = mnist_loader.load_data() # train clf = svm.SVC() clf.fit(training_data[0], training_data[1]) predictions = [int(a) for

安装MXnet包,实现MNIST手写数体识别

我想写一系列深度学习的简单实战教程,用mxnet做实现平台的实例代码简单讲解深度学习常用的一些技术方向和实战样例.这一系列的主要内容偏向于讲解实际的例子,从样例和代码里中学习解决实际问题.我会默认读者有一定神经网络和深度学习的基础知识,读者在这里不会看到大段推导和理论阐述.基础理论知识十分重要,如果读者对理论知识有兴趣,可以参看已有的深度学习教程补充和巩固理论基础,这里http://deeplearning.net/reading-list/tutorials/有一些不错的理论教程,相关的理论知