pytorch 状态字典:state_dict使用详解


Posted in Python onJanuary 17, 2020

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

备注:

1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.

模态字典(state_dict)的保存(model是一个网络结构类的对象)

1.1)仅保存学习到的参数,用以下命令

torch.save(model.state_dict(), PATH)

1.2)加载model.state_dict,用以下命令

model = TheModelClass(*args, **kwargs)
 model.load_state_dict(torch.load(PATH))
 model.eval()

备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名

2.1)保存整个model的状态,用以下命令

torch.save(model,PATH)

2.2)加载整个model的状态,用以下命令:

# Model class must be defined somewhere

 model = torch.load(PATH)

 model.eval()

state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项

如何仅加载某一层的训练的到的参数(某一层的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)

for param in list(model.pretrained.parameters()):
 param.requires_grad = False

注意: requires_grad的操作对象是tensor.

疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:经测试,不可以.model.conv1 没有requires_grad属性.

全部测试代码:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
 
 
# define model
class TheModelClass(nn.Module):
 def __init__(self):
  super(TheModelClass,self).__init__()
  self.conv1 = nn.Conv2d(3,6,5)
  self.pool = nn.MaxPool2d(2,2)
  self.conv2 = nn.Conv2d(6,16,5)
  self.fc1 = nn.Linear(16*5*5,120)
  self.fc2 = nn.Linear(120,84)
  self.fc3 = nn.Linear(84,10)
 
 def forward(self,x):
  x = self.pool(F.relu(self.conv1(x)))
  x = self.pool(F.relu(self.conv2(x)))
  x = x.view(-1,16*5*5)
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x
 
# initial model
model = TheModelClass()
 
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
 print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
 print(var_name,'\t',optimizer.state_dict()[var_name])
 
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
 
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 仅仅加载某一层的参数
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
 
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

以上这篇pytorch 状态字典:state_dict使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现Tab自动补全和历史命令管理的方法
Mar 12 Python
python内存管理分析
Apr 08 Python
python如何在列表、字典中筛选数据
Mar 19 Python
Pandas 数据框增、删、改、查、去重、抽样基本操作方法
Apr 12 Python
PyQt5每天必学之弹出消息框
Apr 19 Python
在Pycharm中将pyinstaller加入External Tools的方法
Jan 16 Python
在django项目中导出数据到excel文件并实现下载的功能
Mar 13 Python
为什么称python为胶水语言
Jun 16 Python
Python不支持 i ++ 语法的原因解析
Jul 22 Python
使用Python判断一个文件是否被占用的方法教程
Dec 16 Python
python-for x in range的用法(注意要点、细节)
May 10 Python
python APScheduler执行定时任务介绍
Apr 19 Python
Python标准库itertools的使用方法
Jan 17 #Python
Python实现投影法分割图像示例(二)
Jan 17 #Python
Python常用库大全及简要说明
Jan 17 #Python
Python Sphinx使用实例及问题解决
Jan 17 #Python
通过实例了解Python str()和repr()的区别
Jan 17 #Python
python无序链表删除重复项的方法
Jan 17 #Python
Python实现投影法分割图像示例(一)
Jan 17 #Python
You might like
将数字格式的计算结果转为汉字格式
2006/10/09 PHP
PHP无敌近乎加密方式!
2010/07/17 PHP
php Smarty 字符比较代码
2011/02/27 PHP
php相对当前文件include其它文件的方法
2015/03/13 PHP
yii2.0实现pathinfo的形式访问的配置方法
2016/04/06 PHP
50个优秀经典PHP算法大集合 附源码
2020/08/26 PHP
对联广告js flash激活
2006/10/19 Javascript
基于jQuery的让非HTML5浏览器支持placeholder属性的代码
2011/05/24 Javascript
jquery实现向下滑出的二级导航下滑菜单效果
2015/08/25 Javascript
JQuery+Ajax实现数据查询、排序和分页功能
2015/09/27 Javascript
jQuery javascript获得网页的高度与宽度的实现代码
2016/04/26 Javascript
JS实现点击网页判断是否安装app并打开否则跳转app store
2016/11/18 Javascript
EditPlus 正则表达式 实战(3)
2016/12/15 Javascript
jQuery Form插件使用详解_动力节点Java学院整理
2017/07/17 jQuery
详解js静态资源文件请求的处理
2017/08/01 Javascript
详解JavaScript中操作符和表达式
2018/09/12 Javascript
JavaScript遍历查找数组中最大值与最小值的方法示例
2019/05/24 Javascript
python切换hosts文件代码示例
2013/12/31 Python
python实现在无须过多援引的情况下创建字典的方法
2014/09/25 Python
Python实现导出数据生成excel报表的方法示例
2017/07/12 Python
python pillow模块使用方法详解
2019/08/30 Python
python使用rsa非对称加密过程解析
2019/12/28 Python
TensorFlow的reshape操作 tf.reshape的实现
2020/04/19 Python
在keras下实现多个模型的融合方式
2020/05/23 Python
Python脚本实现监听服务器的思路代码详解
2020/05/28 Python
CSS3制作皮卡丘动画壁纸的示例
2020/11/02 HTML / CSS
英国最受欢迎的母婴精品品牌:JoJo Maman BéBé
2021/02/17 全球购物
应届毕业生专业个人求职自荐信格式
2013/11/20 职场文书
创先争优制度
2014/01/21 职场文书
优秀体育委员自荐书
2014/01/31 职场文书
廉洁自律演讲稿
2014/05/22 职场文书
教师“一帮一”结对子活动总结
2015/05/07 职场文书
教师节简报
2015/07/20 职场文书
《成长的天空》读后感3篇
2019/12/06 职场文书
Python中22个万用公式的小结
2021/07/21 Python
python利用while求100内的整数和方式
2021/11/07 Python