pytorch 手写数字识别项目 增量式训练

dataset.py

‘‘‘
准备数据集
‘‘‘
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor,Compose,Normalize
import torchvision
import config

def mnist_dataset(train):
    func = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=(0.1307,),
            std = (0.3081,)
        )
    ])

    #准备Mnist数据集
    return MNIST(root="../mnist",train=train,download=False,transform=func)

def get_dataloader(train = True):
    mnist = mnist_dataset(train)
    batch_size = config.train_batch_size if train else config.test_batch_size
    return DataLoader(mnist,batch_size=batch_size,shuffle=True)

if __name__ == ‘__main__‘:
    for (images,labels) in get_dataloader():
        print(images.size())
        print(labels)
        break

  model.py

‘‘‘定义模型‘‘‘

import torch.nn as nn
import torch.nn.functional as F

class MnistModel(nn.Module):
    def __init__(self):
        super(MnistModel,self).__init__()
        self.fc1 = nn.Linear(28*28,100)
        self.fc2 = nn.Linear(100,10)

    def forward(self,image):
        image_viwed = image.view(-1,28*28)
        fc1_out = self.fc1(image_viwed)
        fc1_out_relu = F.relu(fc1_out)
        out = self.fc2(fc1_out_relu)

        return F.log_softmax(out,dim=-1)

  config.py

‘‘‘
项目配置
‘‘‘
import torch

train_batch_size = 128
test_batch_size = 128

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  train.py

‘‘‘
进行模型的训练
‘‘‘
from dataset import get_dataloader
from models import MnistModel
from torch import optim
import torch.nn.functional as F
import config
from tqdm import tqdm
import numpy as np
import torch
import os
from eval import eval

#实例化模型、优化器、损失函数
model = MnistModel().to(config.device)
optimizer = optim.Adam(model.parameters(),lr=0.001)

if os.path.exists("./model/mnist_net.pt"):
    model.load_state_dict(torch.load("./model/mnist_net.pt"))
    optimizer.load_state_dict(torch.load("model/mnist_optimizer.pt"))

#迭代训练

def train(epoch):
    train_dataloader = get_dataloader(train=True)
    bar = tqdm(enumerate(train_dataloader),total=len(train_dataloader))
    total_loss = []
    for idx,(input,target) in bar:
        input = input.to(config.device)
        target = target.to(config.device)
        #梯度置为0
        optimizer.zero_grad()
        #计算得到预测值
        output = model(input)
        #得到损失
        loss = F.nll_loss(output,target)
        total_loss.append(loss.item())
        #反向传播,计算损失
        loss.backward()
        #参数更新
        optimizer.step()

        if idx%10 ==0:
            bar.set_description("epoch:{} idx:{},loss:{}".format(epoch,idx,np.mean(total_loss)))
            torch.save(model.state_dict(),"model/mnist_net.pt")
            torch.save(optimizer.state_dict(),"model/mnist_optimizer.pt")

if __name__ == ‘__main__‘:
    for i in range(10):
        train(i)
        eval()

  eval.py

‘‘‘
进行模型的训练
‘‘‘
from dataset import get_dataloader
from models import MnistModel
from torch import optim
import torch.nn.functional as F
import config
import numpy as np
import torch
import os

#迭代训练

def eval():
    # 实例化模型、优化器、损失函数
    model = MnistModel().to(config.device)
    optimizer = optim.Adam(model.parameters(), lr=0.01)

    if os.path.exists("./model/mnist_net.pt"):
        model.load_state_dict(torch.load("./model/mnist_net.pt"))
        optimizer.load_state_dict(torch.load("model/mnist_optimizer.pt"))
    test_dataloader = get_dataloader(train=False)
    total_loss = []
    total_acc = []
    with torch.no_grad():
        for input,target in test_dataloader:
            input = input.to(config.device)
            target = target.to(config.device)
            #计算得到预测值
            output = model(input)
            #计算损失
            loss = F.nll_loss(output,target)
            #反向传播,计算损失
            total_loss.append(loss.item())
            #计算准确率
            pred = output.max(dim=-1)[-1]
            total_acc.append(pred.eq(target).float().mean().item())
    print("test loss:{},test acc:{}".format(np.mean(total_loss),np.mean(total_acc)))

