pytorch  网络参数 weight bias 初始化详解


Posted in Python onJune 24, 2020

权重初始化对于训练神经网络至关重要,好的初始化权重可以有效的避免梯度消失等问题的发生。

在pytorch的使用过程中有几种权重初始化的方法供大家参考。

注意:第一种方法不推荐。尽量使用后两种方法。

# not recommend
def weights_init(m):
 classname = m.__class__.__name__
 if classname.find('Conv') != -1:
  m.weight.data.normal_(0.0, 0.02)
 elif classname.find('BatchNorm') != -1:
  m.weight.data.normal_(1.0, 0.02)
  m.bias.data.fill_(0)
# recommend
def initialize_weights(m):
 if isinstance(m, nn.Conv2d):
  m.weight.data.normal_(0, 0.02)
  m.bias.data.zero_()
 elif isinstance(m, nn.Linear):
  m.weight.data.normal_(0, 0.02)
  m.bias.data.zero_()
# recommend
def weights_init(m): 
 if isinstance(m, nn.Conv2d): 
  nn.init.xavier_normal_(m.weight.data) 
  nn.init.xavier_normal_(m.bias.data)
 elif isinstance(m, nn.BatchNorm2d):
  nn.init.constant_(m.weight,1)
  nn.init.constant_(m.bias, 0)
 elif isinstance(m, nn.BatchNorm1d):
  nn.init.constant_(m.weight,1)
  nn.init.constant_(m.bias, 0)

编写好weights_init函数后,可以使用模型的apply方法对模型进行权重初始化。

net = Residual() # generate an instance network from the Net class

net.apply(weights_init) # apply weight init

补充知识:Pytorch权值初始化及参数分组

1. 模型参数初始化

# ————————————————— 利用model.apply(weights_init)实现初始化
def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv') != -1:
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
    if m.bias is not None:
      m.bias.data.zero_()
  elif classname.find('BatchNorm') != -1:
    m.weight.data.fill_(1)
    m.bias.data.zero_()
  elif classname.find('Linear') != -1:
    n = m.weight.size(1)
    m.weight.data.normal_(0, 0.01)
    m.bias.data = torch.ones(m.bias.data.size())
    
# ————————————————— 直接放在__init__构造函数中实现初始化
for m in self.modules():
  if isinstance(m, nn.Conv2d):
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
    if m.bias is not None:
      m.bias.data.zero_()
  elif isinstance(m, nn.BatchNorm2d):
    m.weight.data.fill_(1)
    m.bias.data.zero_()
  elif isinstance(m, nn.BatchNorm1d):
    m.weight.data.fill_(1)
    m.bias.data.zero_()
  elif isinstance(m, nn.Linear):
    nn.init.xavier_uniform_(m.weight.data)
    if m.bias is not None:
      m.bias.data.zero_()
    
# —————————————————
self.weight = Parameter(torch.Tensor(out_features, in_features))
self.bias = Parameter(torch.FloatTensor(out_features))
nn.init.xavier_uniform_(self.weight)
nn.init.zero_(self.bias)
nn.init.constant_(m, initm)
# nn.init.kaiming_uniform_()
# self.weight.data.normal_(std=0.001)

2. 模型参数分组weight_decay

def separate_bn_prelu_params(model, ignored_params=[]):
  bn_prelu_params = []
  for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
      ignored_params += list(map(id, m.parameters()))  
      bn_prelu_params += m.parameters()
    if isinstance(m, nn.BatchNorm1d):
      ignored_params += list(map(id, m.parameters()))  
      bn_prelu_params += m.parameters()
    elif isinstance(m, nn.PReLU):
      ignored_params += list(map(id, m.parameters()))
      bn_prelu_params += m.parameters()
  base_params = list(filter(lambda p: id(p) not in ignored_params, model.parameters()))

  return base_params, bn_prelu_params, ignored_params

