pytorch  网络参数 weight bias 初始化详解


Posted in Python onJune 24, 2020

权重初始化对于训练神经网络至关重要,好的初始化权重可以有效的避免梯度消失等问题的发生。

在pytorch的使用过程中有几种权重初始化的方法供大家参考。

注意:第一种方法不推荐。尽量使用后两种方法。

# not recommend
def weights_init(m):
 classname = m.__class__.__name__
 if classname.find('Conv') != -1:
  m.weight.data.normal_(0.0, 0.02)
 elif classname.find('BatchNorm') != -1:
  m.weight.data.normal_(1.0, 0.02)
  m.bias.data.fill_(0)
# recommend
def initialize_weights(m):
 if isinstance(m, nn.Conv2d):
  m.weight.data.normal_(0, 0.02)
  m.bias.data.zero_()
 elif isinstance(m, nn.Linear):
  m.weight.data.normal_(0, 0.02)
  m.bias.data.zero_()
# recommend
def weights_init(m): 
 if isinstance(m, nn.Conv2d): 
  nn.init.xavier_normal_(m.weight.data) 
  nn.init.xavier_normal_(m.bias.data)
 elif isinstance(m, nn.BatchNorm2d):
  nn.init.constant_(m.weight,1)
  nn.init.constant_(m.bias, 0)
 elif isinstance(m, nn.BatchNorm1d):
  nn.init.constant_(m.weight,1)
  nn.init.constant_(m.bias, 0)

编写好weights_init函数后,可以使用模型的apply方法对模型进行权重初始化。

net = Residual() # generate an instance network from the Net class

net.apply(weights_init) # apply weight init

补充知识:Pytorch权值初始化及参数分组

1. 模型参数初始化

# ————————————————— 利用model.apply(weights_init)实现初始化
def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv') != -1:
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
    if m.bias is not None:
      m.bias.data.zero_()
  elif classname.find('BatchNorm') != -1:
    m.weight.data.fill_(1)
    m.bias.data.zero_()
  elif classname.find('Linear') != -1:
    n = m.weight.size(1)
    m.weight.data.normal_(0, 0.01)
    m.bias.data = torch.ones(m.bias.data.size())
    
# ————————————————— 直接放在__init__构造函数中实现初始化
for m in self.modules():
  if isinstance(m, nn.Conv2d):
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
    if m.bias is not None:
      m.bias.data.zero_()
  elif isinstance(m, nn.BatchNorm2d):
    m.weight.data.fill_(1)
    m.bias.data.zero_()
  elif isinstance(m, nn.BatchNorm1d):
    m.weight.data.fill_(1)
    m.bias.data.zero_()
  elif isinstance(m, nn.Linear):
    nn.init.xavier_uniform_(m.weight.data)
    if m.bias is not None:
      m.bias.data.zero_()
    
# —————————————————
self.weight = Parameter(torch.Tensor(out_features, in_features))
self.bias = Parameter(torch.FloatTensor(out_features))
nn.init.xavier_uniform_(self.weight)
nn.init.zero_(self.bias)
nn.init.constant_(m, initm)
# nn.init.kaiming_uniform_()
# self.weight.data.normal_(std=0.001)

2. 模型参数分组weight_decay

def separate_bn_prelu_params(model, ignored_params=[]):
  bn_prelu_params = []
  for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
      ignored_params += list(map(id, m.parameters()))  
      bn_prelu_params += m.parameters()
    if isinstance(m, nn.BatchNorm1d):
      ignored_params += list(map(id, m.parameters()))  
      bn_prelu_params += m.parameters()
    elif isinstance(m, nn.PReLU):
      ignored_params += list(map(id, m.parameters()))
      bn_prelu_params += m.parameters()
  base_params = list(filter(lambda p: id(p) not in ignored_params, model.parameters()))

  return base_params, bn_prelu_params, ignored_params

