Pytorch中实现只导入部分模型参数的方式


Posted in Python onJanuary 02, 2020

我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(don't expected)。我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed)。如果依旧使用torch.load(model.state_dict())的办法,就会出现 xxx expected,xxx missed类似的错误。那么在这种情况下,该如何导入模型呢?

好在Pytorch中的模型参数使用字典保存的,键是参数的名称,值是参数的具体数值。我们使用model.state_dict()获得这个字典,之后就能利用参数名称来实现导入。

请看下面的一个例子。

我们先搭建一个小小的网络。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
class Net(Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1 = nn.Conv2d(3,32,3,1)
    self.conv2 = nn.Conv2d(32,3,3,1)
    self.w = nn.Parameter(t.randn(3,10))
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out,(out.shape[2],out.shape[3]))
    out = F.linear(out,weight=self.w)
    return out

然后我们保存这个网络的初始值。

model = Net()
t.save(model.state_dict(),'xxx.pth')

现在我们将Net修改一下,多加几个卷积层,但并不加入到forward中,仅仅出于少些几行的目的。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
 
 
class Net(Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 3, 3, 1)
    self.conv3 = nn.Conv2d(3,64,3,1)
    self.conv4 = nn.Conv2d(64,32,3,1)
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
 
    self.w = nn.Parameter(t.randn(3, 10))
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out, (out.shape[2], out.shape[3]))
    out = F.linear(out, weight=self.w)
    return out

我们现在试着导入之前保存的模型参数。

path = 'xxx.pth'
model = Net()
model.load_state_dict(t.load(path))
 
'''
RuntimeError: Error(s) in loading state_dict for Net:
 Missing key(s) in state_dict: "conv3.weight", "conv3.bias", "conv4.weight", "conv4.bias". 
'''

出现了没有在模型文件中找到error中的关键字的错误。

现在我们这样导入模型

path = 'xxx.pth'
model = Net()
save_model = t.load(path)
model_dict = model.state_dict()
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)

看看上面的代码,很容易弄明白。其中model_dict.update的作用是更新代码中搭建的模型参数字典。为啥更新我其实并不清楚,但这一步骤是必须的,否则还会报错。

为了弄清楚为什么要更新model_dict,我们不妨分别输出state_dict和model_dict的关键值看一看。

for k in state_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
'''
for k in model_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
conv4.weight
conv4.bias
'''

这个结果也是预料之中的,所以我猜测,update之后,model_dict和state_dict中具有相同键的值已经同步了。updata的目的就是使model_dict带有state_dict中都具有的那一部分参数的值,对于model_dict中有的,但是save_dict中没有的参数,值不改变,参数仍然使用初始值。

以上这篇Pytorch中实现只导入部分模型参数的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python库urllib与urllib2主要区别分析
Jul 13 Python
python爬取51job中hr的邮箱
May 14 Python
python中MethodType方法介绍与使用示例
Aug 03 Python
使用Eclipse如何开发python脚本
Apr 11 Python
Django框架自定义模型管理器与元选项用法分析
Jul 22 Python
详解如何在cmd命令窗口中搭建简单的python开发环境
Aug 29 Python
pygame实现俄罗斯方块游戏(基础篇3)
Oct 29 Python
解决pyshp UnicodeDecodeError的问题
Dec 06 Python
django API 中接口的互相调用实例
Apr 01 Python
python 一维二维插值实例
Apr 22 Python
opencv 图像滤波(均值,方框,高斯,中值)
Jul 08 Python
Pytorch GPU内存占用很高,但是利用率很低如何解决
Jun 01 Python
PyTorch中topk函数的用法详解
Jan 02 #Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
You might like
PHP中的类-什么叫类
2006/11/20 PHP
PHP冒泡排序算法代码详细解读
2011/07/17 PHP
探讨捕获php错误信息方法的详解
2013/06/09 PHP
Laravel学习教程之model validation的使用示例
2017/10/23 PHP
PHP count()函数讲解
2019/02/03 PHP
ThinkPHP5.1+Ajax实现的无刷新分页功能示例
2020/02/10 PHP
js+FSO遍历文件夹下文件并显示
2007/03/07 Javascript
基于MVC3方式实现下拉列表联动(JQuery)
2013/09/02 Javascript
JavaScript调用ajax获取文本文件内容实现代码
2014/03/28 Javascript
JS中的进制转换以及作用
2016/06/26 Javascript
Vue.js动态组件解析
2016/09/09 Javascript
jQuery实现选项卡功能(两种方法)
2017/03/08 Javascript
AngularJS动态添加数据并删除的实例
2018/02/27 Javascript
浅谈Koa2框架利用CORS完成跨域ajax请求
2018/03/06 Javascript
Javasript设计模式之链式调用详解
2018/04/26 Javascript
一秒学会微信小程序制作table表格
2019/02/14 Javascript
JavaScript forEach中return失效问题解决方案
2020/06/01 Javascript
[38:54]完美世界DOTA2联赛PWL S2 Rebirth vs LBZS 第一场 11.28
2020/12/01 DOTA
Django在Win7下的安装及创建项目hello word简明教程
2014/07/14 Python
Python使用PyGreSQL操作PostgreSQL数据库教程
2014/07/30 Python
Python中的左斜杠、右斜杠(正斜杠和反斜杠)
2016/08/30 Python
Python排序搜索基本算法之希尔排序实例分析
2017/12/09 Python
python爬虫_实现校园网自动重连脚本的教程
2018/04/22 Python
python覆盖写入,追加写入的实例
2019/06/26 Python
简单了解python 生成器 列表推导式 生成器表达式
2019/08/22 Python
超实用的 30 段 Python 案例
2019/10/10 Python
Python实现子类调用父类的初始化实例
2020/03/12 Python
基于python获取本地时间并转换时间戳和日期格式
2020/10/27 Python
CSS Grid布局教程之什么是网格布局
2014/12/30 HTML / CSS
海滩咖啡馆:Beach Cafe
2018/02/02 全球购物
俄罗斯购买自行车网站:Vamvelosiped
2021/01/29 全球购物
安全生产网格化管理实施方案
2014/03/01 职场文书
2016年班主任新年寄语
2015/08/18 职场文书
社区志愿者服务心得体会
2016/01/22 职场文书
numpy数据类型dtype转换实现
2021/04/24 Python
pandas:get_dummies()与pd.factorize()的用法及区别说明
2021/05/21 Python