pytorch实现MNIST手写体识别


Posted in Python onFebruary 14, 2020

本文实例为大家分享了pytorch实现MNIST手写体识别的具体代码,供大家参考,具体内容如下

实验环境

pytorch 1.4
Windows 10
python 3.7
cuda 10.1(我笔记本上没有可以使用cuda的显卡)

实验过程

1. 确定我们要加载的库

import torch
import torch.nn as nn
import torchvision #这里面直接加载MNIST数据的方法
import torchvision.transforms as transforms # 将数据转为Tensor
import torch.optim as optim 
import torch.utils.data.dataloader as dataloader

2. 加载数据

这里使用所有数据进行训练,再使用所有数据进行测试

train_set = torchvision.datasets.MNIST(
 root='./data', # 文件存储位置
 train=True,
 transform=transforms.ToTensor(),
 download=True
)

train_dataloader = dataloader.DataLoader(dataset=train_set,shuffle=False,batch_size=100)# dataset可以省

'''
dataloader返回(images,labels)
其中,
images维度:[batch_size,1,28,28]
labels:[batch_size],即图片对应的
'''

test_set = torchvision.datasets.MNIST(
 root='./data',
 train=False,
 transform=transforms.ToTensor(),
 download=True
)

test_dataloader = dataloader.DataLoader(test_set,batch_size=100,shuffle=False) # dataset可以省

3. 定义神经网络模型

这里使用全神经网络作为模型

class NeuralNet(nn.Module):
 def __init__(self,in_num,h_num,out_num):
 super(NeuralNet,self).__init__()
 self.ln1 = nn.Linear(in_num,h_num)
 self.ln2 = nn.Linear(h_num,out_num)
 self.relu = nn.ReLU()
 
 def forward(self,x):
 return self.ln2(self.relu(self.ln1(x)))

4. 模型训练

in_num = 784 # 输入维度
h_num = 500 # 隐藏层维度
out_num = 10 # 输出维度
epochs = 30 # 迭代次数
learning_rate = 0.001
USE_CUDA = torch.cuda.is_available() # 定义是否可以使用cuda

model = NeuralNet(in_num,h_num,out_num) # 初始化模型
optimizer = optim.Adam(model.parameters(),lr=learning_rate) # 使用Adam
loss_fn = nn.CrossEntropyLoss() # 损失函数

for e in range(epochs):
 for i,data in enumerate(train_dataloader):
 (images,labels) = data
 images = images.reshape(-1,28*28) # [batch_size,784]
 if USE_CUDA:
  images = images.cuda() # 使用cuda
  labels = labels.cuda() # 使用cuda
  
 y_pred = model(images) # 预测
 loss = loss_fn(y_pred,labels) # 计算损失
 
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 
 n = e * i +1
 if n % 100 == 0:
  print(n,'loss:',loss.item())

训练模型的loss部分截图如下:

pytorch实现MNIST手写体识别

5. 测试模型

with torch.no_grad():
 total = 0
 correct = 0
 for (images,labels) in test_dataloader:
 images = images.reshape(-1,28*28)
 if USE_CUDA:
  images = images.cuda()
  labels = labels.cuda()
  
 result = model(images)
 prediction = torch.max(result, 1)[1] # 这里需要有[1],因为它返回了概率还有标签
 total += labels.size(0)
 correct += (prediction == labels).sum().item()
 
 print("The accuracy of total {} images: {}%".format(total, 100 * correct/total))

实验结果

最终实验的正确率达到:98.22%

pytorch实现MNIST手写体识别

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python获取局域网占带宽最大3个ip的方法
Jul 09 Python
在Django框架中运行Python应用全攻略
Jul 17 Python
Django REST为文件属性输出完整URL的方法
Dec 18 Python
使用python将时间转换为指定的格式方法
Nov 12 Python
python3中的logging记录日志实现过程及封装成类的操作
May 12 Python
在python3.64中安装pyinstaller库的方法步骤
Jun 02 Python
python利用os模块编写文件复制功能——copy()函数用法
Jul 13 Python
Python3爬虫关于代理池的维护详解
Jul 30 Python
Python脚本打包成可执行文件过程解析
Oct 20 Python
用python爬虫批量下载pdf的实现
Dec 01 Python
python wsgiref源码解析
Feb 06 Python
基于PyQT5制作一个桌面摸鱼工具
Feb 15 Python
Python3.7实现验证码登录方式代码实例
Feb 14 #Python
Python逐行读取文件内容的方法总结
Feb 14 #Python
Python3和PyCharm安装与环境配置【图文教程】
Feb 14 #Python
python对Excel的读取的示例代码
Feb 14 #Python
Python安装依赖(包)模块方法详解
Feb 14 #Python
python 项目目录结构设置
Feb 14 #Python
wxpython自定义下拉列表框过程图解
Feb 14 #Python
You might like
用mysql触发器自动更新memcache的实现代码
2009/10/11 PHP
php 求质素(素数) 的实现代码
2011/04/12 PHP
ie和firefox中img对象区别的困惑
2006/12/27 Javascript
仿校内登陆框,精美,给那些很厉害但是没有设计天才的程序员
2008/11/24 Javascript
jQuery实现可收缩展开的级联菜单实例代码
2013/11/27 Javascript
js 高效去除数组重复元素示例代码
2013/12/19 Javascript
javascript实现拖动元素交换位置
2015/11/29 Javascript
整理Javascript基础语法学习笔记
2015/11/29 Javascript
JavaScript的Vue.js库入门学习教程
2016/05/23 Javascript
Vue2.0如何发布项目实战
2017/07/27 Javascript
Angular2整合其他插件的方法
2018/01/20 Javascript
Vue项目报错:Uncaught SyntaxError: Unexpected token
2018/11/10 Javascript
详解CommonJS和ES6模块循环加载处理的区别
2018/12/26 Javascript
小程序实现搜索框功能
2020/03/26 Javascript
微信小程序导入Vant报错VM292:1 thirdScriptError的解决方法
2019/08/01 Javascript
不依任何赖第三方,单纯用vue实现Tree 树形控件的案例
2020/09/21 Javascript
[01:57]2018年度DOTA2最具潜力解说-完美盛典
2018/12/16 DOTA
python实现udp数据报传输的方法
2014/09/26 Python
Python变量作用范围实例分析
2015/07/07 Python
将Django框架和遗留的Web应用集成的方法
2015/07/24 Python
pycharm debug功能实现跳到循环末尾的方法
2018/11/29 Python
Python with标签使用方法解析
2020/01/17 Python
Python实现Wordcloud生成词云图的示例
2020/03/30 Python
30行Python代码实现高分辨率图像导航的方法
2020/05/22 Python
html5使用window.postMessage进行跨域实现数据交互的一次实战
2021/02/24 HTML / CSS
南京软件公司的.net程序员笔试题
2014/08/31 面试题
怎样写好自我评价呢?
2014/02/16 职场文书
北京奥运会口号
2014/06/21 职场文书
食堂厨师岗位职责
2014/08/25 职场文书
成绩单评语
2015/01/04 职场文书
新员工辞职信范文
2015/05/12 职场文书
会计试用期工作总结2015
2015/05/28 职场文书
python如何获取网络数据
2021/04/11 Python
Mysql官方性能测试工具mysqlslap的使用简介
2021/05/21 MySQL
MySql子查询IN的执行和优化的实现
2021/08/02 MySQL
Oracle 触发器trigger使用案例
2022/02/24 Oracle