python神经网络 使用Keras构建RNN训练


Posted in Python onMay 04, 2022

Keras中构建RNN的重要函数

1、SimpleRNN

SimpleRNN用于在Keras中构建普通的简单RNN层,在使用前需要import。

from keras.layers import SimpleRNN

在实际使用时,需要用到几个参数。

model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)

其中,batch_input_shape代表RNN输入数据的shape,shape的内容分别是每一次训练使用的BATCH,TIME_STEPS表示这个RNN按顺序输入的时间点的数量,INPUT_SIZE表示每一个时间点的输入数据大小。
CELL_SIZE代表训练每一个时间点的神经元数量。

2、model.train_on_batch

与之前的训练CNN网络和普通分类网络不同,RNN网络在建立时就规定了batch_input_shape,所以训练的时候也需要一定量一定量的传入训练数据。
model.train_on_batch在使用前需要对数据进行处理。获取指定BATCH大小的训练集。

X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE

具体训练过程如下:

for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)
        x = X_test[1].reshape(1,28,28)

全部代码

这是一个RNN神经网络的例子,用于识别手写体。

import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam

TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3

(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
 
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255

Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)

model = Sequential()

# conv1
model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)
model.add(Dense(OUTPUT_SIZE))
model.add(Activation("softmax"))
adam = Adam(LR)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])

## tarin
for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)

实验结果为:

10000/10000 [==============================] - 1s 147us/step
accuracy: 0.09329999938607215
…………………………
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9395000022649765
10000/10000 [==============================] - 1s 109us/step
accuracy: 0.9422999995946885
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9534000000357628
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9566000008583069
10000/10000 [==============================] - 1s 113us/step
accuracy: 0.950799999833107
10000/10000 [==============================] - 1s 116us/step
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9474999988079071
10000/10000 [==============================] - 1s 111us/step
accuracy: 0.9515000003576278
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9288999977707862
10000/10000 [==============================] - 1s 115us/step
accuracy: 0.9487999993562698

以上就是python神经网络使用Keras构建RNN训练的详细内容!


Tags in this post...

Python 相关文章推荐
Python的ORM框架SQLAlchemy入门教程
Apr 28 Python
python基于phantomjs实现导入图片
May 13 Python
在pandas中一次性删除dataframe的多个列方法
Apr 10 Python
Python 微信爬虫完整实例【单线程与多线程】
Jul 06 Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 Python
使用python实现希尔、计数、基数基础排序的代码
Dec 25 Python
Python configparser模块操作代码实例
Jun 08 Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 Python
Python pymsql模块的使用
Sep 07 Python
python 深度学习中的4种激活函数
Sep 18 Python
用Python进行栅格数据的分区统计和批量提取
May 27 Python
python3+PyQt5+Qt Designer实现界面可视化
Jun 10 Python
python神经网络学习 使用Keras进行回归运算
May 04 #Python
python神经网络学习 使用Keras进行简单分类
May 04 #Python
python神经网络 tf.name_scope 和 tf.variable_scope 的区别
May 04 #Python
Python3使用Qt5来实现简易的五子棋小游戏
May 02 #Python
python开发制作好看的时钟效果
关于的python五子棋的算法
python开发人人对战的五子棋小游戏
You might like
php中使用DOM类读取XML文件的实现代码
2011/12/14 PHP
PHP获取网页标题的3种实现方法代码实例
2014/04/11 PHP
Centos下升级php5.2到php5.4全记录(编译安装)
2015/04/03 PHP
PHP生成唯一订单号的方法汇总
2015/04/16 PHP
thinkphp配置文件路径的实现方法
2016/08/30 PHP
利用PHP抓取百度阅读的方法示例
2016/12/18 PHP
CakePHP框架Model函数定义方法示例
2017/08/04 PHP
jquery Moblie入门—hello world的示例代码学习
2013/01/08 Javascript
nodejs实现的一个简单聊天室功能分享
2014/12/06 NodeJs
jquery拖拽排序简单实现方法(效果增强版)
2016/02/16 Javascript
JS小数运算出现多为小数问题的解决方法
2016/06/02 Javascript
微信小程序 页面跳转传递值几种方法详解
2017/01/12 Javascript
详解Vue2+Echarts实现多种图表数据可视化Dashboard(附源码)
2017/03/21 Javascript
Vue.js实战之Vuex的入门教程
2017/04/01 Javascript
Vue 2.0在IE11中打开项目页面空白的问题解决
2017/07/16 Javascript
Js利用Canvas实现图片压缩功能
2017/09/13 Javascript
浅谈Vue SPA 首屏加载优化实践
2017/12/15 Javascript
一次Webpack配置文件的分离实战记录
2018/11/30 Javascript
详解CommonJS和ES6模块循环加载处理的区别
2018/12/26 Javascript
vue实现多个echarts根据屏幕大小变化而变化实例
2020/07/19 Javascript
[46:53]Secret vs Liquid 2019国际邀请赛小组赛 BO2 第一场 8.15
2019/08/17 DOTA
Python中的各种装饰器详解
2015/04/11 Python
python中的字典使用分享
2016/07/31 Python
利用python编写一个图片主色转换的脚本
2017/12/07 Python
在cmder下安装ipython以及环境的搭建
2018/10/19 Python
Python 将Matrix、Dict保存到文件的方法
2018/10/30 Python
python 实现从高分辨图像上抠取图像块
2020/01/02 Python
Python底层封装实现方法详解
2020/01/22 Python
python录音并调用百度语音识别接口的示例
2020/12/01 Python
关于逃课的检讨书
2014/01/23 职场文书
研究生毕业鉴定
2014/01/29 职场文书
旅游饭店管理专业自荐书
2014/06/28 职场文书
简单租房协议书范本
2014/08/20 职场文书
党员干部廉洁自律承诺书
2015/04/28 职场文书
诚实守信主题班会
2015/08/13 职场文书
MongoDB日志切割的三种方式总结
2021/09/15 MongoDB