Pytorch 如何实现常用正则化


Posted in Python onMay 27, 2021

Stochastic Depth

论文:Deep Networks with Stochastic Depth

本文的正则化针对于ResNet中的残差结构,类似于dropout的原理,训练时对模块进行随机的删除,从而提升模型的泛化能力。

Pytorch 如何实现常用正则化

对于上述的ResNet网络,模块越在后面被drop掉的概率越大。

作者直觉上认为前期提取的低阶特征会被用于后面的层。

第一个模块保留的概率为1,之后保留概率随着深度线性递减。

对一个模块的drop函数可以采用如下的方式实现:

def drop_connect(inputs, p, training):
    """ Drop connect. """
    if not training: return inputs # 测试阶段
    batch_size = inputs.shape[0]
    keep_prob = 1 - p
    random_tensor = keep_prob
    random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
    # 以样本为单位生成模块是否被drop的01向量
    binary_tensor = torch.floor(random_tensor) 
    # 因为越往后越容易被drop,所以没有被drop的值就要通过除keep_prob来放大
    output = inputs / keep_prob * binary_tensor
    return output

在Pytorch建立的Module类中,具有forward函数

可以在forward函数中进行drop:

def forward(self, x):
 x=...
 if stride == 1 and in_planes == out_planes:
        if drop_connect_rate:
            x = drop_connect(x, p=drop_connect_rate, training=self.training)
        x = x + inputs  # skip connection
    return x

主函数:

for idx, block in enumerate(self._blocks):
    drop_connect_rate = self._global_params.drop_connect_rate
    if drop_connect_rate:
        drop_connect_rate *= float(idx) / len(self._blocks)
    x = block(x, drop_connect_rate=drop_connect_rate)

补充:pytorch中的L2正则化实现方法

搭建神经网络时需要使用L2正则化等操作来防止过拟合,而pytorch不像TensorFlow能在任意卷积函数中添加L2正则化的超参,那怎么在pytorch中实现L2正则化呢?

方法如下:超级简单!

optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=5.0)

torch.optim.Adam()参数中的 weight_decay=5.0 即为L2正则化(只是pytorch换了名字),其数值即为L2正则化的惩罚系数,一般设置为1、5、10(根据需要设置,默认为0,不使用L2正则化)。

注:

pytorch中的优化函数L2正则化默认对所有网络参数进行惩罚,且只能实现L2正则化,如需只惩罚指定网络层参数或采用L1正则化,只能自己定义。。。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python smtplib模块发送SSL/TLS安全邮件实例
Apr 08 Python
讲解Python中if语句的嵌套用法
May 14 Python
Python搭建FTP服务器的方法示例
Jan 19 Python
使用DataFrame删除行和列的实例讲解
Apr 08 Python
Django forms表单 select下拉框的传值实例
Jul 19 Python
解决python3 requests headers参数不能有中文的问题
Aug 21 Python
python针对mysql数据库的连接、查询、更新、删除操作示例
Sep 11 Python
python的命名规则知识点总结
Oct 04 Python
python中dict()的高级用法实现
Nov 13 Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 Python
关于Django Models CharField 参数说明
Mar 31 Python
Python Spyder 调出缩进对齐线的操作
Feb 26 Python
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
pytorch 实现在测试的时候启用dropout
使用Python脚本对GiteePages进行一键部署的使用说明
教你使用Python pypinyin库实现汉字转拼音
基于tensorflow权重文件的解读
May 26 #Python
解决Python字典查找报Keyerror的问题
You might like
解析php开发中的中文编码问题
2013/08/08 PHP
jquery动态加载图片数据练习代码
2011/08/04 Javascript
自己写的兼容ie和ff的在线文本编辑器类似ewebeditor
2012/12/12 Javascript
js弹出的对话窗口永远保持居中显示
2012/12/15 Javascript
浅析jQuery对select操作小结(遍历option,操作option)
2013/07/04 Javascript
jquery 按钮状态效果 正常、移上、按下
2013/08/12 Javascript
jQuery遍历json中多个map的方法
2015/02/12 Javascript
jQuery EasyUI 菜单与按钮之创建简单的菜单和链接按钮
2015/11/18 Javascript
JS验证邮件地址格式方法小结
2015/12/01 Javascript
Javascript随机标签云代码实例
2016/06/21 Javascript
详细介绍RxJS在Angular中的应用
2017/09/23 Javascript
代码整洁之道(重构)
2018/10/25 Javascript
基于Vue实现平滑过渡的拖拽排序功能
2019/06/12 Javascript
Python列表append和+的区别浅析
2015/02/02 Python
python基于windows平台锁定键盘输入的方法
2015/03/05 Python
Python实现合并字典的方法
2015/07/07 Python
python中函数总结之装饰器闭包详解
2016/06/12 Python
python实现多层感知器MLP(基于双月数据集)
2019/01/18 Python
python爬虫 execjs安装配置及使用
2019/07/30 Python
python从内存地址上加载python对象过程详解
2020/01/08 Python
Python+OpenCV实现图像的全景拼接
2020/03/05 Python
Django集成MongoDB实现过程解析
2020/12/01 Python
如何用python开发Zeroc Ice应用
2021/01/29 Python
详解HTML5中垂直上下居中的解决方案
2017/12/20 HTML / CSS
美国LOGO设计公司:The Logo Company
2018/07/16 全球购物
大学生职业生涯规划范文
2014/01/22 职场文书
小班重阳节活动方案
2014/02/08 职场文书
学校节能宣传周活动总结
2014/07/09 职场文书
交通事故赔偿协议书
2014/10/16 职场文书
企业承诺书格式范文
2015/04/28 职场文书
创业计划书之旅游网站
2019/09/06 职场文书
五年级作文之成长
2019/09/16 职场文书
MySQL注入基础练习
2021/05/30 MySQL
Java中的继承、多态以及封装
2022/04/11 Java/Android
Java 轮询锁使用时遇到问题
2022/05/11 Java/Android
SpringBoot详解自定义Stater的应用
2022/07/15 Java/Android