使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式


Posted in Python onJanuary 08, 2020

简介

这是深度学习课程的第一个实验,主要目的就是熟悉 Pytorch 框架。MLP 是多层感知器,我这次实现的是四层感知器,代码和思路参考了网上的很多文章。个人认为,感知器的代码大同小异,尤其是用 Pytorch 实现,除了层数和参数外,代码都很相似。

Pytorch 写神经网络的主要步骤主要有以下几步:

1 构建网络结构

2 加载数据集

3 训练神经网络(包括优化器的选择和 Loss 的计算)

4 测试神经网络

下面将从这四个方面介绍 Pytorch 搭建 MLP 的过程。

项目代码地址:lab1

过程

构建网络结构

神经网络最重要的就是搭建网络,第一步就是定义网络结构。我这里是创建了一个四层的感知器,参数是根据 MNIST 数据集设定的,网络结构如下:

# 建立一个四层感知机网络
class MLP(torch.nn.Module):  # 继承 torch 的 Module
  def __init__(self):
    super(MLP,self).__init__()  # 
    # 初始化三层神经网络 两个全连接的隐藏层,一个输出层
    self.fc1 = torch.nn.Linear(784,512) # 第一个隐含层 
    self.fc2 = torch.nn.Linear(512,128) # 第二个隐含层
    self.fc3 = torch.nn.Linear(128,10)  # 输出层
    
  def forward(self,din):
    # 前向传播, 输入值:din, 返回值 dout
    din = din.view(-1,28*28)    # 将一个多行的Tensor,拼接成一行
    dout = F.relu(self.fc1(din))  # 使用 relu 激活函数
    dout = F.relu(self.fc2(dout))
    dout = F.softmax(self.fc3(dout), dim=1) # 输出层使用 softmax 激活函数
    # 10个数字实际上是10个类别,输出是概率分布,最后选取概率最大的作为预测值输出
    return dout

网络结构其实很简单,设置了三层 Linear。隐含层激活函数使用 Relu; 输出层使用 Softmax。网上还有其他的结构使用了 droupout,我觉得入门的话有点高级,而且放在这里并没有什么用,搞得很麻烦还不能提高准确率。

加载数据集

第二步就是定义全局变量,并加载 MNIST 数据集:

# 定义全局变量
n_epochs = 10   # epoch 的数目
batch_size = 20 # 决定每次读取多少图片

# 定义训练集个测试集,如果找不到数据,就下载
train_data = datasets.MNIST(root = './data', train = True, download = True, transform = transforms.ToTensor())
test_data = datasets.MNIST(root = './data', train = True, download = True, transform = transforms.ToTensor())
# 创建加载器
train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, num_workers = 0)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size, num_workers = 0)

这里参数很多,所以就有很多需要注意的地方了:

root 参数的文件夹即使不存在也没关系,会自动创建

transform 参数,如果不知道要对数据集进行什么变化,这里可自动忽略

batch_size 参数的大小决定了一次训练多少数据,相当于定义了每个 epoch 中反向传播的次数

num_workers 参数默认是 0,即不并行处理数据;我这里设置大于 0 的时候,总是报错,建议设成默认值

如果不理解 epoch 和 batch_size,可以上网查一下资料。(我刚开始学深度学习的时候也是不懂的)

训练神经网络

第三步就是训练网络了,代码如下:

# 训练神经网络
def train():
  # 定义损失函数和优化器
  lossfunc = torch.nn.CrossEntropyLoss()
  optimizer = torch.optim.SGD(params = model.parameters(), lr = 0.01)
  # 开始训练
  for epoch in range(n_epochs):
    train_loss = 0.0
    for data,target in train_loader:
      optimizer.zero_grad()  # 清空上一步的残余更新参数值
      output = model(data)  # 得到预测值
      loss = lossfunc(output,target) # 计算两者的误差
      loss.backward()     # 误差反向传播, 计算参数更新值
      optimizer.step()    # 将参数更新值施加到 net 的 parameters 上
      train_loss += loss.item()*data.size(0)
    train_loss = train_loss / len(train_loader.dataset)
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch + 1, train_loss))

训练之前要定义损失函数和优化器,这里其实有很多学问,但本文就不讲了,理论太多了。

