使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式


Posted in Python onJanuary 08, 2020

简介

这是深度学习课程的第一个实验,主要目的就是熟悉 Pytorch 框架。MLP 是多层感知器,我这次实现的是四层感知器,代码和思路参考了网上的很多文章。个人认为,感知器的代码大同小异,尤其是用 Pytorch 实现,除了层数和参数外,代码都很相似。

Pytorch 写神经网络的主要步骤主要有以下几步:

1 构建网络结构

2 加载数据集

3 训练神经网络(包括优化器的选择和 Loss 的计算)

4 测试神经网络

下面将从这四个方面介绍 Pytorch 搭建 MLP 的过程。

项目代码地址:lab1

过程

构建网络结构

神经网络最重要的就是搭建网络,第一步就是定义网络结构。我这里是创建了一个四层的感知器,参数是根据 MNIST 数据集设定的,网络结构如下:

# 建立一个四层感知机网络
class MLP(torch.nn.Module):  # 继承 torch 的 Module
  def __init__(self):
    super(MLP,self).__init__()  # 
    # 初始化三层神经网络 两个全连接的隐藏层,一个输出层
    self.fc1 = torch.nn.Linear(784,512) # 第一个隐含层 
    self.fc2 = torch.nn.Linear(512,128) # 第二个隐含层
    self.fc3 = torch.nn.Linear(128,10)  # 输出层
    
  def forward(self,din):
    # 前向传播, 输入值:din, 返回值 dout
    din = din.view(-1,28*28)    # 将一个多行的Tensor,拼接成一行
    dout = F.relu(self.fc1(din))  # 使用 relu 激活函数
    dout = F.relu(self.fc2(dout))
    dout = F.softmax(self.fc3(dout), dim=1) # 输出层使用 softmax 激活函数
    # 10个数字实际上是10个类别,输出是概率分布,最后选取概率最大的作为预测值输出
    return dout

网络结构其实很简单,设置了三层 Linear。隐含层激活函数使用 Relu; 输出层使用 Softmax。网上还有其他的结构使用了 droupout,我觉得入门的话有点高级,而且放在这里并没有什么用,搞得很麻烦还不能提高准确率。

加载数据集

第二步就是定义全局变量,并加载 MNIST 数据集:

# 定义全局变量
n_epochs = 10   # epoch 的数目
batch_size = 20 # 决定每次读取多少图片

# 定义训练集个测试集,如果找不到数据,就下载
train_data = datasets.MNIST(root = './data', train = True, download = True, transform = transforms.ToTensor())
test_data = datasets.MNIST(root = './data', train = True, download = True, transform = transforms.ToTensor())
# 创建加载器
train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, num_workers = 0)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size, num_workers = 0)

这里参数很多,所以就有很多需要注意的地方了:

root 参数的文件夹即使不存在也没关系,会自动创建

transform 参数,如果不知道要对数据集进行什么变化,这里可自动忽略

batch_size 参数的大小决定了一次训练多少数据,相当于定义了每个 epoch 中反向传播的次数

num_workers 参数默认是 0,即不并行处理数据;我这里设置大于 0 的时候,总是报错,建议设成默认值

如果不理解 epoch 和 batch_size,可以上网查一下资料。(我刚开始学深度学习的时候也是不懂的)

训练神经网络

第三步就是训练网络了,代码如下:

# 训练神经网络
def train():
  # 定义损失函数和优化器
  lossfunc = torch.nn.CrossEntropyLoss()
  optimizer = torch.optim.SGD(params = model.parameters(), lr = 0.01)
  # 开始训练
  for epoch in range(n_epochs):
    train_loss = 0.0
    for data,target in train_loader:
      optimizer.zero_grad()  # 清空上一步的残余更新参数值
      output = model(data)  # 得到预测值
      loss = lossfunc(output,target) # 计算两者的误差
      loss.backward()     # 误差反向传播, 计算参数更新值
      optimizer.step()    # 将参数更新值施加到 net 的 parameters 上
      train_loss += loss.item()*data.size(0)
    train_loss = train_loss / len(train_loader.dataset)
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch + 1, train_loss))

训练之前要定义损失函数和优化器,这里其实有很多学问,但本文就不讲了,理论太多了。

