解决pytorch 的state_dict()拷贝问题


Posted in Python onMarch 03, 2021

先说结论

model.state_dict()是浅拷贝,返回的参数仍然会随着网络的训练而变化。

应该使用deepcopy(model.state_dict()),或将参数及时序列化到硬盘。

再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。

补充:pytorch中state_dict的理解

在PyTorch中,state_dict是一个Python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数) 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。

其实看了如下代码的输出应该就懂了

import torch
import torch.nn as nn
import torchvision
import numpy as np
from torchsummary import summary
# Define model
class TheModelClass(nn.Module):
  def __init__(self):
    super(TheModelClass, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)
  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
  print(param_tensor,"\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
  print(var_name, "\t", optimizer.state_dict()[var_name])

输出如下:

Model's state_dict:
conv1.weight  torch.Size([6, 3, 5, 5])
conv1.bias  torch.Size([6])
conv2.weight  torch.Size([16, 6, 5, 5])
conv2.bias  torch.Size([16])
fc1.weight  torch.Size([120, 400])
fc1.bias  torch.Size([120])
fc2.weight  torch.Size([84, 120])
fc2.bias  torch.Size([84])
fc3.weight  torch.Size([10, 84])
fc3.bias  torch.Size([10])
Optimizer's state_dict:
state  {}
param_groups  [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]

我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!

补充:pytorch保存模型时报错***object has no attribute 'state_dict'

定义了一个类BaseNet并实例化该类:

net=BaseNet()

保存net时报错 object has no attribute 'state_dict'

torch.save(net.state_dict(), models_dir)

原因是定义类的时候不是继承nn.Module类,比如:

class BaseNet(object):
  def __init__(self):

把类定义改为

class BaseNet(nn.Module):
  def __init__(self):
    super(BaseNet, self).__init__()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python访问抓取网页常用命令总结
Apr 11 Python
Django数据库操作的实例(增删改查)
Sep 04 Python
Python实现针对含中文字符串的截取功能示例
Sep 22 Python
python把数组中的数字每行打印3个并保存在文档中的方法
Jul 17 Python
解决python有时候import不了当前的包问题
Aug 28 Python
python打印直角三角形与等腰三角形实例代码
Oct 20 Python
Django3.0 异步通信初体验(小结)
Dec 04 Python
python学生信息管理系统实现代码
Dec 17 Python
解决Python数据可视化中文部分显示方块问题
May 16 Python
解决pycharm中的run和debug失效无法点击运行
Jun 09 Python
解决import tensorflow导致jupyter内核死亡的问题
Feb 06 Python
Python使用paramiko连接远程服务器执行Shell命令的实现
Mar 04 Python
解决pytorch 保存模型遇到的问题
Mar 03 #Python
解决pytorch 模型复制的一些问题
Mar 03 #Python
Pytorch模型迁移和迁移学习,导入部分模型参数的操作
Mar 03 #Python
pytorch 实现L2和L1正则化regularization的操作
Mar 03 #Python
Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作
Mar 03 #Python
python爬取youtube视频的示例代码
Mar 03 #Python
pytorch Dataset,DataLoader产生自定义的训练数据案例
Mar 03 #Python
You might like
PHP图片上传类带图片显示
2006/11/25 PHP
php时间戳转换的示例
2014/03/31 PHP
php验证码的制作思路和实现方法
2015/11/12 PHP
PHP的Yii框架的常用日志操作总结
2015/12/08 PHP
最新版本PHP 7 vs HHVM 多角度比较
2016/02/14 PHP
基于PHP生成简单的验证码
2016/06/01 PHP
PHP常见字符串处理函数用法示例【转换,转义,截取,比较,查找,反转,切割】
2016/12/24 PHP
浅谈PHP封装CURL
2019/03/06 PHP
javascript 关闭IE6、IE7
2009/06/01 Javascript
js switch case default 的用法示例介绍
2013/10/23 Javascript
动态加载脚本提升javascript性能
2014/02/24 Javascript
node.js中的fs.renameSync方法使用说明
2014/12/16 Javascript
JavaScript 匿名函数和闭包介绍
2015/04/13 Javascript
JQUERY表单暂存功能插件分享
2016/02/23 Javascript
jQuery Ajax 实例代码 ($.ajax、$.post、$.get)
2016/04/29 Javascript
实例分析浏览器中“JavaScript解析器”的工作原理
2016/12/12 Javascript
零基础轻松学JavaScript闭包
2016/12/30 Javascript
常见的浏览器Hack技巧整理
2017/06/29 Javascript
微信小程序引用iconfont图标的方法
2018/10/22 Javascript
JS实现悬浮球只在一侧滑动并且是横屏状态下
2020/08/19 Javascript
python抓取豆瓣图片并自动保存示例学习
2014/01/10 Python
Python实现Linux下守护进程的编写方法
2014/08/22 Python
Python的高级Git库 Gittle
2014/09/22 Python
遍历python字典几种方法总结(推荐)
2016/09/11 Python
Python3解决棋盘覆盖问题的方法示例
2017/12/07 Python
Python+tkinter使用40行代码实现计算器功能
2018/01/30 Python
pandas DataFrame行或列的删除方法的实现示例
2019/08/02 Python
jupyter notebook中美观显示矩阵实例
2020/04/17 Python
python集合的新增元素方法整理
2020/12/07 Python
VSCODE配置Markdown及Markdown基础语法详解
2021/01/19 Python
纯CSS3实现鼠标滑过按钮动画第二节
2020/07/16 HTML / CSS
维多利亚的秘密官方网站:Victoria’s Secret
2018/10/24 全球购物
SmartBuyGlasses德国:购买太阳镜和眼镜
2019/08/20 全球购物
交通文明倡议书
2014/05/16 职场文书
go web 预防跨站脚本的实现方式
2021/06/11 Golang
Android RecyclerView实现九宫格效果
2022/06/28 Java/Android