Pytorch - TORCH.NN.INIT 参数初始化的操作


Posted in Python onFebruary 27, 2021

路径:

https://pytorch.org/docs/master/nn.init.html#nn-init-doc

初始化函数:torch.nn.init

# -*- coding: utf-8 -*-
"""
Created on 2019
@author: fancp
"""
import torch 
import torch.nn as nn
w = torch.empty(3,5)
#1.均匀分布 - u(a,b)
#torch.nn.init.uniform_(tensor, a=0.0, b=1.0)
print(nn.init.uniform_(w))
# =============================================================================
# tensor([[0.9160, 0.1832, 0.5278, 0.5480, 0.6754],
#     [0.9509, 0.8325, 0.9149, 0.8192, 0.9950],
#     [0.4847, 0.4148, 0.8161, 0.0948, 0.3787]])
# =============================================================================
#2.正态分布 - N(mean, std)
#torch.nn.init.normal_(tensor, mean=0.0, std=1.0)
print(nn.init.normal_(w))
# =============================================================================
# tensor([[ 0.4388, 0.3083, -0.6803, -1.1476, -0.6084],
#     [ 0.5148, -0.2876, -1.2222, 0.6990, -0.1595],
#     [-2.0834, -1.6288, 0.5057, -0.5754, 0.3052]])
# =============================================================================
#3.常数 - 固定值 val
#torch.nn.init.constant_(tensor, val)
print(nn.init.constant_(w, 0.3))
# =============================================================================
# tensor([[0.3000, 0.3000, 0.3000, 0.3000, 0.3000],
#     [0.3000, 0.3000, 0.3000, 0.3000, 0.3000],
#     [0.3000, 0.3000, 0.3000, 0.3000, 0.3000]])
# =============================================================================
#4.全1分布
#torch.nn.init.ones_(tensor)
print(nn.init.ones_(w))
# =============================================================================
# tensor([[1., 1., 1., 1., 1.],
#     [1., 1., 1., 1., 1.],
#     [1., 1., 1., 1., 1.]])
# =============================================================================
#5.全0分布
#torch.nn.init.zeros_(tensor)
print(nn.init.zeros_(w))
# =============================================================================
# tensor([[0., 0., 0., 0., 0.],
#     [0., 0., 0., 0., 0.],
#     [0., 0., 0., 0., 0.]])
# =============================================================================
#6.对角线为 1,其它为 0
#torch.nn.init.eye_(tensor)
print(nn.init.eye_(w))
# =============================================================================
# tensor([[1., 0., 0., 0., 0.],
#     [0., 1., 0., 0., 0.],
#     [0., 0., 1., 0., 0.]])
# =============================================================================
#7.xavier_uniform 初始化
#torch.nn.init.xavier_uniform_(tensor, gain=1.0)
#From - Understanding the difficulty of training deep feedforward neural networks - Bengio 2010
print(nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')))
# =============================================================================
# tensor([[-0.1270, 0.3963, 0.9531, -0.2949, 0.8294],
#     [-0.9759, -0.6335, 0.9299, -1.0988, -0.1496],
#     [-0.7224, 0.2181, -1.1219, 0.8629, -0.8825]])
# =============================================================================
#8.xavier_normal 初始化
#torch.nn.init.xavier_normal_(tensor, gain=1.0)
print(nn.init.xavier_normal_(w))
# =============================================================================
# tensor([[ 1.0463, 0.1275, -0.3752, 0.1858, 1.1008],
#     [-0.5560, 0.2837, 0.1000, -0.5835, 0.7886],
#     [-0.2417, 0.1763, -0.7495, 0.4677, -0.1185]])
# =============================================================================
#9.kaiming_uniform 初始化
#torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
#From - Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - HeKaiming 2015
print(nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu'))
# =============================================================================
# tensor([[-0.7712, 0.9344, 0.8304, 0.2367, 0.0478],
#     [-0.6139, -0.3916, -0.0835, 0.5975, 0.1717],
#     [ 0.3197, -0.9825, -0.5380, -1.0033, -0.3701]])
# =============================================================================
#10.kaiming_normal 初始化
#torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
print(nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu'))
# =============================================================================
# tensor([[-0.0210, 0.5532, -0.8647, 0.9813, 0.0466],
#     [ 0.7713, -1.0418, 0.7264, 0.5547, 0.7403],
#     [-0.8471, -1.7371, 1.3333, 0.0395, 1.0787]])
# =============================================================================
#11.正交矩阵 - (semi)orthogonal matrix
#torch.nn.init.orthogonal_(tensor, gain=1)
#From - Exact solutions to the nonlinear dynamics of learning in deep linear neural networks - Saxe 2013
print(nn.init.orthogonal_(w))
# =============================================================================
# tensor([[-0.0346, -0.7607, -0.0428, 0.4771, 0.4366],
#     [-0.0412, -0.0836, 0.9847, 0.0703, -0.1293],
#     [-0.6639, 0.4551, 0.0731, 0.1674, 0.5646]])
# =============================================================================
#12.稀疏矩阵 - sparse matrix 
#torch.nn.init.sparse_(tensor, sparsity, std=0.01)
#From - Deep learning via Hessian-free optimization - Martens 2010
print(nn.init.sparse_(w, sparsity=0.1))
# =============================================================================
# tensor([[ 0.0000, 0.0000, -0.0077, 0.0000, -0.0046],
#     [ 0.0152, 0.0030, 0.0000, -0.0029, 0.0005],
#     [ 0.0199, 0.0132, -0.0088, 0.0060, 0.0000]])
# =============================================================================

补充:【pytorch参数初始化】 pytorch默认参数初始化以及自定义参数初始化

本文用两个问题来引入

1.pytorch自定义网络结构不进行参数初始化会怎样,参数值是随机的吗?

2.如何自定义参数初始化?

先回答第一个问题

在pytorch中,有自己默认初始化参数方式,所以在你定义好网络结构以后,不进行参数初始化也是可以的。

1.Conv2d继承自_ConvNd,在_ConvNd中,可以看到默认参数就是进行初始化的,如下图所示

Pytorch - TORCH.NN.INIT 参数初始化的操作

Pytorch - TORCH.NN.INIT 参数初始化的操作

2.torch.nn.BatchNorm2d也一样有默认初始化的方式

Pytorch - TORCH.NN.INIT 参数初始化的操作

3.torch.nn.Linear也如此

Pytorch - TORCH.NN.INIT 参数初始化的操作

现在来回答第二个问题。

pytorch中对神经网络模型中的参数进行初始化方法如下:

from torch.nn import init
#define the initial function to init the layer's parameters for the network
def weigth_init(m):
  if isinstance(m, nn.Conv2d):
    init.xavier_uniform_(m.weight.data)
    init.constant_(m.bias.data,0.1)
  elif isinstance(m, nn.BatchNorm2d):
    m.weight.data.fill_(1)
    m.bias.data.zero_()
  elif isinstance(m, nn.Linear):
    m.weight.data.normal_(0,0.01)
    m.bias.data.zero_()

首先定义了一个初始化函数,接着进行调用就ok了,不过要先把网络模型实例化:

#Define Network
  model = Net(args.input_channel,args.output_channel)
  model.apply(weigth_init)

此上就完成了对模型中训练参数的初始化。

在知乎上也有看到一个类似的版本,也相应的贴上来作为参考了:

def initNetParams(net):
  '''Init net parameters.'''
  for m in net.modules():
    if isinstance(m, nn.Conv2d):
      init.xavier_uniform(m.weight)
      if m.bias:
        init.constant(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
      init.constant(m.weight, 1)
      init.constant(m.bias, 0)
    elif isinstance(m, nn.Linear):
      init.normal(m.weight, std=1e-3)
      if m.bias:
        init.constant(m.bias, 0) 
initNetParams(net)

再说一下关于模型的保存及加载

1.保存有两种方式,第一种是保存模型的整个结构信息和参数,第二种是只保存模型的参数

#保存整个网络模型及参数
 torch.save(net, 'net.pkl') 
 
 #仅保存模型参数
 torch.save(net.state_dict(), 'net_params.pkl')

2.加载对应保存的两种网络

# 保存和加载整个模型 
torch.save(model_object, 'model.pth') 
model = torch.load('model.pth') 
 
# 仅保存和加载模型参数 
torch.save(model_object.state_dict(), 'params.pth') 
model_object.load_state_dict(torch.load('params.pth'))

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python实现的Kmeans++算法实例
Apr 26 Python
python+mysql实现简单的web程序
Sep 11 Python
在Python中使用swapCase()方法转换大小写的教程
May 20 Python
python3中zip()函数使用详解
Jun 29 Python
对python pandas 画移动平均线的方法详解
Nov 28 Python
Python中按值来获取指定的键
Mar 04 Python
django搭建项目配置环境和创建表过程详解
Jul 22 Python
python单线程下实现多个socket并发过程详解
Jul 27 Python
tensorflow的ckpt及pb模型持久化方式及转化详解
Feb 12 Python
浅谈在django中使用filter()(即对QuerySet操作)时踩的坑
Mar 31 Python
python高级特性简介
Aug 13 Python
如何使用scrapy中的ItemLoader提取数据
Sep 30 Python
python FTP编程基础入门
Feb 27 #Python
python SOCKET编程基础入门
Feb 27 #Python
python 对xml解析的示例
Feb 27 #Python
python如何发送带有附件、正文为HTML的邮件
Feb 27 #Python
pytorch __init__、forward与__call__的用法小结
Feb 27 #Python
python 实现有道翻译功能
Feb 26 #Python
Python爬取酷狗MP3音频的步骤
Feb 26 #Python
You might like
教你如何把一篇文章按要求分段
2006/10/09 PHP
PHP实现定时生成HTML网站首页实例代码
2008/11/20 PHP
php 连接mysql连接被重置的解决方法
2011/02/15 PHP
PHP随机生成唯一HASH值自定义函数
2015/04/20 PHP
PHP中filter函数校验数据的方法详解
2015/07/31 PHP
PHP Pipeline 实现中间件的示例代码
2020/04/26 PHP
通过javascript设置css属性的代码
2009/12/28 Javascript
require简单实现单页应用程序(SPA)
2016/07/12 Javascript
第一次接触神奇的Bootstrap菜单和导航
2016/08/01 Javascript
学习React中ref的两个demo示例
2018/08/14 Javascript
react在安卓中输入框被手机键盘遮挡问题的解决方法
2018/09/03 Javascript
vue-router的使用方法及含参数的配置方法
2018/11/13 Javascript
vue项目设置scrollTop不起作用(总结)
2018/12/21 Javascript
js console.log打印对象时属性缺失的解决方法
2019/05/23 Javascript
ES6基础之 Promise 对象用法实例详解
2019/08/22 Javascript
[57:22]完美世界DOTA2联赛PWL S2 FTD vs PXG 第二场 11.27
2020/12/01 DOTA
Python实现根据指定端口探测服务器/模块部署的方法
2014/08/25 Python
浅谈python 线程池threadpool之实现
2017/11/17 Python
django上传图片并生成缩略图方法示例
2017/12/11 Python
详解Python3的TFTP文件传输
2018/06/26 Python
python selenium firefox使用详解
2019/02/26 Python
python实现PID算法及测试的例子
2019/08/08 Python
使用python将最新的测试报告以附件的形式发到指定邮箱
2019/09/20 Python
浅析python内置模块collections
2019/11/15 Python
Pytorch 的损失函数Loss function使用详解
2020/01/02 Python
小 200 行 Python 代码制作一个换脸程序
2020/05/12 Python
pytorch 把图片数据转化成tensor的操作
2021/03/04 Python
美国受信赖的教育产品供应商:Nest Learning
2018/06/14 全球购物
俄罗斯连接商品和买家的在线平台:goods.ru
2020/11/30 全球购物
WINDOWS域的具体实现方式是什么
2014/02/20 面试题
奥巴马演讲稿
2014/01/08 职场文书
家长给幼儿园的表扬信
2014/01/09 职场文书
党员活动日总结
2014/05/05 职场文书
个人反四风对照检查材料思想汇报
2014/09/23 职场文书
2016大学先进团支部事迹材料
2016/03/01 职场文书
go语言使用Casbin实现角色的权限控制
2021/06/26 Golang