PyTorch LSTM的一个简单例子:实现MNIST图片分类

上一篇博客中,我们实现了用LSTM对单词进行词性判断,本篇博客我们将实现用LSTM对MNIST图片分类。MNIST图片的大小为28*28,我们将其看成长度为28的序列,序列中的每个数据的维度是28,这样我们就可以把它变成一个序列数据了。代码如下。

‘‘‘
本程序实现用LSTM对MNIST进行图片分类
‘‘‘

import torch
import numpy as np
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt

# Hyper parameter
EPOCH = 1
LR = 0.001    # learning rate
BATCH_SIZE = 50

# Mnist digit dataset
train_data = torchvision.datasets.MNIST(
    root=‘/Users/wangpeng/Desktop/all/CS/Courses/Deep Learning/mofan_PyTorch/mnist/‘,    # mnist has been downloaded before, use it directly
    train=True,    # this is training data
    transform=torchvision.transforms.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to
                                                    # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
    download=False,
)

# print(train_data.data.size())       # (60000, 28, 28)
# print(train_data.targets.size())    # (60000)
# plot one image
# plt.imshow(train_data.data[0].numpy(), cmap=‘gray‘)
# plt.title(‘{:d}‘.format(train_data.targets[0]))
# plt.show()

# Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

test_data = torchvision.datasets.MNIST(
    root=‘/Users/wangpeng/Desktop/all/CS/Courses/Deep Learning/mofan_PyTorch/mnist/‘,
    train=False,  # this is training data
)
# print(test_data.data.size())       # (10000, 28, 28)
# print(test_data.targets.size())    # (10000)
# pick 2000 samples to speed up testing
test_x = test_data.data.type(torch.FloatTensor)[:2000]/255    # shape (2000, 28, 28), value in range(0,1)
test_y = test_data.targets[:2000]

class LSTMnet(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_layer, n_class):
        super(LSTMnet, self).__init__()
        self.n_layer = n_layer
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True)
        self.linear = nn.Linear(hidden_dim, n_class)

    def forward(self, x):                  # x‘s shape (batch_size, 序列长度, 序列中每个数据的长度)
        out, _ = self.lstm(x)              # out‘s shape (batch_size, 序列长度, hidden_dim)
        out = out[:, -1, :]                # 中间的序列长度取-1,表示取序列中的最后一个数据,这个数据长度为hidden_dim,
                                           # 得到的out的shape为(batch_size, hidden_dim)
        out = self.linear(out)             # 经过线性层后,out的shape为(batch_size, n_class)
        return out

model = LSTMnet(28, 64, 2, 10)             # 图片大小28*28,lstm的每个隐藏层64个节点,2层隐藏层
if torch.cuda.is_available():
    model = model.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

# training and testing
for epoch in range(EPOCH):
    for iteration, (train_x, train_y) in enumerate(train_loader):    # train_x‘s shape (BATCH_SIZE,1,28,28)
        train_x = train_x.squeeze()        # after squeeze, train_x‘s shape (BATCH_SIZE,28,28),
                                           # 第一个28是序列长度,第二个28是序列中每个数据的长度。
        output = model(train_x)
        loss = criterion(output, train_y)  # cross entropy loss
        optimizer.zero_grad()              # clear gradients for this training step
        loss.backward()                    # backpropagation, compute gradients
        optimizer.step()                   # apply gradients

        if iteration % 100 == 0:
            test_output = model(test_x)
            predict_y = torch.max(test_output, 1)[1].numpy()
            accuracy = float((predict_y == test_y.numpy()).astype(int).sum()) / float(test_y.size(0))
            print(‘epoch:{:<2d} | iteration:{:<4d} | loss:{:<6.4f} | accuracy:{:<4.2f}‘.format(epoch, iteration, loss, accuracy))

# print 10 predictions from test data
test_out = model(test_x[:10])
pred_y = torch.max(test_out, dim=1)[1].data.numpy()
print(‘The predict number is:‘)
print(pred_y)
print(‘The real number is:‘)
print(test_y[:10].numpy())

结果如下:

参考资料:

[1] 10分钟快速入门PyTorch (6)

