解决pytorch 的state_dict()拷贝问题


Posted in Python onMarch 03, 2021

先说结论

model.state_dict()是浅拷贝,返回的参数仍然会随着网络的训练而变化。

应该使用deepcopy(model.state_dict()),或将参数及时序列化到硬盘。

再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。

补充:pytorch中state_dict的理解

在PyTorch中,state_dict是一个Python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数) 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。

其实看了如下代码的输出应该就懂了

import torch
import torch.nn as nn
import torchvision
import numpy as np
from torchsummary import summary
# Define model
class TheModelClass(nn.Module):
  def __init__(self):
    super(TheModelClass, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)
  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
  print(param_tensor,"\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
  print(var_name, "\t", optimizer.state_dict()[var_name])

输出如下:

Model's state_dict:
conv1.weight  torch.Size([6, 3, 5, 5])
conv1.bias  torch.Size([6])
conv2.weight  torch.Size([16, 6, 5, 5])
conv2.bias  torch.Size([16])
fc1.weight  torch.Size([120, 400])
fc1.bias  torch.Size([120])
fc2.weight  torch.Size([84, 120])
fc2.bias  torch.Size([84])
fc3.weight  torch.Size([10, 84])
fc3.bias  torch.Size([10])
Optimizer's state_dict:
state  {}
param_groups  [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]

我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!

补充:pytorch保存模型时报错***object has no attribute 'state_dict'

定义了一个类BaseNet并实例化该类:

net=BaseNet()

保存net时报错 object has no attribute 'state_dict'

torch.save(net.state_dict(), models_dir)

原因是定义类的时候不是继承nn.Module类,比如:

class BaseNet(object):
  def __init__(self):

把类定义改为

class BaseNet(nn.Module):
  def __init__(self):
    super(BaseNet, self).__init__()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python入门篇之字典
Oct 17 Python
python通过post提交数据的方法
May 06 Python
Python2.7基于淘宝接口获取IP地址所在地理位置的方法【测试可用】
Jun 07 Python
Python实现的rsa加密算法详解
Jan 24 Python
python之pandas用法大全
Mar 13 Python
对python的unittest架构公共参数token提取方法详解
Dec 17 Python
python树莓派红外反射传感器
Jan 21 Python
python在新的图片窗口显示图片(图像)的方法
Jul 11 Python
Python中zip()函数的简单用法举例
Sep 02 Python
python实现遍历文件夹图片并重命名
Mar 23 Python
python转化excel数字日期为标准日期操作
Jul 14 Python
如何用Python徒手写线性回归
Jan 25 Python
解决pytorch 保存模型遇到的问题
Mar 03 #Python
解决pytorch 模型复制的一些问题
Mar 03 #Python
Pytorch模型迁移和迁移学习,导入部分模型参数的操作
Mar 03 #Python
pytorch 实现L2和L1正则化regularization的操作
Mar 03 #Python
Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作
Mar 03 #Python
python爬取youtube视频的示例代码
Mar 03 #Python
pytorch Dataset,DataLoader产生自定义的训练数据案例
Mar 03 #Python
You might like
php实现Linux服务器木马排查及加固功能
2014/12/29 PHP
php curl抓取网页的介绍和推广及使用CURL抓取淘宝页面集成方法
2015/11/30 PHP
DEDE实现转跳属性文档在模板上调用出转跳地址
2016/11/04 PHP
[原创]php集成安装包wampserver修改密码后phpmyadmin无法登陆的解决方法
2016/11/23 PHP
php使用PDO执行SQL语句的方法分析
2017/02/16 PHP
PHP实现随机生成水印图片功能
2017/03/22 PHP
php使用环形链表解决约瑟夫问题完整示例
2018/08/07 PHP
javascript删除数组重复元素的方法汇总
2015/06/24 Javascript
Angular 中 select指令用法详解
2016/09/29 Javascript
jQuery实现搜索页面关键字的功能
2017/02/16 Javascript
angular ng-click防止重复提交实例
2017/06/16 Javascript
Vue.js组件通信的几种姿势
2017/10/23 Javascript
js自定义input文件上传样式
2018/10/26 Javascript
JS div匀速移动动画与变速移动动画代码实例
2019/03/26 Javascript
微信小程序 setData 对 data数据影响问题
2019/04/18 Javascript
el-select 下拉框多选实现全选的实现
2019/08/02 Javascript
对layer弹出框中icon数字参数的说明介绍
2019/09/04 Javascript
es6中new.target的作用和使用场景简单示例分析
2020/03/14 Javascript
Vue组件简易模拟实现购物车
2020/12/21 Vue.js
[54:10]完美世界DOTA2联赛PWL S2 Magma vs FTD 第二场 11.29
2020/12/03 DOTA
python转换摩斯密码示例
2014/02/16 Python
搭建Python的Django框架环境并建立和运行第一个App的教程
2016/07/02 Python
python tornado修改log输出方式
2019/11/18 Python
python文件和文件夹复制函数
2020/02/07 Python
Python实现在Windows平台修改文件属性
2020/03/05 Python
美国二手复古奢侈品包包购物网站:LXRandCo
2019/06/18 全球购物
美国购买韩国护肤和美容产品网站:Althea Korea
2020/11/16 全球购物
健康家庭事迹材料
2014/05/02 职场文书
企业负责人任命书
2014/06/05 职场文书
安全责任书范文
2014/08/25 职场文书
高中生学习计划书
2014/09/15 职场文书
2016优秀班主任个人先进事迹材料
2016/02/26 职场文书
你离财务总监还有多远?速览CFO的岗位职责
2019/11/18 职场文书
如何使用php生成zip压缩包
2021/04/21 PHP
Python基础知识学习之类的继承
2021/05/31 Python
Vue3.0中Ref与Reactive的区别示例详析
2021/07/07 Vue.js