Pytorch中实现只导入部分模型参数的方式


Posted in Python onJanuary 02, 2020

我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(don't expected)。我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed)。如果依旧使用torch.load(model.state_dict())的办法,就会出现 xxx expected,xxx missed类似的错误。那么在这种情况下,该如何导入模型呢?

好在Pytorch中的模型参数使用字典保存的,键是参数的名称,值是参数的具体数值。我们使用model.state_dict()获得这个字典,之后就能利用参数名称来实现导入。

请看下面的一个例子。

我们先搭建一个小小的网络。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
class Net(Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1 = nn.Conv2d(3,32,3,1)
    self.conv2 = nn.Conv2d(32,3,3,1)
    self.w = nn.Parameter(t.randn(3,10))
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out,(out.shape[2],out.shape[3]))
    out = F.linear(out,weight=self.w)
    return out

然后我们保存这个网络的初始值。

model = Net()
t.save(model.state_dict(),'xxx.pth')

现在我们将Net修改一下,多加几个卷积层,但并不加入到forward中,仅仅出于少些几行的目的。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
 
 
class Net(Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 3, 3, 1)
    self.conv3 = nn.Conv2d(3,64,3,1)
    self.conv4 = nn.Conv2d(64,32,3,1)
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
 
    self.w = nn.Parameter(t.randn(3, 10))
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out, (out.shape[2], out.shape[3]))
    out = F.linear(out, weight=self.w)
    return out

我们现在试着导入之前保存的模型参数。

path = 'xxx.pth'
model = Net()
model.load_state_dict(t.load(path))
 
'''
RuntimeError: Error(s) in loading state_dict for Net:
 Missing key(s) in state_dict: "conv3.weight", "conv3.bias", "conv4.weight", "conv4.bias". 
'''

出现了没有在模型文件中找到error中的关键字的错误。

现在我们这样导入模型

path = 'xxx.pth'
model = Net()
save_model = t.load(path)
model_dict = model.state_dict()
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)

看看上面的代码,很容易弄明白。其中model_dict.update的作用是更新代码中搭建的模型参数字典。为啥更新我其实并不清楚,但这一步骤是必须的,否则还会报错。

为了弄清楚为什么要更新model_dict,我们不妨分别输出state_dict和model_dict的关键值看一看。

for k in state_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
'''
for k in model_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
conv4.weight
conv4.bias
'''

这个结果也是预料之中的,所以我猜测,update之后,model_dict和state_dict中具有相同键的值已经同步了。updata的目的就是使model_dict带有state_dict中都具有的那一部分参数的值,对于model_dict中有的,但是save_dict中没有的参数,值不改变,参数仍然使用初始值。

以上这篇Pytorch中实现只导入部分模型参数的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python显示进度条的方法
Sep 20 Python
python类继承与子类实例初始化用法分析
Apr 17 Python
详解Python装饰器由浅入深
Dec 09 Python
tensorflow入门之训练简单的神经网络方法
Feb 26 Python
python统计字母、空格、数字等字符个数的实例
Jun 29 Python
Numpy中矩阵matrix读取一列的方法及数组和矩阵的相互转换实例
Jul 02 Python
Python设计模式之工厂方法模式实例详解
Jan 18 Python
Python实现通过解析域名获取ip地址的方法分析
May 17 Python
python安装requests库的实例代码
Jun 25 Python
基于Python的ModbusTCP客户端实现详解
Jul 13 Python
Python发送手机动态验证码代码实例
Feb 28 Python
Python3 类型标注支持操作
Jun 02 Python
PyTorch中topk函数的用法详解
Jan 02 #Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
You might like
在PWS上安装PHP4.0正式版
2006/10/09 PHP
一个PHP的QRcode类与大家分享
2011/11/13 PHP
[原创]PHP字符串中插入子字符串方法总结
2016/05/06 PHP
laravel 实现设置时区的简单方法
2019/10/10 PHP
Jquery实战_读书笔记1—选择jQuery
2010/01/22 Javascript
常用一些Javascript判断函数
2012/08/14 Javascript
js 将json字符串转换为json对象的方法解析
2013/11/13 Javascript
js实现的map方法示例代码
2014/01/13 Javascript
javascript在IE下trim函数无法使用的解决方法
2014/09/12 Javascript
JS实现为表格动态添加标题的方法
2015/03/31 Javascript
jquery对象访问是什么及使用方法介绍
2016/05/03 Javascript
Java框架SSH结合Easyui控件实现省市县三级联动示例解析
2016/06/12 Javascript
JavaScript中访问id对象 属性的方式访问属性(实例代码)
2016/10/28 Javascript
详解PHP中pathinfo()函数导致的安全问题
2017/01/05 Javascript
Vue.js实现模拟微信朋友圈开发demo
2017/04/20 Javascript
基于jQuery解决ios10以上版本缩放问题
2017/11/03 jQuery
JS设计模式之命令模式概念与用法分析
2018/02/06 Javascript
node基于async/await对mysql进行封装
2019/06/20 Javascript
解决layer.open后laydate失效的问题
2019/09/06 Javascript
基于vue-cli3和element实现登陆页面
2019/11/13 Javascript
vue 函数调用加括号与不加括号的区别
2020/10/29 Javascript
[14:25]教你分分钟做大人:主宰(HEROS)
2014/12/08 DOTA
python实现的阳历转阴历(农历)算法
2014/04/25 Python
Django视图之ORM数据库查询操作API的实例
2017/10/27 Python
详解tensorflow训练自己的数据集实现CNN图像分类
2018/02/07 Python
python制作填词游戏步骤详解
2019/05/05 Python
tensorflow之并行读入数据详解
2020/02/05 Python
哈工大自然语言处理工具箱之ltp在windows10下的安装使用教程
2020/05/07 Python
Python 如何调试程序崩溃错误
2020/08/03 Python
50个强大璀璨的CSS3/JS技术运用实例
2010/02/27 HTML / CSS
美国名牌香水折扣网站:Hottperfume
2021/02/10 全球购物
学习雷锋演讲稿
2014/05/10 职场文书
辩论赛新闻稿
2015/07/17 职场文书
Python打包exe时各种异常处理方案总结
2021/05/18 Python
pytorch训练神经网络爆内存的解决方案
2021/05/22 Python
Redis唯一ID生成器的实现
2022/07/07 Redis