用SVM(有核和无核函数)进行MNIST手写字体的分类

1.普通SVM分类MNIST数据集

 1 #导入必备的包
 2 import numpy as np
 3 import struct
 4 import matplotlib.pyplot as plt
 5 import os
 6 ##加载svm模型
 7 from sklearn import svm
 8 ###用于做数据预处理
 9 from sklearn import preprocessing
10 import time
11
12 #加载数据的路径
13 path=‘./dataset/mnist/raw‘
14 def load_mnist_train(path, kind=‘train‘):
15     labels_path = os.path.join(path,‘%s-labels-idx1-ubyte‘% kind)
16     images_path = os.path.join(path,‘%s-images-idx3-ubyte‘% kind)
17     with open(labels_path, ‘rb‘) as lbpath:
18         magic, n = struct.unpack(‘>II‘,lbpath.read(8))
19         labels = np.fromfile(lbpath,dtype=np.uint8)
20     with open(images_path, ‘rb‘) as imgpath:
21         magic, num, rows, cols = struct.unpack(‘>IIII‘,imgpath.read(16))
22         images = np.fromfile(imgpath,dtype=np.uint8).reshape(len(labels), 784)
23     return images, labels
24 def load_mnist_test(path, kind=‘t10k‘):
25     labels_path = os.path.join(path,‘%s-labels-idx1-ubyte‘% kind)
26     images_path = os.path.join(path,‘%s-images-idx3-ubyte‘% kind)
27     with open(labels_path, ‘rb‘) as lbpath:
28         magic, n = struct.unpack(‘>II‘,lbpath.read(8))
29         labels = np.fromfile(lbpath,dtype=np.uint8)
30     with open(images_path, ‘rb‘) as imgpath:
31         magic, num, rows, cols = struct.unpack(‘>IIII‘,imgpath.read(16))
32         images = np.fromfile(imgpath,dtype=np.uint8).reshape(len(labels), 784)
33     return images, labels
34 train_images,train_labels=load_mnist_train(path)
35 test_images,test_labels=load_mnist_test(path)
36
37 X=preprocessing.StandardScaler().fit_transform(train_images)
38 X_train=X[0:60000]
39 y_train=train_labels[0:60000]
40
41 print(time.strftime(‘%Y-%m-%d %H:%M:%S‘))
42 model_svc = svm.LinearSVC()
43 #model_svc = svm.SVC()
44 model_svc.fit(X_train,y_train)
45 print(time.strftime(‘%Y-%m-%d %H:%M:%S‘))
46
47 ##显示前30个样本的真实标签和预测值,用图显示
48 x=preprocessing.StandardScaler().fit_transform(test_images)
49 x_test=x[0:10000]
50 y_pred=test_labels[0:10000]
51 print(model_svc.score(x_test,y_pred))
52 y=model_svc.predict(x)
53
54 fig1=plt.figure(figsize=(8,8))
55 fig1.subplots_adjust(left=0,right=1,bottom=0,top=1,hspace=0.05,wspace=0.05)
56 for i in range(100):
57     ax=fig1.add_subplot(10,10,i+1,xticks=[],yticks=[])
58     ax.imshow(np.reshape(test_images[i], [28,28]),cmap=plt.cm.binary,interpolation=‘nearest‘)
59     ax.text(0,2,"pred:"+str(y[i]),color=‘red‘)
60     #ax.text(0,32,"real:"+str(test_labels[i]),color=‘blue‘)
61 plt.show()

2.运行结果:

开始时间:2018-11-17 08:31:09

结束时间:2018-11-17 08:53:04

用时:21分55秒

精度:0.9122

预测图片:

原文地址:https://www.cnblogs.com/yaowuyangwei521/p/9975714.html

时间: 2024-08-15 07:44:18

用SVM(有核和无核函数)进行MNIST手写字体的分类的相关文章

【OpenCV】opencv3.0中的SVM训练 mnist 手写字体识别

前言: SVM(支持向量机)一种训练分类器的学习方法 mnist 是一个手写字体图像数据库,训练样本有60000个,测试样本有10000个 LibSVM 一个常用的SVM框架 OpenCV3.0 中的ml包含了很多的ML框架接口,就试试了. 详细的OpenCV文档:http://docs.opencv.org/3.0-beta/doc/tutorials/ml/introduction_to_svm/introduction_to_svm.html mnist数据下载:http://yann.l

