pytorch 在网络中添加可训练参数,修改预训练权重文件的方法


Posted in Python onAugust 17, 2019

实践中,针对不同的任务需求,我们经常会在现成的网络结构上做一定的修改来实现特定的目的。

假如我们现在有一个简单的两层感知机网络:

# -*- coding: utf-8 -*-
import torch
from torch.autograd import Variable
import torch.optim as optim
 
x = Variable(torch.FloatTensor([1, 2, 3])).cuda()
y = Variable(torch.FloatTensor([4, 5])).cuda()
 
class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.linear2(x)
 
    return x
 
model = MLP().cuda()
 
loss_fn = torch.nn.MSELoss(size_average=False)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
for t in range(500):
  y_pred = model(x)
  loss = loss_fn(y_pred, y)
  print(t, loss.data[0])
  model.zero_grad()
  loss.backward()
  optimizer.step()
 
print(model(x))

现在想在前向传播时,在relu之后给x乘以一个可训练的系数,只需要在__init__函数中添加一个nn.Parameter类型变量,并在forward函数中乘以该变量即可:

class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
    # the para to be added and updated in train phase, note that NO cuda() at last
    self.coefficient = torch.nn.Parameter(torch.Tensor([1.55]))
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.coefficient * x
    x = self.linear2(x)
 
    return x

注意,Parameter变量和Variable变量的操作大致相同,但是不能手动调用.cuda()方法将其加载在GPU上,事实上它会自动在GPU上加载,可以通过model.state_dict()或者model.named_parameters()函数查看现在的全部可训练参数(包括通过继承得到的父类中的参数):

print(model.state_dict().keys())
for i, j in model.named_parameters():
  print(i)
  print(j)

输出如下:

odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias'])
linear1.weight
Parameter containing:
-0.3582 -0.0283 0.2607
 0.5190 -0.2221 0.0665
-0.2586 -0.3311 0.1927
-0.2765 0.5590 -0.2598
 0.4679 -0.2923 -0.3379
[torch.cuda.FloatTensor of size 5x3 (GPU 0)]
 
linear1.bias
Parameter containing:
-0.2549
-0.5246
-0.1109
 0.5237
-0.1362
[torch.cuda.FloatTensor of size 5 (GPU 0)]
 
linear2.weight
Parameter containing:
-0.0286 -0.3045 0.1928 -0.2323 0.2966
 0.2601 0.1441 -0.2159 0.2484 0.0544
[torch.cuda.FloatTensor of size 2x5 (GPU 0)]
 
linear2.bias
Parameter containing:
-0.4038
 0.3129
[torch.cuda.FloatTensor of size 2 (GPU 0)]

这个参数会在反向传播时与原有变量同时参与更新,这就达到了添加可训练参数的目的。

如果我们有原先网络的预训练权重,现在添加了一个新的参数,原有的权重文件自然就不能加载了,我们需要修改原权重文件,在其中添加我们的新变量的初始值。

调用model.state_dict查看我们添加的参数在参数字典中的完整名称,然后打开原先的权重文件:

a = torch.load("OldWeights.pth") a是一个collecitons.OrderedDict类型变量,也就是一个有序字典,直接将新参数名称和初始值作为键值对插入,然后保存即可。

a = torch.load("OldWeights.pth")
 
a["layer1.0.coefficient"] = torch.FloatTensor([1.2])
a["layer1.1.coefficient"] = torch.FloatTensor([1.5])
 
torch.save(a, "Weights.pth")

现在权重就可以加载在修改后的模型上了。

以上这篇pytorch 在网络中添加可训练参数,修改预训练权重文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用正则表达式检测密码强度源码分享
Jun 11 Python
对python多线程中Lock()与RLock()锁详解
Jan 11 Python
Python基础学习之基本数据结构详解【数字、字符串、列表、元组、集合、字典】
Jun 18 Python
Pyqt5实现英文学习词典
Jun 24 Python
Python符号计算之实现函数极限的方法
Jul 15 Python
Python 异常的捕获、异常的传递与主动抛出异常操作示例
Sep 23 Python
在Python中实现函数重载的示例代码
Dec 12 Python
Flask框架搭建虚拟环境的步骤分析
Dec 21 Python
Python读取表格类型文件代码实例
Feb 17 Python
Python 爬虫性能相关总结
Aug 03 Python
Python将CSV文件转化为HTML文件的操作方法
Jun 30 Python
python实现简单聊天功能
Jul 07 Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
画pytorch模型图,以及参数计算的方法
Aug 17 #Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
You might like
PHP 设计模式之观察者模式介绍
2012/02/22 PHP
thinkphp连贯操作实例分析
2014/11/22 PHP
PHP获取当前所在目录位置的方法
2014/11/26 PHP
php+js实现的无刷新下载文件功能示例
2019/08/23 PHP
SwfUpload在IE10上不出现上传按钮的解决方法
2013/06/25 Javascript
javascript表单验证使用示例(javascript验证邮箱)
2014/01/07 Javascript
js 获取、清空input type="file"的值示例代码
2014/02/19 Javascript
angular简介和其特点介绍
2015/01/29 Javascript
介绍JavaScript的一个微型模版
2015/06/24 Javascript
JavaScript实现点击按钮切换网页背景色的方法
2015/10/17 Javascript
AngularJS初始化静态模板详解
2016/01/14 Javascript
jQuery实现的简单百分比进度条效果示例
2016/08/01 Javascript
javascript 动态脚本添加的简单方法
2016/10/11 Javascript
jQuery和JavaScript节点插入元素的方法对比
2016/11/18 Javascript
Angular(5.2->6.1)升级小结
2018/12/27 Javascript
[38:21]2018DOTA2亚洲邀请赛3月30日 小组赛A组 LGD VS Newbee
2018/03/31 DOTA
机器学习python实战之决策树
2017/11/01 Python
Python遍历pandas数据方法总结
2018/02/09 Python
python购物车程序简单代码
2018/04/18 Python
numpy.linspace函数具体使用详解
2019/05/27 Python
python中的selenium安装的步骤(浏览器自动化测试框架)
2020/03/17 Python
VSCode配合pipenv搞定虚拟环境的实现方法
2020/05/17 Python
结合CSS3的新特性来总结垂直居中的实现方法
2016/05/30 HTML / CSS
canvas与html5实现视频截图功能示例
2016/12/15 HTML / CSS
澳大利亚在线百货商店:Real Smart
2017/08/13 全球购物
欧姆龙医疗欧洲有限公司:Omron Healthcare Europe B.V
2020/06/13 全球购物
面向对象编程OOP的优点
2013/01/22 面试题
年度考核自我鉴定
2013/11/09 职场文书
应届护士推荐信
2013/11/16 职场文书
体育教育毕业生自荐信
2013/11/21 职场文书
施工材料员岗位职责
2014/02/12 职场文书
学生评语大全
2014/04/18 职场文书
大学军训的体会
2014/11/08 职场文书
导游词400字
2015/02/13 职场文书
Java Spring Boot 正确读取配置文件中的属性的值
2022/04/20 Java/Android
MySQL控制流函数(-if ,elseif,else,case...when)
2022/07/07 MySQL