Pytorch 如何实现常用正则化


Posted in Python onMay 27, 2021

Stochastic Depth

论文:Deep Networks with Stochastic Depth

本文的正则化针对于ResNet中的残差结构,类似于dropout的原理,训练时对模块进行随机的删除,从而提升模型的泛化能力。

Pytorch 如何实现常用正则化

对于上述的ResNet网络,模块越在后面被drop掉的概率越大。

作者直觉上认为前期提取的低阶特征会被用于后面的层。

第一个模块保留的概率为1,之后保留概率随着深度线性递减。

对一个模块的drop函数可以采用如下的方式实现:

def drop_connect(inputs, p, training):
    """ Drop connect. """
    if not training: return inputs # 测试阶段
    batch_size = inputs.shape[0]
    keep_prob = 1 - p
    random_tensor = keep_prob
    random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
    # 以样本为单位生成模块是否被drop的01向量
    binary_tensor = torch.floor(random_tensor) 
    # 因为越往后越容易被drop,所以没有被drop的值就要通过除keep_prob来放大
    output = inputs / keep_prob * binary_tensor
    return output

在Pytorch建立的Module类中,具有forward函数

可以在forward函数中进行drop:

def forward(self, x):
 x=...
 if stride == 1 and in_planes == out_planes:
        if drop_connect_rate:
            x = drop_connect(x, p=drop_connect_rate, training=self.training)
        x = x + inputs  # skip connection
    return x

主函数:

for idx, block in enumerate(self._blocks):
    drop_connect_rate = self._global_params.drop_connect_rate
    if drop_connect_rate:
        drop_connect_rate *= float(idx) / len(self._blocks)
    x = block(x, drop_connect_rate=drop_connect_rate)

补充:pytorch中的L2正则化实现方法

搭建神经网络时需要使用L2正则化等操作来防止过拟合,而pytorch不像TensorFlow能在任意卷积函数中添加L2正则化的超参,那怎么在pytorch中实现L2正则化呢?

方法如下:超级简单!

optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=5.0)

torch.optim.Adam()参数中的 weight_decay=5.0 即为L2正则化(只是pytorch换了名字),其数值即为L2正则化的惩罚系数,一般设置为1、5、10(根据需要设置,默认为0,不使用L2正则化)。

注:

pytorch中的优化函数L2正则化默认对所有网络参数进行惩罚,且只能实现L2正则化,如需只惩罚指定网络层参数或采用L1正则化,只能自己定义。。。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基础入门详解(文件输入/输出 内建类型 字典操作使用方法)
Dec 08 Python
常见的在Python中实现单例模式的三种方法
Apr 08 Python
Django框架下在URLconf中指定视图缓存的方法
Jul 23 Python
python搭建虚拟环境的步骤详解
Sep 27 Python
python3+PyQt5实现拖放功能
Apr 24 Python
python微信公众号之关注公众号自动回复
Oct 25 Python
Python API 自动化实战详解(纯代码)
Jun 11 Python
python绘制玫瑰的实现代码
Mar 02 Python
django执行数据库查询之后实现返回的结果集转json
Mar 31 Python
python怎么提高计算速度
Jun 11 Python
python3.7添加dlib模块的方法
Jul 01 Python
在Python中如何使用yield
Jun 07 Python
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
pytorch 实现在测试的时候启用dropout
使用Python脚本对GiteePages进行一键部署的使用说明
教你使用Python pypinyin库实现汉字转拼音
基于tensorflow权重文件的解读
May 26 #Python
解决Python字典查找报Keyerror的问题
You might like
用php实现百度网盘图片直链的代码分享
2012/11/01 PHP
mysqli_set_charset和SET NAMES使用抉择及优劣分析
2013/01/13 PHP
将二维数组转为一维数组的2种方法
2014/05/26 PHP
Smarty foreach控制循环次数的一些方法
2015/07/01 PHP
php封装的图片(缩略图)处理类完整实例
2016/10/19 PHP
jqGrid增加时--判断开始日期与结束日期(实例解析)
2013/11/08 Javascript
jQuery插件formValidator实现表单验证
2016/05/23 Javascript
简单的vue-resourse获取json并应用到模板示例
2017/02/10 Javascript
JS图片预加载插件详解
2017/06/21 Javascript
vuejs父子组件之间数据交互详解
2017/08/09 Javascript
vue router 配置路由的方法
2018/07/26 Javascript
vue  自定义组件实现通讯录功能
2018/09/30 Javascript
vue cli使用融云实现聊天功能的实例代码
2019/04/19 Javascript
layui实现二维码弹窗、并下载到本地的方法
2019/09/25 Javascript
Vue 实现分页与输入框关键字筛选功能
2020/01/02 Javascript
react基本安装与测试示例
2020/04/27 Javascript
在Python中操作字符串之startswith()方法的使用
2015/05/20 Python
常见的python正则用法实例讲解
2016/06/21 Python
Python利用递归和walk()遍历目录文件的方法示例
2017/07/14 Python
python删除不需要的python文件方法
2018/04/24 Python
无法使用pip命令安装python第三方库的原因及解决方法
2018/06/12 Python
Python实现随机漫步功能
2018/07/09 Python
5分钟 Pipenv 上手指南
2018/12/20 Python
新年快乐! python实现绚烂的烟花绽放效果
2019/01/30 Python
python使用celery实现异步任务执行的例子
2019/08/28 Python
python爬虫利用代理池更换IP的方法步骤
2021/02/21 Python
英国假睫毛购买网站:FalseEyelashes.co.uk
2018/05/23 全球购物
信用卡工资证明格式
2014/09/13 职场文书
四风批评与自我批评范文
2014/10/14 职场文书
教师节倡议书2015
2015/04/27 职场文书
2015年煤矿工作总结
2015/04/28 职场文书
2015年反腐倡廉工作总结
2015/05/14 职场文书
民间借贷借条如何写
2015/05/26 职场文书
火烧圆明园观后感
2015/06/03 职场文书
贫民窟的百万富翁观后感
2015/06/09 职场文书
Spring boot admin 服务监控利器详解
2022/08/05 Java/Android