pytorch实现MNIST手写体识别


Posted in Python onFebruary 14, 2020

本文实例为大家分享了pytorch实现MNIST手写体识别的具体代码,供大家参考,具体内容如下

实验环境

pytorch 1.4
Windows 10
python 3.7
cuda 10.1(我笔记本上没有可以使用cuda的显卡)

实验过程

1. 确定我们要加载的库

import torch
import torch.nn as nn
import torchvision #这里面直接加载MNIST数据的方法
import torchvision.transforms as transforms # 将数据转为Tensor
import torch.optim as optim 
import torch.utils.data.dataloader as dataloader

2. 加载数据

这里使用所有数据进行训练,再使用所有数据进行测试

train_set = torchvision.datasets.MNIST(
 root='./data', # 文件存储位置
 train=True,
 transform=transforms.ToTensor(),
 download=True
)

train_dataloader = dataloader.DataLoader(dataset=train_set,shuffle=False,batch_size=100)# dataset可以省

'''
dataloader返回(images,labels)
其中,
images维度:[batch_size,1,28,28]
labels:[batch_size],即图片对应的
'''

test_set = torchvision.datasets.MNIST(
 root='./data',
 train=False,
 transform=transforms.ToTensor(),
 download=True
)

test_dataloader = dataloader.DataLoader(test_set,batch_size=100,shuffle=False) # dataset可以省

3. 定义神经网络模型

这里使用全神经网络作为模型

class NeuralNet(nn.Module):
 def __init__(self,in_num,h_num,out_num):
 super(NeuralNet,self).__init__()
 self.ln1 = nn.Linear(in_num,h_num)
 self.ln2 = nn.Linear(h_num,out_num)
 self.relu = nn.ReLU()
 
 def forward(self,x):
 return self.ln2(self.relu(self.ln1(x)))

4. 模型训练

in_num = 784 # 输入维度
h_num = 500 # 隐藏层维度
out_num = 10 # 输出维度
epochs = 30 # 迭代次数
learning_rate = 0.001
USE_CUDA = torch.cuda.is_available() # 定义是否可以使用cuda

model = NeuralNet(in_num,h_num,out_num) # 初始化模型
optimizer = optim.Adam(model.parameters(),lr=learning_rate) # 使用Adam
loss_fn = nn.CrossEntropyLoss() # 损失函数

for e in range(epochs):
 for i,data in enumerate(train_dataloader):
 (images,labels) = data
 images = images.reshape(-1,28*28) # [batch_size,784]
 if USE_CUDA:
  images = images.cuda() # 使用cuda
  labels = labels.cuda() # 使用cuda
  
 y_pred = model(images) # 预测
 loss = loss_fn(y_pred,labels) # 计算损失
 
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 
 n = e * i +1
 if n % 100 == 0:
  print(n,'loss:',loss.item())

训练模型的loss部分截图如下:

pytorch实现MNIST手写体识别

5. 测试模型

with torch.no_grad():
 total = 0
 correct = 0
 for (images,labels) in test_dataloader:
 images = images.reshape(-1,28*28)
 if USE_CUDA:
  images = images.cuda()
  labels = labels.cuda()
  
 result = model(images)
 prediction = torch.max(result, 1)[1] # 这里需要有[1],因为它返回了概率还有标签
 total += labels.size(0)
 correct += (prediction == labels).sum().item()
 
 print("The accuracy of total {} images: {}%".format(total, 100 * correct/total))

实验结果

最终实验的正确率达到:98.22%

pytorch实现MNIST手写体识别

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用python删除java文件头上版权信息的方法
Jul 31 Python
python中的五种异常处理机制介绍
Sep 02 Python
Python多维/嵌套字典数据无限遍历的实现
Nov 04 Python
python去掉行尾的换行符方法
Jan 04 Python
Python实现一个转存纯真IP数据库的脚本分享
May 21 Python
Python实现简单的文本相似度分析操作详解
Jun 16 Python
python 读取视频,处理后,实时计算帧数fps的方法
Jul 10 Python
python粘包问题及socket套接字编程详解
Jun 29 Python
python字符串格式化方式解析
Oct 19 Python
python访问hdfs的操作
Jun 06 Python
浅析Python迭代器的高级用法
Jul 16 Python
Python Django 后台管理之后台模型属性详解
Apr 25 Python
Python3.7实现验证码登录方式代码实例
Feb 14 #Python
Python逐行读取文件内容的方法总结
Feb 14 #Python
Python3和PyCharm安装与环境配置【图文教程】
Feb 14 #Python
python对Excel的读取的示例代码
Feb 14 #Python
Python安装依赖(包)模块方法详解
Feb 14 #Python
python 项目目录结构设置
Feb 14 #Python
wxpython自定义下拉列表框过程图解
Feb 14 #Python
You might like
php笔记之:有规律大文件的读取与写入的分析
2013/04/26 PHP
PHP动态柱状图实现方法
2015/03/30 PHP
Laravel使用Caching缓存数据减轻数据库查询压力的方法
2016/03/15 PHP
又拍云异步上传实例教程详解
2016/04/19 PHP
比较全面的event对像在IE与FF中的区别 推荐
2009/09/21 Javascript
让低版本浏览器支持input的placeholder属性(js方法)
2013/04/03 Javascript
jQuery实现等比例缩放大图片让大图片自适应页面布局
2013/10/16 Javascript
Jquery通过Ajax方式来提交Form表单的具体实现
2013/11/07 Javascript
JS获取当前日期时间并定时刷新示例
2021/03/04 Javascript
九种原生js动画效果
2015/11/11 Javascript
理解javascript定时器中的单线程
2016/02/23 Javascript
jQuery基础知识点总结(DOM操作)
2016/06/01 Javascript
原生JS轮播图插件
2017/02/09 Javascript
你点的 ES6一些小技巧,请查收
2018/04/25 Javascript
webpack4+react多页面架构的实现
2018/10/25 Javascript
详解如何在Node.js的httpServer中接收前端发送的arraybuffer数据
2018/11/11 Javascript
vue项目中使用bpmn-自定义platter的示例代码
2020/05/11 Javascript
Vue this.$router.push(参数)实现页面跳转操作
2020/09/09 Javascript
[49:11]完美世界DOTA2联赛PWL S3 INK ICE vs DLG 第二场 12.20
2020/12/23 DOTA
纯Python开发的nosql数据库CodernityDB介绍和使用实例
2014/10/23 Python
Python闭包的两个注意事项(推荐)
2017/03/20 Python
Django分页查询并返回jsons数据(中文乱码解决方法)
2018/08/02 Python
深入浅析python 协程与go协程的区别
2019/05/09 Python
python中对数据进行各种排序的方法
2019/07/02 Python
基于Python3.7.1无法导入Numpy的解决方式
2020/03/09 Python
Python 日期时间datetime 加一天,减一天,加减一小时一分钟,加减一年
2020/04/16 Python
阿迪达斯英国官方网站:adidas英国
2019/08/13 全球购物
长曲棍球装备:Lacrosse Monkey
2020/12/02 全球购物
公司年会晚宴演讲稿
2014/01/06 职场文书
医学专业职业生涯规划范文
2014/02/05 职场文书
运动会横幅标语
2014/06/17 职场文书
党员弘扬焦裕禄精神思想汇报
2014/09/10 职场文书
向国旗敬礼活动小结
2014/09/27 职场文书
2014年学生会干事工作总结
2014/11/07 职场文书
pytorch Dropout过拟合的操作
2021/05/27 Python
Python简易开发之制作计算器
2022/04/28 Python