if __name__ == ‘__main__‘:
        eval()

  

D:\anaconda\python.exe C:/Users/liuxinyu/Desktop/pytorch_test/day3/手写数字识别/train.py
epoch:0 idx:460,loss:0.32289110562095413: 100%|██████████| 469/469 [00:24<00:00, 19.05it/s]
test loss:0.17968503131142147,test acc:0.9453125
epoch:1 idx:460,loss:0.15012750004513145: 100%|█████████▉| 468/469 [00:20<00:00, 22.10it/s]epoch:1 idx:460,loss:0.15012750004513145: 100%|██████████| 469/469 [00:20<00:00, 22.52it/s]
test loss:0.12370304338916947,test acc:0.9624208860759493
epoch:2 idx:460,loss:0.10398845713577534:  99%|█████████▉| 464/469 [00:21<00:00, 22.78it/s]epoch:2 idx:460,loss:0.10398845713577534: 100%|█████████▉| 467/469 [00:21<00:00, 22.71it/s]epoch:2 idx:460,loss:0.10398845713577534: 100%|██████████| 469/469 [00:21<00:00, 21.82it/s]
test loss:0.10385569722592077,test acc:0.9697389240506329
epoch:3 idx:460,loss:0.07973297938720653: 100%|█████████▉| 467/469 [00:22<00:00, 23.12it/s]epoch:3 idx:460,loss:0.07973297938720653: 100%|██████████| 469/469 [00:22<00:00, 20.84it/s]
test loss:0.08691684670652015,test acc:0.9754746835443038
epoch:4 idx:460,loss:0.0650228117158285: 100%|█████████▉| 468/469 [00:21<00:00, 24.06it/s]epoch:4 idx:460,loss:0.0650228117158285: 100%|██████████| 469/469 [00:21<00:00, 21.79it/s]
test loss:0.0803159438309413,test acc:0.9760680379746836
epoch:5 idx:460,loss:0.05270117848966101: 100%|██████████| 469/469 [00:21<00:00, 21.92it/s]
test loss:0.08102699166423158,test acc:0.9759691455696202
epoch:6 idx:460,loss:0.04386751471317642: 100%|██████████| 469/469 [00:19<00:00, 24.58it/s]
test loss:0.07991968260347089,test acc:0.9769580696202531
epoch:7 idx:460,loss:0.03656852366544161: 100%|██████████| 469/469 [00:15<00:00, 31.20it/s]
test loss:0.07767781678917288,test acc:0.9774525316455697
epoch:8 idx:460,loss:0.03112584312896925: 100%|██████████| 469/469 [00:14<00:00, 32.41it/s]
test loss:0.07755146227494071,test acc:0.9773536392405063
epoch:9 idx:460,loss:0.025217091969725495: 100%|██████████| 469/469 [00:14<00:00, 31.53it/s]
test loss:0.07112929566845863,test acc:0.9802215189873418

  

原文地址:https://www.cnblogs.com/LiuXinyu12378/p/12314982.html

时间: 2024-07-31 16:03:25

pytorch 手写数字识别项目 增量式训练的相关文章

利用手写数字识别项目详细描述BP深度神经网络的权重学习

本篇文章是针对学习<深度学习入门>(由日本学者斋藤康毅所著陆羽杰所译)中关于神经网络的学习一章来总结归纳一些收获. 本书提出神经网络的学习分四步:1.mini-batch 2.计算梯度 3.更新参数 4.重复前面步骤 1.从识别手写数字项目学习神经网络 所谓“从数据中学习”是指 可以由数据#自动决定权重#.当解决较为简单的问题,使用简单的神经网络时,网络里的权重可以人为的手动设置,去提取输入信息中特定的特征.但是在实际的神经网络中,参数往往是成千上万,甚至可能上亿的权重,这个时候人为手动设置是

用pytorch做手写数字识别,识别l率达97.8%

