Pytorch 实现权重初始化


Posted in Python onDecember 31, 2019

在TensorFlow中,权重的初始化主要是在声明张量的时候进行的。 而PyTorch则提供了另一种方法:首先应该声明张量,然后修改张量的权重。通过调用torch.nn.init包中的多种方法可以将权重初始化为直接访问张量的属性。

1、不初始化的效果

在Pytorch中,定义一个tensor,不进行初始化,打印看看结果:

w = torch.Tensor(3,4)
print (w)

可以看到这时候的初始化的数值都是随机的,而且特别大,这对网络的训练必定不好,最后导致精度提不上,甚至损失无法收敛。

2、初始化的效果

PyTorch提供了多种参数初始化函数:

torch.nn.init.constant(tensor, val)
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.xavier_uniform(tensor, gain=1)

等等。详细请参考:http://pytorch.org/docs/nn.html#torch-nn-init

注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。实际上,我们初始化也是直接去初始化神经网络的参数。

让我们试试效果:

w = torch.Tensor(3,4)
torch.nn.init.normal_(w)
print (w)

3、初始化神经网络的参数

对神经网络的初始化往往放在模型的__init__()函数中,如下所示:

class Net(nn.Module):

def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(Net, self).__init__()
  ***
  *** #定义自己的网络层
  ***

  for m in self.modules():
    if isinstance(m, nn.Conv2d):
      n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
      m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
      m.weight.data.fill_(1)
      m.bias.data.zero_()

***
*** #定义后续的函数
***

也可以采取另一种方式:

定义一个权重初始化函数,如下:

def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv2d') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)
  elif classname.find('Linear') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)

在模型声明时,调用初始化函数,初始化神经网络参数:

model = Net(*****)
model.apply(weights_init)

以上这篇Pytorch 实现权重初始化就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python语言技巧之三元运算符使用介绍
Mar 04 Python
在Python中使用matplotlib模块绘制数据图的示例
May 04 Python
Python的time模块中的常用方法整理
Jun 18 Python
Python的Asyncore异步Socket模块及实现端口转发的例子
Jun 14 Python
Python中类型检查的详细介绍
Feb 13 Python
Python语言描述KNN算法与Kd树
Dec 13 Python
win7+Python3.5下scrapy的安装方法
Jul 31 Python
浅谈python3.x pool.map()方法的实质
Jan 16 Python
python的re模块使用方法详解
Jul 26 Python
Python 元组拆包示例(Tuple Unpacking)
Dec 24 Python
Python函数的迭代器与生成器的示例代码
Jun 18 Python
Selenium获取登录Cookies并添加Cookies自动登录的方法
Dec 04 Python
pytorch 归一化与反归一化实例
Dec 31 #Python
Pytorch 数据加载与数据预处理方式
Dec 31 #Python
pytorch 数据处理:定义自己的数据集合实例
Dec 31 #Python
pytorch: Parameter 的数据结构实例
Dec 31 #Python
Python测试线程应用程序过程解析
Dec 31 #Python
Python TCPServer 多线程多客户端通信的实现
Dec 31 #Python
python给指定csv表格中的联系人群发邮件(带附件的邮件)
Dec 31 #Python
You might like
php设计模式之观察者模式的应用详解
2013/05/21 PHP
PHP判断IP并转跳到相应城市分站的方法
2015/03/25 PHP
php支付宝APP支付功能
2020/07/29 PHP
jquery插件制作 自增长输入框实现代码
2012/08/17 jQuery
jQuery方法简洁实现隔行换色及toggleClass的使用
2013/03/15 Javascript
自动设置iframe大小的jQuery代码
2013/09/11 Javascript
jquery默认校验规则整理
2014/03/24 Javascript
jQueryMobile之Helloworld与页面切换的方法
2015/02/04 Javascript
jQuery获取某天的农历日期并判断是否除夕或新年的方法
2016/03/01 Javascript
不能不知道的10个angularjs英文学习网站
2016/03/23 Javascript
浅析Bootstrip的select控件绑定数据的问题
2016/05/10 Javascript
JS制作图形验证码实现代码
2020/10/19 Javascript
jQuery Validate验证框架详解(推荐)
2016/12/17 Javascript
AngularJS使用带属性值的ng-app指令实现自定义模块自动加载的方法
2017/01/04 Javascript
基于JQuery及AJAX实现名人名言随机生成器
2017/02/10 Javascript
javascript完美实现给定日期返回上月日期的方法
2017/06/15 Javascript
vue2.0中click点击当前li实现动态切换class
2017/06/21 Javascript
详解AngularJS跨页面传值(ui-router)
2017/08/23 Javascript
angularjs使用div模拟textarea文本框的方法
2018/10/02 Javascript
深入学习JavaScript 高阶函数
2019/06/11 Javascript
LayUI数据接口返回实体封装的例子
2019/09/12 Javascript
详解vue中使用axios对同一个接口连续请求导致返回数据混乱的问题
2019/11/06 Javascript
微信小程序实现拨打电话功能的示例代码
2020/06/28 Javascript
Python通过matplotlib画双层饼图及环形图简单示例
2017/12/15 Python
用pandas中的DataFrame时选取行或列的方法
2018/07/11 Python
django的auth认证,authenticate和装饰器功能详解
2019/07/25 Python
Python输出指定字符串的方法
2020/02/06 Python
python十进制转二进制的详解
2020/02/07 Python
CSS3 flex布局之快速实现BorderLayout布局
2015/12/03 HTML / CSS
Lulu Guinness露露·吉尼斯官网:红唇包
2019/02/03 全球购物
小学生清明节演讲稿
2014/09/05 职场文书
财会专业大学生求职信
2014/09/26 职场文书
会议主持人开场白台词
2015/05/28 职场文书
六一活动主持词
2015/06/30 职场文书
企业内部管理控制:银行存款控制制度范本
2020/01/10 职场文书
redis击穿 雪崩 穿透超详细解决方案梳理
2022/03/17 Redis