解决pytorch 的state_dict()拷贝问题


Posted in Python onMarch 03, 2021

先说结论

model.state_dict()是浅拷贝,返回的参数仍然会随着网络的训练而变化。

应该使用deepcopy(model.state_dict()),或将参数及时序列化到硬盘。

再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。

补充:pytorch中state_dict的理解

在PyTorch中,state_dict是一个Python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数) 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。

其实看了如下代码的输出应该就懂了

import torch
import torch.nn as nn
import torchvision
import numpy as np
from torchsummary import summary
# Define model
class TheModelClass(nn.Module):
  def __init__(self):
    super(TheModelClass, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)
  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
  print(param_tensor,"\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
  print(var_name, "\t", optimizer.state_dict()[var_name])

输出如下:

Model's state_dict:
conv1.weight  torch.Size([6, 3, 5, 5])
conv1.bias  torch.Size([6])
conv2.weight  torch.Size([16, 6, 5, 5])
conv2.bias  torch.Size([16])
fc1.weight  torch.Size([120, 400])
fc1.bias  torch.Size([120])
fc2.weight  torch.Size([84, 120])
fc2.bias  torch.Size([84])
fc3.weight  torch.Size([10, 84])
fc3.bias  torch.Size([10])
Optimizer's state_dict:
state  {}
param_groups  [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]

我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!

补充:pytorch保存模型时报错***object has no attribute 'state_dict'

定义了一个类BaseNet并实例化该类:

net=BaseNet()

保存net时报错 object has no attribute 'state_dict'

torch.save(net.state_dict(), models_dir)

原因是定义类的时候不是继承nn.Module类,比如:

class BaseNet(object):
  def __init__(self):

把类定义改为

class BaseNet(nn.Module):
  def __init__(self):
    super(BaseNet, self).__init__()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python3访问sina首页中文的处理方法
Feb 24 Python
Python 基础之字符串string详解及实例
Apr 01 Python
利用python为运维人员写一个监控脚本
Mar 25 Python
pandas多级分组实现排序的方法
Apr 20 Python
python3读取csv和xlsx文件的实例
Jun 22 Python
python实现矩阵打印
Mar 02 Python
这可能是最好玩的python GUI入门实例(推荐)
Jul 19 Python
处理Selenium3+python3定位鼠标悬停才显示的元素
Jul 31 Python
Python Selenium 之数据驱动测试的实现
Aug 01 Python
用Python画一个LinkinPark的logo代码实例
Sep 10 Python
django在开发中取消外键约束的实现
May 20 Python
python有几个版本
Jun 17 Python
解决pytorch 保存模型遇到的问题
Mar 03 #Python
解决pytorch 模型复制的一些问题
Mar 03 #Python
Pytorch模型迁移和迁移学习,导入部分模型参数的操作
Mar 03 #Python
pytorch 实现L2和L1正则化regularization的操作
Mar 03 #Python
Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作
Mar 03 #Python
python爬取youtube视频的示例代码
Mar 03 #Python
pytorch Dataset,DataLoader产生自定义的训练数据案例
Mar 03 #Python
You might like
PHP入门
2006/10/09 PHP
用 Composer构建自己的 PHP 框架之基础准备
2014/10/30 PHP
PHP连接SQLServer2005的方法
2015/01/27 PHP
wordpress网站转移到本地运行测试的方法
2017/03/15 PHP
PHP多进程之pcntl_fork的实例详解
2017/10/15 PHP
php面试中关于面向对象的相关问题
2019/02/13 PHP
解决laravel资源加载路径设置的问题
2019/10/14 PHP
javascript入门·图片对象(无刷新变换图片)\滚动图像
2007/10/01 Javascript
javascript 写类方式之五
2009/07/05 Javascript
网络图片延迟加载实现代码 超越jquery控件
2010/03/27 Javascript
最佳JS代码编写的14条技巧
2011/01/09 Javascript
jQuery操作CheckBox的方法介绍(选中,取消,取值)
2014/02/04 Javascript
纯html+css+javascript实现楼层跳跃式的页面布局(实例代码)
2017/10/25 Javascript
浅谈Vue.js 组件中的v-on绑定自定义事件理解
2017/11/17 Javascript
解决element-ui中下拉菜单子选项click事件不触发的问题
2018/08/22 Javascript
详解JavaScript 中的批处理和缓存
2020/11/19 Javascript
python之yield表达式学习
2014/09/02 Python
Python中获取网页状态码的两个方法
2014/11/03 Python
用Python代码来解图片迷宫的方法整理
2015/04/02 Python
从Python程序中访问Java类的简单示例
2015/04/20 Python
Python如何判断数独是否合法
2016/09/08 Python
分享一个可以生成各种进制格式IP的小工具实例代码
2017/07/28 Python
使用 Python 实现简单的 switch/case 语句的方法
2018/09/17 Python
Python数据预处理之数据规范化(归一化)示例
2019/01/08 Python
Django ORM多对多查询方法(自定义第三张表&ManyToManyField)
2019/08/09 Python
Python计算公交发车时间的完整代码
2020/02/12 Python
python3 自动打印出最新版本执行的mysql2redis实例
2020/04/09 Python
接口自动化多层嵌套json数据处理代码实例
2020/11/20 Python
德国购买踏板车网站:Microscooter
2019/10/14 全球购物
什么造成了Java里面的异常
2016/04/24 面试题
阿里巴巴Oracle DBA笔试题答案-备份恢复类
2013/11/20 面试题
美术教师自我鉴定
2014/02/12 职场文书
新春文艺演出主持词
2014/03/27 职场文书
个人委托函范文
2015/01/29 职场文书
大学生自我推荐信范文
2015/03/24 职场文书
vue3引入highlight.js进行代码高亮的方法实例
2022/04/08 Vue.js