Pytorch中实现只导入部分模型参数的方式


Posted in Python onJanuary 02, 2020

我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(don't expected)。我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed)。如果依旧使用torch.load(model.state_dict())的办法,就会出现 xxx expected,xxx missed类似的错误。那么在这种情况下,该如何导入模型呢?

好在Pytorch中的模型参数使用字典保存的,键是参数的名称,值是参数的具体数值。我们使用model.state_dict()获得这个字典,之后就能利用参数名称来实现导入。

请看下面的一个例子。

我们先搭建一个小小的网络。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
class Net(Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1 = nn.Conv2d(3,32,3,1)
    self.conv2 = nn.Conv2d(32,3,3,1)
    self.w = nn.Parameter(t.randn(3,10))
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out,(out.shape[2],out.shape[3]))
    out = F.linear(out,weight=self.w)
    return out

然后我们保存这个网络的初始值。

model = Net()
t.save(model.state_dict(),'xxx.pth')

现在我们将Net修改一下,多加几个卷积层,但并不加入到forward中,仅仅出于少些几行的目的。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
 
 
class Net(Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 3, 3, 1)
    self.conv3 = nn.Conv2d(3,64,3,1)
    self.conv4 = nn.Conv2d(64,32,3,1)
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
 
    self.w = nn.Parameter(t.randn(3, 10))
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out, (out.shape[2], out.shape[3]))
    out = F.linear(out, weight=self.w)
    return out

我们现在试着导入之前保存的模型参数。

path = 'xxx.pth'
model = Net()
model.load_state_dict(t.load(path))
 
'''
RuntimeError: Error(s) in loading state_dict for Net:
 Missing key(s) in state_dict: "conv3.weight", "conv3.bias", "conv4.weight", "conv4.bias". 
'''

出现了没有在模型文件中找到error中的关键字的错误。

现在我们这样导入模型

path = 'xxx.pth'
model = Net()
save_model = t.load(path)
model_dict = model.state_dict()
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)

看看上面的代码,很容易弄明白。其中model_dict.update的作用是更新代码中搭建的模型参数字典。为啥更新我其实并不清楚,但这一步骤是必须的,否则还会报错。

为了弄清楚为什么要更新model_dict,我们不妨分别输出state_dict和model_dict的关键值看一看。

for k in state_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
'''
for k in model_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
conv4.weight
conv4.bias
'''

这个结果也是预料之中的,所以我猜测,update之后,model_dict和state_dict中具有相同键的值已经同步了。updata的目的就是使model_dict带有state_dict中都具有的那一部分参数的值,对于model_dict中有的,但是save_dict中没有的参数,值不改变,参数仍然使用初始值。

以上这篇Pytorch中实现只导入部分模型参数的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3基础之条件与循环控制实例解析
Aug 13 Python
Python的string模块中的Template类字符串模板用法
Jun 27 Python
Python使用filetype精确判断文件类型
Jul 02 Python
详解用python实现简单的遗传算法
Jan 02 Python
Python对数据进行插值和下采样的方法
Jul 03 Python
python2和python3的输入和输出区别介绍
Nov 20 Python
OpenCV 表盘指针自动读数的示例代码
Apr 10 Python
基于PyTorch的permute和reshape/view的区别介绍
Jun 18 Python
python 使用tkinter+you-get实现视频下载器
Nov 17 Python
python中zip()函数遍历多个列表方法
Feb 18 Python
如何用python反转图片,视频
Apr 24 Python
只用Python就可以制作的简单词云
Jun 07 Python
PyTorch中topk函数的用法详解
Jan 02 #Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
You might like
一键生成各种尺寸Icon的php脚本(实例)
2017/02/08 PHP
PHP简单实现合并2个数字键数组值的方法
2017/05/30 PHP
Swoole 5将移除自动添加Event::wait()特性详解
2019/07/10 PHP
laravel高级的Join语法详解以及使用Join多个条件
2019/10/16 PHP
phpcmsv9.0任意文件上传漏洞解析
2020/10/20 PHP
为JavaScript添加重载函数的辅助方法
2010/07/04 Javascript
Firefox中autocomplete="off" 设置不起作用Bug的解决方法
2011/03/25 Javascript
jquery fancybox ie6不显示关闭按钮的解决办法
2013/12/25 Javascript
javascript实现简单的分页特效
2015/08/12 Javascript
JS实现仿雅虎首页快捷登录入口及导航模块效果
2015/09/19 Javascript
jquery使用on绑定a标签无效 只能用live解决
2016/06/02 Javascript
js实现自动轮换选项卡
2017/01/13 Javascript
babel基本使用详解
2017/02/17 Javascript
Vue中添加过渡效果的方法
2017/03/16 Javascript
react-native中ListView组件点击跳转的方法示例
2017/09/30 Javascript
解决layUI的页面显示不全的问题
2019/09/20 Javascript
layui下拉列表select实现可输入查找的方法
2019/09/28 Javascript
vue组件开发之tab切换组件使用详解
2020/08/21 Javascript
python计算程序开始到程序结束的运行时间和程序运行的CPU时间
2013/11/28 Python
Python中列表的一些基本操作知识汇总
2015/05/20 Python
python使用正则表达式提取网页URL的方法
2015/05/26 Python
详解Python网络爬虫功能的基本写法
2016/01/28 Python
python3利用venv配置虚拟环境及过程中的小问题小结
2018/08/01 Python
Pytorch: 自定义网络层实例
2020/01/07 Python
python_array[0][0]与array[0,0]的区别详解
2020/02/18 Python
python多维数组分位数的求取方式
2020/03/03 Python
Django ORM filter() 的运用详解
2020/05/14 Python
纯CSS3实现扇形动画菜单(简化版)实例源码
2017/01/17 HTML / CSS
EQVVS官网:设计师男装和女装
2018/10/24 全球购物
某公司面试题
2012/03/05 面试题
法定代表人授权委托书范本
2014/10/07 职场文书
2014年外贸业务员工作总结
2014/12/11 职场文书
2015年学校德育工作总结
2015/04/22 职场文书
公司地址变更通知
2015/04/25 职场文书
2016暑期社会实践新闻稿
2015/11/25 职场文书
自考生自我评价
2019/06/21 职场文书