pytorch 在网络中添加可训练参数,修改预训练权重文件的方法


Posted in Python onAugust 17, 2019

实践中,针对不同的任务需求,我们经常会在现成的网络结构上做一定的修改来实现特定的目的。

假如我们现在有一个简单的两层感知机网络:

# -*- coding: utf-8 -*-
import torch
from torch.autograd import Variable
import torch.optim as optim
 
x = Variable(torch.FloatTensor([1, 2, 3])).cuda()
y = Variable(torch.FloatTensor([4, 5])).cuda()
 
class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.linear2(x)
 
    return x
 
model = MLP().cuda()
 
loss_fn = torch.nn.MSELoss(size_average=False)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
for t in range(500):
  y_pred = model(x)
  loss = loss_fn(y_pred, y)
  print(t, loss.data[0])
  model.zero_grad()
  loss.backward()
  optimizer.step()
 
print(model(x))

现在想在前向传播时,在relu之后给x乘以一个可训练的系数,只需要在__init__函数中添加一个nn.Parameter类型变量,并在forward函数中乘以该变量即可:

class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
    # the para to be added and updated in train phase, note that NO cuda() at last
    self.coefficient = torch.nn.Parameter(torch.Tensor([1.55]))
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.coefficient * x
    x = self.linear2(x)
 
    return x

注意,Parameter变量和Variable变量的操作大致相同,但是不能手动调用.cuda()方法将其加载在GPU上,事实上它会自动在GPU上加载,可以通过model.state_dict()或者model.named_parameters()函数查看现在的全部可训练参数(包括通过继承得到的父类中的参数):

print(model.state_dict().keys())
for i, j in model.named_parameters():
  print(i)
  print(j)

输出如下:

odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias'])
linear1.weight
Parameter containing:
-0.3582 -0.0283 0.2607
 0.5190 -0.2221 0.0665
-0.2586 -0.3311 0.1927
-0.2765 0.5590 -0.2598
 0.4679 -0.2923 -0.3379
[torch.cuda.FloatTensor of size 5x3 (GPU 0)]
 
linear1.bias
Parameter containing:
-0.2549
-0.5246
-0.1109
 0.5237
-0.1362
[torch.cuda.FloatTensor of size 5 (GPU 0)]
 
linear2.weight
Parameter containing:
-0.0286 -0.3045 0.1928 -0.2323 0.2966
 0.2601 0.1441 -0.2159 0.2484 0.0544
[torch.cuda.FloatTensor of size 2x5 (GPU 0)]
 
linear2.bias
Parameter containing:
-0.4038
 0.3129
[torch.cuda.FloatTensor of size 2 (GPU 0)]

这个参数会在反向传播时与原有变量同时参与更新,这就达到了添加可训练参数的目的。

如果我们有原先网络的预训练权重,现在添加了一个新的参数,原有的权重文件自然就不能加载了,我们需要修改原权重文件,在其中添加我们的新变量的初始值。

调用model.state_dict查看我们添加的参数在参数字典中的完整名称,然后打开原先的权重文件:

a = torch.load("OldWeights.pth") a是一个collecitons.OrderedDict类型变量,也就是一个有序字典,直接将新参数名称和初始值作为键值对插入,然后保存即可。

a = torch.load("OldWeights.pth")
 
a["layer1.0.coefficient"] = torch.FloatTensor([1.2])
a["layer1.1.coefficient"] = torch.FloatTensor([1.5])
 
torch.save(a, "Weights.pth")

现在权重就可以加载在修改后的模型上了。

以上这篇pytorch 在网络中添加可训练参数,修改预训练权重文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python获取文件ssdeep值的方法
Oct 05 Python
利用matplotlib+numpy绘制多种绘图的方法实例
May 03 Python
Python算术运算符实例详解
May 31 Python
python学习入门细节知识点
Mar 29 Python
利用OpenCV和Python实现查找图片差异
Dec 19 Python
Python读取配置文件(config.ini)以及写入配置文件
Apr 08 Python
使用PyQt5实现图片查看器的示例代码
Apr 21 Python
python实现密度聚类(模板代码+sklearn代码)
Apr 27 Python
使用 django orm 写 exists 条件过滤实例
May 20 Python
用python查找统一局域网下ip对应的mac地址
Jan 13 Python
教你使用TensorFlow2识别验证码
Jun 11 Python
Python学习之异常中的finally使用详解
Mar 16 Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
画pytorch模型图,以及参数计算的方法
Aug 17 #Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
You might like
在PHP中读取和写入WORD文档的代码
2008/04/09 PHP
PHP中文分词的简单实现代码分享
2011/07/17 PHP
ThinkPHP分组下自定义标签库实例
2014/11/01 PHP
自编函数解决pathinfo()函数处理中文问题
2014/11/03 PHP
PHP中通过fopen()函数访问远程文件示例
2014/11/18 PHP
PHP实现通过URL提取根域名
2016/03/31 PHP
thinkPHP实现递归循环栏目并按照树形结构无限极输出的方法
2016/05/19 PHP
jquery tab插件制作实现代码
2010/06/22 Javascript
基于jQuery的input输入框下拉提示层(自动邮箱后缀名)
2012/06/14 Javascript
jquery animate实现鼠标放上去显示离开隐藏效果
2013/07/21 Javascript
JS+CSS实现自动改变切换方向图片幻灯切换效果的方法
2015/03/02 Javascript
深入理解JavaScript系列(38):设计模式之职责链模式详解
2015/03/04 Javascript
js实现左侧网页tab滑动门效果代码
2015/09/06 Javascript
js和jquery中获取非行间样式
2017/05/05 jQuery
BootStrap Fileinput上传插件使用实例代码
2017/07/28 Javascript
vue.js声明式渲染和条件与循环基础知识
2017/07/31 Javascript
vue-prop父组件向子组件进行传值的方法
2018/03/01 Javascript
Vue插件打包与发布的方法示例
2018/08/20 Javascript
webpack4手动搭建Vue开发环境实现todoList项目的方法
2019/05/16 Javascript
vue  elementUI 表单嵌套验证的实例代码
2019/11/06 Javascript
vue中js判断长时间不操作界面自动退出登录(推荐)
2020/01/22 Javascript
[01:08:30]DOTA2-DPC中国联赛 正赛 Ehome vs Elephant BO3 第一场 2月28日
2021/03/11 DOTA
python通过socket查询whois的方法
2015/07/18 Python
pygame游戏之旅 添加游戏介绍
2018/11/20 Python
对pandas的算术运算和数据对齐实例详解
2018/12/22 Python
Django命名URL和反向解析URL实现解析
2019/08/09 Python
Python imageio读取视频并进行编解码详解
2019/12/10 Python
Python常用base64 md5 aes des crc32加密解密方法汇总
2020/11/06 Python
宝塔面板出现“open_basedir restriction in effect. ”的解决方法
2021/03/14 PHP
北美主要的汽车零部件零售商:AutoShack.com
2019/02/23 全球购物
俄罗斯香水和化妆品在线商店:Aroma-butik
2020/02/28 全球购物
担保书范文
2015/01/20 职场文书
left join、inner join、right join的区别
2021/04/05 MySQL
解决Maven项目中 Invalid bound statement 无效的绑定问题
2021/06/15 Java/Android
SpringDataJPA实体类关系映射配置方式
2021/12/06 Java/Android
Java存储没有重复元素的数组
2022/04/29 Java/Android