训练过程就是两层 for 循环:外层是遍历训练集的次数;内层是每次的批次(batch)。最后,输出每个 epoch 的 loss。(每次训练的目的是使 loss 函数减小,以达到训练集上更高的准确率)

测试神经网络

最后,就是在测试集上进行测试,代码如下:

# 在数据集上测试神经网络
def test():
  correct = 0
  total = 0
  with torch.no_grad(): # 训练集中不需要反向传播
    for data in test_loader:
      images, labels = data
      outputs = model(images)
      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()
  print('Accuracy of the network on the test images: %d %%' % (
    100 * correct / total))
  return 100.0 * correct / total

这个测试的代码是同学给我的,我觉得这个测试的代码特别好,很简洁,一直用的这个。

代码首先设置 torch.no_grad(),定义后面的代码不需要计算梯度,能够节省一些内存空间。然后,对测试集中的每个 batch 进行测试,统计总数和准确数,最后计算准确率并输出。

通常是选择边训练边测试的,这里先就按步骤一步一步来做。

有的测试代码前面要加上 model.eval(),表示这是训练状态。但这里不需要,如果没有 Batch Normalization 和 Dropout 方法,加和不加的效果是一样的。

完整代码

'''
系统环境: Windows10
Python版本: 3.7
PyTorch版本: 1.1.0
cuda: no
'''
import torch
import torch.nn.functional as F  # 激励函数的库
from torchvision import datasets
import torchvision.transforms as transforms
import numpy as np

# 定义全局变量
n_epochs = 10   # epoch 的数目
batch_size = 20 # 决定每次读取多少图片

# 定义训练集个测试集,如果找不到数据,就下载
train_data = datasets.MNIST(root = './data', train = True, download = True, transform = transforms.ToTensor())
test_data = datasets.MNIST(root = './data', train = True, download = True, transform = transforms.ToTensor())
# 创建加载器
train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, num_workers = 0)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size, num_workers = 0)


# 建立一个四层感知机网络
class MLP(torch.nn.Module):  # 继承 torch 的 Module
  def __init__(self):
    super(MLP,self).__init__()  # 
    # 初始化三层神经网络 两个全连接的隐藏层,一个输出层
    self.fc1 = torch.nn.Linear(784,512) # 第一个隐含层 
    self.fc2 = torch.nn.Linear(512,128) # 第二个隐含层
    self.fc3 = torch.nn.Linear(128,10)  # 输出层
    
  def forward(self,din):
    # 前向传播, 输入值:din, 返回值 dout
    din = din.view(-1,28*28)    # 将一个多行的Tensor,拼接成一行
    dout = F.relu(self.fc1(din))  # 使用 relu 激活函数
    dout = F.relu(self.fc2(dout))
    dout = F.softmax(self.fc3(dout), dim=1) # 输出层使用 softmax 激活函数
    # 10个数字实际上是10个类别,输出是概率分布,最后选取概率最大的作为预测值输出
    return dout

# 训练神经网络
def train():
  #定义损失函数和优化器
  lossfunc = torch.nn.CrossEntropyLoss()
  optimizer = torch.optim.SGD(params = model.parameters(), lr = 0.01)
  # 开始训练
  for epoch in range(n_epochs):
    train_loss = 0.0
    for data,target in train_loader:
      optimizer.zero_grad()  # 清空上一步的残余更新参数值
      output = model(data)  # 得到预测值
      loss = lossfunc(output,target) # 计算两者的误差
      loss.backward()     # 误差反向传播, 计算参数更新值
      optimizer.step()    # 将参数更新值施加到 net 的 parameters 上
      train_loss += loss.item()*data.size(0)
    train_loss = train_loss / len(train_loader.dataset)
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch + 1, train_loss))
    # 每遍历一遍数据集,测试一下准确率
    test()

# 在数据集上测试神经网络
def test():
  correct = 0
  total = 0
  with torch.no_grad(): # 训练集中不需要反向传播
    for data in test_loader:
      images, labels = data
      outputs = model(images)
      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()
  print('Accuracy of the network on the test images: %d %%' % (
    100 * correct / total))
  return 100.0 * correct / total

# 声明感知器网络
model = MLP()

if __name__ == '__main__':
  train()

10 个 epoch 的训练效果,最后能达到大约 85% 的准确率。可以适当增加 epoch,但代码里没有用 gpu 运行,可能会比较慢。

