使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式


Posted in Python onJanuary 08, 2020

简介

这是深度学习课程的第一个实验,主要目的就是熟悉 Pytorch 框架。MLP 是多层感知器,我这次实现的是四层感知器,代码和思路参考了网上的很多文章。个人认为,感知器的代码大同小异,尤其是用 Pytorch 实现,除了层数和参数外,代码都很相似。

Pytorch 写神经网络的主要步骤主要有以下几步:

1 构建网络结构

2 加载数据集

3 训练神经网络(包括优化器的选择和 Loss 的计算)

4 测试神经网络

下面将从这四个方面介绍 Pytorch 搭建 MLP 的过程。

项目代码地址:lab1

过程

构建网络结构

神经网络最重要的就是搭建网络,第一步就是定义网络结构。我这里是创建了一个四层的感知器,参数是根据 MNIST 数据集设定的,网络结构如下:

# 建立一个四层感知机网络
class MLP(torch.nn.Module):  # 继承 torch 的 Module
  def __init__(self):
    super(MLP,self).__init__()  # 
    # 初始化三层神经网络 两个全连接的隐藏层,一个输出层
    self.fc1 = torch.nn.Linear(784,512) # 第一个隐含层 
    self.fc2 = torch.nn.Linear(512,128) # 第二个隐含层
    self.fc3 = torch.nn.Linear(128,10)  # 输出层
    
  def forward(self,din):
    # 前向传播, 输入值:din, 返回值 dout
    din = din.view(-1,28*28)    # 将一个多行的Tensor,拼接成一行
    dout = F.relu(self.fc1(din))  # 使用 relu 激活函数
    dout = F.relu(self.fc2(dout))
    dout = F.softmax(self.fc3(dout), dim=1) # 输出层使用 softmax 激活函数
    # 10个数字实际上是10个类别,输出是概率分布,最后选取概率最大的作为预测值输出
    return dout

网络结构其实很简单,设置了三层 Linear。隐含层激活函数使用 Relu; 输出层使用 Softmax。网上还有其他的结构使用了 droupout,我觉得入门的话有点高级,而且放在这里并没有什么用,搞得很麻烦还不能提高准确率。

加载数据集

第二步就是定义全局变量,并加载 MNIST 数据集:

# 定义全局变量
n_epochs = 10   # epoch 的数目
batch_size = 20 # 决定每次读取多少图片

# 定义训练集个测试集,如果找不到数据,就下载
train_data = datasets.MNIST(root = './data', train = True, download = True, transform = transforms.ToTensor())
test_data = datasets.MNIST(root = './data', train = True, download = True, transform = transforms.ToTensor())
# 创建加载器
train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, num_workers = 0)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size, num_workers = 0)

这里参数很多,所以就有很多需要注意的地方了:

root 参数的文件夹即使不存在也没关系,会自动创建

transform 参数,如果不知道要对数据集进行什么变化,这里可自动忽略

batch_size 参数的大小决定了一次训练多少数据,相当于定义了每个 epoch 中反向传播的次数

num_workers 参数默认是 0,即不并行处理数据;我这里设置大于 0 的时候,总是报错,建议设成默认值

如果不理解 epoch 和 batch_size,可以上网查一下资料。(我刚开始学深度学习的时候也是不懂的)

训练神经网络

第三步就是训练网络了,代码如下:

# 训练神经网络
def train():
  # 定义损失函数和优化器
  lossfunc = torch.nn.CrossEntropyLoss()
  optimizer = torch.optim.SGD(params = model.parameters(), lr = 0.01)
  # 开始训练
  for epoch in range(n_epochs):
    train_loss = 0.0
    for data,target in train_loader:
      optimizer.zero_grad()  # 清空上一步的残余更新参数值
      output = model(data)  # 得到预测值
      loss = lossfunc(output,target) # 计算两者的误差
      loss.backward()     # 误差反向传播, 计算参数更新值
      optimizer.step()    # 将参数更新值施加到 net 的 parameters 上
      train_loss += loss.item()*data.size(0)
    train_loss = train_loss / len(train_loader.dataset)
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch + 1, train_loss))

训练之前要定义损失函数和优化器,这里其实有很多学问,但本文就不讲了,理论太多了。

