Pytorch中实现只导入部分模型参数的方式


Posted in Python onJanuary 02, 2020

我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(don't expected)。我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed)。如果依旧使用torch.load(model.state_dict())的办法,就会出现 xxx expected,xxx missed类似的错误。那么在这种情况下,该如何导入模型呢?

好在Pytorch中的模型参数使用字典保存的,键是参数的名称,值是参数的具体数值。我们使用model.state_dict()获得这个字典,之后就能利用参数名称来实现导入。

请看下面的一个例子。

我们先搭建一个小小的网络。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
class Net(Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1 = nn.Conv2d(3,32,3,1)
    self.conv2 = nn.Conv2d(32,3,3,1)
    self.w = nn.Parameter(t.randn(3,10))
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out,(out.shape[2],out.shape[3]))
    out = F.linear(out,weight=self.w)
    return out

然后我们保存这个网络的初始值。

model = Net()
t.save(model.state_dict(),'xxx.pth')

现在我们将Net修改一下,多加几个卷积层,但并不加入到forward中,仅仅出于少些几行的目的。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
 
 
class Net(Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 3, 3, 1)
    self.conv3 = nn.Conv2d(3,64,3,1)
    self.conv4 = nn.Conv2d(64,32,3,1)
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
 
    self.w = nn.Parameter(t.randn(3, 10))
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out, (out.shape[2], out.shape[3]))
    out = F.linear(out, weight=self.w)
    return out

我们现在试着导入之前保存的模型参数。

path = 'xxx.pth'
model = Net()
model.load_state_dict(t.load(path))
 
'''
RuntimeError: Error(s) in loading state_dict for Net:
 Missing key(s) in state_dict: "conv3.weight", "conv3.bias", "conv4.weight", "conv4.bias". 
'''

出现了没有在模型文件中找到error中的关键字的错误。

现在我们这样导入模型

path = 'xxx.pth'
model = Net()
save_model = t.load(path)
model_dict = model.state_dict()
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)

看看上面的代码,很容易弄明白。其中model_dict.update的作用是更新代码中搭建的模型参数字典。为啥更新我其实并不清楚,但这一步骤是必须的,否则还会报错。

为了弄清楚为什么要更新model_dict,我们不妨分别输出state_dict和model_dict的关键值看一看。

for k in state_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
'''
for k in model_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
conv4.weight
conv4.bias
'''

这个结果也是预料之中的,所以我猜测,update之后,model_dict和state_dict中具有相同键的值已经同步了。updata的目的就是使model_dict带有state_dict中都具有的那一部分参数的值,对于model_dict中有的,但是save_dict中没有的参数,值不改变,参数仍然使用初始值。

以上这篇Pytorch中实现只导入部分模型参数的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python tempfile模块学习笔记(临时文件)
May 25 Python
Python3 正在毁灭 Python的原因分析
Nov 28 Python
python模拟enum枚举类型的方法小结
Apr 30 Python
使用python批量化音乐文件格式转换的实例
Jan 09 Python
完美解决Python matplotlib绘图时汉字显示不正常的问题
Jan 29 Python
很酷的python表白工具 你喜欢我吗
Apr 11 Python
使用python编写一个语音朗读闹钟功能的示例代码
Jul 14 Python
Python爬虫Scrapy框架CrawlSpider原理及使用案例
Nov 20 Python
python 根据列表批量下载网易云音乐的免费音乐
Dec 03 Python
Pytorch如何切换 cpu和gpu的使用详解
Mar 01 Python
字典算法实现及操作 --python(实用)
Mar 31 Python
Python采集股票数据并制作可视化柱状图
Apr 04 Python
PyTorch中topk函数的用法详解
Jan 02 #Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
You might like
php实现的Cookies操作类实例
2014/09/24 PHP
thinkphp中html:list标签传递多个参数实例
2014/10/30 PHP
关于Laravel参数验证的一些疑与惑
2019/11/19 PHP
Javascript倒计时代码
2010/08/12 Javascript
js跨浏览器实现将字符串转化为xml对象的方法
2013/09/25 Javascript
jquery在项目中做复选框时遇到的一些问题笔记
2013/11/17 Javascript
javascript数组快速打乱重排的方法
2014/01/02 Javascript
js数组去重的常用方法总结
2014/01/24 Javascript
Jquery之Bind方法参数传递与接收的三种方法
2014/06/24 Javascript
jQuery EasyUI Pagination实现分页的常用方法
2016/05/21 Javascript
关于vue.js弹窗组件的知识点总结
2016/09/11 Javascript
JavaScript中的编码和解码函数
2017/02/15 Javascript
详解Windows下安装Nodejs步骤
2017/05/18 NodeJs
jquery.rotate.js实现可选抽奖次数和中奖内容的转盘抽奖代码
2017/08/23 jQuery
ES6中字符串string常用的新增方法小结
2017/11/07 Javascript
详解Node.js模板引擎Jade入门
2018/01/19 Javascript
vue 引入公共css文件的简单方法(推荐)
2018/01/20 Javascript
Python 调用DLL操作抄表机
2009/01/12 Python
浅谈Python类里的__init__方法函数,Python类的构造函数
2016/12/10 Python
python实现移位加密和解密
2019/03/22 Python
python字典一键多值实例代码分享
2019/06/14 Python
如何通过Python实现标签云算法
2019/07/02 Python
python有序查找算法 二分法实例解析
2020/02/18 Python
Python3.8.2安装包及安装教程图文详解(附安装包)
2020/11/28 Python
椰子猫砂:CatSpot
2018/08/27 全球购物
网络信息管理员岗位职责
2014/01/05 职场文书
十佳大学生事迹材料
2014/01/29 职场文书
远程网络教育毕业生自我鉴定
2014/04/14 职场文书
文明礼仪倡议书
2015/04/28 职场文书
网吧员工管理制度
2015/08/05 职场文书
学习委员竞选稿
2015/11/20 职场文书
怎样评估创业计划书是否有可行性?
2019/08/07 职场文书
Pytest之测试命名规则的使用
2021/04/16 Python
深入浅析python3 依赖倒置原则(示例代码)
2021/07/09 Python
Mysql调整优化之四种分区方式以及组合分区
2022/04/13 MySQL
如何vue使用el-table遍历循环表头和表体数据
2022/04/26 Vue.js