pytorch自定义初始化权重的方法


Posted in Python onAugust 17, 2019

在常见的pytorch代码中,我们见到的初始化方式都是调用init类对每层所有参数进行初始化。但是,有时我们有些特殊需求,比如用某一层的权重取优化其它层,或者手动指定某些权重的初始值。

核心思想就是构造和该层权重同一尺寸的矩阵去对该层权重赋值。但是,值得注意的是,pytorch中各层权重的数据类型是nn.Parameter,而不是Tensor或者Variable。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
 
# 第一一个卷积层,我们可以看到它的权值是随机初始化的
w=torch.nn.Conv2d(2,2,3,padding=1)
print(w.weight)
 
 
# 第一种方法
print("1.使用另一个Conv层的权值")
q=torch.nn.Conv2d(2,2,3,padding=1) # 假设q代表一个训练好的卷积层
print(q.weight) # 可以看到q的权重和w是不同的
w.weight=q.weight # 把一个Conv层的权重赋值给另一个Conv层
print(w.weight)
 
# 第二种方法
print("2.使用来自Tensor的权值")
ones=torch.Tensor(np.ones([2,2,3,3])) # 先创建一个自定义权值的Tensor,这里为了方便将所有权值设为1
w.weight=torch.nn.Parameter(ones) # 把Tensor的值作为权值赋值给Conv层,这里需要先转为torch.nn.Parameter类型,否则将报错
print(w.weight)

附:Variable和Parameter的区别

Parameter 是torch.autograd.Variable的一个字类,常被用于Module的参数。例如权重和偏置。

Parameters和Modules一起使用的时候会有一些特殊的属性。parameters赋值给Module的属性的时候,它会被自动加到Module的参数列表中,即会出现在Parameter()迭代器中。将Varaible赋给Module的时候没有这样的属性。这可以在nn.Module的实现中详细看一下。这样做是为了保存模型的时候只保存权重偏置参数,不保存节点值。所以复写Variable加以区分。

另外一个不同是parameter不能设置volatile,而且require_grad默认设置为true。Varaible默认设置为False.

参数:

parameter.data 得到tensor数据

parameter.requires_grad 默认为True, BP过程中会求导

Parameter一般是在Modules中作为权重和偏置,自动加入参数列表,可以进行保存恢复。和Variable具有相同的运算。

我们可以这样简单区分,在计算图中,数据(包括输入数据和计算过程中产生的feature map等)时variable类型,该类型不会被保存到模型中。 网络的权重是parameter类型,在计算过程中会被更新,将会被保存到模型中。

以上这篇pytorch自定义初始化权重的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现扫描指定目录下的子目录及文件的方法
Jul 16 Python
python通过exifread模块获得图片exif信息的方法
Mar 16 Python
使用Python装饰器在Django框架下去除冗余代码的教程
Apr 16 Python
python字典的常用操作方法小结
May 16 Python
python实现机械分词之逆向最大匹配算法代码示例
Dec 13 Python
mac安装scrapy并创建项目的实例讲解
Jun 13 Python
JavaScript中的模拟事件和自定义事件实例分析
Jul 27 Python
python redis 删除key脚本的实例
Feb 19 Python
python绘图模块matplotlib示例详解
Jul 26 Python
python线性插值解析
Jul 05 Python
django表单中的按钮获取数据的实例分析
Jul 31 Python
Django如何在不停机的情况下创建索引
Aug 02 Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
You might like
用PHP4访问Oracle815
2006/10/09 PHP
php empty() 检查一个变量是否为空
2011/11/10 PHP
php实现utf-8和GB2312编码相互转换函数代码
2013/02/07 PHP
php 模拟get_headers函数的代码示例
2013/04/27 PHP
详解php设置session(过期、失效、有效期)
2015/11/12 PHP
php for 循环使用的简单实例
2016/06/02 PHP
php redis实现文章发布系统(用户投票系统)
2017/03/04 PHP
使用JS操作页面表格,元素的一些技巧
2007/02/02 Javascript
extJs 下拉框联动实现代码
2010/04/09 Javascript
EXTJS FORM HIDDEN TEXTFIELD 赋值 使用value不好用的问题
2011/04/16 Javascript
Javascript实现重力弹跳拖拽运动效果示例
2013/06/28 Javascript
js Select下拉列表框进行多选、移除、交换内容的具体实现方法
2013/08/13 Javascript
再分享70+免费的jquery 图片滑块效果插件和教程
2014/12/15 Javascript
javascript实现时间格式输出FormatDate函数
2015/01/13 Javascript
nodejs 整合kindEditor实现图片上传
2015/02/03 NodeJs
javascript中indexOf技术详解
2015/05/07 Javascript
javascript创建对象、对象继承的实用方式详解
2016/03/08 Javascript
特殊日期提示功能的实现方法
2016/06/16 Javascript
详解vue2.0 transition 多个元素嵌套使用过渡
2017/06/19 Javascript
vue+Element-ui实现分页效果实例代码详解
2018/12/10 Javascript
微信小程序自定义底部弹出框功能
2020/11/18 Javascript
[06:07]DOTA2-DPC中国联赛3月5日Recap集锦
2021/03/11 DOTA
python引用DLL文件的方法
2015/05/11 Python
基于python的Tkinter编写登陆注册界面
2017/06/30 Python
opencv+python实现均值滤波
2020/02/19 Python
Python如何实现爬取B站视频
2020/05/20 Python
pandas使用函数批量处理数据(map、apply、applymap)
2020/11/27 Python
阿迪达斯印尼官方网站:adidas印尼
2020/02/10 全球购物
会计电算化专业毕业生推荐信
2013/12/24 职场文书
公务员培训心得体会
2013/12/28 职场文书
采购部部长岗位职责
2014/02/06 职场文书
人力资源管理求职信
2014/08/07 职场文书
营业用房租赁协议书
2014/11/26 职场文书
2014年挂职干部工作总结
2014/12/06 职场文书
食堂卫生管理制度
2015/08/04 职场文书
go类型转换及与C的类型转换方式
2021/05/05 Golang