pytorch实现MNIST手写体识别


Posted in Python onFebruary 14, 2020

本文实例为大家分享了pytorch实现MNIST手写体识别的具体代码,供大家参考,具体内容如下

实验环境

pytorch 1.4
Windows 10
python 3.7
cuda 10.1(我笔记本上没有可以使用cuda的显卡)

实验过程

1. 确定我们要加载的库

import torch
import torch.nn as nn
import torchvision #这里面直接加载MNIST数据的方法
import torchvision.transforms as transforms # 将数据转为Tensor
import torch.optim as optim 
import torch.utils.data.dataloader as dataloader

2. 加载数据

这里使用所有数据进行训练,再使用所有数据进行测试

train_set = torchvision.datasets.MNIST(
 root='./data', # 文件存储位置
 train=True,
 transform=transforms.ToTensor(),
 download=True
)

train_dataloader = dataloader.DataLoader(dataset=train_set,shuffle=False,batch_size=100)# dataset可以省

'''
dataloader返回(images,labels)
其中,
images维度:[batch_size,1,28,28]
labels:[batch_size],即图片对应的
'''

test_set = torchvision.datasets.MNIST(
 root='./data',
 train=False,
 transform=transforms.ToTensor(),
 download=True
)

test_dataloader = dataloader.DataLoader(test_set,batch_size=100,shuffle=False) # dataset可以省

3. 定义神经网络模型

这里使用全神经网络作为模型

class NeuralNet(nn.Module):
 def __init__(self,in_num,h_num,out_num):
 super(NeuralNet,self).__init__()
 self.ln1 = nn.Linear(in_num,h_num)
 self.ln2 = nn.Linear(h_num,out_num)
 self.relu = nn.ReLU()
 
 def forward(self,x):
 return self.ln2(self.relu(self.ln1(x)))

4. 模型训练

in_num = 784 # 输入维度
h_num = 500 # 隐藏层维度
out_num = 10 # 输出维度
epochs = 30 # 迭代次数
learning_rate = 0.001
USE_CUDA = torch.cuda.is_available() # 定义是否可以使用cuda

model = NeuralNet(in_num,h_num,out_num) # 初始化模型
optimizer = optim.Adam(model.parameters(),lr=learning_rate) # 使用Adam
loss_fn = nn.CrossEntropyLoss() # 损失函数

for e in range(epochs):
 for i,data in enumerate(train_dataloader):
 (images,labels) = data
 images = images.reshape(-1,28*28) # [batch_size,784]
 if USE_CUDA:
  images = images.cuda() # 使用cuda
  labels = labels.cuda() # 使用cuda
  
 y_pred = model(images) # 预测
 loss = loss_fn(y_pred,labels) # 计算损失
 
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 
 n = e * i +1
 if n % 100 == 0:
  print(n,'loss:',loss.item())

训练模型的loss部分截图如下:

pytorch实现MNIST手写体识别

5. 测试模型

with torch.no_grad():
 total = 0
 correct = 0
 for (images,labels) in test_dataloader:
 images = images.reshape(-1,28*28)
 if USE_CUDA:
  images = images.cuda()
  labels = labels.cuda()
  
 result = model(images)
 prediction = torch.max(result, 1)[1] # 这里需要有[1],因为它返回了概率还有标签
 total += labels.size(0)
 correct += (prediction == labels).sum().item()
 
 print("The accuracy of total {} images: {}%".format(total, 100 * correct/total))

实验结果

最终实验的正确率达到:98.22%

pytorch实现MNIST手写体识别

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深入理解Javascript中的this关键字
Mar 27 Python
使用Scrapy爬取动态数据
Oct 21 Python
浅析python的优势和不足之处
Nov 20 Python
人工神经网络算法知识点总结
Jun 11 Python
pandas中的series数据类型详解
Jul 06 Python
关于PyTorch源码解读之torchvision.models
Aug 17 Python
Python解析json代码实例解析
Nov 25 Python
TensorFlow:将ckpt文件固化成pb文件教程
Feb 11 Python
详解Selenium-webdriver绕开反爬虫机制的4种方法
Oct 28 Python
Django REST Framework 分页(Pagination)详解
Nov 30 Python
Python实现Kerberos用户的增删改查操作
Dec 14 Python
python分分钟绘制精美地图海报
Feb 15 Python
Python3.7实现验证码登录方式代码实例
Feb 14 #Python
Python逐行读取文件内容的方法总结
Feb 14 #Python
Python3和PyCharm安装与环境配置【图文教程】
Feb 14 #Python
python对Excel的读取的示例代码
Feb 14 #Python
Python安装依赖(包)模块方法详解
Feb 14 #Python
python 项目目录结构设置
Feb 14 #Python
wxpython自定义下拉列表框过程图解
Feb 14 #Python
You might like
php序列化函数serialize() 和 unserialize() 与原生函数对比
2015/05/08 PHP
PHP浮点数精度问题汇总
2015/05/13 PHP
功能强大的PHP POST提交数据类
2016/07/15 PHP
laravel 实现向公共模板中传值 (view composer)
2019/10/22 PHP
脚本吧 - 幻宇工作室用到js,超强推荐share.js
2006/12/23 Javascript
javascript  Error 对象 错误处理
2008/05/18 Javascript
javascript 客户端验证上传图片的大小(兼容IE和火狐)
2009/08/15 Javascript
轻松创建nodejs服务器(4):路由
2014/12/18 NodeJs
jquery 插件实现瀑布流图片展示实例
2015/04/03 Javascript
js实现同一个页面多个渐变效果的方法
2015/04/10 Javascript
JavaScript 实现完美兼容多浏览器的复制功能代码
2015/04/28 Javascript
由简入繁实现Jquery树状结构的方法(推荐)
2016/06/10 Javascript
原生js封装的一些jquery方法(详解)
2016/09/20 Javascript
如何在Angular2中使用jQuery及其插件的方法
2017/02/09 Javascript
Angularjs添加排序查询功能的实例代码
2017/10/24 Javascript
webstorm中vue语法的支持详解
2018/05/09 Javascript
JS中DOM元素的attribute与property属性示例详解
2018/09/04 Javascript
JavaScript onclick事件使用方法详解
2020/05/15 Javascript
浅谈JavaScript中this的指向更改
2020/07/28 Javascript
Python实现的ini文件操作类分享
2014/11/20 Python
在Django的上下文中设置变量的方法
2015/07/20 Python
python直接访问私有属性的简单方法
2016/07/25 Python
Python中字符串的处理技巧分享
2016/09/17 Python
Python中将dataframe转换为字典的实例
2018/04/13 Python
python Kmeans算法原理深入解析
2019/08/23 Python
Python 经典算法100及解析(小结)
2019/09/13 Python
python hashlib加密实现代码
2019/10/17 Python
numpy 返回函数的上三角矩阵实例
2019/11/25 Python
python 在右键菜单中加入复制目标文件的有效存放路径(单斜杠或者双反斜杠)
2020/04/08 Python
scrapy在python爬虫中搭建出错的解决方法
2020/11/22 Python
HTML5实现移动端点击翻牌功能
2020/10/23 HTML / CSS
毕业生个人投资创业计划书
2014/01/04 职场文书
员工工作及收入证明
2014/10/28 职场文书
员工工作心得体会
2019/05/07 职场文书
MySQL之DML语言
2021/04/05 MySQL
Python道路车道线检测的实现
2021/06/27 Python