pytorch实现MNIST手写体识别


Posted in Python onFebruary 14, 2020

本文实例为大家分享了pytorch实现MNIST手写体识别的具体代码,供大家参考,具体内容如下

实验环境

pytorch 1.4
Windows 10
python 3.7
cuda 10.1(我笔记本上没有可以使用cuda的显卡)

实验过程

1. 确定我们要加载的库

import torch
import torch.nn as nn
import torchvision #这里面直接加载MNIST数据的方法
import torchvision.transforms as transforms # 将数据转为Tensor
import torch.optim as optim 
import torch.utils.data.dataloader as dataloader

2. 加载数据

这里使用所有数据进行训练,再使用所有数据进行测试

train_set = torchvision.datasets.MNIST(
 root='./data', # 文件存储位置
 train=True,
 transform=transforms.ToTensor(),
 download=True
)

train_dataloader = dataloader.DataLoader(dataset=train_set,shuffle=False,batch_size=100)# dataset可以省

'''
dataloader返回(images,labels)
其中,
images维度:[batch_size,1,28,28]
labels:[batch_size],即图片对应的
'''

test_set = torchvision.datasets.MNIST(
 root='./data',
 train=False,
 transform=transforms.ToTensor(),
 download=True
)

test_dataloader = dataloader.DataLoader(test_set,batch_size=100,shuffle=False) # dataset可以省

3. 定义神经网络模型

这里使用全神经网络作为模型

class NeuralNet(nn.Module):
 def __init__(self,in_num,h_num,out_num):
 super(NeuralNet,self).__init__()
 self.ln1 = nn.Linear(in_num,h_num)
 self.ln2 = nn.Linear(h_num,out_num)
 self.relu = nn.ReLU()
 
 def forward(self,x):
 return self.ln2(self.relu(self.ln1(x)))

4. 模型训练

in_num = 784 # 输入维度
h_num = 500 # 隐藏层维度
out_num = 10 # 输出维度
epochs = 30 # 迭代次数
learning_rate = 0.001
USE_CUDA = torch.cuda.is_available() # 定义是否可以使用cuda

model = NeuralNet(in_num,h_num,out_num) # 初始化模型
optimizer = optim.Adam(model.parameters(),lr=learning_rate) # 使用Adam
loss_fn = nn.CrossEntropyLoss() # 损失函数

for e in range(epochs):
 for i,data in enumerate(train_dataloader):
 (images,labels) = data
 images = images.reshape(-1,28*28) # [batch_size,784]
 if USE_CUDA:
  images = images.cuda() # 使用cuda
  labels = labels.cuda() # 使用cuda
  
 y_pred = model(images) # 预测
 loss = loss_fn(y_pred,labels) # 计算损失
 
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 
 n = e * i +1
 if n % 100 == 0:
  print(n,'loss:',loss.item())

训练模型的loss部分截图如下:

pytorch实现MNIST手写体识别

5. 测试模型

with torch.no_grad():
 total = 0
 correct = 0
 for (images,labels) in test_dataloader:
 images = images.reshape(-1,28*28)
 if USE_CUDA:
  images = images.cuda()
  labels = labels.cuda()
  
 result = model(images)
 prediction = torch.max(result, 1)[1] # 这里需要有[1],因为它返回了概率还有标签
 total += labels.size(0)
 correct += (prediction == labels).sum().item()
 
 print("The accuracy of total {} images: {}%".format(total, 100 * correct/total))

实验结果

最终实验的正确率达到:98.22%

pytorch实现MNIST手写体识别

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现dict版图遍历示例
Feb 19 Python
Python入门篇之文件
Oct 20 Python
python计算对角线有理函数插值的方法
May 07 Python
python实现的简单FTP上传下载文件实例
Jun 30 Python
django之常用命令详解
Jun 30 Python
Python模块WSGI使用详解
Feb 02 Python
opencv python 傅里叶变换的使用
Jul 21 Python
Python 的字典(Dict)是如何存储的
Jul 05 Python
利用Python的sympy包求解一元三次方程示例
Nov 22 Python
Python pygame绘制文字制作滚动文字过程解析
Dec 12 Python
python实现快速文件格式批量转换的方法
Oct 16 Python
Pytorch实现图像识别之数字识别(附详细注释)
May 11 Python
Python3.7实现验证码登录方式代码实例
Feb 14 #Python
Python逐行读取文件内容的方法总结
Feb 14 #Python
Python3和PyCharm安装与环境配置【图文教程】
Feb 14 #Python
python对Excel的读取的示例代码
Feb 14 #Python
Python安装依赖(包)模块方法详解
Feb 14 #Python
python 项目目录结构设置
Feb 14 #Python
wxpython自定义下拉列表框过程图解
Feb 14 #Python
You might like
浅析PHP中Collection 类的设计
2013/06/21 PHP
laravel容器延迟加载以及auth扩展详解
2015/03/02 PHP
php阿拉伯数字转中文人民币大写
2015/12/21 PHP
PHP下载文件的函数实例代码
2016/05/18 PHP
PHP基于单例模式编写PDO类的方法
2016/09/13 PHP
万能的php分页类
2017/07/06 PHP
Laravel框架分页实现方法分析
2018/06/12 PHP
Display SQL Server Version Information
2007/06/21 Javascript
Javascript 表单之间的数据传递代码
2008/12/04 Javascript
JS input 数字验证代码
2009/07/30 Javascript
IE 上下滚动展示模仿Marquee机制
2009/12/20 Javascript
EXTJS记事本 当CompositeField遇上RowEditor
2011/07/31 Javascript
javascript学习笔记(十四) window对象使用介绍
2012/06/20 Javascript
node+express+jade制作简单网站指南
2014/11/26 Javascript
JavaScript 实现完美兼容多浏览器的复制功能代码
2015/04/28 Javascript
基于jQuey实现鼠标滑过变色(整行变色)
2015/12/07 Javascript
JS针对Array的各种操作汇总
2016/11/29 Javascript
JS实现的样式切换功能tableCSS实例
2016/12/30 Javascript
JS实现图片点击后出现模态框效果
2017/05/03 Javascript
mui框架移动开发初体验详解
2017/10/11 Javascript
Vue.directive 自定义指令的问题小结
2018/03/04 Javascript
vue的常用组件操作方法应用分析
2018/04/13 Javascript
mpvue构建小程序的方法(步骤+地址)
2018/05/22 Javascript
深入理解移动前端开发之viewport
2018/10/19 Javascript
微信小程序下拉框功能的实例代码
2018/11/06 Javascript
解决layui的使用以及针对select、radio等表单组件不显示的问题
2019/09/05 Javascript
解决vue-cli@3.xx安装不成功的问题及搭建ts-vue项目
2020/02/09 Javascript
python机器学习理论与实战(一)K近邻法
2021/01/28 Python
解决Python 命令行执行脚本时,提示导入的包找不到的问题
2019/01/19 Python
ubuntu 16.04下python版本切换的方法
2019/06/14 Python
Hotels.com香港酒店网:你的自由行酒店订房专家
2018/01/22 全球购物
工商管理专业大学生职业生涯规划范文
2014/03/09 职场文书
小学学校评估方案
2014/06/08 职场文书
小学领导班子对照材料
2014/08/23 职场文书
乡镇务虚会发言材料
2014/10/20 职场文书
优秀共产党员事迹材料
2014/12/18 职场文书