训练过程就是两层 for 循环:外层是遍历训练集的次数;内层是每次的批次(batch)。最后,输出每个 epoch 的 loss。(每次训练的目的是使 loss 函数减小,以达到训练集上更高的准确率)

测试神经网络

最后,就是在测试集上进行测试,代码如下:

# 在数据集上测试神经网络
def test():
  correct = 0
  total = 0
  with torch.no_grad(): # 训练集中不需要反向传播
    for data in test_loader:
      images, labels = data
      outputs = model(images)
      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()
  print('Accuracy of the network on the test images: %d %%' % (
    100 * correct / total))
  return 100.0 * correct / total

这个测试的代码是同学给我的,我觉得这个测试的代码特别好,很简洁,一直用的这个。

代码首先设置 torch.no_grad(),定义后面的代码不需要计算梯度,能够节省一些内存空间。然后,对测试集中的每个 batch 进行测试,统计总数和准确数,最后计算准确率并输出。

通常是选择边训练边测试的,这里先就按步骤一步一步来做。

有的测试代码前面要加上 model.eval(),表示这是训练状态。但这里不需要,如果没有 Batch Normalization 和 Dropout 方法,加和不加的效果是一样的。

完整代码

'''
系统环境: Windows10
Python版本: 3.7
PyTorch版本: 1.1.0
cuda: no
'''
import torch
import torch.nn.functional as F  # 激励函数的库
from torchvision import datasets
import torchvision.transforms as transforms
import numpy as np

# 定义全局变量
n_epochs = 10   # epoch 的数目
batch_size = 20 # 决定每次读取多少图片

# 定义训练集个测试集,如果找不到数据,就下载
train_data = datasets.MNIST(root = './data', train = True, download = True, transform = transforms.ToTensor())
test_data = datasets.MNIST(root = './data', train = True, download = True, transform = transforms.ToTensor())
# 创建加载器
train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, num_workers = 0)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size, num_workers = 0)


# 建立一个四层感知机网络
class MLP(torch.nn.Module):  # 继承 torch 的 Module
  def __init__(self):
    super(MLP,self).__init__()  # 
    # 初始化三层神经网络 两个全连接的隐藏层,一个输出层
    self.fc1 = torch.nn.Linear(784,512) # 第一个隐含层 
    self.fc2 = torch.nn.Linear(512,128) # 第二个隐含层
    self.fc3 = torch.nn.Linear(128,10)  # 输出层
    
  def forward(self,din):
    # 前向传播, 输入值:din, 返回值 dout
    din = din.view(-1,28*28)    # 将一个多行的Tensor,拼接成一行
    dout = F.relu(self.fc1(din))  # 使用 relu 激活函数
    dout = F.relu(self.fc2(dout))
    dout = F.softmax(self.fc3(dout), dim=1) # 输出层使用 softmax 激活函数
    # 10个数字实际上是10个类别,输出是概率分布,最后选取概率最大的作为预测值输出
    return dout

# 训练神经网络
def train():
  #定义损失函数和优化器
  lossfunc = torch.nn.CrossEntropyLoss()
  optimizer = torch.optim.SGD(params = model.parameters(), lr = 0.01)
  # 开始训练
  for epoch in range(n_epochs):
    train_loss = 0.0
    for data,target in train_loader:
      optimizer.zero_grad()  # 清空上一步的残余更新参数值
      output = model(data)  # 得到预测值
      loss = lossfunc(output,target) # 计算两者的误差
      loss.backward()     # 误差反向传播, 计算参数更新值
      optimizer.step()    # 将参数更新值施加到 net 的 parameters 上
      train_loss += loss.item()*data.size(0)
    train_loss = train_loss / len(train_loader.dataset)
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch + 1, train_loss))
    # 每遍历一遍数据集,测试一下准确率
    test()

# 在数据集上测试神经网络
def test():
  correct = 0
  total = 0
  with torch.no_grad(): # 训练集中不需要反向传播
    for data in test_loader:
      images, labels = data
      outputs = model(images)
      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()
  print('Accuracy of the network on the test images: %d %%' % (
    100 * correct / total))
  return 100.0 * correct / total

# 声明感知器网络
model = MLP()

if __name__ == '__main__':
  train()