OPTIMIZER = optim.SGD([
    {'params': base_params, 'weight_decay': WEIGHT_DECAY},     
    {'params': fc_head_param, 'weight_decay': WEIGHT_DECAY * 10},
    {'params': bn_prelu_params, 'weight_decay': 0.0}
    ], lr=LR, momentum=MOMENTUM ) # , nesterov=True

Note 1:PReLU(x) = max(0,x) + a * min(0,x). Here a is a learnable parameter. When called without arguments, nn.PReLU() uses a single parameter a across all input channels. If called with nn.PReLU(nChannels), a separate a is used for each input channel.

Note 2: weight decay should not be used when learning a for good performance.

Note 3: The default number of a to learn is 1, the default initial value of a is 0.25.

3. 参数分组weight_decay?其他

第2节中的内容可以满足一般的参数分组需求,此部分可以满足更个性化的分组需求。参考:face_evoLVe_Pytorch-master

自定义schedule

def schedule_lr(optimizer):
  for params in optimizer.param_groups:
    params['lr'] /= 10.
  print(optimizer)

方法一:利用model.modules()和obj.__class__ (更普适)

# model.modules()和model.children()的区别:model.modules()会迭代地遍历模型的所有子层,而model.children()只会遍历模型下的一层
# 下面的关键词if 'model',源于模型定义文件。如model_resnet.py中自定义的所有nn.Module子类,都会前缀'model_resnet',所以可通过这种方式一次性筛选出自定义的模块
def separate_irse_bn_paras(model):
  paras_only_bn = []         
  paras_no_bn = []
  for layer in model.modules():
    if 'model' in str(layer.__class__):		      # eg. a=[1,2] type(a): <class 'list'> a.__class__: <class 'list'>
      continue
    if 'container' in str(layer.__class__):       # 去掉Sequential型的模块
      continue
    else:
      if 'batchnorm' in str(layer.__class__):
        paras_only_bn.extend([*layer.parameters()])
      else:
        paras_no_bn.extend([*layer.parameters()])  # extend()用于在列表末尾一次性追加另一个序列中的多个值(用新列表扩展原来的列表)

  return paras_only_bn, paras_no_bn

方法二:调用modules.parameters和named_parameters()

但是本质上,parameters()是根据named_parameters()获取,named_parameters()是根据modules()获取。使用此方法的前提是,须按下文1,2中的方式定义模型,或者利用Sequential+OrderedDict定义模型。

def separate_resnet_bn_paras(model):
  all_parameters = model.parameters()
  paras_only_bn = []

  for pname, p in model.named_parameters():
    if pname.find('bn') >= 0:
      paras_only_bn.append(p)
      
  paras_only_bn_id = list(map(id, paras_only_bn))
  paras_no_bn = list(filter(lambda p: id(p) not in paras_only_bn_id, all_parameters))
  
  return paras_only_bn, paras_no_bn

两种方法的区别

参数分组的区别,其实对应了模型构造时的区别。举例:

1、构造ResNet的basic block,在__init__()函数中定义了

self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = BatchNorm2d(planes)
self.relu = ReLU(inplace = True)
…

2、在forward()中定义

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
…

3、对ResNet取model.name_parameters()返回的pname形如:

‘layer1.0.conv1.weight'
‘layer1.0.bn1.weight'
‘layer1.0.bn1.bias'
# layer对应conv2_x, …, conv5_x; '0'对应各layer中的block索引,比如conv2_x有3个block,对应索引为layer1.0, …, layer1.2; 'conv1'就是__init__()中定义的self.conv1

4、若构造model时采用了Sequential(),则model.name_parameters()返回的pname形如:

‘body.3.res_layer.1.weight',此处的1.weight实际对应了BN的weight,无法通过pname.find(‘bn')找到该模块。

self.res_layer = Sequential(
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
BatchNorm2d(depth),
ReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth)
)

5、针对4中的情况,两种解决办法:利用OrderedDict修饰Sequential,或利用方法一