简单HOG+SVM mnist手写数字分类

使用工具 :VS2013 + OpenCV 3.1 数据集:minst 训练数据:60000张 测试数据:10000张 输出模型:HOG_SVM_DATA.xml 数据准备 train-images-idx3-ubyte.gz:  training set images (9912422 bytes) train-labels-idx1-ubyte.gz:  training set labels (28881 bytes) t10k-images-idx3-ubyte.gz:   test s

学习OpenCV——SVM 手写数字检测

转自http://blog.csdn.net/firefight/article/details/6452188 是MNIST手写数字图片库:http://code.google.com/p/supplement-of-the-mnist-database-of-handwritten-digits/downloads/list 其他方法:http://blog.csdn.net/onezeros/article/details/5672192 使用OPENCV训练手写数字识别分类器 1,下载训

前馈全连接神经网络和函数逼近、时间序列预测、手写数字识别

https://www.cnblogs.com/conmajia/p/annt-feed-forward-fully-connected-neural-networks.html Andrew Kirillov 著Conmajia 译2019 年 1 月 12 日 原文发表于 CodeProject(2018 年 9 月 28 日). 中文版有小幅修改,已获作者本人授权. 本文介绍了如何使用 ANNT 神经网络库生成前馈全连接神经网络并应用到问题求解. 全文约 12,000 字,建议阅读时间 3

libsvm代码阅读(2):svm.cpp浅谈和函数指针(转)

svm.cpp浅谈 svm.cpp总共有3159行代码,实现了svm算法的核心功能,里面总共有Cache.Kernel.ONE_CLASS_Q.QMatrix.Solver.Solver_NU.SVC_Q.SVR_Q 8个类(如下图1所示),而它们之间的继承和组合关系如图2.图3所示.在这些类中Cache.Kernel.Solver是核心类,对整个算法起支撑作用.在以后的博文中我们将对这3个核心类做重点注解分析,另外还将对svm.cpp中的svm_train函数做一个注解分析. 图1 图2 图3

SVM 手写数字识别

初次是根据“支持向量机通俗导论(理解SVM的三层境界)”对SVM有了简单的了解.总的来说其主要的思想可以概括为以下两点(也是别人的总结) 1.SVM是对二分类问题在线性可分的情况下提出的,当样本线性不可分时,它通过非线性的映射算法,将在低维空间线性不可分的样本映射到高维的特征空间使其线性可分,从而使得对非线性可分样本进行线性分类. 2.SVM是建立在统计学习理论的 VC理论和结构风险最小化原理基础上的,在保证样本分类精度的前提下,建立最优的分割超平面,使得学习器有较好的全局性和推广性. 第一个能

【机器学习算法实现】kNN算法__手写识别——基于Python和NumPy函数库

[机器学习算法实现]系列文章将记录个人阅读机器学习论文.书籍过程中所碰到的算法,每篇文章描述一个具体的算法.算法的编程实现.算法的具体应用实例.争取每个算法都用多种语言编程实现.所有代码共享至github:https://github.com/wepe/MachineLearning-Demo     欢迎交流指正! (1)kNN算法_手写识别实例--基于Python和NumPy函数库 1.kNN算法简介 kNN算法,即K最近邻(k-NearestNeighbor)分类算法,是最简单的机器学习算

oracle 10G 没有 PIVOT 函数怎么办,自己写一个不久有了

众所周知,静态SQL的输出结构必须也是静态的.对于经典的行转列问题,如果行数不定导致输出的列数不定,标准的答案就是使用动态SQL, 到11G里面则有XML结果的PIVOT. 但是 oracle 10G 没有 PIVOT 函数怎么办,自己写一个不久有了.上代码 直接点. CREATE OR REPLACEtype PivotImpl_shx as object( ret_type anytype, -- The return type of the table function stmt varc

js函数定义 参数只要写名称就可以了

js函数定义  参数只要写名称就可以了 以下为标准: function add(type)  { } 不要写成下面这个样子 function add(var type)  { } 哎 妹的  老何java混淆