pytorch 状态字典:state_dict使用详解


Posted in Python onJanuary 17, 2020

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

备注:

1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.

模态字典(state_dict)的保存(model是一个网络结构类的对象)

1.1)仅保存学习到的参数,用以下命令

torch.save(model.state_dict(), PATH)

1.2)加载model.state_dict,用以下命令

model = TheModelClass(*args, **kwargs)
 model.load_state_dict(torch.load(PATH))
 model.eval()

备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名

2.1)保存整个model的状态,用以下命令

torch.save(model,PATH)

2.2)加载整个model的状态,用以下命令:

# Model class must be defined somewhere

 model = torch.load(PATH)

 model.eval()

state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项

如何仅加载某一层的训练的到的参数(某一层的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)

for param in list(model.pretrained.parameters()):
 param.requires_grad = False

注意: requires_grad的操作对象是tensor.

疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:经测试,不可以.model.conv1 没有requires_grad属性.

全部测试代码:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
 
 
# define model
class TheModelClass(nn.Module):
 def __init__(self):
  super(TheModelClass,self).__init__()
  self.conv1 = nn.Conv2d(3,6,5)
  self.pool = nn.MaxPool2d(2,2)
  self.conv2 = nn.Conv2d(6,16,5)
  self.fc1 = nn.Linear(16*5*5,120)
  self.fc2 = nn.Linear(120,84)
  self.fc3 = nn.Linear(84,10)
 
 def forward(self,x):
  x = self.pool(F.relu(self.conv1(x)))
  x = self.pool(F.relu(self.conv2(x)))
  x = x.view(-1,16*5*5)
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x
 
# initial model
model = TheModelClass()
 
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
 print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
 print(var_name,'\t',optimizer.state_dict()[var_name])
 
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
 
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 仅仅加载某一层的参数
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
 
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

以上这篇pytorch 状态字典:state_dict使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python中的wxPython实现最基本的浏览器功能
Apr 14 Python
使用Python的Twisted框架编写简单的网络客户端
Apr 16 Python
Python二分查找详解
Sep 13 Python
python中文分词,使用结巴分词对python进行分词(实例讲解)
Nov 14 Python
python3利用smtplib通过qq邮箱发送邮件方法示例
Dec 03 Python
在Pycharm中项目解释器与环境变量的设置方法
Oct 29 Python
Python去除字符串前后空格的几种方法
Mar 04 Python
python3 tcp的粘包现象和解决办法解析
Dec 09 Python
基于梯度爆炸的解决方法:clip gradient
Feb 04 Python
Django 设置多环境配置文件载入问题
Feb 25 Python
利用matplotlib为图片上添加触发事件进行交互
Apr 23 Python
python对输出的奇数偶数排序实例代码
Dec 04 Python
Python标准库itertools的使用方法
Jan 17 #Python
Python实现投影法分割图像示例(二)
Jan 17 #Python
Python常用库大全及简要说明
Jan 17 #Python
Python Sphinx使用实例及问题解决
Jan 17 #Python
通过实例了解Python str()和repr()的区别
Jan 17 #Python
python无序链表删除重复项的方法
Jan 17 #Python
Python实现投影法分割图像示例(一)
Jan 17 #Python
You might like
一个改进的UBB类
2006/10/09 PHP
兼容firefox,chrome的网页灰度效果
2011/08/08 PHP
yii2框架中使用下拉菜单的自动搜索yii-widget-select2实例分析
2016/01/09 PHP
jQuery 表单验证扩展代码(二)
2010/10/20 Javascript
30个最佳jQuery Lightbox效果插件分享
2011/04/11 Javascript
kindeditor修复会替换script内容的问题
2015/04/03 Javascript
jQuery实现的Div窗口震动效果实例
2015/08/07 Javascript
基于jquery实现鼠标左右拖动滑块滑动附源码下载
2015/12/23 Javascript
JavaScript基于扩展String实现替换字符串中index处字符的方法
2017/06/13 Javascript
JavaScript实现兼容IE6的收起折叠与展开效果实例
2017/09/20 Javascript
layer.open 按钮的点击事件关闭方法
2018/08/17 Javascript
vue自定义键盘信息、监听数据变化的方法示例【基于vm.$watch】
2019/03/16 Javascript
微信小程序API—获取定位的详解
2019/04/30 Javascript
在Express中提供静态文件的实现方法
2019/10/17 Javascript
Layui表格监听行单双击事件讲解
2019/11/14 Javascript
小程序实现图片预览裁剪插件
2019/11/22 Javascript
javascript设计模式之迭代器模式
2020/01/30 Javascript
Nodejs环境实现socket通信过程解析
2020/07/03 NodeJs
[02:11]完美世界DOTA2联赛10月28日赛事精彩集锦:来吧展示实力强劲
2020/10/29 DOTA
Python程序中使用SQLAlchemy时出现乱码的解决方案
2015/04/24 Python
总结Python编程中三条常用的技巧
2015/05/11 Python
使用tensorflow实现线性svm
2018/09/07 Python
浅谈python新式类和旧式类区别
2019/04/26 Python
python打印文件的前几行或最后几行教程
2020/02/13 Python
Python制作一个仿QQ办公版的图形登录界面
2020/09/22 Python
amazeui页面校验功能的实现代码
2020/08/24 HTML / CSS
伯利陶器:Burleigh Pottery
2018/01/03 全球购物
丝绸和人造花卉、植物和树木:Nearly Natural
2018/11/28 全球购物
美国家用和厨房电器销售网站:Appliances Connection
2020/01/24 全球购物
Loreto Gallo英国:欧洲领先的在线药房
2021/01/21 全球购物
应届生.NET方向面试题
2015/05/23 面试题
违反单位工作制度检讨书
2014/10/25 职场文书
班主任先进事迹材料
2014/12/17 职场文书
公司会议开幕词
2016/03/03 职场文书
SpringBoot整合RabbitMQ的5种模式实战
2021/08/02 Java/Android
用Python可视化新冠疫情数据
2022/01/18 Python