pytorch 在网络中添加可训练参数,修改预训练权重文件的方法


Posted in Python onAugust 17, 2019

实践中,针对不同的任务需求,我们经常会在现成的网络结构上做一定的修改来实现特定的目的。

假如我们现在有一个简单的两层感知机网络:

# -*- coding: utf-8 -*-
import torch
from torch.autograd import Variable
import torch.optim as optim
 
x = Variable(torch.FloatTensor([1, 2, 3])).cuda()
y = Variable(torch.FloatTensor([4, 5])).cuda()
 
class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.linear2(x)
 
    return x
 
model = MLP().cuda()
 
loss_fn = torch.nn.MSELoss(size_average=False)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
for t in range(500):
  y_pred = model(x)
  loss = loss_fn(y_pred, y)
  print(t, loss.data[0])
  model.zero_grad()
  loss.backward()
  optimizer.step()
 
print(model(x))

现在想在前向传播时,在relu之后给x乘以一个可训练的系数,只需要在__init__函数中添加一个nn.Parameter类型变量,并在forward函数中乘以该变量即可:

class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
    # the para to be added and updated in train phase, note that NO cuda() at last
    self.coefficient = torch.nn.Parameter(torch.Tensor([1.55]))
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.coefficient * x
    x = self.linear2(x)
 
    return x

注意,Parameter变量和Variable变量的操作大致相同,但是不能手动调用.cuda()方法将其加载在GPU上,事实上它会自动在GPU上加载,可以通过model.state_dict()或者model.named_parameters()函数查看现在的全部可训练参数(包括通过继承得到的父类中的参数):

print(model.state_dict().keys())
for i, j in model.named_parameters():
  print(i)
  print(j)

输出如下:

odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias'])
linear1.weight
Parameter containing:
-0.3582 -0.0283 0.2607
 0.5190 -0.2221 0.0665
-0.2586 -0.3311 0.1927
-0.2765 0.5590 -0.2598
 0.4679 -0.2923 -0.3379
[torch.cuda.FloatTensor of size 5x3 (GPU 0)]
 
linear1.bias
Parameter containing:
-0.2549
-0.5246
-0.1109
 0.5237
-0.1362
[torch.cuda.FloatTensor of size 5 (GPU 0)]
 
linear2.weight
Parameter containing:
-0.0286 -0.3045 0.1928 -0.2323 0.2966
 0.2601 0.1441 -0.2159 0.2484 0.0544
[torch.cuda.FloatTensor of size 2x5 (GPU 0)]
 
linear2.bias
Parameter containing:
-0.4038
 0.3129
[torch.cuda.FloatTensor of size 2 (GPU 0)]

这个参数会在反向传播时与原有变量同时参与更新,这就达到了添加可训练参数的目的。

如果我们有原先网络的预训练权重,现在添加了一个新的参数,原有的权重文件自然就不能加载了,我们需要修改原权重文件,在其中添加我们的新变量的初始值。

调用model.state_dict查看我们添加的参数在参数字典中的完整名称,然后打开原先的权重文件:

a = torch.load("OldWeights.pth") a是一个collecitons.OrderedDict类型变量,也就是一个有序字典,直接将新参数名称和初始值作为键值对插入,然后保存即可。

a = torch.load("OldWeights.pth")
 
a["layer1.0.coefficient"] = torch.FloatTensor([1.2])
a["layer1.1.coefficient"] = torch.FloatTensor([1.5])
 
torch.save(a, "Weights.pth")

现在权重就可以加载在修改后的模型上了。

以上这篇pytorch 在网络中添加可训练参数,修改预训练权重文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python笔记(2)
Oct 24 Python
Python遍历目录中的所有文件的方法
Jul 08 Python
python 网络编程详解及简单实例
Apr 25 Python
Python常用字符串替换函数strip、replace及sub用法示例
May 21 Python
Python测试线程应用程序过程解析
Dec 31 Python
Python基于QQ邮箱实现SSL发送
Apr 26 Python
Python xlwt模块使用代码实例
Jun 10 Python
基于Python的自媒体小助手---登录页面的实现代码
Jun 29 Python
解决python运行效率不高的问题
Jul 20 Python
Python如何在单元测试中给对象打补丁
Aug 03 Python
利用Python pandas对Excel进行合并的方法示例
Nov 04 Python
Django中的JWT身份验证的实现
May 07 Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
画pytorch模型图,以及参数计算的方法
Aug 17 #Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
You might like
改德生G88 - 加装等响度低音提升电路
2021/03/02 无线电
漂亮但不安全的CTB
2006/10/09 PHP
利用PHP函数计算中英文字符串长度的方法
2014/11/11 PHP
php自定义错误处理用法实例
2015/03/20 PHP
PHP实现原比例生成缩略图的方法
2016/02/03 PHP
在云虚拟主机部署thinkphp5项目的步骤详解
2017/12/21 PHP
PHP生成指定范围内的N个不重复的随机数
2019/03/18 PHP
Yii 框架控制器创建使用及控制器响应操作示例
2019/10/14 PHP
asp批量修改记录的代码
2008/06/25 Javascript
FileUpload上传图片(图片不变形)
2010/08/05 Javascript
javascript获取鼠标位置部分的实例代码(兼容IE,FF)
2013/08/05 Javascript
一个奇葩的最短的 IE 版本判断JS脚本
2014/05/28 Javascript
JQuery实现表格动态增加行并对新行添加事件
2014/07/30 Javascript
JavaScript中浅讲ajax图文详解
2016/11/11 Javascript
Three.js利用dat.GUI如何简化试验流程详解
2017/09/26 Javascript
JavaScript适配器模式详解
2017/10/19 Javascript
webpack 4.0.0-beta.0版本新特性介绍
2018/02/10 Javascript
vue 微信授权登录解决方案
2018/04/10 Javascript
详解Vue CLI3配置之filenameHashing使用和源码设计使用和源码设计
2018/08/31 Javascript
详解vue中使用transition和animation的实例代码
2020/12/12 Vue.js
在Python中marshal对象序列化的相关知识
2015/07/01 Python
python 获取当天每个准点时间戳的实例
2018/05/22 Python
对python插入数据库和生成插入sql的示例讲解
2018/11/14 Python
使用python进行拆分大文件的方法
2018/12/10 Python
Python+OpenCV+图片旋转并用原底色填充新四角的例子
2019/12/12 Python
python 用pandas实现数据透视表功能
2020/12/21 Python
使用canvas压缩图片大小的方法示例
2019/08/02 HTML / CSS
智利最大的网上商店:Linio智利
2016/11/24 全球购物
德国狗狗用品在线商店:Schecker
2017/03/17 全球购物
工商治理实习生的自我评价分享
2014/02/20 职场文书
旅游文化节策划方案
2014/06/06 职场文书
国际贸易系求职信
2014/08/09 职场文书
2015试用期转正工作总结
2014/12/12 职场文书
庆六一开幕词
2015/01/29 职场文书
golang 实现Location跳转方式
2021/05/02 Golang
css如何把元素固定在容器底部的四种方式
2022/06/16 HTML / CSS