[Pytorch]Pytorch 保存模型与加载模型(转)

转自:知乎

目录:

  • 保存模型与加载模型
  • 冻结一部分参数,训练另一部分参数
  • 采用不同的学习率进行训练

1.保存模型与加载

简单的保存与加载方法:

# 保存整个网络
torch.save(net, PATH)
# 保存网络中的参数, 速度快,占空间少
torch.save(net.state_dict(),PATH)
#--------------------------------------------------
#针对上面一般的保存方法,加载的方法分别是:
model_dict=torch.load(PATH)
model_dict=model.load_state_dict(torch.load(PATH))

然而,在实验中往往需要保存更多的信息,比如优化器的参数,那么可以采取下面的方法保存:

torch.save({‘epoch‘: epochID + 1, ‘state_dict‘: model.state_dict(), ‘best_loss‘: lossMIN,
‘optimizer‘: optimizer.state_dict(),‘alpha‘: loss.alpha, ‘gamma‘: loss.gamma},
checkpoint_path + ‘/m-‘ + launchTimestamp + ‘-‘ + str("%.4f" % lossMIN) + ‘.pth.tar‘)

以上包含的信息有,epochID, state_dict, min loss, optimizer, 自定义损失函数的两个参数;格式以字典的格式存储。

加载的方式:

def load_checkpoint(model, checkpoint_PATH, optimizer):
if checkpoint != None:
model_CKPT = torch.load(checkpoint_PATH)
model.load_state_dict(model_CKPT[‘state_dict‘])
print(‘loading checkpoint!‘)
optimizer.load_state_dict(model_CKPT[‘optimizer‘])
return model, optimizer

其他的参数可以通过以字典的方式获得

但是,但是,我们可能修改了一部分网络,比如加了一些,删除一些,等等,那么需要过滤这些参数,加载方式:

def load_checkpoint(model, checkpoint, optimizer, loadOptimizer):
if checkpoint != ‘No‘:
print("loading checkpoint...")
model_dict = model.state_dict()
modelCheckpoint = torch.load(checkpoint)
pretrained_dict = modelCheckpoint[‘state_dict‘]
# 过滤操作
new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
model_dict.update(new_dict)
# 打印出来,更新了多少的参数
print(‘Total : {}, update: {}‘.format(len(pretrained_dict), len(new_dict)))
model.load_state_dict(model_dict)
print("loaded finished!")
# 如果不需要更新优化器那么设置为false
if loadOptimizer == True:
optimizer.load_state_dict(modelCheckpoint[‘optimizer‘])
print(‘loaded! optimizer‘)
else:
print(‘not loaded optimizer‘)
else:
print(‘No checkpoint is included‘)
return model, optimizer

2.冻结部分参数,训练另一部分参数

1)添加下面一句话到模型中

for p in self.parameters():
p.requires_grad = False

比如加载了resnet预训练模型之后,在resenet的基础上连接了新的模快,resenet模块那部分可以先暂时冻结不更新,只更新其他部分的参数,那么可以在下面加入上面那句话

class RESNET_MF(nn.Module):
def init(self, model, pretrained):
super(RESNET_MF, self).__init__()
self.resnet = model(pretrained)
for p in self.parameters():
p.requires_grad = False
self.f = SpectralNorm(nn.Conv2d(2048, 512, 1))
self.g = SpectralNorm(nn.Conv2d(2048, 512, 1))
self.h = SpectralNorm(nn.Conv2d(2048, 2048, 1))
...

同时在优化器中添加:filter(lambda p: p.requires_grad, model.parameters())

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, betas=(0.9, 0.999),
eps=1e-08, weight_decay=1e-5)

2) 参数保存在有序的字典中,那么可以通过查找参数的名字对应的id值,进行冻结

查找的代码:

 model_dict = torch.load(‘net.pth.tar‘).state_dict()
dict_name = list(model_dict)
for i, p in enumerate(dict_name):
print(i, p)

保存一下这个文件,可以看到大致是这个样子的:

0 gamma
1 resnet.conv1.weight
2 resnet.bn1.weight
3 resnet.bn1.bias
4 resnet.bn1.running_mean
5 resnet.bn1.running_var
6 resnet.layer1.0.conv1.weight
7 resnet.layer1.0.bn1.weight
8 resnet.layer1.0.bn1.bias
9 resnet.layer1.0.bn1.running_mean
....

同样在模型中添加这样的代码:

for i,p in enumerate(net.parameters()):
if i < 165:
p.requires_grad = False

在优化器中添加上面的那句话可以实现参数的屏蔽

原文地址:https://www.cnblogs.com/kk17/p/10074188.html

时间: 2024-08-29 02:03:19

[Pytorch]Pytorch 保存模型与加载模型(转)的相关文章

torch保存加载模型

