pytorch 在网络中添加可训练参数,修改预训练权重文件的方法


Posted in Python onAugust 17, 2019

实践中,针对不同的任务需求,我们经常会在现成的网络结构上做一定的修改来实现特定的目的。

假如我们现在有一个简单的两层感知机网络:

# -*- coding: utf-8 -*-
import torch
from torch.autograd import Variable
import torch.optim as optim
 
x = Variable(torch.FloatTensor([1, 2, 3])).cuda()
y = Variable(torch.FloatTensor([4, 5])).cuda()
 
class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.linear2(x)
 
    return x
 
model = MLP().cuda()
 
loss_fn = torch.nn.MSELoss(size_average=False)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
for t in range(500):
  y_pred = model(x)
  loss = loss_fn(y_pred, y)
  print(t, loss.data[0])
  model.zero_grad()
  loss.backward()
  optimizer.step()
 
print(model(x))

现在想在前向传播时,在relu之后给x乘以一个可训练的系数,只需要在__init__函数中添加一个nn.Parameter类型变量,并在forward函数中乘以该变量即可:

class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
    # the para to be added and updated in train phase, note that NO cuda() at last
    self.coefficient = torch.nn.Parameter(torch.Tensor([1.55]))
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.coefficient * x
    x = self.linear2(x)
 
    return x

注意,Parameter变量和Variable变量的操作大致相同,但是不能手动调用.cuda()方法将其加载在GPU上,事实上它会自动在GPU上加载,可以通过model.state_dict()或者model.named_parameters()函数查看现在的全部可训练参数(包括通过继承得到的父类中的参数):

print(model.state_dict().keys())
for i, j in model.named_parameters():
  print(i)
  print(j)

输出如下:

odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias'])
linear1.weight
Parameter containing:
-0.3582 -0.0283 0.2607
 0.5190 -0.2221 0.0665
-0.2586 -0.3311 0.1927
-0.2765 0.5590 -0.2598
 0.4679 -0.2923 -0.3379
[torch.cuda.FloatTensor of size 5x3 (GPU 0)]
 
linear1.bias
Parameter containing:
-0.2549
-0.5246
-0.1109
 0.5237
-0.1362
[torch.cuda.FloatTensor of size 5 (GPU 0)]
 
linear2.weight
Parameter containing:
-0.0286 -0.3045 0.1928 -0.2323 0.2966
 0.2601 0.1441 -0.2159 0.2484 0.0544
[torch.cuda.FloatTensor of size 2x5 (GPU 0)]
 
linear2.bias
Parameter containing:
-0.4038
 0.3129
[torch.cuda.FloatTensor of size 2 (GPU 0)]

这个参数会在反向传播时与原有变量同时参与更新,这就达到了添加可训练参数的目的。

如果我们有原先网络的预训练权重,现在添加了一个新的参数,原有的权重文件自然就不能加载了,我们需要修改原权重文件,在其中添加我们的新变量的初始值。

调用model.state_dict查看我们添加的参数在参数字典中的完整名称,然后打开原先的权重文件:

a = torch.load("OldWeights.pth") a是一个collecitons.OrderedDict类型变量,也就是一个有序字典,直接将新参数名称和初始值作为键值对插入,然后保存即可。

a = torch.load("OldWeights.pth")
 
a["layer1.0.coefficient"] = torch.FloatTensor([1.2])
a["layer1.1.coefficient"] = torch.FloatTensor([1.5])
 
torch.save(a, "Weights.pth")

现在权重就可以加载在修改后的模型上了。

以上这篇pytorch 在网络中添加可训练参数,修改预训练权重文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python GAE、Django导出Excel的方法
Nov 24 Python
python遍历文件夹并删除特定格式文件的示例
Mar 05 Python
python采用requests库模拟登录和抓取数据的简单示例
Jul 05 Python
Python使用Flask框架同时上传多个文件的方法
Mar 21 Python
Python脚本实现虾米网签到功能
Apr 12 Python
深入浅析python定时杀进程
Jun 06 Python
python 文件操作删除某行的实例
Sep 04 Python
基于Python中capitalize()与title()的区别详解
Dec 09 Python
python中set()函数简介及实例解析
Jan 09 Python
Python 仅获取响应头, 不获取实体的实例
Aug 21 Python
Python ConfigParser模块的使用示例
Oct 12 Python
python实现按日期归档文件
Jan 30 Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
画pytorch模型图,以及参数计算的方法
Aug 17 #Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
You might like
生成缩略图
2006/10/09 PHP
使用Apache的htaccess防止图片被盗链的解决方法
2013/04/27 PHP
PHP抓屏函数实现屏幕快照代码分享
2014/01/02 PHP
php给一组指定关键词添加span标签的方法
2015/03/31 PHP
php原生数据库分页的代码实例
2019/02/18 PHP
javascript模仿msgbox提示效果代码
2008/06/10 Javascript
JS Excel读取和写入操作(模板操作)实现代码
2010/04/11 Javascript
javascript客户端解决方案 缓存提供程序
2010/07/14 Javascript
基于jquery实现的表格分页实现代码
2011/06/21 Javascript
jQuery数组处理代码详解(含实例演示)
2012/02/03 Javascript
JavaScript调试技巧之console.log()详解
2014/03/19 Javascript
jQuery中replaceAll()方法用法实例
2015/01/16 Javascript
javascript文件加载管理简单实现方法
2015/07/25 Javascript
js实现匹配时换色的输入提示特效代码
2015/08/17 Javascript
基于jQuery Bar Indicator 插件实现进度条展示效果
2015/09/30 Javascript
jquery动态增加删减表格行特效
2015/11/20 Javascript
JavaScript判断数组是否存在key的简单实例
2016/08/03 Javascript
微信小程序 window_x64环境搭建
2016/09/30 Javascript
AngularJS+bootstrap实现动态选择商品功能示例
2017/05/17 Javascript
AngularJS ng-repeat指令及Ajax的应用实例分析
2017/07/06 Javascript
vue动态路由实现多级嵌套面包屑的思路与方法
2017/08/16 Javascript
VuePress 快速踩坑小结
2019/02/14 Javascript
javascript面向对象创建对象的方式小结
2019/07/29 Javascript
JS实现放烟花效果
2020/03/10 Javascript
详解使用Python处理文件目录的相关方法
2015/10/16 Python
Python 网络爬虫--关于简单的模拟登录实例讲解
2018/06/01 Python
使用Python给头像加上圣诞帽或圣诞老人小图标附源码
2019/12/25 Python
pytorch 模型的train模式与eval模式实例
2020/02/20 Python
Python如何使用OS模块调用cmd
2020/02/27 Python
python中rc1什么意思
2020/06/19 Python
Python制作运行进度条的实现效果(代码运行不无聊)
2021/02/24 Python
杭州-DOTNET笔试题集
2013/09/25 面试题
学生出入校管理制度
2014/01/16 职场文书
女儿十岁生日答谢词
2014/01/27 职场文书
优秀教师申报材料
2014/12/16 职场文书
Python使用socket去实现TCP客户端和TCP服务端
2022/04/12 Python