pytorch 状态字典:state_dict使用详解


Posted in Python onJanuary 17, 2020

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

备注:

1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.

模态字典(state_dict)的保存(model是一个网络结构类的对象)

1.1)仅保存学习到的参数,用以下命令

torch.save(model.state_dict(), PATH)

1.2)加载model.state_dict,用以下命令

model = TheModelClass(*args, **kwargs)
 model.load_state_dict(torch.load(PATH))
 model.eval()

备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名

2.1)保存整个model的状态,用以下命令

torch.save(model,PATH)

2.2)加载整个model的状态,用以下命令:

# Model class must be defined somewhere

 model = torch.load(PATH)

 model.eval()

state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项

如何仅加载某一层的训练的到的参数(某一层的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)

for param in list(model.pretrained.parameters()):
 param.requires_grad = False

注意: requires_grad的操作对象是tensor.

疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:经测试,不可以.model.conv1 没有requires_grad属性.

全部测试代码:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
 
 
# define model
class TheModelClass(nn.Module):
 def __init__(self):
  super(TheModelClass,self).__init__()
  self.conv1 = nn.Conv2d(3,6,5)
  self.pool = nn.MaxPool2d(2,2)
  self.conv2 = nn.Conv2d(6,16,5)
  self.fc1 = nn.Linear(16*5*5,120)
  self.fc2 = nn.Linear(120,84)
  self.fc3 = nn.Linear(84,10)
 
 def forward(self,x):
  x = self.pool(F.relu(self.conv1(x)))
  x = self.pool(F.relu(self.conv2(x)))
  x = x.view(-1,16*5*5)
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x
 
# initial model
model = TheModelClass()
 
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
 print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
 print(var_name,'\t',optimizer.state_dict()[var_name])
 
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
 
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 仅仅加载某一层的参数
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
 
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

以上这篇pytorch 状态字典:state_dict使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用优化器来提升Python程序的执行效率的教程
Apr 02 Python
Python中自定义函数的教程
Apr 27 Python
python统计文本文件内单词数量的方法
May 30 Python
Python的Django框架中的Context使用
Jul 15 Python
Python中列表、字典、元组数据结构的简单学习笔记
Mar 20 Python
Python中使用装饰器来优化尾递归的示例
Jun 18 Python
Python处理json字符串转化为字典的简单实现
Jul 07 Python
Python实现字符串逆序输出功能示例
Jun 24 Python
Python实现扩展内置类型的方法分析
Oct 16 Python
Python列表list排列组合操作示例
Dec 18 Python
Python高斯消除矩阵
Jan 02 Python
django将网络中的图片,保存成model中的ImageField的实例
Aug 07 Python
Python标准库itertools的使用方法
Jan 17 #Python
Python实现投影法分割图像示例(二)
Jan 17 #Python
Python常用库大全及简要说明
Jan 17 #Python
Python Sphinx使用实例及问题解决
Jan 17 #Python
通过实例了解Python str()和repr()的区别
Jan 17 #Python
python无序链表删除重复项的方法
Jan 17 #Python
Python实现投影法分割图像示例(一)
Jan 17 #Python
You might like
php四种基础算法代码实例
2013/10/29 PHP
支持生僻字且自动识别utf-8编码的php汉字转拼音类
2014/06/27 PHP
php使用fputcsv()函数csv文件读写数据的方法
2015/01/06 PHP
微信公众号判断用户是否已关注php代码解析
2016/06/24 PHP
PHP不使用递归的无限级分类简单实例
2016/11/05 PHP
Laravel 5.4重新登录实现跳转到登录前页面的原理和方法
2017/07/13 PHP
比较详细的关于javascript中void(0)的具体含义解释
2007/08/02 Javascript
兼容多浏览器的字幕特效Marquee的通用js类
2008/07/20 Javascript
Jquery进度条插件 Progress Bar小问题解决
2011/07/12 Javascript
window.open关于浏览器拦截问题分析及解决方法
2013/02/05 Javascript
JavaScript控制各种浏览器全屏模式的方法、属性和事件介绍
2014/04/03 Javascript
JQuery中DOM加载与事件执行实例分析
2015/06/13 Javascript
JS实现弹性漂浮效果的广告代码
2015/09/02 Javascript
获取JS中网页各种高宽与位置的方法总结
2016/07/27 Javascript
从源码里了解vue中的nextTick的使用
2018/11/22 Javascript
[08:54]《一刀刀一天》之DOTA全时刻18:十九支奔赴西雅图队伍全部出炉
2014/06/04 DOTA
[10:18]2018DOTA2国际邀请赛寻真——Fnatic能否笑到最后?
2018/08/14 DOTA
Python中os.path用法分析
2015/01/15 Python
Python读取键盘输入的2种方法
2015/06/16 Python
解决uWSGI的编码问题详解
2017/03/24 Python
了解不常见但是实用的Python技巧
2019/05/23 Python
python求平均数、方差、中位数的例子
2019/08/22 Python
40行Python代码实现天气预报和每日鸡汤推送功能
2020/02/27 Python
使用Python和百度语音识别生成视频字幕的实现
2020/04/09 Python
Python基于内置函数type创建新类型
2020/10/22 Python
用python制作个音乐下载器
2021/01/30 Python
马来西亚在线时尚女装商店:KEI MAG
2017/09/28 全球购物
波兰多品牌运动商店:StreetStyle24.pl
2020/09/22 全球购物
新领导上任欢迎词
2014/01/13 职场文书
校园摄影活动策划方案
2014/02/05 职场文书
工会优秀工作者事迹
2014/08/17 职场文书
企业务虚会发言材料
2014/10/20 职场文书
党支部培养考察意见
2015/06/02 职场文书
银行资信证明
2015/06/17 职场文书
药品销售员2015年终工作总结
2015/10/22 职场文书
netty 实现tomcat的示例代码
2022/06/05 Servers