Pytorch 如何实现常用正则化


Posted in Python onMay 27, 2021

Stochastic Depth

论文:Deep Networks with Stochastic Depth

本文的正则化针对于ResNet中的残差结构,类似于dropout的原理,训练时对模块进行随机的删除,从而提升模型的泛化能力。

Pytorch 如何实现常用正则化

对于上述的ResNet网络,模块越在后面被drop掉的概率越大。

作者直觉上认为前期提取的低阶特征会被用于后面的层。

第一个模块保留的概率为1,之后保留概率随着深度线性递减。

对一个模块的drop函数可以采用如下的方式实现:

def drop_connect(inputs, p, training):
    """ Drop connect. """
    if not training: return inputs # 测试阶段
    batch_size = inputs.shape[0]
    keep_prob = 1 - p
    random_tensor = keep_prob
    random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
    # 以样本为单位生成模块是否被drop的01向量
    binary_tensor = torch.floor(random_tensor) 
    # 因为越往后越容易被drop,所以没有被drop的值就要通过除keep_prob来放大
    output = inputs / keep_prob * binary_tensor
    return output

在Pytorch建立的Module类中,具有forward函数

可以在forward函数中进行drop:

def forward(self, x):
 x=...
 if stride == 1 and in_planes == out_planes:
        if drop_connect_rate:
            x = drop_connect(x, p=drop_connect_rate, training=self.training)
        x = x + inputs  # skip connection
    return x

主函数:

for idx, block in enumerate(self._blocks):
    drop_connect_rate = self._global_params.drop_connect_rate
    if drop_connect_rate:
        drop_connect_rate *= float(idx) / len(self._blocks)
    x = block(x, drop_connect_rate=drop_connect_rate)

补充:pytorch中的L2正则化实现方法

搭建神经网络时需要使用L2正则化等操作来防止过拟合,而pytorch不像TensorFlow能在任意卷积函数中添加L2正则化的超参,那怎么在pytorch中实现L2正则化呢?

方法如下:超级简单!

optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=5.0)

torch.optim.Adam()参数中的 weight_decay=5.0 即为L2正则化(只是pytorch换了名字),其数值即为L2正则化的惩罚系数,一般设置为1、5、10(根据需要设置,默认为0,不使用L2正则化)。

注:

pytorch中的优化函数L2正则化默认对所有网络参数进行惩罚,且只能实现L2正则化,如需只惩罚指定网络层参数或采用L1正则化,只能自己定义。。。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python随机生成数据后插入到PostgreSQL
Jul 28 Python
python sys.argv[]用法实例详解
May 25 Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 Python
python实现石头剪刀布小游戏
Jan 20 Python
Python时间和字符串转换操作实例分析
Mar 16 Python
python实现五子棋游戏
Jun 18 Python
Django实现微信小程序的登录验证功能并维护登录态
Jul 04 Python
python列表生成器迭代器实例解析
Dec 19 Python
Python中的Cookie模块如何使用
Jun 04 Python
浅析Python 责任链设计模式
Sep 11 Python
用Python自动清理电脑内重复文件,只要10行代码(自动脚本)
Jan 09 Python
Python时间操作之pytz模块使用详解
Jun 14 Python
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
pytorch 实现在测试的时候启用dropout
使用Python脚本对GiteePages进行一键部署的使用说明
教你使用Python pypinyin库实现汉字转拼音
基于tensorflow权重文件的解读
May 26 #Python
解决Python字典查找报Keyerror的问题
You might like
3种平台下安装php4经验点滴
2006/10/09 PHP
require(),include(),require_once()和include_once()的异同
2007/01/02 PHP
Discuz 模板语句分析及知识技巧
2009/08/21 PHP
php 从数据库提取二进制图片的处理代码
2009/09/09 PHP
通过PHP修改Linux或Unix口令的方法分享
2012/01/30 PHP
限制文本字节数js代码
2007/03/06 Javascript
利用location.hash实现跨域iframe自适应
2010/05/04 Javascript
理解Javascript_14_函数形式参数与arguments
2010/10/20 Javascript
juqery 学习之三 选择器 可见性 元素属性
2010/11/25 Javascript
初识JQuery 实例一(first)
2011/03/16 Javascript
3种不同方式的焦点图轮播特效分享
2013/10/30 Javascript
Node.js实现批量去除BOM文件头
2014/12/20 Javascript
AngularJS中的Directive自定义一个表格
2016/01/25 Javascript
JavaScript下的时间格式处理函数Date.prototype.format
2016/01/27 Javascript
Node.js使用NodeMailer发送邮件实例代码
2017/03/06 Javascript
浅谈es6语法 (Proxy和Reflect的对比)
2017/10/24 Javascript
关于Angularjs中自定义指令一些有价值的细节和技巧小结
2018/04/22 Javascript
JavaScript求一个数组中重复出现次数最多的元素及其下标位置示例
2018/07/23 Javascript
vue.js input框之间赋值方法
2018/08/24 Javascript
Vue实现表格中对数据进行转换、处理的方法
2018/09/06 Javascript
webpack 代码分离优化快速指北
2019/05/18 Javascript
python面向对象_详谈类的继承与方法的重载
2017/06/07 Python
利用Python批量压缩png方法实例(支持过滤个别文件与文件夹)
2017/07/30 Python
利用python实现微信头像加红色数字功能
2018/03/26 Python
3行Python代码实现图像照片抠图和换底色的方法
2019/10/10 Python
解决使用python print打印函数返回值多一个None的问题
2020/04/09 Python
python super()函数的基本使用
2020/09/10 Python
美国当红的名品折扣网:Gilt Groupe
2016/08/15 全球购物
罗马尼亚购物网站:Vivantis.ro
2019/07/20 全球购物
什么是makefile? 如何编写makefile?
2012/08/08 面试题
大学自我评价
2014/02/12 职场文书
优秀教师工作感言
2014/02/16 职场文书
顶碗少年教学反思
2014/02/21 职场文书
体育教育毕业生自荐信
2014/06/29 职场文书
女生抽烟检讨书
2014/10/05 职场文书
食堂卫生管理制度
2015/08/04 职场文书