python神经网络 使用Keras构建RNN训练


Posted in Python onMay 04, 2022

Keras中构建RNN的重要函数

1、SimpleRNN

SimpleRNN用于在Keras中构建普通的简单RNN层,在使用前需要import。

from keras.layers import SimpleRNN

在实际使用时,需要用到几个参数。

model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)

其中,batch_input_shape代表RNN输入数据的shape,shape的内容分别是每一次训练使用的BATCH,TIME_STEPS表示这个RNN按顺序输入的时间点的数量,INPUT_SIZE表示每一个时间点的输入数据大小。
CELL_SIZE代表训练每一个时间点的神经元数量。

2、model.train_on_batch

与之前的训练CNN网络和普通分类网络不同,RNN网络在建立时就规定了batch_input_shape,所以训练的时候也需要一定量一定量的传入训练数据。
model.train_on_batch在使用前需要对数据进行处理。获取指定BATCH大小的训练集。

X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE

具体训练过程如下:

for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)
        x = X_test[1].reshape(1,28,28)

全部代码

这是一个RNN神经网络的例子,用于识别手写体。

import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam

TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3

(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
 
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255

Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)

model = Sequential()

# conv1
model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)
model.add(Dense(OUTPUT_SIZE))
model.add(Activation("softmax"))
adam = Adam(LR)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])

## tarin
for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)

实验结果为:

10000/10000 [==============================] - 1s 147us/step
accuracy: 0.09329999938607215
…………………………
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9395000022649765
10000/10000 [==============================] - 1s 109us/step
accuracy: 0.9422999995946885
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9534000000357628
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9566000008583069
10000/10000 [==============================] - 1s 113us/step
accuracy: 0.950799999833107
10000/10000 [==============================] - 1s 116us/step
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9474999988079071
10000/10000 [==============================] - 1s 111us/step
accuracy: 0.9515000003576278
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9288999977707862
10000/10000 [==============================] - 1s 115us/step
accuracy: 0.9487999993562698

以上就是python神经网络使用Keras构建RNN训练的详细内容!


Tags in this post...

Python 相关文章推荐
python smtplib模块发送SSL/TLS安全邮件实例
Apr 08 Python
浅谈Python生成器generator之next和send的运行流程(详解)
May 08 Python
python实现word 2007文档转换为pdf文件
Mar 15 Python
python版本的仿windows计划任务工具
Apr 30 Python
python多进程下实现日志记录按时间分割
Jul 22 Python
详解一种用django_cache实现分布式锁的方式
Sep 01 Python
Python实现报警信息实时发送至邮箱功能(实例代码)
Nov 11 Python
详解python opencv、scikit-image和PIL图像处理库比较
Dec 26 Python
django 解决扩展自带User表遇到的问题
May 14 Python
Python 如何调试程序崩溃错误
Aug 03 Python
Python3压缩和解压缩实现代码
Mar 01 Python
PyTorch的Debug指南
May 07 Python
python神经网络学习 使用Keras进行回归运算
May 04 #Python
python神经网络学习 使用Keras进行简单分类
May 04 #Python
python神经网络 tf.name_scope 和 tf.variable_scope 的区别
May 04 #Python
Python3使用Qt5来实现简易的五子棋小游戏
May 02 #Python
python开发制作好看的时钟效果
关于的python五子棋的算法
python开发人人对战的五子棋小游戏
You might like
利用php+mysql来做一个功能强大的在线计算器
2010/10/12 PHP
关于mysql字符集设置了character_set_client=binary 在gbk情况下会出现表描述是乱码的情况
2013/01/06 PHP
解析使用ThinkPHP应该掌握的调试手段
2013/06/20 PHP
CodeIgniter安全相关设置汇总
2014/07/03 PHP
php实现图片等比例缩放代码
2015/07/23 PHP
PHP判断FORM表单或URL参数来的数据是否为整数的方法
2016/03/25 PHP
PHP基于IMAP收取邮件的方法示例
2017/08/07 PHP
PHP有序表查找之二分查找(折半查找)算法示例
2018/02/09 PHP
JavaScript Event学习第十一章 按键的检测
2010/02/10 Javascript
Javascript 面向对象(三)接口代码
2012/05/23 Javascript
jquery+json实现数据列表分页示例代码
2013/11/15 Javascript
JavaScript调试工具汇总
2014/12/23 Javascript
分析了一下JQuery中的extend方法实现原理
2015/02/27 Javascript
JQuery中$.each 和$(selector).each()的区别详解
2015/03/13 Javascript
深入浅析Node.js 事件循环、定时器和process.nextTick()
2018/10/22 Javascript
微信小程序实现发送验证码按钮效果
2018/12/20 Javascript
利用Dectorator分模块存储Vuex状态的实现
2019/02/05 Javascript
深入理解vue-class-component源码阅读
2019/02/18 Javascript
JavaScript命令模式原理与用法实例详解
2020/03/10 Javascript
vue实现虚拟列表功能的代码
2020/07/28 Javascript
DWR内存兼容及无法调用问题解决方案
2020/10/16 Javascript
[02:09:59]火猫TV国士无双dota2 6.82版本详解(下)
2014/09/29 DOTA
[00:53]2015国际邀请赛 中国区预选赛一触即发
2015/05/14 DOTA
[01:20]2018DOTA2亚洲邀请赛总决赛战队LGD晋级之路
2018/04/07 DOTA
python itchat实现微信自动回复的示例代码
2017/08/14 Python
Python处理命令行参数模块optpars用法实例分析
2018/05/31 Python
对python中的控制条件、循环和跳出详解
2019/06/24 Python
解决Python中pandas读取*.csv文件出现编码问题
2019/07/12 Python
Python虚拟环境venv用法详解
2020/05/25 Python
对Keras中predict()方法和predict_classes()方法的区别说明
2020/06/09 Python
实现Python3数组旋转的3种算法实例
2020/09/16 Python
意大利综合购物网站:Giordano Shop
2016/10/21 全球购物
日本高岛屋百货购物网站:TAKASHIMAYA
2019/03/24 全球购物
2014年为民办实事工作总结
2014/12/20 职场文书
2015年中学体育教师工作总结
2015/10/23 职场文书
索尼ICF-5900W收音机测评
2022/04/24 无线电