Pytorch 实现权重初始化


Posted in Python onDecember 31, 2019

在TensorFlow中,权重的初始化主要是在声明张量的时候进行的。 而PyTorch则提供了另一种方法:首先应该声明张量,然后修改张量的权重。通过调用torch.nn.init包中的多种方法可以将权重初始化为直接访问张量的属性。

1、不初始化的效果

在Pytorch中,定义一个tensor,不进行初始化,打印看看结果:

w = torch.Tensor(3,4)
print (w)

可以看到这时候的初始化的数值都是随机的,而且特别大,这对网络的训练必定不好,最后导致精度提不上,甚至损失无法收敛。

2、初始化的效果

PyTorch提供了多种参数初始化函数:

torch.nn.init.constant(tensor, val)
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.xavier_uniform(tensor, gain=1)

等等。详细请参考:http://pytorch.org/docs/nn.html#torch-nn-init

注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。实际上,我们初始化也是直接去初始化神经网络的参数。

让我们试试效果:

w = torch.Tensor(3,4)
torch.nn.init.normal_(w)
print (w)

3、初始化神经网络的参数

对神经网络的初始化往往放在模型的__init__()函数中,如下所示:

class Net(nn.Module):

def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(Net, self).__init__()
  ***
  *** #定义自己的网络层
  ***

  for m in self.modules():
    if isinstance(m, nn.Conv2d):
      n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
      m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
      m.weight.data.fill_(1)
      m.bias.data.zero_()

***
*** #定义后续的函数
***

也可以采取另一种方式:

定义一个权重初始化函数,如下:

def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv2d') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)
  elif classname.find('Linear') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)

在模型声明时,调用初始化函数,初始化神经网络参数:

model = Net(*****)
model.apply(weights_init)

以上这篇Pytorch 实现权重初始化就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用MD5加密字符串示例
Aug 22 Python
Python工程师面试题 与Python Web相关
Jan 14 Python
python使用SMTP发送qq或sina邮件
Oct 21 Python
Python及PyCharm下载与安装教程
Nov 18 Python
利用selenium 3.7和python3添加cookie模拟登陆的实现
Nov 20 Python
Python使用numpy模块创建数组操作示例
Jun 20 Python
分享vim python缩进等一些配置
Jul 02 Python
利用PyQt中的QThread类实现多线程
Feb 18 Python
python解析xml文件方式(解析、更新、写入)
Mar 05 Python
基于Python pyecharts实现多种图例代码解析
Aug 10 Python
Python urlopen()参数代码示例解析
Dec 10 Python
python 实现网易邮箱邮件阅读和删除的辅助小脚本
Mar 01 Python
pytorch 归一化与反归一化实例
Dec 31 #Python
Pytorch 数据加载与数据预处理方式
Dec 31 #Python
pytorch 数据处理:定义自己的数据集合实例
Dec 31 #Python
pytorch: Parameter 的数据结构实例
Dec 31 #Python
Python测试线程应用程序过程解析
Dec 31 #Python
Python TCPServer 多线程多客户端通信的实现
Dec 31 #Python
python给指定csv表格中的联系人群发邮件(带附件的邮件)
Dec 31 #Python
You might like
php的zip解压缩类pclzip使用示例
2014/03/14 PHP
JavaScript 定义function的三种方式小结
2009/10/16 Javascript
获取数组中最大最小值方法js代码(自写)
2013/08/12 Javascript
JS获取月的最后一天与JS得到一个月份最大天数的实例代码
2013/12/16 Javascript
浅析jquery的js图表组件highcharts
2014/03/06 Javascript
再谈Jquery Ajax方法传递到action(补充)
2014/05/12 Javascript
nodejs调用cmd命令实现复制目录
2015/05/04 NodeJs
javascript实现数字倒计时特效
2016/03/30 Javascript
jQuery Validate格式验证功能实例代码(包括重名验证)
2017/07/18 jQuery
详解node.js的http模块实例演示
2018/07/12 Javascript
node 解析图片二维码的内容代码实例
2019/09/11 Javascript
微信头像地址失效踩坑记附带解决方案
2019/09/23 Javascript
vue 防止页面加载时看到花括号的解决操作
2020/11/09 Javascript
js实现鼠标切换图片(无定时器)
2021/01/27 Javascript
Python中的yield浅析
2014/06/16 Python
Python实现数通设备端口使用情况监控实例
2015/07/15 Python
Python如何import文件夹下的文件(实现方法)
2017/01/24 Python
python实现下载文件的三种方法
2017/02/09 Python
python3中int(整型)的使用教程
2017/03/23 Python
通过Python 获取Android设备信息的轻量级框架
2017/12/18 Python
python中join()方法介绍
2018/10/11 Python
python配置grpc环境
2019/01/01 Python
python 输出所有大小写字母的方法
2019/01/02 Python
flask应用部署到服务器的方法
2019/07/12 Python
PurCotton全棉时代官网:100%天然棉花生产的生活护理用品
2016/11/18 全球购物
美国电子产品主要品牌的授权在线零售商:DataVision
2019/03/23 全球购物
思想政治自我鉴定
2013/10/06 职场文书
面试后的英文感谢信
2014/02/01 职场文书
冰淇淋开店创业计划书
2014/02/01 职场文书
仓库规划计划书
2014/04/28 职场文书
社区科普工作方案
2014/06/03 职场文书
社团活动总结模板
2014/06/30 职场文书
分居协议书范本
2014/11/03 职场文书
承诺函范文
2015/01/21 职场文书
Go语言使用select{}阻塞main函数介绍
2021/04/25 Golang
python实现简单聊天功能
2021/07/07 Python