pytorch 在网络中添加可训练参数,修改预训练权重文件的方法


Posted in Python onAugust 17, 2019

实践中,针对不同的任务需求,我们经常会在现成的网络结构上做一定的修改来实现特定的目的。

假如我们现在有一个简单的两层感知机网络:

# -*- coding: utf-8 -*-
import torch
from torch.autograd import Variable
import torch.optim as optim
 
x = Variable(torch.FloatTensor([1, 2, 3])).cuda()
y = Variable(torch.FloatTensor([4, 5])).cuda()
 
class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.linear2(x)
 
    return x
 
model = MLP().cuda()
 
loss_fn = torch.nn.MSELoss(size_average=False)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
for t in range(500):
  y_pred = model(x)
  loss = loss_fn(y_pred, y)
  print(t, loss.data[0])
  model.zero_grad()
  loss.backward()
  optimizer.step()
 
print(model(x))

现在想在前向传播时,在relu之后给x乘以一个可训练的系数,只需要在__init__函数中添加一个nn.Parameter类型变量,并在forward函数中乘以该变量即可:

class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
    # the para to be added and updated in train phase, note that NO cuda() at last
    self.coefficient = torch.nn.Parameter(torch.Tensor([1.55]))
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.coefficient * x
    x = self.linear2(x)
 
    return x

注意,Parameter变量和Variable变量的操作大致相同,但是不能手动调用.cuda()方法将其加载在GPU上,事实上它会自动在GPU上加载,可以通过model.state_dict()或者model.named_parameters()函数查看现在的全部可训练参数(包括通过继承得到的父类中的参数):

print(model.state_dict().keys())
for i, j in model.named_parameters():
  print(i)
  print(j)

输出如下:

odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias'])
linear1.weight
Parameter containing:
-0.3582 -0.0283 0.2607
 0.5190 -0.2221 0.0665
-0.2586 -0.3311 0.1927
-0.2765 0.5590 -0.2598
 0.4679 -0.2923 -0.3379
[torch.cuda.FloatTensor of size 5x3 (GPU 0)]
 
linear1.bias
Parameter containing:
-0.2549
-0.5246
-0.1109
 0.5237
-0.1362
[torch.cuda.FloatTensor of size 5 (GPU 0)]
 
linear2.weight
Parameter containing:
-0.0286 -0.3045 0.1928 -0.2323 0.2966
 0.2601 0.1441 -0.2159 0.2484 0.0544
[torch.cuda.FloatTensor of size 2x5 (GPU 0)]
 
linear2.bias
Parameter containing:
-0.4038
 0.3129
[torch.cuda.FloatTensor of size 2 (GPU 0)]

这个参数会在反向传播时与原有变量同时参与更新,这就达到了添加可训练参数的目的。

如果我们有原先网络的预训练权重,现在添加了一个新的参数,原有的权重文件自然就不能加载了,我们需要修改原权重文件,在其中添加我们的新变量的初始值。

调用model.state_dict查看我们添加的参数在参数字典中的完整名称,然后打开原先的权重文件:

a = torch.load("OldWeights.pth") a是一个collecitons.OrderedDict类型变量,也就是一个有序字典,直接将新参数名称和初始值作为键值对插入,然后保存即可。

a = torch.load("OldWeights.pth")
 
a["layer1.0.coefficient"] = torch.FloatTensor([1.2])
a["layer1.1.coefficient"] = torch.FloatTensor([1.5])
 
torch.save(a, "Weights.pth")

现在权重就可以加载在修改后的模型上了。

以上这篇pytorch 在网络中添加可训练参数,修改预训练权重文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python逐行读取文件内容的三种方法
Jan 20 Python
ubuntu17.4下为python和python3装上pip的方法
Jun 12 Python
python 剪切移动文件的实现代码
Aug 02 Python
Selenium+Python 自动化操控登录界面实例(有简单验证码图片校验)
Jun 28 Python
python 子类调用父类的构造函数实例
Mar 12 Python
将pycharm配置为matlab或者spyder的用法说明
Jun 08 Python
python在CMD界面读取excel所有数据的示例
Sep 28 Python
python如何写个俄罗斯方块
Nov 06 Python
Selenium环境变量配置(火狐浏览器)及验证实现
Dec 07 Python
PyCharm 解决找不到新打开项目的窗口问题
Jan 15 Python
Python 将代码转换为可执行文件脱离python环境运行(步骤详解)
Jan 25 Python
基于PyTorch中view的用法说明
Mar 03 Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
画pytorch模型图,以及参数计算的方法
Aug 17 #Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
You might like
上海无线电三厂简史修改版
2021/03/01 无线电
php调用mysql存储过程实例分析
2014/12/29 PHP
Zend Framework连接Mysql数据库实例分析
2016/03/19 PHP
thinkphp在低版本Nginx 下支持PATHINFO的方法分享
2016/05/27 PHP
thinkphp5框架API token身份验证功能示例
2019/05/21 PHP
基于Laravel-admin 后台的自定义页面用法详解
2019/09/30 PHP
php服务器的系统详解
2019/10/12 PHP
收集的网上用的ajax之chat.js文件
2007/04/08 Javascript
php gethostbyname获取域名ip地址函数详解
2010/01/24 Javascript
原生javascript实现DIV拖拽并计算重复面积
2015/01/02 Javascript
详细分析JavaScript函数定义
2015/07/16 Javascript
JavaScript里 ==与===区别详解
2016/08/16 Javascript
AngularJS控制器详解及示例代码
2016/08/16 Javascript
bootstrap表单按回车会自动刷新页面的解决办法
2017/03/08 Javascript
JS获取一个表单字段中多条数据并转化为json格式
2017/10/17 Javascript
详解微信小程序审核不通过的解决方法
2018/01/17 Javascript
在Vue组件中使用 TypeScript的方法
2018/02/28 Javascript
jQuery实现图片简单轮播功能示例
2018/08/13 jQuery
JavaScript 2018 中即将迎来的新功能
2018/09/21 Javascript
angularjs通过过滤器返回超链接的方法
2018/10/26 Javascript
vue读取本地的excel文件并显示在网页上方法示例
2019/05/29 Javascript
JavaScript中this函数使用实例解析
2020/02/21 Javascript
JS定时器如何实现提交成功提示功能
2020/06/12 Javascript
Vue select 绑定动态变量的实例讲解
2020/10/22 Javascript
基于wxpython开发的简单gui计算器实例
2015/05/30 Python
十条建议帮你提高Python编程效率
2016/02/16 Python
利用python实现数据分析
2017/01/11 Python
Python安装图文教程 Pycharm安装教程
2018/03/27 Python
详解Python中的type和object
2018/08/15 Python
CSS3媒体查询(Media Queries)介绍
2013/09/12 HTML / CSS
生物有机护肤品:Aurelia Probiotic Skincare
2018/01/31 全球购物
泰国演唱会订票网站:StubHub泰国
2018/02/26 全球购物
品恩科技软件测试面试题
2014/10/26 面试题
百度吧主申请感言
2014/01/12 职场文书
公务员四风问题对照检查材料整改措施
2014/09/26 职场文书
php将xml转化对象的实例详解
2021/11/17 PHP