pytorch自定义初始化权重的方法


Posted in Python onAugust 17, 2019

在常见的pytorch代码中,我们见到的初始化方式都是调用init类对每层所有参数进行初始化。但是,有时我们有些特殊需求,比如用某一层的权重取优化其它层,或者手动指定某些权重的初始值。

核心思想就是构造和该层权重同一尺寸的矩阵去对该层权重赋值。但是,值得注意的是,pytorch中各层权重的数据类型是nn.Parameter,而不是Tensor或者Variable。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
 
# 第一一个卷积层,我们可以看到它的权值是随机初始化的
w=torch.nn.Conv2d(2,2,3,padding=1)
print(w.weight)
 
 
# 第一种方法
print("1.使用另一个Conv层的权值")
q=torch.nn.Conv2d(2,2,3,padding=1) # 假设q代表一个训练好的卷积层
print(q.weight) # 可以看到q的权重和w是不同的
w.weight=q.weight # 把一个Conv层的权重赋值给另一个Conv层
print(w.weight)
 
# 第二种方法
print("2.使用来自Tensor的权值")
ones=torch.Tensor(np.ones([2,2,3,3])) # 先创建一个自定义权值的Tensor,这里为了方便将所有权值设为1
w.weight=torch.nn.Parameter(ones) # 把Tensor的值作为权值赋值给Conv层,这里需要先转为torch.nn.Parameter类型,否则将报错
print(w.weight)

附:Variable和Parameter的区别

Parameter 是torch.autograd.Variable的一个字类,常被用于Module的参数。例如权重和偏置。

Parameters和Modules一起使用的时候会有一些特殊的属性。parameters赋值给Module的属性的时候,它会被自动加到Module的参数列表中,即会出现在Parameter()迭代器中。将Varaible赋给Module的时候没有这样的属性。这可以在nn.Module的实现中详细看一下。这样做是为了保存模型的时候只保存权重偏置参数,不保存节点值。所以复写Variable加以区分。

另外一个不同是parameter不能设置volatile,而且require_grad默认设置为true。Varaible默认设置为False.

参数:

parameter.data 得到tensor数据

parameter.requires_grad 默认为True, BP过程中会求导

Parameter一般是在Modules中作为权重和偏置,自动加入参数列表,可以进行保存恢复。和Variable具有相同的运算。

我们可以这样简单区分,在计算图中,数据(包括输入数据和计算过程中产生的feature map等)时variable类型,该类型不会被保存到模型中。 网络的权重是parameter类型,在计算过程中会被更新,将会被保存到模型中。

以上这篇pytorch自定义初始化权重的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 自动补全(vim)
Nov 30 Python
python连接MySQL数据库实例分析
May 12 Python
Python自动登录126邮箱的方法
Jul 10 Python
PyQt5利用QPainter绘制各种图形的实例
Oct 19 Python
python使用xpath中遇到:到底是什么?
Jan 04 Python
python-itchat 获取微信群用户信息的实例
Feb 21 Python
python2.7使用plotly绘制本地散点图和折线图
Apr 02 Python
Python实用工具FuckIt.py介绍
Jul 02 Python
将Python文件打包成.EXE可执行文件的方法
Aug 11 Python
Python实现分数序列求和
Feb 25 Python
Python环境管理virtualenv&virtualenvwrapper的配置详解
Jul 01 Python
Jupyter notebook 更改文件打开的默认路径操作
May 21 Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
You might like
xajax写的留言本
2006/11/25 PHP
php获取客户端IP及URL的方法示例
2017/02/03 PHP
浅谈ThinkPHP中initialize和construct的区别
2017/04/01 PHP
javascript学习随笔(使用window和frame)的技巧
2007/03/08 Javascript
js no-repeat写法 背景不重复
2009/03/18 Javascript
jqPlot jquery的页面图表绘制工具
2009/07/25 Javascript
js渐变显示渐变消失示例代码
2013/08/01 Javascript
js数组中如何随机取出一个值
2014/06/13 Javascript
jquery实现的点击翻书效果代码
2015/11/04 Javascript
如何动态加载外部Javascript文件
2015/12/02 Javascript
Bootstrap CDN和本地化环境搭建
2016/10/26 Javascript
完美解决JS文件页面加载时的阻塞问题
2016/12/18 Javascript
jQuery拖拽通过八个点改变div大小
2020/11/29 Javascript
xmlplus组件设计系列之文本框(TextBox)(3)
2017/05/03 Javascript
JSON 数据格式详解
2017/09/13 Javascript
实例分析js事件循环机制
2017/12/13 Javascript
JS获取指定月份的天数两种实现方法
2018/06/22 Javascript
微信小程序可滑动月日历组件使用详解
2019/10/21 Javascript
JavaScript逻辑运算符相关总结
2020/09/04 Javascript
vue实现顶部菜单栏
2020/11/08 Javascript
Python中使用wxPython开发的一个简易笔记本程序实例
2015/02/08 Python
Python设计模式之观察者模式简单示例
2018/01/10 Python
Python数据分析之双色球基于线性回归算法预测下期中奖结果示例
2018/02/08 Python
Python中存取文件的4种不同操作
2018/07/02 Python
python: 自动安装缺失库文件的方法
2018/10/22 Python
对python中的float除法和整除法的实例详解
2019/07/20 Python
Python使用Pandas库常见操作详解
2020/01/16 Python
今天学到的CSS最新技术(与图片背景相关)
2012/12/24 HTML / CSS
分享一个H5原生form表单的checkbox特效代码
2018/02/26 HTML / CSS
Prototype如何为一个Ajax添加一个参数
2015/12/06 面试题
应届专科生个人的自我评价
2014/01/05 职场文书
优秀导游先进事迹材料
2014/01/25 职场文书
社区工作感言
2014/02/21 职场文书
法学专业求职信
2014/07/15 职场文书
融资合作协议书范本
2014/10/17 职场文书
运动会广播稿300字
2015/08/19 职场文书