pytorch 利用lstm做mnist手写数字识别分类的实例


Posted in Python onJanuary 10, 2020

代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定义数据
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定义模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用两层lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#将最后一个的rnn使用全连接的到最后的输出结果
     
   def forward(self, x):
     #x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,变成(batch, 28,28)
     x = x.permute(2, 0, 1)#将最后一维放到第一维,变成(batch,28,28)
     out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一个,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分类结果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定义训练过程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)

以上这篇pytorch 利用lstm做mnist手写数字识别分类的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Django框架中方法的访问和查找
Jul 15 Python
听歌识曲--用python实现一个音乐检索器的功能
Nov 15 Python
python 换位密码算法的实例详解
Jul 19 Python
Python 中 Virtualenv 和 pip 的简单用法详解
Aug 18 Python
python list删除元素时要注意的坑点分享
Apr 18 Python
基于anaconda下强大的conda命令介绍
Jun 11 Python
python 对txt中每行内容进行批量替换的方法
Jul 11 Python
python实现从pdf文件中提取文本,并自动翻译的方法
Nov 28 Python
Python3.5 Pandas模块缺失值处理和层次索引实例详解
Apr 23 Python
Python time库基本使用方法分析
Dec 13 Python
详解Flask前后端分离项目案例
Jul 24 Python
python实现ROA算子边缘检测算法
Apr 05 Python
Tensorflow Summary用法学习笔记
Jan 10 #Python
TENSORFLOW变量作用域(VARIABLE SCOPE)
Jan 10 #Python
python numpy数组复制使用实例解析
Jan 10 #Python
关于Pytorch的MNIST数据集的预处理详解
Jan 10 #Python
详解pycharm连接不上mysql数据库的解决办法
Jan 10 #Python
pycharm双击无响应(打不开问题解决办法)
Jan 10 #Python
python ubplot使用方法解析
Jan 10 #Python
You might like
供参考的 php 学习提高路线分享
2011/10/23 PHP
php给图片加文字水印
2015/07/31 PHP
PHP pthreads v3下的Volatile简介与使用方法示例
2020/02/21 PHP
firefox中用javascript实现鼠标位置的定位
2007/06/17 Javascript
IE6-IE9中tbody的innerHTML不能赋值的解决方法
2014/09/26 Javascript
javascript常用功能汇总
2015/07/05 Javascript
jquery实现叠层3D文字特效代码分享
2015/08/21 Javascript
JavaScript数组去重的两种方法推荐
2016/04/05 Javascript
js基于cookie方式记住返回页面用法示例
2016/05/27 Javascript
JQuery学习总结【二】
2016/12/01 Javascript
JS 组件系列之 bootstrap treegrid 组件封装过程
2017/04/28 Javascript
vue+mockjs模拟数据实现前后端分离开发的实例代码
2017/08/08 Javascript
vue3.0 CLI - 2.3 - 组件 home.vue 中学习指令和绑定
2018/09/14 Javascript
jquery实现Ajax请求的几种常见方式总结
2019/05/28 jQuery
JS中数据结构与算法---排序算法(Sort Algorithm)实例详解
2019/06/17 Javascript
vue.js 2.0实现简单分页效果
2019/07/29 Javascript
vue 解决文本框被键盘遮住的问题
2019/11/06 Javascript
微信小程序使用GoEasy实现websocket实时通讯
2020/05/19 Javascript
JavaScript实现手机号码 3-4-4格式并控制新增和删除时光标的位置
2020/06/02 Javascript
JS实现无限轮播无倒退效果
2020/09/21 Javascript
python通过字典dict判断指定键值是否存在的方法
2015/03/21 Python
《Python之禅》中对于Python编程过程中的一些建议
2015/04/03 Python
WINDOWS 同时安装 python2 python3 后 pip 错误的解决方法
2017/03/16 Python
Python简单生成随机姓名的方法示例
2017/12/27 Python
Python操作MySQL数据库的三种方法总结
2018/01/30 Python
通过python的matplotlib包将Tensorflow数据进行可视化的方法
2019/01/09 Python
Python使用while循环花式打印乘法表
2019/01/28 Python
Python集中化管理平台Ansible介绍与YAML简介
2019/06/12 Python
python能做哪些生活有趣的事情
2020/09/09 Python
比较一下entity bean和session bean
2013/12/27 面试题
《广玉兰》教学反思
2014/04/14 职场文书
网吧消防安全责任书
2014/07/29 职场文书
井冈山红色之旅心得体会
2014/10/07 职场文书
村党的群众路线教育实践活动工作总结
2014/10/25 职场文书
工程催款通知书
2015/04/17 职场文书
法律进社区活动总结
2015/05/07 职场文书