downsample = Sequential( OrderedDict([
(‘conv_ds', conv1x1(self.inplanes, planes * block.expansion, stride)),
(‘bn_ds', BatchNorm2d(planes * block.expansion)),
]))
# 如此,相应模块的pname将会带有'conv_ds',‘bn_ds'字样

以上这篇pytorch 网络参数 weight bias 初始化详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中设置变量访问权限的方法
Apr 27 Python
Python2.7编程中SQLite3基本操作方法示例
Aug 09 Python
Python实现屏幕截图的两种方式
Feb 05 Python
Python实现确认字符串是否包含指定字符串的实例
May 02 Python
python获取指定字符串中重复模式最高的字符串方法
Jun 29 Python
浅谈Pycharm中的Python Console与Terminal
Jan 17 Python
python logging模块的使用总结
Jul 09 Python
关于numpy数组轴的使用详解
Dec 05 Python
Python如何省略括号方法详解
Mar 21 Python
Numpy中np.max的用法及np.maximum区别
Nov 27 Python
python爬虫利用selenium实现自动翻页爬取某鱼数据的思路详解
Dec 22 Python
python实现手机推送 代码也就10行左右
Apr 12 Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
virtualenv介绍及简明教程
Jun 23 #Python
python不同系统中打开方法
Jun 23 #Python
You might like
php实现字符串翻转的方法
2015/03/27 PHP
Yii2框架dropDownList下拉菜单用法实例分析
2016/07/18 PHP
Yii框架中用response保存cookie,用request读取cookie的原理解析
2019/09/04 PHP
使用PHP+Redis实现延迟任务,实现自动取消订单功能
2019/11/21 PHP
Gird事件机制初级读本
2007/03/10 Javascript
固定背景实现的背景滚动特效示例分享
2013/05/19 Javascript
21个值得收藏的Javascript技巧
2014/02/04 Javascript
JS取数字小数点后两位或n位的简单方法
2016/10/24 Javascript
原生JS实现《别踩白块》游戏(兼容IE)
2017/02/20 Javascript
jQuery中map函数的两种方式
2017/04/07 jQuery
微信小程序 request接口的封装实例代码
2017/04/26 Javascript
浅谈Vuex@2.3.0 中的 state 支持函数申明
2017/11/22 Javascript
实现jquery放大镜的两种方法
2018/02/22 jQuery
Vue2.0子同级组件之间数据交互方法
2018/02/28 Javascript
详解vue的diff算法原理
2018/05/20 Javascript
layui数据表格 table.render 报错的解决方法
2019/09/29 Javascript
vue之组件内监控$store中定义变量的变化详解
2019/11/08 Javascript
vue学习笔记之slot插槽用法实例分析
2020/02/29 Javascript
JS+CSS+HTML实现“代码雨”类似黑客帝国文字下落效果
2020/03/17 Javascript
JavaScript this关键字指向常用情况解析
2020/09/02 Javascript
微信小程序canvas实现签名功能
2021/01/19 Javascript
[00:32]2018DOTA2亚洲邀请赛Liquid出场
2018/04/03 DOTA
Python中实现字符串类型与字典类型相互转换的方法
2014/08/18 Python
Python中使用copy模块实现列表(list)拷贝
2015/04/14 Python
python3.x+pyqt5实现主窗口状态栏里(嵌入)显示进度条功能
2019/07/04 Python
Python将列表中的元素转化为数字并排序的示例
2019/12/25 Python
Python tornado上传文件的功能
2020/03/26 Python
Python关键字及可变参数*args,**kw原理解析
2020/04/04 Python
Python如何实现定时器功能
2020/05/28 Python
Python实现淘宝秒杀功能的示例代码
2021/01/19 Python
HTML5+CSS3 诱人的实例:3D立方体旋转动画实例
2016/12/30 HTML / CSS
光电信息专业应届生求职信
2013/10/07 职场文书
妇女儿童发展规划实施方案
2014/03/16 职场文书
反腐倡廉警示教育活动心得体会
2014/09/04 职场文书
创业计划书之酒店
2019/08/30 职场文书
pycharm代码删除恢复的方法
2021/06/26 Python