pytorch 利用lstm做mnist手写数字识别分类的实例


Posted in Python onJanuary 10, 2020

代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定义数据
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定义模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用两层lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#将最后一个的rnn使用全连接的到最后的输出结果
     
   def forward(self, x):
     #x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,变成(batch, 28,28)
     x = x.permute(2, 0, 1)#将最后一维放到第一维,变成(batch,28,28)
     out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一个,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分类结果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定义训练过程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)

以上这篇pytorch 利用lstm做mnist手写数字识别分类的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python进行稳定可靠的文件操作详解
Dec 31 Python
python处理中文编码和判断编码示例
Feb 26 Python
Python 实现文件的全备份和差异备份详解
Dec 27 Python
使用python爬虫实现网络股票信息爬取的demo
Jan 05 Python
python数据结构之线性表的顺序存储结构
Sep 28 Python
朴素贝叶斯Python实例及解析
Nov 19 Python
浅谈Python中的全局锁(GIL)问题
Jan 11 Python
Python3.4学习笔记之 idle 清屏扩展插件用法分析
Mar 01 Python
Python如何把多个PDF文件合并代码实例
Feb 13 Python
Python中pass的作用与使用教程
Nov 13 Python
celery在python爬虫中定时操作实例讲解
Nov 27 Python
python 实现网易邮箱邮件阅读和删除的辅助小脚本
Mar 01 Python
Tensorflow Summary用法学习笔记
Jan 10 #Python
TENSORFLOW变量作用域(VARIABLE SCOPE)
Jan 10 #Python
python numpy数组复制使用实例解析
Jan 10 #Python
关于Pytorch的MNIST数据集的预处理详解
Jan 10 #Python
详解pycharm连接不上mysql数据库的解决办法
Jan 10 #Python
pycharm双击无响应(打不开问题解决办法)
Jan 10 #Python
python ubplot使用方法解析
Jan 10 #Python
You might like
php获取某个目录大小的代码
2008/09/10 PHP
鸡肋的PHP单例模式应用详解
2013/06/03 PHP
Codeigniter+PHPExcel实现导出数据到Excel文件
2014/06/12 PHP
Thinkphp5框架简单实现钩子(Hook)行为的方法示例
2019/09/03 PHP
关于laravel 日志写入失败问题汇总
2019/10/17 PHP
javascript笔试题目附答案@20081025_jb51.net
2008/10/26 Javascript
js删除所有的cookie的代码
2010/11/25 Javascript
IE事件对象(The Internet Explorer Event Object)
2012/06/27 Javascript
js判断客户端是iOS还是Android等移动终端的方法
2013/12/11 Javascript
jQuery获取及设置表单input各种类型值的方法小结
2016/05/24 Javascript
jQuery实现ajax的叠加和停止(终止ajax请求)
2016/08/08 Javascript
解决前端跨域问题方案汇总
2016/11/20 Javascript
JavaScript实现汉字转换为拼音的库文件示例
2016/12/22 Javascript
Angular4开发解决跨域问题详解
2017/08/28 Javascript
vue jsx 使用指南及vue.js 使用jsx语法的方法
2017/11/11 Javascript
jquery实现的简单轮播图功能【适合新手】
2018/08/17 jQuery
微信小程序日历效果
2018/12/29 Javascript
JS块级作用域和私有变量实例分析
2019/05/11 Javascript
[48:51]完美世界DOTA2联赛PWL S2 Magma vs InkIce 第一场 11.28
2020/12/02 DOTA
探究数组排序提升Python程序的循环的运行效率的原因
2015/04/01 Python
Go语言基于Socket编写服务器端与客户端通信的实例
2016/02/19 Python
浅谈python内置变量-reversed(seq)
2017/06/21 Python
python矩阵转换为一维数组的实例
2018/06/05 Python
使用NumPy和pandas对CSV文件进行写操作的实例
2018/06/14 Python
python爬虫之urllib3的使用示例
2018/07/09 Python
Python加密模块的hashlib,hmac模块使用解析
2020/01/02 Python
使用tensorflow实现VGG网络,训练mnist数据集方式
2020/05/26 Python
写求职信有哪些注意事项
2014/05/08 职场文书
廉洁自律演讲稿
2014/05/22 职场文书
求职信怎么写
2014/05/23 职场文书
标准版离职证明书
2014/09/12 职场文书
运动会宣传语
2015/07/13 职场文书
2016感恩父亲节主题广播稿
2015/12/18 职场文书
篮球拉拉队口号
2015/12/25 职场文书
Python入门之基础语法详解
2021/05/11 Python
浅谈MySQL之浅入深出页原理
2021/06/23 MySQL