pytorch 利用lstm做mnist手写数字识别分类的实例


Posted in Python onJanuary 10, 2020

代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定义数据
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定义模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用两层lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#将最后一个的rnn使用全连接的到最后的输出结果
     
   def forward(self, x):
     #x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,变成(batch, 28,28)
     x = x.permute(2, 0, 1)#将最后一维放到第一维,变成(batch,28,28)
     out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一个,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分类结果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定义训练过程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)

以上这篇pytorch 利用lstm做mnist手写数字识别分类的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
centos下更新Python版本的步骤
Feb 12 Python
Python基于PycURL自动处理cookie的方法
Jul 25 Python
详解Python 数据库 (sqlite3)应用
Dec 07 Python
python操作mysql数据库
Mar 05 Python
pyqt 实现QlineEdit 输入密码显示成圆点的方法
Jun 24 Python
python内置模块collections知识点总结
Dec 19 Python
Python的PIL库中getpixel方法的使用
Apr 09 Python
Python字典fromkeys()方法使用代码实例
Jul 20 Python
Python3实现英文字母转换哥特式字体实例代码
Sep 01 Python
属性与 @property 方法让你的python更高效
Sep 21 Python
使用python画出逻辑斯蒂映射(logistic map)中的分叉图案例
Dec 11 Python
python可视化之颜色映射详解
Sep 15 Python
Tensorflow Summary用法学习笔记
Jan 10 #Python
TENSORFLOW变量作用域(VARIABLE SCOPE)
Jan 10 #Python
python numpy数组复制使用实例解析
Jan 10 #Python
关于Pytorch的MNIST数据集的预处理详解
Jan 10 #Python
详解pycharm连接不上mysql数据库的解决办法
Jan 10 #Python
pycharm双击无响应(打不开问题解决办法)
Jan 10 #Python
python ubplot使用方法解析
Jan 10 #Python
You might like
咖啡风味 世界咖啡主要分布分布 咖啡的生长要求
2021/03/06 新手入门
PHP中获取文件扩展名的N种方法小结
2012/02/27 PHP
关于尾递归的使用详解
2013/05/02 PHP
php获取文件内容最后一行示例
2014/01/09 PHP
php socket客户端及服务器端应用实例
2014/07/04 PHP
IE8 引入跨站数据获取功能说明
2008/07/22 Javascript
IE DOM实现存在的部分问题及解决方法
2009/07/25 Javascript
javascript json 新手入门文档
2009/12/03 Javascript
window.location不跳转的问题解决方法
2014/04/17 Javascript
深入浅析JavaScript中prototype和proto的关系
2015/11/15 Javascript
JS实现区分中英文并统计字符个数的方法示例
2018/06/09 Javascript
vue.js项目 el-input 组件 监听回车键实现搜索功能示例
2018/08/25 Javascript
JavaScript错误处理操作实例详解
2019/01/04 Javascript
详解50行代码,Node爬虫练手项目
2019/04/22 Javascript
Vue $mount实战之实现消息弹窗组件
2019/04/22 Javascript
JavaScript实现图片轮播特效
2019/10/23 Javascript
JS 数组基本用法入门示例解析
2020/01/16 Javascript
vuex页面刷新导致数据丢失的解决方案
2020/12/10 Vue.js
python使用7z解压apk包的方法
2015/04/18 Python
详解使用python的logging模块在stdout输出的两种方法
2017/05/17 Python
opencv改变imshow窗口大小,窗口位置的方法
2018/04/02 Python
pandas DataFrame 交集并集补集的实现
2019/06/24 Python
python网络编程之多线程同时接受和发送
2019/09/03 Python
django实现支付宝支付实例讲解
2019/10/17 Python
python爬虫爬取幽默笑话网站
2019/10/24 Python
python3获取文件中url内容并下载代码实例
2019/12/27 Python
django rest framework使用django-filter用法
2020/07/15 Python
css3实现超立体3D图片侧翻倾斜效果
2014/04/16 HTML / CSS
阿里旅行:飞猪
2017/01/05 全球购物
世界领先的豪华床上用品供应商之一:Bedeck Home
2019/03/18 全球购物
俄罗斯首家面向中国消费者的一站式购物网站:Wruru
2020/05/08 全球购物
学前教育学生自荐信范文
2013/12/31 职场文书
高中教师考核方案
2014/05/18 职场文书
社团个人总结范文
2015/03/05 职场文书
pycharm无法导入lxml的解决办法
2021/03/31 Python
Centos系统通过Docker安装并搭建MongoDB数据库
2022/04/12 MongoDB