Pytorch - TORCH.NN.INIT 参数初始化的操作


Posted in Python onFebruary 27, 2021

路径:

https://pytorch.org/docs/master/nn.init.html#nn-init-doc

初始化函数:torch.nn.init

# -*- coding: utf-8 -*-
"""
Created on 2019
@author: fancp
"""
import torch 
import torch.nn as nn
w = torch.empty(3,5)
#1.均匀分布 - u(a,b)
#torch.nn.init.uniform_(tensor, a=0.0, b=1.0)
print(nn.init.uniform_(w))
# =============================================================================
# tensor([[0.9160, 0.1832, 0.5278, 0.5480, 0.6754],
#     [0.9509, 0.8325, 0.9149, 0.8192, 0.9950],
#     [0.4847, 0.4148, 0.8161, 0.0948, 0.3787]])
# =============================================================================
#2.正态分布 - N(mean, std)
#torch.nn.init.normal_(tensor, mean=0.0, std=1.0)
print(nn.init.normal_(w))
# =============================================================================
# tensor([[ 0.4388, 0.3083, -0.6803, -1.1476, -0.6084],
#     [ 0.5148, -0.2876, -1.2222, 0.6990, -0.1595],
#     [-2.0834, -1.6288, 0.5057, -0.5754, 0.3052]])
# =============================================================================
#3.常数 - 固定值 val
#torch.nn.init.constant_(tensor, val)
print(nn.init.constant_(w, 0.3))
# =============================================================================
# tensor([[0.3000, 0.3000, 0.3000, 0.3000, 0.3000],
#     [0.3000, 0.3000, 0.3000, 0.3000, 0.3000],
#     [0.3000, 0.3000, 0.3000, 0.3000, 0.3000]])
# =============================================================================
#4.全1分布
#torch.nn.init.ones_(tensor)
print(nn.init.ones_(w))
# =============================================================================
# tensor([[1., 1., 1., 1., 1.],
#     [1., 1., 1., 1., 1.],
#     [1., 1., 1., 1., 1.]])
# =============================================================================
#5.全0分布
#torch.nn.init.zeros_(tensor)
print(nn.init.zeros_(w))
# =============================================================================
# tensor([[0., 0., 0., 0., 0.],
#     [0., 0., 0., 0., 0.],
#     [0., 0., 0., 0., 0.]])
# =============================================================================
#6.对角线为 1,其它为 0
#torch.nn.init.eye_(tensor)
print(nn.init.eye_(w))
# =============================================================================
# tensor([[1., 0., 0., 0., 0.],
#     [0., 1., 0., 0., 0.],
#     [0., 0., 1., 0., 0.]])
# =============================================================================
#7.xavier_uniform 初始化
#torch.nn.init.xavier_uniform_(tensor, gain=1.0)
#From - Understanding the difficulty of training deep feedforward neural networks - Bengio 2010
print(nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')))
# =============================================================================
# tensor([[-0.1270, 0.3963, 0.9531, -0.2949, 0.8294],
#     [-0.9759, -0.6335, 0.9299, -1.0988, -0.1496],
#     [-0.7224, 0.2181, -1.1219, 0.8629, -0.8825]])
# =============================================================================
#8.xavier_normal 初始化
#torch.nn.init.xavier_normal_(tensor, gain=1.0)
print(nn.init.xavier_normal_(w))
# =============================================================================
# tensor([[ 1.0463, 0.1275, -0.3752, 0.1858, 1.1008],
#     [-0.5560, 0.2837, 0.1000, -0.5835, 0.7886],
#     [-0.2417, 0.1763, -0.7495, 0.4677, -0.1185]])
# =============================================================================
#9.kaiming_uniform 初始化
#torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
#From - Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - HeKaiming 2015
print(nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu'))
# =============================================================================
# tensor([[-0.7712, 0.9344, 0.8304, 0.2367, 0.0478],
#     [-0.6139, -0.3916, -0.0835, 0.5975, 0.1717],
#     [ 0.3197, -0.9825, -0.5380, -1.0033, -0.3701]])
# =============================================================================
#10.kaiming_normal 初始化
#torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
print(nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu'))
# =============================================================================
# tensor([[-0.0210, 0.5532, -0.8647, 0.9813, 0.0466],
#     [ 0.7713, -1.0418, 0.7264, 0.5547, 0.7403],
#     [-0.8471, -1.7371, 1.3333, 0.0395, 1.0787]])
# =============================================================================
#11.正交矩阵 - (semi)orthogonal matrix
#torch.nn.init.orthogonal_(tensor, gain=1)
#From - Exact solutions to the nonlinear dynamics of learning in deep linear neural networks - Saxe 2013
print(nn.init.orthogonal_(w))
# =============================================================================
# tensor([[-0.0346, -0.7607, -0.0428, 0.4771, 0.4366],
#     [-0.0412, -0.0836, 0.9847, 0.0703, -0.1293],
#     [-0.6639, 0.4551, 0.0731, 0.1674, 0.5646]])
# =============================================================================
#12.稀疏矩阵 - sparse matrix 
#torch.nn.init.sparse_(tensor, sparsity, std=0.01)
#From - Deep learning via Hessian-free optimization - Martens 2010
print(nn.init.sparse_(w, sparsity=0.1))
# =============================================================================
# tensor([[ 0.0000, 0.0000, -0.0077, 0.0000, -0.0046],
#     [ 0.0152, 0.0030, 0.0000, -0.0029, 0.0005],
#     [ 0.0199, 0.0132, -0.0088, 0.0060, 0.0000]])
# =============================================================================

补充:【pytorch参数初始化】 pytorch默认参数初始化以及自定义参数初始化

本文用两个问题来引入

1.pytorch自定义网络结构不进行参数初始化会怎样,参数值是随机的吗?

2.如何自定义参数初始化?

先回答第一个问题

在pytorch中,有自己默认初始化参数方式,所以在你定义好网络结构以后,不进行参数初始化也是可以的。

1.Conv2d继承自_ConvNd,在_ConvNd中,可以看到默认参数就是进行初始化的,如下图所示

Pytorch - TORCH.NN.INIT 参数初始化的操作

Pytorch - TORCH.NN.INIT 参数初始化的操作

2.torch.nn.BatchNorm2d也一样有默认初始化的方式

Pytorch - TORCH.NN.INIT 参数初始化的操作

3.torch.nn.Linear也如此

Pytorch - TORCH.NN.INIT 参数初始化的操作

现在来回答第二个问题。

pytorch中对神经网络模型中的参数进行初始化方法如下:

from torch.nn import init
#define the initial function to init the layer's parameters for the network
def weigth_init(m):
  if isinstance(m, nn.Conv2d):
    init.xavier_uniform_(m.weight.data)
    init.constant_(m.bias.data,0.1)
  elif isinstance(m, nn.BatchNorm2d):
    m.weight.data.fill_(1)
    m.bias.data.zero_()
  elif isinstance(m, nn.Linear):
    m.weight.data.normal_(0,0.01)
    m.bias.data.zero_()

首先定义了一个初始化函数,接着进行调用就ok了,不过要先把网络模型实例化:

#Define Network
  model = Net(args.input_channel,args.output_channel)
  model.apply(weigth_init)

此上就完成了对模型中训练参数的初始化。

在知乎上也有看到一个类似的版本,也相应的贴上来作为参考了:

def initNetParams(net):
  '''Init net parameters.'''
  for m in net.modules():
    if isinstance(m, nn.Conv2d):
      init.xavier_uniform(m.weight)
      if m.bias:
        init.constant(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
      init.constant(m.weight, 1)
      init.constant(m.bias, 0)
    elif isinstance(m, nn.Linear):
      init.normal(m.weight, std=1e-3)
      if m.bias:
        init.constant(m.bias, 0) 
initNetParams(net)

再说一下关于模型的保存及加载

1.保存有两种方式,第一种是保存模型的整个结构信息和参数,第二种是只保存模型的参数

#保存整个网络模型及参数
 torch.save(net, 'net.pkl') 
 
 #仅保存模型参数
 torch.save(net.state_dict(), 'net_params.pkl')

2.加载对应保存的两种网络

# 保存和加载整个模型 
torch.save(model_object, 'model.pth') 
model = torch.load('model.pth') 
 
# 仅保存和加载模型参数 
torch.save(model_object.state_dict(), 'params.pth') 
model_object.load_state_dict(torch.load('params.pth'))

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python实现矩阵乘法的方法
Jun 28 Python
Python自定义主从分布式架构实例分析
Sep 19 Python
解决Linux系统中python matplotlib画图的中文显示问题
Jun 15 Python
python3.0 模拟用户登录,三次错误锁定的实例
Nov 02 Python
Windows下python3.7安装教程
Jul 31 Python
Python实现微信翻译机器人的方法
Aug 13 Python
python实现录屏功能(亲测好用)
Mar 02 Python
Matplotlib使用Cursor实现UI定位的示例代码
Mar 12 Python
pandas使用之宽表变窄表的实现
Apr 12 Python
简单了解python shutil模块原理及使用方法
Apr 28 Python
Python+Xlwings 删除Excel的行和列
Dec 19 Python
python文件名批量重命名脚本实例代码
Apr 22 Python
python FTP编程基础入门
Feb 27 #Python
python SOCKET编程基础入门
Feb 27 #Python
python 对xml解析的示例
Feb 27 #Python
python如何发送带有附件、正文为HTML的邮件
Feb 27 #Python
pytorch __init__、forward与__call__的用法小结
Feb 27 #Python
python 实现有道翻译功能
Feb 26 #Python
Python爬取酷狗MP3音频的步骤
Feb 26 #Python
You might like
使用JSON实现数据的跨域传输的php代码
2011/12/20 PHP
php resizeimage 部分jpg文件 生成缩略图失败的原因分析及解决办法
2016/03/23 PHP
PHP面向对象中new self()与 new static()的区别浅析
2017/08/17 PHP
解决laravel资源加载路径设置的问题
2019/10/14 PHP
node.js应用后台守护进程管理器Forever安装和使用实例
2014/06/01 Javascript
判断数组是否包含某个元素的js函数实现方法
2016/05/19 Javascript
Javascript中apply、call、bind的巧妙使用
2016/08/18 Javascript
利用vue-router实现二级菜单内容转换
2016/11/30 Javascript
React 子组件向父组件传值的方法
2017/07/24 Javascript
Vue 将后台传过来的带html字段的字符串转换为 HTML
2018/03/29 Javascript
nodejs express配置自签名https服务器的方法
2018/05/22 NodeJs
使用Javascript简单计算器
2018/11/17 Javascript
用element的upload组件实现多图片上传和压缩的示例代码
2019/02/12 Javascript
详解关于Vuex的action传入多个参数的问题
2019/02/22 Javascript
vue实现鼠标移入移出事件代码实例
2019/03/27 Javascript
jQuery 隐藏/显示效果函数用法实例分析
2020/05/20 jQuery
详解js中的几种常用设计模式
2020/07/16 Javascript
vue实现虚拟列表功能的代码
2020/07/28 Javascript
基于element-ui对话框el-dialog初始化的校验问题解决
2020/09/11 Javascript
在vue项目中promise解决回调地狱和并发请求的问题
2020/11/09 Javascript
python实现按行切分文本文件的方法
2016/04/18 Python
Python装饰器用法示例小结
2018/02/11 Python
wxpython实现图书管理系统
2018/03/12 Python
matplotlib实现热成像图colorbar和极坐标图的方法
2018/12/13 Python
对python中基于tcp协议的通信(数据传输)实例讲解
2019/07/22 Python
关于pytorch中网络loss传播和参数更新的理解
2019/08/20 Python
pytorch 实现查看网络中的参数
2020/01/06 Python
Python Pillow(PIL)库的用法详解
2020/09/19 Python
python 模拟登陆163邮箱
2020/12/15 Python
简述Html5 IphoneX 适配方法
2018/02/08 HTML / CSS
类和结构的区别
2012/08/15 面试题
Linux内核的同步机制是什么?主要有哪几种内核锁
2016/07/11 面试题
犯错检讨书
2014/02/21 职场文书
乡镇组织委员个人整改措施
2014/09/16 职场文书
工作态度恶劣检讨书
2015/05/06 职场文书
python中 .npy文件的读写操作实例
2022/04/14 Python