pytorch实现MNIST手写体识别


Posted in Python onFebruary 14, 2020

本文实例为大家分享了pytorch实现MNIST手写体识别的具体代码,供大家参考,具体内容如下

实验环境

pytorch 1.4
Windows 10
python 3.7
cuda 10.1(我笔记本上没有可以使用cuda的显卡)

实验过程

1. 确定我们要加载的库

import torch
import torch.nn as nn
import torchvision #这里面直接加载MNIST数据的方法
import torchvision.transforms as transforms # 将数据转为Tensor
import torch.optim as optim 
import torch.utils.data.dataloader as dataloader

2. 加载数据

这里使用所有数据进行训练,再使用所有数据进行测试

train_set = torchvision.datasets.MNIST(
 root='./data', # 文件存储位置
 train=True,
 transform=transforms.ToTensor(),
 download=True
)

train_dataloader = dataloader.DataLoader(dataset=train_set,shuffle=False,batch_size=100)# dataset可以省

'''
dataloader返回(images,labels)
其中,
images维度:[batch_size,1,28,28]
labels:[batch_size],即图片对应的
'''

test_set = torchvision.datasets.MNIST(
 root='./data',
 train=False,
 transform=transforms.ToTensor(),
 download=True
)

test_dataloader = dataloader.DataLoader(test_set,batch_size=100,shuffle=False) # dataset可以省

3. 定义神经网络模型

这里使用全神经网络作为模型

class NeuralNet(nn.Module):
 def __init__(self,in_num,h_num,out_num):
 super(NeuralNet,self).__init__()
 self.ln1 = nn.Linear(in_num,h_num)
 self.ln2 = nn.Linear(h_num,out_num)
 self.relu = nn.ReLU()
 
 def forward(self,x):
 return self.ln2(self.relu(self.ln1(x)))

4. 模型训练

in_num = 784 # 输入维度
h_num = 500 # 隐藏层维度
out_num = 10 # 输出维度
epochs = 30 # 迭代次数
learning_rate = 0.001
USE_CUDA = torch.cuda.is_available() # 定义是否可以使用cuda

model = NeuralNet(in_num,h_num,out_num) # 初始化模型
optimizer = optim.Adam(model.parameters(),lr=learning_rate) # 使用Adam
loss_fn = nn.CrossEntropyLoss() # 损失函数

for e in range(epochs):
 for i,data in enumerate(train_dataloader):
 (images,labels) = data
 images = images.reshape(-1,28*28) # [batch_size,784]
 if USE_CUDA:
  images = images.cuda() # 使用cuda
  labels = labels.cuda() # 使用cuda
  
 y_pred = model(images) # 预测
 loss = loss_fn(y_pred,labels) # 计算损失
 
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 
 n = e * i +1
 if n % 100 == 0:
  print(n,'loss:',loss.item())

训练模型的loss部分截图如下:

pytorch实现MNIST手写体识别

5. 测试模型

with torch.no_grad():
 total = 0
 correct = 0
 for (images,labels) in test_dataloader:
 images = images.reshape(-1,28*28)
 if USE_CUDA:
  images = images.cuda()
  labels = labels.cuda()
  
 result = model(images)
 prediction = torch.max(result, 1)[1] # 这里需要有[1],因为它返回了概率还有标签
 total += labels.size(0)
 correct += (prediction == labels).sum().item()
 
 print("The accuracy of total {} images: {}%".format(total, 100 * correct/total))

实验结果

最终实验的正确率达到:98.22%

pytorch实现MNIST手写体识别

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python单链表的简单实现方法
Sep 23 Python
Python操作MongoDB数据库PyMongo库使用方法
Apr 27 Python
python 判断参数为Nonetype类型或空的实例
Oct 30 Python
Python File(文件) 方法整理
Feb 18 Python
让Python脚本暂停执行的几种方法(小结)
Jul 11 Python
python处理自动化任务之同时批量修改word里面的内容的方法
Aug 23 Python
python中自带的三个装饰器的实现
Nov 08 Python
python常用运维脚本实例小结
Feb 14 Python
将自己的数据集制作成TFRecord格式教程
Feb 17 Python
Python序列化pickle模块使用详解
Mar 05 Python
Django 如何实现文件上传下载
Apr 08 Python
用Python制作灯光秀短视频的思路详解
Apr 13 Python
Python3.7实现验证码登录方式代码实例
Feb 14 #Python
Python逐行读取文件内容的方法总结
Feb 14 #Python
Python3和PyCharm安装与环境配置【图文教程】
Feb 14 #Python
python对Excel的读取的示例代码
Feb 14 #Python
Python安装依赖(包)模块方法详解
Feb 14 #Python
python 项目目录结构设置
Feb 14 #Python
wxpython自定义下拉列表框过程图解
Feb 14 #Python
You might like
php按百分比生成缩略图的代码分享
2014/05/10 PHP
Mootools 1.2教程 排序类和方法简介
2009/09/15 Javascript
基于jquery的15款幻灯片插件
2011/04/10 Javascript
Javascript insertAfter() 实现函数代码
2011/10/12 Javascript
JavaScript打印网页指定区域的例子
2014/05/03 Javascript
jquery专业的导航菜单特效代码分享
2015/08/29 Javascript
JS使用cookie设置样式的方法
2016/06/30 Javascript
微信小程序 Page()函数详解
2016/10/17 Javascript
Bootstrap Modal对话框如何在关闭时触发事件
2016/12/02 Javascript
Vue 2.x教程之基础API
2017/03/06 Javascript
JS仿QQ好友列表展开、收缩功能(第二篇)
2017/07/07 Javascript
微信小程序 功能函数小结(手机号验证*、密码验证*、获取验证码*)
2017/12/08 Javascript
利用vue + element实现表格分页和前端搜索的方法
2017/12/25 Javascript
JavaScript中EventLoop介绍
2018/01/22 Javascript
Vue 理解之白话 getter/setter详解
2019/04/16 Javascript
javascript实现视频弹幕效果(两个版本)
2019/11/28 Javascript
vue路由缓存的几种实现方式小结
2020/02/02 Javascript
js编写简易的计算器
2020/07/29 Javascript
Vue实现返回顶部按钮实例代码
2020/10/21 Javascript
[03:55]显微镜下的DOTA2特别篇——430灰烬之灵神级操作
2014/06/24 DOTA
[42:32]完美世界DOTA2联赛循环赛 Magma vs PXG BO2第二场 10.28
2020/10/28 DOTA
30分钟搭建Python的Flask框架并在上面编写第一个应用
2015/03/30 Python
多版本Python共存的配置方法
2017/05/22 Python
详解Python 序列化Serialize 和 反序列化Deserialize
2017/08/20 Python
Python中property属性实例解析
2018/02/10 Python
python 使用pdfminer3k 读取PDF文档的例子
2019/08/27 Python
python字典与json转换的方法总结
2020/12/28 Python
CSS3 实现弹幕的示例代码
2017/08/07 HTML / CSS
365 Tickets英国:全球景点门票
2019/07/06 全球购物
计算机专业个人求职自荐信
2013/09/21 职场文书
新闻学专业大学生职业生涯规划范文
2014/03/02 职场文书
常住证明范本
2015/06/23 职场文书
大学生暑假实习总结
2015/07/13 职场文书
2016年寒假学习心得体会
2015/10/09 职场文书
2016领导干部廉洁自律心得体会
2016/01/13 职场文书
Nginx 常用配置
2022/05/15 Servers