10 个 epoch 的训练效果,最后能达到大约 85% 的准确率。可以适当增加 epoch,但代码里没有用 gpu 运行,可能会比较慢。

以上这篇使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现合并两个数组的方法
May 16 Python
使用anaconda的pip安装第三方python包的操作步骤
Jun 11 Python
Python3中正则模块re.compile、re.match及re.search函数用法详解
Jun 11 Python
使用Django连接Mysql数据库步骤
Jan 15 Python
python实现诗歌游戏(类继承)
Feb 26 Python
Python中list的交、并、差集获取方法示例
Aug 01 Python
Django应用程序入口WSGIHandler源码解析
Aug 05 Python
Python 类的魔法属性用法实例分析
Nov 21 Python
Flask和pyecharts实现动态数据可视化
Feb 26 Python
python 读取、写入txt文件的示例
Sep 27 Python
python如何调用百度识图api
Sep 29 Python
使用tensorflow 实现反向传播求导
May 26 Python
Pycharm 2020最新永久激活码(附最新激活码和插件)
Sep 17 #Python
将matplotlib绘图嵌入pyqt的方法示例
Jan 08 #Python
pyinstaller还原python代码过程图解
Jan 08 #Python
python Tensor和Array对比分析
Jan 08 #Python
Pycharm小白级简单使用教程
Jan 08 #Python
python如何实现不可变字典inmutabledict
Jan 08 #Python
PyQt5 closeEvent关闭事件退出提示框原理解析
Jan 08 #Python
You might like
一个捕获函数输出的函数
2007/02/14 PHP
支持php4、php5的mysql数据库操作类
2008/01/10 PHP
详解PHP中websocket的使用方法
2016/09/15 PHP
Yii框架where查询用法实例分析
2019/10/22 PHP
初试jQuery EasyUI 使用介绍
2010/04/01 Javascript
基于jquery的仿百度的鼠标移入图片抖动效果
2010/09/17 Javascript
JavaScript如何从listbox里同时删除多个项目
2013/10/12 Javascript
Js判断参数(String,Array,Object)是否为undefined或者值为空
2013/11/04 Javascript
jquery操作checkbox实现全选和取消全选
2014/05/02 Javascript
全面解析JavaScript的Backbone.js框架中的Router路由
2016/05/05 Javascript
BootStrap 动态添加验证项和取消验证项的实现方法
2016/09/28 Javascript
js实现文字跑马灯效果
2017/02/23 Javascript
Angular.JS中的this指向详解
2017/05/17 Javascript
BootStrap入门学习第一篇
2017/08/28 Javascript
vue跨域解决方法
2017/10/15 Javascript
JavaScript实现微信号随机切换代码
2018/03/09 Javascript
用Vue写一个分页器的示例代码
2018/04/22 Javascript
JS数组push、unshift、pop、shift方法的实现与使用方法示例
2020/04/29 Javascript
学习python的几条建议分享
2013/02/10 Python
python中使用pyhook实现键盘监控的例子
2014/07/18 Python
分享一下Python 开发者节省时间的10个方法
2015/10/02 Python
Python 正则表达式入门(初级篇)
2016/12/07 Python
用 Python 连接 MySQL 的几种方式详解
2018/04/04 Python
python 读取DICOM头文件的实例
2018/05/07 Python
解决win7操作系统Python3.7.1安装后启动提示缺少.dll文件问题
2019/07/15 Python
python处理excel绘制雷达图
2019/10/18 Python
Python openpyxl模块实现excel读写操作
2020/06/30 Python
MoviePy简介及Python视频剪辑自动化
2020/12/18 Python
Python3自带工具2to3.py 转换 Python2.x 代码到Python3的操作
2021/03/03 Python
土耳其国际性时尚购物网站:Modanisa
2018/01/19 全球购物
美国狗旅行和户外用品领先供应商:kurgo
2020/08/18 全球购物
干部考核工作总结
2015/08/12 职场文书
MySQL sql_mode的使用详解
2021/05/08 MySQL
Mybatis是这样防止sql注入的
2021/12/06 Java/Android
Python之matplotlib绘制折线图
2022/04/13 Python
使用JS前端技术实现静态图片局部流动效果
2022/08/05 Javascript