pytorch 利用lstm做mnist手写数字识别分类的实例


Posted in Python onJanuary 10, 2020

代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定义数据
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定义模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用两层lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#将最后一个的rnn使用全连接的到最后的输出结果
     
   def forward(self, x):
     #x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,变成(batch, 28,28)
     x = x.permute(2, 0, 1)#将最后一维放到第一维,变成(batch,28,28)
     out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一个,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分类结果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定义训练过程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)

以上这篇pytorch 利用lstm做mnist手写数字识别分类的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python获取当前时间的方法
Jan 14 Python
在Python中使用PIL模块对图片进行高斯模糊处理的教程
May 05 Python
python实现烟花小程序
Jan 30 Python
对Python定时任务的启动和停止方法详解
Feb 19 Python
pyqt5对用qt designer设计的窗体实现弹出子窗口的示例
Jun 19 Python
在flask中使用python-dotenv+flask-cli自定义命令(推荐)
Jan 05 Python
ubuntu 安装pyqt5和卸载pyQt5的方法
Mar 24 Python
使用Django xadmin 实现修改时间选择器为不可输入状态
Mar 30 Python
python正则表达式的懒惰匹配和贪婪匹配说明
Jul 13 Python
Python configparser模块应用过程解析
Aug 14 Python
深入了解Python 方法之类方法 & 静态方法
Aug 17 Python
python数字图像处理之图像自动阈值分割示例
Jun 28 Python
Tensorflow Summary用法学习笔记
Jan 10 #Python
TENSORFLOW变量作用域(VARIABLE SCOPE)
Jan 10 #Python
python numpy数组复制使用实例解析
Jan 10 #Python
关于Pytorch的MNIST数据集的预处理详解
Jan 10 #Python
详解pycharm连接不上mysql数据库的解决办法
Jan 10 #Python
pycharm双击无响应(打不开问题解决办法)
Jan 10 #Python
python ubplot使用方法解析
Jan 10 #Python
You might like
php 删除一个数组中的某个值.兼容多维数组!
2012/02/18 PHP
php中flush()、ob_flush()、ob_end_flush()的区别介绍
2013/02/17 PHP
php常用Stream函数集介绍
2013/06/24 PHP
PHP中“=>
2019/03/01 PHP
用htc组件制作windows选项卡
2007/01/13 Javascript
jQuery操作Select的Option上下移动及移除添加等等
2013/11/18 Javascript
动态加载jquery库的方法
2014/02/12 Javascript
Three.js学习之网格
2016/08/10 Javascript
原生JS:Date对象全面解析
2016/09/06 Javascript
jQuery is not defined 错误原因与解决方法小结
2017/03/19 Javascript
jQuery实现的手风琴侧边菜单效果
2017/03/29 jQuery
JavaScript数据结构之二叉树的查找算法示例
2017/04/13 Javascript
vue.js实现条件渲染的实例代码
2017/06/22 Javascript
详解10分钟学会vue滚动行为
2017/09/21 Javascript
微信小程序自定义tabBar组件开发详解
2020/09/24 Javascript
vue.js多页面开发环境搭建过程
2019/04/24 Javascript
Python中几种操作字符串的方法的介绍
2015/04/09 Python
PyTorch学习:动态图和静态图的例子
2020/01/06 Python
Python selenium抓取虎牙短视频代码实例
2020/03/02 Python
Python生成六万个随机,唯一的8位数字和数字组成的随机字符串实例
2020/03/03 Python
基于python代码批量处理图片resize
2020/06/04 Python
Python Merge函数原理及用法解析
2020/09/16 Python
python 发送get请求接口详解
2020/11/17 Python
美国最大的旗帜经销商:Carrot-Top
2018/02/26 全球购物
Kaufmann Mercantile官网:家居装饰、配件、户外及更多
2018/09/28 全球购物
如何处理简单的PHP错误
2015/10/14 面试题
介绍一下grep命令的使用
2015/06/12 面试题
sort命令的作用和用法
2012/11/04 面试题
如何在Shell脚本中使用函数
2015/09/06 面试题
自考生自我评价分享
2014/01/18 职场文书
银行反四风对照检查材料
2014/09/29 职场文书
万里长城导游词
2015/01/30 职场文书
自主招生专家推荐信
2015/03/26 职场文书
毕业论文致谢范文
2015/05/14 职场文书
信息技术课教学反思
2016/02/23 职场文书
情况说明书格式及范文
2019/06/24 职场文书