python神经网络 使用Keras构建RNN训练


Posted in Python onMay 04, 2022

Keras中构建RNN的重要函数

1、SimpleRNN

SimpleRNN用于在Keras中构建普通的简单RNN层,在使用前需要import。

from keras.layers import SimpleRNN

在实际使用时,需要用到几个参数。

model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)

其中,batch_input_shape代表RNN输入数据的shape,shape的内容分别是每一次训练使用的BATCH,TIME_STEPS表示这个RNN按顺序输入的时间点的数量,INPUT_SIZE表示每一个时间点的输入数据大小。
CELL_SIZE代表训练每一个时间点的神经元数量。

2、model.train_on_batch

与之前的训练CNN网络和普通分类网络不同,RNN网络在建立时就规定了batch_input_shape,所以训练的时候也需要一定量一定量的传入训练数据。
model.train_on_batch在使用前需要对数据进行处理。获取指定BATCH大小的训练集。

X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE

具体训练过程如下:

for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)
        x = X_test[1].reshape(1,28,28)

全部代码

这是一个RNN神经网络的例子,用于识别手写体。

import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam

TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3

(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
 
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255

Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)

model = Sequential()

# conv1
model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)
model.add(Dense(OUTPUT_SIZE))
model.add(Activation("softmax"))
adam = Adam(LR)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])

## tarin
for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)

实验结果为:

10000/10000 [==============================] - 1s 147us/step
accuracy: 0.09329999938607215
…………………………
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9395000022649765
10000/10000 [==============================] - 1s 109us/step
accuracy: 0.9422999995946885
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9534000000357628
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9566000008583069
10000/10000 [==============================] - 1s 113us/step
accuracy: 0.950799999833107
10000/10000 [==============================] - 1s 116us/step
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9474999988079071
10000/10000 [==============================] - 1s 111us/step
accuracy: 0.9515000003576278
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9288999977707862
10000/10000 [==============================] - 1s 115us/step
accuracy: 0.9487999993562698

以上就是python神经网络使用Keras构建RNN训练的详细内容!


Tags in this post...

Python 相关文章推荐
Python读取指定目录下指定后缀文件并保存为docx
Apr 23 Python
python生成随机图形验证码详解
Nov 08 Python
Python浅复制中对象生存周期实例分析
Apr 02 Python
wxPython的安装与使用教程
Aug 31 Python
使用CodeMirror实现Python3在线编辑器的示例代码
Jan 14 Python
python对文件目录的操作方法实例总结
Jun 24 Python
在Python函数中输入任意数量参数的实例
Jul 16 Python
python super的使用方法及实例详解
Sep 25 Python
Python类反射机制使用实例解析
Dec 30 Python
Python调用Windows命令打印文件
Feb 07 Python
python操作redis数据库的三种方法
Sep 10 Python
python软件测试Jmeter性能测试JDBC Request(结合数据库)的使用详解
Jan 26 Python
python神经网络学习 使用Keras进行回归运算
May 04 #Python
python神经网络学习 使用Keras进行简单分类
May 04 #Python
python神经网络 tf.name_scope 和 tf.variable_scope 的区别
May 04 #Python
Python3使用Qt5来实现简易的五子棋小游戏
May 02 #Python
python开发制作好看的时钟效果
关于的python五子棋的算法
python开发人人对战的五子棋小游戏
You might like
牡丹941资料
2021/03/01 无线电
php session 错误
2009/05/21 PHP
php牛逼的面试题分享
2013/01/18 PHP
PHP禁止个别IP访问网站
2013/10/30 PHP
PHP和javascript常用正则表达式及用法实例
2014/07/01 PHP
PHP中__set()实例用法和基础讲解
2019/07/23 PHP
php项目中类的自动加载实例讲解
2019/09/12 PHP
Yii框架 session 数据库存储操作方法示例
2019/11/18 PHP
php封装的page分页类完整实例代码
2020/02/01 PHP
ASP.NET jQuery 实例11 通过使用jQuery validation插件简单实现用户登录页面验证功能
2012/02/03 Javascript
jquery实现叠层3D文字特效代码分享
2015/08/21 Javascript
JavaScript获取function所有参数名的方法
2015/10/30 Javascript
理解JavaScript表单的基础知识
2016/01/25 Javascript
基于jquery实现智能提示控件intellSeach.js
2016/03/17 Javascript
javascript如何创建对象
2016/08/29 Javascript
AngularJS中过滤器的使用与自定义实例代码
2016/09/17 Javascript
JavaScript纯色二维码变成彩色二维码
2020/07/23 Javascript
详解nodeJS之路径PATH模块
2017/05/31 NodeJs
Angular4学习笔记之实现绑定和分包
2017/08/01 Javascript
JS执行控制之节流模式实例分析
2018/12/21 Javascript
Vue extend的基本用法(实例详解)
2019/12/09 Javascript
[01:03:50]DOTA2-DPC中国联赛 正赛 CDEC vs DLG BO3 第二场 2月7日
2021/03/11 DOTA
Python迭代和迭代器详解
2016/11/10 Python
浅谈flask截获所有访问及before/after_request修饰器
2018/01/18 Python
Python模拟登录的多种方法(四种)
2018/06/01 Python
Python assert关键字原理及实例解析
2019/12/13 Python
Python用access判断文件是否被占用的实例方法
2020/12/17 Python
基于CSS3实现的几个小loading效果
2018/09/27 HTML / CSS
canvas实现有递增动画的环形进度条的实现方法
2019/07/10 HTML / CSS
实例讲解使用HTML5 Canvas绘制阴影效果的方法
2016/03/25 HTML / CSS
千元咖啡店的创业计划书范文
2013/12/29 职场文书
小学生自我评价范文
2014/01/25 职场文书
《翻越远方的大山》教学反思
2014/04/13 职场文书
悬空寺导游词
2015/02/05 职场文书
入党积极分子群众意见
2015/06/01 职场文书
Pytest allure 命令行参数的使用
2021/04/18 Python