解决pytorch 的state_dict()拷贝问题


Posted in Python onMarch 03, 2021

先说结论

model.state_dict()是浅拷贝,返回的参数仍然会随着网络的训练而变化。

应该使用deepcopy(model.state_dict()),或将参数及时序列化到硬盘。

再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。

补充:pytorch中state_dict的理解

在PyTorch中,state_dict是一个Python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数) 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。

其实看了如下代码的输出应该就懂了

import torch
import torch.nn as nn
import torchvision
import numpy as np
from torchsummary import summary
# Define model
class TheModelClass(nn.Module):
  def __init__(self):
    super(TheModelClass, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)
  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
  print(param_tensor,"\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
  print(var_name, "\t", optimizer.state_dict()[var_name])

输出如下:

Model's state_dict:
conv1.weight  torch.Size([6, 3, 5, 5])
conv1.bias  torch.Size([6])
conv2.weight  torch.Size([16, 6, 5, 5])
conv2.bias  torch.Size([16])
fc1.weight  torch.Size([120, 400])
fc1.bias  torch.Size([120])
fc2.weight  torch.Size([84, 120])
fc2.bias  torch.Size([84])
fc3.weight  torch.Size([10, 84])
fc3.bias  torch.Size([10])
Optimizer's state_dict:
state  {}
param_groups  [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]

我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!

补充:pytorch保存模型时报错***object has no attribute 'state_dict'

定义了一个类BaseNet并实例化该类:

net=BaseNet()

保存net时报错 object has no attribute 'state_dict'

torch.save(net.state_dict(), models_dir)

原因是定义类的时候不是继承nn.Module类,比如:

class BaseNet(object):
  def __init__(self):

把类定义改为

class BaseNet(nn.Module):
  def __init__(self):
    super(BaseNet, self).__init__()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
17个Python小技巧分享
Jan 23 Python
python读取json文件并将数据插入到mongodb的方法
Mar 23 Python
利用Fn.py库在Python中进行函数式编程
Apr 22 Python
详解Python中的文件操作
Aug 28 Python
浅谈Django学习migrate和makemigrations的差别
Jan 18 Python
Python文件常见操作实例分析【读写、遍历】
Dec 10 Python
DataFrame:通过SparkSql将scala类转为DataFrame的方法
Jan 29 Python
python中字符串数组逆序排列方法总结
Jun 23 Python
SELENIUM自动化模拟键盘快捷键操作实现解析
Oct 28 Python
python入门:argparse浅析 nargs='+'作用
Jul 12 Python
Python 测试框架unittest和pytest的优劣
Sep 26 Python
python基于tkinter制作下班倒计时工具
Apr 28 Python
解决pytorch 保存模型遇到的问题
Mar 03 #Python
解决pytorch 模型复制的一些问题
Mar 03 #Python
Pytorch模型迁移和迁移学习,导入部分模型参数的操作
Mar 03 #Python
pytorch 实现L2和L1正则化regularization的操作
Mar 03 #Python
Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作
Mar 03 #Python
python爬取youtube视频的示例代码
Mar 03 #Python
pytorch Dataset,DataLoader产生自定义的训练数据案例
Mar 03 #Python
You might like
用Apache反向代理设置对外的WWW和文件服务器
2006/10/09 PHP
Thinkphp模板中使用自定义函数的方法
2012/09/23 PHP
PHP实现HTML生成PDF文件的方法
2014/11/07 PHP
PHP可变变量学习小结
2015/11/29 PHP
PHP保存Base64图片base64_decode的问题整理
2019/11/04 PHP
根据地区不同显示时间的javascript代码
2007/08/13 Javascript
在JavaScript中获取请求的URL参数
2010/12/22 Javascript
JS隐藏参数post传值实例
2013/04/18 Javascript
node.js中的fs.chown方法使用说明
2014/12/16 Javascript
js进行表单验证实例分析
2015/02/10 Javascript
jQuery使用模式窗口实现在主页面和子页面中互相传值的方法
2016/03/01 Javascript
Vue中 key keep-alive的实现原理
2018/09/18 Javascript
node.js连接mysql与基本用法示例
2019/01/05 Javascript
小程序云开发实现数据库异步操作同步化
2019/05/18 Javascript
JS简单表单验证功能完整示例
2020/01/26 Javascript
Vite和Vue CLI的优劣
2021/01/30 Vue.js
Numpy array数据的增、删、改、查实例
2018/06/04 Python
Python功能点实现:函数级/代码块级计时器
2019/01/02 Python
python多线程共享变量的使用和效率方法
2019/07/16 Python
Python读取excel文件中带公式的值的实现
2020/04/17 Python
浅谈keras2 predict和fit_generator的坑
2020/06/17 Python
pytorch cuda上tensor的定义 以及减少cpu的操作详解
2020/06/23 Python
如何以Winsows Service方式运行JupyterLab
2020/08/30 Python
Python爬虫中Selenium实现文件上传
2020/12/04 Python
python网络爬虫实现发送短信验证码的方法
2021/02/25 Python
html5 初试 indexedDB(推荐)
2016/07/21 HTML / CSS
泰国汽车、火车和轮渡票预订网站:Bus Online Ticket
2017/09/09 全球购物
巴西最大的家具及装饰用品店:Mobly
2017/10/11 全球购物
英国HYPE双肩包官网:英国本土时尚潮牌
2018/09/26 全球购物
医学院校毕业生自荐信范文
2014/01/01 职场文书
医院反腐倡廉演讲稿
2014/09/16 职场文书
民主生活会整改措施(党员)
2014/09/18 职场文书
民间个人借款协议书
2014/09/30 职场文书
六查六看心得体会
2014/10/14 职场文书
涨工资申请书应该怎么写?
2019/07/08 职场文书
Python Parser的用法
2021/05/12 Python