python神经网络 使用Keras构建RNN训练


Posted in Python onMay 04, 2022

Keras中构建RNN的重要函数

1、SimpleRNN

SimpleRNN用于在Keras中构建普通的简单RNN层,在使用前需要import。

from keras.layers import SimpleRNN

在实际使用时,需要用到几个参数。

model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)

其中,batch_input_shape代表RNN输入数据的shape,shape的内容分别是每一次训练使用的BATCH,TIME_STEPS表示这个RNN按顺序输入的时间点的数量,INPUT_SIZE表示每一个时间点的输入数据大小。
CELL_SIZE代表训练每一个时间点的神经元数量。

2、model.train_on_batch

与之前的训练CNN网络和普通分类网络不同,RNN网络在建立时就规定了batch_input_shape,所以训练的时候也需要一定量一定量的传入训练数据。
model.train_on_batch在使用前需要对数据进行处理。获取指定BATCH大小的训练集。

X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE

具体训练过程如下:

for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)
        x = X_test[1].reshape(1,28,28)

全部代码

这是一个RNN神经网络的例子,用于识别手写体。

import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam

TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3

(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
 
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255

Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)

model = Sequential()

# conv1
model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)
model.add(Dense(OUTPUT_SIZE))
model.add(Activation("softmax"))
adam = Adam(LR)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])

## tarin
for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)

实验结果为:

10000/10000 [==============================] - 1s 147us/step
accuracy: 0.09329999938607215
…………………………
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9395000022649765
10000/10000 [==============================] - 1s 109us/step
accuracy: 0.9422999995946885
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9534000000357628
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9566000008583069
10000/10000 [==============================] - 1s 113us/step
accuracy: 0.950799999833107
10000/10000 [==============================] - 1s 116us/step
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9474999988079071
10000/10000 [==============================] - 1s 111us/step
accuracy: 0.9515000003576278
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9288999977707862
10000/10000 [==============================] - 1s 115us/step
accuracy: 0.9487999993562698

以上就是python神经网络使用Keras构建RNN训练的详细内容!


Tags in this post...

Python 相关文章推荐
从零学Python之hello world
May 21 Python
浅析Python中的多进程与多线程的使用
Apr 07 Python
Python实现把数字转换成中文
Jun 29 Python
python用装饰器自动注册Tornado路由详解
Feb 14 Python
Python2.7基于笛卡尔积算法实现N个数组的排列组合运算示例
Nov 23 Python
Django中cookie的基本使用方法示例
Feb 03 Python
django 自定义filter 判断if var in list的例子
Aug 20 Python
python ImageDraw类实现几何图形的绘制与文字的绘制
Feb 26 Python
python 实现的车牌识别项目
Jan 25 Python
python 获取域名到期时间的方法步骤
Feb 10 Python
python爬取豆瓣电影排行榜(requests)的示例代码
Feb 18 Python
详解Python中openpyxl模块基本用法
Feb 23 Python
python神经网络学习 使用Keras进行回归运算
May 04 #Python
python神经网络学习 使用Keras进行简单分类
May 04 #Python
python神经网络 tf.name_scope 和 tf.variable_scope 的区别
May 04 #Python
Python3使用Qt5来实现简易的五子棋小游戏
May 02 #Python
python开发制作好看的时钟效果
关于的python五子棋的算法
python开发人人对战的五子棋小游戏
You might like
PHP编写的图片验证码类文件分享
2016/06/06 PHP
php 使用redis锁限制并发访问类示例
2016/11/02 PHP
PHP基于MySQLI函数封装的数据库连接工具类【定义与用法】
2017/08/11 PHP
使用jquery实现图文切换效果另加特效
2013/01/20 Javascript
js 事件截取enter按键页面提交事件示例代码
2014/03/04 Javascript
Javscript删除数组中指定元素并返回新数组
2014/03/06 Javascript
js点击button按钮跳转到另一个新页面
2014/10/10 Javascript
jquery 插件实现瀑布流图片展示实例
2015/04/03 Javascript
JavaScript面向对象的实现方法小结
2015/04/14 Javascript
JavaScript中指定函数名称的相关方法
2015/06/04 Javascript
js 声明数组和向数组中添加对象变量的简单实例
2016/07/28 Javascript
利用js编写响应式侧边栏
2016/09/17 Javascript
详解在Vue中如何使用axios跨域访问数据
2017/07/07 Javascript
php 解压zip压缩包内容到指定目录的实例
2018/01/23 Javascript
详解小程序如何避免多次点击,重复触发事件
2019/04/08 Javascript
Element Badge标记的使用方法
2020/07/27 Javascript
[03:11]DOTA2上海特锦赛小组赛第一日recap精彩回顾
2016/02/28 DOTA
[44:40]KG vs LGD 2019国际邀请赛小组赛 BO2 第一场 8.15
2019/08/16 DOTA
python中使用urllib2伪造HTTP报头的2个方法
2014/07/07 Python
Python中的自省(反射)详解
2015/06/02 Python
Python找出list中最常出现元素的方法
2016/06/14 Python
Python实现类的创建与使用方法示例
2017/07/25 Python
为什么选择python编程语言入门黑客攻防 给你几个理由!
2018/02/02 Python
浅谈Django中的数据库模型类-models.py(一对一的关系)
2018/05/30 Python
python 实现将txt文件多行合并为一行并将中间的空格去掉方法
2018/12/20 Python
pytorch使用Variable实现线性回归
2019/05/21 Python
使用python-opencv读取视频,计算视频总帧数及FPS的实现
2019/12/10 Python
Anaconda+Pycharm环境下的PyTorch配置方法
2020/03/13 Python
浅谈Keras中shuffle和validation_split的顺序
2020/06/19 Python
财务学生的职业生涯发展
2014/02/11 职场文书
物流管理专业毕业生自荐信
2014/03/04 职场文书
物业总经理助理岗位职责
2014/06/29 职场文书
说好普通话圆梦你我他演讲稿
2014/09/21 职场文书
法律专业大学生职业生涯规划书:向目标一步步迈进
2014/09/22 职场文书
pytorch中的model.eval()和BN层的使用
2021/05/22 Python
vue @click.native 绑定原生点击事件
2022/04/22 Vue.js