pytorch自定义初始化权重的方法


Posted in Python onAugust 17, 2019

在常见的pytorch代码中,我们见到的初始化方式都是调用init类对每层所有参数进行初始化。但是,有时我们有些特殊需求,比如用某一层的权重取优化其它层,或者手动指定某些权重的初始值。

核心思想就是构造和该层权重同一尺寸的矩阵去对该层权重赋值。但是,值得注意的是,pytorch中各层权重的数据类型是nn.Parameter,而不是Tensor或者Variable。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
 
# 第一一个卷积层,我们可以看到它的权值是随机初始化的
w=torch.nn.Conv2d(2,2,3,padding=1)
print(w.weight)
 
 
# 第一种方法
print("1.使用另一个Conv层的权值")
q=torch.nn.Conv2d(2,2,3,padding=1) # 假设q代表一个训练好的卷积层
print(q.weight) # 可以看到q的权重和w是不同的
w.weight=q.weight # 把一个Conv层的权重赋值给另一个Conv层
print(w.weight)
 
# 第二种方法
print("2.使用来自Tensor的权值")
ones=torch.Tensor(np.ones([2,2,3,3])) # 先创建一个自定义权值的Tensor,这里为了方便将所有权值设为1
w.weight=torch.nn.Parameter(ones) # 把Tensor的值作为权值赋值给Conv层,这里需要先转为torch.nn.Parameter类型,否则将报错
print(w.weight)

附:Variable和Parameter的区别

Parameter 是torch.autograd.Variable的一个字类,常被用于Module的参数。例如权重和偏置。

Parameters和Modules一起使用的时候会有一些特殊的属性。parameters赋值给Module的属性的时候,它会被自动加到Module的参数列表中,即会出现在Parameter()迭代器中。将Varaible赋给Module的时候没有这样的属性。这可以在nn.Module的实现中详细看一下。这样做是为了保存模型的时候只保存权重偏置参数,不保存节点值。所以复写Variable加以区分。

另外一个不同是parameter不能设置volatile,而且require_grad默认设置为true。Varaible默认设置为False.

参数:

parameter.data 得到tensor数据

parameter.requires_grad 默认为True, BP过程中会求导

Parameter一般是在Modules中作为权重和偏置,自动加入参数列表,可以进行保存恢复。和Variable具有相同的运算。

我们可以这样简单区分,在计算图中,数据(包括输入数据和计算过程中产生的feature map等)时variable类型,该类型不会被保存到模型中。 网络的权重是parameter类型,在计算过程中会被更新,将会被保存到模型中。

以上这篇pytorch自定义初始化权重的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现保存网页到本地示例
Mar 16 Python
python从入门到精通(DAY 3)
Dec 20 Python
举例讲解Python中的Null模式与桥接模式编程
Feb 02 Python
在Python程序和Flask框架中使用SQLAlchemy的教程
Jun 06 Python
浅谈python中的数字类型与处理工具
Aug 02 Python
python好玩的项目—色情图片识别代码分享
Nov 07 Python
python下解压缩zip文件并删除文件的实例
Apr 24 Python
对python requests发送json格式数据的实例详解
Dec 19 Python
Python整数对象实现原理详解
Jul 01 Python
Python bytes string相互转换过程解析
Mar 05 Python
解决使用Pandas 读取超过65536行的Excel文件问题
Nov 10 Python
Python selenium的这三种等待方式一定要会!
Jun 10 Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
You might like
供参考的 php 学习提高路线分享
2011/10/23 PHP
解析link_mysql的php版
2013/06/30 PHP
php7 安装yar 生成docker镜像
2017/05/09 PHP
PHP命令空间namespace及use的用法小结
2017/11/27 PHP
PHP7.1实现的AES与RSA加密操作示例
2018/06/15 PHP
PHP APP微信提现接口代码
2018/09/30 PHP
JavaScript While 循环基础教程
2007/04/05 Javascript
JavaScript 拾漏补遗
2009/12/27 Javascript
使用jquery实现select添加实现后台权限添加的效果
2011/05/28 Javascript
JS子父窗口互相操作取值赋值的方法介绍
2013/05/11 Javascript
JS中数组Array的用法示例介绍
2014/02/20 Javascript
javascript正则表达式中的replace方法详解
2015/04/20 Javascript
javascript中undefined与null的区别
2015/08/16 Javascript
整理Javascript数组学习笔记
2015/11/29 Javascript
jQuery+PHP实现微信转盘抽奖功能的方法
2016/05/25 Javascript
Bootstrap进度条学习使用
2017/02/09 Javascript
js字符串处理之绝妙的代码
2019/04/05 Javascript
element-ui表格合并span-method的实现方法
2019/05/21 Javascript
使用layui监听器监听select下拉框,事件绑定不成功的解决方法
2019/09/28 Javascript
小程序使用wxs解决wxml保留2位小数问题
2019/12/13 Javascript
小程序按钮避免多次调用接口和点击方案实现(不用showLoading)
2020/04/15 Javascript
[01:59]DOTA2首部纪录片《Free to play》预告片
2014/03/12 DOTA
python爬虫获取京东手机图片的图文教程
2017/12/29 Python
pandas object格式转float64格式的方法
2018/04/10 Python
Python针对给定列表中元素进行翻转操作的方法分析
2018/04/27 Python
浅析Python与Mongodb数据库之间的操作方法
2019/07/01 Python
python Django里CSRF 对应策略详解
2019/08/05 Python
kafka监控获取指定topic的消息总量示例
2019/12/23 Python
学前教育学生自荐信范文
2013/12/31 职场文书
同事吵架检讨书
2014/02/05 职场文书
2014年勤工助学工作总结
2014/11/24 职场文书
小学生作文评语集锦
2014/12/25 职场文书
初中学生操行评语
2014/12/26 职场文书
贷款担保书
2015/01/20 职场文书
增值税发票丢失证明
2015/06/19 职场文书
社区服务理念口号
2015/12/25 职场文书