pytorch 利用lstm做mnist手写数字识别分类的实例


Posted in Python onJanuary 10, 2020

代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定义数据
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定义模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用两层lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#将最后一个的rnn使用全连接的到最后的输出结果
     
   def forward(self, x):
     #x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,变成(batch, 28,28)
     x = x.permute(2, 0, 1)#将最后一维放到第一维,变成(batch,28,28)
     out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一个,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分类结果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定义训练过程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)

以上这篇pytorch 利用lstm做mnist手写数字识别分类的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的词法分析与语法分析
May 18 Python
python动态网页批量爬取
Feb 14 Python
基于Python对象引用、可变性和垃圾回收详解
Aug 21 Python
浅析python协程相关概念
Jan 20 Python
python如何创建TCP服务端和客户端
Aug 26 Python
Django model序列化为json的方法示例
Oct 16 Python
Python中模块(Module)和包(Package)的区别详解
Aug 07 Python
Python 实现的 Google 批量翻译功能
Aug 26 Python
基于Python词云分析政府工作报告关键词
Jun 02 Python
基于Python编写一个计算器程序,实现简单的加减乘除和取余二元运算
Aug 05 Python
Python基于argparse与ConfigParser库进行入参解析与ini parser
Feb 02 Python
Python基于Opencv识别两张相似图片
Apr 25 Python
Tensorflow Summary用法学习笔记
Jan 10 #Python
TENSORFLOW变量作用域(VARIABLE SCOPE)
Jan 10 #Python
python numpy数组复制使用实例解析
Jan 10 #Python
关于Pytorch的MNIST数据集的预处理详解
Jan 10 #Python
详解pycharm连接不上mysql数据库的解决办法
Jan 10 #Python
pycharm双击无响应(打不开问题解决办法)
Jan 10 #Python
python ubplot使用方法解析
Jan 10 #Python
You might like
一个简单的PHP入门源程序
2006/10/09 PHP
PHP Cookei记录用户历史浏览信息的代码
2016/02/03 PHP
thinkPHP3.2.2框架行为扩展及demo示例
2018/06/19 PHP
TP5框架model常见操作示例小结【增删改查、聚合、时间戳、软删除等】
2020/04/05 PHP
爱恋千雪-US-AscII加密解密工具(网页加密)下载
2007/06/06 Javascript
Javascript的一种模块模式
2008/03/22 Javascript
Node.js中使用事件发射器模式实现事件绑定详解
2014/08/15 Javascript
原生js与jQuery实现简单的tab切换特效对比
2015/07/30 Javascript
js判断手机号运营商的方法
2015/10/23 Javascript
JavaScript统计网站访问次数的实现代码
2015/11/18 Javascript
js实现n秒倒计时后才可以点击的效果
2015/12/20 Javascript
Vue自定义指令介绍(2)
2016/12/08 Javascript
Js实现中国公民身份证号码有效性验证实例代码
2017/05/03 Javascript
微信小程序商品到详情的实现
2017/06/27 Javascript
angular4模块中给标签添加背景图的实现方法
2017/09/15 Javascript
js将键值对字符串转为json字符串的方法
2018/03/30 Javascript
vue iview实现动态路由和权限验证功能
2018/04/17 Javascript
使用element-ui的el-menu导航选中后刷新页面保持当前选中状态
2019/07/19 Javascript
[46:12]完美世界DOTA2联赛循环赛 DM vs Matador BO2第一场 11.04
2020/11/04 DOTA
Python实现从百度API获取天气的方法
2015/03/11 Python
在Django框架中运行Python应用全攻略
2015/07/17 Python
Python 由字符串函数名得到对应的函数(实例讲解)
2017/08/10 Python
Python 使用with上下文实现计时功能
2018/03/09 Python
python3实现指定目录下文件sha256及文件大小统计
2019/02/25 Python
Python的numpy库下的几个小函数的用法(小结)
2019/07/12 Python
使用APScheduler3.0.1 实现定时任务的方法
2019/07/22 Python
python db类用法说明
2020/07/07 Python
Python 实现PS滤镜中的径向模糊特效
2020/12/03 Python
CSS3支持IE6, 7, and 8的边框border属性
2012/12/28 HTML / CSS
HTML5通过调用canvas对象的getContext()方法来获取绘图环境
2014/06/23 HTML / CSS
eDreams德国:南欧领先的在线旅游公司
2020/12/07 全球购物
Linux如何压缩可执行文件
2013/10/21 面试题
护理学应聘自荐书范文
2014/02/05 职场文书
国际会计专业求职信
2014/08/04 职场文书
无犯罪记录证明
2014/09/19 职场文书
纯html+css实现奥运五环的示例代码
2021/08/02 HTML / CSS