以上这篇使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用Pycrypto库进行RSA加密的方法详解
Jun 06 Python
pandas全表查询定位某个值所在行列的方法
Apr 12 Python
Matplotlib 生成不同大小的subplots实例
May 25 Python
python3.6使用pymysql连接Mysql数据库
May 25 Python
python将txt文档每行内容循环插入数据库的方法
Dec 28 Python
局域网内python socket实现windows与linux间的消息传送
Apr 19 Python
Python对接六大主流数据库(只需三步)
Jul 31 Python
基于pytorch padding=SAME的解决方式
Feb 18 Python
在Sublime Editor中配置Python环境的详细教程
May 03 Python
python读取pdf格式文档的实现代码
Apr 01 Python
Jupyter notebook 输出部分显示不全的解决方案
Apr 24 Python
Python django中如何使用restful框架
Jun 23 Python
Pycharm 2020最新永久激活码(附最新激活码和插件)
Sep 17 #Python
将matplotlib绘图嵌入pyqt的方法示例
Jan 08 #Python
pyinstaller还原python代码过程图解
Jan 08 #Python
python Tensor和Array对比分析
Jan 08 #Python
Pycharm小白级简单使用教程
Jan 08 #Python
python如何实现不可变字典inmutabledict
Jan 08 #Python
PyQt5 closeEvent关闭事件退出提示框原理解析
Jan 08 #Python
You might like
PHP+MySql+jQuery实现的"顶"和"踩"投票功能
2016/05/21 PHP
PHP文件操作简单介绍及函数汇总
2020/12/11 PHP
通过Unicode转义序列来加密,按你说的可以算是混淆吧
2007/05/06 Javascript
原始的js代码和jquery对比体会
2013/09/10 Javascript
javascript中对Attr(dom中属性)的操作示例讲解
2013/12/02 Javascript
jquery div拖动效果示例代码
2013/12/08 Javascript
angularjs实现与服务器交互分享
2014/06/24 Javascript
JQuery插件Quicksand实现超炫的动画洗牌效果
2015/05/03 Javascript
AngularJS基础学习笔记之指令
2015/05/10 Javascript
js如何判断访问是来自搜索引擎(蜘蛛人)还是直接访问
2015/09/14 Javascript
学习AngularJs:Directive指令用法(完整版)
2016/04/26 Javascript
利用Angular+Angular-Ui实现分页(代码加简单)
2017/03/10 Javascript
javascript实现文字无缝滚动效果
2017/08/26 Javascript
在 Angular6 中使用 HTTP 请求服务端数据的步骤详解
2018/08/06 Javascript
vue-cli的工程模板与构建工具详解
2018/09/27 Javascript
在vue中使用cookie记住用户上次选择的实例(本次例子中为下拉框)
2020/09/11 Javascript
vant-ui组件调用Dialog弹窗异步关闭操作
2020/11/04 Javascript
Python写入CSV文件的方法
2015/07/08 Python
Windows下Python2与Python3两个版本共存的方法详解
2017/02/12 Python
Python中函数及默认参数的定义与调用操作实例分析
2017/07/25 Python
利用Hyperic调用Python实现进程守护
2018/01/02 Python
Python构建图像分类识别器的方法
2019/01/12 Python
django实现模板中的字符串文字和自动转义
2020/03/31 Python
使用python批量修改XML文件中图像的depth值
2020/07/22 Python
CSS3之transition实现下划线的示例代码
2018/05/30 HTML / CSS
KARATOV珠宝在线商店:俄罗斯珠宝品牌
2019/03/13 全球购物
美国厨房和园艺工具网上商店:Nestneed
2019/08/24 全球购物
美国婴儿和儿童服装购物网站:PatPat
2020/10/01 全球购物
50道外企软件测试面试题
2014/08/18 面试题
《记承天寺夜游》教学反思
2014/02/16 职场文书
常务副总经理任命书
2014/06/05 职场文书
药剂专业毕业生求职信
2014/06/24 职场文书
小学课外活动总结
2014/07/09 职场文书
杜甫草堂导游词
2015/02/03 职场文书
2015年大学生工作总结
2015/04/21 职场文书
MyBatis-Plus 批量插入数据的操作方法
2021/09/25 Java/Android