pytorch 状态字典:state_dict使用详解


Posted in Python onJanuary 17, 2020

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

备注:

1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.

模态字典(state_dict)的保存(model是一个网络结构类的对象)

1.1)仅保存学习到的参数,用以下命令

torch.save(model.state_dict(), PATH)

1.2)加载model.state_dict,用以下命令

model = TheModelClass(*args, **kwargs)
 model.load_state_dict(torch.load(PATH))
 model.eval()

备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名

2.1)保存整个model的状态,用以下命令

torch.save(model,PATH)

2.2)加载整个model的状态,用以下命令:

# Model class must be defined somewhere

 model = torch.load(PATH)

 model.eval()

state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项

如何仅加载某一层的训练的到的参数(某一层的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)

for param in list(model.pretrained.parameters()):
 param.requires_grad = False

注意: requires_grad的操作对象是tensor.

疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:经测试,不可以.model.conv1 没有requires_grad属性.

全部测试代码:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
 
 
# define model
class TheModelClass(nn.Module):
 def __init__(self):
  super(TheModelClass,self).__init__()
  self.conv1 = nn.Conv2d(3,6,5)
  self.pool = nn.MaxPool2d(2,2)
  self.conv2 = nn.Conv2d(6,16,5)
  self.fc1 = nn.Linear(16*5*5,120)
  self.fc2 = nn.Linear(120,84)
  self.fc3 = nn.Linear(84,10)
 
 def forward(self,x):
  x = self.pool(F.relu(self.conv1(x)))
  x = self.pool(F.relu(self.conv2(x)))
  x = x.view(-1,16*5*5)
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x
 
# initial model
model = TheModelClass()
 
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
 print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
 print(var_name,'\t',optimizer.state_dict()[var_name])
 
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
 
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 仅仅加载某一层的参数
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
 
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

以上这篇pytorch 状态字典:state_dict使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的装饰器用法详解
Jan 14 Python
linux环境下python中MySQLdb模块的安装方法
Jun 16 Python
matplotlib subplots 设置总图的标题方法
May 25 Python
mac下给python3安装requests库和scrapy库的实例
Jun 13 Python
用于业余项目的8个优秀Python库
Sep 21 Python
使用python的pexpect模块,实现远程免密登录的示例
Feb 14 Python
numpy数组之存取文件的实现示例
May 24 Python
给大家整理了19个pythonic的编程习惯(小结)
Sep 25 Python
pip 安装库比较慢的解决方法(国内镜像)
Oct 06 Python
Python使用Chrome插件实现爬虫过程图解
Jun 09 Python
Python字节单位转换(将字节转换为K M G T)
Mar 02 Python
python 中yaml文件用法大全
Jul 04 Python
Python标准库itertools的使用方法
Jan 17 #Python
Python实现投影法分割图像示例(二)
Jan 17 #Python
Python常用库大全及简要说明
Jan 17 #Python
Python Sphinx使用实例及问题解决
Jan 17 #Python
通过实例了解Python str()和repr()的区别
Jan 17 #Python
python无序链表删除重复项的方法
Jan 17 #Python
Python实现投影法分割图像示例(一)
Jan 17 #Python
You might like
安装PHP可能遇到的问题“无法载入mysql扩展” 的解决方法
2007/04/16 PHP
深入理解:XML与对象的序列化与反序列化
2013/06/08 PHP
php分页示例分享
2014/04/30 PHP
thinkphp表单上传文件并将文件路径保存到数据库中
2016/07/28 PHP
PHP实现二叉树深度优先遍历(前序、中序、后序)和广度优先遍历(层次)实例详解
2018/04/20 PHP
jquery+ashx无刷新GridView数据显示插件(实现分页、排序、过滤功能)
2010/04/25 Javascript
解决Extjs上传图片无法预览的解决方法
2012/03/22 Javascript
ExtJs默认的字体大小改变的几种方法(自己整理)
2013/04/18 Javascript
使用js对select动态添加和删除OPTION示例代码
2013/08/12 Javascript
jquery indexOf使用方法
2013/08/19 Javascript
jquery通过扩展select控件实现支持enter或focus选择的方法
2015/11/19 Javascript
jQuery UI Grid 模态框中的表格实例代码
2017/04/01 jQuery
node中koa中间件机制详解
2017/08/22 Javascript
vue打包之后生成一个配置文件修改接口的方法
2018/12/09 Javascript
jQuery实现的别踩白块小游戏完整示例
2019/01/07 jQuery
微信用户访问小程序的登录过程详解
2019/09/20 Javascript
如何通过Proxy实现JSBridge模块化封装
2020/10/22 Javascript
Python实现在线音乐播放器
2017/03/03 Python
Python+matplotlib实现计算两个信号的交叉谱密度实例
2018/01/08 Python
python求最大连续子数组的和
2018/07/07 Python
django中账号密码验证登陆功能的实现方法
2019/07/15 Python
Python3批量移动指定文件到指定文件夹方法示例
2019/09/02 Python
python实现小世界网络生成
2019/11/21 Python
Flask中endpoint的理解(小结)
2019/12/11 Python
django实现模型字段动态choice的操作
2020/04/01 Python
windows10在visual studio2019下配置使用openCV4.3.0
2020/07/14 Python
美国领先的家庭智能音响系统品牌:Sonos
2018/07/20 全球购物
递归计算如下递归函数的值(斐波拉契)
2012/02/04 面试题
公司培训欢迎词
2014/01/10 职场文书
省三好学生申请材料
2014/01/22 职场文书
物流管理专业自荐信
2014/06/23 职场文书
供用电专业求职信
2014/07/07 职场文书
2014年招商工作总结
2014/11/22 职场文书
客房服务员岗位职责
2015/02/09 职场文书
高中同学会致辞
2015/08/01 职场文书
Python语言内置数据类型
2022/02/24 Python