保存模型 torch.save(my_model.state_dict(), "params.pkl") 加载模型 先初始化model网络结构 model.load_state_dict(torch.load("params.pkl")) 原文地址:https://www.cnblogs.com/rise0111/p/11621640.html

[iTyran原创]iPhone中OpenGL ES显示3DS MAX模型之二:lib3ds加载模型

[iTyran原创]iPhone中OpenGL ES显示3DS MAX模型之二:lib3ds加载模型 作者:u0u0 - iTyran 在上一节中,我们分析了OBJ格式.OBJ格式优点是文本形式,可读性好,缺点也很明显,计算机解析文本过程会比解析二进制文件慢很多.OBJ还有个问题是各种3D建模工具导出的布局格式还不太一样,face还有多边形(超过三边形),不利于在OpenGL ES里面加载. .3ds文件是OBJ的二进制形式,并且多很多信息.有一个C语言写的开源库可以用来加.3ds文件,这就是l

OpenGL(二)加载模型

在OpenGL(一) OpenGL管线 与 可编程管线流程中,提到加载VBO.IBO的相关技术,本篇详细说一下.实际应用时,我们是不可能手写顶点和索引点.通常模型是使用3dMax或Maya制作,然后在OpenGL程序中 加载模型 .本文着重分析这些文件的格式以及 加载模型 的流程和方法. 大体流程 加载模型 的主要流程是: 读取模型文件内容 解析 vbo(vertex buffer object) 和 ibo(index buffer object) 信息.其中vbo包括顶点的位置.纹理坐标.法

解决在Azure SharePoint 2013 “在为项或数据源“FirstRSDS.rsds”加载模型时出现错误。请确认连接信息正确并且您有权访问该数据源。”

解决在Azure SharePoint 2013  "在为项或数据源"FirstRSDS.rsds"加载模型时出现错误.请确认连接信息正确并且您有权访问该数据源." 错误抓图如下 错误描述 <detail><ErrorCode xmlns="http://www.microsoft.com/sql/reportingservices">rsCannotRetrieveModel</ErrorCode><H

Libgdx New 3D API 教程之 -- 使用Libgdx加载模型

http://bbs.9ria.com/thread-221701-1-1.html 在前面的教程中,我们已经看到如何设置libgdx渲染3D场景.我们已经设置了Camera,增加了一些灯光并渲染一个绿色的盒子.现在让我们添加一个比盒子更有趣的东西,模型Model. 您可以从您喜爱的建模应用程序或使用已有的模型.我找了gdx-invaders里面的飞船模型文件,你可以点这里下载.您可以解压缩后,将文件放到的android项目的assets目录下.请注意,它包含三个文件,这些文件需要放同一个文件夹

libgdx3D第二讲-加载模型

定义: 将一个类(Adaptee)的接口转换成客户(Client)希望的另外一个接口(Target). 目标接口(Target):客户所期待的接口.目标可以是具体的或抽象的类,也可以是接口. 需要适配的类(Adaptee):需要适配的类或适配者类. 适配器(Adapter):使得一个东西适合另一个东西的东西.百度中定义为:接口转换器.通过包装一个需要适配的对象,把源接口转换成目标接口. 为什么要适配:需要的东西已做好,但是不能用,短时间又不能改造,想办法适配它. 作用: 使得原本由于接口不兼容而

Unity3d-WWW实现图片资源显示以及保存和本地加载

本文固定连接:http://blog.csdn.net/u013108312/article/details/52712844 WWW实现图片资源显示以及保存和本地加载 using UnityEngine; using System.Collections; using System.IO; using UnityEditor; enum GetPicType { DownLoad = 0, LocalLoad, } public class Picture : MonoBehaviour {

[深度学习] Pytorch(三)—— 多/单GPU、CPU,训练保存、加载模型参数问题

[深度学习] Pytorch(三)-- 多/单GPU.CPU,训练保存.加载预测模型问题 上一篇实践学习中,遇到了在多/单个GPU.GPU与CPU的不同环境下训练保存.加载使用使用模型的问题,如果保存.加载的上述三类环境不同,加载时会出错.就去研究了一下,做了实验,得出以下结论: 多/单GPU训练保存模型参数.CPU加载使用模型 #保存 PATH = 'cifar_net.pth' torch.save(net.module.state_dict(), PATH) #加载 net = Net()

pytorch中修改后的模型如何加载预训练模型

问题描述 简单来说,比如你要加载一个vgg16模型,但是你自己需要的网络结构并不是原本的vgg16网络,可能你删掉某些层,可能你改掉某些层,这时你去加载预训练模型,就会报错,错误原因就是你的模型和原本的模型不匹配. 此时有两种解决方法: 1.重新解析参数的字典,将预训练模型的参数提取出来,然后放在自己的模型中对应的位置 2.直接用原本的vgg16网络去加载预训练模型,然后再修改网络. 具体操作待续吧...... 我个人推荐第一种方法. 原文地址:https://www.cnblogs.com/y