[2] 莫烦PyTorch教程系列:CNN卷积神经网络

原文地址:https://www.cnblogs.com/picassooo/p/12556293.html

时间: 2024-10-05 23:25:27

PyTorch LSTM的一个简单例子:实现MNIST图片分类的相关文章

从一个简单例子来理解js引用类型指针的工作方式

? 1 2 3 4 5 6 7 <script> var a = {n:1};  var b = a;   a.x = a = {n:2};  console.log(a.x);// --> undefined  console.log(b.x);// --> [object Object]  </script> 上面的例子看似简单,但结果并不好了解,很容易把人们给想绕了--"a.x不是指向对象a了么?为啥log(a.x)是undefined?".&

C语言多线程的一个简单例子

多线程的一个简单例子: #include <stdio.h> #include <stdlib.h> #include <string.h> #include <unistd.h> #include <pthread.h> void * print_a(void *); void * print_b(void *); int main(){ pthread_t t0; pthread_t t1; // 创建线程A if(pthread_creat

生产者与消费者的一个简单例子

生产者 #include<fstream> #include<iostream> #include<Windows.h> using namespace std; int main(void) { ofstream out; const char ch = '*'; long long k = 0; DWORD64 time = GetTickCount64(); while (true) { if (GetTickCount64() - time > 5000)

一个简单的全屏图片上下打开显示网页效果

打包下载地址:http://download.csdn.net/detail/sweetsuzyhyf/7602105 上源码看效果: <!DOCTYPE html> <html> <head> <title></title> <style> body { margin: 0; padding: 0; } .wrap { overflow: hidden; position: fixed; z-index: 99999; width:

PyQt安装与一个简单例子

PyQt在Windows+Visual Studio下安装所需文件如下: python-2.7.3.msi (www.python.org/download) sip-4.14.2.zip (www.riverbankcomputing.co.uk/software/sip/download) PyQt-Py2.7-x86-gpl-4.9.6-1.exe(www.riverbankcomputing.co.uk/software/pyqt/download) 安装方法: 首先安装python2.

词法分析程序 LEX和VC6整合使用的一个简单例子

词法分析的理论知识不少,包括了正规式.正规文法.它们之间的转换以及确定的有穷自动机和不确定的有穷自动机等等... 要自己写一个词法分析器也不会很难,只要给出了最简的有穷自动机,就能很方便实现了,用if.switch-case来写一通所谓的状态转换就可以,我近期会写一个简单的词法分析程序来作为例子... 现在已经有人发明了一个叫LEX的工具让你去应用,那我们就省了不少力气,毕竟没到万不得已的时候,我们都没必要重新发明轮子,从另一个角度来说,使用工具是我们人类知识继承的一种方法,也是我们比其他动物优

一个简单例子了解使用互斥量线程同步

在刚开始学习学习线程同步时总是认为两个线程或是多个线程共同运行,但是那样是做的. 同步就是协同步调,按预定的先后次序进行运行.如:你说完,我再说. "同"字从字面上容易理解为一起动作. 其实不是,"同"字应是指协同.协助.互相配合. 如进程.线程同步,可理解为进程或线程A和B一块配合,A执行到一定程度时要依靠B的某个结果,于是停下来,示意B运行:B依言执行,再将结果给A:A再继续操作. 所谓同步,就是在发出一个功能调用时,在没有得到结果之前,该调用就不返回,同时其它

netsh interface portproxy的一个简单例子

netsh interface portproxy的微软帮助文档地址: https://technet.microsoft.com/zh-cn/library/cc776297(WS.10).aspx#BKMK_1 下面是一个简单的例子: //显示所有 portproxy 参数,包括 v4tov4.v4tov6.v6tov4 和 v6tov6 的端口/地址对. C:\>netsh interface portproxy show all //因为没有配置过它,所以没有东西可以显示. //添加配置

Spring MVC:使用SimpleUrlHandlerMapping的一个简单例子

实现一个控制器ShirdrnController,如下所示: package org.shirdrn.spring.mvc; import java.util.Date; import javax.servlet.http.HttpServletRequest;import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log;import org.apache.commons.logging.