Pytorch 实现权重初始化


Posted in Python onDecember 31, 2019

在TensorFlow中,权重的初始化主要是在声明张量的时候进行的。 而PyTorch则提供了另一种方法:首先应该声明张量,然后修改张量的权重。通过调用torch.nn.init包中的多种方法可以将权重初始化为直接访问张量的属性。

1、不初始化的效果

在Pytorch中,定义一个tensor,不进行初始化,打印看看结果:

w = torch.Tensor(3,4)
print (w)

可以看到这时候的初始化的数值都是随机的,而且特别大,这对网络的训练必定不好,最后导致精度提不上,甚至损失无法收敛。

2、初始化的效果

PyTorch提供了多种参数初始化函数:

torch.nn.init.constant(tensor, val)
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.xavier_uniform(tensor, gain=1)

等等。详细请参考:http://pytorch.org/docs/nn.html#torch-nn-init

注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。实际上,我们初始化也是直接去初始化神经网络的参数。

让我们试试效果:

w = torch.Tensor(3,4)
torch.nn.init.normal_(w)
print (w)

3、初始化神经网络的参数

对神经网络的初始化往往放在模型的__init__()函数中,如下所示:

class Net(nn.Module):

def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(Net, self).__init__()
  ***
  *** #定义自己的网络层
  ***

  for m in self.modules():
    if isinstance(m, nn.Conv2d):
      n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
      m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
      m.weight.data.fill_(1)
      m.bias.data.zero_()

***
*** #定义后续的函数
***

也可以采取另一种方式:

定义一个权重初始化函数,如下:

def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv2d') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)
  elif classname.find('Linear') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)

在模型声明时,调用初始化函数,初始化神经网络参数:

model = Net(*****)
model.apply(weights_init)

以上这篇Pytorch 实现权重初始化就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用不同编码读写txt文件详解
May 28 Python
举例简单讲解Python中的数据存储模块shelve的用法
Mar 03 Python
python3之微信文章爬虫实例讲解
Jul 12 Python
Django 导出 Excel 代码的实例详解
Aug 11 Python
python 3.0 模拟用户登录功能并实现三次错误锁定
Nov 01 Python
python模块导入的细节详解
Dec 10 Python
Python 复平面绘图实例
Nov 21 Python
python脚本实现mp4中的音频提取并保存在原目录
Feb 27 Python
使用Python将图片转正方形的两种方法实例代码详解
Apr 29 Python
Python项目跨域问题解决方案
Jun 22 Python
Python基础之变量的相关知识总结
Jun 23 Python
Python使用openpyxl批量处理数据
Jun 23 Python
pytorch 归一化与反归一化实例
Dec 31 #Python
Pytorch 数据加载与数据预处理方式
Dec 31 #Python
pytorch 数据处理:定义自己的数据集合实例
Dec 31 #Python
pytorch: Parameter 的数据结构实例
Dec 31 #Python
Python测试线程应用程序过程解析
Dec 31 #Python
Python TCPServer 多线程多客户端通信的实现
Dec 31 #Python
python给指定csv表格中的联系人群发邮件(带附件的邮件)
Dec 31 #Python
You might like
php在程序中将网页生成word文档并提供下载的代码
2012/10/09 PHP
PHP无限分类(树形类)的深入分析
2013/06/02 PHP
PHP添加Xdebug扩展的方法
2014/02/12 PHP
基于PHP实现的事件机制实例分析
2015/06/18 PHP
PHP函数积累总结
2019/03/19 PHP
js跑马灯代码(自写)
2013/04/17 Javascript
JS中判断null、undefined与NaN的方法
2014/03/24 Javascript
js实现的点击数量加一可操作数据库
2014/05/09 Javascript
Node.js开源应用框架HapiJS介绍
2015/01/14 Javascript
js 判断登录界面的账号密码是否为空
2017/02/08 Javascript
详解Windows下安装Nodejs步骤
2017/05/18 NodeJs
Vue.js2.0中的变化小结
2017/10/24 Javascript
Js通过AES加密后PHP用Openssl解密的方法
2019/07/12 Javascript
JavaScript使用表单元素验证表单的示例代码
2019/08/20 Javascript
[02:40]DOTA2殁境神蚀者 英雄基础教程
2013/11/26 DOTA
python使用beautifulsoup从爱奇艺网抓取视频播放
2014/01/23 Python
Python中使用item()方法遍历字典的例子
2014/08/26 Python
Python中逗号的三种作用实例分析
2015/06/08 Python
python将一组数分成每3个一组的实例
2018/11/14 Python
对python中字典keys,values,items的使用详解
2019/02/03 Python
Python爬虫 bilibili视频弹幕提取过程详解
2019/07/31 Python
Scrapy框架基本命令与settings.py设置
2020/02/06 Python
TensorFLow 变量命名空间实例
2020/02/11 Python
Python中的整除和取模实例
2020/06/03 Python
加拿大购物频道:The Shopping Channel
2016/07/21 全球购物
viagogo波兰票务平台:演唱会、体育比赛、戏剧门票
2018/04/23 全球购物
白酒业务员岗位职责
2013/12/27 职场文书
高三语文教学反思
2014/01/15 职场文书
电子装配专业毕业生求职信
2014/04/23 职场文书
2015年保险公司工作总结
2015/04/24 职场文书
建国大业观后感800字
2015/06/01 职场文书
七年级数学教学反思
2016/02/17 职场文书
python - asyncio异步编程
2021/04/06 Python
详解如何使用Node.js实现热重载页面
2021/05/06 Javascript
vue-cropper组件实现图片切割上传
2021/05/27 Vue.js
详解thinkphp的Auth类认证
2021/05/28 PHP