OPTIMIZER = optim.SGD([
    {'params': base_params, 'weight_decay': WEIGHT_DECAY},     
    {'params': fc_head_param, 'weight_decay': WEIGHT_DECAY * 10},
    {'params': bn_prelu_params, 'weight_decay': 0.0}
    ], lr=LR, momentum=MOMENTUM ) # , nesterov=True

Note 1:PReLU(x) = max(0,x) + a * min(0,x). Here a is a learnable parameter. When called without arguments, nn.PReLU() uses a single parameter a across all input channels. If called with nn.PReLU(nChannels), a separate a is used for each input channel.

Note 2: weight decay should not be used when learning a for good performance.

Note 3: The default number of a to learn is 1, the default initial value of a is 0.25.

3. 参数分组weight_decay?其他

第2节中的内容可以满足一般的参数分组需求,此部分可以满足更个性化的分组需求。参考:face_evoLVe_Pytorch-master

自定义schedule

def schedule_lr(optimizer):
  for params in optimizer.param_groups:
    params['lr'] /= 10.
  print(optimizer)

方法一:利用model.modules()和obj.__class__ (更普适)

# model.modules()和model.children()的区别:model.modules()会迭代地遍历模型的所有子层,而model.children()只会遍历模型下的一层
# 下面的关键词if 'model',源于模型定义文件。如model_resnet.py中自定义的所有nn.Module子类,都会前缀'model_resnet',所以可通过这种方式一次性筛选出自定义的模块
def separate_irse_bn_paras(model):
  paras_only_bn = []         
  paras_no_bn = []
  for layer in model.modules():
    if 'model' in str(layer.__class__):		      # eg. a=[1,2] type(a): <class 'list'> a.__class__: <class 'list'>
      continue
    if 'container' in str(layer.__class__):       # 去掉Sequential型的模块
      continue
    else:
      if 'batchnorm' in str(layer.__class__):
        paras_only_bn.extend([*layer.parameters()])
      else:
        paras_no_bn.extend([*layer.parameters()])  # extend()用于在列表末尾一次性追加另一个序列中的多个值(用新列表扩展原来的列表)

  return paras_only_bn, paras_no_bn

方法二:调用modules.parameters和named_parameters()

但是本质上,parameters()是根据named_parameters()获取,named_parameters()是根据modules()获取。使用此方法的前提是,须按下文1,2中的方式定义模型,或者利用Sequential+OrderedDict定义模型。

def separate_resnet_bn_paras(model):
  all_parameters = model.parameters()
  paras_only_bn = []

  for pname, p in model.named_parameters():
    if pname.find('bn') >= 0:
      paras_only_bn.append(p)
      
  paras_only_bn_id = list(map(id, paras_only_bn))
  paras_no_bn = list(filter(lambda p: id(p) not in paras_only_bn_id, all_parameters))
  
  return paras_only_bn, paras_no_bn

两种方法的区别

参数分组的区别,其实对应了模型构造时的区别。举例:

1、构造ResNet的basic block,在__init__()函数中定义了

self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = BatchNorm2d(planes)
self.relu = ReLU(inplace = True)
…

2、在forward()中定义

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
…

3、对ResNet取model.name_parameters()返回的pname形如:

‘layer1.0.conv1.weight'
‘layer1.0.bn1.weight'
‘layer1.0.bn1.bias'
# layer对应conv2_x, …, conv5_x; '0'对应各layer中的block索引,比如conv2_x有3个block,对应索引为layer1.0, …, layer1.2; 'conv1'就是__init__()中定义的self.conv1

4、若构造model时采用了Sequential(),则model.name_parameters()返回的pname形如:

‘body.3.res_layer.1.weight',此处的1.weight实际对应了BN的weight,无法通过pname.find(‘bn')找到该模块。

self.res_layer = Sequential(
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
BatchNorm2d(depth),
ReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth)
)

5、针对4中的情况,两种解决办法:利用OrderedDict修饰Sequential,或利用方法一

