pytorch 状态字典:state_dict使用详解


Posted in Python onJanuary 17, 2020

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

备注:

1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.

模态字典(state_dict)的保存(model是一个网络结构类的对象)

1.1)仅保存学习到的参数,用以下命令

torch.save(model.state_dict(), PATH)

1.2)加载model.state_dict,用以下命令

model = TheModelClass(*args, **kwargs)
 model.load_state_dict(torch.load(PATH))
 model.eval()

备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名

2.1)保存整个model的状态,用以下命令

torch.save(model,PATH)

2.2)加载整个model的状态,用以下命令:

# Model class must be defined somewhere

 model = torch.load(PATH)

 model.eval()

state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项

如何仅加载某一层的训练的到的参数(某一层的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)

for param in list(model.pretrained.parameters()):
 param.requires_grad = False

注意: requires_grad的操作对象是tensor.

疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:经测试,不可以.model.conv1 没有requires_grad属性.

全部测试代码:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
 
 
# define model
class TheModelClass(nn.Module):
 def __init__(self):
  super(TheModelClass,self).__init__()
  self.conv1 = nn.Conv2d(3,6,5)
  self.pool = nn.MaxPool2d(2,2)
  self.conv2 = nn.Conv2d(6,16,5)
  self.fc1 = nn.Linear(16*5*5,120)
  self.fc2 = nn.Linear(120,84)
  self.fc3 = nn.Linear(84,10)
 
 def forward(self,x):
  x = self.pool(F.relu(self.conv1(x)))
  x = self.pool(F.relu(self.conv2(x)))
  x = x.view(-1,16*5*5)
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x
 
# initial model
model = TheModelClass()
 
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
 print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
 print(var_name,'\t',optimizer.state_dict()[var_name])
 
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
 
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 仅仅加载某一层的参数
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
 
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

以上这篇pytorch 状态字典:state_dict使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python随手笔记之标准类型内建函数
Dec 02 Python
Python实现二分查找与bisect模块详解
Jan 13 Python
Python如何快速实现分布式任务
Jul 06 Python
Python实现检测文件MD5值的方法示例
Apr 11 Python
python程序快速缩进多行代码方法总结
Jun 23 Python
如何使用Flask-Migrate拓展数据库表结构
Jul 24 Python
python pandas cumsum求累计次数的用法
Jul 29 Python
pygame实现弹球游戏
Apr 14 Python
Python pip安装模块提示错误解决方案
May 22 Python
opencv 图像礼帽和图像黑帽的实现
Jul 07 Python
Python collections.deque双边队列原理详解
Oct 05 Python
Python实现单例模式的5种方法
Jun 15 Python
Python标准库itertools的使用方法
Jan 17 #Python
Python实现投影法分割图像示例(二)
Jan 17 #Python
Python常用库大全及简要说明
Jan 17 #Python
Python Sphinx使用实例及问题解决
Jan 17 #Python
通过实例了解Python str()和repr()的区别
Jan 17 #Python
python无序链表删除重复项的方法
Jan 17 #Python
Python实现投影法分割图像示例(一)
Jan 17 #Python
You might like
使用array_map简单搞定PHP删除文件、删除目录
2014/10/29 PHP
phpStudy访问速度慢和启动失败的解决办法
2015/11/19 PHP
PHP中使用array函数新建一个数组
2015/11/19 PHP
php中关于换行的实例写法
2019/09/26 PHP
yii框架结合charjs统计上一年与当前年数据的方法示例
2020/04/04 PHP
jquery的Tooltip插件 qtip使用详细说明
2010/09/08 Javascript
对象无length属性时IE6/IE7中无法将其转换成伪数组(ArrayLike)
2011/07/31 Javascript
点击显示指定元素隐藏其他同辈元素的方法
2014/02/19 Javascript
php读取sqlite数据库入门实例代码
2014/06/25 Javascript
jquery使用remove()方法删除指定class子元素
2015/03/26 Javascript
js Canvas实现圆形时钟教程
2016/09/19 Javascript
微信小程序 选择器(时间,日期,地区)实例详解
2016/11/16 Javascript
javascript监听页面刷新和页面关闭事件方法详解
2017/01/09 Javascript
浅谈原型对象的常用开发模式
2017/07/22 Javascript
angularjs实现对表单输入改变的监控(ng-change和watch两种方式)
2018/08/29 Javascript
详解jQuery-each()方法
2019/03/13 jQuery
vue2配置scss的方法步骤
2019/06/06 Javascript
vue 开发之路由配置方法详解
2019/12/02 Javascript
ES6学习笔记之字符串、数组、对象、函数新增知识点实例分析
2020/01/22 Javascript
[01:45]DOTA2众星出演!DSPL刀塔次级职业联赛宣传片
2014/11/21 DOTA
[52:31]VP vs Serenity 2018国际邀请赛小组赛BO2 第二场 8.16
2018/08/17 DOTA
Python中logging模块的用法实例
2014/09/29 Python
高效使用Python字典的清单
2018/04/04 Python
解决python3爬虫无法显示中文的问题
2018/04/12 Python
解决Pandas to_json()中文乱码,转化为json数组的问题
2018/05/10 Python
python psutil监控进程实例
2019/12/17 Python
Html5移动端弹幕动画实现示例代码
2018/08/27 HTML / CSS
美国独家设计师眼镜在线光学商店:Glasses Gallery
2017/12/28 全球购物
比利时香水网上商店:NOTINO
2018/03/28 全球购物
好家长事迹材料
2014/01/23 职场文书
美容院营销方案
2014/03/05 职场文书
企业管理毕业生求职信范文
2014/03/07 职场文书
实习介绍信模板
2015/01/30 职场文书
用Python简陋模拟n阶魔方
2021/04/17 Python
Python中for后接else的语法使用
2021/05/18 Python
vue elementUI批量上传文件
2022/04/26 Vue.js