Pytorch 实现权重初始化


Posted in Python onDecember 31, 2019

在TensorFlow中,权重的初始化主要是在声明张量的时候进行的。 而PyTorch则提供了另一种方法:首先应该声明张量,然后修改张量的权重。通过调用torch.nn.init包中的多种方法可以将权重初始化为直接访问张量的属性。

1、不初始化的效果

在Pytorch中,定义一个tensor,不进行初始化,打印看看结果:

w = torch.Tensor(3,4)
print (w)

可以看到这时候的初始化的数值都是随机的,而且特别大,这对网络的训练必定不好,最后导致精度提不上,甚至损失无法收敛。

2、初始化的效果

PyTorch提供了多种参数初始化函数:

torch.nn.init.constant(tensor, val)
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.xavier_uniform(tensor, gain=1)

等等。详细请参考:http://pytorch.org/docs/nn.html#torch-nn-init

注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。实际上,我们初始化也是直接去初始化神经网络的参数。

让我们试试效果:

w = torch.Tensor(3,4)
torch.nn.init.normal_(w)
print (w)

3、初始化神经网络的参数

对神经网络的初始化往往放在模型的__init__()函数中,如下所示:

class Net(nn.Module):

def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(Net, self).__init__()
  ***
  *** #定义自己的网络层
  ***

  for m in self.modules():
    if isinstance(m, nn.Conv2d):
      n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
      m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
      m.weight.data.fill_(1)
      m.bias.data.zero_()

***
*** #定义后续的函数
***

也可以采取另一种方式:

定义一个权重初始化函数,如下:

def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv2d') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)
  elif classname.find('Linear') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)

在模型声明时,调用初始化函数,初始化神经网络参数:

model = Net(*****)
model.apply(weights_init)

以上这篇Pytorch 实现权重初始化就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python同时给两个收件人发送邮件的方法
Apr 30 Python
python实现读取excel写入mysql的小工具详解
Nov 20 Python
python实现二叉树的遍历
Dec 11 Python
python cs架构实现简单文件传输
Mar 20 Python
Python将一个Excel拆分为多个Excel
Nov 07 Python
在Python中实现shuffle给列表洗牌
Nov 08 Python
python 发送和接收ActiveMQ消息的实例
Jan 30 Python
python2.7 安装pip的方法步骤(管用)
May 05 Python
pyqt5之将textBrowser的内容写入txt文档的方法
Jun 21 Python
Python 获取 datax 执行结果保存到数据库的方法
Jul 11 Python
Python正则表达式高级使用方法汇总
Jun 18 Python
python获取时间戳的实现示例(10位和13位)
Sep 23 Python
pytorch 归一化与反归一化实例
Dec 31 #Python
Pytorch 数据加载与数据预处理方式
Dec 31 #Python
pytorch 数据处理:定义自己的数据集合实例
Dec 31 #Python
pytorch: Parameter 的数据结构实例
Dec 31 #Python
Python测试线程应用程序过程解析
Dec 31 #Python
Python TCPServer 多线程多客户端通信的实现
Dec 31 #Python
python给指定csv表格中的联系人群发邮件(带附件的邮件)
Dec 31 #Python
You might like
国内php原创论坛
2006/10/09 PHP
php邮件发送,php发送邮件的类
2011/03/24 PHP
php若干单维数组遍历方法的比较
2011/09/20 PHP
php截取字符串并保留完整xml标签的函数代码
2013/02/06 PHP
php使用QueryList轻松采集js动态渲染页面方法
2018/09/11 PHP
Jquery 插件开发笔记整理
2011/01/17 Javascript
基于jQuery实现select下拉选择可输入附源码下载
2016/02/03 Javascript
详解打造 Vue.js 可复用组件
2017/03/24 Javascript
nodejs个人博客开发第四步 数据模型
2017/04/12 NodeJs
JS回调函数基本定义与用法实例分析
2017/05/24 Javascript
template.js前端模板引擎使用详解
2017/10/10 Javascript
详述 Sublime Text 打开 GBK 格式中文乱码的解决方法
2017/10/26 Javascript
pace.js和NProgress.js两个加载进度插件的一点小总结
2018/01/31 Javascript
JavaScript正则表达式函数总结(常用)
2018/02/22 Javascript
vue实现循环切换动画
2018/10/17 Javascript
Vue axios全局拦截 get请求、post请求、配置请求的实例代码
2018/11/28 Javascript
jquery多级树形下拉菜单的实例代码
2019/07/09 jQuery
JS表格的动态操作完整示例
2020/01/13 Javascript
Python实现配置文件备份的方法
2015/07/30 Python
浅谈Python的list中的选取范围
2018/11/12 Python
使用python进行波形及频谱绘制的方法
2019/06/17 Python
python 输出列表元素实例(以空格/逗号为分隔符)
2019/12/25 Python
Python3.7实现验证码登录方式代码实例
2020/02/14 Python
Python使用struct处理二进制(pack和unpack用法)
2020/11/12 Python
解决pytorch 保存模型遇到的问题
2021/03/03 Python
HTML5进阶段内联标签汇总(小篇)
2016/07/13 HTML / CSS
JD Sports德国官网:英国领先的运动鞋和运动服饰零售商
2018/02/26 全球购物
优秀应届生推荐信
2013/11/09 职场文书
工程监理应届生求职信
2013/11/09 职场文书
顶碗少年教学反思
2014/02/21 职场文书
房屋转让协议书
2014/04/11 职场文书
2014年学生会部门工作总结
2014/11/07 职场文书
个人汇报材料范文
2014/12/30 职场文书
辞职信模板(中英文版)
2015/02/27 职场文书
使用numpy nonzero 找出非0元素
2021/05/14 Python
海贼王十大潜力果实,路飞仅排第十,第一可毁世界(震震果实)
2022/03/18 日漫