Pytorch 实现权重初始化


Posted in Python onDecember 31, 2019

在TensorFlow中,权重的初始化主要是在声明张量的时候进行的。 而PyTorch则提供了另一种方法:首先应该声明张量,然后修改张量的权重。通过调用torch.nn.init包中的多种方法可以将权重初始化为直接访问张量的属性。

1、不初始化的效果

在Pytorch中,定义一个tensor,不进行初始化,打印看看结果:

w = torch.Tensor(3,4)
print (w)

可以看到这时候的初始化的数值都是随机的,而且特别大,这对网络的训练必定不好,最后导致精度提不上,甚至损失无法收敛。

2、初始化的效果

PyTorch提供了多种参数初始化函数:

torch.nn.init.constant(tensor, val)
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.xavier_uniform(tensor, gain=1)

等等。详细请参考:http://pytorch.org/docs/nn.html#torch-nn-init

注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。实际上,我们初始化也是直接去初始化神经网络的参数。

让我们试试效果:

w = torch.Tensor(3,4)
torch.nn.init.normal_(w)
print (w)

3、初始化神经网络的参数

对神经网络的初始化往往放在模型的__init__()函数中,如下所示:

class Net(nn.Module):

def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(Net, self).__init__()
  ***
  *** #定义自己的网络层
  ***

  for m in self.modules():
    if isinstance(m, nn.Conv2d):
      n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
      m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
      m.weight.data.fill_(1)
      m.bias.data.zero_()

***
*** #定义后续的函数
***

也可以采取另一种方式:

定义一个权重初始化函数,如下:

def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv2d') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)
  elif classname.find('Linear') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)

在模型声明时,调用初始化函数,初始化神经网络参数:

model = Net(*****)
model.apply(weights_init)

以上这篇Pytorch 实现权重初始化就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python操作xml文件详细介绍
Jun 09 Python
Python入门篇之条件、循环
Oct 17 Python
Python and、or以及and-or语法总结
Apr 14 Python
Python 专题二 条件语句和循环语句的基础知识
Mar 19 Python
Python栈算法的实现与简单应用示例
Nov 01 Python
python判断一个数是否能被另一个整数整除的实例
Dec 12 Python
pandas 把数据写入txt文件每行固定写入一定数量的值方法
Dec 28 Python
Python定义函数功能与用法实例详解
Apr 08 Python
Python 从subprocess运行的子进程中实时获取输出的例子
Aug 14 Python
Python代理IP爬虫的新手使用教程
Sep 05 Python
Python中的延迟绑定原理详解
Oct 11 Python
浅谈Keras的Sequential与PyTorch的Sequential的区别
Jun 17 Python
pytorch 归一化与反归一化实例
Dec 31 #Python
Pytorch 数据加载与数据预处理方式
Dec 31 #Python
pytorch 数据处理:定义自己的数据集合实例
Dec 31 #Python
pytorch: Parameter 的数据结构实例
Dec 31 #Python
Python测试线程应用程序过程解析
Dec 31 #Python
Python TCPServer 多线程多客户端通信的实现
Dec 31 #Python
python给指定csv表格中的联系人群发邮件(带附件的邮件)
Dec 31 #Python
You might like
twig模板常用语句实例小结
2016/02/04 PHP
PHP程序员简单的开展服务治理架构操作详解(二)
2020/05/14 PHP
javascript 面向对象编程  function是方法(函数)
2009/09/17 Javascript
使用Jquery搭建最佳用户体验的登录页面之记住密码自动登录功能(含后台代码)
2011/07/10 Javascript
获取客户端电脑日期时间js代码(jquery)
2012/09/12 Javascript
Jquery 动态循环输出表格具体方法
2013/11/23 Javascript
基于javascript html5实现多文件上传
2016/03/03 Javascript
jQuery操作iframe中js函数的方法小结
2016/07/06 Javascript
浅析Javascript ES6中的原生Promise
2016/08/25 Javascript
用javascript获取任意颜色的更亮或更暗颜色值示例代码
2017/07/21 Javascript
js模拟百度模糊搜索的实例
2017/08/04 Javascript
基于react组件之间的参数传递(详解)
2017/09/05 Javascript
js导出Excel表格超出26位英文字符的解决方法ES6
2017/11/15 Javascript
jQuery AJAX 方法success()后台传来的4种数据详解
2018/08/08 jQuery
bootstrapTable+ajax加载数据 refresh更新数据
2018/08/31 Javascript
vue数据操作之点击事件实现num加减功能示例
2019/01/19 Javascript
ES6 Object属性新的写法实例小结
2019/06/25 Javascript
JS+CSS实现炫酷光感效果
2020/09/05 Javascript
原生JS生成指定位数的验证码
2020/10/28 Javascript
[42:06]2019国际邀请赛全明星赛 8.23
2019/09/05 DOTA
python实现跨文件全局变量的方法
2014/07/07 Python
python3 pillow生成简单验证码图片的示例
2017/09/19 Python
50行Python代码实现人脸检测功能
2018/01/23 Python
django rest framework 过滤时间操作
2020/07/12 Python
H5调用相机拍照并压缩图片的实例代码
2017/07/20 HTML / CSS
美国正宗设计师眼镜在线零售商:EYEZZ
2019/03/23 全球购物
酒吧总经理岗位职责
2013/12/10 职场文书
建筑设计学生的自我评价
2014/01/16 职场文书
年度考核自我评价
2014/01/25 职场文书
房地产开发项目建议书
2014/05/16 职场文书
小学领导班子对照材料
2014/08/23 职场文书
副总经理岗位职责范本
2014/09/30 职场文书
2014年设计师工作总结
2014/11/25 职场文书
2016猴年春节慰问信
2015/11/30 职场文书
Rust 连接 PostgreSQL 数据库的详细过程
2022/01/22 PostgreSQL
世界无敌的ICOM IC-R9500宽频接收机
2022/03/25 无线电