pytorch实现MNIST手写体识别


Posted in Python onFebruary 14, 2020

本文实例为大家分享了pytorch实现MNIST手写体识别的具体代码,供大家参考,具体内容如下

实验环境

pytorch 1.4
Windows 10
python 3.7
cuda 10.1(我笔记本上没有可以使用cuda的显卡)

实验过程

1. 确定我们要加载的库

import torch
import torch.nn as nn
import torchvision #这里面直接加载MNIST数据的方法
import torchvision.transforms as transforms # 将数据转为Tensor
import torch.optim as optim 
import torch.utils.data.dataloader as dataloader

2. 加载数据

这里使用所有数据进行训练,再使用所有数据进行测试

train_set = torchvision.datasets.MNIST(
 root='./data', # 文件存储位置
 train=True,
 transform=transforms.ToTensor(),
 download=True
)

train_dataloader = dataloader.DataLoader(dataset=train_set,shuffle=False,batch_size=100)# dataset可以省

'''
dataloader返回(images,labels)
其中,
images维度:[batch_size,1,28,28]
labels:[batch_size],即图片对应的
'''

test_set = torchvision.datasets.MNIST(
 root='./data',
 train=False,
 transform=transforms.ToTensor(),
 download=True
)

test_dataloader = dataloader.DataLoader(test_set,batch_size=100,shuffle=False) # dataset可以省

3. 定义神经网络模型

这里使用全神经网络作为模型

class NeuralNet(nn.Module):
 def __init__(self,in_num,h_num,out_num):
 super(NeuralNet,self).__init__()
 self.ln1 = nn.Linear(in_num,h_num)
 self.ln2 = nn.Linear(h_num,out_num)
 self.relu = nn.ReLU()
 
 def forward(self,x):
 return self.ln2(self.relu(self.ln1(x)))

4. 模型训练

in_num = 784 # 输入维度
h_num = 500 # 隐藏层维度
out_num = 10 # 输出维度
epochs = 30 # 迭代次数
learning_rate = 0.001
USE_CUDA = torch.cuda.is_available() # 定义是否可以使用cuda

model = NeuralNet(in_num,h_num,out_num) # 初始化模型
optimizer = optim.Adam(model.parameters(),lr=learning_rate) # 使用Adam
loss_fn = nn.CrossEntropyLoss() # 损失函数

for e in range(epochs):
 for i,data in enumerate(train_dataloader):
 (images,labels) = data
 images = images.reshape(-1,28*28) # [batch_size,784]
 if USE_CUDA:
  images = images.cuda() # 使用cuda
  labels = labels.cuda() # 使用cuda
  
 y_pred = model(images) # 预测
 loss = loss_fn(y_pred,labels) # 计算损失
 
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 
 n = e * i +1
 if n % 100 == 0:
  print(n,'loss:',loss.item())

训练模型的loss部分截图如下:

pytorch实现MNIST手写体识别

5. 测试模型

with torch.no_grad():
 total = 0
 correct = 0
 for (images,labels) in test_dataloader:
 images = images.reshape(-1,28*28)
 if USE_CUDA:
  images = images.cuda()
  labels = labels.cuda()
  
 result = model(images)
 prediction = torch.max(result, 1)[1] # 这里需要有[1],因为它返回了概率还有标签
 total += labels.size(0)
 correct += (prediction == labels).sum().item()
 
 print("The accuracy of total {} images: {}%".format(total, 100 * correct/total))

实验结果

最终实验的正确率达到:98.22%

pytorch实现MNIST手写体识别

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现自动登录人人网并访问最近来访者实例
Sep 26 Python
Python多线程编程(一):threading模块综述
Apr 05 Python
对numpy的array和python中自带的list之间相互转化详解
Apr 13 Python
使用Python来开发微信功能
Jun 13 Python
Python中__slots__属性介绍与基本使用方法
Sep 05 Python
python实现多进程代码示例
Oct 31 Python
python实现播放音频和录音功能示例代码
Dec 30 Python
对python For 循环的三种遍历方式解析
Feb 01 Python
Pandas中DataFrame的分组/分割/合并的实现
Jul 16 Python
python构建指数平滑预测模型示例
Nov 21 Python
记一次django内存异常排查及解决方法
Aug 07 Python
pytorch Dropout过拟合的操作
May 27 Python
Python3.7实现验证码登录方式代码实例
Feb 14 #Python
Python逐行读取文件内容的方法总结
Feb 14 #Python
Python3和PyCharm安装与环境配置【图文教程】
Feb 14 #Python
python对Excel的读取的示例代码
Feb 14 #Python
Python安装依赖(包)模块方法详解
Feb 14 #Python
python 项目目录结构设置
Feb 14 #Python
wxpython自定义下拉列表框过程图解
Feb 14 #Python
You might like
Yii框架中sphinx索引配置方法解析
2016/10/18 PHP
PHP PDOStatement::errorCode讲解
2019/01/31 PHP
PHP正则之正向预查与反向预查讲解与实例
2020/04/06 PHP
php自动加载代码实例详解
2021/02/26 PHP
WordPress 插件——CoolCode使用方法与下载
2007/07/02 Javascript
javascript中"/"运算符常见错误
2010/10/13 Javascript
基于JQUERY的多级联动代码
2012/01/24 Javascript
jQuery使用动态渲染表单功能完成ajax文件下载
2013/01/15 Javascript
关于jQuery中的each方法(jQuery到底干了什么)
2014/03/05 Javascript
php和js对数据库图片进行等比缩放示例
2014/04/28 Javascript
基于Jquery插件实现跨域异步上传文件功能
2016/04/26 Javascript
vue插件tab选项卡使用小结
2016/10/27 Javascript
jQuery插件ajaxFileUpload使用详解
2017/01/10 Javascript
原生js实现返回顶部缓冲效果
2017/01/18 Javascript
详解如何使用vue-cli脚手架搭建Vue.js项目
2017/05/19 Javascript
vue基于Element构建自定义树的示例代码
2017/09/19 Javascript
jquery应用实例分享_实现手风琴特效
2018/02/01 jQuery
vue vue-Router默认hash模式修改为history需要做的修改详解
2018/09/13 Javascript
jQuery使用$.extend(true,object1, object2);实现深拷贝对象的方法分析
2019/03/06 jQuery
jquery ui 实现 tab标签功能示例【测试可用】
2019/07/25 jQuery
连接pandas以及数组转pandas的方法
2019/06/28 Python
python如果快速判断数字奇数偶数
2019/11/13 Python
Pandas时间序列重采样(resample)方法中closed、label的作用详解
2019/12/10 Python
Python搭建HTTP服务过程图解
2019/12/14 Python
购买大码女装:Lane Bryant
2016/09/07 全球购物
全球领先美式家具品牌:Ashley爱室丽家居
2017/08/07 全球购物
在C语言中"指针和数组等价"到底是什么意思?
2014/03/24 面试题
期末自我鉴定
2014/01/23 职场文书
绿化先进工作者事迹材料
2014/01/30 职场文书
会计学专业自荐信
2014/06/25 职场文书
2014年秋季开学典礼致辞
2014/08/02 职场文书
教师节倡议书
2014/08/30 职场文书
教师个人自我剖析材料
2014/09/29 职场文书
2015年度电厂个人工作总结
2015/05/13 职场文书
学生检讨书范文
2019/06/24 职场文书
《原神》新角色演示“神里绫人:林隐泓洄” 宠妹狂魔
2022/04/03 其他游戏