Pytorch中的数据集划分&正则化方法


Posted in Python onMay 27, 2021

1.训练集&验证集&测试集

训练集:训练数据

验证集:验证不同算法(比如利用网格搜索对超参数进行调整等),检验哪种更有效

测试集:正确评估分类器的性能

正常流程:验证集会记录每个时间戳的参数,在加载test数据前会加载那个最好的参数,再来评估。比方说训练完6000个epoch后,发现在第3520个epoch的validation表现最好,测试时会加载第3520个epoch的参数。

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transforms
#超参数
batch_size=200
learning_rate=0.01
epochs=10
#获取训练数据
train_db = datasets.MNIST('../data', train=True, download=True,   #train=True则得到的是训练集
                   transform=transforms.Compose([                 #transform进行数据预处理
                       transforms.ToTensor(),                     #转成Tensor类型的数据
                       transforms.Normalize((0.1307,), (0.3081,)) #进行数据标准化(减去均值除以方差)
                   ]))
#DataLoader把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化
train_loader = torch.utils.data.DataLoader(train_db, batch_size=batch_size, shuffle=True)
#获取测试数据
test_db = datasets.MNIST('../data', train=False,
                   transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                   ]))
test_loader = torch.utils.data.DataLoader(test_db, batch_size=batch_size, shuffle=True)
#将训练集拆分成训练集和验证集
print('train:', len(train_db), 'test:', len(test_db))                              #train: 60000 test: 10000
train_db, val_db = torch.utils.data.random_split(train_db, [50000, 10000])
print('db1:', len(train_db), 'db2:', len(val_db))                                  #db1: 50000 db2: 10000
train_loader = torch.utils.data.DataLoader(train_db, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_db, batch_size=batch_size, shuffle=True)
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.model = nn.Sequential(         #定义网络的每一层,
            nn.Linear(784, 200),
            nn.ReLU(inplace=True),
            nn.Linear(200, 200),
            nn.ReLU(inplace=True),
            nn.Linear(200, 10),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        x = self.model(x)
        return x
net = MLP()
#定义sgd优化器,指明优化参数、学习率,net.parameters()得到这个类所定义的网络的参数[[w1,b1,w2,b2,...]
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss()
for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28*28)          #将二维的图片数据摊平[样本数,784]
        logits = net(data)                   #前向传播
        loss = criteon(logits, target)       #nn.CrossEntropyLoss()自带Softmax
        optimizer.zero_grad()                #梯度信息清空
        loss.backward()                      #反向传播获取梯度
        optimizer.step()                     #优化器更新
        if batch_idx % 100 == 0:             #每100个batch输出一次信息
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))
    #验证集用来检测训练是否过拟合
    val_loss = 0
    correct = 0
    for data, target in val_loader:
        data = data.view(-1, 28 * 28)
        logits = net(data)
        val_loss += criteon(logits, target).item()
        pred = logits.data.max(dim=1)[1]
        correct += pred.eq(target.data).sum()
    val_loss /= len(val_loader.dataset)
    print('\nVAL set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        val_loss, correct, len(val_loader.dataset),
        100. * correct / len(val_loader.dataset)))
#测试集用来评估
test_loss = 0
correct = 0                                         #correct记录正确分类的样本数
for data, target in test_loader:
    data = data.view(-1, 28 * 28)
    logits = net(data)
    test_loss += criteon(logits, target).item()     #其实就是criteon(logits, target)的值,标量
    pred = logits.data.max(dim=1)[1]                #也可以写成pred=logits.argmax(dim=1)
    correct += pred.eq(target.data).sum()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

2.正则化

正则化可以解决过拟合问题。

2.1L2范数(更常用)

在定义优化器的时候设定weigth_decay,即L2范数前面的λ参数。

optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, weight_decay=0.01)

2.2L1范数

Pytorch没有直接可以调用的方法,实现如下:

Pytorch中的数据集划分&正则化方法

3.动量(Momentum)

Adam优化器内置了momentum,SGD需要手动设置。

optimizer = torch.optim.SGD(model.parameters(), args=lr, momentum=args.momentum, weight_decay=args.weight_decay)

4.学习率衰减

torch.optim.lr_scheduler 中提供了基于多种epoch数目调整学习率的方法。

4.1torch.optim.lr_scheduler.ReduceLROnPlateau:基于测量指标对学习率进行动态的下降

torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)

训练过程中,optimizer会把learning rate 交给scheduler管理,当指标(比如loss)连续patience次数还没有改进时,需要降低学习率,factor为每次下降的比例。

scheduler.step(loss_val)每调用一次就会监听一次loss_val。

Pytorch中的数据集划分&正则化方法

4.2torch.optim.lr_scheduler.StepLR:基于epoch

torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1)

当epoch每过stop_size时,学习率都变为初始学习率的gamma倍。

Pytorch中的数据集划分&正则化方法

5.提前停止(防止overfitting)

基于经验值。

6.Dropout随机失活

遍历每一层,设置消除神经网络中的节点概率,得到精简后的一个样本。

torch.nn.Dropout(p=dropout_prob)

p表示的示的是删除节点数的比例(Tip:tensorflow中keep_prob表示保留节点数的比例,不要混淆)

Pytorch中的数据集划分&正则化方法

测试阶段无需使用dropout,所以在train之前执行net_dropped.train()相当于启用dropout,测试之前执行net_dropped.eval()相当于不启用dropout。

Pytorch中的数据集划分&正则化方法

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
gearman的安装启动及python API使用实例
Jul 08 Python
Python Web框架Flask中使用七牛云存储实例
Feb 08 Python
简单谈谈python中的多进程
Nov 06 Python
答题辅助python代码实现
Jan 16 Python
python 爬虫 批量获取代理ip的实例代码
May 22 Python
Django 大文件下载实现过程解析
Aug 01 Python
Django 框架模型操作入门教程
Nov 05 Python
python计算波峰波谷值的方法(极值点)
Feb 18 Python
详解python logging日志传输
Jul 01 Python
Expected conditions模块使用方法汇总代码解析
Aug 13 Python
梳理总结Python开发中需要摒弃的18个坏习惯
Jan 22 Python
python开发制作好看的时钟效果
May 02 Python
Pytorch 如何实现常用正则化
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
pytorch 实现在测试的时候启用dropout
使用Python脚本对GiteePages进行一键部署的使用说明
教你使用Python pypinyin库实现汉字转拼音
基于tensorflow权重文件的解读
May 26 #Python
You might like
解析php安全性问题中的:Null 字符问题
2013/06/21 PHP
PHP下通过QRCode类库创建中间带网站LOGO的二维码
2014/07/12 PHP
PHP实现的同步推荐操作API接口案例分析
2016/11/30 PHP
PHP实现的DES加密解密类定义与用法示例
2020/11/02 PHP
求得div 下 img的src地址的js代码
2007/02/28 Javascript
javascript实现二分查找法实现代码
2007/11/12 Javascript
javascript使用中为什么10..toString()正常而10.toString()出错呢
2013/01/11 Javascript
GridView中获取被点击行中的DropDownList和TextBox中的值
2013/07/18 Javascript
js解决select下拉选不中问题
2014/10/14 Javascript
jQuery实现用户输入自动完成功能
2017/02/13 Javascript
vue省市区三联动下拉选择组件的实现
2017/04/28 Javascript
浅谈如何使用 webpack 优化资源
2017/10/20 Javascript
JavaScript面向对象编程小游戏---贪吃蛇代码实例
2019/05/15 Javascript
bootstrap-table+treegrid实现树形表格
2019/07/26 Javascript
JS如何实现网站中PC端和手机端自动识别并跳转对应的代码
2020/01/08 Javascript
[45:46]2014 DOTA2国际邀请赛中国区预选赛5.21 HGT VS DT
2014/05/23 DOTA
[56:41]2018DOTA2亚洲邀请赛 3.31 小组赛 A组 Newbee vs OG
2018/04/01 DOTA
python实现带验证码网站的自动登陆实现代码
2015/01/12 Python
Python命令行参数解析模块optparse使用实例
2015/04/13 Python
python绘制随机网络图形示例
2019/11/21 Python
python闭包、深浅拷贝、垃圾回收、with语句知识点汇总
2020/03/11 Python
Django模型中字段属性choice使用说明
2020/03/30 Python
Python如何读取、写入JSON数据
2020/07/28 Python
Python3 + Appium + 安卓模拟器实现APP自动化测试并生成测试报告
2021/01/27 Python
Cinque网上商店:德国服装品牌
2019/03/17 全球购物
英国在线药房和在线医生:LloydsPharmacy
2019/10/21 全球购物
夏季奶茶店创业计划书
2014/01/16 职场文书
技校个人求职信范文
2014/01/25 职场文书
员工年终自我评价
2014/09/14 职场文书
学生夜不归宿检讨书
2014/09/23 职场文书
2014年高中生自我评价范文
2014/09/26 职场文书
质监局领导班子对照检查材料思想汇报
2014/09/27 职场文书
nginx反向代理配置去除前缀案例教程
2021/07/26 Servers
用Python实现屏幕截图详解
2022/01/22 Python
python中字符串String及其常见操作指南(方法、函数)
2022/04/06 Python
css3属性选择器 “~”(波浪号) “,”(逗号) “+”(加号)和 “>”(大于号)
2022/04/19 HTML / CSS