pytorch做手写数字识别 效果如下: 工程目录如下 第一步  数据获取 下载MNIST库,这个库在网上,执行下面代码自动下载到当前data文件夹下 from torchvision.datasets import MNIST import torchvision mnist = MNIST(root='./data',train=True,download=True) print(mnist) print(mnist[0]) print(len(mnist)) img = mnist[0][

Pytorch入门实战一:LeNet神经网络实现 MNIST手写数字识别

记得第一次接触手写数字识别数据集还在学习TensorFlow,各种sess.run(),头都绕晕了.自从接触pytorch以来,一直想写点什么.曾经在2017年5月,Andrej Karpathy发表的一片Twitter,调侃道:l've been using PyTorch a few months now, l've never felt better, l've more energy.My skin is clearer. My eye sight has improved.确实,使用p

使用AI算法进行手写数字识别

人工智能 ??人工智能(Artificial Intelligence,简称AI)一词最初是在1956年Dartmouth学会上提出的,从那以后,研究者们发展了众多理论和原理,人工智能的概念也随之扩展.由于人工智能的研究是高度技术性和专业的,各分支领域都是深入且各不相通的,因而涉及范围极广 . 人工智能的核心问题包括建构能够跟人类似甚至超越人类的推理.知识.学习.交流.感知.使用工具和操控机械的能力等,当前人工智能已经有了初步成果,甚至在一些影像识别.语言分析.棋类游戏等等单方面的能力达到了超越

手把手教你搭建caffe及手写数字识别(全程命令提示、纯小白教程)

手把手教你搭建caffe及手写数字识别 作者:七月在线课程助教团队,骁哲.小蔡.李伟.July时间:二零一六年十一月九日交流:深度学习实战交流Q群 472899334,有问题可以加此群共同交流.另探究实验背后原理,请参看此课程:11月深度学习班. 一.前言 在前面的教程中,我们搭建了tensorflow.torch,教程发布后,大家的问题少了非常多.但另一大框架caffe的问题则也不少,加之caffe也是11月深度学习班要讲的三大框架之一,因此,我们再把caffe的搭建完整走一遍,手把手且全程命

机器学习(二)-kNN手写数字识别

一.kNN算法 1.kNN算法是机器学习的入门算法,其中不涉及训练,主要思想是计算待测点和参照点的距离,选取距离较近的参照点的类别作为待测点的的类别. 2,距离可以是欧式距离,夹角余弦距离等等. 3,k值不能选择太大或太小,k值含义,是最后选取距离最近的前k个参照点的类标,统计次数最多的记为待测点类标. 4,欧式距离公式: 二.关于kNN实现手写数字识别 1,手写数字训练集测试集的数据格式,本篇文章说明的是<机器学习实战>书提供的文件,将所有数字已经转化成32*32灰度矩阵. 三.代码结构构成

Tensorflow实战 手写数字识别(Tensorboard可视化)

一.前言 为了更好的理解Neural Network,本文使用Tensorflow实现一个最简单的神经网络,然后使用MNIST数据集进行测试.同时使用Tensorboard对训练过程进行可视化,算是打响学习Tensorflow的第一枪啦. 看本文之前,希望你已经具备机器学习和深度学习基础. 机器学习基础可以看我的系列博文: https://cuijiahua.com/blog/ml/ 深度学习基础可以看吴恩达老师的公开课: http://mooc.study.163.com/smartSpec/

Kaggle竞赛丨入门手写数字识别之KNN、CNN、降维

引言 这段时间来,看了西瓜书.蓝皮书,各种机器学习算法都有所了解,但在实践方面却缺乏相应的锻炼.于是我决定通过Kaggle这个平台来提升一下自己的应用能力,培养自己的数据分析能力. 我个人的计划是先从简单的数据集入手如手写数字识别.泰坦尼克号.房价预测,这些目前已经有丰富且成熟的方案可以参考,之后关注未来就业的方向如计算广告.点击率预测,有合适的时机,再与小伙伴一同参加线上比赛. 数据集 介绍 MNIST ("Modified National Institute of Standards an

使用Caffe进行手写数字识别执行流程解析

之前在 http://blog.csdn.net/fengbingchun/article/details/50987185 中仿照Caffe中的examples实现对手写数字进行识别,这里详细介绍下其执行流程并精简了实现代码,使用Caffe对MNIST数据集进行train的文章可以参考  http://blog.csdn.net/fengbingchun/article/details/68065338 : 1.   先注册所有层,执行layer_factory.hpp中类LayerRegis