Pytorch 如何实现常用正则化


Posted in Python onMay 27, 2021

Stochastic Depth

论文:Deep Networks with Stochastic Depth

本文的正则化针对于ResNet中的残差结构,类似于dropout的原理,训练时对模块进行随机的删除,从而提升模型的泛化能力。

Pytorch 如何实现常用正则化

对于上述的ResNet网络,模块越在后面被drop掉的概率越大。

作者直觉上认为前期提取的低阶特征会被用于后面的层。

第一个模块保留的概率为1,之后保留概率随着深度线性递减。

对一个模块的drop函数可以采用如下的方式实现:

def drop_connect(inputs, p, training):
    """ Drop connect. """
    if not training: return inputs # 测试阶段
    batch_size = inputs.shape[0]
    keep_prob = 1 - p
    random_tensor = keep_prob
    random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
    # 以样本为单位生成模块是否被drop的01向量
    binary_tensor = torch.floor(random_tensor) 
    # 因为越往后越容易被drop,所以没有被drop的值就要通过除keep_prob来放大
    output = inputs / keep_prob * binary_tensor
    return output

在Pytorch建立的Module类中,具有forward函数

可以在forward函数中进行drop:

def forward(self, x):
 x=...
 if stride == 1 and in_planes == out_planes:
        if drop_connect_rate:
            x = drop_connect(x, p=drop_connect_rate, training=self.training)
        x = x + inputs  # skip connection
    return x

主函数:

for idx, block in enumerate(self._blocks):
    drop_connect_rate = self._global_params.drop_connect_rate
    if drop_connect_rate:
        drop_connect_rate *= float(idx) / len(self._blocks)
    x = block(x, drop_connect_rate=drop_connect_rate)

补充:pytorch中的L2正则化实现方法

搭建神经网络时需要使用L2正则化等操作来防止过拟合,而pytorch不像TensorFlow能在任意卷积函数中添加L2正则化的超参,那怎么在pytorch中实现L2正则化呢?

方法如下:超级简单!

optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=5.0)

torch.optim.Adam()参数中的 weight_decay=5.0 即为L2正则化(只是pytorch换了名字),其数值即为L2正则化的惩罚系数,一般设置为1、5、10(根据需要设置,默认为0,不使用L2正则化)。

注:

pytorch中的优化函数L2正则化默认对所有网络参数进行惩罚,且只能实现L2正则化,如需只惩罚指定网络层参数或采用L1正则化,只能自己定义。。。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python去掉字符串中空格的方法
Mar 11 Python
Python Sql数据库增删改查操作简单封装
Apr 18 Python
Python实现通过文件路径获取文件hash值的方法
Apr 29 Python
Python图像处理之图像的读取、显示与保存操作【测试可用】
Jan 04 Python
python ipset管理 增删白名单的方法
Jan 14 Python
python通用读取vcf文件的类(复制粘贴即可用)
Feb 29 Python
python实现猜数游戏
Mar 27 Python
python实现处理mysql结果输出方式
Apr 09 Python
python中 _、__、__xx__()区别及使用场景
Jun 30 Python
Python爬虫新手入门之初学lxml库
Dec 20 Python
Python运算符+与+=的方法实例
Feb 18 Python
pytorch常用数据类型所占字节数对照表一览
May 17 Python
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
pytorch 实现在测试的时候启用dropout
使用Python脚本对GiteePages进行一键部署的使用说明
教你使用Python pypinyin库实现汉字转拼音
基于tensorflow权重文件的解读
May 26 #Python
解决Python字典查找报Keyerror的问题
You might like
PHP中的流(streams)浅析
2015/07/02 PHP
PHP实现微信申请退款功能
2018/10/01 PHP
实例讲解php将字符串输出到HTML
2019/01/27 PHP
js保存当前路径(cookies记录)
2010/12/14 Javascript
表单验证的完整应用案例探讨
2013/03/29 Javascript
如何使Chrome控制台支持多行js模式——意外发现
2013/06/13 Javascript
JS Replace 全部替换字符的用法小结
2013/12/24 Javascript
jQuery表单验证功能实例
2015/08/28 Javascript
深入解析JavaScript中函数的Currying柯里化
2016/03/19 Javascript
JS实现的RGB网页颜色在线取色器完整实例
2016/12/21 Javascript
Validform验证时可以为空否则按照指定格式验证
2017/10/20 Javascript
Element-UI Table组件上添加列拖拽效果实现方法
2018/04/14 Javascript
使用webpack搭建react开发环境的方法
2018/05/15 Javascript
node.js环境搭建图文详解
2018/09/19 Javascript
Angular设置别名alias的方法
2018/11/08 Javascript
Vue CLI3.0中使用jQuery和Bootstrap的方法
2019/02/28 jQuery
vue-cli项目使用mock数据的方法(借助express)
2019/04/15 Javascript
JavaScript通如何过RGraph实现动态仪表盘
2020/10/15 Javascript
antd多选下拉框一行展示的实现方式
2020/10/31 Javascript
Python 冒泡,选择,插入排序使用实例
2015/02/05 Python
Python新手入门最容易犯的错误总结
2017/04/24 Python
浅谈Python 的枚举 Enum
2017/06/12 Python
Python实现的矩阵类实例
2017/08/22 Python
python的变量与赋值详细分析
2017/11/08 Python
python 删除字符串中连续多个空格并保留一个的方法
2018/12/22 Python
python开发游戏的前期准备
2019/05/05 Python
在python中实现同行输入/接收多个数据的示例
2019/07/20 Python
Python  Django 母版和继承解析
2019/08/09 Python
Python用来做Web开发的优势有哪些
2020/08/05 Python
匡威爱尔兰官网:Converse爱尔兰
2019/06/09 全球购物
会计求职信范文
2014/05/24 职场文书
新农村建设典型材料
2014/05/31 职场文书
租房安全协议书
2014/08/20 职场文书
小学运动会演讲稿
2014/08/25 职场文书
怎样写家长意见
2015/06/04 职场文书
设置IIS Express并发数
2022/07/07 Servers