训练过程就是两层 for 循环:外层是遍历训练集的次数;内层是每次的批次(batch)。最后,输出每个 epoch 的 loss。(每次训练的目的是使 loss 函数减小,以达到训练集上更高的准确率)

测试神经网络

最后,就是在测试集上进行测试,代码如下:

# 在数据集上测试神经网络
def test():
  correct = 0
  total = 0
  with torch.no_grad(): # 训练集中不需要反向传播
    for data in test_loader:
      images, labels = data
      outputs = model(images)
      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()
  print('Accuracy of the network on the test images: %d %%' % (
    100 * correct / total))
  return 100.0 * correct / total

这个测试的代码是同学给我的,我觉得这个测试的代码特别好,很简洁,一直用的这个。

代码首先设置 torch.no_grad(),定义后面的代码不需要计算梯度,能够节省一些内存空间。然后,对测试集中的每个 batch 进行测试,统计总数和准确数,最后计算准确率并输出。

通常是选择边训练边测试的,这里先就按步骤一步一步来做。

有的测试代码前面要加上 model.eval(),表示这是训练状态。但这里不需要,如果没有 Batch Normalization 和 Dropout 方法,加和不加的效果是一样的。

完整代码

'''
系统环境: Windows10
Python版本: 3.7
PyTorch版本: 1.1.0
cuda: no
'''
import torch
import torch.nn.functional as F  # 激励函数的库
from torchvision import datasets
import torchvision.transforms as transforms
import numpy as np

# 定义全局变量
n_epochs = 10   # epoch 的数目
batch_size = 20 # 决定每次读取多少图片

# 定义训练集个测试集,如果找不到数据,就下载
train_data = datasets.MNIST(root = './data', train = True, download = True, transform = transforms.ToTensor())
test_data = datasets.MNIST(root = './data', train = True, download = True, transform = transforms.ToTensor())
# 创建加载器
train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, num_workers = 0)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size, num_workers = 0)


# 建立一个四层感知机网络
class MLP(torch.nn.Module):  # 继承 torch 的 Module
  def __init__(self):
    super(MLP,self).__init__()  # 
    # 初始化三层神经网络 两个全连接的隐藏层,一个输出层
    self.fc1 = torch.nn.Linear(784,512) # 第一个隐含层 
    self.fc2 = torch.nn.Linear(512,128) # 第二个隐含层
    self.fc3 = torch.nn.Linear(128,10)  # 输出层
    
  def forward(self,din):
    # 前向传播, 输入值:din, 返回值 dout
    din = din.view(-1,28*28)    # 将一个多行的Tensor,拼接成一行
    dout = F.relu(self.fc1(din))  # 使用 relu 激活函数
    dout = F.relu(self.fc2(dout))
    dout = F.softmax(self.fc3(dout), dim=1) # 输出层使用 softmax 激活函数
    # 10个数字实际上是10个类别,输出是概率分布,最后选取概率最大的作为预测值输出
    return dout

# 训练神经网络
def train():
  #定义损失函数和优化器
  lossfunc = torch.nn.CrossEntropyLoss()
  optimizer = torch.optim.SGD(params = model.parameters(), lr = 0.01)
  # 开始训练
  for epoch in range(n_epochs):
    train_loss = 0.0
    for data,target in train_loader:
      optimizer.zero_grad()  # 清空上一步的残余更新参数值
      output = model(data)  # 得到预测值
      loss = lossfunc(output,target) # 计算两者的误差
      loss.backward()     # 误差反向传播, 计算参数更新值
      optimizer.step()    # 将参数更新值施加到 net 的 parameters 上
      train_loss += loss.item()*data.size(0)
    train_loss = train_loss / len(train_loader.dataset)
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch + 1, train_loss))
    # 每遍历一遍数据集,测试一下准确率
    test()

# 在数据集上测试神经网络
def test():
  correct = 0
  total = 0
  with torch.no_grad(): # 训练集中不需要反向传播
    for data in test_loader:
      images, labels = data
      outputs = model(images)
      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()
  print('Accuracy of the network on the test images: %d %%' % (
    100 * correct / total))
  return 100.0 * correct / total

# 声明感知器网络
model = MLP()

if __name__ == '__main__':
  train()

10 个 epoch 的训练效果,最后能达到大约 85% 的准确率。可以适当增加 epoch,但代码里没有用 gpu 运行,可能会比较慢。

