从头学pytorch(十六):VGG NET

VGG

AlexNet在Lenet的基础上增加了几个卷积层,改变了卷积核大小,每一层输出通道数目等,并且取得了很好的效果.但是并没有提出一个简单有效的思路.
VGG做到了这一点,提出了可以通过重复使?简单的基础块来构建深度学习模型的思路.

论文地址:https://arxiv.org/abs/1409.1556

vgg的结构如下所示:

上图给出了不同层数的vgg的结构.也就是常说的vgg16,vgg19等等.

VGG BLOCK

vgg的设计思路是,通过不断堆叠3x3的卷积核,不断加深模型深度.vgg net证明了加深模型深度对提高模型的学习能力是一个很有效的手段.


看上图就能发现,连续的2个3x3卷积,感受野和一个5x5卷积是一样的,但是前者有两次非线性变换,后者只有一次!,这就是连续堆叠小卷积核能提高
模型特征学习的关键.此外,2个3x3的参数数量也比一个5x5少.(2x3x3 < 5x5)

vgg的基础组成模块,每一个卷积层都由n个3x3卷积后面接2x2的最大池化.池化层的步幅为2.从而卷积层卷积后,宽高不变,池化后,宽高减半.
我们可以有以下代码:

def make_layers(in_channels,cfg):
    layers = []
    previous_channel = in_channels #上一层的输出的channel数量
    for v in cfg:
        if v == 'M':
            layers.append(nn.MaxPool2d(kernel_size=2,stride=2))
        else:
            layers.append(nn.Conv2d(previous_channel,v,kernel_size=3,padding=1))
            layers.append(nn.ReLU())

            previous_channel = v

    conv = nn.Sequential(*layers)
    return conv

cfgs = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

cfgs定义了不同的vgg模型的结构,比如‘A‘代表vgg11. 数字代表卷积后的channel数. ‘M‘代表Maxpool

我们可以给出模型定义

class VGG(nn.Module):
    def __init__(self,input_channels,cfg,num_classes=10, init_weights=True):
        super(VGG, self).__init__()
        self.conv = make_layers(input_channels,cfg) # torch.Size([1, 512, 7, 7])
        self.fc = nn.Sequential(
            nn.Linear(512*7*7,4096),
            nn.ReLU(),
            nn.Linear(4096,4096),
            nn.ReLU(),
            nn.Linear(4096,num_classes)
        )

    def forward(self, img):
        feature = self.conv(img)
        output = self.fc(feature.view(img.shape[0], -1))
        return output

卷积层的输出可由以下测试代码得出

# conv = make_layers(1,cfgs['A'])
# X = torch.randn((1,1,224,224))
# out = conv(X)
# #print(out.shape)

加载数据

batch_size,num_workers=4,4
train_iter,test_iter = learntorch_utils.load_data(batch_size,num_workers,resize=224)

这里batch_size调到8我的显存就不够了...

定义模型

net = VGG(1,cfgs['A']).cuda()

定义损失函数

loss = nn.CrossEntropyLoss()

定义优化器 

opt = torch.optim.Adam(net.parameters(),lr=0.001)

定义评估函数

def test():
    acc_sum = 0
    batch = 0
    for X,y in test_iter:
        X,y = X.cuda(),y.cuda()
        y_hat = net(X)
        acc_sum += (y_hat.argmax(dim=1) == y).float().sum().item()
        batch += 1
    #print('acc_sum %d,batch %d' % (acc_sum,batch))

    return 1.0*acc_sum/(batch*batch_size)

训练

num_epochs = 3
def train():
    for epoch in range(num_epochs):
        train_l_sum,batch,acc_sum = 0,0,0
        start = time.time()
        for X,y in train_iter:
            # start_batch_begin = time.time()
            X,y = X.cuda(),y.cuda()
            y_hat = net(X)
            acc_sum += (y_hat.argmax(dim=1) == y).float().sum().item()

            l = loss(y_hat,y)
            opt.zero_grad()
            l.backward()

            opt.step()
            train_l_sum += l.item()

            batch += 1

            mean_loss = train_l_sum/(batch*batch_size) #计算平均到每张图片的loss
            start_batch_end = time.time()
            time_batch = start_batch_end - start

            print('epoch %d,batch %d,train_loss %.3f,time %.3f' %
                (epoch,batch,mean_loss,time_batch))

        print('***************************************')
        mean_loss = train_l_sum/(batch*batch_size) #计算平均到每张图片的loss
        train_acc = acc_sum/(batch*batch_size)     #计算训练准确率
        test_acc = test()                           #计算测试准确率
        end = time.time()
        time_per_epoch =  end - start
        print('epoch %d,train_loss %f,train_acc %f,test_acc %f,time %f' %
                (epoch + 1,mean_loss,train_acc,test_acc,time_per_epoch))

train()

4G的GTX 1050显卡,训练一个epoch大概一个多小时.
完整代码:https://github.com/sdu2011/learn_pytorch

原文地址:https://www.cnblogs.com/sdu20112013/p/12176304.html

时间: 2024-10-08 15:29:07

