pytorch 利用lstm做mnist手写数字识别分类的实例


Posted in Python onJanuary 10, 2020

代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定义数据
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定义模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用两层lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#将最后一个的rnn使用全连接的到最后的输出结果
     
   def forward(self, x):
     #x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,变成(batch, 28,28)
     x = x.permute(2, 0, 1)#将最后一维放到第一维,变成(batch,28,28)
     out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一个,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分类结果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定义训练过程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)

以上这篇pytorch 利用lstm做mnist手写数字识别分类的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用PyGame绘制图像并保存为图片文件的方法
Apr 24 Python
python简单实现基于SSL的IRC bot实例
Jun 15 Python
Linux下多个Python版本安装教程
Aug 15 Python
对python中dict和json的区别详解
Dec 18 Python
python利用selenium进行浏览器爬虫
Apr 25 Python
Python学习笔记之Django创建第一个数据库模型的方法
Aug 07 Python
python KNN算法实现鸢尾花数据集分类
Oct 24 Python
Python中的全局变量如何理解
Jun 04 Python
python和php学习哪个更有发展
Jun 17 Python
详解python 支持向量机(SVM)算法
Sep 18 Python
Python os库常用操作代码汇总
Nov 03 Python
Python 流媒体播放器的实现(基于VLC)
Apr 28 Python
Tensorflow Summary用法学习笔记
Jan 10 #Python
TENSORFLOW变量作用域(VARIABLE SCOPE)
Jan 10 #Python
python numpy数组复制使用实例解析
Jan 10 #Python
关于Pytorch的MNIST数据集的预处理详解
Jan 10 #Python
详解pycharm连接不上mysql数据库的解决办法
Jan 10 #Python
pycharm双击无响应(打不开问题解决办法)
Jan 10 #Python
python ubplot使用方法解析
Jan 10 #Python
You might like
PHP实现从远程下载文件的方法
2015/03/12 PHP
PHP调用API接口实现天气查询功能的示例
2017/09/21 PHP
ThinkPHP框架获取最后一次执行SQL语句及变量调试简单操作示例
2018/06/13 PHP
使用Grunt.js管理你项目的应用说明
2013/04/24 Javascript
document.getElementBy("id")与$("#id")有什么区别
2013/09/22 Javascript
jQuery实现左右切换焦点图
2015/04/03 Javascript
DOM 高级编程
2015/05/06 Javascript
js实现仿Discuz文本框弹出层效果
2015/08/13 Javascript
基于insertBefore制作简单的循环插空效果
2015/09/21 Javascript
JS/jquery实现一个网页内同时调用多个倒计时的方法
2017/04/27 jQuery
解决低版本的浏览器不支持es6的import问题
2018/03/09 Javascript
js表达式与运算符简单操作示例
2020/02/15 Javascript
JQuery使用数组遍历跳出each循环
2020/09/01 jQuery
[02:12]2019完美世界全国高校联赛(春季赛)报名开启
2019/03/01 DOTA
用Python中的wxPython实现最基本的浏览器功能
2015/04/14 Python
Python处理XML格式数据的方法详解
2017/03/21 Python
python+pyqt实现12306图片验证效果
2017/10/25 Python
python把数组中的数字每行打印3个并保存在文档中的方法
2018/07/17 Python
python+splinter实现12306网站刷票并自动购票流程
2018/09/25 Python
详解Python下Flask-ApScheduler快速指南
2018/11/04 Python
Python实现EXCEL表格的排序功能示例
2019/06/25 Python
pyqt5 使用cv2 显示图片,摄像头的实例
2019/06/27 Python
Django实现从数据库中获取到的数据转换为dict
2020/03/27 Python
python利用os模块编写文件复制功能——copy()函数用法
2020/07/13 Python
CSS3支持IE6, 7, and 8的边框border属性
2012/12/28 HTML / CSS
canvas像素点操作之视频绿幕抠图
2018/09/11 HTML / CSS
美国在线艺术商店:HandmadePiece
2020/11/06 全球购物
TUMI香港官网:国际领先的行李箱、背囊品牌
2021/03/01 全球购物
几个常见的软件测试问题
2016/09/07 面试题
自我评价正确写法范文
2013/12/10 职场文书
副校长竞聘演讲稿
2014/09/01 职场文书
学校领导班子对照检查材料
2014/09/24 职场文书
党员转正意见怎么写
2015/06/03 职场文书
机械原理课程设计心得体会
2016/01/15 职场文书
MySQL约束超详解
2021/09/04 MySQL
前端使用svg图片改色实现示例
2022/07/23 HTML / CSS