Pytorch中的数据集划分&正则化方法


Posted in Python onMay 27, 2021

1.训练集&验证集&测试集

训练集:训练数据

验证集:验证不同算法(比如利用网格搜索对超参数进行调整等),检验哪种更有效

测试集:正确评估分类器的性能

正常流程:验证集会记录每个时间戳的参数,在加载test数据前会加载那个最好的参数,再来评估。比方说训练完6000个epoch后,发现在第3520个epoch的validation表现最好,测试时会加载第3520个epoch的参数。

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transforms
#超参数
batch_size=200
learning_rate=0.01
epochs=10
#获取训练数据
train_db = datasets.MNIST('../data', train=True, download=True,   #train=True则得到的是训练集
                   transform=transforms.Compose([                 #transform进行数据预处理
                       transforms.ToTensor(),                     #转成Tensor类型的数据
                       transforms.Normalize((0.1307,), (0.3081,)) #进行数据标准化(减去均值除以方差)
                   ]))
#DataLoader把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化
train_loader = torch.utils.data.DataLoader(train_db, batch_size=batch_size, shuffle=True)
#获取测试数据
test_db = datasets.MNIST('../data', train=False,
                   transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                   ]))
test_loader = torch.utils.data.DataLoader(test_db, batch_size=batch_size, shuffle=True)
#将训练集拆分成训练集和验证集
print('train:', len(train_db), 'test:', len(test_db))                              #train: 60000 test: 10000
train_db, val_db = torch.utils.data.random_split(train_db, [50000, 10000])
print('db1:', len(train_db), 'db2:', len(val_db))                                  #db1: 50000 db2: 10000
train_loader = torch.utils.data.DataLoader(train_db, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_db, batch_size=batch_size, shuffle=True)
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.model = nn.Sequential(         #定义网络的每一层,
            nn.Linear(784, 200),
            nn.ReLU(inplace=True),
            nn.Linear(200, 200),
            nn.ReLU(inplace=True),
            nn.Linear(200, 10),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        x = self.model(x)
        return x
net = MLP()
#定义sgd优化器,指明优化参数、学习率,net.parameters()得到这个类所定义的网络的参数[[w1,b1,w2,b2,...]
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss()
for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28*28)          #将二维的图片数据摊平[样本数,784]
        logits = net(data)                   #前向传播
        loss = criteon(logits, target)       #nn.CrossEntropyLoss()自带Softmax
        optimizer.zero_grad()                #梯度信息清空
        loss.backward()                      #反向传播获取梯度
        optimizer.step()                     #优化器更新
        if batch_idx % 100 == 0:             #每100个batch输出一次信息
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))
    #验证集用来检测训练是否过拟合
    val_loss = 0
    correct = 0
    for data, target in val_loader:
        data = data.view(-1, 28 * 28)
        logits = net(data)
        val_loss += criteon(logits, target).item()
        pred = logits.data.max(dim=1)[1]
        correct += pred.eq(target.data).sum()
    val_loss /= len(val_loader.dataset)
    print('\nVAL set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        val_loss, correct, len(val_loader.dataset),
        100. * correct / len(val_loader.dataset)))
#测试集用来评估
test_loss = 0
correct = 0                                         #correct记录正确分类的样本数
for data, target in test_loader:
    data = data.view(-1, 28 * 28)
    logits = net(data)
    test_loss += criteon(logits, target).item()     #其实就是criteon(logits, target)的值,标量
    pred = logits.data.max(dim=1)[1]                #也可以写成pred=logits.argmax(dim=1)
    correct += pred.eq(target.data).sum()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

2.正则化

正则化可以解决过拟合问题。

2.1L2范数(更常用)

在定义优化器的时候设定weigth_decay,即L2范数前面的λ参数。

optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, weight_decay=0.01)

2.2L1范数

Pytorch没有直接可以调用的方法,实现如下:

Pytorch中的数据集划分&正则化方法

3.动量(Momentum)

Adam优化器内置了momentum,SGD需要手动设置。

optimizer = torch.optim.SGD(model.parameters(), args=lr, momentum=args.momentum, weight_decay=args.weight_decay)

4.学习率衰减

torch.optim.lr_scheduler 中提供了基于多种epoch数目调整学习率的方法。

4.1torch.optim.lr_scheduler.ReduceLROnPlateau:基于测量指标对学习率进行动态的下降

torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)

训练过程中,optimizer会把learning rate 交给scheduler管理,当指标(比如loss)连续patience次数还没有改进时,需要降低学习率,factor为每次下降的比例。

scheduler.step(loss_val)每调用一次就会监听一次loss_val。

