pytorch实现MNIST手写体识别


Posted in Python onFebruary 14, 2020

本文实例为大家分享了pytorch实现MNIST手写体识别的具体代码,供大家参考,具体内容如下

实验环境

pytorch 1.4
Windows 10
python 3.7
cuda 10.1(我笔记本上没有可以使用cuda的显卡)

实验过程

1. 确定我们要加载的库

import torch
import torch.nn as nn
import torchvision #这里面直接加载MNIST数据的方法
import torchvision.transforms as transforms # 将数据转为Tensor
import torch.optim as optim 
import torch.utils.data.dataloader as dataloader

2. 加载数据

这里使用所有数据进行训练,再使用所有数据进行测试

train_set = torchvision.datasets.MNIST(
 root='./data', # 文件存储位置
 train=True,
 transform=transforms.ToTensor(),
 download=True
)

train_dataloader = dataloader.DataLoader(dataset=train_set,shuffle=False,batch_size=100)# dataset可以省

'''
dataloader返回(images,labels)
其中,
images维度:[batch_size,1,28,28]
labels:[batch_size],即图片对应的
'''

test_set = torchvision.datasets.MNIST(
 root='./data',
 train=False,
 transform=transforms.ToTensor(),
 download=True
)

test_dataloader = dataloader.DataLoader(test_set,batch_size=100,shuffle=False) # dataset可以省

3. 定义神经网络模型

这里使用全神经网络作为模型

class NeuralNet(nn.Module):
 def __init__(self,in_num,h_num,out_num):
 super(NeuralNet,self).__init__()
 self.ln1 = nn.Linear(in_num,h_num)
 self.ln2 = nn.Linear(h_num,out_num)
 self.relu = nn.ReLU()
 
 def forward(self,x):
 return self.ln2(self.relu(self.ln1(x)))

4. 模型训练

in_num = 784 # 输入维度
h_num = 500 # 隐藏层维度
out_num = 10 # 输出维度
epochs = 30 # 迭代次数
learning_rate = 0.001
USE_CUDA = torch.cuda.is_available() # 定义是否可以使用cuda

model = NeuralNet(in_num,h_num,out_num) # 初始化模型
optimizer = optim.Adam(model.parameters(),lr=learning_rate) # 使用Adam
loss_fn = nn.CrossEntropyLoss() # 损失函数

for e in range(epochs):
 for i,data in enumerate(train_dataloader):
 (images,labels) = data
 images = images.reshape(-1,28*28) # [batch_size,784]
 if USE_CUDA:
  images = images.cuda() # 使用cuda
  labels = labels.cuda() # 使用cuda
  
 y_pred = model(images) # 预测
 loss = loss_fn(y_pred,labels) # 计算损失
 
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 
 n = e * i +1
 if n % 100 == 0:
  print(n,'loss:',loss.item())

训练模型的loss部分截图如下:

pytorch实现MNIST手写体识别

5. 测试模型

with torch.no_grad():
 total = 0
 correct = 0
 for (images,labels) in test_dataloader:
 images = images.reshape(-1,28*28)
 if USE_CUDA:
  images = images.cuda()
  labels = labels.cuda()
  
 result = model(images)
 prediction = torch.max(result, 1)[1] # 这里需要有[1],因为它返回了概率还有标签
 total += labels.size(0)
 correct += (prediction == labels).sum().item()
 
 print("The accuracy of total {} images: {}%".format(total, 100 * correct/total))

实验结果

最终实验的正确率达到:98.22%

pytorch实现MNIST手写体识别

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现脚本锁功能(同时只能执行一个脚本)
May 10 Python
对python中list的拷贝与numpy的array的拷贝详解
Jan 29 Python
关于不懂Chromedriver如何配置环境变量问题解决方法
Jun 12 Python
matplotlib命令与格式之tick坐标轴日期格式(设置日期主副刻度)
Aug 06 Python
对python中UDP,socket的使用详解
Aug 22 Python
Python英文文章词频统计(14份剑桥真题词频统计)
Oct 13 Python
Python使用QQ邮箱发送邮件报错smtplib.SMTPAuthenticationError
Dec 20 Python
jupyter notebook中新建cell的方法与快捷键操作
Apr 22 Python
Python xlrd模块导入过程及常用操作
Jun 10 Python
sklearn中的交叉验证的实现(Cross-Validation)
Feb 22 Python
Python&Matlab实现灰狼优化算法的示例代码
Mar 21 Python
Python简易开发之制作计算器
Apr 28 Python
Python3.7实现验证码登录方式代码实例
Feb 14 #Python
Python逐行读取文件内容的方法总结
Feb 14 #Python
Python3和PyCharm安装与环境配置【图文教程】
Feb 14 #Python
python对Excel的读取的示例代码
Feb 14 #Python
Python安装依赖(包)模块方法详解
Feb 14 #Python
python 项目目录结构设置
Feb 14 #Python
wxpython自定义下拉列表框过程图解
Feb 14 #Python
You might like
PHP脚本的10个技巧(2)
2006/10/09 PHP
windows下配置apache+php+mysql时出现问题的处理方法
2014/06/20 PHP
php实现批量压缩图片文件大小的脚本
2014/07/04 PHP
php获取数据库结果集方法(推荐)
2017/06/01 PHP
PHP扩展mcrypt实现的AES加密功能示例
2019/01/29 PHP
Js 本页面传值实现代码
2009/05/17 Javascript
关于JavaScript的一些看法
2009/05/27 Javascript
Jquery绑定事件(bind和live的区别介绍)
2013/08/23 Javascript
jquery和css3实现的炫酷时尚的菜单导航
2014/09/01 Javascript
JavaScript实现带标题的图片轮播特效
2015/05/20 Javascript
Underscore源码分析
2015/12/30 Javascript
一个仿微博登陆邮箱提示框js开发案例
2016/07/28 Javascript
微信小程序 刷新上拉下拉不会断详细介绍
2017/05/11 Javascript
JS中Safari浏览器中的Date
2017/07/17 Javascript
探索Vue高阶组件的使用
2018/01/08 Javascript
jQuery实现鼠标移到某个对象时弹出显示层功能
2018/08/23 jQuery
JavaScript读写二进制数据的方法详解
2018/09/09 Javascript
Vue项目中使用jsonp抓取跨域数据的方法
2019/11/10 Javascript
nodejs开发一个最简单的web服务器实例讲解
2020/01/02 NodeJs
Python自动化测试Eclipse+Pydev 搭建开发环境
2016/08/15 Python
python3使用pyqt5制作一个超简单浏览器的实例
2017/10/19 Python
解决python大批量读写.doc文件的问题
2018/05/08 Python
解决python "No module named pip" 的问题
2018/10/13 Python
Python实现八皇后问题示例代码
2018/12/09 Python
tensorflow 实现自定义layer并添加到计算图中
2020/02/04 Python
Python xlrd/xlwt 创建excel文件及常用操作
2020/09/24 Python
Python读写Excel表格的方法
2021/03/02 Python
全球采购的街头服饰和帽子:Urban Excess
2020/10/28 全球购物
请解释接口的显式实现有什么意义
2012/05/26 面试题
27个经典Linux面试题及答案,你知道几个?
2013/01/10 面试题
企业趣味活动方案
2014/08/21 职场文书
假释思想汇报范文
2014/10/11 职场文书
幼儿园教师个人总结
2015/02/05 职场文书
开学典礼观后感
2015/06/15 职场文书
CSS的class与id常用的命名规则
2021/05/18 HTML / CSS
vue2实现provide inject传递响应式
2021/05/21 Vue.js