pytorch实现MNIST手写体识别


Posted in Python onFebruary 14, 2020

本文实例为大家分享了pytorch实现MNIST手写体识别的具体代码,供大家参考,具体内容如下

实验环境

pytorch 1.4
Windows 10
python 3.7
cuda 10.1(我笔记本上没有可以使用cuda的显卡)

实验过程

1. 确定我们要加载的库

import torch
import torch.nn as nn
import torchvision #这里面直接加载MNIST数据的方法
import torchvision.transforms as transforms # 将数据转为Tensor
import torch.optim as optim 
import torch.utils.data.dataloader as dataloader

2. 加载数据

这里使用所有数据进行训练,再使用所有数据进行测试

train_set = torchvision.datasets.MNIST(
 root='./data', # 文件存储位置
 train=True,
 transform=transforms.ToTensor(),
 download=True
)

train_dataloader = dataloader.DataLoader(dataset=train_set,shuffle=False,batch_size=100)# dataset可以省

'''
dataloader返回(images,labels)
其中,
images维度:[batch_size,1,28,28]
labels:[batch_size],即图片对应的
'''

test_set = torchvision.datasets.MNIST(
 root='./data',
 train=False,
 transform=transforms.ToTensor(),
 download=True
)

test_dataloader = dataloader.DataLoader(test_set,batch_size=100,shuffle=False) # dataset可以省

3. 定义神经网络模型

这里使用全神经网络作为模型

class NeuralNet(nn.Module):
 def __init__(self,in_num,h_num,out_num):
 super(NeuralNet,self).__init__()
 self.ln1 = nn.Linear(in_num,h_num)
 self.ln2 = nn.Linear(h_num,out_num)
 self.relu = nn.ReLU()
 
 def forward(self,x):
 return self.ln2(self.relu(self.ln1(x)))

4. 模型训练

in_num = 784 # 输入维度
h_num = 500 # 隐藏层维度
out_num = 10 # 输出维度
epochs = 30 # 迭代次数
learning_rate = 0.001
USE_CUDA = torch.cuda.is_available() # 定义是否可以使用cuda

model = NeuralNet(in_num,h_num,out_num) # 初始化模型
optimizer = optim.Adam(model.parameters(),lr=learning_rate) # 使用Adam
loss_fn = nn.CrossEntropyLoss() # 损失函数

for e in range(epochs):
 for i,data in enumerate(train_dataloader):
 (images,labels) = data
 images = images.reshape(-1,28*28) # [batch_size,784]
 if USE_CUDA:
  images = images.cuda() # 使用cuda
  labels = labels.cuda() # 使用cuda
  
 y_pred = model(images) # 预测
 loss = loss_fn(y_pred,labels) # 计算损失
 
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 
 n = e * i +1
 if n % 100 == 0:
  print(n,'loss:',loss.item())

训练模型的loss部分截图如下:

pytorch实现MNIST手写体识别

5. 测试模型

with torch.no_grad():
 total = 0
 correct = 0
 for (images,labels) in test_dataloader:
 images = images.reshape(-1,28*28)
 if USE_CUDA:
  images = images.cuda()
  labels = labels.cuda()
  
 result = model(images)
 prediction = torch.max(result, 1)[1] # 这里需要有[1],因为它返回了概率还有标签
 total += labels.size(0)
 correct += (prediction == labels).sum().item()
 
 print("The accuracy of total {} images: {}%".format(total, 100 * correct/total))

实验结果

最终实验的正确率达到:98.22%

pytorch实现MNIST手写体识别

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用MONGODB入门实例
May 11 Python
python中文分词,使用结巴分词对python进行分词(实例讲解)
Nov 14 Python
Python模块文件结构代码详解
Feb 03 Python
详解python之协程gevent模块
Jun 14 Python
快速排序的四种python实现(推荐)
Apr 03 Python
pytorch 固定部分参数训练的方法
Aug 17 Python
Python中turtle库的使用实例
Sep 09 Python
Python3和PyCharm安装与环境配置【图文教程】
Feb 14 Python
Pytest框架之fixture的详细使用教程
Apr 07 Python
增大python字体的方法步骤
Jul 05 Python
Python selenium爬取微信公众号文章代码详解
Aug 12 Python
Python OpenCV快速入门教程
Apr 17 Python
Python3.7实现验证码登录方式代码实例
Feb 14 #Python
Python逐行读取文件内容的方法总结
Feb 14 #Python
Python3和PyCharm安装与环境配置【图文教程】
Feb 14 #Python
python对Excel的读取的示例代码
Feb 14 #Python
Python安装依赖(包)模块方法详解
Feb 14 #Python
python 项目目录结构设置
Feb 14 #Python
wxpython自定义下拉列表框过程图解
Feb 14 #Python
You might like
两种php去除二维数组的重复项方法
2015/11/04 PHP
PHP如何实现订单的延时处理详解
2017/12/30 PHP
php设计模式之享元模式分析【星际争霸游戏案例】
2020/03/23 PHP
20款效果非常棒的 jQuery 插件小结分享
2011/11/18 Javascript
JavaScript判断DOM何时加载完毕的技巧
2012/11/11 Javascript
jquery ajax 简单范例(界面+后台)
2013/11/19 Javascript
json属性名为什么要双引号(个人猜测)
2014/07/31 Javascript
js怎么覆盖原有方法实现重写
2014/09/04 Javascript
实用框架(iframe)操作代码
2014/10/23 Javascript
通过sails和阿里大于实现短信验证
2017/01/04 Javascript
javascript判断回文数详解及实现代码
2017/02/03 Javascript
JS闭包可被利用的常见场景小结
2017/04/09 Javascript
vue2 mint-ui loadmore实现下拉刷新,上拉更多功能
2018/03/21 Javascript
Webpack中雪碧图插件使用详解
2018/05/25 Javascript
在 vue-cli v3.0 中使用 SCSS/SASS的方法
2018/06/14 Javascript
vue2.0项目实现路由跳转的方法详解
2018/06/21 Javascript
JS使用canvas中的measureText方法测量字体宽度示例
2019/02/02 Javascript
JS中的算法与数据结构之集合(Set)实例详解
2019/08/20 Javascript
[51:26]DOTA2上海特级锦标赛主赛事日 - 2 胜者组第一轮#3Secret VS OG第二局
2016/03/03 DOTA
用Python实现一个简单的能够发送带附件的邮件程序的教程
2015/04/08 Python
python中字符串类型json操作的注意事项
2017/05/02 Python
对Python中9种生成新对象的方法总结
2018/05/23 Python
python验证码识别教程之利用滴水算法分割图片
2018/06/05 Python
Python @property使用方法解析
2019/09/17 Python
Python continue语句实例用法
2020/02/06 Python
最小二乘法及其python实现详解
2020/02/24 Python
浅谈keras中的batch_dot,dot方法和TensorFlow的matmul
2020/06/18 Python
python链表类中获取元素实例方法
2021/02/23 Python
JavaScript获取当前url根目录(路径)
2014/02/19 面试题
公司庆典活动邀请函
2014/01/09 职场文书
电子信息专业自荐书
2014/02/04 职场文书
个人简历求职信范文
2015/03/20 职场文书
2015年度保密工作总结
2015/04/24 职场文书
团日活动总结格式
2015/05/11 职场文书
民事代理词范文
2015/05/25 职场文书
Android开发之底部导航栏的快速实现
2022/04/28 Java/Android