pytorch  网络参数 weight bias 初始化详解


Posted in Python onJune 24, 2020

权重初始化对于训练神经网络至关重要,好的初始化权重可以有效的避免梯度消失等问题的发生。

在pytorch的使用过程中有几种权重初始化的方法供大家参考。

注意:第一种方法不推荐。尽量使用后两种方法。

# not recommend
def weights_init(m):
 classname = m.__class__.__name__
 if classname.find('Conv') != -1:
  m.weight.data.normal_(0.0, 0.02)
 elif classname.find('BatchNorm') != -1:
  m.weight.data.normal_(1.0, 0.02)
  m.bias.data.fill_(0)
# recommend
def initialize_weights(m):
 if isinstance(m, nn.Conv2d):
  m.weight.data.normal_(0, 0.02)
  m.bias.data.zero_()
 elif isinstance(m, nn.Linear):
  m.weight.data.normal_(0, 0.02)
  m.bias.data.zero_()
# recommend
def weights_init(m): 
 if isinstance(m, nn.Conv2d): 
  nn.init.xavier_normal_(m.weight.data) 
  nn.init.xavier_normal_(m.bias.data)
 elif isinstance(m, nn.BatchNorm2d):
  nn.init.constant_(m.weight,1)
  nn.init.constant_(m.bias, 0)
 elif isinstance(m, nn.BatchNorm1d):
  nn.init.constant_(m.weight,1)
  nn.init.constant_(m.bias, 0)

编写好weights_init函数后,可以使用模型的apply方法对模型进行权重初始化。

net = Residual() # generate an instance network from the Net class

net.apply(weights_init) # apply weight init

补充知识:Pytorch权值初始化及参数分组

1. 模型参数初始化

# ————————————————— 利用model.apply(weights_init)实现初始化
def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv') != -1:
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
    if m.bias is not None:
      m.bias.data.zero_()
  elif classname.find('BatchNorm') != -1:
    m.weight.data.fill_(1)
    m.bias.data.zero_()
  elif classname.find('Linear') != -1:
    n = m.weight.size(1)
    m.weight.data.normal_(0, 0.01)
    m.bias.data = torch.ones(m.bias.data.size())
    
# ————————————————— 直接放在__init__构造函数中实现初始化
for m in self.modules():
  if isinstance(m, nn.Conv2d):
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
    if m.bias is not None:
      m.bias.data.zero_()
  elif isinstance(m, nn.BatchNorm2d):
    m.weight.data.fill_(1)
    m.bias.data.zero_()
  elif isinstance(m, nn.BatchNorm1d):
    m.weight.data.fill_(1)
    m.bias.data.zero_()
  elif isinstance(m, nn.Linear):
    nn.init.xavier_uniform_(m.weight.data)
    if m.bias is not None:
      m.bias.data.zero_()
    
# —————————————————
self.weight = Parameter(torch.Tensor(out_features, in_features))
self.bias = Parameter(torch.FloatTensor(out_features))
nn.init.xavier_uniform_(self.weight)
nn.init.zero_(self.bias)
nn.init.constant_(m, initm)
# nn.init.kaiming_uniform_()
# self.weight.data.normal_(std=0.001)

2. 模型参数分组weight_decay

def separate_bn_prelu_params(model, ignored_params=[]):
  bn_prelu_params = []
  for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
      ignored_params += list(map(id, m.parameters()))  
      bn_prelu_params += m.parameters()
    if isinstance(m, nn.BatchNorm1d):
      ignored_params += list(map(id, m.parameters()))  
      bn_prelu_params += m.parameters()
    elif isinstance(m, nn.PReLU):
      ignored_params += list(map(id, m.parameters()))
      bn_prelu_params += m.parameters()
  base_params = list(filter(lambda p: id(p) not in ignored_params, model.parameters()))

  return base_params, bn_prelu_params, ignored_params

