Pytorch 实现权重初始化


Posted in Python onDecember 31, 2019

在TensorFlow中,权重的初始化主要是在声明张量的时候进行的。 而PyTorch则提供了另一种方法:首先应该声明张量,然后修改张量的权重。通过调用torch.nn.init包中的多种方法可以将权重初始化为直接访问张量的属性。

1、不初始化的效果

在Pytorch中,定义一个tensor,不进行初始化,打印看看结果:

w = torch.Tensor(3,4)
print (w)

可以看到这时候的初始化的数值都是随机的,而且特别大,这对网络的训练必定不好,最后导致精度提不上,甚至损失无法收敛。

2、初始化的效果

PyTorch提供了多种参数初始化函数:

torch.nn.init.constant(tensor, val)
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.xavier_uniform(tensor, gain=1)

等等。详细请参考:http://pytorch.org/docs/nn.html#torch-nn-init

注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。实际上,我们初始化也是直接去初始化神经网络的参数。

让我们试试效果:

w = torch.Tensor(3,4)
torch.nn.init.normal_(w)
print (w)

3、初始化神经网络的参数

对神经网络的初始化往往放在模型的__init__()函数中,如下所示:

class Net(nn.Module):

def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(Net, self).__init__()
  ***
  *** #定义自己的网络层
  ***

  for m in self.modules():
    if isinstance(m, nn.Conv2d):
      n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
      m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
      m.weight.data.fill_(1)
      m.bias.data.zero_()

***
*** #定义后续的函数
***

也可以采取另一种方式:

定义一个权重初始化函数,如下:

def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv2d') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)
  elif classname.find('Linear') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)

在模型声明时,调用初始化函数,初始化神经网络参数:

model = Net(*****)
model.apply(weights_init)

以上这篇Pytorch 实现权重初始化就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中Iterator迭代器的使用杂谈
Jun 20 Python
详解Django-auth-ldap 配置方法
Dec 10 Python
python flask框架实现重定向功能示例
Jul 02 Python
python批量读取文件名并写入txt文件中
Sep 05 Python
如何基于Python批量下载音乐
Nov 11 Python
基于Python实现扑克牌面试题
Dec 11 Python
python生成大写32位uuid代码
Mar 03 Python
Python自动化测试中yaml文件读取操作
Aug 20 Python
Jupyter Notebook安装及使用方法解析
Nov 12 Python
Python自动化测试基础必备知识点总结
Feb 07 Python
Python中生成ndarray实例讲解
Feb 22 Python
Python基础之教你怎么在M1系统上使用pandas
May 08 Python
pytorch 归一化与反归一化实例
Dec 31 #Python
Pytorch 数据加载与数据预处理方式
Dec 31 #Python
pytorch 数据处理:定义自己的数据集合实例
Dec 31 #Python
pytorch: Parameter 的数据结构实例
Dec 31 #Python
Python测试线程应用程序过程解析
Dec 31 #Python
Python TCPServer 多线程多客户端通信的实现
Dec 31 #Python
python给指定csv表格中的联系人群发邮件(带附件的邮件)
Dec 31 #Python
You might like
基于qmail的完整WEBMAIL解决方案安装详解
2006/10/09 PHP
php获取远程图片的两种 CURL方式和sockets方式获取远程图片
2011/11/07 PHP
PHP封装的一个支持HTML、JS、PHP重定向的多功能跳转函数
2014/06/19 PHP
PHP MVC框架skymvc支持多文件上传
2016/05/26 PHP
Javascript remove 自定义数组删除方法
2009/10/20 Javascript
JQuery select标签操作代码段
2010/05/16 Javascript
jQueryUI写一个调整分类的拖放效果实现代码
2012/05/10 Javascript
幻灯片带网页设计中的20个奇妙应用示例小结
2012/05/27 Javascript
javascript实现信息的显示和隐藏如注册页面
2013/12/03 Javascript
Jquery中国地图热点效果-鼠标经过弹出提示层信息的简单实例
2014/02/12 Javascript
JS实现控制表格行内容垂直对齐的方法
2015/03/30 Javascript
JavaScript直播评论发弹幕切图功能点集合效果代码
2016/06/26 Javascript
基于JS实现的随机数字抽签实例
2016/12/08 Javascript
vuejs开发组件分享之H5图片上传、压缩及拍照旋转的问题处理
2017/03/06 Javascript
关于javascript作用域的常见面试题分享
2017/06/18 Javascript
详解Angular 开发环境搭建
2017/06/22 Javascript
利用d3.js实现蜂巢图表带动画效果
2019/09/03 Javascript
layui实现数据表格隐藏列的示例
2019/10/25 Javascript
浅谈Vue.use到底是什么鬼
2020/01/21 Javascript
Python ORM框架SQLAlchemy学习笔记之安装和简单查询实例
2014/06/10 Python
快速入手Python字符编码
2016/08/03 Python
快速了解Python相对导入
2018/01/12 Python
浅谈Django学习migrate和makemigrations的差别
2018/01/18 Python
删除DataFrame中值全为NaN或者包含有NaN的列或行方法
2018/11/06 Python
python实现飞机大战游戏(pygame版)
2020/10/26 Python
Python爬虫headers处理及网络超时问题解决方案
2020/06/19 Python
python 贪心算法的实现
2020/09/18 Python
Scrapy爬虫文件批量运行的实现
2020/09/30 Python
解释下面关于J2EE的名词
2013/11/15 面试题
房地产管理毕业生自荐信
2013/11/04 职场文书
电子商务网站的创业计划书
2014/01/05 职场文书
中华美德颂演讲稿
2014/05/20 职场文书
工厂仓管员岗位职责
2015/04/01 职场文书
学历证明样本
2015/06/16 职场文书
幼儿园语言教学反思
2016/02/23 职场文书
 Python 中 logging 模块使用详情
2022/03/03 Python