基于PyTorch实现MNIST手写字识别

本篇不涉及模型原理,只是分享下代码。想要了解模型原理的可以去看网上很多大牛的博客。

目前代码实现了CNN和LSTM两个网络,整个代码分为四部分:

  • Config:项目中涉及的参数;
  • CNN:卷积神经网络结构;
  • LSTM:长短期记忆网络结构;
  • TrainProcess

    模型训练及评估,参数model控制训练何种模型(CNN or LSTM)。

完整代码

Talk is cheap, show me the code.

# -*- coding: utf-8 -*-

# @author: Awesome_Tang
# @date: 2019-04-05
# @version: python3.7

import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from datetime import datetime

class Config:
    batch_size = 64
    epoch = 10
    alpha = 1e-3

    print_per_step = 100  # 控制输出

class CNN(nn.Module):

    def __init__(self):
        super(CNN, self).__init__()
        """
        Conv2d参数:
        第一位:input channels  输入通道数
        第二位:output channels 输出通道数
        第三位:kernel size 卷积核尺寸
        第四位:stride 步长,默认为1
        第五位:padding size 默认为0,不补
        """
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.fc1 = nn.Sequential(
            nn.Linear(64 * 5 * 5, 128),
            nn.BatchNorm1d(128),
            nn.ReLU()
        )

        self.fc2 = nn.Sequential(
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),  # 加快收敛速度的方法(注:批标准化一般放在全连接层后面,激活函数层的前面)
            nn.ReLU()
        )

        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()

        self.lstm = nn.LSTM(
            input_size=28,
            hidden_size=64,
            num_layers=1,
            batch_first=True,
        )

        self.output = nn.Linear(64, 10)

    def forward(self, x):
        r_out, (_, _) = self.lstm(x, None)

        out = self.output(r_out[:, -1, :])
        return out

class TrainProcess:

    def __init__(self, model="CNN"):
        self.train, self.test = self.load_data()
        self.model = model
        if self.model == "CNN":
            self.net = CNN()
        elif self.model == "LSTM":
            self.net = LSTM()
        else:
            raise ValueError('"CNN" or "LSTM" is expected, but received "%s".' % model)
        self.criterion = nn.CrossEntropyLoss()  # 定义损失函数
        self.optimizer = optim.Adam(self.net.parameters(), lr=Config.alpha)

    @staticmethod
    def load_data():
        print("Loading Data......")
        """加载MNIST数据集,本地数据不存在会自动下载"""
        train_data = datasets.MNIST(root='./data/',
                                    train=True,
                                    transform=transforms.ToTensor(),
                                    download=True)

        test_data = datasets.MNIST(root='./data/',
                                   train=False,
                                   transform=transforms.ToTensor())

        # 返回一个数据迭代器
        # shuffle:是否打乱顺序
        train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                                   batch_size=Config.batch_size,
                                                   shuffle=True)

        test_loader = torch.utils.data.DataLoader(dataset=test_data,
                                                  batch_size=Config.batch_size,
                                                  shuffle=False)
        return train_loader, test_loader

    def train_step(self):
        steps = 0
        start_time = datetime.now()

        print("Training & Evaluating based on '%s'......" % self.model)
        for epoch in range(Config.epoch):
            print("Epoch {:3}.".format(epoch + 1))

            for data, label in self.train:
                data, label = Variable(data.cpu()), Variable(label.cpu())
                # LSTM输入为3维,CNN输入为4维
                if self.model == "LSTM":
                    data = data.view(-1, 28, 28)
                self.optimizer.zero_grad()  # 将梯度归零
                outputs = self.net(data)  # 将数据传入网络进行前向运算
                loss = self.criterion(outputs, label)  # 得到损失函数
                loss.backward()  # 反向传播
                self.optimizer.step()  # 通过梯度做一步参数更新

                # 每100次打印一次结果
                if steps % Config.print_per_step == 0:
                    _, predicted = torch.max(outputs, 1)
                    correct = int(sum(predicted == label))  # 计算预测正确个数
                    accuracy = correct / Config.batch_size  # 计算准确率
                    end_time = datetime.now()
                    time_diff = (end_time - start_time).seconds
                    time_usage = '{:3}m{:3}s'.format(int(time_diff / 60), time_diff % 60)
                    msg = "Step {:5}, Loss:{:6.2f}, Accuracy:{:8.2%}, Time usage:{:9}."
                    print(msg.format(steps, loss, accuracy, time_usage))

                steps += 1

        test_loss = 0.
        test_correct = 0
        for data, label in self.test:
            data, label = Variable(data.cpu()), Variable(label.cpu())
            if self.model == "LSTM":
                data = data.view(-1, 28, 28)
            outputs = self.net(data)
            loss = self.criterion(outputs, label)
            test_loss += loss * Config.batch_size
            _, predicted = torch.max(outputs, 1)
            correct = int(sum(predicted == label))
            test_correct += correct

        accuracy = test_correct / len(self.test.dataset)
        loss = test_loss / len(self.test.dataset)
        print("Test Loss: {:5.2f}, Accuracy: {:6.2%}".format(loss, accuracy))

        end_time = datetime.now()
        time_diff = (end_time - start_time).seconds
        print("Time Usage: {:5.2f} mins.".format(time_diff / 60.))

if __name__ == "__main__":
    p = TrainProcess(model='CNN')
    p.train_step()


Peace~~

