Pytorch 如何实现常用正则化


Posted in Python onMay 27, 2021

Stochastic Depth

论文:Deep Networks with Stochastic Depth

本文的正则化针对于ResNet中的残差结构,类似于dropout的原理,训练时对模块进行随机的删除,从而提升模型的泛化能力。

Pytorch 如何实现常用正则化

对于上述的ResNet网络,模块越在后面被drop掉的概率越大。

作者直觉上认为前期提取的低阶特征会被用于后面的层。

第一个模块保留的概率为1,之后保留概率随着深度线性递减。

对一个模块的drop函数可以采用如下的方式实现:

def drop_connect(inputs, p, training):
    """ Drop connect. """
    if not training: return inputs # 测试阶段
    batch_size = inputs.shape[0]
    keep_prob = 1 - p
    random_tensor = keep_prob
    random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
    # 以样本为单位生成模块是否被drop的01向量
    binary_tensor = torch.floor(random_tensor) 
    # 因为越往后越容易被drop,所以没有被drop的值就要通过除keep_prob来放大
    output = inputs / keep_prob * binary_tensor
    return output

在Pytorch建立的Module类中,具有forward函数

可以在forward函数中进行drop:

def forward(self, x):
 x=...
 if stride == 1 and in_planes == out_planes:
        if drop_connect_rate:
            x = drop_connect(x, p=drop_connect_rate, training=self.training)
        x = x + inputs  # skip connection
    return x

主函数:

for idx, block in enumerate(self._blocks):
    drop_connect_rate = self._global_params.drop_connect_rate
    if drop_connect_rate:
        drop_connect_rate *= float(idx) / len(self._blocks)
    x = block(x, drop_connect_rate=drop_connect_rate)

补充:pytorch中的L2正则化实现方法

搭建神经网络时需要使用L2正则化等操作来防止过拟合,而pytorch不像TensorFlow能在任意卷积函数中添加L2正则化的超参,那怎么在pytorch中实现L2正则化呢?

方法如下:超级简单!

optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=5.0)

torch.optim.Adam()参数中的 weight_decay=5.0 即为L2正则化(只是pytorch换了名字),其数值即为L2正则化的惩罚系数,一般设置为1、5、10(根据需要设置,默认为0,不使用L2正则化)。

注:

pytorch中的优化函数L2正则化默认对所有网络参数进行惩罚,且只能实现L2正则化,如需只惩罚指定网络层参数或采用L1正则化,只能自己定义。。。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现查找数组中任意第k大的数字算法示例
Jan 23 Python
详解Ubuntu16.04安装Python3.7及其pip3并切换为默认版本
Feb 25 Python
利用python numpy+matplotlib绘制股票k线图的方法
Jun 26 Python
Python 获取ftp服务器文件时间的方法
Jul 02 Python
django 自定义filter 判断if var in list的例子
Aug 20 Python
Python操作excel的方法总结(xlrd、xlwt、openpyxl)
Sep 02 Python
Django1.11配合uni-app发起微信支付的实现
Oct 12 Python
Python任务自动化工具tox使用教程
Mar 17 Python
浅析Python __name__ 是什么
Jul 07 Python
python爬虫泛滥的解决方法详解
Nov 25 Python
Pandas之缺失数据的实现
Jan 06 Python
使用python求解迷宫问题的三种实现方法
Mar 17 Python
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
pytorch 实现在测试的时候启用dropout
使用Python脚本对GiteePages进行一键部署的使用说明
教你使用Python pypinyin库实现汉字转拼音
基于tensorflow权重文件的解读
May 26 #Python
解决Python字典查找报Keyerror的问题
You might like
PHP学习 运算符与运算符优先级
2008/06/15 PHP
php cout<<的一点看法
2010/01/24 PHP
Zend framework处理一个http请求的流程分析
2010/02/08 PHP
PHP sprintf()函数用例解析
2011/05/18 PHP
PHP判断文章里是否有图片的简单方法
2014/07/26 PHP
PHP开发的文字水印,缩略图,图片水印实现类与用法示例
2019/04/12 PHP
JavaScript中的事件处理
2008/01/16 Javascript
js DOM模型操作
2009/12/28 Javascript
超简单JS二级、多级联动的简单实例
2014/02/18 Javascript
jQuery右下角旋转环状菜单特效代码
2015/08/10 Javascript
JavaScript常用函数工具集:lao-utils
2016/03/01 Javascript
AngularJS基础 ng-open 指令简单实例
2016/08/02 Javascript
JS中用childNodes获取子元素换行会产生一个子元素
2016/12/08 Javascript
Angular实现表单验证功能
2017/11/13 Javascript
vuex 项目结构目录及一些简单配置介绍
2018/04/08 Javascript
js统计页面上每个标签的数量实例代码
2018/05/29 Javascript
vue 实现数字滚动增加效果的实例代码
2018/07/06 Javascript
在小程序中集成redux/immutable/thunk第三方库的方法
2018/08/12 Javascript
如何封装了一个vue移动端下拉加载下一页数据的组件
2019/01/06 Javascript
微信小程序授权登录解决方案的代码实例(含未通过授权解决方案)
2019/05/10 Javascript
浅谈Vue2.4.0 $attrs与inheritAttrs的具体使用
2020/03/08 Javascript
微信小程序自定义支持图片的弹窗
2020/12/21 Javascript
python文件操作之目录遍历实例分析
2015/05/20 Python
Python运算符重载用法实例
2015/05/28 Python
Python之re操作方法(详解)
2017/06/14 Python
python logging重复记录日志问题的解决方法
2018/07/12 Python
利用pyinstaller打包exe文件的基本教程
2019/05/02 Python
对Django url的几种使用方式详解
2019/08/06 Python
Python matplotlib以日期为x轴作图代码实例
2019/11/22 Python
安装完Python包然后找不到模块的解决步骤
2020/02/13 Python
Python Opencv中用compareHist函数进行直方图比较对比图片
2020/04/07 Python
flask项目集成swagger的方法
2020/12/09 Python
校园公益广告语
2014/03/13 职场文书
2014党员干部四风问题对照检查材料思想汇报
2014/09/24 职场文书
文体活动总结
2015/02/04 职场文书
横空出世观后感
2015/06/09 职场文书