Pytorch中的数据集划分&正则化方法


Posted in Python onMay 27, 2021

1.训练集&验证集&测试集

训练集:训练数据

验证集:验证不同算法(比如利用网格搜索对超参数进行调整等),检验哪种更有效

测试集:正确评估分类器的性能

正常流程:验证集会记录每个时间戳的参数,在加载test数据前会加载那个最好的参数,再来评估。比方说训练完6000个epoch后,发现在第3520个epoch的validation表现最好,测试时会加载第3520个epoch的参数。

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transforms
#超参数
batch_size=200
learning_rate=0.01
epochs=10
#获取训练数据
train_db = datasets.MNIST('../data', train=True, download=True,   #train=True则得到的是训练集
                   transform=transforms.Compose([                 #transform进行数据预处理
                       transforms.ToTensor(),                     #转成Tensor类型的数据
                       transforms.Normalize((0.1307,), (0.3081,)) #进行数据标准化(减去均值除以方差)
                   ]))
#DataLoader把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化
train_loader = torch.utils.data.DataLoader(train_db, batch_size=batch_size, shuffle=True)
#获取测试数据
test_db = datasets.MNIST('../data', train=False,
                   transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                   ]))
test_loader = torch.utils.data.DataLoader(test_db, batch_size=batch_size, shuffle=True)
#将训练集拆分成训练集和验证集
print('train:', len(train_db), 'test:', len(test_db))                              #train: 60000 test: 10000
train_db, val_db = torch.utils.data.random_split(train_db, [50000, 10000])
print('db1:', len(train_db), 'db2:', len(val_db))                                  #db1: 50000 db2: 10000
train_loader = torch.utils.data.DataLoader(train_db, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_db, batch_size=batch_size, shuffle=True)
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.model = nn.Sequential(         #定义网络的每一层,
            nn.Linear(784, 200),
            nn.ReLU(inplace=True),
            nn.Linear(200, 200),
            nn.ReLU(inplace=True),
            nn.Linear(200, 10),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        x = self.model(x)
        return x
net = MLP()
#定义sgd优化器,指明优化参数、学习率,net.parameters()得到这个类所定义的网络的参数[[w1,b1,w2,b2,...]
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss()
for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28*28)          #将二维的图片数据摊平[样本数,784]
        logits = net(data)                   #前向传播
        loss = criteon(logits, target)       #nn.CrossEntropyLoss()自带Softmax
        optimizer.zero_grad()                #梯度信息清空
        loss.backward()                      #反向传播获取梯度
        optimizer.step()                     #优化器更新
        if batch_idx % 100 == 0:             #每100个batch输出一次信息
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))
    #验证集用来检测训练是否过拟合
    val_loss = 0
    correct = 0
    for data, target in val_loader:
        data = data.view(-1, 28 * 28)
        logits = net(data)
        val_loss += criteon(logits, target).item()
        pred = logits.data.max(dim=1)[1]
        correct += pred.eq(target.data).sum()
    val_loss /= len(val_loader.dataset)
    print('\nVAL set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        val_loss, correct, len(val_loader.dataset),
        100. * correct / len(val_loader.dataset)))
#测试集用来评估
test_loss = 0
correct = 0                                         #correct记录正确分类的样本数
for data, target in test_loader:
    data = data.view(-1, 28 * 28)
    logits = net(data)
    test_loss += criteon(logits, target).item()     #其实就是criteon(logits, target)的值,标量
    pred = logits.data.max(dim=1)[1]                #也可以写成pred=logits.argmax(dim=1)
    correct += pred.eq(target.data).sum()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

2.正则化

正则化可以解决过拟合问题。

2.1L2范数(更常用)

在定义优化器的时候设定weigth_decay,即L2范数前面的λ参数。

optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, weight_decay=0.01)

2.2L1范数

Pytorch没有直接可以调用的方法,实现如下:

Pytorch中的数据集划分&正则化方法

3.动量(Momentum)

Adam优化器内置了momentum,SGD需要手动设置。

optimizer = torch.optim.SGD(model.parameters(), args=lr, momentum=args.momentum, weight_decay=args.weight_decay)

4.学习率衰减

torch.optim.lr_scheduler 中提供了基于多种epoch数目调整学习率的方法。

4.1torch.optim.lr_scheduler.ReduceLROnPlateau:基于测量指标对学习率进行动态的下降

torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)

训练过程中,optimizer会把learning rate 交给scheduler管理,当指标(比如loss)连续patience次数还没有改进时,需要降低学习率,factor为每次下降的比例。

scheduler.step(loss_val)每调用一次就会监听一次loss_val。

Pytorch中的数据集划分&正则化方法

