Pytorch中实现只导入部分模型参数的方式


Posted in Python onJanuary 02, 2020

我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(don't expected)。我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed)。如果依旧使用torch.load(model.state_dict())的办法,就会出现 xxx expected,xxx missed类似的错误。那么在这种情况下,该如何导入模型呢?

好在Pytorch中的模型参数使用字典保存的,键是参数的名称,值是参数的具体数值。我们使用model.state_dict()获得这个字典,之后就能利用参数名称来实现导入。

请看下面的一个例子。

我们先搭建一个小小的网络。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
class Net(Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1 = nn.Conv2d(3,32,3,1)
    self.conv2 = nn.Conv2d(32,3,3,1)
    self.w = nn.Parameter(t.randn(3,10))
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out,(out.shape[2],out.shape[3]))
    out = F.linear(out,weight=self.w)
    return out

然后我们保存这个网络的初始值。

model = Net()
t.save(model.state_dict(),'xxx.pth')

现在我们将Net修改一下,多加几个卷积层,但并不加入到forward中,仅仅出于少些几行的目的。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
 
 
class Net(Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 3, 3, 1)
    self.conv3 = nn.Conv2d(3,64,3,1)
    self.conv4 = nn.Conv2d(64,32,3,1)
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
 
    self.w = nn.Parameter(t.randn(3, 10))
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out, (out.shape[2], out.shape[3]))
    out = F.linear(out, weight=self.w)
    return out

我们现在试着导入之前保存的模型参数。

path = 'xxx.pth'
model = Net()
model.load_state_dict(t.load(path))
 
'''
RuntimeError: Error(s) in loading state_dict for Net:
 Missing key(s) in state_dict: "conv3.weight", "conv3.bias", "conv4.weight", "conv4.bias". 
'''

出现了没有在模型文件中找到error中的关键字的错误。

现在我们这样导入模型

path = 'xxx.pth'
model = Net()
save_model = t.load(path)
model_dict = model.state_dict()
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)

看看上面的代码,很容易弄明白。其中model_dict.update的作用是更新代码中搭建的模型参数字典。为啥更新我其实并不清楚,但这一步骤是必须的,否则还会报错。

为了弄清楚为什么要更新model_dict,我们不妨分别输出state_dict和model_dict的关键值看一看。

for k in state_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
'''
for k in model_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
conv4.weight
conv4.bias
'''

这个结果也是预料之中的,所以我猜测,update之后,model_dict和state_dict中具有相同键的值已经同步了。updata的目的就是使model_dict带有state_dict中都具有的那一部分参数的值,对于model_dict中有的,但是save_dict中没有的参数,值不改变,参数仍然使用初始值。

以上这篇Pytorch中实现只导入部分模型参数的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中for循环的使用方法
May 14 Python
在Python中用get()方法获取字典键值的教程
May 21 Python
python实现字典(dict)和字符串(string)的相互转换方法
Mar 01 Python
Python 加密的实例详解
Oct 09 Python
Python实现插入排序和选择排序的方法
May 12 Python
django 单表操作实例详解
Jul 30 Python
Python中模块(Module)和包(Package)的区别详解
Aug 07 Python
使用python3 实现插入数据到mysql
Mar 02 Python
Python新手学习函数默认参数设置
Jun 03 Python
Python3+selenium配置常见报错解决方案
Aug 28 Python
python pyhs2 的安装操作
Apr 07 Python
python3+PyQt5+Qt Designer实现界面可视化
Jun 10 Python
PyTorch中topk函数的用法详解
Jan 02 #Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
You might like
用PHP实现登陆验证码(类似条行码状)
2006/10/09 PHP
提升PHP执行速度全攻略(下)
2006/10/09 PHP
PHP 正则表达式之正则处理函数小结(preg_match,preg_match_all,preg_replace,preg_split)
2012/10/05 PHP
基于PHP中的常用函数回顾
2013/07/11 PHP
yii,CI,yaf框架+smarty模板使用方法
2015/12/29 PHP
关于php支持的协议与封装协议总结(推荐)
2017/11/17 PHP
总结PHP代码规范、流程规范、git规范
2018/06/18 PHP
PHP配置ZendOpcache插件加速
2019/02/14 PHP
关于ExtJS4.1:快捷键支持的问题
2013/04/24 Javascript
原生javascript实现隔行换色
2015/01/04 Javascript
jQuery实现MSN中文网滑动Tab菜单效果代码
2015/09/09 Javascript
Bootstrap表单布局样式源代码
2016/07/04 Javascript
新手学习前端之js模仿淘宝主页网站
2016/10/31 Javascript
JavaScript中for循环的几种写法与效率总结
2017/02/03 Javascript
vue项目中使用axios上传图片等文件操作
2017/11/02 Javascript
three.js中文文档学习之通过模块导入
2017/11/20 Javascript
vue.js使用v-model实现表单元素(input) 双向数据绑定功能示例
2019/03/08 Javascript
BootstrapValidator验证用户名已存在(ajax)
2019/11/08 Javascript
Python数据结构与算法之图的基本实现及迭代器实例详解
2017/12/12 Python
python email smtplib模块发送邮件代码实例
2018/04/26 Python
python读取excel指定列数据并写入到新的excel方法
2018/07/10 Python
Flask框架Jinjia模板常用语法总结
2018/07/19 Python
详解PyTorch手写数字识别(MNIST数据集)
2019/08/16 Python
python 解决mysql where in 对列表(list,,array)问题
2020/06/06 Python
Python 测试框架unittest和pytest的优劣
2020/09/26 Python
详解HTML5常用的语义化标签
2019/09/27 HTML / CSS
大学生简历中个人的自我评价
2013/10/06 职场文书
12月小学生校园广播稿
2014/02/04 职场文书
益达广告词
2014/03/14 职场文书
学生会主席竞聘书
2014/03/31 职场文书
应聘护士求职信
2014/07/21 职场文书
党的群众路线教育实践活动总结材料
2014/10/30 职场文书
财务负责人岗位职责
2015/02/03 职场文书
小学教师师德师风自我评价
2015/03/04 职场文书
机关保密工作承诺书
2015/05/04 职场文书
Mybatis-plus在项目中的简单应用
2021/07/01 Java/Android