Pytorch中实现只导入部分模型参数的方式


Posted in Python onJanuary 02, 2020

我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(don't expected)。我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed)。如果依旧使用torch.load(model.state_dict())的办法,就会出现 xxx expected,xxx missed类似的错误。那么在这种情况下,该如何导入模型呢?

好在Pytorch中的模型参数使用字典保存的,键是参数的名称,值是参数的具体数值。我们使用model.state_dict()获得这个字典,之后就能利用参数名称来实现导入。

请看下面的一个例子。

我们先搭建一个小小的网络。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
class Net(Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1 = nn.Conv2d(3,32,3,1)
    self.conv2 = nn.Conv2d(32,3,3,1)
    self.w = nn.Parameter(t.randn(3,10))
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out,(out.shape[2],out.shape[3]))
    out = F.linear(out,weight=self.w)
    return out

然后我们保存这个网络的初始值。

model = Net()
t.save(model.state_dict(),'xxx.pth')

现在我们将Net修改一下,多加几个卷积层,但并不加入到forward中,仅仅出于少些几行的目的。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
 
 
class Net(Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 3, 3, 1)
    self.conv3 = nn.Conv2d(3,64,3,1)
    self.conv4 = nn.Conv2d(64,32,3,1)
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
 
    self.w = nn.Parameter(t.randn(3, 10))
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out, (out.shape[2], out.shape[3]))
    out = F.linear(out, weight=self.w)
    return out

我们现在试着导入之前保存的模型参数。

path = 'xxx.pth'
model = Net()
model.load_state_dict(t.load(path))
 
'''
RuntimeError: Error(s) in loading state_dict for Net:
 Missing key(s) in state_dict: "conv3.weight", "conv3.bias", "conv4.weight", "conv4.bias". 
'''

出现了没有在模型文件中找到error中的关键字的错误。

现在我们这样导入模型

path = 'xxx.pth'
model = Net()
save_model = t.load(path)
model_dict = model.state_dict()
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)

看看上面的代码,很容易弄明白。其中model_dict.update的作用是更新代码中搭建的模型参数字典。为啥更新我其实并不清楚,但这一步骤是必须的,否则还会报错。

为了弄清楚为什么要更新model_dict,我们不妨分别输出state_dict和model_dict的关键值看一看。

for k in state_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
'''
for k in model_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
conv4.weight
conv4.bias
'''

这个结果也是预料之中的,所以我猜测,update之后,model_dict和state_dict中具有相同键的值已经同步了。updata的目的就是使model_dict带有state_dict中都具有的那一部分参数的值,对于model_dict中有的,但是save_dict中没有的参数,值不改变,参数仍然使用初始值。

以上这篇Pytorch中实现只导入部分模型参数的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现的防DDoS脚本
Feb 08 Python
举例讲解Python中is和id的用法
Apr 03 Python
python使用webbrowser浏览指定url的方法
Apr 04 Python
python创建和删除目录的方法
Apr 29 Python
Python提取网页中超链接的方法
Sep 18 Python
详解Python多线程Selenium跨浏览器测试
Apr 01 Python
PyTorch上实现卷积神经网络CNN的方法
Apr 28 Python
Python模拟百度自动输入搜索功能的实例
Feb 14 Python
python三方库之requests的快速上手
Mar 04 Python
如何为Python终端提供持久性历史记录
Sep 03 Python
python和node.js生成当前时间戳的示例
Sep 29 Python
python字典与json转换的方法总结
Dec 28 Python
PyTorch中topk函数的用法详解
Jan 02 #Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
You might like
PHPUnit PHP测试框架安装方法
2011/03/23 PHP
PHP实践教程之过滤、验证、转义与密码详解
2017/07/24 PHP
jquery 学习之二 属性(类)
2010/11/25 Javascript
jquery中:input和input的区别分析
2011/07/13 Javascript
JS解决ie6下png透明的方法实例
2013/08/02 Javascript
JavaScript伸缩的菜单简单示例
2013/12/03 Javascript
js动态修改表格行colspan列跨度的方法
2015/03/30 Javascript
jQuery中的ajax async同步和异步详解
2015/09/29 Javascript
使用jQuery.Qrcode插件在客户端动态生成二维码并添加自定义Logo
2016/09/01 Javascript
浅析vue-router原理
2018/10/19 Javascript
mpvue+vant app搭建微信小程序的方法步骤
2019/02/11 Javascript
通过npm或yarn自动生成vue组件的方法示例
2019/02/12 Javascript
vue+moment实现倒计时效果
2019/08/26 Javascript
JS数组的高级使用方法示例小结
2020/03/14 Javascript
vue通过过滤器实现数据格式化
2020/07/20 Javascript
python获取各操作系统硬件信息的方法
2015/06/03 Python
[原创]windows下Anaconda的安装与配置正解(Anaconda入门教程)
2018/04/05 Python
python实现排序算法解析
2018/09/08 Python
python批量复制图片到另一个文件夹
2018/09/17 Python
Python之pymysql的使用小结
2019/07/01 Python
python飞机大战pygame碰撞检测实现方法分析
2019/12/17 Python
pytorch中使用cuda扩展的实现示例
2020/02/12 Python
python GUI库图形界面开发之PyQt5日期时间控件QDateTimeEdit详细使用方法与实例
2020/02/27 Python
jupyter notebook 恢复误删单元格或者历史代码的实现
2020/04/17 Python
你的自行车健身专家:FaFit24
2016/11/16 全球购物
欧洲最大的预定车位市场:JustPark
2020/01/06 全球购物
网络体系结构及协议的定义
2014/03/13 面试题
送餐员岗位职责范本
2014/02/21 职场文书
机电职业生涯规划书范文
2014/03/08 职场文书
合作意向协议书范本
2014/03/31 职场文书
营销总监岗位职责
2014/09/16 职场文书
初中毕业生自我评价
2015/03/02 职场文书
政审证明材料
2015/06/19 职场文书
Pytorch使用shuffle打乱数据的操作
2021/05/20 Python
golang生成vcf通讯录格式文件详情
2022/03/25 Golang
CentOS下安装Jenkins的完整步骤
2022/04/07 Servers