Pytorch 如何实现常用正则化


Posted in Python onMay 27, 2021

Stochastic Depth

论文:Deep Networks with Stochastic Depth

本文的正则化针对于ResNet中的残差结构,类似于dropout的原理,训练时对模块进行随机的删除,从而提升模型的泛化能力。

Pytorch 如何实现常用正则化

对于上述的ResNet网络,模块越在后面被drop掉的概率越大。

作者直觉上认为前期提取的低阶特征会被用于后面的层。

第一个模块保留的概率为1,之后保留概率随着深度线性递减。

对一个模块的drop函数可以采用如下的方式实现:

def drop_connect(inputs, p, training):
    """ Drop connect. """
    if not training: return inputs # 测试阶段
    batch_size = inputs.shape[0]
    keep_prob = 1 - p
    random_tensor = keep_prob
    random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
    # 以样本为单位生成模块是否被drop的01向量
    binary_tensor = torch.floor(random_tensor) 
    # 因为越往后越容易被drop,所以没有被drop的值就要通过除keep_prob来放大
    output = inputs / keep_prob * binary_tensor
    return output

在Pytorch建立的Module类中,具有forward函数

可以在forward函数中进行drop:

def forward(self, x):
 x=...
 if stride == 1 and in_planes == out_planes:
        if drop_connect_rate:
            x = drop_connect(x, p=drop_connect_rate, training=self.training)
        x = x + inputs  # skip connection
    return x

主函数:

for idx, block in enumerate(self._blocks):
    drop_connect_rate = self._global_params.drop_connect_rate
    if drop_connect_rate:
        drop_connect_rate *= float(idx) / len(self._blocks)
    x = block(x, drop_connect_rate=drop_connect_rate)

补充:pytorch中的L2正则化实现方法

搭建神经网络时需要使用L2正则化等操作来防止过拟合,而pytorch不像TensorFlow能在任意卷积函数中添加L2正则化的超参,那怎么在pytorch中实现L2正则化呢?

方法如下:超级简单!

optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=5.0)

torch.optim.Adam()参数中的 weight_decay=5.0 即为L2正则化(只是pytorch换了名字),其数值即为L2正则化的惩罚系数,一般设置为1、5、10(根据需要设置,默认为0,不使用L2正则化)。

注:

pytorch中的优化函数L2正则化默认对所有网络参数进行惩罚,且只能实现L2正则化,如需只惩罚指定网络层参数或采用L1正则化,只能自己定义。。。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python类继承用法实例分析
May 27 Python
在Django中管理Users和Permissions以及Groups的方法
Jul 23 Python
python+ffmpeg视频并发直播压力测试
Mar 06 Python
分分钟入门python语言
Mar 20 Python
pandas 使用均值填充缺失值列的小技巧分享
Jul 04 Python
python实现发送form-data数据的方法详解
Sep 27 Python
python生成器推导式用法简单示例
Oct 08 Python
Python API自动化框架总结
Nov 12 Python
python 爬虫 实现增量去重和定时爬取实例
Feb 28 Python
Python修改DBF文件指定列
Dec 19 Python
Opencv 图片的OCR识别的实战示例
Mar 02 Python
pip install命令安装扩展库整理
Mar 02 Python
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
pytorch 实现在测试的时候启用dropout
使用Python脚本对GiteePages进行一键部署的使用说明
教你使用Python pypinyin库实现汉字转拼音
基于tensorflow权重文件的解读
May 26 #Python
解决Python字典查找报Keyerror的问题
You might like
IIS下配置Php+Mysql+zend的图文教程
2006/12/08 PHP
PHP CodeBase:将时间显示为"刚刚""n分钟/小时前"的方法详解
2013/06/06 PHP
初识ThinkPHP控制器
2016/04/07 PHP
PHP随机获取未被微信屏蔽的域名(微信域名检测)
2017/03/19 PHP
IE7提供XMLHttpRequest对象为兼容
2007/03/08 Javascript
基于jquery实现的类似百度搜索的输入框自动完成功能
2011/08/23 Javascript
windows系统下简单nodejs安装及环境配置
2013/01/08 NodeJs
Javascript倒计时页面跳转实例小结
2013/09/11 Javascript
JavaScript实现的日期控件具体代码
2013/11/18 Javascript
jQuery简单实现隐藏以及显示特效
2015/02/26 Javascript
JavaScript取得键盘按下方向键是哪个的方法
2015/08/04 Javascript
Javascript实现商品秒杀倒计时(时间与服务器时间同步)
2015/09/16 Javascript
JS实现添加,替换,删除节点元素的方法
2016/06/30 Javascript
详解bootstrap用dropdown-menu实现上下文菜单
2017/09/22 Javascript
React组件重构之嵌套+继承及高阶组件详解
2018/07/19 Javascript
vue+iview+less 实现换肤功能
2018/08/17 Javascript
基于vue中对鼠标划过事件的处理方式详解
2018/08/22 Javascript
Vue常见面试题整理【值得收藏】
2018/09/20 Javascript
JavaScript实现随机点名器实例详解
2019/05/07 Javascript
JS 逻辑判断不要只知道用 if-else 和 switch条件判断(小技巧)
2020/05/27 Javascript
python使用Tkinter显示网络图片的方法
2015/04/24 Python
详解Python使用simplejson模块解析JSON的方法
2016/03/24 Python
python实现朴素贝叶斯算法
2018/11/19 Python
python实现滑雪者小游戏
2020/02/22 Python
五种Python转义表示法
2020/11/27 Python
韩国三星集团旗下时尚品牌官网:SSF SHOP
2016/08/02 全球购物
有趣的流行文化T恤、马克杯、手机壳和更多:Look Human
2019/01/07 全球购物
澳洲本土太阳镜品牌:Quay Australia
2019/07/29 全球购物
2014年机关植树节活动方案
2014/02/27 职场文书
给校长的建议书300字
2014/05/16 职场文书
员工培训协议书
2014/09/15 职场文书
义诊活动通知
2015/04/24 职场文书
故意伤害罪辩护词
2015/05/21 职场文书
小学语文继续教育研修日志
2015/11/13 职场文书
祝福语集锦:朋友新店开业祝福语
2019/12/10 职场文书
win10音频服务未响应怎么解决?win10音频服务未响应未修复的解决方法
2022/08/14 数码科技