pytorch实现MNIST手写体识别


Posted in Python onFebruary 14, 2020

本文实例为大家分享了pytorch实现MNIST手写体识别的具体代码,供大家参考,具体内容如下

实验环境

pytorch 1.4
Windows 10
python 3.7
cuda 10.1(我笔记本上没有可以使用cuda的显卡)

实验过程

1. 确定我们要加载的库

import torch
import torch.nn as nn
import torchvision #这里面直接加载MNIST数据的方法
import torchvision.transforms as transforms # 将数据转为Tensor
import torch.optim as optim 
import torch.utils.data.dataloader as dataloader

2. 加载数据

这里使用所有数据进行训练,再使用所有数据进行测试

train_set = torchvision.datasets.MNIST(
 root='./data', # 文件存储位置
 train=True,
 transform=transforms.ToTensor(),
 download=True
)

train_dataloader = dataloader.DataLoader(dataset=train_set,shuffle=False,batch_size=100)# dataset可以省

'''
dataloader返回(images,labels)
其中,
images维度:[batch_size,1,28,28]
labels:[batch_size],即图片对应的
'''

test_set = torchvision.datasets.MNIST(
 root='./data',
 train=False,
 transform=transforms.ToTensor(),
 download=True
)

test_dataloader = dataloader.DataLoader(test_set,batch_size=100,shuffle=False) # dataset可以省

3. 定义神经网络模型

这里使用全神经网络作为模型

class NeuralNet(nn.Module):
 def __init__(self,in_num,h_num,out_num):
 super(NeuralNet,self).__init__()
 self.ln1 = nn.Linear(in_num,h_num)
 self.ln2 = nn.Linear(h_num,out_num)
 self.relu = nn.ReLU()
 
 def forward(self,x):
 return self.ln2(self.relu(self.ln1(x)))

4. 模型训练

in_num = 784 # 输入维度
h_num = 500 # 隐藏层维度
out_num = 10 # 输出维度
epochs = 30 # 迭代次数
learning_rate = 0.001
USE_CUDA = torch.cuda.is_available() # 定义是否可以使用cuda

model = NeuralNet(in_num,h_num,out_num) # 初始化模型
optimizer = optim.Adam(model.parameters(),lr=learning_rate) # 使用Adam
loss_fn = nn.CrossEntropyLoss() # 损失函数

for e in range(epochs):
 for i,data in enumerate(train_dataloader):
 (images,labels) = data
 images = images.reshape(-1,28*28) # [batch_size,784]
 if USE_CUDA:
  images = images.cuda() # 使用cuda
  labels = labels.cuda() # 使用cuda
  
 y_pred = model(images) # 预测
 loss = loss_fn(y_pred,labels) # 计算损失
 
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 
 n = e * i +1
 if n % 100 == 0:
  print(n,'loss:',loss.item())

训练模型的loss部分截图如下:

pytorch实现MNIST手写体识别

5. 测试模型

with torch.no_grad():
 total = 0
 correct = 0
 for (images,labels) in test_dataloader:
 images = images.reshape(-1,28*28)
 if USE_CUDA:
  images = images.cuda()
  labels = labels.cuda()
  
 result = model(images)
 prediction = torch.max(result, 1)[1] # 这里需要有[1],因为它返回了概率还有标签
 total += labels.size(0)
 correct += (prediction == labels).sum().item()
 
 print("The accuracy of total {} images: {}%".format(total, 100 * correct/total))

实验结果

最终实验的正确率达到:98.22%

pytorch实现MNIST手写体识别

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python爬虫模拟登录带验证码网站
Jan 22 Python
Python+Opencv识别两张相似图片
Mar 23 Python
Python网络爬虫与信息提取(实例讲解)
Aug 29 Python
wxpython实现图书管理系统
Mar 12 Python
Python爬虫小技巧之伪造随机的User-Agent
Sep 13 Python
pandas筛选某列出现编码错误的解决方法
Nov 07 Python
selenium在执行phantomjs的API并获取执行结果的方法
Dec 17 Python
Python之循环结构
Jan 15 Python
PyTorch笔记之scatter()函数的使用
Feb 12 Python
如何用Python提取10000份log中的产品信息
Jan 14 Python
python中time包实例详解
Feb 02 Python
用Python selenium实现淘宝抢单机器人
Jun 18 Python
Python3.7实现验证码登录方式代码实例
Feb 14 #Python
Python逐行读取文件内容的方法总结
Feb 14 #Python
Python3和PyCharm安装与环境配置【图文教程】
Feb 14 #Python
python对Excel的读取的示例代码
Feb 14 #Python
Python安装依赖(包)模块方法详解
Feb 14 #Python
python 项目目录结构设置
Feb 14 #Python
wxpython自定义下拉列表框过程图解
Feb 14 #Python
You might like
星际流派综述
2020/03/04 星际争霸
php MySQL与分页效率
2008/06/04 PHP
wiki-shan写的php在线加密的解密程序
2008/09/07 PHP
PHP之图片上传类实例代码(加了缩略图)
2016/06/30 PHP
php实现通过soap调用.Net的WebService asmx文件
2017/02/27 PHP
搜索附近的人PHP实现代码
2018/02/11 PHP
PHP结合jquery ajax实现上传多张图片,并限制图片大小操作示例
2019/03/01 PHP
ASP Json Parser修正版
2009/12/06 Javascript
jquery 实现checkbox全选,反选,全不选等功能代码(奇数)
2012/10/24 Javascript
javascript setTimeout和setInterval计时的区别详解
2013/06/21 Javascript
jQuery制作可自定义大小的拼图游戏
2015/03/30 Javascript
jQuery无刷新上传之uploadify3.1简单使用
2016/06/18 Javascript
js实现日历的简单算法
2017/01/24 Javascript
Angularjs中的ui-bootstrap的使用教程
2017/02/19 Javascript
使用Javascript简单计算器
2018/11/17 Javascript
js验证账户名是否重复
2020/05/26 Javascript
JavaScript中window和document用法详解
2020/07/28 Javascript
详解Vue的组件中data选项为什么必须是函数
2020/08/17 Javascript
微信小程序基于高德地图API实现天气组件(动态效果)
2020/10/22 Javascript
JavaScript前后端JSON使用方法教程
2020/11/23 Javascript
vue实现图片裁剪后上传
2020/12/16 Vue.js
JavaScript中跨域问题的深入理解
2021/03/04 Javascript
[01:08:43]DOTA2-DPC中国联赛定级赛 Phoenix vs DLG BO3第一场 1月9日
2021/03/11 DOTA
python paramiko模块学习分享
2017/08/23 Python
python实现自动发送邮件发送多人、群发、多附件的示例
2018/01/23 Python
Python实现爬虫抓取与读写、追加到excel文件操作示例
2018/06/27 Python
在python中使用with打开多个文件的方法
2019/01/07 Python
对Django项目中的ORM映射与模糊查询的使用详解
2019/07/18 Python
Tensorflow实现神经网络拟合线性回归
2019/07/19 Python
Python代码一键转Jar包及Java调用Python新姿势
2020/03/10 Python
介绍CSS3使用技巧5个
2009/04/02 HTML / CSS
国际旅客访问北美最大的汽车租赁提供商:Alamo Rent A Car
2018/06/13 全球购物
学习考察心得体会
2014/09/04 职场文书
2016小学新学期寄语
2015/12/04 职场文书
该怎么书写道歉信?
2019/07/03 职场文书
python神经网络学习 使用Keras进行简单分类
2022/05/04 Python