Pytorch 实现权重初始化


Posted in Python onDecember 31, 2019

在TensorFlow中,权重的初始化主要是在声明张量的时候进行的。 而PyTorch则提供了另一种方法:首先应该声明张量,然后修改张量的权重。通过调用torch.nn.init包中的多种方法可以将权重初始化为直接访问张量的属性。

1、不初始化的效果

在Pytorch中,定义一个tensor,不进行初始化,打印看看结果:

w = torch.Tensor(3,4)
print (w)

可以看到这时候的初始化的数值都是随机的,而且特别大,这对网络的训练必定不好,最后导致精度提不上,甚至损失无法收敛。

2、初始化的效果

PyTorch提供了多种参数初始化函数:

torch.nn.init.constant(tensor, val)
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.xavier_uniform(tensor, gain=1)

等等。详细请参考:http://pytorch.org/docs/nn.html#torch-nn-init

注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。实际上,我们初始化也是直接去初始化神经网络的参数。

让我们试试效果:

w = torch.Tensor(3,4)
torch.nn.init.normal_(w)
print (w)

3、初始化神经网络的参数

对神经网络的初始化往往放在模型的__init__()函数中,如下所示:

class Net(nn.Module):

def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(Net, self).__init__()
  ***
  *** #定义自己的网络层
  ***

  for m in self.modules():
    if isinstance(m, nn.Conv2d):
      n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
      m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
      m.weight.data.fill_(1)
      m.bias.data.zero_()

***
*** #定义后续的函数
***

也可以采取另一种方式:

定义一个权重初始化函数,如下:

def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv2d') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)
  elif classname.find('Linear') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)

在模型声明时,调用初始化函数,初始化神经网络参数:

model = Net(*****)
model.apply(weights_init)

以上这篇Pytorch 实现权重初始化就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
教你如何在Django 1.6中正确使用 Signal
Jun 22 Python
python仿抖音表白神器
Apr 08 Python
pyqt5 comboBox获得下标、文本和事件选中函数的方法
Jun 14 Python
python 数据提取及拆分的实现代码
Aug 26 Python
Python for i in range ()用法详解
Sep 18 Python
pytorch方法测试详解——归一化(BatchNorm2d)
Jan 15 Python
Python基于pyecharts实现关联图绘制
Mar 27 Python
jupyter notebook tensorflow打印device信息实例
Apr 20 Python
python中如何写类
Jun 29 Python
解决运行出现'dict' object has no attribute 'has_key'问题
Jul 15 Python
使用Selenium实现微博爬虫(预登录、展开全文、翻页)
Apr 13 Python
Python批量将csv文件转化成xml文件的实例
May 10 Python
pytorch 归一化与反归一化实例
Dec 31 #Python
Pytorch 数据加载与数据预处理方式
Dec 31 #Python
pytorch 数据处理:定义自己的数据集合实例
Dec 31 #Python
pytorch: Parameter 的数据结构实例
Dec 31 #Python
Python测试线程应用程序过程解析
Dec 31 #Python
Python TCPServer 多线程多客户端通信的实现
Dec 31 #Python
python给指定csv表格中的联系人群发邮件(带附件的邮件)
Dec 31 #Python
You might like
咖啡豆要不要放冰箱的原因
2021/03/04 冲泡冲煮
PHP设置头信息及取得返回头信息的方法
2016/01/25 PHP
php创建图像具体步骤
2017/03/13 PHP
PHP 的Opcache加速的使用方法
2017/12/29 PHP
javascript中的一些注意事项 更新中
2010/12/06 Javascript
jQuery ajax serialize() 方法使用示例
2014/11/02 Javascript
AngularJS模块管理问题的非常规处理方法
2015/04/29 Javascript
轻松使用jQuery双向select控件Bootstrap Dual Listbox
2015/12/13 Javascript
jQuery的实例及必知重要的jQuery选择器详解
2016/05/20 Javascript
总结JavaScript设计模式编程中的享元模式使用
2016/05/21 Javascript
AngularJS Phonecat实例讲解
2016/11/21 Javascript
JavaScript设计模式之代理模式详解
2017/06/09 Javascript
JS散列表碰撞处理、开链法、HashTable散列示例
2019/02/08 Javascript
Python open()文件处理使用介绍
2014/11/30 Python
Windows下使Python2.x版本的解释器与3.x共存的方法
2015/10/25 Python
Python实现简单的多任务mysql转xml的方法
2017/02/08 Python
浅谈python中的数字类型与处理工具
2017/08/02 Python
详解Python自建logging模块
2018/01/29 Python
Python求均值,方差,标准差的实例
2019/06/29 Python
Python何时应该使用Lambda函数
2019/07/02 Python
python多进程并行代码实例
2019/09/30 Python
详解Python 实现 ZeroMQ 的三种基本工作模式
2020/03/24 Python
解决Jupyter Notebook使用parser.parse_args出现错误问题
2020/04/20 Python
python+adb+monkey实现Rom稳定性测试详解
2020/04/23 Python
H5仿微信界面教程(一)
2017/07/05 HTML / CSS
小学教师的自我评价范例
2013/10/31 职场文书
安全月活动总结
2014/05/05 职场文书
工作犯错保证书
2015/05/11 职场文书
单位病假条范文
2015/08/17 职场文书
MySql 8.0及对应驱动包匹配的注意点说明
2021/06/23 MySQL
浅谈Redis位图(Bitmap)及Redis二进制中的问题
2021/07/15 Redis
MySQL 1130异常,无法远程登录解决方案详解
2021/08/23 MySQL
sass 常用备忘案例详解
2021/09/15 HTML / CSS
MySQL分库分表详情
2021/09/25 MySQL
python模块与C和C++动态库相互调用实现过程示例
2021/11/02 Python
canvas实现贪食蛇的实践
2022/02/15 Javascript