pytorch 利用lstm做mnist手写数字识别分类的实例


Posted in Python onJanuary 10, 2020

代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定义数据
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定义模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用两层lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#将最后一个的rnn使用全连接的到最后的输出结果
     
   def forward(self, x):
     #x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,变成(batch, 28,28)
     x = x.permute(2, 0, 1)#将最后一维放到第一维,变成(batch,28,28)
     out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一个,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分类结果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定义训练过程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)

以上这篇pytorch 利用lstm做mnist手写数字识别分类的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 随机数生成的代码的详细分析
May 15 Python
Python中多线程的创建及基本调用方法
Jul 08 Python
python连接数据库的方法
Oct 19 Python
python Flask实现restful api service
Dec 04 Python
快速了解Python开发中的cookie及简单代码示例
Jan 17 Python
一文了解Python并发编程的工程实现方法
May 31 Python
python字典的setdefault的巧妙用法
Aug 07 Python
PYTHON发送邮件YAGMAIL的简单实现解析
Oct 28 Python
Python单链表原理与实现方法详解
Feb 22 Python
python如何调用java类
Jul 05 Python
学习Python需要哪些工具
Sep 04 Python
Python Django模型详解
Oct 05 Python
Tensorflow Summary用法学习笔记
Jan 10 #Python
TENSORFLOW变量作用域(VARIABLE SCOPE)
Jan 10 #Python
python numpy数组复制使用实例解析
Jan 10 #Python
关于Pytorch的MNIST数据集的预处理详解
Jan 10 #Python
详解pycharm连接不上mysql数据库的解决办法
Jan 10 #Python
pycharm双击无响应(打不开问题解决办法)
Jan 10 #Python
python ubplot使用方法解析
Jan 10 #Python
You might like
浅析Yii中使用RBAC的完全指南(用户角色权限控制)
2013/06/20 PHP
PHP 如何利用phpexcel导入数据库
2013/08/24 PHP
php生成随机数的三种方法
2014/09/10 PHP
ThinkPHP实现分页功能
2017/04/28 PHP
thinkphp框架类库扩展操作示例
2019/11/26 PHP
php开发最强大的IDE编辑的phpstorm 2020.2配置Xdebug调试的详细教程
2020/08/17 PHP
分享XmlHttpRequest调用Webservice的一点心得
2012/07/20 Javascript
jQuery实现表头固定效果的实例代码
2013/05/24 Javascript
完美解决AJAX跨域问题
2013/11/01 Javascript
jquery 模板的应用示例
2013/11/12 Javascript
js实现倒计时时钟的示例代码
2013/12/17 Javascript
Bootstrap每天必学之缩略图与警示窗
2015/11/29 Javascript
日常收藏的jquery技巧
2015/12/02 Javascript
详解用Node.js实现Restful风格webservice
2017/09/29 Javascript
总结javascript三元运算符知识点
2018/09/28 Javascript
详解vue如何使用rules对表单字段进行校验
2018/10/17 Javascript
webpack是如何实现模块化加载的方法
2019/11/06 Javascript
angular异步验证防抖踩坑实录
2019/12/01 Javascript
使用webpack搭建pixi.js开发环境
2020/02/12 Javascript
koa2的中间件功能及应用示例
2020/03/05 Javascript
vuex中store存储store.commit和store.dispatch的用法
2020/07/24 Javascript
[02:17]2016完美“圣”典风云人物:Sccc专访
2016/12/03 DOTA
python基础教程之元组操作使用详解
2014/03/25 Python
python网络编程学习笔记(四):域名系统
2014/06/09 Python
python实现简单点对点(p2p)聊天
2017/09/13 Python
对numpy中轴与维度的理解
2018/04/18 Python
python 实现视频 图像帧提取
2019/12/10 Python
HTML5使用ApplicationCache接口实现离线缓存技术解决离线难题
2012/12/13 HTML / CSS
HTML5 3D旋转相册的实现示例
2019/12/03 HTML / CSS
AmazeUI底部导航栏与分享按钮的示例代码
2020/08/18 HTML / CSS
毕业生多媒体设计求职信
2013/10/12 职场文书
求职信模版
2013/11/30 职场文书
毕业自我鉴定怎么写
2014/03/25 职场文书
写字楼租赁意向书
2014/07/30 职场文书
护士自荐信怎么写
2015/03/06 职场文书
python编写五子棋游戏
2021/05/25 Python