pytorch自定义初始化权重的方法


Posted in Python onAugust 17, 2019

在常见的pytorch代码中,我们见到的初始化方式都是调用init类对每层所有参数进行初始化。但是,有时我们有些特殊需求,比如用某一层的权重取优化其它层,或者手动指定某些权重的初始值。

核心思想就是构造和该层权重同一尺寸的矩阵去对该层权重赋值。但是,值得注意的是,pytorch中各层权重的数据类型是nn.Parameter,而不是Tensor或者Variable。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
 
# 第一一个卷积层,我们可以看到它的权值是随机初始化的
w=torch.nn.Conv2d(2,2,3,padding=1)
print(w.weight)
 
 
# 第一种方法
print("1.使用另一个Conv层的权值")
q=torch.nn.Conv2d(2,2,3,padding=1) # 假设q代表一个训练好的卷积层
print(q.weight) # 可以看到q的权重和w是不同的
w.weight=q.weight # 把一个Conv层的权重赋值给另一个Conv层
print(w.weight)
 
# 第二种方法
print("2.使用来自Tensor的权值")
ones=torch.Tensor(np.ones([2,2,3,3])) # 先创建一个自定义权值的Tensor,这里为了方便将所有权值设为1
w.weight=torch.nn.Parameter(ones) # 把Tensor的值作为权值赋值给Conv层,这里需要先转为torch.nn.Parameter类型,否则将报错
print(w.weight)

附:Variable和Parameter的区别

Parameter 是torch.autograd.Variable的一个字类,常被用于Module的参数。例如权重和偏置。

Parameters和Modules一起使用的时候会有一些特殊的属性。parameters赋值给Module的属性的时候,它会被自动加到Module的参数列表中,即会出现在Parameter()迭代器中。将Varaible赋给Module的时候没有这样的属性。这可以在nn.Module的实现中详细看一下。这样做是为了保存模型的时候只保存权重偏置参数,不保存节点值。所以复写Variable加以区分。

另外一个不同是parameter不能设置volatile,而且require_grad默认设置为true。Varaible默认设置为False.

参数:

parameter.data 得到tensor数据

parameter.requires_grad 默认为True, BP过程中会求导

Parameter一般是在Modules中作为权重和偏置,自动加入参数列表,可以进行保存恢复。和Variable具有相同的运算。

我们可以这样简单区分,在计算图中,数据(包括输入数据和计算过程中产生的feature map等)时variable类型,该类型不会被保存到模型中。 网络的权重是parameter类型,在计算过程中会被更新,将会被保存到模型中。

以上这篇pytorch自定义初始化权重的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现的简单窗口倒计时界面实例
May 05 Python
win10下Python3.6安装、配置以及pip安装包教程
Oct 01 Python
python清理子进程机制剖析
Nov 23 Python
Python批量发送post请求的实现代码
May 05 Python
Scrapy框架使用的基本知识
Oct 21 Python
python实现对输入的密文加密
Mar 20 Python
对python中的os.getpid()和os.fork()函数详解
Aug 08 Python
Django框架 querySet功能解析
Sep 04 Python
python3 assert 断言的使用详解 (区别于python2)
Nov 27 Python
Python 解析简单的XML数据
Jul 24 Python
详解python第三方库的安装、PyInstaller库、random库
Mar 03 Python
pytorch锁死在dataloader(训练时卡死)
May 28 Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
You might like
php 运行效率总结(提示程序速度)
2009/11/26 PHP
IIS7.X配置PHP运行环境小结
2011/06/09 PHP
PHP __autoload函数(自动载入类文件)的使用方法
2012/02/04 PHP
php 检查电子邮件函数(自写)
2014/01/16 PHP
Codeigniter购物车类不能添加中文的解决方法
2014/11/29 PHP
php实现的树形结构数据存取类实例
2014/11/29 PHP
ThinkPHP打开验证码页面显示乱码的解决方法
2014/12/18 PHP
PHP远程调试之XDEBUG
2015/12/29 PHP
CI框架常用经典操作类总结(路由,伪静态,分页,session,验证码等)
2016/11/21 PHP
PHP使用数组实现矩阵数学运算的方法示例
2017/05/29 PHP
PHP性能测试工具xhprof安装与使用方法详解
2018/04/29 PHP
PHP-FPM和Nginx的通信机制详解
2019/02/01 PHP
PHP设计模式之简单工厂和工厂模式实例分析
2019/03/25 PHP
基于jquery封装的一个js分页
2011/11/15 Javascript
jquery使用淘宝接口跨域查询手机号码归属地实例
2013/11/28 Javascript
js防止DIV布局滚动时闪动的解决方法
2014/10/30 Javascript
使用Sticker.js实现贴纸效果
2015/01/28 Javascript
jQuery插件pagewalkthrough实现引导页效果
2015/07/05 Javascript
jQuery Timelinr实现垂直水平时间轴插件(附源码下载)
2016/02/16 Javascript
Vue 父子组件、组件间通信
2017/03/08 Javascript
JavaScript数据结构之链表的实现
2017/03/19 Javascript
微信小程序实现带缩略图轮播效果
2018/11/04 Javascript
深入理解javascript prototype的相关知识
2019/09/19 Javascript
[08:17]Ti9 现场cosplay
2019/09/10 DOTA
Python操作使用MySQL数据库的实例代码
2017/05/25 Python
python登录WeChat 实现自动回复实例详解
2019/05/28 Python
Python字符串中添加、插入特定字符的方法
2019/09/10 Python
如何在Python 游戏中模拟引力
2020/03/27 Python
Html5如何唤起百度地图App的方法
2019/01/27 HTML / CSS
德国街头和运动文化高品质商店:BSTN Store
2017/08/26 全球购物
新郎婚礼致辞
2015/07/27 职场文书
班主任培训研修日志
2015/11/13 职场文书
优胜劣汰,强者为王——读《鲁滨逊漂流记》有感
2019/08/15 职场文书
Golang 获取文件md5校验的方法以及效率对比
2021/05/08 Golang
基于HTML十秒做出淘宝页面
2021/10/24 HTML / CSS
win10蓝屏0xc0000001安全模式进不了怎么办?win10出现0xc0000001的解决方法
2022/08/05 数码科技