pytorch 利用lstm做mnist手写数字识别分类的实例


Posted in Python onJanuary 10, 2020

代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定义数据
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定义模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用两层lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#将最后一个的rnn使用全连接的到最后的输出结果
     
   def forward(self, x):
     #x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,变成(batch, 28,28)
     x = x.permute(2, 0, 1)#将最后一维放到第一维,变成(batch,28,28)
     out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一个,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分类结果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定义训练过程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)

以上这篇pytorch 利用lstm做mnist手写数字识别分类的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python ORM框架SQLAlchemy学习笔记之映射类使用实例和Session会话介绍
Jun 10 Python
python批量实现Word文件转换为PDF文件
Mar 15 Python
python3使用SMTP发送简单文本邮件
Jun 19 Python
python+pandas+时间、日期以及时间序列处理方法
Jul 10 Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 Python
Java文件与类动手动脑实例详解
Nov 10 Python
Django3.0 异步通信初体验(小结)
Dec 04 Python
Python简单实现区域生长方式
Jan 16 Python
浅谈keras中Dropout在预测过程中是否仍要起作用
Jul 09 Python
Django返回HTML文件的实现方法
Sep 17 Python
Python图像处理库PIL详细使用说明
Apr 06 Python
Django框架之路由用法
Jun 10 Python
Tensorflow Summary用法学习笔记
Jan 10 #Python
TENSORFLOW变量作用域(VARIABLE SCOPE)
Jan 10 #Python
python numpy数组复制使用实例解析
Jan 10 #Python
关于Pytorch的MNIST数据集的预处理详解
Jan 10 #Python
详解pycharm连接不上mysql数据库的解决办法
Jan 10 #Python
pycharm双击无响应(打不开问题解决办法)
Jan 10 #Python
python ubplot使用方法解析
Jan 10 #Python
You might like
php中根据某年第几天计算出日期年月日的代码
2011/02/24 PHP
浅谈PHP的反射API
2017/02/26 PHP
PHP中的empty、isset、isnull的区别与使用实例
2019/03/22 PHP
jQuery 可以拖动的div实现代码 脚本之家修正版
2009/06/26 Javascript
JS 时间显示效果代码
2009/08/23 Javascript
javascript中检测变量的类型的代码
2010/12/28 Javascript
firefox下jQuery UI Autocomplete 1.8.*中文输入修正方法
2012/09/19 Javascript
javascript动画对象支持加速、减速、缓入、缓出的实现代码
2012/09/30 Javascript
js获取单选框或复选框值及操作
2012/12/18 Javascript
AngularJS基础 ng-value 指令简单示例
2016/08/03 Javascript
Node.js配合node-http-proxy解决本地开发ajax跨域问题
2016/08/31 Javascript
jQuery双向列表选择器DIV模拟版
2016/11/01 Javascript
JavaScript中数据类型转换总结
2016/12/25 Javascript
vue-router:嵌套路由的使用方法
2017/02/21 Javascript
Angular2平滑升级到Angular4的步骤详解
2017/03/29 Javascript
从零开始学习Node.js系列教程四:多页面实现数学运算的client端和server端示例
2017/04/13 Javascript
jQuery遍历节点方法汇总(推荐)
2017/05/13 jQuery
JavaScript中常见的八个陷阱总结
2017/06/28 Javascript
vue实现树形菜单效果
2018/03/19 Javascript
微信小程序首页的分类功能和搜索功能的实现思路及代码详解
2018/09/11 Javascript
基于vue+axios+lrz.js微信端图片压缩上传方法
2019/06/25 Javascript
简单了解常用的JavaScript 库
2020/07/16 Javascript
python爬虫之自动登录与验证码识别
2020/06/15 Python
Python定时发送消息的脚本:每天跟你女朋友说晚安
2018/10/21 Python
Python运行不显示DOS窗口的解决方法
2018/10/22 Python
使用Python快速制作可视化报表的方法
2019/02/03 Python
python3实现指定目录下文件sha256及文件大小统计
2019/02/25 Python
Python环境管理virtualenv&virtualenvwrapper的配置详解
2020/07/01 Python
python3 中时间戳、时间、日期的转换和加减操作
2020/07/14 Python
利用CSS3的checked伪类实现OL的隐藏显示的方法
2010/12/18 HTML / CSS
Coggles美国/加拿大:高级国际时装零售商
2018/10/23 全球购物
在家更换处方镜片:Lensabl
2019/05/01 全球购物
超市促销活动总结
2014/07/01 职场文书
教师党员学习群众路线心得体会
2014/11/04 职场文书
2015年政务公开工作总结
2015/05/19 职场文书
《中国机长》观后感:敬畏生命,敬畏职责
2019/11/12 职场文书