python神经网络 使用Keras构建RNN训练


Posted in Python onMay 04, 2022

Keras中构建RNN的重要函数

1、SimpleRNN

SimpleRNN用于在Keras中构建普通的简单RNN层,在使用前需要import。

from keras.layers import SimpleRNN

在实际使用时,需要用到几个参数。

model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)

其中,batch_input_shape代表RNN输入数据的shape,shape的内容分别是每一次训练使用的BATCH,TIME_STEPS表示这个RNN按顺序输入的时间点的数量,INPUT_SIZE表示每一个时间点的输入数据大小。
CELL_SIZE代表训练每一个时间点的神经元数量。

2、model.train_on_batch

与之前的训练CNN网络和普通分类网络不同,RNN网络在建立时就规定了batch_input_shape,所以训练的时候也需要一定量一定量的传入训练数据。
model.train_on_batch在使用前需要对数据进行处理。获取指定BATCH大小的训练集。

X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE

具体训练过程如下:

for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)
        x = X_test[1].reshape(1,28,28)

全部代码

这是一个RNN神经网络的例子,用于识别手写体。

import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam

TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3

(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
 
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255

Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)

model = Sequential()

# conv1
model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)
model.add(Dense(OUTPUT_SIZE))
model.add(Activation("softmax"))
adam = Adam(LR)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])

## tarin
for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)

实验结果为:

10000/10000 [==============================] - 1s 147us/step
accuracy: 0.09329999938607215
…………………………
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9395000022649765
10000/10000 [==============================] - 1s 109us/step
accuracy: 0.9422999995946885
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9534000000357628
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9566000008583069
10000/10000 [==============================] - 1s 113us/step
accuracy: 0.950799999833107
10000/10000 [==============================] - 1s 116us/step
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9474999988079071
10000/10000 [==============================] - 1s 111us/step
accuracy: 0.9515000003576278
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9288999977707862
10000/10000 [==============================] - 1s 115us/step
accuracy: 0.9487999993562698

以上就是python神经网络使用Keras构建RNN训练的详细内容!


Tags in this post...

Python 相关文章推荐
python实现dnspod自动更新dns解析的方法
Feb 14 Python
python用来获得图片exif信息的库实例分析
Mar 16 Python
python实现下载指定网址所有图片的方法
Aug 08 Python
Python模拟登录验证码(代码简单)
Feb 06 Python
python中import学习备忘笔记
Jan 24 Python
Python获取指定文件夹下的文件名的方法
Feb 06 Python
python实现简单淘宝秒杀功能
May 03 Python
Python使用Shelve保存对象方法总结
Jan 28 Python
如何爬取通过ajax加载数据的网站
Aug 15 Python
python 默认参数相关知识详解
Sep 18 Python
Tensorflow读取并输出已保存模型的权重数值方式
Jan 04 Python
Python3爬虫中识别图形验证码的实例讲解
Jul 30 Python
python神经网络学习 使用Keras进行回归运算
May 04 #Python
python神经网络学习 使用Keras进行简单分类
May 04 #Python
python神经网络 tf.name_scope 和 tf.variable_scope 的区别
May 04 #Python
Python3使用Qt5来实现简易的五子棋小游戏
May 02 #Python
python开发制作好看的时钟效果
关于的python五子棋的算法
python开发人人对战的五子棋小游戏
You might like
图解上海144收音机
2021/03/02 无线电
PHP开发大型项目的一点经验
2006/10/09 PHP
PHP CURL 内存泄露问题解决方法
2015/02/12 PHP
Aster vs Newbee BO3 第三场2.18
2021/03/10 DOTA
JavaScript中使用正则匹配多条,且获取每条中的分组数据
2010/11/30 Javascript
js 操作select和option常用代码整理
2012/12/13 Javascript
JS将光标聚焦在文本最后的实现代码
2014/03/28 Javascript
javascript url几种编码方式详解
2016/06/06 Javascript
AngularJS入门教程之服务(Service)
2016/07/27 Javascript
基于JS实现仿百度百家主页的轮播图效果
2017/03/06 Javascript
HTML5实现微信拍摄上传照片功能
2017/04/21 Javascript
原生JS实现隐藏显示图片 JS实现点击切换图片效果
2021/01/27 Javascript
React Native 环境搭建的教程
2017/08/19 Javascript
快速解决vue在ios端下点击响应延时的问题
2018/08/27 Javascript
实例分析vue循环列表动态数据的处理方法
2018/09/28 Javascript
iview同时验证多个表单问题总结
2018/09/29 Javascript
如何使用less实现随机下雪动画详解
2019/01/02 Javascript
js实现计算器功能
2020/08/10 Javascript
python使用PythonMagick将jpg图片转换成ico图片的方法
2015/03/26 Python
在Django中创建第一个静态视图
2015/07/15 Python
基于python的多进程共享变量正确打开方式
2018/04/28 Python
使用NumPy和pandas对CSV文件进行写操作的实例
2018/06/14 Python
python 遍历目录(包括子目录)下所有文件的实例
2018/07/11 Python
python统计字符的个数代码实例
2020/02/07 Python
python实现opencv+scoket网络实时图传
2020/03/20 Python
python文件读取失败怎么处理
2020/06/23 Python
Python 在 VSCode 中使用 IPython Kernel 的方法详解
2020/09/05 Python
超级英雄、电影和电视、乐队和音乐T恤:Loud Clothing
2019/09/01 全球购物
马智宇结婚主持词
2014/04/01 职场文书
2014年巴西世界杯口号
2014/06/05 职场文书
农林经济管理专业自荐信
2014/09/01 职场文书
创业计划书之宠物店
2019/09/19 职场文书
导游词之清晏园
2019/11/22 职场文书
Apache Calcite 实现方言转换的代码
2021/04/24 Servers
新手初学Java网络编程
2021/07/07 Java/Android
阿里云国际版 使用Nginx作为HTTPS转发代理服务器
2022/05/11 Servers