Pytorch中的数据集划分&正则化方法

4.2torch.optim.lr_scheduler.StepLR:基于epoch

torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1)

当epoch每过stop_size时,学习率都变为初始学习率的gamma倍。

Pytorch中的数据集划分&正则化方法

5.提前停止(防止overfitting)

基于经验值。

6.Dropout随机失活

遍历每一层,设置消除神经网络中的节点概率,得到精简后的一个样本。

torch.nn.Dropout(p=dropout_prob)

p表示的示的是删除节点数的比例(Tip:tensorflow中keep_prob表示保留节点数的比例,不要混淆)

Pytorch中的数据集划分&正则化方法

测试阶段无需使用dropout,所以在train之前执行net_dropped.train()相当于启用dropout,测试之前执行net_dropped.eval()相当于不启用dropout。

Pytorch中的数据集划分&正则化方法

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python完全新手教程
Feb 08 Python
Python基于pillow判断图片完整性的方法
Sep 18 Python
详解python中 os._exit() 和 sys.exit(), exit(0)和exit(1) 的用法和区别
Jun 23 Python
Python实现好友全头像的拼接实例(推荐)
Jun 24 Python
python列表生成式与列表生成器的使用
Feb 23 Python
Python读取properties配置文件操作示例
Mar 29 Python
Scrapy-Redis结合POST请求获取数据的方法示例
May 07 Python
通过pycharm使用git的步骤(图文详解)
Jun 13 Python
python 装饰器功能与用法案例详解
Mar 06 Python
python 实现仿微信聊天时间格式化显示的代码
Apr 17 Python
基于Python实现2种反转链表方法代码实例
Jul 06 Python
Python 3.9的到来到底是意味着什么
Oct 14 Python
Pytorch 如何实现常用正则化
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
pytorch 实现在测试的时候启用dropout
使用Python脚本对GiteePages进行一键部署的使用说明
教你使用Python pypinyin库实现汉字转拼音
基于tensorflow权重文件的解读
May 26 #Python
You might like
php导出word格式数据的代码实例
2013/11/25 PHP
php弹出对话框实现重定向代码
2014/01/23 PHP
使用PHP similar text计算两个字符串相似度
2015/11/06 PHP
php简单实现多语言切换的方法
2016/05/09 PHP
使用Git实现Laravel项目的自动化部署
2019/11/24 PHP
asp(javascript)全角半角转换代码 dbc2sbc
2009/08/06 Javascript
ajax更新数据后,jquery、jq失效问题
2011/03/16 Javascript
js为鼠标添加右击事件防止默认的右击菜单弹出
2013/07/29 Javascript
file控件选择上传文件确定后触发的js事件是哪个
2014/03/17 Javascript
理解javascript中的MVC模式
2016/01/28 Javascript
jQuery实现只允许输入数字和小数点的方法
2016/03/02 Javascript
AngularJS使用ocLazyLoad实现js延迟加载
2017/07/05 Javascript
jQuery基于cookie实现换肤功能实例
2017/10/14 jQuery
vue弹窗组件使用方法
2018/04/28 Javascript
vue实现键盘输入支付密码功能
2018/08/18 Javascript
vue项目初始化到登录login页面的示例
2019/10/31 Javascript
JavaScript命令模式原理与用法实例详解
2020/03/10 Javascript
如何正确解决VuePress本地访问出现资源报错404的问题
2020/12/03 Vue.js
下载糗事百科的内容_python版
2008/12/07 Python
python代码制作configure文件示例
2014/07/28 Python
python中去空格函数的用法
2014/08/21 Python
Python实现把回车符\r\n转换成\n
2015/04/23 Python
开始着手第一个Django项目
2015/07/15 Python
完美解决python遍历删除字典里值为空的元素报错问题
2016/09/11 Python
pycharm执行python时,填写参数的方法
2018/10/29 Python
python Django 创建应用过程图示详解
2019/07/29 Python
Python3.7将普通图片(png)转换为SVG图片格式(网站logo图标)动起来
2020/04/21 Python
python求解汉诺塔游戏
2020/07/09 Python
浅谈Python 命令行参数argparse写入图片路径操作
2020/07/12 Python
手摸手教你用canvas实现给图片添加平铺水印的实现
2019/08/20 HTML / CSS
机械专业个人求职自荐信格式
2013/09/21 职场文书
人力资源主管的岗位职责
2014/03/15 职场文书
感恩之星事迹材料
2014/05/03 职场文书
教育专业毕业生推荐信
2014/07/10 职场文书
涨价通知
2015/04/23 职场文书
副校长2015年教育教学工作总结
2015/07/27 职场文书