pytorch 状态字典:state_dict使用详解


Posted in Python onJanuary 17, 2020

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

备注:

1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.

模态字典(state_dict)的保存(model是一个网络结构类的对象)

1.1)仅保存学习到的参数,用以下命令

torch.save(model.state_dict(), PATH)

1.2)加载model.state_dict,用以下命令

model = TheModelClass(*args, **kwargs)
 model.load_state_dict(torch.load(PATH))
 model.eval()

备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名

2.1)保存整个model的状态,用以下命令

torch.save(model,PATH)

2.2)加载整个model的状态,用以下命令:

# Model class must be defined somewhere

 model = torch.load(PATH)

 model.eval()

state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项

如何仅加载某一层的训练的到的参数(某一层的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)

for param in list(model.pretrained.parameters()):
 param.requires_grad = False

注意: requires_grad的操作对象是tensor.

疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:经测试,不可以.model.conv1 没有requires_grad属性.

全部测试代码:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
 
 
# define model
class TheModelClass(nn.Module):
 def __init__(self):
  super(TheModelClass,self).__init__()
  self.conv1 = nn.Conv2d(3,6,5)
  self.pool = nn.MaxPool2d(2,2)
  self.conv2 = nn.Conv2d(6,16,5)
  self.fc1 = nn.Linear(16*5*5,120)
  self.fc2 = nn.Linear(120,84)
  self.fc3 = nn.Linear(84,10)
 
 def forward(self,x):
  x = self.pool(F.relu(self.conv1(x)))
  x = self.pool(F.relu(self.conv2(x)))
  x = x.view(-1,16*5*5)
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x
 
# initial model
model = TheModelClass()
 
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
 print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
 print(var_name,'\t',optimizer.state_dict()[var_name])
 
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
 
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 仅仅加载某一层的参数
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
 
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

以上这篇pytorch 状态字典:state_dict使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
简单的Python2.7编程初学经验总结
Apr 01 Python
windows下安装Python和pip终极图文教程
Mar 05 Python
PyQt5利用QPainter绘制各种图形的实例
Oct 19 Python
python正则中最短匹配实现代码
Jan 16 Python
python模拟表单提交登录图书馆
Apr 27 Python
Python中创建二维数组
Oct 17 Python
python实现弹窗祝福效果
Apr 07 Python
Django ORM多对多查询方法(自定义第三张表&ManyToManyField)
Aug 09 Python
Python 解码Base64 得到码流格式文本实例
Jan 09 Python
Python对Tornado请求与响应的数据处理
Feb 12 Python
Python QT组件库qtwidgets的使用
Nov 02 Python
在python中实现导入一个需要传参的模块
May 12 Python
Python标准库itertools的使用方法
Jan 17 #Python
Python实现投影法分割图像示例(二)
Jan 17 #Python
Python常用库大全及简要说明
Jan 17 #Python
Python Sphinx使用实例及问题解决
Jan 17 #Python
通过实例了解Python str()和repr()的区别
Jan 17 #Python
python无序链表删除重复项的方法
Jan 17 #Python
Python实现投影法分割图像示例(一)
Jan 17 #Python
You might like
不用数据库的多用户文件自由上传投票系统(1)
2006/10/09 PHP
在PHP3中实现SESSION的功能(二)
2006/10/09 PHP
php adodb连接带密码access数据库实例,测试成功
2008/05/14 PHP
php Notice: Undefined index 错误提示解决方法
2010/08/29 PHP
大家在抢红包,程序员在研究红包算法
2015/08/31 PHP
php注册和登录界面的实现案例(推荐)
2016/10/24 PHP
利用phpexcel对数据库数据的导入excel(excel筛选)、导出excel
2017/04/27 PHP
thinkphp中的多表关联查询的实例详解
2017/10/12 PHP
PhpSpreadsheet设置单元格常用操作汇总
2020/11/13 PHP
input+select(multiple) 实现下拉框输入值
2009/05/21 Javascript
基于jquery实现的服务器验证控件的启用和禁用代码
2010/04/27 Javascript
php实例分享之实现显示网站运行时间
2014/05/20 Javascript
JavaScript简单表格编辑功能实现方法
2015/04/16 Javascript
javascript生成不重复的随机数
2015/07/17 Javascript
javascript仿百度输入框提示自动下拉补全
2016/01/07 Javascript
Bootstrap3 模态框使用实例
2017/02/22 Javascript
jQuery实现拖动效果的实例代码
2017/06/25 jQuery
vue使用drag与drop实现拖拽的示例代码
2017/09/07 Javascript
深入浅析js原型链和vue构造函数
2018/10/25 Javascript
使用vue实现各类弹出框组件
2019/07/03 Javascript
Node对CommonJS的模块规范
2019/11/06 Javascript
VUE 项目在IE11白屏报错 SCRIPT1002: 语法错误的解决
2020/09/27 Javascript
深入解析Python中的上下文管理器
2016/06/28 Python
Python使用Matplotlib实现Logos设计代码
2017/12/25 Python
你还在@微信官方?聊聊Python生成你想要的微信头像
2019/09/25 Python
解决django接口无法通过ip进行访问的问题
2020/03/27 Python
基于plt.title无法显示中文的快速解决
2020/05/16 Python
python爬虫用request库处理cookie的实例讲解
2021/02/20 Python
html5 自定义播放器核心代码
2013/12/20 HTML / CSS
美国最受欢迎的度假目的地优惠套餐:BookVIP
2018/09/27 全球购物
aden + anais英国官网:美国婴儿贴身用品品牌
2019/09/08 全球购物
Crocs欧洲官网:Crocs Europe
2020/01/14 全球购物
骨干教师培训制度
2014/01/13 职场文书
校本教研工作方案
2014/01/14 职场文书
优秀士兵先进事迹
2014/02/06 职场文书
团员年度个人总结
2015/02/26 职场文书