pytorch 在网络中添加可训练参数,修改预训练权重文件的方法


Posted in Python onAugust 17, 2019

实践中,针对不同的任务需求,我们经常会在现成的网络结构上做一定的修改来实现特定的目的。

假如我们现在有一个简单的两层感知机网络:

# -*- coding: utf-8 -*-
import torch
from torch.autograd import Variable
import torch.optim as optim
 
x = Variable(torch.FloatTensor([1, 2, 3])).cuda()
y = Variable(torch.FloatTensor([4, 5])).cuda()
 
class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.linear2(x)
 
    return x
 
model = MLP().cuda()
 
loss_fn = torch.nn.MSELoss(size_average=False)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
for t in range(500):
  y_pred = model(x)
  loss = loss_fn(y_pred, y)
  print(t, loss.data[0])
  model.zero_grad()
  loss.backward()
  optimizer.step()
 
print(model(x))

现在想在前向传播时,在relu之后给x乘以一个可训练的系数,只需要在__init__函数中添加一个nn.Parameter类型变量,并在forward函数中乘以该变量即可:

class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
    # the para to be added and updated in train phase, note that NO cuda() at last
    self.coefficient = torch.nn.Parameter(torch.Tensor([1.55]))
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.coefficient * x
    x = self.linear2(x)
 
    return x

注意,Parameter变量和Variable变量的操作大致相同,但是不能手动调用.cuda()方法将其加载在GPU上,事实上它会自动在GPU上加载,可以通过model.state_dict()或者model.named_parameters()函数查看现在的全部可训练参数(包括通过继承得到的父类中的参数):

print(model.state_dict().keys())
for i, j in model.named_parameters():
  print(i)
  print(j)

输出如下:

odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias'])
linear1.weight
Parameter containing:
-0.3582 -0.0283 0.2607
 0.5190 -0.2221 0.0665
-0.2586 -0.3311 0.1927
-0.2765 0.5590 -0.2598
 0.4679 -0.2923 -0.3379
[torch.cuda.FloatTensor of size 5x3 (GPU 0)]
 
linear1.bias
Parameter containing:
-0.2549
-0.5246
-0.1109
 0.5237
-0.1362
[torch.cuda.FloatTensor of size 5 (GPU 0)]
 
linear2.weight
Parameter containing:
-0.0286 -0.3045 0.1928 -0.2323 0.2966
 0.2601 0.1441 -0.2159 0.2484 0.0544
[torch.cuda.FloatTensor of size 2x5 (GPU 0)]
 
linear2.bias
Parameter containing:
-0.4038
 0.3129
[torch.cuda.FloatTensor of size 2 (GPU 0)]

这个参数会在反向传播时与原有变量同时参与更新,这就达到了添加可训练参数的目的。

如果我们有原先网络的预训练权重,现在添加了一个新的参数,原有的权重文件自然就不能加载了,我们需要修改原权重文件,在其中添加我们的新变量的初始值。

调用model.state_dict查看我们添加的参数在参数字典中的完整名称,然后打开原先的权重文件:

a = torch.load("OldWeights.pth") a是一个collecitons.OrderedDict类型变量,也就是一个有序字典,直接将新参数名称和初始值作为键值对插入,然后保存即可。

a = torch.load("OldWeights.pth")
 
a["layer1.0.coefficient"] = torch.FloatTensor([1.2])
a["layer1.1.coefficient"] = torch.FloatTensor([1.5])
 
torch.save(a, "Weights.pth")

现在权重就可以加载在修改后的模型上了。

以上这篇pytorch 在网络中添加可训练参数,修改预训练权重文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python局域网ip扫描示例分享
Apr 03 Python
python使用fcntl模块实现程序加锁功能示例
Jun 23 Python
python文件特定行插入和替换实例详解
Jul 12 Python
Python中循环引用(import)失败的解决方法
Apr 22 Python
Python 循环语句之 while,for语句详解
Apr 23 Python
python邮件发送smtplib使用详解
Jun 16 Python
python实现猜数字小游戏
Mar 24 Python
pandas DataFrame索引行列的实现
Jun 04 Python
Python 中Django安装和使用教程详解
Jul 03 Python
Python魔法方法 容器部方法详解
Jan 02 Python
pycharm导入源码的具体步骤
Aug 04 Python
python单向链表实例详解
May 25 Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
画pytorch模型图,以及参数计算的方法
Aug 17 #Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
You might like
用PHP读取flv文件的播放时间长度
2009/09/03 PHP
php中curl、fsocket、file_get_content三个函数的使用比较
2014/05/09 PHP
php查找指定目录下指定大小文件的方法
2014/11/28 PHP
php数据库的增删改查 php与javascript之间的交互
2017/08/31 PHP
PHP实现微信支付(jsapi支付)流程步骤详解
2018/03/15 PHP
PHP执行系统命令函数实例讲解
2021/03/03 PHP
js最简单的拖拽效果实现代码
2010/09/24 Javascript
得到form下的所有的input的js代码
2013/11/07 Javascript
jQGrid Table操作列中点击【操作】按钮弹出按钮层的实现代码
2016/12/05 Javascript
JS常见创建类的方法小结【工厂方式,构造器方式,原型方式,联合方式等】
2017/04/01 Javascript
vue实现留言板todolist功能
2017/08/16 Javascript
详解vue-cli构建项目反向代理配置
2017/09/07 Javascript
nodejs实现简单的gulp打包
2017/12/21 NodeJs
基于iScroll实现内容滚动效果
2018/03/21 Javascript
微信小程序车牌号码模拟键盘输入功能的实现代码
2018/11/11 Javascript
配置eslint规范项目代码风格
2019/03/11 Javascript
[05:48]DOTA2英雄梦之声vol21 屠夫
2014/06/20 DOTA
[02:49]DOTA2完美大师赛首日观众采访
2017/11/23 DOTA
Python的迭代器和生成器使用实例
2015/01/14 Python
Python实现二分查找算法实例
2015/05/26 Python
python实现mysql的单引号字符串过滤方法
2015/11/14 Python
Python绘制3d螺旋曲线图实例代码
2017/12/20 Python
Python实现最常见加密方式详解
2019/07/13 Python
pytorch 指定gpu训练与多gpu并行训练示例
2019/12/31 Python
matplotlib quiver箭图绘制案例
2020/04/17 Python
使用Pytorch搭建模型的步骤
2020/11/16 Python
python中scrapy处理项目数据的实例分析
2020/11/22 Python
美国环保妈妈、儿童和婴儿用品购物网站:The Tot
2019/11/24 全球购物
delegate与普通函数的区别
2014/01/22 面试题
关于工作时间玩手机的检讨书
2014/09/18 职场文书
学校机关党总支领导班子整改工作方案
2014/10/26 职场文书
服务整改报告
2014/11/06 职场文书
学生党员检讨书范文
2014/12/27 职场文书
2015年禁毒宣传活动总结
2015/03/25 职场文书
厉行节约工作总结
2015/08/12 职场文书
丧事答谢词大全
2015/09/30 职场文书