解决pytorch 的state_dict()拷贝问题


Posted in Python onMarch 03, 2021

先说结论

model.state_dict()是浅拷贝,返回的参数仍然会随着网络的训练而变化。

应该使用deepcopy(model.state_dict()),或将参数及时序列化到硬盘。

再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。

补充:pytorch中state_dict的理解

在PyTorch中,state_dict是一个Python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数) 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。

其实看了如下代码的输出应该就懂了

import torch
import torch.nn as nn
import torchvision
import numpy as np
from torchsummary import summary
# Define model
class TheModelClass(nn.Module):
  def __init__(self):
    super(TheModelClass, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)
  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
  print(param_tensor,"\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
  print(var_name, "\t", optimizer.state_dict()[var_name])

输出如下:

Model's state_dict:
conv1.weight  torch.Size([6, 3, 5, 5])
conv1.bias  torch.Size([6])
conv2.weight  torch.Size([16, 6, 5, 5])
conv2.bias  torch.Size([16])
fc1.weight  torch.Size([120, 400])
fc1.bias  torch.Size([120])
fc2.weight  torch.Size([84, 120])
fc2.bias  torch.Size([84])
fc3.weight  torch.Size([10, 84])
fc3.bias  torch.Size([10])
Optimizer's state_dict:
state  {}
param_groups  [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]

我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!

补充:pytorch保存模型时报错***object has no attribute 'state_dict'

定义了一个类BaseNet并实例化该类:

net=BaseNet()

保存net时报错 object has no attribute 'state_dict'

torch.save(net.state_dict(), models_dir)

原因是定义类的时候不是继承nn.Module类,比如:

class BaseNet(object):
  def __init__(self):

把类定义改为

class BaseNet(nn.Module):
  def __init__(self):
    super(BaseNet, self).__init__()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python实现全局变量的两个解决方法
Jul 03 Python
Python矩阵常见运算操作实例总结
Sep 29 Python
python 实现在Excel末尾增加新行
May 02 Python
浅谈python中requests模块导入的问题
May 18 Python
python计算两个数的百分比方法
Jun 29 Python
Python计算库numpy进行方差/标准方差/样本标准方差/协方差的计算
Dec 28 Python
Python 变量的创建过程详解
Sep 02 Python
Pytorch的mean和std调查实例
Jan 02 Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 Python
在 Windows 下搭建高效的 django 开发环境的详细教程
Jul 27 Python
Django开发RESTful API实现增删改查(入门级)
May 10 Python
在 Python 中利用 Pool 进行多线程
Apr 24 Python
解决pytorch 保存模型遇到的问题
Mar 03 #Python
解决pytorch 模型复制的一些问题
Mar 03 #Python
Pytorch模型迁移和迁移学习,导入部分模型参数的操作
Mar 03 #Python
pytorch 实现L2和L1正则化regularization的操作
Mar 03 #Python
Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作
Mar 03 #Python
python爬取youtube视频的示例代码
Mar 03 #Python
pytorch Dataset,DataLoader产生自定义的训练数据案例
Mar 03 #Python
You might like
德生9700DX电路分析
2021/03/02 无线电
php学习之运算符相关概念
2011/06/09 PHP
解析PHP中empty is_null和isset的测试
2013/06/29 PHP
浅析PHP Socket技术
2013/08/02 PHP
php实现的九九乘法口诀表简洁版
2014/07/28 PHP
PHP中curl_setopt函数用法实例分析
2015/04/16 PHP
PHP中异常处理的一些方法整理
2015/07/03 PHP
wampserver改变默认网站目录的办法
2015/08/05 PHP
PHP实现仿百度文库,豆丁在线文档效果(word,excel,ppt转flash)
2016/03/10 PHP
php实现将HTML页面转换成word并且保存的方法
2016/10/14 PHP
使用jquery与图片美化checkbox和radio控件的代码(打包下载)
2010/11/11 Javascript
Jquery index()方法 获取相应元素索引值
2012/10/12 Javascript
JS 模态对话框和非模态对话框操作技巧汇总
2013/04/15 Javascript
angularjs实现与服务器交互分享
2014/06/24 Javascript
jQuery实现感应鼠标动画效果自动伸长的输入框实例
2015/02/24 Javascript
javascript实现了照片拖拽点击置顶的照片墙代码
2015/04/03 Javascript
js实现仿MSN带关闭功能的右下角弹窗代码
2015/09/04 Javascript
Jquery插件之Fancybox丰富的弹出层效果附源码下载
2015/12/02 Javascript
jQuery基于扩展实现的倒计时效果
2016/05/14 Javascript
原生js实现焦点轮播图效果
2017/01/12 Javascript
javascript实现Emrips反质数枚举的示例代码
2017/12/06 Javascript
浅谈微信小程序之官方UI框架we-ui使用教程
2018/08/20 Javascript
javascript 数组精简技巧小结
2020/02/26 Javascript
通过实例了解Nodejs模块系统及require机制
2020/07/16 NodeJs
TensorFlow神经网络优化策略学习
2018/03/09 Python
使用Python的SymPy库解决数学运算问题的方法
2019/03/27 Python
python实现基于朴素贝叶斯的垃圾分类算法
2019/07/09 Python
一封普通求职者的求职信
2013/11/20 职场文书
给实习单位的感谢信
2014/02/01 职场文书
工商治理实习生的自我评价分享
2014/02/20 职场文书
管理建议书范文
2014/05/13 职场文书
实习单位指导教师评语
2014/12/30 职场文书
2015年八一建军节演讲稿
2015/03/19 职场文书
物流业务员岗位职责
2015/04/03 职场文书
具结保证书范本
2015/05/11 职场文书
2015入党自传书范文
2015/06/26 职场文书