OPTIMIZER = optim.SGD([
    {'params': base_params, 'weight_decay': WEIGHT_DECAY},     
    {'params': fc_head_param, 'weight_decay': WEIGHT_DECAY * 10},
    {'params': bn_prelu_params, 'weight_decay': 0.0}
    ], lr=LR, momentum=MOMENTUM ) # , nesterov=True

Note 1:PReLU(x) = max(0,x) + a * min(0,x). Here a is a learnable parameter. When called without arguments, nn.PReLU() uses a single parameter a across all input channels. If called with nn.PReLU(nChannels), a separate a is used for each input channel.

Note 2: weight decay should not be used when learning a for good performance.

Note 3: The default number of a to learn is 1, the default initial value of a is 0.25.

3. 参数分组weight_decay?其他

第2节中的内容可以满足一般的参数分组需求,此部分可以满足更个性化的分组需求。参考:face_evoLVe_Pytorch-master

自定义schedule

def schedule_lr(optimizer):
  for params in optimizer.param_groups:
    params['lr'] /= 10.
  print(optimizer)

方法一:利用model.modules()和obj.__class__ (更普适)

# model.modules()和model.children()的区别:model.modules()会迭代地遍历模型的所有子层,而model.children()只会遍历模型下的一层
# 下面的关键词if 'model',源于模型定义文件。如model_resnet.py中自定义的所有nn.Module子类,都会前缀'model_resnet',所以可通过这种方式一次性筛选出自定义的模块
def separate_irse_bn_paras(model):
  paras_only_bn = []         
  paras_no_bn = []
  for layer in model.modules():
    if 'model' in str(layer.__class__):		      # eg. a=[1,2] type(a): <class 'list'> a.__class__: <class 'list'>
      continue
    if 'container' in str(layer.__class__):       # 去掉Sequential型的模块
      continue
    else:
      if 'batchnorm' in str(layer.__class__):
        paras_only_bn.extend([*layer.parameters()])
      else:
        paras_no_bn.extend([*layer.parameters()])  # extend()用于在列表末尾一次性追加另一个序列中的多个值(用新列表扩展原来的列表)

  return paras_only_bn, paras_no_bn

方法二:调用modules.parameters和named_parameters()

但是本质上,parameters()是根据named_parameters()获取,named_parameters()是根据modules()获取。使用此方法的前提是,须按下文1,2中的方式定义模型,或者利用Sequential+OrderedDict定义模型。

def separate_resnet_bn_paras(model):
  all_parameters = model.parameters()
  paras_only_bn = []

  for pname, p in model.named_parameters():
    if pname.find('bn') >= 0:
      paras_only_bn.append(p)
      
  paras_only_bn_id = list(map(id, paras_only_bn))
  paras_no_bn = list(filter(lambda p: id(p) not in paras_only_bn_id, all_parameters))
  
  return paras_only_bn, paras_no_bn

两种方法的区别

参数分组的区别,其实对应了模型构造时的区别。举例:

1、构造ResNet的basic block,在__init__()函数中定义了

self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = BatchNorm2d(planes)
self.relu = ReLU(inplace = True)
…

2、在forward()中定义

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
…

3、对ResNet取model.name_parameters()返回的pname形如:

‘layer1.0.conv1.weight'
‘layer1.0.bn1.weight'
‘layer1.0.bn1.bias'
# layer对应conv2_x, …, conv5_x; '0'对应各layer中的block索引,比如conv2_x有3个block,对应索引为layer1.0, …, layer1.2; 'conv1'就是__init__()中定义的self.conv1

4、若构造model时采用了Sequential(),则model.name_parameters()返回的pname形如:

‘body.3.res_layer.1.weight',此处的1.weight实际对应了BN的weight,无法通过pname.find(‘bn')找到该模块。

self.res_layer = Sequential(
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
BatchNorm2d(depth),
ReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth)
)

5、针对4中的情况,两种解决办法:利用OrderedDict修饰Sequential,或利用方法一

