Pytorch 实现权重初始化


Posted in Python onDecember 31, 2019

在TensorFlow中,权重的初始化主要是在声明张量的时候进行的。 而PyTorch则提供了另一种方法:首先应该声明张量,然后修改张量的权重。通过调用torch.nn.init包中的多种方法可以将权重初始化为直接访问张量的属性。

1、不初始化的效果

在Pytorch中,定义一个tensor,不进行初始化,打印看看结果:

w = torch.Tensor(3,4)
print (w)

可以看到这时候的初始化的数值都是随机的,而且特别大,这对网络的训练必定不好,最后导致精度提不上,甚至损失无法收敛。

2、初始化的效果

PyTorch提供了多种参数初始化函数:

torch.nn.init.constant(tensor, val)
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.xavier_uniform(tensor, gain=1)

等等。详细请参考:http://pytorch.org/docs/nn.html#torch-nn-init

注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。实际上,我们初始化也是直接去初始化神经网络的参数。

让我们试试效果:

w = torch.Tensor(3,4)
torch.nn.init.normal_(w)
print (w)

3、初始化神经网络的参数

对神经网络的初始化往往放在模型的__init__()函数中,如下所示:

class Net(nn.Module):

def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(Net, self).__init__()
  ***
  *** #定义自己的网络层
  ***

  for m in self.modules():
    if isinstance(m, nn.Conv2d):
      n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
      m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
      m.weight.data.fill_(1)
      m.bias.data.zero_()

***
*** #定义后续的函数
***

也可以采取另一种方式:

定义一个权重初始化函数,如下:

def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv2d') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)
  elif classname.find('Linear') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)

在模型声明时,调用初始化函数,初始化神经网络参数:

model = Net(*****)
model.apply(weights_init)

以上这篇Pytorch 实现权重初始化就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用rsa加密算法模块模拟新浪微博登录
Jan 22 Python
python SSH模块登录,远程机执行shell命令实例解析
Jan 12 Python
78行Python代码实现现微信撤回消息功能
Jul 26 Python
详解Python3注释知识点
Feb 19 Python
Python read函数按字节(字符)读取文件的实现
Jul 03 Python
10分钟教你用python动画演示深度优先算法搜寻逃出迷宫的路径
Aug 12 Python
Python打开文件、文件读写操作、with方式、文件常用函数实例分析
Jan 07 Python
基于python监控程序是否关闭
Jan 14 Python
字典算法实现及操作 --python(实用)
Mar 31 Python
上手简单,功能强大的Python爬虫框架——feapder
Apr 27 Python
浅谈Python 中的复数问题
May 19 Python
pytorch中[..., 0]的用法说明
May 20 Python
pytorch 归一化与反归一化实例
Dec 31 #Python
Pytorch 数据加载与数据预处理方式
Dec 31 #Python
pytorch 数据处理:定义自己的数据集合实例
Dec 31 #Python
pytorch: Parameter 的数据结构实例
Dec 31 #Python
Python测试线程应用程序过程解析
Dec 31 #Python
Python TCPServer 多线程多客户端通信的实现
Dec 31 #Python
python给指定csv表格中的联系人群发邮件(带附件的邮件)
Dec 31 #Python
You might like
php 动态执行带有参数的类方法
2009/04/10 PHP
一个基于phpQuery的php通用采集类分享
2014/04/09 PHP
php递归json类实例
2014/12/02 PHP
php读取文件内容的方法汇总
2015/01/24 PHP
php判断表是否存在的方法
2015/06/18 PHP
php检测文本的编码
2015/07/26 PHP
php设计模式之单例模式代码
2016/06/11 PHP
PHP实现文件上传操作和封装
2020/03/04 PHP
JS+CSS制作DIV层可(最小化/拖拽/排序)功能实现代码
2013/02/25 Javascript
JS数组去重与取重的示例代码
2014/01/24 Javascript
textarea焦点的用法实现获取焦点清空失去焦点提示效果
2014/05/19 Javascript
javascript实现仿IE顶部的可关闭警告条
2015/05/05 Javascript
13个PHP函数超实用
2015/10/21 Javascript
BootStrap智能表单实战系列(八)表单配置json详解
2016/06/13 Javascript
vue中用H5实现文件上传的方法实例代码
2017/05/27 Javascript
Vue 2.5 Level E 发布了: 新功能特性一览
2017/10/24 Javascript
一步步教会你微信小程序的登录鉴权
2018/04/09 Javascript
微信小程序中换行空格(多个空格)写法详解
2018/07/10 Javascript
JS判断用户用的哪个浏览器实例详解
2018/10/09 Javascript
如何在基于vue-cli的项目自定义打包环境
2018/11/10 Javascript
JS实现简易留言板增删功能
2020/02/08 Javascript
在Angular项目使用socket.io实现通信的方法
2021/01/05 Javascript
[35:44]2014 DOTA2华西杯精英邀请赛 5 24 iG VS VG
2014/05/26 DOTA
Python利用正则表达式实现计算器算法思路解析
2018/04/25 Python
python 对类的成员函数开启线程的方法
2019/01/22 Python
python 实现的发送邮件模板【普通邮件、带附件、带图片邮件】
2019/07/06 Python
HTML5+CSS3应用详解
2014/02/24 HTML / CSS
TUMI马来西亚官方网站:国际领先的高品质商旅箱包品牌
2018/04/26 全球购物
意大利在线大学图书馆:Libreria universitaria
2019/07/16 全球购物
Linux文件操作命令都有哪些
2015/02/27 面试题
后勤自我鉴定
2013/10/13 职场文书
社会实践感言
2014/01/25 职场文书
工程项目经理任命书
2014/06/05 职场文书
素质拓展训练感想
2015/08/07 职场文书
因个人工作失误检讨书
2019/06/21 职场文书
Java多条件判断场景中规则执行器的设计
2021/06/26 Java/Android