Pytorch中实现只导入部分模型参数的方式


Posted in Python onJanuary 02, 2020

我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(don't expected)。我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed)。如果依旧使用torch.load(model.state_dict())的办法,就会出现 xxx expected,xxx missed类似的错误。那么在这种情况下,该如何导入模型呢?

好在Pytorch中的模型参数使用字典保存的,键是参数的名称,值是参数的具体数值。我们使用model.state_dict()获得这个字典,之后就能利用参数名称来实现导入。

请看下面的一个例子。

我们先搭建一个小小的网络。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
class Net(Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1 = nn.Conv2d(3,32,3,1)
    self.conv2 = nn.Conv2d(32,3,3,1)
    self.w = nn.Parameter(t.randn(3,10))
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out,(out.shape[2],out.shape[3]))
    out = F.linear(out,weight=self.w)
    return out

然后我们保存这个网络的初始值。

model = Net()
t.save(model.state_dict(),'xxx.pth')

现在我们将Net修改一下,多加几个卷积层,但并不加入到forward中,仅仅出于少些几行的目的。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
 
 
class Net(Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 3, 3, 1)
    self.conv3 = nn.Conv2d(3,64,3,1)
    self.conv4 = nn.Conv2d(64,32,3,1)
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
 
    self.w = nn.Parameter(t.randn(3, 10))
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out, (out.shape[2], out.shape[3]))
    out = F.linear(out, weight=self.w)
    return out

我们现在试着导入之前保存的模型参数。

path = 'xxx.pth'
model = Net()
model.load_state_dict(t.load(path))
 
'''
RuntimeError: Error(s) in loading state_dict for Net:
 Missing key(s) in state_dict: "conv3.weight", "conv3.bias", "conv4.weight", "conv4.bias". 
'''

出现了没有在模型文件中找到error中的关键字的错误。

现在我们这样导入模型

path = 'xxx.pth'
model = Net()
save_model = t.load(path)
model_dict = model.state_dict()
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)

看看上面的代码,很容易弄明白。其中model_dict.update的作用是更新代码中搭建的模型参数字典。为啥更新我其实并不清楚,但这一步骤是必须的,否则还会报错。

为了弄清楚为什么要更新model_dict,我们不妨分别输出state_dict和model_dict的关键值看一看。

for k in state_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
'''
for k in model_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
conv4.weight
conv4.bias
'''

这个结果也是预料之中的,所以我猜测,update之后,model_dict和state_dict中具有相同键的值已经同步了。updata的目的就是使model_dict带有state_dict中都具有的那一部分参数的值,对于model_dict中有的,但是save_dict中没有的参数,值不改变,参数仍然使用初始值。

以上这篇Pytorch中实现只导入部分模型参数的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
php使用递归与迭代实现快速排序示例
Jan 23 Python
Python中的类学习笔记
Sep 23 Python
python轻松查到删除自己的微信好友
Jan 10 Python
Django小白教程之Django用户注册与登录
Apr 22 Python
Python写的一个定时重跑获取数据库数据
Dec 28 Python
Python基于lxml模块解析html获取页面内所有叶子节点xpath路径功能示例
May 16 Python
Python装饰器知识点补充
May 28 Python
Django 开发调试工具 Django-debug-toolbar使用详解
Jul 23 Python
Python sep参数使用方法详解
Feb 12 Python
不到20行实现Python代码即可制作精美证件照
Apr 24 Python
Django 如何使用日期时间选择器规范用户的时间输入示例代码详解
May 22 Python
Python创建简单的神经网络实例讲解
Jan 04 Python
PyTorch中topk函数的用法详解
Jan 02 #Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
You might like
PHP获取当前url的具体方法全面解析
2013/11/26 PHP
PHP实现下载远程图片保存到本地的方法
2017/06/19 PHP
PHP实现打包zip并下载功能
2018/06/12 PHP
深入解析PHP底层机制及相关原理
2020/12/11 PHP
JS获取scrollHeight问题想到的标准问题
2007/05/27 Javascript
javascript控制swfObject应用介绍
2012/11/29 Javascript
JS可以控制样式的名称写法一览
2014/01/16 Javascript
js使用html()或text()方法获取设置p标签的显示的值
2014/08/01 Javascript
JavaScript弹出窗口方法汇总
2014/08/12 Javascript
兼容主流浏览器的jQuery+CSS 实现遮罩层的简单代码
2014/10/14 Javascript
Bootstrap Table使用心得总结
2016/11/29 Javascript
javascript数据类型详解
2017/02/07 Javascript
jQuery中layer分页器的使用
2017/03/13 Javascript
vue使用echarts图表的详细方法
2018/10/22 Javascript
微信小程序引用iconfont图标的方法
2018/10/22 Javascript
Vue.js 事件修饰符的使用教程
2018/11/01 Javascript
js实现单元格拖拽效果
2020/02/10 Javascript
浅谈vue 二级路由嵌套和二级路由高亮问题
2020/08/06 Javascript
Python中Class类用法实例分析
2015/11/12 Python
Using Django with GAE Python 后台抓取多个网站的页面全文
2016/02/17 Python
python实现监控某个服务 服务崩溃即发送邮件报告
2018/06/21 Python
Python爬虫之正则表达式基本用法实例分析
2018/08/08 Python
Python+Pyqt实现简单GUI电子时钟
2021/02/22 Python
解决Django 在ForeignKey中出现 non-nullable field错误的问题
2019/08/06 Python
解决pycharm 安装numpy失败的问题
2019/12/05 Python
Python3 把一个列表按指定数目分成多个列表的方式
2019/12/25 Python
Python使用qrcode二维码库生成二维码方法详解
2020/02/17 Python
Django单元测试中Fixtures的使用方法
2020/02/26 Python
python为什么要安装到c盘
2020/07/20 Python
Python 利用Entrez库筛选下载PubMed文献摘要的示例
2020/11/24 Python
东方电视购物:东方CJ
2016/10/12 全球购物
印尼太阳百货公司网站:Matahari
2018/02/04 全球购物
高中政治教学反思
2014/01/18 职场文书
大学运动会入场词
2014/02/22 职场文书
教师自我剖析材料(群众路线)
2014/09/29 职场文书
Mysql8.0递归查询的简单用法示例
2021/08/04 MySQL