pytorch 状态字典:state_dict使用详解


Posted in Python onJanuary 17, 2020

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

备注:

1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.

模态字典(state_dict)的保存(model是一个网络结构类的对象)

1.1)仅保存学习到的参数,用以下命令

torch.save(model.state_dict(), PATH)

1.2)加载model.state_dict,用以下命令

model = TheModelClass(*args, **kwargs)
 model.load_state_dict(torch.load(PATH))
 model.eval()

备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名

2.1)保存整个model的状态,用以下命令

torch.save(model,PATH)

2.2)加载整个model的状态,用以下命令:

# Model class must be defined somewhere

 model = torch.load(PATH)

 model.eval()

state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项

如何仅加载某一层的训练的到的参数(某一层的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)

for param in list(model.pretrained.parameters()):
 param.requires_grad = False

注意: requires_grad的操作对象是tensor.

疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:经测试,不可以.model.conv1 没有requires_grad属性.

全部测试代码:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
 
 
# define model
class TheModelClass(nn.Module):
 def __init__(self):
  super(TheModelClass,self).__init__()
  self.conv1 = nn.Conv2d(3,6,5)
  self.pool = nn.MaxPool2d(2,2)
  self.conv2 = nn.Conv2d(6,16,5)
  self.fc1 = nn.Linear(16*5*5,120)
  self.fc2 = nn.Linear(120,84)
  self.fc3 = nn.Linear(84,10)
 
 def forward(self,x):
  x = self.pool(F.relu(self.conv1(x)))
  x = self.pool(F.relu(self.conv2(x)))
  x = x.view(-1,16*5*5)
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x
 
# initial model
model = TheModelClass()
 
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
 print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
 print(var_name,'\t',optimizer.state_dict()[var_name])
 
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
 
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 仅仅加载某一层的参数
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
 
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

以上这篇pytorch 状态字典:state_dict使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 正则表达式 概述及常用字符
May 04 Python
python通过imaplib模块读取gmail里邮件的方法
May 08 Python
Django添加KindEditor富文本编辑器的使用
Oct 24 Python
pygame游戏之旅 添加icon和bgm音效的方法
Nov 21 Python
在python中使用requests 模拟浏览器发送请求数据的方法
Dec 26 Python
python生成带有表格的图片实例
Feb 03 Python
Python实例方法、类方法、静态方法的区别与作用详解
Mar 25 Python
详解用Python实现自动化监控远程服务器
May 18 Python
python算法与数据结构之冒泡排序实例详解
Jun 22 Python
解析python 类方法、对象方法、静态方法
Aug 15 Python
python - asyncio异步编程
Apr 06 Python
Python实现学生管理系统(面向对象版)
Jun 24 Python
Python标准库itertools的使用方法
Jan 17 #Python
Python实现投影法分割图像示例(二)
Jan 17 #Python
Python常用库大全及简要说明
Jan 17 #Python
Python Sphinx使用实例及问题解决
Jan 17 #Python
通过实例了解Python str()和repr()的区别
Jan 17 #Python
python无序链表删除重复项的方法
Jan 17 #Python
Python实现投影法分割图像示例(一)
Jan 17 #Python
You might like
php下几个常用的去空、分组、调试数组函数
2009/02/22 PHP
eAccelerator的安装与使用详解
2013/06/13 PHP
Thinkphp中数据按分类嵌套循环实现方法
2014/10/30 PHP
JS 无法通过W3C验证的处理方法
2010/03/09 Javascript
javascript 基础篇2 数据类型,语句,函数
2012/03/14 Javascript
nodejs URL模块操作URL相关方法介绍
2015/03/03 NodeJs
js上传图片及预览功能实例分析
2015/04/24 Javascript
java中String类型变量的赋值问题介绍
2016/03/23 Javascript
javascript css红色经典选项卡效果实现代码
2016/05/17 Javascript
第一次动手实现bootstrap table分页效果
2016/09/22 Javascript
Node.js用readline模块实现输入输出
2016/12/16 Javascript
激动人心的 Angular HttpClient的源码解析
2017/07/10 Javascript
vue.js父子组件通信动态绑定的实例
2018/09/28 Javascript
vue router动态路由设置参数可选问题
2019/08/21 Javascript
微信小程序swiper组件实现抖音翻页切换视频功能的实例代码
2020/06/24 Javascript
微信小程序连接服务器展示MQTT数据信息的实现
2020/07/14 Javascript
vue 判断元素内容是否超过宽度的方式
2020/07/29 Javascript
[02:24]DOTA2痛苦女王 英雄基础教程
2013/11/26 DOTA
Python的shutil模块中文件的复制操作函数详解
2016/07/05 Python
python 处理dataframe中的时间字段方法
2018/04/10 Python
pycharm执行python时,填写参数的方法
2018/10/29 Python
python判断列表的连续数字范围并分块的方法
2018/11/16 Python
Python通过TensorFlow卷积神经网络实现猫狗识别
2019/03/14 Python
python操作kafka实践的示例代码
2019/06/19 Python
Python实现自动打开电脑应用的示例代码
2020/04/17 Python
tensorflow 动态获取 BatchSzie 的大小实例
2020/06/30 Python
python 写一个性能测试工具(一)
2020/10/24 Python
Python使用cn2an实现中文数字与阿拉伯数字的相互转换
2021/03/02 Python
摩托车和ATV零件、配件和服装的首选在线零售商:MotoSport
2017/12/22 全球购物
自立自强的名人事例
2014/02/10 职场文书
《难忘的泼水节》教学反思
2014/02/27 职场文书
贪污检举信范文
2015/03/02 职场文书
2015年高一班主任工作总结
2015/05/13 职场文书
护士业务学习心得体会
2016/01/25 职场文书
先进个人事迹材料(2016推荐版)
2016/03/01 职场文书
MySQL pt-slave-restart工具的使用简介
2021/04/07 MySQL