pytorch 在网络中添加可训练参数,修改预训练权重文件的方法


Posted in Python onAugust 17, 2019

实践中,针对不同的任务需求,我们经常会在现成的网络结构上做一定的修改来实现特定的目的。

假如我们现在有一个简单的两层感知机网络:

# -*- coding: utf-8 -*-
import torch
from torch.autograd import Variable
import torch.optim as optim
 
x = Variable(torch.FloatTensor([1, 2, 3])).cuda()
y = Variable(torch.FloatTensor([4, 5])).cuda()
 
class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.linear2(x)
 
    return x
 
model = MLP().cuda()
 
loss_fn = torch.nn.MSELoss(size_average=False)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
for t in range(500):
  y_pred = model(x)
  loss = loss_fn(y_pred, y)
  print(t, loss.data[0])
  model.zero_grad()
  loss.backward()
  optimizer.step()
 
print(model(x))

现在想在前向传播时,在relu之后给x乘以一个可训练的系数,只需要在__init__函数中添加一个nn.Parameter类型变量,并在forward函数中乘以该变量即可:

class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
    # the para to be added and updated in train phase, note that NO cuda() at last
    self.coefficient = torch.nn.Parameter(torch.Tensor([1.55]))
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.coefficient * x
    x = self.linear2(x)
 
    return x

注意,Parameter变量和Variable变量的操作大致相同,但是不能手动调用.cuda()方法将其加载在GPU上,事实上它会自动在GPU上加载,可以通过model.state_dict()或者model.named_parameters()函数查看现在的全部可训练参数(包括通过继承得到的父类中的参数):

print(model.state_dict().keys())
for i, j in model.named_parameters():
  print(i)
  print(j)

输出如下:

odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias'])
linear1.weight
Parameter containing:
-0.3582 -0.0283 0.2607
 0.5190 -0.2221 0.0665
-0.2586 -0.3311 0.1927
-0.2765 0.5590 -0.2598
 0.4679 -0.2923 -0.3379
[torch.cuda.FloatTensor of size 5x3 (GPU 0)]
 
linear1.bias
Parameter containing:
-0.2549
-0.5246
-0.1109
 0.5237
-0.1362
[torch.cuda.FloatTensor of size 5 (GPU 0)]
 
linear2.weight
Parameter containing:
-0.0286 -0.3045 0.1928 -0.2323 0.2966
 0.2601 0.1441 -0.2159 0.2484 0.0544
[torch.cuda.FloatTensor of size 2x5 (GPU 0)]
 
linear2.bias
Parameter containing:
-0.4038
 0.3129
[torch.cuda.FloatTensor of size 2 (GPU 0)]

这个参数会在反向传播时与原有变量同时参与更新,这就达到了添加可训练参数的目的。

如果我们有原先网络的预训练权重,现在添加了一个新的参数,原有的权重文件自然就不能加载了,我们需要修改原权重文件,在其中添加我们的新变量的初始值。

调用model.state_dict查看我们添加的参数在参数字典中的完整名称,然后打开原先的权重文件:

a = torch.load("OldWeights.pth") a是一个collecitons.OrderedDict类型变量,也就是一个有序字典,直接将新参数名称和初始值作为键值对插入,然后保存即可。

a = torch.load("OldWeights.pth")
 
a["layer1.0.coefficient"] = torch.FloatTensor([1.2])
a["layer1.1.coefficient"] = torch.FloatTensor([1.5])
 
torch.save(a, "Weights.pth")

现在权重就可以加载在修改后的模型上了。

以上这篇pytorch 在网络中添加可训练参数,修改预训练权重文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
pyqt4教程之实现半透明的天气预报界面示例
Mar 02 Python
python实现爬取千万淘宝商品的方法
Jun 30 Python
python从入门到精通(DAY 3)
Dec 20 Python
Python编程中对super函数的正确理解和用法解析
Jul 02 Python
Python AES加密实例解析
Jan 18 Python
Python 的字典(Dict)是如何存储的
Jul 05 Python
基于python实现语音录入识别代码实例
Jan 17 Python
python无序链表删除重复项的方法
Jan 17 Python
Python列表倒序输出及其效率详解
Mar 04 Python
如何基于Python Matplotlib实现网格动画
Jul 20 Python
Python中的面向接口编程示例详解
Jan 17 Python
解决python存数据库速度太慢的问题
Apr 23 Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
画pytorch模型图,以及参数计算的方法
Aug 17 #Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
You might like
CI使用Tank Auth转移数据库导致密码用户错误的解决办法
2014/06/12 PHP
2个Codeigniter文件批量上传控制器写法例子
2014/07/25 PHP
php利用fsockopen GET/POST提交表单及上传文件
2017/05/22 PHP
Yii框架实现对数据库的CURD操作示例
2019/09/03 PHP
document.getElementById获取控件对象为空的解决方法
2013/11/20 Javascript
js购物车实现思路及代码(个人感觉不错)
2013/12/23 Javascript
使用js实现关闭js弹出层的窗口
2014/02/10 Javascript
json字符串之间的相互转换示例代码
2014/08/21 Javascript
Jquery全屏相册插件zoomvisualizer具有调节放大与缩小功能
2015/11/02 Javascript
基于Jquery和CSS3制作数字时钟附源码下载(CSS3篇)
2015/11/24 Javascript
大型JavaScript应用程序架构设计模式
2016/06/29 Javascript
Bootstrap时间选择器datetimepicker和daterangepicker使用实例解析
2016/09/17 Javascript
巧用canvas
2017/01/21 Javascript
jquery uploadify如何取消已上传成功文件
2017/02/08 Javascript
JavaScript 过滤关键字
2017/03/20 Javascript
Angular 表单控件示例代码
2017/06/26 Javascript
vue引入swiper插件的使用实例
2017/07/19 Javascript
Vue axios设置访问基础路径方法
2018/09/19 Javascript
SSM+layUI 根据登录信息显示不同的页面方法
2019/09/20 Javascript
[02:10]探秘浦东源深体育馆 DOTA2 Supermajor不见不散
2018/05/17 DOTA
[01:07:19]2018DOTA2亚洲邀请赛 4.5 淘汰赛 Mineski vs VG 第一场
2018/04/06 DOTA
Python字符串的encode与decode研究心得乱码问题解决方法
2009/03/23 Python
Scrapy-redis爬虫分布式爬取的分析和实现
2017/02/07 Python
python实现停车管理系统
2018/11/30 Python
详解Django中CBV(Class Base Views)模型源码分析
2019/02/25 Python
Django用户登录与注册系统的实现示例
2020/06/03 Python
如何将json数据转换为python数据
2020/09/04 Python
Python爬虫回测股票的实例讲解
2021/01/22 Python
利用简洁的图片预加载组件提升html5移动页面的用户体验
2016/03/11 HTML / CSS
html5记忆翻牌游戏实现思路及代码
2013/07/25 HTML / CSS
cosme官方海外旗舰店:日本最大化妆品和美容产品的综合口碑网站
2017/01/18 全球购物
澳大利亚便宜的家庭购物网站:CrazySales
2018/02/06 全球购物
医院节能减排方案
2014/06/13 职场文书
教师节横幅标语
2014/10/08 职场文书
任长霞观后感
2015/06/16 职场文书
2015年党风廉政建设个人总结
2015/08/18 职场文书