手把手教你写一个用pytorch实现的Lenet5

最近为了实现HR-net在学习pytorch,然后突然发现这个框架简直比tensorflow要方便太多太多啊,我本来其实不太喜欢python,但是这个框架使用的流畅性真的让我非常的喜欢,下面我就开始介绍从0开始编写一个Lenet并用它来训练cifar10。

1.首先需要先找到Lenet的结构图再考虑怎么去实现它,在网上找了一个供参考

2.需要下载好cifar-10的数据集,在pytorch下默认的是下载cifar-10-python版本的,由于官网速度较慢,我直接提供度娘网盘的链接:链接:https://pan.baidu.com/s/18LNEZmGVkzEwf3SgOrO2rw  密码:n1h7

3.下载好数据集后,需要定义网络的结构,根据图我们可以看出,整个lenet只有两个卷积层,两个池化层(其实应该叫降采样层,那个时候还没有池化),三个全连接层。

pytorch中有一个容器,叫做Sequential,你可以在这个容器里添加你需要使用的卷积,池化,全连接操作,但是,这个Sequential它只能包含类方法定义的层,而不能包含像torch.Functional里面的函数方法(可能我说的不专业,见谅),所以如果当你想自己定义某个层的话,例如在输入全连接层之前,需要将形如[batch_size,channel,higth,width]的tensor转化成[batch_size,channel*higth*width]这种形式,那我如果想在Sequential这个容器里加入这一个操作该怎么办呢,这时候就需要我们继承nn.Module这个类来实现,具体的方法如下

import torch
import torch.nn as nn
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
    def forward(self,input):
        out=input.view(input.size(0),-1)
        return out

好了,介绍完Sequential我们就开始实现这个网络的结构吧

#文件名是Lenet5.py
import torch
import torch.nn as nn
from pytorch__lesson.pytorch_mnist.main import Flatten
class Lenet(nn.Module):
    def __init__(self):
        super(Lenet, self).__init__()
        self.net=nn.Sequential(
            nn.Conv2d(3,6,5,stride=1,padding=0),
            nn.MaxPool2d(kernel_size=2,stride=2,padding=0),
            nn.Conv2d(6,16,5,stride=1,padding=0),
            nn.MaxPool2d(kernel_size=2,stride=2,padding=0),
            Flatten(),
            nn.Linear(400,120),
            nn.ReLU(inplace=True),
            nn.Linear(120,84),
            nn.ReLU(inplace=True),
            nn.Linear(84,10),
            nn.ReLU(inplace=True)
        )
        # self.criteon=nn.CrossEntropyLoss()
    def forward(self,x):
        logits= self.net(x)
        # pred=nn.Softmax(logits,dim=1),这一行不需要写,因为在CrossEntropyLoss这一步包含了softmax的操作
        return logits
# net=Lenet()
# input=torch.randn(2,3,32,32)
# out=net(input)
# print(out.shape)

其中这里面的Flatten就是上面代码的Flatten类。因为它继承了nn.Module因此可以直接将其放在Sequential里面了,以后定义任何网络,我们都可以使用这个类来进行tensor的展平操作。

4.接下来就可以定义训练部分的代码了

import torch
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.functional as F
from pytorch__lesson.pytorch_mnist.Lenet5 import Lenet

batch_size=32
def main():
    # cifar表示的是在当前的目录下新建一个叫cifar的文件夹,这个方法一次只能加载一张
    cifar_train=datasets.CIFAR10(‘cifar‘,train=True,transform=transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor()
    ]),download=True)
    # 这个方法才能保证一次读取进来的是一个batch_size大小的数据
    cifar_train_loader=DataLoader(cifar_train,batch_size=batch_size,shuffle=True)

    cifar_test=datasets.CIFAR10(‘cifar‘,train=False,transform=transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor()
    ]))
    cifar_test_loader=DataLoader(cifar_test,batch_size=batch_size,shuffle=False)

    x,label=iter(cifar_train_loader).next()
    print(‘x shapex:‘,x.shape,‘label shape:‘,label.shape)

    # use CrossEntropy as the loss function
    criteon=nn.CrossEntropyLoss()
    # use Lenet() function to build a model
    # net=Lenet().to(device) 将模型放入cuda上进行加速
    net=Lenet()
    optimizer=optim.Adam(net.parameters(),lr=1e-3)
    # device=torch.device(‘cuda‘)
    # net=Lenet().to(device) 将模型放入cuda上进行加速
    print(net)
    for epoch in range(1000):
        for batchidx,(x,label) in enumerate(cifar_train_loader):
            # 生成软对数
            # 将网络转化成train的模式
            net.train()
            logits=net(x)
            # x,label=x.to(device),label.to(device)
            # 使用crossentropyloss的就不需要将logits放入到softmax中了,直接就可以计算出loss
            loss=criteon(logits,label)

            #接下来进行反向的传播,先是将梯度清零,再进行反向传播,再进行梯度更新
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        # loss是一个tensor scalor 是一个长度为0的标量
        print(epoch,loss.item())

        net.eval()
        with torch.no_grad():
            # 将整个网络转换成test模式或者validation模式
            # test这一部分不需要构造计算图也不需要统计梯度,因此将这部分放在函数torch.no_grad()
            total_correct=0
            total_num=0
            for x,label in cifar_test_loader:
                #  如果有gpu的话先将x和label放入gup进行加速
                # [batch_size,10]
                logits=net(x)
                # 取出最大下标的索引[b]
                pred=logits.argmax(dim=1)
                # eq函数调用后会返回一个byte,true或者false估计,然后需要将其转换成float类型再通过item()函数来提取它的值
                total_correct+=torch.eq(label,pred).float().sum()
                total_num+=x.size(0)
            acc=total_correct/total_num
            print(epoch,‘the acc of the test is :‘,(acc*100))

