Pytorch 如何实现常用正则化


Posted in Python onMay 27, 2021

Stochastic Depth

论文:Deep Networks with Stochastic Depth

本文的正则化针对于ResNet中的残差结构,类似于dropout的原理,训练时对模块进行随机的删除,从而提升模型的泛化能力。

Pytorch 如何实现常用正则化

对于上述的ResNet网络,模块越在后面被drop掉的概率越大。

作者直觉上认为前期提取的低阶特征会被用于后面的层。

第一个模块保留的概率为1,之后保留概率随着深度线性递减。

对一个模块的drop函数可以采用如下的方式实现:

def drop_connect(inputs, p, training):
    """ Drop connect. """
    if not training: return inputs # 测试阶段
    batch_size = inputs.shape[0]
    keep_prob = 1 - p
    random_tensor = keep_prob
    random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
    # 以样本为单位生成模块是否被drop的01向量
    binary_tensor = torch.floor(random_tensor) 
    # 因为越往后越容易被drop,所以没有被drop的值就要通过除keep_prob来放大
    output = inputs / keep_prob * binary_tensor
    return output

在Pytorch建立的Module类中,具有forward函数

可以在forward函数中进行drop:

def forward(self, x):
 x=...
 if stride == 1 and in_planes == out_planes:
        if drop_connect_rate:
            x = drop_connect(x, p=drop_connect_rate, training=self.training)
        x = x + inputs  # skip connection
    return x

主函数:

for idx, block in enumerate(self._blocks):
    drop_connect_rate = self._global_params.drop_connect_rate
    if drop_connect_rate:
        drop_connect_rate *= float(idx) / len(self._blocks)
    x = block(x, drop_connect_rate=drop_connect_rate)

补充:pytorch中的L2正则化实现方法

搭建神经网络时需要使用L2正则化等操作来防止过拟合,而pytorch不像TensorFlow能在任意卷积函数中添加L2正则化的超参,那怎么在pytorch中实现L2正则化呢?

方法如下:超级简单!

optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=5.0)

torch.optim.Adam()参数中的 weight_decay=5.0 即为L2正则化(只是pytorch换了名字),其数值即为L2正则化的惩罚系数,一般设置为1、5、10(根据需要设置,默认为0,不使用L2正则化)。

注:

pytorch中的优化函数L2正则化默认对所有网络参数进行惩罚,且只能实现L2正则化,如需只惩罚指定网络层参数或采用L1正则化,只能自己定义。。。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 实时遍历日志文件
Apr 12 Python
Python编程实现双击更新所有已安装python模块的方法
Jun 05 Python
Python3.6简单反射操作示例
Jun 14 Python
Django中的Model操作表的实现
Jul 24 Python
Flask模拟实现CSRF攻击的方法
Jul 24 Python
Python拼接字符串的7种方法总结
Nov 01 Python
python实现对指定字符串补足固定长度倍数截断输出的方法
Nov 15 Python
python 实现调用子文件下的模块方法
Dec 07 Python
Django文件存储 默认存储系统解析
Aug 02 Python
pytorch构建多模型实例
Jan 15 Python
8种常用的Python工具
Aug 05 Python
快速创建python 虚拟环境
Nov 28 Python
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
pytorch 实现在测试的时候启用dropout
使用Python脚本对GiteePages进行一键部署的使用说明
教你使用Python pypinyin库实现汉字转拼音
基于tensorflow权重文件的解读
May 26 #Python
解决Python字典查找报Keyerror的问题
You might like
与空气斗智斗勇的经典《Overlord》,传说中的“无稽之谈”
2020/04/09 日漫
php获取twitter最新消息的方法
2015/04/14 PHP
PHP基于文件存储实现缓存的方法
2015/07/20 PHP
php多进程应用场景实例详解
2019/07/22 PHP
php设计模式之迭代器模式实例分析【星际争霸游戏案例】
2020/04/07 PHP
jquery 表单取值常用代码
2009/12/22 Javascript
Node.js中对通用模块的封装方法
2014/06/06 Javascript
基于AngularJS+HTML+Groovy实现登录功能
2016/02/17 Javascript
javascript瀑布流式图片懒加载实例解析与优化
2016/02/23 Javascript
下雪了 javascript实现雪花飞舞
2020/08/02 Javascript
jquery 无限极下拉菜单的简单实例(精简浓缩版)
2016/05/31 Javascript
Bootstrap select实现下拉框多选效果
2016/12/23 Javascript
NodeJS遍历文件生产文件列表功能示例
2017/01/22 NodeJs
Node.js Express安装与使用教程
2018/05/11 Javascript
vue 点击按钮增加一行的方法
2018/09/07 Javascript
JavaScript 继承 封装 多态实现及原理详解
2019/07/29 Javascript
JS实现可视化音频效果的实例代码
2020/01/16 Javascript
Vue中el-form标签中的自定义el-select下拉框标签功能
2020/04/20 Javascript
基于Python log 的正确打开方式
2018/04/28 Python
python numpy 部分排序 寻找最大的前几个数的方法
2018/06/27 Python
Pycharm简单使用教程(入门小结)
2019/07/04 Python
python 链接sqlserver 写接口实例
2020/03/11 Python
python实现将range()函数生成的数字存储在一个列表中
2020/04/02 Python
浅谈python量化 双均线策略(金叉死叉)
2020/06/03 Python
python 递归相关知识总结
2021/03/03 Python
美国综合购物商城:UnbeatableSale.com
2018/11/28 全球购物
2014年巴西世界杯口号
2014/06/05 职场文书
销售提升方案
2014/06/07 职场文书
村级四风对照检查材料
2014/08/24 职场文书
2014年终个人工作总结
2014/11/07 职场文书
公司处罚决定书
2015/06/24 职场文书
2015年科普工作总结
2015/07/23 职场文书
2016清明节森林防火广播稿
2015/12/17 职场文书
初中物理教学反思
2016/02/19 职场文书
导游词之山西-五老峰
2019/10/07 职场文书
MySQL查询学习之基础查询操作
2021/05/08 MySQL