pytorch 状态字典:state_dict使用详解


Posted in Python onJanuary 17, 2020

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

备注:

1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.

模态字典(state_dict)的保存(model是一个网络结构类的对象)

1.1)仅保存学习到的参数,用以下命令

torch.save(model.state_dict(), PATH)

1.2)加载model.state_dict,用以下命令

model = TheModelClass(*args, **kwargs)
 model.load_state_dict(torch.load(PATH))
 model.eval()

备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名

2.1)保存整个model的状态,用以下命令

torch.save(model,PATH)

2.2)加载整个model的状态,用以下命令:

# Model class must be defined somewhere

 model = torch.load(PATH)

 model.eval()

state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项

如何仅加载某一层的训练的到的参数(某一层的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)

for param in list(model.pretrained.parameters()):
 param.requires_grad = False

注意: requires_grad的操作对象是tensor.

疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:经测试,不可以.model.conv1 没有requires_grad属性.

全部测试代码:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
 
 
# define model
class TheModelClass(nn.Module):
 def __init__(self):
  super(TheModelClass,self).__init__()
  self.conv1 = nn.Conv2d(3,6,5)
  self.pool = nn.MaxPool2d(2,2)
  self.conv2 = nn.Conv2d(6,16,5)
  self.fc1 = nn.Linear(16*5*5,120)
  self.fc2 = nn.Linear(120,84)
  self.fc3 = nn.Linear(84,10)
 
 def forward(self,x):
  x = self.pool(F.relu(self.conv1(x)))
  x = self.pool(F.relu(self.conv2(x)))
  x = x.view(-1,16*5*5)
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x
 
# initial model
model = TheModelClass()
 
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
 print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
 print(var_name,'\t',optimizer.state_dict()[var_name])
 
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
 
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 仅仅加载某一层的参数
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
 
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

以上这篇pytorch 状态字典:state_dict使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python程序员开发中常犯的10个错误
Jul 07 Python
linux下python抓屏实现方法
May 22 Python
用pickle存储Python的原生对象方法
Apr 28 Python
Python 比较两个数组的元素的异同方法
Aug 17 Python
对python 各种删除文件失败的处理方式分享
Apr 24 Python
Django xadmin开启搜索功能的实现
Nov 15 Python
解决python -m pip install --upgrade pip 升级不成功问题
Mar 05 Python
PyQt5中QSpinBox计数器的实现
Jan 18 Python
从Pytorch模型pth文件中读取参数成numpy矩阵的操作
Mar 04 Python
Python爬取用户观影数据并分析用户与电影之间的隐藏信息!
Jun 29 Python
Selenium浏览器自动化如何上传文件
Apr 06 Python
Django框架中视图的用法
Jun 10 Python
Python标准库itertools的使用方法
Jan 17 #Python
Python实现投影法分割图像示例(二)
Jan 17 #Python
Python常用库大全及简要说明
Jan 17 #Python
Python Sphinx使用实例及问题解决
Jan 17 #Python
通过实例了解Python str()和repr()的区别
Jan 17 #Python
python无序链表删除重复项的方法
Jan 17 #Python
Python实现投影法分割图像示例(一)
Jan 17 #Python
You might like
风格模板初级不完全修改教程
2006/10/09 PHP
PHP聊天室技术
2006/10/09 PHP
Codeigniter中mkdir创建目录遇到权限问题和解决方法
2014/07/25 PHP
PHP下的Oracle客户端扩展(OCI8)安装教程
2014/09/10 PHP
自己写的php中文截取函数mb_strlen和mb_substr
2015/02/09 PHP
Yii2搭建后台并实现rbac权限控制完整实例教程
2016/04/28 PHP
PHP封装cURL工具类与应用示例
2019/07/01 PHP
动态加载js的几种方法
2006/10/23 Javascript
Jquery 快速构建可拖曳的购物车DragDrop
2009/11/30 Javascript
jQuery操作 input type=checkbox的实现代码
2012/06/14 Javascript
网页前端优化之滚动延时加载图片示例
2013/07/13 Javascript
javascript事件函数中获得事件源的两种不错方法
2014/03/17 Javascript
jquery使用正则表达式验证email地址的方法
2015/01/22 Javascript
jQuery表单对象属性过滤选择器实例详解
2016/09/13 Javascript
jQuery中$.ajax()方法参数解析
2016/10/22 Javascript
JS正则匹配中文的方法示例
2017/01/06 Javascript
Vue.js学习笔记之修饰符详解
2017/07/25 Javascript
原生js+cookie实现购物车功能的方法分析
2017/12/21 Javascript
Vue高版本中一些新特性的使用详解
2018/09/25 Javascript
微信小程序模板消息限制实现无限制主动推送的示例代码
2019/08/27 Javascript
JS 5种遍历对象的方式
2020/06/16 Javascript
Vant+postcss-pxtorem 实现浏览器适配功能
2021/02/05 Javascript
[05:45]Ti4观战指南(下)
2014/07/07 DOTA
Python中关于Sequence切片的下标问题详解
2017/06/15 Python
Python实现购物系统(示例讲解)
2017/09/13 Python
python3实现绘制二维点图
2019/12/04 Python
python求最大公约数和最小公倍数的简单方法
2020/02/13 Python
python词云库wordcloud的使用方法与实例详解
2020/02/17 Python
CSS3中线性颜色渐变的一些实现方法
2015/07/14 HTML / CSS
毕业生自荐信
2013/12/14 职场文书
社区七一党员活动方案
2014/01/25 职场文书
小学英语教学反思案例
2014/02/04 职场文书
爱心捐书活动总结
2014/07/05 职场文书
反对四风自我剖析材料
2014/10/07 职场文书
学校勤俭节约倡议书
2015/04/29 职场文书
八一建军节主持词
2015/07/01 职场文书