Pytorch中实现只导入部分模型参数的方式


Posted in Python onJanuary 02, 2020

我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(don't expected)。我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed)。如果依旧使用torch.load(model.state_dict())的办法,就会出现 xxx expected,xxx missed类似的错误。那么在这种情况下,该如何导入模型呢?

好在Pytorch中的模型参数使用字典保存的,键是参数的名称,值是参数的具体数值。我们使用model.state_dict()获得这个字典,之后就能利用参数名称来实现导入。

请看下面的一个例子。

我们先搭建一个小小的网络。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
class Net(Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1 = nn.Conv2d(3,32,3,1)
    self.conv2 = nn.Conv2d(32,3,3,1)
    self.w = nn.Parameter(t.randn(3,10))
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out,(out.shape[2],out.shape[3]))
    out = F.linear(out,weight=self.w)
    return out

然后我们保存这个网络的初始值。

model = Net()
t.save(model.state_dict(),'xxx.pth')

现在我们将Net修改一下,多加几个卷积层,但并不加入到forward中,仅仅出于少些几行的目的。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
 
 
class Net(Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 3, 3, 1)
    self.conv3 = nn.Conv2d(3,64,3,1)
    self.conv4 = nn.Conv2d(64,32,3,1)
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
 
    self.w = nn.Parameter(t.randn(3, 10))
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out, (out.shape[2], out.shape[3]))
    out = F.linear(out, weight=self.w)
    return out

我们现在试着导入之前保存的模型参数。

path = 'xxx.pth'
model = Net()
model.load_state_dict(t.load(path))
 
'''
RuntimeError: Error(s) in loading state_dict for Net:
 Missing key(s) in state_dict: "conv3.weight", "conv3.bias", "conv4.weight", "conv4.bias". 
'''

出现了没有在模型文件中找到error中的关键字的错误。

现在我们这样导入模型

path = 'xxx.pth'
model = Net()
save_model = t.load(path)
model_dict = model.state_dict()
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)

看看上面的代码,很容易弄明白。其中model_dict.update的作用是更新代码中搭建的模型参数字典。为啥更新我其实并不清楚,但这一步骤是必须的,否则还会报错。

为了弄清楚为什么要更新model_dict,我们不妨分别输出state_dict和model_dict的关键值看一看。

for k in state_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
'''
for k in model_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
conv4.weight
conv4.bias
'''

这个结果也是预料之中的,所以我猜测,update之后,model_dict和state_dict中具有相同键的值已经同步了。updata的目的就是使model_dict带有state_dict中都具有的那一部分参数的值,对于model_dict中有的,但是save_dict中没有的参数,值不改变,参数仍然使用初始值。

以上这篇Pytorch中实现只导入部分模型参数的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
为Python的web框架编写前端模版的教程
Apr 30 Python
Python 2.7中文显示与处理方法
Jul 16 Python
django如何连接已存在数据的数据库
Aug 14 Python
python random从集合中随机选择元素的方法
Jan 23 Python
Python文件打开方式实例详解【a、a+、r+、w+区别】
Mar 30 Python
PyQt5固定窗口大小的方法
Jun 18 Python
Python 装饰器原理、定义与用法详解
Dec 07 Python
python GUI库图形界面开发之PyQt5表格控件QTableView详细使用方法与实例
Mar 01 Python
Pycharm激活码激活两种快速方式(附最新激活码和插件)
Mar 12 Python
pandas.DataFrame.drop_duplicates 用法介绍
Jul 06 Python
Pytorch之扩充tensor的操作
Mar 04 Python
什么是Python装饰器?如何定义和使用?
Apr 11 Python
PyTorch中topk函数的用法详解
Jan 02 #Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
You might like
简单的php数据库操作类代码(增,删,改,查)
2013/04/08 PHP
LotusPhp笔记之:Cookie组件的使用详解
2013/05/06 PHP
解决PhpMyAdmin中导入2M以上大文件限制的方法分享
2014/06/06 PHP
PHP以mysqli方式连接类完整代码实例
2014/07/15 PHP
highchart数据源纵轴json内的值必须是int(详解)
2017/02/20 PHP
js中方法重载如何实现?以及函数的参数问题
2013/08/01 Javascript
jQuery实现级联菜单效果(仿淘宝首页菜单动画)
2014/04/10 Javascript
JavaScript阻止事件冒泡示例分享
2014/12/28 Javascript
Jquery 实现grid绑定模板
2015/01/28 Javascript
JavaScript中split() 使用方法汇总
2015/04/17 Javascript
JSON相关知识汇总
2015/07/03 Javascript
详解JavaScript ES6中的Generator
2015/07/28 Javascript
jQuery Validate表单验证插件 添加class属性形式的校验
2016/01/18 Javascript
jQuery模拟物体自由落体运动(附演示与demo源码下载)
2016/01/21 Javascript
浅谈Web页面向后台提交数据的方式和选择
2016/09/23 Javascript
jQuery实现手势解锁密码特效
2017/08/14 jQuery
JS实现基于拖拽改变物体大小的方法
2018/01/23 Javascript
浅谈javascript事件环微任务和宏任务队列原理
2020/09/12 Javascript
python命令行参数解析OptionParser类用法实例
2014/10/09 Python
在Python中操作文件之truncate()方法的使用教程
2015/05/25 Python
Python使用base64模块进行二进制数据编码详解
2018/01/11 Python
python opencv之SIFT算法示例
2018/02/24 Python
python实现自动获取IP并发送到邮箱
2018/12/26 Python
Python两个字典键同值相加的几种方法
2019/03/05 Python
python 根据字典的键值进行排序的方法
2019/07/24 Python
python 监控logcat关键字功能
2020/09/04 Python
Python列表推导式实现代码实例
2020/09/09 Python
python Matplotlib基础--如何添加文本和标注
2021/01/26 Python
解决pytorch 保存模型遇到的问题
2021/03/03 Python
英国最大的汽车交易网站:Auto Trader UK
2016/09/23 全球购物
Notino芬兰:购买香水和化妆品
2019/04/15 全球购物
教师自我鉴定
2013/12/13 职场文书
市场营销专业自荐书
2014/06/10 职场文书
我的1919观后感
2015/06/03 职场文书
竞聘书的秘诀
2019/04/02 职场文书
win10清理dns缓存
2022/04/19 数码科技