pytorch 利用lstm做mnist手写数字识别分类的实例


Posted in Python onJanuary 10, 2020

代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定义数据
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定义模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用两层lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#将最后一个的rnn使用全连接的到最后的输出结果
     
   def forward(self, x):
     #x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,变成(batch, 28,28)
     x = x.permute(2, 0, 1)#将最后一维放到第一维,变成(batch,28,28)
     out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一个,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分类结果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定义训练过程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)

以上这篇pytorch 利用lstm做mnist手写数字识别分类的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Flask SQLAlchemy一对一,一对多的使用方法实践
Feb 10 Python
Python中的测试模块unittest和doctest的使用教程
Apr 14 Python
详解Python中contextlib上下文管理模块的用法
Jun 28 Python
python搭建虚拟环境的步骤详解
Sep 27 Python
Python 中迭代器与生成器实例详解
Mar 29 Python
Python2和Python3中print的用法示例总结
Oct 25 Python
Python实现找出数组中第2大数字的方法示例
Mar 26 Python
Python流程控制 while循环实现解析
Sep 02 Python
pytorch中的卷积和池化计算方式详解
Jan 03 Python
python如何编写win程序
Jun 08 Python
python能在浏览器能运行吗
Jun 17 Python
Python OpenCV 彩色与灰度图像的转换实现
Jun 05 Python
Tensorflow Summary用法学习笔记
Jan 10 #Python
TENSORFLOW变量作用域(VARIABLE SCOPE)
Jan 10 #Python
python numpy数组复制使用实例解析
Jan 10 #Python
关于Pytorch的MNIST数据集的预处理详解
Jan 10 #Python
详解pycharm连接不上mysql数据库的解决办法
Jan 10 #Python
pycharm双击无响应(打不开问题解决办法)
Jan 10 #Python
python ubplot使用方法解析
Jan 10 #Python
You might like
Protoss兵种对照表
2020/03/14 星际争霸
codeigniter自带数据库类使用方法说明
2014/03/25 PHP
php生成RSS订阅的方法
2015/02/13 PHP
php通过rmdir删除目录的简单用法
2015/03/18 PHP
详解PHP实现执行定时任务
2015/12/21 PHP
PHP实现在windows下配置sendmail并通过mail()函数发送邮件的方法
2017/06/20 PHP
php经典趣味算法实例代码
2020/01/21 PHP
关于火狐(firefox)及ie下event获取的两种方法
2012/12/27 Javascript
javaScript arguments 对象使用介绍
2013/10/18 Javascript
变量声明时命名与变量作为对象属性时命名的区别解析
2013/12/06 Javascript
iframe窗口高度自适应的实现方法
2014/01/08 Javascript
Javascript中数组sort和reverse用法分析
2014/12/30 Javascript
js表头排序实现方法
2015/01/16 Javascript
详解Javacript和AngularJS中的Promises
2016/02/09 Javascript
JS数组方法slice()用法实例分析
2020/01/18 Javascript
jQuery实现鼠标放置名字上显示详细内容气泡提示框效果的方法分析
2020/04/04 jQuery
jquery实现点击左右按钮切换图片
2021/01/27 jQuery
Python基础入门之seed()方法的使用
2015/05/15 Python
Python+MongoDB自增键值的简单实现
2016/11/04 Python
python中reload(module)的用法示例详解
2017/09/15 Python
Python中字符串List按照长度排序
2019/07/01 Python
Django基础知识 URL路由系统详解
2019/07/18 Python
python 实现return返回多个值
2019/11/19 Python
python+selenium定时爬取丁香园的新型冠状病毒数据并制作出类似的地图(部署到云服务器)
2020/02/09 Python
python爬虫工具例举说明
2020/11/30 Python
python des,aes,rsa加解密的实现
2021/01/16 Python
Topman美国官网:英国著名的国际平价时尚男装品牌
2017/12/22 全球购物
巴塞罗那观光通票:Barcelona Pass
2019/10/30 全球购物
团日活动总结怎么写
2014/06/25 职场文书
平面设计师岗位职责
2014/09/18 职场文书
说好普通话圆梦你我他演讲稿
2014/09/21 职场文书
财务务虚会发言材料
2014/10/20 职场文书
管辖权异议上诉状
2015/05/23 职场文书
党内外群众意见范文
2015/06/02 职场文书
CSS3 实现NES游戏机的示例代码
2021/04/21 HTML / CSS
python面向对象版学生信息管理系统
2021/06/24 Python