从头学pytorch(十六):VGG NET的相关文章

从头学pytorch(十):模型参数访问/初始化/共享

模型参数的访问初始化和共享 参数访问 参数访问:通过下述两个方法.这两个方法是在nn.Module类中实现的.继承自该类的子类也有相同方法. .parameters() .named_parameters() import torch from torch import nn from torch.nn import init net = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 1)) # pytorch已进行默认初始化 pr

从头学pytorch(十五):AlexNet

AlexNet AlexNet是2012年提出的一个模型,并且赢得了ImageNet图像识别挑战赛的冠军.首次证明了由计算机自动学习到的特征可以超越手工设计的特征,对计算机视觉的研究有着极其重要的意义. AlexNet的设计思路和LeNet是非常类似的.不同点主要有以下几点: 激活函数由sigmoid改为Relu AlexNet使用了dropout,LeNet没有使用 AlexNet引入了大量的图像增广,如翻转.裁剪和颜色变化,从而进一步扩大数据集来缓解过拟合 激活函数 relu \[\text

从头学pytorch(十四):lenet

卷积神经网络 在之前的文章里,对28 X 28的图像,我们是通过把它展开为长度为784的一维向量,然后送进全连接层,训练出一个分类模型.这样做主要有两个问题 图像在同一列邻近的像素在这个向量中可能相距较远.它们构成的模式可能难以被模型识别. 对于大尺寸的输入图像,使用全连接层容易造成模型过大.假设输入是高和宽均为1000像素的彩色照片(含3个通道).即使全连接层输出个数仍是256,该层权重参数的形状是\(3,000,000\times 256\),按照参数为float,占用4字节计算,它占用了大

从头学pytorch(五) 多层感知机及其实现

多层感知机 上图所示的多层感知机中,输入和输出个数分别为4和3,中间的隐藏层中包含了5个隐藏单元(hidden unit).由于输入层不涉及计算,图3.3中的多层感知机的层数为2.由图3.3可见,隐藏层中的神经元和输入层中各个输入完全连接,输出层中的神经元和隐藏层中的各个神经元也完全连接.因此,多层感知机中的隐藏层和输出层都是全连接层. 具体来说,给定一个小批量样本\(\boldsymbol{X} \in \mathbb{R}^{n \times d}\),其批量大小为\(n\),输入个数为\(

Java从零开始学三十六(JAVA IO- 字符流)

一.字符流 BufferedReader:BufferedReader是从缓冲区之中读取内容,所有的输入的字节数据都将放在缓冲区之中 BufferedWriter:把一批数据写入到缓冲区,当缓冲区区的满时,再把缓冲区的内容写到字符输出流中 二.对文本文件的读写 2.1.字符输入流 2.2.字符输出流 2.3.综合使用 package com.pb.io.buffered; import java.io.BufferedReader; import java.io.BufferedWriter;

Java从零开始学四十六(Junit)

一.软件测试 软件开发: 项目调研--需求分析--软件设计--程序编码--软件测试--运行维护 软件测试:利用测试工具按照测试方案和流程对产品进行功能和性能测试,使用人工或者自动手段来运行或测试某个系统的过程.目的在于检验是否满足规定的需求,确认预期结果与实际结果之间的差别. 墨盒测试-白盒测试-回归测试-单元测试 二.JUnit-单元测试工具 三.测试Junit测试类 创建被测试类 package com.pb.junit; /** *1.创建被测试类 *2.Junit 3.0或者4.0 这里

Java从零开始学二十六(包装类)

一.包装类 包装类是将基本类型封装到一个类中.也就是将基本数据类型包装成一个类类型. java程序设计为每一种基本类型都提供了一个包装类.这些包装类就在java.lang包中.有8个包装类 二.包装类的构造方法和静态方法 2.1.第一种 public Type (type value) 其中首字母大写的Type表示包装类,小写的type表示基本类型 这个构造方法接收一个基本数据类型值,并创建一个与之相应的包装类. 可以使用new关键字将一个基本类型包装为一个对象 Integer intValue

三分钟教你学Git(十六) - 统计

有时候想统计仓库的情况,比如代码量,贡献者之类的. 1 统计某人的commit数量 git log --author="$(git config --get user.name)" --oneline | wc -l 2 统计某人的代码量 git log --author="$(git config --get user.name)" --pretty=tformat: --numstat | awk '{adds += $1; subs += $2; all +=

边记边学PHP-(十六)PHP使用MySQL扩展库操作数据库

PHP提供了很多扩展库,这里说的是使用MySQL扩展库,但是这种扩展库在不久的将来就会被摒弃,因为如果使用MySQL扩展库编写的代码在运行的时候会有warning的提示.我本来想直接写另一种,但是感觉这是基础.MySQL扩展库,一说到库,自然而然就想到是一堆函数,很多函数组成一个库,使用扩展库也就是使用里面的函数.MySQL扩展库是完全面向过程的,显然不符合面向对象的特性,被摒弃也是可以理解的.废话不多说,直接上重点. 一.PHP使用MySQL扩展库操作数据库的示意图 此图是我自己画的,可能有不