pytorch 在网络中添加可训练参数,修改预训练权重文件的方法


Posted in Python onAugust 17, 2019

实践中,针对不同的任务需求,我们经常会在现成的网络结构上做一定的修改来实现特定的目的。

假如我们现在有一个简单的两层感知机网络:

# -*- coding: utf-8 -*-
import torch
from torch.autograd import Variable
import torch.optim as optim
 
x = Variable(torch.FloatTensor([1, 2, 3])).cuda()
y = Variable(torch.FloatTensor([4, 5])).cuda()
 
class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.linear2(x)
 
    return x
 
model = MLP().cuda()
 
loss_fn = torch.nn.MSELoss(size_average=False)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
for t in range(500):
  y_pred = model(x)
  loss = loss_fn(y_pred, y)
  print(t, loss.data[0])
  model.zero_grad()
  loss.backward()
  optimizer.step()
 
print(model(x))

现在想在前向传播时,在relu之后给x乘以一个可训练的系数,只需要在__init__函数中添加一个nn.Parameter类型变量,并在forward函数中乘以该变量即可:

class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
    # the para to be added and updated in train phase, note that NO cuda() at last
    self.coefficient = torch.nn.Parameter(torch.Tensor([1.55]))
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.coefficient * x
    x = self.linear2(x)
 
    return x

注意,Parameter变量和Variable变量的操作大致相同,但是不能手动调用.cuda()方法将其加载在GPU上,事实上它会自动在GPU上加载,可以通过model.state_dict()或者model.named_parameters()函数查看现在的全部可训练参数(包括通过继承得到的父类中的参数):

print(model.state_dict().keys())
for i, j in model.named_parameters():
  print(i)
  print(j)

输出如下:

odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias'])
linear1.weight
Parameter containing:
-0.3582 -0.0283 0.2607
 0.5190 -0.2221 0.0665
-0.2586 -0.3311 0.1927
-0.2765 0.5590 -0.2598
 0.4679 -0.2923 -0.3379
[torch.cuda.FloatTensor of size 5x3 (GPU 0)]
 
linear1.bias
Parameter containing:
-0.2549
-0.5246
-0.1109
 0.5237
-0.1362
[torch.cuda.FloatTensor of size 5 (GPU 0)]
 
linear2.weight
Parameter containing:
-0.0286 -0.3045 0.1928 -0.2323 0.2966
 0.2601 0.1441 -0.2159 0.2484 0.0544
[torch.cuda.FloatTensor of size 2x5 (GPU 0)]
 
linear2.bias
Parameter containing:
-0.4038
 0.3129
[torch.cuda.FloatTensor of size 2 (GPU 0)]

这个参数会在反向传播时与原有变量同时参与更新,这就达到了添加可训练参数的目的。

如果我们有原先网络的预训练权重,现在添加了一个新的参数,原有的权重文件自然就不能加载了,我们需要修改原权重文件,在其中添加我们的新变量的初始值。

调用model.state_dict查看我们添加的参数在参数字典中的完整名称,然后打开原先的权重文件:

a = torch.load("OldWeights.pth") a是一个collecitons.OrderedDict类型变量,也就是一个有序字典,直接将新参数名称和初始值作为键值对插入,然后保存即可。

a = torch.load("OldWeights.pth")
 
a["layer1.0.coefficient"] = torch.FloatTensor([1.2])
a["layer1.1.coefficient"] = torch.FloatTensor([1.5])
 
torch.save(a, "Weights.pth")

现在权重就可以加载在修改后的模型上了。

以上这篇pytorch 在网络中添加可训练参数,修改预训练权重文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
推荐11个实用Python库
Jan 23 Python
python在windows下实现ping操作并接收返回信息的方法
Mar 20 Python
在Python中操作字典之fromkeys()方法的使用
May 21 Python
Python下载网络小说实例代码
Feb 03 Python
Python3中的json模块使用详解
May 05 Python
Python Numpy库安装与基本操作示例
Jan 08 Python
NumPy 数组使用大全
Apr 25 Python
python实现在cmd窗口显示彩色文字
Jun 24 Python
Python装饰器原理与基本用法分析
Jan 07 Python
python统计文章中单词出现次数实例
Feb 27 Python
Python logging模块原理解析及应用
Aug 13 Python
python opencv pytesseract 验证码识别的实现
Aug 28 Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
画pytorch模型图,以及参数计算的方法
Aug 17 #Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
You might like
虫族 Zerg 历史背景
2020/03/14 星际争霸
PHP处理大量表单字段的便捷方法
2015/02/07 PHP
Yii2中OAuth扩展及QQ互联登录实现方法
2016/05/16 PHP
thinkphp5 加载静态资源路径与常量的方法
2017/12/24 PHP
javascript 定义初始化数组函数
2009/09/07 Javascript
Jquery 表单取值赋值的一些基本操作
2009/10/11 Javascript
jquery插件制作 图片走廊 gallery
2012/08/17 Javascript
基于jquery自己写tab滑动门(通用版)
2012/10/30 Javascript
父元素与子iframe相互获取变量和元素对象的具体实现
2013/10/15 Javascript
jQuery EasyUI Dialog拖不下来如何解决
2015/09/28 Javascript
AngularJS删除路由中的#符号的方法
2016/09/20 Javascript
jquery easyui如何实现格式化列
2017/07/30 jQuery
基于node.js实现微信支付退款功能
2017/12/19 Javascript
bootstrap实现点击删除按钮弹出确认框的实例代码
2018/08/16 Javascript
Vue源码学习之关于对Array的数据侦听实现
2019/04/23 Javascript
ES6 Class中实现私有属性的一些方法总结
2019/07/08 Javascript
python time模块用法实例详解
2014/09/11 Python
Python获取文件所在目录和文件名的方法
2017/01/12 Python
用pandas按列合并两个文件的实例
2018/04/12 Python
Python爬虫之正则表达式基本用法实例分析
2018/08/08 Python
对pandas中两种数据类型Series和DataFrame的区别详解
2018/11/12 Python
python二进制读写及特殊码同步实现详解
2019/10/11 Python
基于python的itchat库实现微信聊天机器人(推荐)
2019/10/29 Python
Python 测试框架unittest和pytest的优劣
2020/09/26 Python
CSS3制作精致的照片墙特效
2016/06/07 HTML / CSS
韩国11街:11STREET
2018/03/27 全球购物
英国手工制作的现代与经典的沙发和床:Love Your Home
2020/09/26 全球购物
如何客观的进行自我评价
2013/12/17 职场文书
期终自我鉴定
2014/02/17 职场文书
工程技术员岗位职责
2014/03/02 职场文书
《大江保卫战》教学反思
2014/04/11 职场文书
数学教研活动总结
2014/07/02 职场文书
巾帼文明岗事迹材料
2014/12/24 职场文书
中秋联欢会主持词
2015/07/04 职场文书
和领导吃饭祝酒词
2015/08/11 职场文书
Python中的嵌套循环详情
2022/03/23 Python