pytorch 状态字典:state_dict使用详解


Posted in Python onJanuary 17, 2020

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

备注:

1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.

模态字典(state_dict)的保存(model是一个网络结构类的对象)

1.1)仅保存学习到的参数,用以下命令

torch.save(model.state_dict(), PATH)

1.2)加载model.state_dict,用以下命令

model = TheModelClass(*args, **kwargs)
 model.load_state_dict(torch.load(PATH))
 model.eval()

备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名

2.1)保存整个model的状态,用以下命令

torch.save(model,PATH)

2.2)加载整个model的状态,用以下命令:

# Model class must be defined somewhere

 model = torch.load(PATH)

 model.eval()

state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项

如何仅加载某一层的训练的到的参数(某一层的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)

for param in list(model.pretrained.parameters()):
 param.requires_grad = False

注意: requires_grad的操作对象是tensor.

疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:经测试,不可以.model.conv1 没有requires_grad属性.

全部测试代码:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
 
 
# define model
class TheModelClass(nn.Module):
 def __init__(self):
  super(TheModelClass,self).__init__()
  self.conv1 = nn.Conv2d(3,6,5)
  self.pool = nn.MaxPool2d(2,2)
  self.conv2 = nn.Conv2d(6,16,5)
  self.fc1 = nn.Linear(16*5*5,120)
  self.fc2 = nn.Linear(120,84)
  self.fc3 = nn.Linear(84,10)
 
 def forward(self,x):
  x = self.pool(F.relu(self.conv1(x)))
  x = self.pool(F.relu(self.conv2(x)))
  x = x.view(-1,16*5*5)
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x
 
# initial model
model = TheModelClass()
 
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
 print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
 print(var_name,'\t',optimizer.state_dict()[var_name])
 
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
 
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 仅仅加载某一层的参数
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
 
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

以上这篇pytorch 状态字典:state_dict使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python列表(list)、字典(dict)、字符串(string)基本操作小结
Nov 28 Python
python使用Queue在多个子进程间交换数据的方法
Apr 18 Python
Python的Django应用程序解决AJAX跨域访问问题的方法
May 31 Python
python中urllib.unquote乱码的原因与解决方法
Apr 24 Python
python使用matplotlib绘图时图例显示问题的解决
Apr 27 Python
python matplotlib绘图,修改坐标轴刻度为文字的实例
May 25 Python
python导入模块交叉引用的方法
Jan 19 Python
Python range与enumerate函数区别解析
Feb 28 Python
jupyter note 实现将数据保存为word
Apr 14 Python
用Python实现定时备份Mongodb数据并上传到FTP服务器
Jan 27 Python
python中time.ctime()实例用法
Feb 03 Python
matplotlib 范围选区(SpanSelector)的使用
Feb 24 Python
Python标准库itertools的使用方法
Jan 17 #Python
Python实现投影法分割图像示例(二)
Jan 17 #Python
Python常用库大全及简要说明
Jan 17 #Python
Python Sphinx使用实例及问题解决
Jan 17 #Python
通过实例了解Python str()和repr()的区别
Jan 17 #Python
python无序链表删除重复项的方法
Jan 17 #Python
Python实现投影法分割图像示例(一)
Jan 17 #Python
You might like
2019年漫画销量排行榜:鬼灭登顶 海贼单卷制霸 尾田盛赞鬼灭
2020/03/08 日漫
php 取得瑞年与平年的天数的代码
2009/08/10 PHP
Zend Framework入门教程之Zend_Config组件用法详解
2016/12/09 PHP
Javascript 入门基础学习
2010/03/10 Javascript
JavaScript动态调整TextArea高度的代码
2010/12/28 Javascript
各情景下元素宽高的获取实现代码
2011/09/13 Javascript
JS批量修改PS中图层名称的方法
2014/01/26 Javascript
js中文逗号转英文实现
2014/02/11 Javascript
jquery中常用的函数和属性详细解析
2014/03/07 Javascript
jQuery产品间断向下滚动效果核心代码
2014/05/08 Javascript
arguments对象验证函数的参数是否合法
2015/06/26 Javascript
详解JavaScript时间格式化
2015/12/23 Javascript
浅析JSONP技术原理及实现
2016/06/08 Javascript
Bootstrap 3.x打印预览背景色与文字显示异常的解决
2016/11/06 Javascript
javascript中活灵活现的Array对象详解
2016/11/30 Javascript
详解javascript表单的Ajax提交插件的使用
2016/12/29 Javascript
深入理解JavaScript 参数按值传递
2017/05/24 Javascript
node.js读取Excel数据(下载图片)的方法示例
2018/08/02 Javascript
教你如何将 Sublime 3 打造成 Python/Django IDE开发利器
2014/07/04 Python
Python错误提示:[Errno 24] Too many open files的分析与解决
2017/02/16 Python
Python中.py文件打包成exe可执行文件详解
2017/03/22 Python
Python3.6通过自带的urllib通过get或post方法请求url的实例
2018/05/10 Python
用Python shell简化开发
2018/08/08 Python
Python实现的爬取百度贴吧图片功能完整示例
2019/05/10 Python
python实现静态服务器
2019/09/05 Python
Python3+RIDE+RobotFramework自动化测试框架搭建过程详解
2020/09/23 Python
python3 os进行嵌套操作的实例讲解
2020/11/19 Python
css3中的calc函数浅析
2018/07/10 HTML / CSS
美国大城市最热门旅游景点门票:CityPASS
2016/12/16 全球购物
世界顶级足球门票网站:Live Football Tickets
2017/10/14 全球购物
学期评语大全
2014/04/30 职场文书
学校节能宣传周活动总结
2014/07/09 职场文书
水电工岗位职责
2015/02/14 职场文书
创业项目(超低成本创业项目)
2019/08/16 职场文书
关于办理居住证的介绍信模板
2019/11/27 职场文书
Python&Matlab实现灰狼优化算法的示例代码
2022/03/21 Python