pytorch 状态字典:state_dict使用详解


Posted in Python onJanuary 17, 2020

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

备注:

1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.

模态字典(state_dict)的保存(model是一个网络结构类的对象)

1.1)仅保存学习到的参数,用以下命令

torch.save(model.state_dict(), PATH)

1.2)加载model.state_dict,用以下命令

model = TheModelClass(*args, **kwargs)
 model.load_state_dict(torch.load(PATH))
 model.eval()

备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名

2.1)保存整个model的状态,用以下命令

torch.save(model,PATH)

2.2)加载整个model的状态,用以下命令:

# Model class must be defined somewhere

 model = torch.load(PATH)

 model.eval()

state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项

如何仅加载某一层的训练的到的参数(某一层的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)

for param in list(model.pretrained.parameters()):
 param.requires_grad = False

注意: requires_grad的操作对象是tensor.

疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:经测试,不可以.model.conv1 没有requires_grad属性.

全部测试代码:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
 
 
# define model
class TheModelClass(nn.Module):
 def __init__(self):
  super(TheModelClass,self).__init__()
  self.conv1 = nn.Conv2d(3,6,5)
  self.pool = nn.MaxPool2d(2,2)
  self.conv2 = nn.Conv2d(6,16,5)
  self.fc1 = nn.Linear(16*5*5,120)
  self.fc2 = nn.Linear(120,84)
  self.fc3 = nn.Linear(84,10)
 
 def forward(self,x):
  x = self.pool(F.relu(self.conv1(x)))
  x = self.pool(F.relu(self.conv2(x)))
  x = x.view(-1,16*5*5)
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x
 
# initial model
model = TheModelClass()
 
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
 print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
 print(var_name,'\t',optimizer.state_dict()[var_name])
 
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
 
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 仅仅加载某一层的参数
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
 
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

以上这篇pytorch 状态字典:state_dict使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python深入学习之对象的属性
Aug 31 Python
Python时间的精准正则匹配方法分析
Aug 17 Python
Django中针对基于类的视图添加csrf_exempt实例代码
Feb 11 Python
python实现数独游戏 java简单实现数独游戏
Mar 30 Python
python字符串string的内置方法实例详解
May 14 Python
Python将列表数据写入文件(txt, csv,excel)
Apr 03 Python
在Python中COM口的调用方法
Jul 03 Python
在SQLite-Python中实现返回、查询中文字段的方法
Jul 17 Python
Python 使用matplotlib模块模拟掷骰子
Aug 08 Python
Python如何定义有可选参数的元类
Jul 31 Python
Python+pyftpdlib实现局域网文件互传
Aug 24 Python
如何使用python-opencv批量生成带噪点噪线的数字验证码
Dec 21 Python
Python标准库itertools的使用方法
Jan 17 #Python
Python实现投影法分割图像示例(二)
Jan 17 #Python
Python常用库大全及简要说明
Jan 17 #Python
Python Sphinx使用实例及问题解决
Jan 17 #Python
通过实例了解Python str()和repr()的区别
Jan 17 #Python
python无序链表删除重复项的方法
Jan 17 #Python
Python实现投影法分割图像示例(一)
Jan 17 #Python
You might like
火影忍者:这才是千手柱间和扉间的真正死因,角都就比较搞笑了!
2020/03/10 日漫
DOTA2 1月28日更新:监管系统降临刀塔世界
2021/01/28 DOTA
PHP中date与gmdate的区别及默认时区设置
2014/05/12 PHP
php实例分享之html转为rtf格式
2014/06/02 PHP
php判断邮箱地址是否存在的方法
2016/02/13 PHP
jquery中eq和get的区别与使用方法
2011/04/14 Javascript
jquery 选项卡效果 新手代码
2011/07/08 Javascript
理解JAVASCRIPT中hasOwnProperty()的作用
2013/06/05 Javascript
JavaScript中的字符串操作详解
2013/11/12 Javascript
jquery解析xml字符串简单示例
2014/04/11 Javascript
使用upstart把nodejs应用封装为系统服务实例
2014/06/01 NodeJs
15个值得开发人员关注的jQuery开发技巧和心得总结【经典收藏】
2016/05/25 Javascript
关于Jquery中的事件绑定总结
2016/10/26 Javascript
vue.js之vue-cli脚手架的搭建详解
2017/05/05 Javascript
mui 打开新窗口的方式总结及注意事项
2017/08/20 Javascript
nodejs socket服务端和客户端简单通信功能
2017/09/14 NodeJs
Vuex中mutations与actions的区别详解
2018/03/01 Javascript
Node.js Buffer模块功能及常用方法实例分析
2019/01/05 Javascript
Angular8基础应用之表单及其验证
2019/08/11 Javascript
react-router-dom 嵌套路由的实现
2020/05/02 Javascript
vue自定义标签和单页面多路由的实现代码
2020/05/03 Javascript
[02:49]2018DOTA2亚洲邀请赛主赛事决赛日战况回顾 Mineski鏖战5局夺得辉耀
2018/04/10 DOTA
[20:39]DOTA2-DPC中国联赛 正赛开幕式 1月18日
2021/03/11 DOTA
python处理html转义字符的方法详解
2016/07/01 Python
window下eclipse安装python插件教程
2017/04/24 Python
python使用两种发邮件的方式smtp和outlook示例
2017/06/02 Python
python实现浪漫的烟花秀
2019/01/30 Python
最新远光软件笔试题面试题内容
2013/11/08 面试题
我有一个梦想演讲稿
2014/05/05 职场文书
四查四看自我剖析材料
2014/09/19 职场文书
乡镇机关党员民主评议表自我评价
2014/09/21 职场文书
检讨书1000字
2014/10/11 职场文书
年会邀请函范文
2015/01/30 职场文书
大学生村官入党自传
2015/06/26 职场文书
pytorch实现手写数字图片识别
2021/05/20 Python
Python读取和写入Excel数据
2022/04/20 Python