python神经网络 使用Keras构建RNN训练


Posted in Python onMay 04, 2022

Keras中构建RNN的重要函数

1、SimpleRNN

SimpleRNN用于在Keras中构建普通的简单RNN层,在使用前需要import。

from keras.layers import SimpleRNN

在实际使用时,需要用到几个参数。

model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)

其中,batch_input_shape代表RNN输入数据的shape,shape的内容分别是每一次训练使用的BATCH,TIME_STEPS表示这个RNN按顺序输入的时间点的数量,INPUT_SIZE表示每一个时间点的输入数据大小。
CELL_SIZE代表训练每一个时间点的神经元数量。

2、model.train_on_batch

与之前的训练CNN网络和普通分类网络不同,RNN网络在建立时就规定了batch_input_shape,所以训练的时候也需要一定量一定量的传入训练数据。
model.train_on_batch在使用前需要对数据进行处理。获取指定BATCH大小的训练集。

X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE

具体训练过程如下:

for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)
        x = X_test[1].reshape(1,28,28)

全部代码

这是一个RNN神经网络的例子,用于识别手写体。

import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam

TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3

(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
 
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255

Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)

model = Sequential()

# conv1
model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)
model.add(Dense(OUTPUT_SIZE))
model.add(Activation("softmax"))
adam = Adam(LR)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])

## tarin
for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)

实验结果为:

10000/10000 [==============================] - 1s 147us/step
accuracy: 0.09329999938607215
…………………………
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9395000022649765
10000/10000 [==============================] - 1s 109us/step
accuracy: 0.9422999995946885
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9534000000357628
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9566000008583069
10000/10000 [==============================] - 1s 113us/step
accuracy: 0.950799999833107
10000/10000 [==============================] - 1s 116us/step
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9474999988079071
10000/10000 [==============================] - 1s 111us/step
accuracy: 0.9515000003576278
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9288999977707862
10000/10000 [==============================] - 1s 115us/step
accuracy: 0.9487999993562698

以上就是python神经网络使用Keras构建RNN训练的详细内容!


Tags in this post...

Python 相关文章推荐
python使用线程封装的一个简单定时器类实例
May 16 Python
Python进行数据提取的方法总结
Aug 22 Python
Python模拟三级菜单效果
Sep 11 Python
答题辅助python代码实现
Jan 16 Python
python+numpy+matplotalib实现梯度下降法
Aug 31 Python
python 定义n个变量方法 (变量声明自动化)
Nov 10 Python
python实现nao机器人手臂动作控制
Apr 29 Python
选择python进行数据分析的理由和优势
Jun 25 Python
Win下PyInstaller 安装和使用教程
Dec 25 Python
python torch.utils.data.DataLoader使用方法
Apr 02 Python
Python selenium自动化测试模型图解
Apr 15 Python
在Matplotlib图中插入LaTex公式实例
Apr 17 Python
python神经网络学习 使用Keras进行回归运算
May 04 #Python
python神经网络学习 使用Keras进行简单分类
May 04 #Python
python神经网络 tf.name_scope 和 tf.variable_scope 的区别
May 04 #Python
Python3使用Qt5来实现简易的五子棋小游戏
May 02 #Python
python开发制作好看的时钟效果
关于的python五子棋的算法
python开发人人对战的五子棋小游戏
You might like
dedecms 制作模板中使用的全局标记图文教程
2007/03/11 PHP
PHP以及MYSQL日期比较方法
2012/11/29 PHP
PHP模拟登陆163邮箱发邮件及获取通讯录列表的方法
2015/03/07 PHP
仿校内登陆框,精美,给那些很厉害但是没有设计天才的程序员
2008/11/24 Javascript
jquery绑定原理 简单解析与实现代码分享
2011/09/06 Javascript
js检测输入内容全为空格的方法
2014/05/03 Javascript
jQuery基础知识小结
2014/12/22 Javascript
JavaScript设置body高度为浏览器高度的方法
2015/02/09 Javascript
Nodejs中session的简单使用及通过session实现身份验证的方法
2016/02/04 NodeJs
javascript事件冒泡简单示例
2016/06/20 Javascript
JS解决iframe之间通信和自适应高度的问题
2016/08/24 Javascript
轻松实现jquery选项卡切换效果
2016/10/10 Javascript
手把手教你搭建ES6的开发运行环境
2017/07/11 Javascript
vue props传值失败 输出undefined的解决方法
2018/09/11 Javascript
ionic+html5+API实现双击返回键退出应用
2019/09/17 Javascript
npx create-react-app xxx创建项目报错的解决办法
2020/02/17 Javascript
Vue3配置axios跨域实现过程解析
2020/11/25 Vue.js
[02:07]DOTA2新英雄展现中国元素,完美“圣典”亮相央视
2016/12/19 DOTA
[58:57]2018DOTA2亚洲邀请赛3月29日小组赛B组 Effect VS VGJ.T
2018/03/30 DOTA
python高并发异步服务器核心库forkcore使用方法
2013/11/26 Python
浅析Python中的多重继承
2015/04/28 Python
Python实现的Excel文件读写类
2015/07/30 Python
tensorflow识别自己手写数字
2018/03/14 Python
pycharm 将django中多个app放到同个文件夹apps的处理方法
2018/05/30 Python
Scrapy框架爬取Boss直聘网Python职位信息的源码
2019/02/22 Python
详解python tkinter包获取本地绝对路径(以获取图片并展示)
2020/09/04 Python
Python批量获取并保存手机号归属地和运营商的示例
2020/10/09 Python
利用pipenv和pyenv管理多个相互独立的Python虚拟开发环境
2020/11/01 Python
HTML5网页录音和上传到服务器支持PC、Android,支持IOS微信功能
2019/04/26 HTML / CSS
高校生生产实习自我鉴定
2013/09/21 职场文书
超市仓管员岗位职责范本
2014/09/18 职场文书
婚庆开业庆典主持词
2015/06/30 职场文书
小学入学感言
2015/08/01 职场文书
2016年共产党员公开承诺书
2016/03/24 职场文书
Go中的条件语句Switch示例详解
2021/08/23 Golang
Android Studio 计算器开发
2022/05/20 Java/Android