downsample = Sequential( OrderedDict([
(‘conv_ds', conv1x1(self.inplanes, planes * block.expansion, stride)),
(‘bn_ds', BatchNorm2d(planes * block.expansion)),
]))
# 如此,相应模块的pname将会带有'conv_ds',‘bn_ds'字样

以上这篇pytorch 网络参数 weight bias 初始化详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
通过C++学习Python
Jan 20 Python
举例讲解Python设计模式编程中对抽象工厂模式的运用
Mar 02 Python
代码分析Python地图坐标转换
Feb 08 Python
python 每天如何定时启动爬虫任务(实现方法分享)
May 21 Python
基于Django URL传参 FORM表单传数据 get post的用法实例
May 28 Python
python批量修改图片后缀的方法(png到jpg)
Oct 25 Python
Python参数解析模块sys、getopt、argparse使用与对比分析
Apr 02 Python
python,Django实现的淘宝客登录功能示例
Jun 12 Python
python 实现的发送邮件模板【普通邮件、带附件、带图片邮件】
Jul 06 Python
PyCharm专业最新版2019.1安装步骤(含激活码)
Oct 09 Python
tensorflow之自定义神经网络层实例
Feb 07 Python
python实现批量修改文件名
Mar 23 Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
virtualenv介绍及简明教程
Jun 23 #Python
python不同系统中打开方法
Jun 23 #Python
You might like
《PHP编程最快明白》第六讲:Mysql数据库操作
2010/11/01 PHP
5种PHP创建数组的实例代码分享
2014/01/17 PHP
网站防止被刷票的一些思路与方法
2015/01/08 PHP
PHP实现图片不变型裁剪及图片按比例裁剪的方法
2016/01/14 PHP
漂亮的widgets,支持换肤和后期开发新皮肤(2007-4-27已更新1.7alpha)
2007/04/27 Javascript
JQuery为textarea添加maxlength属性的代码
2010/04/07 Javascript
JavaScript 继承机制的实现(待续)
2010/05/18 Javascript
window.open不被拦截的实现代码
2012/08/22 Javascript
DWR实现模拟Google搜索效果实现原理及代码
2013/01/30 Javascript
转换字符串为json对象的方法详解
2013/11/29 Javascript
JS+CSS实现弹出全屏灰黑色透明遮罩效果的方法
2014/12/20 Javascript
javascript顺序加载图片的方法
2015/07/18 Javascript
使用coffeescript编写node.js项目的方法汇总
2015/08/05 Javascript
jQuery实现类似标签风格的导航菜单效果代码
2015/08/25 Javascript
jQuery实现微信长按识别二维码功能
2016/08/26 Javascript
简单模拟node.js中require的加载机制
2016/10/27 Javascript
Javascript中八种遍历方法的执行速度深度对比
2017/04/25 Javascript
微信小程序中时间戳和日期的相互转换问题
2018/07/09 Javascript
Bootstrap 按钮样式与使用代码详解
2018/12/09 Javascript
解决layui弹框失效的问题
2019/09/09 Javascript
JS中FileReader类实现文件上传及时预览功能
2020/03/27 Javascript
vue 解决provide和inject响应的问题
2020/11/12 Javascript
JS实现超级好看的鼠标小尾巴特效
2020/12/01 Javascript
python库lxml在linux和WIN系统下的安装
2018/06/24 Python
python字符串替换第一个字符串的方法
2019/06/26 Python
pandas读取CSV文件时查看修改各列的数据类型格式
2019/07/07 Python
解决Django layui {{}}冲突的问题
2019/08/29 Python
python 的 openpyxl模块 读取 Excel文件的方法
2019/09/09 Python
python 在sql语句中使用%s,%d,%f说明
2020/06/06 Python
thinkphp5 路由分发原理
2021/03/18 PHP
计算机专业个人求职自荐信
2013/09/21 职场文书
大二自我鉴定
2014/01/31 职场文书
四年大学自我鉴定
2014/02/17 职场文书
岗位职责怎么写
2014/03/14 职场文书
普通话演讲稿
2014/09/03 职场文书
2014年加油站工作总结
2014/12/04 职场文书