downsample = Sequential( OrderedDict([
(‘conv_ds', conv1x1(self.inplanes, planes * block.expansion, stride)),
(‘bn_ds', BatchNorm2d(planes * block.expansion)),
]))
# 如此,相应模块的pname将会带有'conv_ds',‘bn_ds'字样

以上这篇pytorch 网络参数 weight bias 初始化详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python线程池的实现实例
Nov 18 Python
Python生成随机数的方法
Jan 14 Python
Python实现在线程里运行scrapy的方法
Apr 07 Python
python编写简单爬虫资料汇总
Mar 22 Python
Python编程实现的图片识别功能示例
Aug 03 Python
python实现最长公共子序列
May 22 Python
Python过滤txt文件内重复内容的方法
Oct 21 Python
用python3教你任意Html主内容提取功能
Nov 05 Python
Python3中exp()函数用法分析
Feb 19 Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 Python
使用python matploblib库绘制准确率,损失率折线图
Jun 16 Python
Python接收手机短信的代码整理
Aug 02 Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
virtualenv介绍及简明教程
Jun 23 #Python
python不同系统中打开方法
Jun 23 #Python
You might like
一个简单的自动发送邮件系统(一)
2006/10/09 PHP
实用函数9
2007/11/08 PHP
PHP管理依赖(dependency)关系工具 Composer 安装与使用
2014/08/18 PHP
php批量删除数据库下指定前缀的表以prefix_为例
2014/08/24 PHP
php使用PDO获取结果集的方法
2017/02/16 PHP
PHP实现添加购物车功能
2017/03/06 PHP
javascript中onmouse事件在div中失效问题的解决方法
2012/01/09 Javascript
javascript工厂方式定义对象
2014/12/26 Javascript
JavaScript编写点击查看大图的页面半透明遮罩层效果实例
2016/05/09 Javascript
用原生JS实现简单的多选框功能
2017/06/12 Javascript
JavaScript中EventLoop介绍
2018/01/22 Javascript
vue-cli V3.0版本的使用详解
2018/10/24 Javascript
angular中两种表单的区别(响应式和模板驱动表单)
2018/12/06 Javascript
详解vuex持久化插件解决浏览器刷新数据消失问题
2019/04/15 Javascript
javascript自定义日期比较函数用法示例
2019/07/22 Javascript
JS中如何轻松遍历对象属性的方式总结
2019/08/06 Javascript
详解Vue 项目中的几个实用组件(ts)
2019/10/29 Javascript
JS实现横向跑马灯效果代码
2020/04/20 Javascript
vue 实现setInterval 创建和销毁实例
2020/07/21 Javascript
Vue Element校验validate的实例
2020/09/21 Javascript
vant 中van-list的用法说明
2020/11/11 Javascript
[05:53]完美世界携手游戏风云打造 卡尔工作室观战系统篇
2013/04/22 DOTA
[03:48]显微镜下的DOTA2第四期——TP动作
2014/06/20 DOTA
Python实现 版本号对比功能的实例代码
2019/04/18 Python
python装饰器相当于函数的调用方式
2019/12/27 Python
python yield和Generator函数用法详解
2020/02/10 Python
python使用openpyxl操作excel的方法步骤
2020/05/28 Python
音频处理 windows10下python三方库librosa安装教程
2020/06/20 Python
Tenstickers法国:墙贴和装饰贴纸
2019/08/26 全球购物
一套C++笔试题面试题
2012/06/06 面试题
简述Linux文件系统通过i节点把文件的逻辑结构和物理结构转换的工作过程
2012/04/17 面试题
初一家长会邀请函
2014/01/31 职场文书
音乐器材管理制度
2014/01/31 职场文书
学校个人对照检查材料
2014/08/26 职场文书
出纳岗位职责范本
2015/03/31 职场文书
《火纹风花雪月无双》预告“神秘雇佣兵” 紫发剑客
2022/04/13 其他游戏