4.2torch.optim.lr_scheduler.StepLR:基于epoch

torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1)

当epoch每过stop_size时,学习率都变为初始学习率的gamma倍。

Pytorch中的数据集划分&正则化方法

5.提前停止(防止overfitting)

基于经验值。

6.Dropout随机失活

遍历每一层,设置消除神经网络中的节点概率,得到精简后的一个样本。

torch.nn.Dropout(p=dropout_prob)

p表示的示的是删除节点数的比例(Tip:tensorflow中keep_prob表示保留节点数的比例,不要混淆)

Pytorch中的数据集划分&正则化方法

测试阶段无需使用dropout,所以在train之前执行net_dropped.train()相当于启用dropout,测试之前执行net_dropped.eval()相当于不启用dropout。

Pytorch中的数据集划分&正则化方法

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python函数缺省值与引用学习笔记分享
Feb 10 Python
python获得图片base64编码示例
Jan 16 Python
深入理解Python 代码优化详解
Oct 27 Python
python3实现读取chrome浏览器cookie
Jun 19 Python
Python使用selenium实现网页用户名 密码 验证码自动登录功能
May 16 Python
在Python中将函数作为另一个函数的参数传入并调用的方法
Jan 22 Python
python redis连接 有序集合去重的代码
Aug 04 Python
python 实现二维列表转置
Dec 02 Python
Python面向对象之私有属性和私有方法应用案例分析
Dec 31 Python
Python通过VGG16模型实现图像风格转换操作详解
Jan 16 Python
Python使用monkey.patch_all()解决协程阻塞问题
Apr 15 Python
使用tensorflow 实现反向传播求导
May 26 Python
Pytorch 如何实现常用正则化
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
pytorch 实现在测试的时候启用dropout
使用Python脚本对GiteePages进行一键部署的使用说明
教你使用Python pypinyin库实现汉字转拼音
基于tensorflow权重文件的解读
May 26 #Python
You might like
php表单提交问题的解决方法
2011/04/12 PHP
php 的加密函数 md5,crypt,base64_encode 等使用介绍
2012/04/09 PHP
PHP常量及变量区别原理详解
2020/08/14 PHP
JavaScript中令你抓狂的魔术变量
2006/11/30 Javascript
Javascript Select操作大集合
2009/05/26 Javascript
JScript 脚本实现文件下载 一般用于下载木马
2009/10/29 Javascript
JavaScript语言核心数据类型和变量使用介绍
2013/08/23 Javascript
js实现背景图片感应鼠标变化的方法
2015/02/28 Javascript
jQuery实现鼠标划过展示大图的方法
2015/03/09 Javascript
javascript移动开发中touch触摸事件详解
2016/03/18 Javascript
AngularJS 前台分页实现的示例代码
2018/06/07 Javascript
JavaScript设计模式之代理模式简单实例教程
2018/07/03 Javascript
关于微信小程序获取小程序码并接受buffer流保存为图片的方法
2019/06/07 Javascript
bootstrap 日期控件 datepicker被弹出框dialog覆盖的解决办法
2019/07/09 Javascript
js实现验证码功能
2020/07/24 Javascript
[02:52]DOTA2新手基础教程 米波
2014/01/21 DOTA
[43:47]完美世界DOTA2联赛PWL S3 LBZS vs Phoenix 第一场 12.09
2020/12/11 DOTA
Python学习资料
2007/02/08 Python
浅谈对yield的初步理解
2017/05/29 Python
CentOS 7下安装Python 3.5并与Python2.7兼容并存详解
2017/07/07 Python
对python3 一组数值的归一化处理方法详解
2018/07/11 Python
解决Python找不到ssl模块问题 No module named _ssl的方法
2019/04/29 Python
Django使用中间件解决前后端同源策略问题
2019/09/02 Python
Python 切分数组实例解析
2019/11/07 Python
Python箱型图处理离群点的例子
2019/12/09 Python
Python调用.NET库的方法步骤
2019/12/27 Python
python求前n个阶乘的和实例
2020/04/02 Python
移动端Html5中百度地图的点击事件
2019/01/31 HTML / CSS
app内嵌H5 webview 本地缓存问题的解决
2020/10/19 HTML / CSS
美国在线家装零售商:Build.com
2016/09/02 全球购物
七年级政治教学反思
2014/02/03 职场文书
设立有限责任公司出资协议书
2014/11/01 职场文书
2014年信息技术工作总结
2014/12/16 职场文书
先进个人评语大全
2015/01/04 职场文书
MySQL去除密码登录告警的方法
2022/04/20 MySQL
vue css 相对路径导入问题级踩坑记录
2022/06/05 Vue.js