python神经网络 使用Keras构建RNN训练


Posted in Python onMay 04, 2022

Keras中构建RNN的重要函数

1、SimpleRNN

SimpleRNN用于在Keras中构建普通的简单RNN层,在使用前需要import。

from keras.layers import SimpleRNN

在实际使用时,需要用到几个参数。

model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)

其中,batch_input_shape代表RNN输入数据的shape,shape的内容分别是每一次训练使用的BATCH,TIME_STEPS表示这个RNN按顺序输入的时间点的数量,INPUT_SIZE表示每一个时间点的输入数据大小。
CELL_SIZE代表训练每一个时间点的神经元数量。

2、model.train_on_batch

与之前的训练CNN网络和普通分类网络不同,RNN网络在建立时就规定了batch_input_shape,所以训练的时候也需要一定量一定量的传入训练数据。
model.train_on_batch在使用前需要对数据进行处理。获取指定BATCH大小的训练集。

X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE

具体训练过程如下:

for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)
        x = X_test[1].reshape(1,28,28)

全部代码

这是一个RNN神经网络的例子,用于识别手写体。

import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam

TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3

(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
 
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255

Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)

model = Sequential()

# conv1
model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)
model.add(Dense(OUTPUT_SIZE))
model.add(Activation("softmax"))
adam = Adam(LR)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])

## tarin
for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)

实验结果为:

10000/10000 [==============================] - 1s 147us/step
accuracy: 0.09329999938607215
…………………………
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9395000022649765
10000/10000 [==============================] - 1s 109us/step
accuracy: 0.9422999995946885
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9534000000357628
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9566000008583069
10000/10000 [==============================] - 1s 113us/step
accuracy: 0.950799999833107
10000/10000 [==============================] - 1s 116us/step
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9474999988079071
10000/10000 [==============================] - 1s 111us/step
accuracy: 0.9515000003576278
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9288999977707862
10000/10000 [==============================] - 1s 115us/step
accuracy: 0.9487999993562698

以上就是python神经网络使用Keras构建RNN训练的详细内容!


Tags in this post...

Python 相关文章推荐
Python调用C/C++动态链接库的方法详解
Jul 22 Python
基于python绘制科赫雪花
Jun 22 Python
python生成以及打开json、csv和txt文件的实例
Nov 16 Python
详解Python 解压缩文件
Apr 09 Python
python中数据库like模糊查询方式
Mar 02 Python
Python enumerate() 函数如何实现索引功能
Jun 29 Python
使用python批量修改XML文件中图像的depth值
Jul 22 Python
pycharm导入源码的具体步骤
Aug 04 Python
UI自动化定位常用实现方法代码示例
Oct 27 Python
python之np.argmax()及对axis=0或者1的理解
Jun 02 Python
Python函数式编程中itertools模块详解
Sep 15 Python
Python  序列化反序列化和异常处理的问题小结
Dec 24 Python
python神经网络学习 使用Keras进行回归运算
May 04 #Python
python神经网络学习 使用Keras进行简单分类
May 04 #Python
python神经网络 tf.name_scope 和 tf.variable_scope 的区别
May 04 #Python
Python3使用Qt5来实现简易的五子棋小游戏
May 02 #Python
python开发制作好看的时钟效果
关于的python五子棋的算法
python开发人人对战的五子棋小游戏
You might like
PHP系统流量分析的程序
2006/10/09 PHP
php使用pdo连接并查询sql数据库的方法
2014/12/24 PHP
几个优化WordPress中JavaScript加载体验的插件介绍
2015/12/17 PHP
深入了解PHP中的Array数组和foreach
2016/11/06 PHP
浅谈PHP错误类型及屏蔽方法
2017/05/27 PHP
jquery入门—访问DOM对象方法
2013/01/07 Javascript
JS+CSS实现自适应选项卡宽度的圆角滑动门效果
2015/09/15 Javascript
jQuery实用技巧必备(中)
2015/11/03 Javascript
详解Bootstrap四种图片样式
2016/01/04 Javascript
javascript闭包概念简单解析(推荐)
2016/06/03 Javascript
Bootstrap弹出框(modal)垂直居中的问题及解决方案详解
2016/06/12 Javascript
jQuery实现div横向拖拽排序的简单实例
2016/07/13 Javascript
谈谈对JavaScript原生拖放的深入理解
2016/09/20 Javascript
JavaScript版经典游戏之扫雷游戏完整示例【附demo源码下载】
2016/12/12 Javascript
详解.vue文件中监听input输入事件(oninput)
2017/09/19 Javascript
jQuery常见的遍历DOM操作详解
2018/09/05 jQuery
在vue中使用cookie记住用户上次选择的实例(本次例子中为下拉框)
2020/09/11 Javascript
[16:56]heroes英雄教学 司夜刺客
2014/09/18 DOTA
简单的Python抓taobao图片爬虫
2014/10/26 Python
python创建关联数组(字典)的方法
2015/05/04 Python
pyenv命令管理多个Python版本
2017/03/26 Python
caffe binaryproto 与 npy相互转换的实例讲解
2018/07/09 Python
解决tensorflow1.x版本加载saver.restore目录报错的问题
2018/07/26 Python
pyqt5移动鼠标显示坐标的方法
2019/06/21 Python
浅谈keras.callbacks设置模型保存策略
2020/06/18 Python
python将字典内容写入json文件的实例代码
2020/08/12 Python
CSS3制作日历实现代码
2012/01/21 HTML / CSS
IE10 Error.stack 让脚本调试更加方便快捷
2013/04/22 HTML / CSS
canvas实现手机的手势解锁的步骤详细
2020/03/16 HTML / CSS
Ray-Ban雷朋奥地利官网:全球领先的太阳眼镜品牌
2020/10/12 全球购物
三八红旗手先进事迹材料
2014/05/13 职场文书
护士节策划方案
2014/05/19 职场文书
文艺晚会策划方案
2014/06/11 职场文书
基层党员干部四风问题整改方向和措施
2014/09/25 职场文书
亮剑观后感
2015/06/05 职场文书
Redis入门基础常用操作命令整理
2022/06/01 Redis