Pytorch 如何实现常用正则化


Posted in Python onMay 27, 2021

Stochastic Depth

论文:Deep Networks with Stochastic Depth

本文的正则化针对于ResNet中的残差结构,类似于dropout的原理,训练时对模块进行随机的删除,从而提升模型的泛化能力。

Pytorch 如何实现常用正则化

对于上述的ResNet网络,模块越在后面被drop掉的概率越大。

作者直觉上认为前期提取的低阶特征会被用于后面的层。

第一个模块保留的概率为1,之后保留概率随着深度线性递减。

对一个模块的drop函数可以采用如下的方式实现:

def drop_connect(inputs, p, training):
    """ Drop connect. """
    if not training: return inputs # 测试阶段
    batch_size = inputs.shape[0]
    keep_prob = 1 - p
    random_tensor = keep_prob
    random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
    # 以样本为单位生成模块是否被drop的01向量
    binary_tensor = torch.floor(random_tensor) 
    # 因为越往后越容易被drop,所以没有被drop的值就要通过除keep_prob来放大
    output = inputs / keep_prob * binary_tensor
    return output

在Pytorch建立的Module类中,具有forward函数

可以在forward函数中进行drop:

def forward(self, x):
 x=...
 if stride == 1 and in_planes == out_planes:
        if drop_connect_rate:
            x = drop_connect(x, p=drop_connect_rate, training=self.training)
        x = x + inputs  # skip connection
    return x

主函数:

for idx, block in enumerate(self._blocks):
    drop_connect_rate = self._global_params.drop_connect_rate
    if drop_connect_rate:
        drop_connect_rate *= float(idx) / len(self._blocks)
    x = block(x, drop_connect_rate=drop_connect_rate)

补充:pytorch中的L2正则化实现方法

搭建神经网络时需要使用L2正则化等操作来防止过拟合,而pytorch不像TensorFlow能在任意卷积函数中添加L2正则化的超参,那怎么在pytorch中实现L2正则化呢?

方法如下:超级简单!

optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=5.0)

torch.optim.Adam()参数中的 weight_decay=5.0 即为L2正则化(只是pytorch换了名字),其数值即为L2正则化的惩罚系数,一般设置为1、5、10(根据需要设置,默认为0,不使用L2正则化)。

注:

pytorch中的优化函数L2正则化默认对所有网络参数进行惩罚,且只能实现L2正则化,如需只惩罚指定网络层参数或采用L1正则化,只能自己定义。。。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python两个整数相除得到浮点数值的方法
Mar 18 Python
为Python的web框架编写前端模版的教程
Apr 30 Python
Python中的自省(反射)详解
Jun 02 Python
python选择排序算法实例总结
Jul 01 Python
详解如何利用Cython为Python代码加速
Jan 27 Python
Python获取指定文件夹下的文件名的方法
Feb 06 Python
Python使用pandas处理CSV文件的实例讲解
Jun 22 Python
python中的常量和变量代码详解
Jul 25 Python
python+influxdb+shell编写区域网络状况表
Jul 27 Python
python3+PyQt5 数据库编程--增删改实例
Jun 17 Python
pandas进行时间数据的转换和计算时间差并提取年月日
Jul 06 Python
树莓派4B+opencv4+python 打开摄像头的实现方法
Oct 18 Python
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
pytorch 实现在测试的时候启用dropout
使用Python脚本对GiteePages进行一键部署的使用说明
教你使用Python pypinyin库实现汉字转拼音
基于tensorflow权重文件的解读
May 26 #Python
解决Python字典查找报Keyerror的问题
You might like
法压式咖啡之制作法
2021/03/03 冲泡冲煮
PHP模板引擎SMARTY
2006/10/09 PHP
php 数学运算验证码实现代码
2009/10/11 PHP
ezSQL PHP数据库操作类库
2010/05/16 PHP
利用PHP访问MySql数据库的逻辑操作以及增删改查的实例讲解
2017/08/30 PHP
php 判断IP为有效IP地址的方法
2018/01/28 PHP
javascript getElementsByTagName
2011/01/31 Javascript
使用JavaScript链式编程实现模拟Jquery函数
2014/12/21 Javascript
JavaScript判断变量是否为数组的方法(Array)
2016/02/24 Javascript
Bootstrap表格和栅格分页实例详解
2016/05/20 Javascript
浅谈js的异步执行
2016/10/18 Javascript
jQuery Ajax File Upload实例源码
2016/12/12 Javascript
JS通过调用微信API实现微信支付功能的方法示例
2017/06/29 Javascript
Vue学习笔记之表单输入控件绑定
2017/09/05 Javascript
JavaScript使用atan2来绘制箭头和曲线的实例
2017/09/14 Javascript
详解用场景去理解函数柯里化(入门篇)
2019/04/11 Javascript
javascript实现前端input密码输入强度验证
2020/06/24 Javascript
[02:43]中国五虎出征TI3视频
2013/08/02 DOTA
详解Python中的__new__、__init__、__call__三个特殊方法
2016/06/02 Python
python实现随机漫步算法
2018/08/27 Python
情人节快乐! python绘制漂亮玫瑰
2020/08/18 Python
python解析xml文件方式(解析、更新、写入)
2020/03/05 Python
解决pycharm编辑区显示yaml文件层级结构遇中文乱码问题
2020/04/27 Python
Python Django搭建网站流程图解
2020/06/13 Python
Python常用数据分析模块原理解析
2020/07/20 Python
python 爬取英雄联盟皮肤并下载的示例
2020/12/04 Python
CSS3 transform的skew属性值图文详解
2014/07/21 HTML / CSS
x-ua-compatible content=”IE=7, IE=9″意思理解
2013/07/22 HTML / CSS
沃尔玛旗下墨西哥超市:Bodega Aurrera
2020/11/13 全球购物
找工作最新求职信
2013/12/22 职场文书
2014年度安全生产目标管理责任书
2014/07/25 职场文书
医学专业大学生职业生涯规划书
2014/10/25 职场文书
2014年招商引资工作总结
2014/11/22 职场文书
运动会广播稿300字
2015/08/19 职场文书
导游词之重庆渣滓洞
2020/01/08 职场文书
html5移动端禁止长按图片保存的实现
2021/04/20 HTML / CSS