原文地址:https://www.cnblogs.com/awesometang/p/12005651.html

时间: 2024-10-28 22:48:11

基于PyTorch实现MNIST手写字识别的相关文章

基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型

一.卷积神经网络模型知识要点卷积卷积 1.卷积 2.池化 3.全连接 4.梯度下降法 5.softmax 本次就是用最简单的方法给大家讲解这些概念,因为具体的各种论文网上都有,连推导都有,所以本文主要就是给大家做个铺垫,如有错误请指正,相互学习共同进步. 二.卷积神经网络讲解 2.1卷积神经网络作用 大家应该知道大名鼎鼎的傅里叶变换,即一个波形,可以有不同的正弦函数和余弦函数进行叠加完成,卷积神经网络也是一样,可以认为一张图片是由各种不同特征的图片叠加而成的,所以它的作用是用来提取特定的特征,举

基于tensorflow的MNIST手写识别

这个例子,是学习tensorflow的人员通常会用到的,也是基本的学习曲线中的一环.我也是! 这个例子很简单,这里,就是简单的说下,不同的tensorflow版本,相关的接口函数,可能会有不一样哟.在TensorFlow的中文介绍文档中的内容,有些可能与你使用的tensorflow的版本不一致了,我这里用到的tensorflow的版本就有这个问题. 另外,还给大家说下,例子中的MNIST所用到的资源图片,在原始的官网上,估计很多人都下载不到了.我也提供一下下载地址. 我的tensorflow的版

使用逻辑回归进行mnist手写字识别

1.引言 逻辑回归(LR)在分类问题中的应用十分广泛,它是一个基于概率的线性分类器,通过建立一个简单的输入层和输出层,即可实现对输入数据的有效分类.而该网络结构的主要参数只有两个,分别是权重和偏置,本文定义损耗函数为负对数,然后通过随机梯度下降算法(SGD)来对参数进行更新,并定义误差函数来衡量训练的阶段. 2.具体训练过程 在第3部分将会给出本文的完整python代码,其中用到的文件mnist.pkl.gz可以去网上下载,放到与python文件同一目录下面即可. 首先,定义一个基于object

Tensorflow编程基础之Mnist手写识别实验+关于cross_entropy的理解

好久没有静下心来写点东西了,最近好像又回到了高中时候的状态,休息不好,无法全心学习,恶性循环,现在终于调整的好一点了,听着纯音乐突然非常伤感,那些曾经快乐的大学时光啊,突然又慢慢的一下子出现在了眼前,不知道我大学的那些小伙伴们现在都怎么样了,考研的刚刚希望他考上,实习的菜头希望他早日脱离苦海,小瑞哥希望他早日出成果,范爷熊健研究生一定要过的开心啊!天哥也哥早日结婚领证!那些回不去的曾经的快乐的时光,你们都还好吗! 最近开始接触Tensorflow,可能是论文里用的是这个框架吧,其实我还是觉得py

win10下通过Anaconda安装TensorFlow-GPU1.3版本,并配置pycharm运行Mnist手写识别程序

折腾了一天半终于装好了win10下的TensorFlow-GPU版,在这里做个记录. 准备安装包: visual studio 2015: Anaconda3-4.2.0-Windows-x86_64: pycharm-community: CUDA:cuda_8.0.61_win10:下载时选择 exe(local) CUDA补丁:cuda_8.0.61.2_windows: cuDNN:cudnn-8.0-windows10-x64-v6.0;如果你安装的TensorFlow版本和我一样1.

tensorflow笔记(四)之MNIST手写识别系列一

tensorflow笔记(四)之MNIST手写识别系列一 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7436310.html 前言 这篇博客将利用神经网络去训练MNIST数据集,通过学习到的模型去分类手写数字. 我会将本篇博客的jupyter notebook放在最后,方便你下载在线调试!推荐结合官方的tensorflow教程来看这个notebook! 1. MNIST数据集的导入 这里介绍一下MNIST,MNIST是在

我的第一个svm程序:手写字识别

之前学过svm相关知识,基本原理不算复杂,今天做了一个手写字识别程序,总算验证了svm的效果. 因为只是验证效果,实现上原则是简单,使用python + libsvm + PIL(python image library).这部分工作花了一些时间: PIL: http://www.pythonware.com/products/pil/ 下载源码包,解压之后运行:python setup.py install即可. max下python libsvm安装使用:http://blog.csdn.n

Haskell手撸Softmax回归实现MNIST手写识别

Haskell手撸Softmax回归实现MNIST手写识别 前言 初学Haskell,看的书是Learn You a Haskell for Great Good, 才刚看到Making Our Own Types and Typeclasses这一章. 为了加深对Haskell的理解,便动手写了个Softmax回归.纯粹造轮子,只用了base. 显示图片虽然用了OpenGL,但是本文不会提到关于OpenGL的内容.虽说是造轮子, 但是这轮子造得还是使我受益匪浅.Softmax回归方面的内容参考

Python 基于KNN算法的手写识别系统

本文主要利用k-近邻分类器实现手写识别系统,训练数据集大约2000个样本,每个数字大约有200个样本,每个样本保存在一个txt文件中,手写体图像本身是32X32的二值图像,如下图所示: 手写数字识别系统的测试代码: from numpy import * import operator from os import listdir #inX    要检测的数据 #dataSet   数据集 #labels    结果集 #k      要对比的长度 def classify0(inX, data