pytorch 状态字典:state_dict使用详解


Posted in Python onJanuary 17, 2020

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

备注:

1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.

模态字典(state_dict)的保存(model是一个网络结构类的对象)

1.1)仅保存学习到的参数,用以下命令

torch.save(model.state_dict(), PATH)

1.2)加载model.state_dict,用以下命令

model = TheModelClass(*args, **kwargs)
 model.load_state_dict(torch.load(PATH))
 model.eval()

备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名

2.1)保存整个model的状态,用以下命令

torch.save(model,PATH)

2.2)加载整个model的状态,用以下命令:

# Model class must be defined somewhere

 model = torch.load(PATH)

 model.eval()

state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项

如何仅加载某一层的训练的到的参数(某一层的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)

for param in list(model.pretrained.parameters()):
 param.requires_grad = False

注意: requires_grad的操作对象是tensor.

疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:经测试,不可以.model.conv1 没有requires_grad属性.

全部测试代码:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
 
 
# define model
class TheModelClass(nn.Module):
 def __init__(self):
  super(TheModelClass,self).__init__()
  self.conv1 = nn.Conv2d(3,6,5)
  self.pool = nn.MaxPool2d(2,2)
  self.conv2 = nn.Conv2d(6,16,5)
  self.fc1 = nn.Linear(16*5*5,120)
  self.fc2 = nn.Linear(120,84)
  self.fc3 = nn.Linear(84,10)
 
 def forward(self,x):
  x = self.pool(F.relu(self.conv1(x)))
  x = self.pool(F.relu(self.conv2(x)))
  x = x.view(-1,16*5*5)
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x
 
# initial model
model = TheModelClass()
 
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
 print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
 print(var_name,'\t',optimizer.state_dict()[var_name])
 
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
 
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 仅仅加载某一层的参数
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
 
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

以上这篇pytorch 状态字典:state_dict使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python局域网ip扫描示例分享
Apr 03 Python
python和shell实现的校验IP地址合法性脚本分享
Oct 23 Python
Python自动化构建工具scons使用入门笔记
Mar 10 Python
Python装饰器实现几类验证功能做法实例
May 18 Python
Flask数据库迁移简单介绍
Oct 24 Python
Python下简易的单例模式详解
Apr 08 Python
详解python解压压缩包的五种方法
Jul 05 Python
用Python配平化学方程式的方法
Jul 20 Python
python处理RSTP视频流过程解析
Jan 11 Python
python实现飞机大战项目
Mar 11 Python
python 数据类型强制转换的总结
Jan 25 Python
Pycharm创建python文件自动添加日期作者等信息(步骤详解)
Feb 03 Python
Python标准库itertools的使用方法
Jan 17 #Python
Python实现投影法分割图像示例(二)
Jan 17 #Python
Python常用库大全及简要说明
Jan 17 #Python
Python Sphinx使用实例及问题解决
Jan 17 #Python
通过实例了解Python str()和repr()的区别
Jan 17 #Python
python无序链表删除重复项的方法
Jan 17 #Python
Python实现投影法分割图像示例(一)
Jan 17 #Python
You might like
php简单开启gzip压缩方法(zlib.output_compression)
2013/04/13 PHP
php创建无限级树型菜单
2015/11/05 PHP
两种php实现图片上传的方法
2016/01/22 PHP
基于jquery的防止大图片撑破页面的实现代码(立即缩放)
2011/10/24 Javascript
解决Jquery load()加载GB2312页面时出现乱码的两种方案
2013/09/10 Javascript
javascript中字符串拼接详解
2014/09/26 Javascript
JS获取时间的方法
2015/01/21 Javascript
JavaScript中split() 使用方法汇总
2015/04/17 Javascript
轻量级的原生js日历插件calendar.js使用指南
2015/04/28 Javascript
全面解析Javascript无限添加QQ好友原理
2016/06/15 Javascript
Angular页面间切换及传值的4种方法
2016/11/04 Javascript
Vue.js 2.0窥探之Virtual DOM到底是什么?
2017/02/10 Javascript
H5上传本地图片并预览功能
2017/05/08 Javascript
jquery.onoff实现简单的开关按钮功能(推荐)
2018/05/24 jQuery
原生javascript自定义input[type=radio]效果示例
2019/08/27 Javascript
文章或博客自动生成章节目录索引(支持三级)的实现代码
2020/05/10 Javascript
jQuery 选择方法及$(this)用法实例分析
2020/05/19 jQuery
Vue结合路由配置递归实现菜单栏功能
2020/06/16 Javascript
Python爬取读者并制作成PDF
2015/03/10 Python
python文件特定行插入和替换实例详解
2017/07/12 Python
利用Python查看目录中的文件示例详解
2017/08/28 Python
浅谈机器学习需要的了解的十大算法
2017/12/15 Python
Python一句代码实现找出所有水仙花数的方法
2018/11/13 Python
Python 中Django安装和使用教程详解
2019/07/03 Python
win10系统下python3安装及pip换源和使用教程
2020/01/06 Python
如何在Win10系统使用Python3连接Hive
2020/10/15 Python
微软开源最强Python自动化神器Playwright(不用写一行代码)
2021/01/05 Python
摩顿布朗英国官方网上商店:奢华沐浴、身体和头发护理
2016/10/29 全球购物
决定成败的关键——创业计划书
2014/01/24 职场文书
初中学习计划书范文
2014/09/15 职场文书
务虚会发言材料
2014/12/25 职场文书
成品仓库管理员岗位职责
2015/04/09 职场文书
公司搬迁通知
2015/04/20 职场文书
公司保洁员管理制度
2015/08/04 职场文书
游戏《我的世界》澄清Xbox版暂无计划加入光追
2022/04/03 其他游戏
清空 Oracle 安装记录并重新安装
2022/04/26 Oracle