Pytorch中实现只导入部分模型参数的方式


Posted in Python onJanuary 02, 2020

我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(don't expected)。我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed)。如果依旧使用torch.load(model.state_dict())的办法,就会出现 xxx expected,xxx missed类似的错误。那么在这种情况下,该如何导入模型呢?

好在Pytorch中的模型参数使用字典保存的,键是参数的名称,值是参数的具体数值。我们使用model.state_dict()获得这个字典,之后就能利用参数名称来实现导入。

请看下面的一个例子。

我们先搭建一个小小的网络。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
class Net(Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1 = nn.Conv2d(3,32,3,1)
    self.conv2 = nn.Conv2d(32,3,3,1)
    self.w = nn.Parameter(t.randn(3,10))
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out,(out.shape[2],out.shape[3]))
    out = F.linear(out,weight=self.w)
    return out

然后我们保存这个网络的初始值。

model = Net()
t.save(model.state_dict(),'xxx.pth')

现在我们将Net修改一下,多加几个卷积层,但并不加入到forward中,仅仅出于少些几行的目的。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
 
 
class Net(Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 3, 3, 1)
    self.conv3 = nn.Conv2d(3,64,3,1)
    self.conv4 = nn.Conv2d(64,32,3,1)
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
 
    self.w = nn.Parameter(t.randn(3, 10))
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out, (out.shape[2], out.shape[3]))
    out = F.linear(out, weight=self.w)
    return out

我们现在试着导入之前保存的模型参数。

path = 'xxx.pth'
model = Net()
model.load_state_dict(t.load(path))
 
'''
RuntimeError: Error(s) in loading state_dict for Net:
 Missing key(s) in state_dict: "conv3.weight", "conv3.bias", "conv4.weight", "conv4.bias". 
'''

出现了没有在模型文件中找到error中的关键字的错误。

现在我们这样导入模型

path = 'xxx.pth'
model = Net()
save_model = t.load(path)
model_dict = model.state_dict()
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)

看看上面的代码,很容易弄明白。其中model_dict.update的作用是更新代码中搭建的模型参数字典。为啥更新我其实并不清楚,但这一步骤是必须的,否则还会报错。

为了弄清楚为什么要更新model_dict,我们不妨分别输出state_dict和model_dict的关键值看一看。

for k in state_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
'''
for k in model_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
conv4.weight
conv4.bias
'''

这个结果也是预料之中的,所以我猜测,update之后,model_dict和state_dict中具有相同键的值已经同步了。updata的目的就是使model_dict带有state_dict中都具有的那一部分参数的值,对于model_dict中有的,但是save_dict中没有的参数,值不改变,参数仍然使用初始值。

以上这篇Pytorch中实现只导入部分模型参数的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 排列组合之itertools
Mar 20 Python
python网络编程学习笔记(九):数据库客户端 DB-API
Jun 09 Python
把项目从Python2.x移植到Python3.x的经验总结
Apr 20 Python
使用Python的web.py框架实现类似Django的ORM查询的教程
May 02 Python
python中requests使用代理proxies方法介绍
Oct 25 Python
利用Python在一个文件的头部插入数据的实例
May 02 Python
多个应用共存的Django配置方法
May 30 Python
想学python 这5本书籍你必看!
Dec 11 Python
python global关键字的用法详解
Sep 05 Python
DRF框架API版本管理实现方法解析
Aug 21 Python
python实现腾讯滑块验证码识别
Apr 27 Python
详解分布式系统中如何用python实现Paxos
May 18 Python
PyTorch中topk函数的用法详解
Jan 02 #Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
You might like
php遍历目录输出目录及其下的所有文件示例
2014/01/27 PHP
ThinkPHP之M方法实例详解
2014/06/20 PHP
PHP中的插件机制原理和实例
2014/07/08 PHP
visual studio code 调试php方法(图文详解)
2017/09/15 PHP
PHP设计模式之适配器模式原理与用法分析
2018/04/25 PHP
PHP校验15位和18位身份证号的类封装
2018/11/07 PHP
PDO::errorCode讲解
2019/01/28 PHP
javascript实现的距离现在多长时间后的一个格式化的日期
2009/10/29 Javascript
JQuery操作元素的css样式
2015/03/09 Javascript
JavaScript中几种排序算法的简单实现
2015/07/29 Javascript
JavaScript中Array的实用操作技巧分享
2016/09/11 Javascript
jQuery插件FusionCharts绘制2D环饼图效果示例【附demo源码】
2017/04/10 jQuery
JavaScript编程设计模式之构造器模式实例分析
2017/10/25 Javascript
微信小程序实现刷脸登录
2018/05/25 Javascript
webpack 样式加载的实现原理
2018/06/12 Javascript
Node.js 使用axios读写influxDB的方法示例
2018/10/26 Javascript
vue微信分享到朋友圈 vue微信发送给好友
2018/11/28 Javascript
Vue路由模块化配置的完整步骤
2019/08/14 Javascript
jQuery zTree树插件的使用教程
2019/08/16 jQuery
理解Proxy及使用Proxy实现vue数据双向绑定操作
2020/07/18 Javascript
[55:54]FNATIC vs EG 2019国际邀请赛小组赛 BO2 第一场 8.15
2019/08/16 DOTA
用Python抢过年的火车票附源码
2015/12/07 Python
Python内存管理实例分析
2019/07/10 Python
Python实现Selenium自动化Page模式
2019/07/14 Python
使用tensorflow实现矩阵分解方式
2020/02/07 Python
Python+Kepler.gl轻松制作酷炫路径动画的实现示例
2020/06/02 Python
QML用PathView实现轮播图
2020/06/03 Python
新西兰演唱会和体育门票网站:Ticketmaster新西兰
2017/10/07 全球购物
娱乐地球:Entertainment Earth
2020/01/08 全球购物
SQL注入攻击的种类有哪些
2013/12/30 面试题
node中使用shell脚本的方法步骤
2021/03/23 Javascript
中英文自我评价语句
2013/12/20 职场文书
建筑班组长岗位职责
2014/01/02 职场文书
教育技术职业规划范文
2014/03/04 职场文书
求职简历自荐信
2014/06/18 职场文书
2015年幼儿园国庆节活动总结
2015/07/30 职场文书