以上这篇使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Web框架Pylons中使用MongoDB的例子
Dec 03 Python
Python制作简易注册登录系统
Dec 15 Python
利用信号如何监控Django模型对象字段值的变化详解
Nov 27 Python
Pycharm设置去除显示的波浪线方法
Oct 28 Python
Python 实现数据结构中的的栈队列
May 16 Python
Python实现微信好友的数据分析
Dec 16 Python
基于python读取.mat文件并取出信息
Dec 16 Python
解决Python列表字符不区分大小写的问题
Dec 19 Python
在tensorflow中设置使用某一块GPU、多GPU、CPU的操作
Feb 07 Python
keras实现多GPU或指定GPU的使用介绍
Jun 17 Python
keras 回调函数Callbacks 断点ModelCheckpoint教程
Jun 18 Python
基于logstash实现日志文件同步elasticsearch
Aug 06 Python
Pycharm 2020最新永久激活码(附最新激活码和插件)
Sep 17 #Python
将matplotlib绘图嵌入pyqt的方法示例
Jan 08 #Python
pyinstaller还原python代码过程图解
Jan 08 #Python
python Tensor和Array对比分析
Jan 08 #Python
Pycharm小白级简单使用教程
Jan 08 #Python
python如何实现不可变字典inmutabledict
Jan 08 #Python
PyQt5 closeEvent关闭事件退出提示框原理解析
Jan 08 #Python
You might like
PHP 验证码不显示只有一个小红叉的解决方法
2013/09/30 PHP
Mac版PhpStorm之XAMPP整合apache服务器配置的图文教程详解
2016/10/13 PHP
php cookie用户登录的详解及实例代码
2017/01/03 PHP
thinkPHP框架实现生成条形码的方法示例
2018/06/06 PHP
(currentStyle)javascript为何有时用style得不到已设定的CSS的属性
2007/08/15 Javascript
return false,对阻止事件默认动作的一些测试代码
2010/11/17 Javascript
使用jQuery动态加载js脚本文件的方法
2014/04/03 Javascript
Clipboard.js 无需Flash的JavaScript复制粘贴库
2015/10/02 Javascript
AngularJS 作用域详解及示例代码
2016/08/17 Javascript
AngularJS模板加载用法详解
2016/11/04 Javascript
基于Marquee.js插件实现的跑马灯效果示例
2017/01/25 Javascript
Angularjs 依赖压缩及自定义过滤器写法
2017/02/04 Javascript
浅谈vue中改elementUI默认样式引发的static与assets的区别
2018/02/03 Javascript
小程序图片剪裁加旋转的示例代码
2018/07/10 Javascript
跨域解决之JSONP和CORS的详细介绍
2018/11/21 Javascript
优雅的处理vue项目异常实战记录
2019/06/05 Javascript
javascript数组常见操作方法实例总结【连接、添加、删除、去重、排序等】
2019/06/13 Javascript
vue打开其他项目页面并传入数据详解
2020/11/25 Vue.js
[15:35]教你分分钟做大人:天怒法师
2014/10/30 DOTA
Python实现一个简单的MySQL类
2015/01/07 Python
总结Python中逻辑运算符的使用
2015/05/13 Python
在CentOS上配置Nginx+Gunicorn+Python+Flask环境的教程
2016/06/07 Python
python学习必备知识汇总
2017/09/08 Python
基于python神经卷积网络的人脸识别
2018/05/24 Python
Python3.4学习笔记之 idle 清屏扩展插件用法分析
2019/03/01 Python
Django用户认证系统 Web请求中的认证解析
2019/08/02 Python
python 初始化一个定长的数组实例
2019/12/02 Python
opencv3/python 鼠标响应操作详解
2019/12/11 Python
python利用platform模块获取系统信息
2020/10/09 Python
matplotlib交互式数据光标mpldatacursor的实现
2021/02/03 Python
世界上最好的帽子:Tilley
2016/11/27 全球购物
动物学专业毕业生求职信
2013/10/11 职场文书
新闻传播专业求职信
2014/07/22 职场文书
企业开业庆典答谢词
2015/01/20 职场文书
实习单位鉴定意见
2015/06/04 职场文书
MySQL 存储过程的优缺点分析
2021/05/20 MySQL