pytorch实现MNIST手写体识别


Posted in Python onFebruary 14, 2020

本文实例为大家分享了pytorch实现MNIST手写体识别的具体代码,供大家参考,具体内容如下

实验环境

pytorch 1.4
Windows 10
python 3.7
cuda 10.1(我笔记本上没有可以使用cuda的显卡)

实验过程

1. 确定我们要加载的库

import torch
import torch.nn as nn
import torchvision #这里面直接加载MNIST数据的方法
import torchvision.transforms as transforms # 将数据转为Tensor
import torch.optim as optim 
import torch.utils.data.dataloader as dataloader

2. 加载数据

这里使用所有数据进行训练,再使用所有数据进行测试

train_set = torchvision.datasets.MNIST(
 root='./data', # 文件存储位置
 train=True,
 transform=transforms.ToTensor(),
 download=True
)

train_dataloader = dataloader.DataLoader(dataset=train_set,shuffle=False,batch_size=100)# dataset可以省

'''
dataloader返回(images,labels)
其中,
images维度:[batch_size,1,28,28]
labels:[batch_size],即图片对应的
'''

test_set = torchvision.datasets.MNIST(
 root='./data',
 train=False,
 transform=transforms.ToTensor(),
 download=True
)

test_dataloader = dataloader.DataLoader(test_set,batch_size=100,shuffle=False) # dataset可以省

3. 定义神经网络模型

这里使用全神经网络作为模型

class NeuralNet(nn.Module):
 def __init__(self,in_num,h_num,out_num):
 super(NeuralNet,self).__init__()
 self.ln1 = nn.Linear(in_num,h_num)
 self.ln2 = nn.Linear(h_num,out_num)
 self.relu = nn.ReLU()
 
 def forward(self,x):
 return self.ln2(self.relu(self.ln1(x)))

4. 模型训练

in_num = 784 # 输入维度
h_num = 500 # 隐藏层维度
out_num = 10 # 输出维度
epochs = 30 # 迭代次数
learning_rate = 0.001
USE_CUDA = torch.cuda.is_available() # 定义是否可以使用cuda

model = NeuralNet(in_num,h_num,out_num) # 初始化模型
optimizer = optim.Adam(model.parameters(),lr=learning_rate) # 使用Adam
loss_fn = nn.CrossEntropyLoss() # 损失函数

for e in range(epochs):
 for i,data in enumerate(train_dataloader):
 (images,labels) = data
 images = images.reshape(-1,28*28) # [batch_size,784]
 if USE_CUDA:
  images = images.cuda() # 使用cuda
  labels = labels.cuda() # 使用cuda
  
 y_pred = model(images) # 预测
 loss = loss_fn(y_pred,labels) # 计算损失
 
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 
 n = e * i +1
 if n % 100 == 0:
  print(n,'loss:',loss.item())

训练模型的loss部分截图如下:

pytorch实现MNIST手写体识别

5. 测试模型

with torch.no_grad():
 total = 0
 correct = 0
 for (images,labels) in test_dataloader:
 images = images.reshape(-1,28*28)
 if USE_CUDA:
  images = images.cuda()
  labels = labels.cuda()
  
 result = model(images)
 prediction = torch.max(result, 1)[1] # 这里需要有[1],因为它返回了概率还有标签
 total += labels.size(0)
 correct += (prediction == labels).sum().item()
 
 print("The accuracy of total {} images: {}%".format(total, 100 * correct/total))

实验结果

最终实验的正确率达到:98.22%

pytorch实现MNIST手写体识别

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中操作文件之write()方法的使用教程
May 25 Python
Python实现字典的key和values的交换
Aug 04 Python
python实现将内容分行输出
Nov 05 Python
python编程开发之类型转换convert实例分析
Nov 13 Python
浅谈python中的数字类型与处理工具
Aug 02 Python
python多线程之事件Event的使用详解
Apr 27 Python
详解Python3.6的py文件打包生成exe
Jul 13 Python
Python用61行代码实现图片像素化的示例代码
Dec 10 Python
python字符串替换re.sub()实例解析
Feb 09 Python
利用Python如何实时检测自身内存占用
May 09 Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 Python
python 实现百度网盘非会员上传超过500个文件的方法
Jan 07 Python
Python3.7实现验证码登录方式代码实例
Feb 14 #Python
Python逐行读取文件内容的方法总结
Feb 14 #Python
Python3和PyCharm安装与环境配置【图文教程】
Feb 14 #Python
python对Excel的读取的示例代码
Feb 14 #Python
Python安装依赖(包)模块方法详解
Feb 14 #Python
python 项目目录结构设置
Feb 14 #Python
wxpython自定义下拉列表框过程图解
Feb 14 #Python
You might like
鸡肋的PHP单例模式应用详解
2013/06/03 PHP
详解PHP内置访问资源的超时时间 time_out file_get_contents read_file
2013/06/03 PHP
phpmailer简单发送邮件的方法(附phpmailer源码下载)
2016/06/13 PHP
Yii框架实现记录日志到自定义文件的方法
2017/05/23 PHP
PHP机器学习库php-ml的简单测试和使用方法
2017/07/14 PHP
php实现socket推送技术的示例
2017/12/20 PHP
PHP自定义函数实现assign()数组分配到模板及extract()变量分配到模板功能示例
2018/05/23 PHP
js获取字符串最后一位方法汇总
2014/11/13 Javascript
动态加载js、css的实例代码
2016/05/26 Javascript
JS转换HTML转义符的方法
2016/08/24 Javascript
bootstrap suggest搜索建议插件使用详解
2017/03/25 Javascript
关于Promise 异步编程的实例讲解
2017/09/01 Javascript
Angular4 ElementRef的应用
2018/02/26 Javascript
jQuery实现浏览器之间跳转并传递参数功能【支持中文字符】
2018/03/28 jQuery
浅谈vue-cli 3.0.x 初体验
2018/04/11 Javascript
解决jQuery使用append添加的元素事件无效的问题
2018/08/30 jQuery
jQuery中each和js中forEach的区别分析
2019/02/27 jQuery
了解在JavaScript中将值转换为字符串的5种方法
2019/06/06 Javascript
微信小程序连接服务器展示MQTT数据信息的实现
2020/07/14 Javascript
[01:07:19]2018DOTA2亚洲邀请赛 4.5 淘汰赛 Mineski vs VG 第一场
2018/04/06 DOTA
在RedHat系Linux上部署Python的Celery框架的教程
2015/04/07 Python
python spyder中读取txt为图片的方法
2018/04/27 Python
python3实现猜数字游戏
2020/12/07 Python
Django中使用极验Geetest滑动验证码过程解析
2019/07/31 Python
Python 3 判断2个字典相同
2019/08/06 Python
Python爬虫:将headers请求头字符串转为字典的方法
2019/08/21 Python
Python编程快速上手——Excel表格创建乘法表案例分析
2020/02/28 Python
日本钓鱼渔具和户外用品网上商店:naturum
2016/08/07 全球购物
澳大利亚冒险体验:Adrenaline(跳伞、V8赛车、热气球等)
2017/09/18 全球购物
计算机毕业大学生推荐信
2013/12/01 职场文书
物业保洁员管理制度
2015/08/05 职场文书
小学体育教学随笔
2015/08/14 职场文书
2016庆祝国庆67周年宣传语
2015/11/25 职场文书
Python机器学习之基础概述
2021/05/19 Python
Python数据可视化之Seaborn的安装及使用
2022/04/19 Python
详解ZABBIX监控ESXI主机的问题
2022/06/21 Servers