解决pytorch 的state_dict()拷贝问题


Posted in Python onMarch 03, 2021

先说结论

model.state_dict()是浅拷贝,返回的参数仍然会随着网络的训练而变化。

应该使用deepcopy(model.state_dict()),或将参数及时序列化到硬盘。

再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。

补充:pytorch中state_dict的理解

在PyTorch中,state_dict是一个Python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数) 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。

其实看了如下代码的输出应该就懂了

import torch
import torch.nn as nn
import torchvision
import numpy as np
from torchsummary import summary
# Define model
class TheModelClass(nn.Module):
  def __init__(self):
    super(TheModelClass, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)
  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
  print(param_tensor,"\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
  print(var_name, "\t", optimizer.state_dict()[var_name])

输出如下:

Model's state_dict:
conv1.weight  torch.Size([6, 3, 5, 5])
conv1.bias  torch.Size([6])
conv2.weight  torch.Size([16, 6, 5, 5])
conv2.bias  torch.Size([16])
fc1.weight  torch.Size([120, 400])
fc1.bias  torch.Size([120])
fc2.weight  torch.Size([84, 120])
fc2.bias  torch.Size([84])
fc3.weight  torch.Size([10, 84])
fc3.bias  torch.Size([10])
Optimizer's state_dict:
state  {}
param_groups  [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]

我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!

补充:pytorch保存模型时报错***object has no attribute 'state_dict'

定义了一个类BaseNet并实例化该类:

net=BaseNet()

保存net时报错 object has no attribute 'state_dict'

torch.save(net.state_dict(), models_dir)

原因是定义类的时候不是继承nn.Module类,比如:

class BaseNet(object):
  def __init__(self):

把类定义改为

class BaseNet(nn.Module):
  def __init__(self):
    super(BaseNet, self).__init__()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python标准库之循环器(itertools)介绍
Nov 25 Python
python函数形参用法实例分析
Aug 04 Python
python爬虫 execjs安装配置及使用
Jul 30 Python
Python collections中的双向队列deque简单介绍详解
Nov 04 Python
python导入不同目录下的自定义模块过程解析
Nov 18 Python
简单了解python列表和元组的区别
May 14 Python
python算的上脚本语言吗
Jun 22 Python
详解基于python的全局与局部序列比对的实现(DNA)
Oct 07 Python
python 监控服务器是否有人远程登录(详细思路+代码)
Dec 18 Python
python实现三种随机请求头方式
Jan 05 Python
如何用python识别滑块验证码中的缺口
Apr 01 Python
Python数据分析入门之数据读取与存储
May 13 Python
解决pytorch 保存模型遇到的问题
Mar 03 #Python
解决pytorch 模型复制的一些问题
Mar 03 #Python
Pytorch模型迁移和迁移学习,导入部分模型参数的操作
Mar 03 #Python
pytorch 实现L2和L1正则化regularization的操作
Mar 03 #Python
Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作
Mar 03 #Python
python爬取youtube视频的示例代码
Mar 03 #Python
pytorch Dataset,DataLoader产生自定义的训练数据案例
Mar 03 #Python
You might like
Terran兵种介绍
2020/03/14 星际争霸
四个常见html网页乱码问题及解决办法
2015/09/08 PHP
PHP编写登录验证码功能 附调用方法
2016/05/19 PHP
javascript 鼠标悬浮图片显示原图 移出鼠标后原图消失(多图)
2009/12/28 Javascript
js运动框架_包括图片的淡入淡出效果
2013/05/11 Javascript
jquery教程ajax请求json数据示例
2014/01/13 Javascript
js常用自定义公共函数汇总
2014/01/15 Javascript
jquery取消选择select下拉框示例代码
2014/02/22 Javascript
javascript 操作符(~、&、|、^、)使用案例
2014/12/31 Javascript
Javascript6中字符串的四个新用法分享
2016/09/11 Javascript
jQuery grep()方法详解及实例代码
2016/10/30 Javascript
JS实现的数字格式化功能示例
2017/02/10 Javascript
jQuery实现遍历复选框的方法示例
2017/03/06 Javascript
用nodeJS搭建本地文件服务器的几种方法小结
2017/03/16 NodeJs
微信小程序tabbar不显示解决办法
2017/06/08 Javascript
vue操作动画的记录animate.css实例代码
2019/04/26 Javascript
js 实现碰撞检测的示例
2020/10/28 Javascript
Python中使用Beautiful Soup库的超详细教程
2015/04/30 Python
Python记录详细调用堆栈日志的方法
2015/05/05 Python
python实现猜数字小游戏
2020/03/24 Python
利用pandas进行大文件计数处理的方法
2018/07/25 Python
Python面向对象基础入门之编码细节与注意事项
2018/12/11 Python
pycharm new project变成灰色的解决方法
2019/06/27 Python
Python实现括号匹配方法详解
2020/02/10 Python
使用keras实现孪生网络中的权值共享教程
2020/06/11 Python
Python 解析简单的XML数据
2020/07/24 Python
ebookers英国:隶属全球最大的在线旅游公司Expedia
2017/12/28 全球购物
乡镇干部先进事迹材料
2014/02/03 职场文书
数控技校生自我鉴定
2014/03/02 职场文书
技术合作协议书范本
2014/04/18 职场文书
食品工程专业求职信
2014/06/15 职场文书
群众路线教育实践活动实施方案
2014/10/31 职场文书
青年岗位能手事迹材料
2014/12/23 职场文书
工程移交协议书
2016/03/24 职场文书
Golang 语言控制并发 Goroutine的方法
2021/06/30 Golang
python中mongodb包操作数据库
2022/04/19 Python