Pytorch 如何实现常用正则化


Posted in Python onMay 27, 2021

Stochastic Depth

论文:Deep Networks with Stochastic Depth

本文的正则化针对于ResNet中的残差结构,类似于dropout的原理,训练时对模块进行随机的删除,从而提升模型的泛化能力。

Pytorch 如何实现常用正则化

对于上述的ResNet网络,模块越在后面被drop掉的概率越大。

作者直觉上认为前期提取的低阶特征会被用于后面的层。

第一个模块保留的概率为1,之后保留概率随着深度线性递减。

对一个模块的drop函数可以采用如下的方式实现:

def drop_connect(inputs, p, training):
    """ Drop connect. """
    if not training: return inputs # 测试阶段
    batch_size = inputs.shape[0]
    keep_prob = 1 - p
    random_tensor = keep_prob
    random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
    # 以样本为单位生成模块是否被drop的01向量
    binary_tensor = torch.floor(random_tensor) 
    # 因为越往后越容易被drop,所以没有被drop的值就要通过除keep_prob来放大
    output = inputs / keep_prob * binary_tensor
    return output

在Pytorch建立的Module类中,具有forward函数

可以在forward函数中进行drop:

def forward(self, x):
 x=...
 if stride == 1 and in_planes == out_planes:
        if drop_connect_rate:
            x = drop_connect(x, p=drop_connect_rate, training=self.training)
        x = x + inputs  # skip connection
    return x

主函数:

for idx, block in enumerate(self._blocks):
    drop_connect_rate = self._global_params.drop_connect_rate
    if drop_connect_rate:
        drop_connect_rate *= float(idx) / len(self._blocks)
    x = block(x, drop_connect_rate=drop_connect_rate)

补充:pytorch中的L2正则化实现方法

搭建神经网络时需要使用L2正则化等操作来防止过拟合,而pytorch不像TensorFlow能在任意卷积函数中添加L2正则化的超参,那怎么在pytorch中实现L2正则化呢?

方法如下:超级简单!

optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=5.0)

torch.optim.Adam()参数中的 weight_decay=5.0 即为L2正则化(只是pytorch换了名字),其数值即为L2正则化的惩罚系数,一般设置为1、5、10(根据需要设置,默认为0,不使用L2正则化)。

注:

pytorch中的优化函数L2正则化默认对所有网络参数进行惩罚,且只能实现L2正则化,如需只惩罚指定网络层参数或采用L1正则化,只能自己定义。。。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python的tkinter布局之简单的聊天窗口实现方法
Sep 03 Python
python实现根据窗口标题调用窗口的方法
Mar 13 Python
用C++封装MySQL的API的教程
May 06 Python
python学习笔记之调用eval函数出现invalid syntax错误问题
Oct 18 Python
Python抓取电影天堂电影信息的代码
Apr 07 Python
有趣的python小程序分享
Dec 05 Python
Python文本处理之按行处理大文件的方法
Apr 09 Python
详谈pandas中agg函数和apply函数的区别
Apr 20 Python
对python中Matplotlib的坐标轴的坐标区间的设定实例讲解
May 25 Python
对python中assert、isinstance的用法详解
Nov 27 Python
Python流程控制常用工具详解
Feb 24 Python
Python极值整数的边界探讨分析
Sep 15 Python
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
pytorch 实现在测试的时候启用dropout
使用Python脚本对GiteePages进行一键部署的使用说明
教你使用Python pypinyin库实现汉字转拼音
基于tensorflow权重文件的解读
May 26 #Python
解决Python字典查找报Keyerror的问题
You might like
php正则修正符用法实例详解
2016/12/29 PHP
javascript 装载iframe子页面,自适应高度
2009/03/20 Javascript
Js从头学起(基本数据类型和引用类型的参数传递详细分析)
2012/02/16 Javascript
jQuery循环动画与获取组件尺寸的方法
2015/02/02 Javascript
JS实现左右拖动改变内容显示区域大小的方法
2015/10/13 Javascript
jsonp跨域请求数据实现手机号码查询实例分析
2015/12/12 Javascript
轮播的简单实现方法
2016/07/28 Javascript
AngularJS入门教程之静态模板详解
2016/08/18 Javascript
Jquery组件easyUi实现手风琴(折叠面板)示例
2016/08/23 Javascript
js内置对象处理_打印学生成绩单的简单实现
2016/09/24 Javascript
Bootstrap企业网站实战项目4
2016/10/14 Javascript
详解Angularjs在控制器(controller.js)中使用过滤器($filter)格式化日期/时间实例
2017/02/17 Javascript
基于Bootstrap漂亮简洁的CSS3价格表(附源码下载)
2017/02/28 Javascript
利用angularjs1.4制作的简易滑动门效果
2017/02/28 Javascript
jquery实现图片平滑滚动详解
2017/03/22 jQuery
详解利用 Express 托管静态文件的方法
2017/09/18 Javascript
JavaScript设计模式之工厂模式和抽象工厂模式定义与用法分析
2018/07/26 Javascript
微信小程序自定义底部导航带跳转功能
2018/11/27 Javascript
Vue 数组和对象更新,但是页面没有刷新的解决方式
2019/11/09 Javascript
d3.js实现图形缩放平移
2019/12/19 Javascript
vue+element获取el-table某行的下标,根据下标操作数组对象方式
2020/08/07 Javascript
Python的Django框架安装全攻略
2015/07/15 Python
Python Numpy 数组的初始化和基本操作
2018/03/13 Python
Django 后台获取文件列表 InMemoryUploadedFile的例子
2019/08/07 Python
python中执行smtplib失败的处理方法
2020/07/01 Python
CSS3解决移动页面上点击链接触发色块的问题
2016/06/03 HTML / CSS
临床医学专业毕业生的自我评价
2013/10/17 职场文书
如何写一份好的自荐信
2014/01/02 职场文书
你的创业计划书怎样才能打动风投
2014/02/06 职场文书
英语教师岗位职责
2014/03/16 职场文书
纪检干部现实表现材料
2014/08/21 职场文书
幼儿园园务工作总结2015
2015/05/18 职场文书
2016年端午节红领巾广播稿
2015/12/18 职场文书
JavaScript严格模式不支持八进制的问题讲解
2021/11/07 Javascript
《辉夜大小姐想让我告白》第三季正式预告
2022/03/20 日漫
python 单机五子棋对战游戏
2022/04/28 Python