Pytorch 实现权重初始化


Posted in Python onDecember 31, 2019

在TensorFlow中,权重的初始化主要是在声明张量的时候进行的。 而PyTorch则提供了另一种方法:首先应该声明张量,然后修改张量的权重。通过调用torch.nn.init包中的多种方法可以将权重初始化为直接访问张量的属性。

1、不初始化的效果

在Pytorch中,定义一个tensor,不进行初始化,打印看看结果:

w = torch.Tensor(3,4)
print (w)

可以看到这时候的初始化的数值都是随机的,而且特别大,这对网络的训练必定不好,最后导致精度提不上,甚至损失无法收敛。

2、初始化的效果

PyTorch提供了多种参数初始化函数:

torch.nn.init.constant(tensor, val)
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.xavier_uniform(tensor, gain=1)

等等。详细请参考:http://pytorch.org/docs/nn.html#torch-nn-init

注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。实际上,我们初始化也是直接去初始化神经网络的参数。

让我们试试效果:

w = torch.Tensor(3,4)
torch.nn.init.normal_(w)
print (w)

3、初始化神经网络的参数

对神经网络的初始化往往放在模型的__init__()函数中,如下所示:

class Net(nn.Module):

def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(Net, self).__init__()
  ***
  *** #定义自己的网络层
  ***

  for m in self.modules():
    if isinstance(m, nn.Conv2d):
      n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
      m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
      m.weight.data.fill_(1)
      m.bias.data.zero_()

***
*** #定义后续的函数
***

也可以采取另一种方式:

定义一个权重初始化函数,如下:

def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv2d') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)
  elif classname.find('Linear') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)

在模型声明时,调用初始化函数,初始化神经网络参数:

model = Net(*****)
model.apply(weights_init)

以上这篇Pytorch 实现权重初始化就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过imaplib模块读取gmail里邮件的方法
May 08 Python
详解 Python 与文件对象共事的实例
Sep 11 Python
Django实现组合搜索的方法示例
Jan 23 Python
详解Python核心对象类型字符串
Feb 11 Python
对python3新增的byte类型详解
Dec 04 Python
python绘制散点图并标记序号的方法
Dec 11 Python
python对于requests的封装方法详解
Jan 03 Python
Ubuntu下Anaconda和Pycharm配置方法详解
Jun 14 Python
Django ORM 自定义 char 类型字段解析
Aug 09 Python
Python 如何批量更新已安装的库
May 26 Python
Python如何脚本过滤文件中的注释
May 27 Python
Django数据统计功能count()的使用
Nov 30 Python
pytorch 归一化与反归一化实例
Dec 31 #Python
Pytorch 数据加载与数据预处理方式
Dec 31 #Python
pytorch 数据处理:定义自己的数据集合实例
Dec 31 #Python
pytorch: Parameter 的数据结构实例
Dec 31 #Python
Python测试线程应用程序过程解析
Dec 31 #Python
Python TCPServer 多线程多客户端通信的实现
Dec 31 #Python
python给指定csv表格中的联系人群发邮件(带附件的邮件)
Dec 31 #Python
You might like
PHP中在数据库中保存Checkbox数据(1)
2006/10/09 PHP
YII框架中搜索分页jQuery写法详解
2016/12/19 PHP
基于win2003虚拟机中apache服务器的访问
2017/08/01 PHP
PHP ajax+jQuery 实现批量删除功能实例代码小结
2018/12/06 PHP
Laravel 5.5 实现禁用用户注册示例
2019/10/24 PHP
不错的JS中变量相关的细节分析
2007/08/13 Javascript
JS 实现Table相同行的单元格自动合并示例代码
2013/08/27 Javascript
JS实现距离上次刷新已过多少秒示例
2014/05/23 Javascript
javascript实现全角半角检测的方法
2015/07/23 Javascript
原生js实现图片层叠轮播切换效果
2016/02/02 Javascript
基于Vue如何封装分页组件
2016/12/16 Javascript
jQuery编写设置和获取颜色的插件
2017/01/09 Javascript
基于Vue过渡状态实例讲解
2017/09/14 Javascript
如何将HTML字符转换为DOM节点并动态添加到文档中详解
2018/08/19 Javascript
详解Node.js中path模块的resolve()和join()方法的区别
2018/10/29 Javascript
JavaScript学习笔记之数组基本操作示例
2019/01/09 Javascript
vue中轮训器的使用
2019/01/27 Javascript
解决父组件将子组件作为弹窗调用只执行一次created的问题
2020/07/24 Javascript
[51:17]VGJ.T vs Mineski 2018国际邀请赛小组赛BO2 第二场 8.18
2018/08/19 DOTA
Python下载网络文本数据到本地内存的四种实现方法示例
2018/02/05 Python
Python编写合并字典并实现敏感目录的小脚本
2019/02/26 Python
在pycharm下设置自己的个性模版方法
2019/07/15 Python
django云端留言板实例详解
2019/07/22 Python
PyCharm2019安装教程及其使用(图文教程)
2019/09/29 Python
Python3实现配置文件差异对比脚本
2019/11/18 Python
详解python opencv、scikit-image和PIL图像处理库比较
2019/12/26 Python
详解CSS3的box-shadow属性制作边框阴影效果的方法
2016/05/10 HTML / CSS
突袭HTML5之Javascript API扩展4—拖拽(Drag/Drop)概述
2013/01/31 HTML / CSS
施华洛世奇日本官网:SWAROVSKI日本
2018/05/04 全球购物
EJB面试题
2015/07/28 面试题
写clone()方法时,通常都有一行代码,是什么?
2012/10/31 面试题
工程售后服务承诺书
2014/05/21 职场文书
遗嘱继承权公证书
2015/01/26 职场文书
施工员岗位职责范本
2015/04/11 职场文书
Java基础之线程锁相关知识总结
2021/06/30 Java/Android
十大最帅动漫男主 碓冰拓海上榜,第一是《灌篮高手》男主角
2022/03/18 日漫