python神经网络 使用Keras构建RNN训练


Posted in Python onMay 04, 2022

Keras中构建RNN的重要函数

1、SimpleRNN

SimpleRNN用于在Keras中构建普通的简单RNN层,在使用前需要import。

from keras.layers import SimpleRNN

在实际使用时,需要用到几个参数。

model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)

其中,batch_input_shape代表RNN输入数据的shape,shape的内容分别是每一次训练使用的BATCH,TIME_STEPS表示这个RNN按顺序输入的时间点的数量,INPUT_SIZE表示每一个时间点的输入数据大小。
CELL_SIZE代表训练每一个时间点的神经元数量。

2、model.train_on_batch

与之前的训练CNN网络和普通分类网络不同,RNN网络在建立时就规定了batch_input_shape,所以训练的时候也需要一定量一定量的传入训练数据。
model.train_on_batch在使用前需要对数据进行处理。获取指定BATCH大小的训练集。

X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE

具体训练过程如下:

for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)
        x = X_test[1].reshape(1,28,28)

全部代码

这是一个RNN神经网络的例子,用于识别手写体。

import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam

TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3

(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
 
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255

Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)

model = Sequential()

# conv1
model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)
model.add(Dense(OUTPUT_SIZE))
model.add(Activation("softmax"))
adam = Adam(LR)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])

## tarin
for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)

实验结果为:

10000/10000 [==============================] - 1s 147us/step
accuracy: 0.09329999938607215
…………………………
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9395000022649765
10000/10000 [==============================] - 1s 109us/step
accuracy: 0.9422999995946885
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9534000000357628
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9566000008583069
10000/10000 [==============================] - 1s 113us/step
accuracy: 0.950799999833107
10000/10000 [==============================] - 1s 116us/step
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9474999988079071
10000/10000 [==============================] - 1s 111us/step
accuracy: 0.9515000003576278
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9288999977707862
10000/10000 [==============================] - 1s 115us/step
accuracy: 0.9487999993562698

以上就是python神经网络使用Keras构建RNN训练的详细内容!


Tags in this post...

Python 相关文章推荐
Python升级提示Tkinter模块找不到的解决方法
Aug 22 Python
详细解析Python当中的数据类型和变量
Apr 25 Python
Python制作钉钉加密/解密工具
Dec 07 Python
Python+request+unittest实现接口测试框架集成实例
Mar 16 Python
解决python Markdown模块乱码的问题
Feb 14 Python
python multiprocessing模块用法及原理介绍
Aug 20 Python
PyCharm License Activation激活码失效问题的解决方法(图文详解)
Mar 12 Python
解析Python 偏函数用法全方位实现
Jun 26 Python
python rsa-oaep加密的示例代码
Sep 23 Python
python 检测nginx服务邮件报警的脚本
Dec 31 Python
详解pytorch创建tensor函数
Mar 22 Python
在python中读取和写入CSV文件详情
Jun 28 Python
python神经网络学习 使用Keras进行回归运算
May 04 #Python
python神经网络学习 使用Keras进行简单分类
May 04 #Python
python神经网络 tf.name_scope 和 tf.variable_scope 的区别
May 04 #Python
Python3使用Qt5来实现简易的五子棋小游戏
May 02 #Python
python开发制作好看的时钟效果
关于的python五子棋的算法
python开发人人对战的五子棋小游戏
You might like
apache+mysql+php+ssl服务器之完全安装攻略
2006/09/05 PHP
常用的php对象类型判断
2008/08/27 PHP
用php实现让页面只能被百度gogole蜘蛛访问的方法
2009/12/29 PHP
洪恩在线成语词典小偷程序php版
2012/04/20 PHP
php+html5+ajax实现上传图片的方法
2016/05/14 PHP
php基于mcrypt_encrypt和mcrypt_decrypt实现字符串加密解密的方法
2016/07/12 PHP
php array_map()函数实例用法
2021/03/03 PHP
Jquery调用webService远程访问出错的解决方法
2010/05/21 Javascript
JavaScript EasyPager 分页函数
2011/05/25 Javascript
javascript高级学习笔记整理
2011/08/14 Javascript
javascript日期转换 时间戳转日期格式
2011/11/05 Javascript
Javascript算符的优先级介绍
2013/03/20 Javascript
子页向父页传值示例
2013/11/27 Javascript
js数组与字符串的相互转换方法
2014/07/09 Javascript
jQuery检测输入的字符串包含的中英文的数量
2015/04/17 Javascript
javascript 应用小技巧方法汇总
2015/07/05 Javascript
js 获取经纬度的实现方法
2016/06/20 Javascript
超全面的vue.js使用总结
2017/02/12 Javascript
从零开始学习Node.js系列教程之SQLite3和MongoDB用法分析
2017/04/13 Javascript
关于Vue.nextTick()的正确使用方法浅析
2017/08/25 Javascript
Bootstrap标签页(Tab)插件切换echarts不显示问题的解决
2018/07/13 Javascript
Vue2.5学习笔记之如何在项目中使用和配置Vue
2018/09/26 Javascript
javascript删除数组元素的七个方法示例
2019/09/09 Javascript
vue elementui 实现搜索栏公共组件封装的实例代码
2020/01/20 Javascript
JS实现图片幻灯片效果代码实例
2020/05/21 Javascript
如何使用jQuery操作Cookies方法解析
2020/09/08 jQuery
python实现发送邮件功能代码
2017/12/14 Python
PyCharm更改字体和界面样式的方法步骤
2019/09/27 Python
给老师的感谢信
2015/01/20 职场文书
感谢信范文大全
2015/01/23 职场文书
考生诚信考试承诺书
2015/04/29 职场文书
婚礼伴郎致辞
2015/07/28 职场文书
导游词之铁岭象牙山
2019/12/06 职场文书
php引用传递
2021/04/01 PHP
CSS完成视差滚动效果
2021/04/27 HTML / CSS
CSS实现背景图片全屏铺满自适应的3种方式
2022/07/07 HTML / CSS