解决pytorch 的state_dict()拷贝问题


Posted in Python onMarch 03, 2021

先说结论

model.state_dict()是浅拷贝,返回的参数仍然会随着网络的训练而变化。

应该使用deepcopy(model.state_dict()),或将参数及时序列化到硬盘。

再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。

补充:pytorch中state_dict的理解

在PyTorch中,state_dict是一个Python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数) 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。

其实看了如下代码的输出应该就懂了

import torch
import torch.nn as nn
import torchvision
import numpy as np
from torchsummary import summary
# Define model
class TheModelClass(nn.Module):
  def __init__(self):
    super(TheModelClass, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)
  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
  print(param_tensor,"\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
  print(var_name, "\t", optimizer.state_dict()[var_name])

输出如下:

Model's state_dict:
conv1.weight  torch.Size([6, 3, 5, 5])
conv1.bias  torch.Size([6])
conv2.weight  torch.Size([16, 6, 5, 5])
conv2.bias  torch.Size([16])
fc1.weight  torch.Size([120, 400])
fc1.bias  torch.Size([120])
fc2.weight  torch.Size([84, 120])
fc2.bias  torch.Size([84])
fc3.weight  torch.Size([10, 84])
fc3.bias  torch.Size([10])
Optimizer's state_dict:
state  {}
param_groups  [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]

我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!

补充:pytorch保存模型时报错***object has no attribute 'state_dict'

定义了一个类BaseNet并实例化该类:

net=BaseNet()

保存net时报错 object has no attribute 'state_dict'

torch.save(net.state_dict(), models_dir)

原因是定义类的时候不是继承nn.Module类,比如:

class BaseNet(object):
  def __init__(self):

把类定义改为

class BaseNet(nn.Module):
  def __init__(self):
    super(BaseNet, self).__init__()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
使用python获取CPU和内存信息的思路与实现(linux系统)
Jan 03 Python
Python解析网页源代码中的115网盘链接实例
Sep 30 Python
详解Python中open()函数指定文件打开方式的用法
Jun 04 Python
python-opencv 将连续图片写成视频格式的方法
Jan 08 Python
Python3利用Dlib实现摄像头实时人脸检测和平铺显示示例
Feb 21 Python
python sort、sort_index方法代码实例
Mar 28 Python
解决pyqt5中QToolButton无法使用的问题
Jun 21 Python
Python Pandas数据中对时间的操作
Jul 30 Python
在Python3 numpy中mean和average的区别详解
Aug 24 Python
Python基于Socket实现简单聊天室
Feb 17 Python
基于python 等频分箱qcut问题的解决
Mar 03 Python
python使用for...else跳出双层嵌套循环的方法实例
May 17 Python
解决pytorch 保存模型遇到的问题
Mar 03 #Python
解决pytorch 模型复制的一些问题
Mar 03 #Python
Pytorch模型迁移和迁移学习,导入部分模型参数的操作
Mar 03 #Python
pytorch 实现L2和L1正则化regularization的操作
Mar 03 #Python
Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作
Mar 03 #Python
python爬取youtube视频的示例代码
Mar 03 #Python
pytorch Dataset,DataLoader产生自定义的训练数据案例
Mar 03 #Python
You might like
php获取mysql版本的几种方法小结
2008/03/25 PHP
php中文繁体和简体相互转换的方法
2015/03/21 PHP
在CentOS上搭建LAMP+vsftpd环境的简单指南
2015/08/01 PHP
WordPress主题制作中自定义头部的相关PHP函数解析
2016/01/08 PHP
PHP5.0 TIDY_PARSE_FILE缓冲区溢出漏洞的解决方案
2018/10/14 PHP
JS面向对象编程之对象使用分析
2010/08/19 Javascript
jquery实现心算练习代码
2010/12/06 Javascript
Safari5中alert的无限循环BUG
2011/04/07 Javascript
jquery中ajax学习笔记一
2011/10/16 Javascript
JavaScript 命名空间 使用介绍
2013/08/29 Javascript
jQuery实现鼠标点击弹出渐变层的方法
2015/07/09 Javascript
JavaScript实现简单获取当前网页网址的方法
2015/11/09 Javascript
深入理解JavaScript中的对象复制(Object Clone)
2016/05/18 Javascript
JS及PHP代码编写八大排序算法
2016/07/12 Javascript
JS 组件系列之Bootstrap Table 冻结列功能IE浏览器兼容性问题解决方案
2017/06/30 Javascript
Node.JS更改Windows注册表Regedit的方法小结
2017/08/18 Javascript
vue项目中导入swiper插件的方法
2018/01/30 Javascript
js中split()方法得到的数组长度问题
2018/07/19 Javascript
微信小程序自定义tabBar组件开发详解
2020/09/24 Javascript
详解如何探测小程序返回到webview页面
2019/05/14 Javascript
如何进行微信公众号开发的本地调试的方法
2019/06/16 Javascript
taro 实现购物车逻辑的实例代码
2020/06/05 Javascript
JS数组reduce()方法原理及使用技巧解析
2020/07/14 Javascript
[04:52]第二届DOTA2亚洲邀请赛主赛事第一天比赛集锦:OG娜迦海妖放大配合谜团大中3人
2017/04/02 DOTA
python Django连接MySQL数据库做增删改查
2013/11/07 Python
利用Python绘制数据的瀑布图的教程
2015/04/07 Python
一波神奇的Python语句、函数与方法的使用技巧总结
2015/12/08 Python
python基于pyDes库实现des加密的方法
2017/04/29 Python
python pandas消除空值和空格以及 Nan数据替换方法
2018/10/30 Python
python中random模块详解
2021/03/01 Python
施华洛世奇德国官网:SWAROVSKI德国
2017/02/01 全球购物
.NET是怎么支持多种语言的
2015/02/24 面试题
教师产假请假条范文
2014/04/10 职场文书
生态养殖创业计划书
2014/05/06 职场文书
2016大学生优秀志愿者事迹材料
2016/02/25 职场文书
pytorch 两个GPU同时训练的解决方案
2021/06/01 Python