Pytorch 实现权重初始化


Posted in Python onDecember 31, 2019

在TensorFlow中,权重的初始化主要是在声明张量的时候进行的。 而PyTorch则提供了另一种方法:首先应该声明张量,然后修改张量的权重。通过调用torch.nn.init包中的多种方法可以将权重初始化为直接访问张量的属性。

1、不初始化的效果

在Pytorch中,定义一个tensor,不进行初始化,打印看看结果:

w = torch.Tensor(3,4)
print (w)

可以看到这时候的初始化的数值都是随机的,而且特别大,这对网络的训练必定不好,最后导致精度提不上,甚至损失无法收敛。

2、初始化的效果

PyTorch提供了多种参数初始化函数:

torch.nn.init.constant(tensor, val)
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.xavier_uniform(tensor, gain=1)

等等。详细请参考:http://pytorch.org/docs/nn.html#torch-nn-init

注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。实际上,我们初始化也是直接去初始化神经网络的参数。

让我们试试效果:

w = torch.Tensor(3,4)
torch.nn.init.normal_(w)
print (w)

3、初始化神经网络的参数

对神经网络的初始化往往放在模型的__init__()函数中,如下所示:

class Net(nn.Module):

def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(Net, self).__init__()
  ***
  *** #定义自己的网络层
  ***

  for m in self.modules():
    if isinstance(m, nn.Conv2d):
      n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
      m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
      m.weight.data.fill_(1)
      m.bias.data.zero_()

***
*** #定义后续的函数
***

也可以采取另一种方式:

定义一个权重初始化函数,如下:

def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv2d') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)
  elif classname.find('Linear') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)

在模型声明时,调用初始化函数,初始化神经网络参数:

model = Net(*****)
model.apply(weights_init)

以上这篇Pytorch 实现权重初始化就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Numpy掩码式数组详解
Apr 17 Python
用TensorFlow实现多类支持向量机的示例代码
Apr 28 Python
详解python里的命名规范
Jul 16 Python
Python实现FTP弱口令扫描器的方法示例
Jan 31 Python
python pygame实现五子棋小游戏
Oct 26 Python
Python爬虫学习之翻译小程序
Jul 30 Python
Django实现列表页商品数据返回教程
Apr 03 Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 Python
keras实现VGG16 CIFAR10数据集方式
Jul 07 Python
Python3爬虫关于识别检验滑动验证码的实例
Jul 30 Python
python实现马丁策略的实例详解
Jan 15 Python
Pandas数据分析的一些常用小技巧
Feb 07 Python
pytorch 归一化与反归一化实例
Dec 31 #Python
Pytorch 数据加载与数据预处理方式
Dec 31 #Python
pytorch 数据处理:定义自己的数据集合实例
Dec 31 #Python
pytorch: Parameter 的数据结构实例
Dec 31 #Python
Python测试线程应用程序过程解析
Dec 31 #Python
Python TCPServer 多线程多客户端通信的实现
Dec 31 #Python
python给指定csv表格中的联系人群发邮件(带附件的邮件)
Dec 31 #Python
You might like
[原创]PHP中通过ADODB库实现调用Access数据库之修正版本
2006/12/31 PHP
详解PHP5.6.30与Apache2.4.x配置
2017/06/02 PHP
Thinkphp5框架实现图片、音频和视频文件的上传功能详解
2019/08/27 PHP
PHP设计模式概论【概念、分类、原则等】
2020/05/01 PHP
jQuery dialog 异步调用ashx,webservice数据的代码
2010/08/03 Javascript
javascript的console.log()用法小结
2012/05/31 Javascript
javascript中的undefined和not defined区别示例介绍
2014/02/26 Javascript
jQuery表格排序组件-tablesorter使用示例
2014/05/26 Javascript
Bootstrap每天必学之进度条
2015/11/30 Javascript
JS获取IE版本号与HTML设置IE文档模式的方法
2016/10/09 Javascript
BootStrap实现带有增删改查功能的表格(DEMO详解)
2016/10/26 Javascript
jQuery Password Validation密码验证
2016/12/30 Javascript
微信小程序 数组(增,删,改,查)等操作实例详解
2017/01/05 Javascript
微信小程序 下拉列表的实现实例代码
2017/03/08 Javascript
解决JS外部文件中文注释出现乱码问题
2017/07/09 Javascript
快速了解vue-cli 3.0 新特性
2018/02/28 Javascript
Python更新数据库脚本两种方法及对比介绍
2017/07/27 Python
Python多进程库multiprocessing中进程池Pool类的使用详解
2017/11/24 Python
给 TensorFlow 变量进行赋值的方式
2020/02/10 Python
Python MySQLdb 执行sql语句时的参数传递方式
2020/03/04 Python
如何使用Python自动生成报表并以邮件发送
2020/10/15 Python
详解Django关于StreamingHttpResponse与FileResponse文件下载的最优方法
2021/01/07 Python
css3实现蒙版弹幕功能
2019/06/18 HTML / CSS
男女时尚与复古风格在线购物:RoseGal(全球免费送货)
2017/07/19 全球购物
TobyDeals美国:在电子产品上获得最好的优惠和折扣
2019/08/11 全球购物
计算机专业推荐信范文
2013/11/20 职场文书
大学生期末自我鉴定
2014/02/01 职场文书
商业房地产广告语
2014/03/13 职场文书
大学班级学风建设方案
2014/05/01 职场文书
关于青春的演讲稿三分钟
2014/08/22 职场文书
私人房屋买卖协议书
2014/10/04 职场文书
普通党员群众路线教育实践活动心得体会
2014/11/04 职场文书
食品质检员岗位职责
2015/04/08 职场文书
员工工作心得体会
2019/05/07 职场文书
Python中os模块的简单使用及重命名操作
2021/04/17 Python
python创建字典及相关管理操作
2022/04/13 Python