if __name__==‘__main__‘:
    main()

因为我的电脑没有英伟达的显卡,不支持cuda加速,因此的话没办法都训练出来截图,如果有N卡的,可以自己试试,注释写的比较详细,我就不再赘述了,不是很难。

ps:我太唠叨了吧??

原文地址:https://www.cnblogs.com/daremosiranaihana/p/12564245.html

时间: 2024-08-27 15:09:10

手把手教你写一个用pytorch实现的Lenet5的相关文章

大神手把手教你写一个页面模板引擎,只需20行Javascript代码!

只用20行Javascript代码就写出一个页面模板引擎的大神是AbsurdJS的作者,下面是他分享的全文,转需. 不知道你有木有听说过一个基于Javascript的Web页面预处理器,叫做AbsurdJS.我是它的作者,目前我还在不断地完善它.最初我只是打算写一个CSS的预处理器,不过后来扩展到了CSS和HTML,可以用来把Javascript代码转成CSS和HTML代码.当然,由于可以生成HTML代码,你也可以把它当成一个模板引擎,用于在标记语言中填充数据. 于是我又想着能不能写一些简单的代

手把手教你写一个RN小程序!

时间过得真快,眨眼已经快3年了! 1.我的第一个App 还记得我14年初写的第一个iOS小程序,当时是给别人写的一个单机的相册,也是我开发的第一个完整的app,虽然功能挺少,但是耐不住心中的激动啊,现在我开始学react native,那么现在对于react native也算是有所了解了,就用网上的接口开发一个小程序,现在带大家来写这个程序!接口是用看知乎的API,简简单单的只有get,可以从这里入门,也算是带大家入门吧,过后我会把源代码放在我的github上,前期项目肯定特别简陋,后面慢慢来优

手把手教你写一个java的orm(二)

创建映射关系 ? 想要实现一个orm的功能,我觉得就是要将class和数据库中的表创建映射关系.把class的名称和表的名称,class属性名称和表的字段名称,属性类型与表的字段类型一一对应起来.可以通过配置文件,注解等等各种方式实现这个映射关系. 需要的依赖 ? 因为编写配置文件总是一件十分繁琐的事情,所以我决定使用注解的方式来实现这个映射.在项目刚开始写的时候我用的是自定义注解的方法.自己规定一套注解,后来觉得这样没有太大的必要,因为已经有jpa里的一套注解.所以直接用就好了.所以添加依赖:

手把手教你写一个java的orm(完)

生成sql:select 上一篇讲了怎样生成一个sql中where的一部分,之后我们要做事情就简单很多了,就只要像最开始一样的生成各种sql语句就好了,之后只要再加上我们需要的条件,一个完整的sql就顺利的做好了. 现在我们开始写生成查询语句的sql.一个查询语句大致上是这样的: SELECT name, id, create_date, age, mark, status FROM user 这里可以看出来,一个基础的查询语句基本上就是一个 SELECT 后面加上需要查询的字段,跟上 FROM

手把手教你写一个通用的helm chart

[TOC] 1. 模板介绍 首先,放上此模板链接: https://github.com/ygqygq2/charts/tree/master/mod-chart 此chart可当作POD单image的通用模板,只需要使用sed替换下chart名,并修改下README.md和NOTES.txt就可以了.下文,我通过复制此chart成example-chart来作示范说明. [[email protected] mod-chart]# tree . ├── Chart.yaml ├── READM

Android开发之手把手教你写ButterKnife框架(二)

欢迎转载,转载请标明出处: http://blog.csdn.net/johnny901114/article/details/52664112 本文出自:[余志强的博客] 上一篇博客Android开发之手把手教你写ButterKnife框架(一)我们讲了ButterKnife是什么.ButterKnife的作用和功能介绍以及ButterKnife的实现原理. 本篇博客主要讲在android studio中如何使用apt. 一.新建个项目, 然后创建一个module名叫processor 新建m

手把手教你写Sublime中的Snippet

手把手教你写Sublime中的Snippet Sublime Text号称最性感的编辑器, 并且越来越多人使用, 美观, 高效 关于如何使用Sublime text可以参考我的另一篇文章, 相信你会喜欢上的..Sublime Text 2使用心得 现在介绍一下Snippet, Snippets are smart templates that will insert text for you and adapt it to their context. Snippet 是插入到文本中的智能模板并

手把手教你写专利申请书/怎样申请专利

手把手教你写专利申请书·怎样申请专利 摘要小前言(一)申请前的准备工作    1.申请前查询    2.其它方面的考虑    3.申请文件准备(二)填写专利申请系列文档    1.实际操作步骤    2.详细操作    3.经验分享.注意事项(三)关于费用(四)其它的话參考资源提示常见问题的问与答 摘要: 怎样写好专利申请?由于非常多专利申请人都是第一次申请,因此,可能有一种神奇和些许恐惧.本文写的是怎样写专利申请书,手把手教你写专利申请并提供申请专利时的注意事项,专利申请费用及费用减缓等相关參

手把手教你写Windows 64位平台调试器

本文网页排版有些差,已上传了doc,可以下载阅读.本文中的所有代码已打包,下载地址在此. -------------------------------------------------------------------------------------------------------------------------------------------------------------- 手写一个调试器有助于我们理解hook.进程注入等底层黑客技术具体实现,在编写过程中需要涉及大