pytorch 在网络中添加可训练参数,修改预训练权重文件的方法


Posted in Python onAugust 17, 2019

实践中,针对不同的任务需求,我们经常会在现成的网络结构上做一定的修改来实现特定的目的。

假如我们现在有一个简单的两层感知机网络:

# -*- coding: utf-8 -*-
import torch
from torch.autograd import Variable
import torch.optim as optim
 
x = Variable(torch.FloatTensor([1, 2, 3])).cuda()
y = Variable(torch.FloatTensor([4, 5])).cuda()
 
class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.linear2(x)
 
    return x
 
model = MLP().cuda()
 
loss_fn = torch.nn.MSELoss(size_average=False)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
for t in range(500):
  y_pred = model(x)
  loss = loss_fn(y_pred, y)
  print(t, loss.data[0])
  model.zero_grad()
  loss.backward()
  optimizer.step()
 
print(model(x))

现在想在前向传播时,在relu之后给x乘以一个可训练的系数,只需要在__init__函数中添加一个nn.Parameter类型变量,并在forward函数中乘以该变量即可:

class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
    # the para to be added and updated in train phase, note that NO cuda() at last
    self.coefficient = torch.nn.Parameter(torch.Tensor([1.55]))
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.coefficient * x
    x = self.linear2(x)
 
    return x

注意,Parameter变量和Variable变量的操作大致相同,但是不能手动调用.cuda()方法将其加载在GPU上,事实上它会自动在GPU上加载,可以通过model.state_dict()或者model.named_parameters()函数查看现在的全部可训练参数(包括通过继承得到的父类中的参数):

print(model.state_dict().keys())
for i, j in model.named_parameters():
  print(i)
  print(j)

输出如下:

odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias'])
linear1.weight
Parameter containing:
-0.3582 -0.0283 0.2607
 0.5190 -0.2221 0.0665
-0.2586 -0.3311 0.1927
-0.2765 0.5590 -0.2598
 0.4679 -0.2923 -0.3379
[torch.cuda.FloatTensor of size 5x3 (GPU 0)]
 
linear1.bias
Parameter containing:
-0.2549
-0.5246
-0.1109
 0.5237
-0.1362
[torch.cuda.FloatTensor of size 5 (GPU 0)]
 
linear2.weight
Parameter containing:
-0.0286 -0.3045 0.1928 -0.2323 0.2966
 0.2601 0.1441 -0.2159 0.2484 0.0544
[torch.cuda.FloatTensor of size 2x5 (GPU 0)]
 
linear2.bias
Parameter containing:
-0.4038
 0.3129
[torch.cuda.FloatTensor of size 2 (GPU 0)]

这个参数会在反向传播时与原有变量同时参与更新,这就达到了添加可训练参数的目的。

如果我们有原先网络的预训练权重,现在添加了一个新的参数,原有的权重文件自然就不能加载了,我们需要修改原权重文件,在其中添加我们的新变量的初始值。

调用model.state_dict查看我们添加的参数在参数字典中的完整名称,然后打开原先的权重文件:

a = torch.load("OldWeights.pth") a是一个collecitons.OrderedDict类型变量,也就是一个有序字典,直接将新参数名称和初始值作为键值对插入,然后保存即可。

a = torch.load("OldWeights.pth")
 
a["layer1.0.coefficient"] = torch.FloatTensor([1.2])
a["layer1.1.coefficient"] = torch.FloatTensor([1.5])
 
torch.save(a, "Weights.pth")

现在权重就可以加载在修改后的模型上了。

以上这篇pytorch 在网络中添加可训练参数,修改预训练权重文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python转换HTML到Text纯文本的方法
Jan 15 Python
Python打包可执行文件的方法详解
Sep 19 Python
Python生成密码库功能示例
May 23 Python
python使用邻接矩阵构造图代码示例
Nov 10 Python
python 读取视频,处理后,实时计算帧数fps的方法
Jul 10 Python
python实现桌面壁纸切换功能
Jan 21 Python
Python基础知识点 初识Python.md
May 14 Python
Python3将jpg转为pdf文件的方法示例
Dec 13 Python
python中图像通道分离与合并实例
Jan 17 Python
使用python3 实现插入数据到mysql
Mar 02 Python
查看已安装tensorflow版本的方法示例
Apr 19 Python
django queryset 去重 .distinct()说明
May 19 Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
画pytorch模型图,以及参数计算的方法
Aug 17 #Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
You might like
无线电波是什么?它是怎样传输的?
2021/03/01 无线电
php获取数组长度的方法(有实例)
2013/10/27 PHP
PHP 下载文件时自动添加bom头的方法实例
2014/01/10 PHP
PHP实现将textarea的值根据回车换行拆分至数组
2015/06/10 PHP
jquery多浏览器捕捉回车事件代码
2010/06/22 Javascript
jquery的键盘事件修改代码
2011/02/24 Javascript
Ajax 数据请求的简单分析
2011/04/05 Javascript
javascript 基础篇1 什么是js 建立第一个js程序
2012/03/14 Javascript
javascript 闭包详解
2015/02/15 Javascript
用JS实现图片轮播效果代码(一)
2016/06/26 Javascript
js实现浏览器倒计时跳转页面效果
2016/08/12 Javascript
js实现年月日表单三级联动
2020/04/17 Javascript
微信小程序实现图片轮播及文件上传
2017/04/07 Javascript
js编写简单的计时器功能
2017/07/15 Javascript
vue滚动轴插件better-scroll使用详解
2017/10/17 Javascript
vue.js图片转Base64上传图片并预览的实现方法
2018/08/02 Javascript
详解webpack编译速度提升之DllPlugin
2019/02/05 Javascript
Vue项目vscode 安装eslint插件的方法(代码自动修复)
2020/04/15 Javascript
vue 项目中当访问路由不存在的时候默认访问404页面操作
2020/08/31 Javascript
基于js实现的图片拖拽排序源码实例
2020/11/04 Javascript
微信小程序实现底部弹出模态框
2020/11/18 Javascript
python中去空格函数的用法
2014/08/21 Python
python实现登陆知乎获得个人收藏并保存为word文件
2015/03/16 Python
Python编程给numpy矩阵添加一列方法示例
2017/12/04 Python
python进程间通信Queue工作过程详解
2019/11/01 Python
使用layui实现左侧菜单栏及动态操作tab项的方法
2020/11/10 HTML / CSS
美国眼镜网站:EyeBuyDirect
2017/04/13 全球购物
Otticanet意大利:最顶尖的世界名牌眼镜, 能得到打折季的价格
2019/03/10 全球购物
迪拜领先运动补剂零售品牌中文站:Sporter商城
2019/08/20 全球购物
银行自荐信范文
2013/10/07 职场文书
机械系大学毕业生推荐信
2013/11/27 职场文书
单位介绍信格式
2015/01/31 职场文书
员工家属慰问信
2015/03/24 职场文书
青年联谊会致辞
2015/07/31 职场文书
Spring Security中用JWT退出登录时遇到的坑
2021/10/16 Java/Android
Spring Boot 的创建和运行示例代码详解
2022/07/23 Java/Android