python神经网络 使用Keras构建RNN训练


Posted in Python onMay 04, 2022

Keras中构建RNN的重要函数

1、SimpleRNN

SimpleRNN用于在Keras中构建普通的简单RNN层,在使用前需要import。

from keras.layers import SimpleRNN

在实际使用时,需要用到几个参数。

model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)

其中,batch_input_shape代表RNN输入数据的shape,shape的内容分别是每一次训练使用的BATCH,TIME_STEPS表示这个RNN按顺序输入的时间点的数量,INPUT_SIZE表示每一个时间点的输入数据大小。
CELL_SIZE代表训练每一个时间点的神经元数量。

2、model.train_on_batch

与之前的训练CNN网络和普通分类网络不同,RNN网络在建立时就规定了batch_input_shape,所以训练的时候也需要一定量一定量的传入训练数据。
model.train_on_batch在使用前需要对数据进行处理。获取指定BATCH大小的训练集。

X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE

具体训练过程如下:

for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)
        x = X_test[1].reshape(1,28,28)

全部代码

这是一个RNN神经网络的例子,用于识别手写体。

import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam

TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3

(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
 
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255

Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)

model = Sequential()

# conv1
model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)
model.add(Dense(OUTPUT_SIZE))
model.add(Activation("softmax"))
adam = Adam(LR)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])

## tarin
for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)

实验结果为:

10000/10000 [==============================] - 1s 147us/step
accuracy: 0.09329999938607215
…………………………
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9395000022649765
10000/10000 [==============================] - 1s 109us/step
accuracy: 0.9422999995946885
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9534000000357628
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9566000008583069
10000/10000 [==============================] - 1s 113us/step
accuracy: 0.950799999833107
10000/10000 [==============================] - 1s 116us/step
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9474999988079071
10000/10000 [==============================] - 1s 111us/step
accuracy: 0.9515000003576278
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9288999977707862
10000/10000 [==============================] - 1s 115us/step
accuracy: 0.9487999993562698

以上就是python神经网络使用Keras构建RNN训练的详细内容!


Tags in this post...

Python 相关文章推荐
学习python的几条建议分享
Feb 10 Python
Python 迭代器工具包【推荐】
May 06 Python
Django 使用Ajax进行前后台交互的示例讲解
May 28 Python
完美解决python中ndarray 默认用科学计数法显示的问题
Jul 14 Python
Python3.6简单的操作Mysql数据库的三个实例
Oct 17 Python
Python数据抓取爬虫代理防封IP方法
Dec 23 Python
python 反编译exe文件为py文件的实例代码
Jun 27 Python
Python解决pip install时出现的Could not fetch URL问题
Aug 01 Python
Pandas DataFrame中的tuple元素遍历的实现
Oct 23 Python
Pytorch之保存读取模型实例
Dec 30 Python
在Python中使用K-Means聚类和PCA主成分分析进行图像压缩
Apr 10 Python
jupyter 导入csv文件方式
Apr 21 Python
python神经网络学习 使用Keras进行回归运算
May 04 #Python
python神经网络学习 使用Keras进行简单分类
May 04 #Python
python神经网络 tf.name_scope 和 tf.variable_scope 的区别
May 04 #Python
Python3使用Qt5来实现简易的五子棋小游戏
May 02 #Python
python开发制作好看的时钟效果
关于的python五子棋的算法
python开发人人对战的五子棋小游戏
You might like
PHP的变量总结 新手推荐
2011/04/18 PHP
php Calender(日历)代码分享
2014/01/03 PHP
PHP编程文件处理类SplFileObject和SplFileInfo用法实例分析
2017/07/22 PHP
6个常见的 PHP 安全性攻击实例和阻止方法
2020/12/16 PHP
PHP7变量处理机制修改
2021/03/09 PHP
基于Jquery的实现回车键Enter切换焦点
2010/09/14 Javascript
理解Javascript_09_Function与Object
2010/10/16 Javascript
js中字符替换函数String.replace()使用技巧
2011/08/14 Javascript
如何实现chrome浏览器关闭页面时弹出“确定要离开此面吗?”
2015/03/05 Javascript
JS获取表格内指定单元格html内容的方法
2015/03/31 Javascript
获取IE浏览器Cookie信息的方法
2017/01/23 Javascript
详解angular应用容器化部署
2018/08/14 Javascript
深入理解es6块级作用域的使用
2019/03/28 Javascript
微信小程序实现搜索指定景点周边美食、酒店
2019/05/18 Javascript
vue-router的钩子函数用法实例分析
2019/10/26 Javascript
Vue axios获取token临时令牌封装案例
2020/09/11 Javascript
vue的$http的get请求要加上params操作
2020/11/12 Javascript
ES6的循环与可迭代对象示例详解
2021/01/31 Javascript
Python实现备份文件实例
2014/09/16 Python
Python编程实现的简单Web服务器示例
2017/06/22 Python
Python WSGI的深入理解
2018/08/01 Python
Python第三方库h5py_读取mat文件并显示值的方法
2019/02/08 Python
Python中zip()函数的解释和可视化(实例详解)
2020/02/16 Python
pytorch 模型的train模式与eval模式实例
2020/02/20 Python
Python tkinter布局与按钮间距设置方式
2020/03/04 Python
python对XML文件的操作实现代码
2020/03/27 Python
windows10 pycharm下安装pyltp库和加载模型实现语义角色标注的示例代码
2020/05/07 Python
详解FireFox下Canvas使用图像合成绘制SVG的Bug
2019/07/10 HTML / CSS
快速实现一个简单的canvas迷宫游戏的示例
2018/07/04 HTML / CSS
军训学生自我鉴定
2014/02/12 职场文书
我爱我家教学反思
2014/05/01 职场文书
升学宴演讲稿
2014/09/01 职场文书
开业庆典活动策划方案
2014/09/21 职场文书
关于拾金不昧的感谢信
2015/01/21 职场文书
学校学习型党组织建设心得体会
2019/06/21 职场文书
Python进度条的使用
2021/05/17 Python