解决pytorch 的state_dict()拷贝问题


Posted in Python onMarch 03, 2021

先说结论

model.state_dict()是浅拷贝,返回的参数仍然会随着网络的训练而变化。

应该使用deepcopy(model.state_dict()),或将参数及时序列化到硬盘。

再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。

补充:pytorch中state_dict的理解

在PyTorch中,state_dict是一个Python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数) 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。

其实看了如下代码的输出应该就懂了

import torch
import torch.nn as nn
import torchvision
import numpy as np
from torchsummary import summary
# Define model
class TheModelClass(nn.Module):
  def __init__(self):
    super(TheModelClass, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)
  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
  print(param_tensor,"\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
  print(var_name, "\t", optimizer.state_dict()[var_name])

输出如下:

Model's state_dict:
conv1.weight  torch.Size([6, 3, 5, 5])
conv1.bias  torch.Size([6])
conv2.weight  torch.Size([16, 6, 5, 5])
conv2.bias  torch.Size([16])
fc1.weight  torch.Size([120, 400])
fc1.bias  torch.Size([120])
fc2.weight  torch.Size([84, 120])
fc2.bias  torch.Size([84])
fc3.weight  torch.Size([10, 84])
fc3.bias  torch.Size([10])
Optimizer's state_dict:
state  {}
param_groups  [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]

我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!

补充:pytorch保存模型时报错***object has no attribute 'state_dict'

定义了一个类BaseNet并实例化该类:

net=BaseNet()

保存net时报错 object has no attribute 'state_dict'

torch.save(net.state_dict(), models_dir)

原因是定义类的时候不是继承nn.Module类,比如:

class BaseNet(object):
  def __init__(self):

把类定义改为

class BaseNet(nn.Module):
  def __init__(self):
    super(BaseNet, self).__init__()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python中实现字符串类型与字典类型相互转换的方法
Aug 18 Python
深入解析Python中的urllib2模块
Nov 13 Python
Python实现图片尺寸缩放脚本
Mar 10 Python
python 筛选数据集中列中value长度大于20的数据集方法
Jun 14 Python
Python基本数据结构之字典类型dict用法分析
Jun 08 Python
python多任务之协程的使用详解
Aug 26 Python
Python 文件操作之读取文件(read),文件指针与写入文件(write),文件打开方式示例
Sep 29 Python
使用Numpy对特征中的异常值进行替换及条件替换方式
Jun 08 Python
Django中日期时间型字段进行年月日时分秒分组统计
Nov 27 Python
如何通过Python实现RabbitMQ延迟队列
Nov 28 Python
如何在python中实现ECDSA你知道吗
Nov 23 Python
Python使用OpenCV实现虚拟缩放效果
Feb 28 Python
解决pytorch 保存模型遇到的问题
Mar 03 #Python
解决pytorch 模型复制的一些问题
Mar 03 #Python
Pytorch模型迁移和迁移学习,导入部分模型参数的操作
Mar 03 #Python
pytorch 实现L2和L1正则化regularization的操作
Mar 03 #Python
Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作
Mar 03 #Python
python爬取youtube视频的示例代码
Mar 03 #Python
pytorch Dataset,DataLoader产生自定义的训练数据案例
Mar 03 #Python
You might like
整理的9个实用的PHP库简介和下载
2010/11/09 PHP
php用正则表达式匹配URL的简单方法
2013/11/12 PHP
destoon实现调用自增数字从1开始的方法
2014/08/21 PHP
php发送短信验证码完成注册功能
2015/11/24 PHP
如何打开php的gd2库
2017/02/09 PHP
PHP基于SMTP协议实现邮件发送实例代码
2017/04/27 PHP
PHP实现对xml的增删改查操作案例分析
2017/05/19 PHP
PHP7 其他语言层面的修改
2021/03/09 PHP
让你的网站可编辑的实现js代码
2009/10/19 Javascript
基于jquery的不规则矩形的排列实现代码
2012/04/16 Javascript
详解微信开发中snsapi_base和snsapi_userinfo及静默授权的实现
2017/03/11 Javascript
详解React-Native解决键盘遮挡问题(Keyboard遮挡问题)
2017/07/13 Javascript
深入理解Angularjs 脏值检测
2018/10/12 Javascript
Vue中的methods、watch、computed的区别
2018/11/26 Javascript
python requests 测试代理ip是否生效
2018/07/25 Python
python pytest进阶之fixture详解
2019/06/27 Python
OpenCV3.0+Python3.6实现特定颜色的物体追踪
2019/07/23 Python
Python(PyS60)实现简单语音整点报时
2019/11/18 Python
python给图像加上mask,并提取mask区域实例
2020/01/19 Python
pycharm解决关闭flask后依旧可以访问服务的问题
2020/04/03 Python
python新手学习可变和不可变对象
2020/06/11 Python
Python持续监听文件变化代码实例
2020/07/22 Python
使用Pytorch搭建模型的步骤
2020/11/16 Python
抽象方法、抽象类怎样声明
2014/10/25 面试题
物业经理求职自我评价
2013/09/22 职场文书
幼儿园保育员辞职信
2014/01/12 职场文书
阿德的梦教学反思
2014/02/06 职场文书
入党介绍人评语
2014/05/06 职场文书
广告宣传策划方案
2014/05/21 职场文书
运动会口号16字
2014/06/07 职场文书
家具商场的活动方案
2014/08/16 职场文书
2015年质量月活动总结报告
2015/03/27 职场文书
五星红旗迎风飘扬观后感
2015/06/17 职场文书
六一文艺汇演主持词
2015/06/30 职场文书
捐书仪式主持词
2015/07/04 职场文书
2019年特色火锅店的创业计划书模板
2019/08/28 职场文书