Pytorch中实现只导入部分模型参数的方式


Posted in Python onJanuary 02, 2020

我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(don't expected)。我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed)。如果依旧使用torch.load(model.state_dict())的办法,就会出现 xxx expected,xxx missed类似的错误。那么在这种情况下,该如何导入模型呢?

好在Pytorch中的模型参数使用字典保存的,键是参数的名称,值是参数的具体数值。我们使用model.state_dict()获得这个字典,之后就能利用参数名称来实现导入。

请看下面的一个例子。

我们先搭建一个小小的网络。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
class Net(Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1 = nn.Conv2d(3,32,3,1)
    self.conv2 = nn.Conv2d(32,3,3,1)
    self.w = nn.Parameter(t.randn(3,10))
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out,(out.shape[2],out.shape[3]))
    out = F.linear(out,weight=self.w)
    return out

然后我们保存这个网络的初始值。

model = Net()
t.save(model.state_dict(),'xxx.pth')

现在我们将Net修改一下,多加几个卷积层,但并不加入到forward中,仅仅出于少些几行的目的。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
 
 
class Net(Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 3, 3, 1)
    self.conv3 = nn.Conv2d(3,64,3,1)
    self.conv4 = nn.Conv2d(64,32,3,1)
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
 
    self.w = nn.Parameter(t.randn(3, 10))
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out, (out.shape[2], out.shape[3]))
    out = F.linear(out, weight=self.w)
    return out

我们现在试着导入之前保存的模型参数。

path = 'xxx.pth'
model = Net()
model.load_state_dict(t.load(path))
 
'''
RuntimeError: Error(s) in loading state_dict for Net:
 Missing key(s) in state_dict: "conv3.weight", "conv3.bias", "conv4.weight", "conv4.bias". 
'''

出现了没有在模型文件中找到error中的关键字的错误。

现在我们这样导入模型

path = 'xxx.pth'
model = Net()
save_model = t.load(path)
model_dict = model.state_dict()
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)

看看上面的代码,很容易弄明白。其中model_dict.update的作用是更新代码中搭建的模型参数字典。为啥更新我其实并不清楚,但这一步骤是必须的,否则还会报错。

为了弄清楚为什么要更新model_dict,我们不妨分别输出state_dict和model_dict的关键值看一看。

for k in state_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
'''
for k in model_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
conv4.weight
conv4.bias
'''

这个结果也是预料之中的,所以我猜测,update之后,model_dict和state_dict中具有相同键的值已经同步了。updata的目的就是使model_dict带有state_dict中都具有的那一部分参数的值,对于model_dict中有的,但是save_dict中没有的参数,值不改变,参数仍然使用初始值。

以上这篇Pytorch中实现只导入部分模型参数的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现子类调用父类的方法
Nov 10 Python
Python实现把回车符\r\n转换成\n
Apr 23 Python
python生成器generator用法实例分析
Jun 04 Python
Vue的el-scrollbar实现自定义滚动
May 29 Python
浅谈python之新式类
Aug 12 Python
python 常见字符串与函数的用法详解
Nov 23 Python
Python函数返回不定数量的值方法
Jan 22 Python
Python实现简单石头剪刀布游戏
Jan 20 Python
在Python中COM口的调用方法
Jul 03 Python
Django 解决阿里云部署同步数据库报错的问题
May 14 Python
Python嵌入C/C++进行开发详解
Jun 09 Python
python 写一个文件分发小程序
Dec 05 Python
PyTorch中topk函数的用法详解
Jan 02 #Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
You might like
PHP读取RSS(Feed)简单实例
2014/06/12 PHP
三个思路解决laravel上传文件报错:413 Request Entity Too Large问题
2017/11/13 PHP
浅谈PHP中如何实现Hook机制
2017/11/14 PHP
阿里云Win2016安装Apache和PHP环境图文教程
2018/03/11 PHP
php获取手机端的号码以及ip地址实例代码
2018/09/12 PHP
laravel中数据显示方法(默认值和下拉option默认选中)
2019/10/11 PHP
js cookies 常见网页木马挂马代码 24小时只加载一次
2009/04/13 Javascript
JavaScript创建一个欢迎cookie弹出窗实现代码
2013/03/15 Javascript
表格奇偶行设置不同颜色的核心JS代码
2013/12/24 Javascript
DOM基础教程之事件类型
2015/01/20 Javascript
一道优雅面试题分析js中fn()和return fn()的区别
2016/07/05 Javascript
一个炫酷的Bootstrap导航菜单
2016/12/28 Javascript
nodejs入门教程二:创建一个简单应用示例
2017/04/24 NodeJs
小程序获取当前位置加搜索附近热门小区及商区的方法
2019/04/08 Javascript
js实现从右往左匀速显示图片(无缝轮播)
2020/06/29 Javascript
Vue+Java+Base64实现条码解析的示例
2020/09/23 Javascript
详解python如何在django中为用户模型添加自定义权限
2018/10/15 Python
Python用61行代码实现图片像素化的示例代码
2018/12/10 Python
python调用java的jar包方法
2018/12/15 Python
python 监听salt job状态,并任务数据推送到redis中的方法
2019/01/14 Python
Python进阶之全面解读高级特性之切片
2019/02/19 Python
图文详解Django使用Pycharm连接MySQL数据库
2019/08/09 Python
详解Django将秒转换为xx天xx时xx分
2019/09/27 Python
Python 网络编程之TCP客户端/服务端功能示例【基于socket套接字】
2019/10/12 Python
匈牙利最大的健身制造商和销售商:inSPORTline
2018/10/30 全球购物
"火柴棍式"程序员面试题
2014/03/16 面试题
如何在存储过程中使用Loop
2016/01/05 面试题
中学教师实习自我鉴定
2013/09/28 职场文书
淘宝活动总结范文
2014/06/26 职场文书
机械设备与数控技术专业求职信
2014/08/10 职场文书
高中课前三分钟演讲稿
2014/08/18 职场文书
离婚纠纷代理词
2015/05/23 职场文书
2015年高校教师个人工作总结
2015/05/25 职场文书
2016五一劳动节慰问信
2015/11/30 职场文书
用Python可视化新冠疫情数据
2022/01/18 Python
javascript进阶篇深拷贝实现的四种方式
2022/07/07 Javascript