手把手教你使用TensorFlow2实现RNN


Posted in Python onJuly 15, 2021
目录
  • 概述
  • 权重共享
  • 计算过程:
  • 案例
    • 数据集
    • RNN 层
    • 获取数据
  • 完整代码

 

概述

RNN (Recurrent Netural Network) 是用于处理序列数据的神经网络. 所谓序列数据, 即前面的输入和后面的输入有一定的联系.

手把手教你使用TensorFlow2实现RNN

 

权重共享

传统神经网络:

手把手教你使用TensorFlow2实现RNN

RNN:

手把手教你使用TensorFlow2实现RNN

RNN 的权重共享和 CNN 的权重共享类似, 不同时刻共享一个权重, 大大减少了参数数量.

 

计算过程:

手把手教你使用TensorFlow2实现RNN

计算状态 (State)

手把手教你使用TensorFlow2实现RNN

计算输出:

手把手教你使用TensorFlow2实现RNN

 

案例

 

数据集

IBIM 数据集包含了来自互联网的 50000 条关于电影的评论, 分为正面评价和负面评价.

 

RNN 层

class RNN(tf.keras.Model):

    def __init__(self, units):
        super(RNN, self).__init__()

        # 初始化 [b, 64] (b 表示 batch_size)
        self.state0 = [tf.zeros([batch_size, units])]
        self.state1 = [tf.zeros([batch_size, units])]

        # [b, 80] => [b, 80, 100]
        self.embedding = tf.keras.layers.Embedding(total_words, embedding_len, input_length=max_review_len)

        self.rnn_cell0 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)
        self.rnn_cell1 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)

        # [b, 80, 100] => [b, 64] => [b, 1]
        self.out_layer = tf.keras.layers.Dense(1)

    def call(self, inputs, training=None):
        """

        :param inputs: [b, 80]
        :param training:
        :return:
        """

        state0 = self.state0
        state1 = self.state1

        x = self.embedding(inputs)

        for word in tf.unstack(x, axis=1):
            out0, state0 = self.rnn_cell0(word, state0, training=training)
            out1, state1 = self.rnn_cell1(out0, state1, training=training)

        # [b, 64] -> [b, 1]
        x = self.out_layer(out1)

        prob = tf.sigmoid(x)

        return prob

 

获取数据

def get_data():
    # 获取数据
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=total_words)

    # 更改句子长度
    X_train = tf.keras.preprocessing.sequence.pad_sequences(X_train, maxlen=max_review_len)
    X_test = tf.keras.preprocessing.sequence.pad_sequences(X_test, maxlen=max_review_len)

    # 调试输出
    print(X_train.shape, y_train.shape)  # (25000, 80) (25000,)
    print(X_test.shape, y_test.shape)  # (25000, 80) (25000,)

    # 分割训练集
    train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    train_db = train_db.shuffle(10000).batch(batch_size, drop_remainder=True)

    # 分割测试集
    test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test))
    test_db = test_db.batch(batch_size, drop_remainder=True)

    return train_db, test_db

 

完整代码

import tensorflow as tf


class RNN(tf.keras.Model):

    def __init__(self, units):
        super(RNN, self).__init__()

        # 初始化 [b, 64]
        self.state0 = [tf.zeros([batch_size, units])]
        self.state1 = [tf.zeros([batch_size, units])]

        # [b, 80] => [b, 80, 100]
        self.embedding = tf.keras.layers.Embedding(total_words, embedding_len, input_length=max_review_len)

        self.rnn_cell0 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)
        self.rnn_cell1 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)

        # [b, 80, 100] => [b, 64] => [b, 1]
        self.out_layer = tf.keras.layers.Dense(1)

    def call(self, inputs, training=None):
        """

        :param inputs: [b, 80]
        :param training:
        :return:
        """

        state0 = self.state0
        state1 = self.state1

        x = self.embedding(inputs)

        for word in tf.unstack(x, axis=1):
            out0, state0 = self.rnn_cell0(word, state0, training=training)
            out1, state1 = self.rnn_cell1(out0, state1, training=training)

        # [b, 64] -> [b, 1]
        x = self.out_layer(out1)

        prob = tf.sigmoid(x)

        return prob


# 超参数
total_words = 10000  # 文字数量
max_review_len = 80  # 句子长度
embedding_len = 100  # 词维度
batch_size = 1024  # 一次训练的样本数目
learning_rate = 0.0001  # 学习率
iteration_num = 20  # 迭代次数
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)  # 优化器
loss = tf.losses.BinaryCrossentropy(from_logits=True)  # 损失
model = RNN(64)

# 调试输出summary
model.build(input_shape=[None, 64])
print(model.summary())

# 组合
model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"])


def get_data():
    # 获取数据
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=total_words)

    # 更改句子长度
    X_train = tf.keras.preprocessing.sequence.pad_sequences(X_train, maxlen=max_review_len)
    X_test = tf.keras.preprocessing.sequence.pad_sequences(X_test, maxlen=max_review_len)

    # 调试输出
    print(X_train.shape, y_train.shape)  # (25000, 80) (25000,)
    print(X_test.shape, y_test.shape)  # (25000, 80) (25000,)

    # 分割训练集
    train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    train_db = train_db.shuffle(10000).batch(batch_size, drop_remainder=True)

    # 分割测试集
    test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test))
    test_db = test_db.batch(batch_size, drop_remainder=True)

    return train_db, test_db


if __name__ == "__main__":
    # 获取分割的数据集
    train_db, test_db = get_data()

    # 拟合
    model.fit(train_db, epochs=iteration_num, validation_data=test_db, validation_freq=1)

输出结果:

Model: "rnn"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) multiple 1000000
_________________________________________________________________
simple_rnn_cell (SimpleRNNCe multiple 10560
_________________________________________________________________
simple_rnn_cell_1 (SimpleRNN multiple 8256
_________________________________________________________________
dense (Dense) multiple 65
=================================================================
Total params: 1,018,881
Trainable params: 1,018,881
Non-trainable params: 0
_________________________________________________________________
None

(25000, 80) (25000,)
(25000, 80) (25000,)
Epoch 1/20
2021-07-10 17:59:45.150639: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
24/24 [==============================] - 12s 294ms/step - loss: 0.7113 - accuracy: 0.5033 - val_loss: 0.6968 - val_accuracy: 0.4994
Epoch 2/20
24/24 [==============================] - 7s 292ms/step - loss: 0.6951 - accuracy: 0.5005 - val_loss: 0.6939 - val_accuracy: 0.4994
Epoch 3/20
24/24 [==============================] - 7s 297ms/step - loss: 0.6937 - accuracy: 0.5000 - val_loss: 0.6935 - val_accuracy: 0.4994
Epoch 4/20
24/24 [==============================] - 8s 316ms/step - loss: 0.6934 - accuracy: 0.5001 - val_loss: 0.6933 - val_accuracy: 0.4994
Epoch 5/20
24/24 [==============================] - 7s 301ms/step - loss: 0.6934 - accuracy: 0.4996 - val_loss: 0.6933 - val_accuracy: 0.4994
Epoch 6/20
24/24 [==============================] - 8s 334ms/step - loss: 0.6932 - accuracy: 0.5000 - val_loss: 0.6932 - val_accuracy: 0.4994
Epoch 7/20
24/24 [==============================] - 10s 398ms/step - loss: 0.6931 - accuracy: 0.5006 - val_loss: 0.6932 - val_accuracy: 0.4994
Epoch 8/20
24/24 [==============================] - 9s 382ms/step - loss: 0.6930 - accuracy: 0.5006 - val_loss: 0.6931 - val_accuracy: 0.4994
Epoch 9/20
24/24 [==============================] - 8s 322ms/step - loss: 0.6924 - accuracy: 0.4995 - val_loss: 0.6913 - val_accuracy: 0.5240
Epoch 10/20
24/24 [==============================] - 8s 321ms/step - loss: 0.6812 - accuracy: 0.5501 - val_loss: 0.6655 - val_accuracy: 0.5767
Epoch 11/20
24/24 [==============================] - 8s 318ms/step - loss: 0.6381 - accuracy: 0.6896 - val_loss: 0.6235 - val_accuracy: 0.7399
Epoch 12/20
24/24 [==============================] - 8s 323ms/step - loss: 0.6088 - accuracy: 0.7655 - val_loss: 0.6110 - val_accuracy: 0.7533
Epoch 13/20
24/24 [==============================] - 8s 321ms/step - loss: 0.5949 - accuracy: 0.7956 - val_loss: 0.6111 - val_accuracy: 0.7878
Epoch 14/20
24/24 [==============================] - 8s 324ms/step - loss: 0.5859 - accuracy: 0.8142 - val_loss: 0.5993 - val_accuracy: 0.7904
Epoch 15/20
24/24 [==============================] - 8s 330ms/step - loss: 0.5791 - accuracy: 0.8318 - val_loss: 0.5961 - val_accuracy: 0.7907
Epoch 16/20
24/24 [==============================] - 8s 340ms/step - loss: 0.5739 - accuracy: 0.8421 - val_loss: 0.5942 - val_accuracy: 0.7961
Epoch 17/20
24/24 [==============================] - 9s 378ms/step - loss: 0.5701 - accuracy: 0.8497 - val_loss: 0.5933 - val_accuracy: 0.8014
Epoch 18/20
24/24 [==============================] - 9s 361ms/step - loss: 0.5665 - accuracy: 0.8589 - val_loss: 0.5958 - val_accuracy: 0.8082
Epoch 19/20
24/24 [==============================] - 8s 353ms/step - loss: 0.5630 - accuracy: 0.8681 - val_loss: 0.5931 - val_accuracy: 0.7966
Epoch 20/20
24/24 [==============================] - 8s 314ms/step - loss: 0.5614 - accuracy: 0.8702 - val_loss: 0.5925 - val_accuracy: 0.7959

Process finished with exit code 0

到此这篇关于手把手教你使用TensorFlow2实现RNN的文章就介绍到这了,更多相关TensorFlow2实现RNN内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python检测远程服务器tcp端口的方法
Mar 14 Python
使用Python的Tornado框架实现一个Web端图书展示页面
Jul 11 Python
Python数据可视化正态分布简单分析及实现代码
Dec 04 Python
用tensorflow搭建CNN的方法
Mar 05 Python
详解Python计算机视觉 图像扭曲(仿射扭曲)
Mar 27 Python
Python的Tkinter点击按钮触发事件的例子
Jul 19 Python
Python timer定时器两种常用方法解析
Jan 20 Python
Python+OpenCV实现图像的全景拼接
Mar 05 Python
基于python计算滚动方差(标准差)talib和pd.rolling函数差异详解
Jun 08 Python
Python3.9 beta2版本发布了,看看这7个新的PEP都是什么
Jun 10 Python
详解python中[-1]、[:-1]、[::-1]、[n::-1]使用方法
Apr 25 Python
python图像处理 PIL Image操作实例
Apr 09 Python
一篇文章弄懂Python关键字、标识符和变量
python开发飞机大战游戏
详解Python中下划线的5种含义
Python操作CSV格式文件的方法大全
openstack中的rpc远程调用的方法
Python实现查询剪贴板自动匹配信息的思路详解
如何利用Python实现一个论文降重工具
You might like
PHP中HTTP方式下的Gzip压缩传输方法举偶
2007/02/15 PHP
php 运行效率总结(提示程序速度)
2009/11/26 PHP
php上传图片到指定位置路径保存到数据库的具体实现
2013/12/30 PHP
PHP使用缓存即时输出内容(output buffering)的方法
2015/08/03 PHP
[原创]php求圆周率的简单实现方法
2016/05/30 PHP
关于Yii中模型场景的一些简单介绍
2019/09/22 PHP
jQuery LigerUI 使用教程表格篇(1)
2012/01/18 Javascript
jQuery中focus事件用法实例
2014/12/26 Javascript
Javascript闭包用法实例分析
2015/01/23 Javascript
JavaScript tab选项卡插件实例代码
2016/02/23 Javascript
JS控制伪元素的方法汇总
2016/04/06 Javascript
JavaScript禁止用户多次提交的两种方法
2016/07/24 Javascript
JavaScript基于对象去除数组重复项的方法
2016/10/09 Javascript
Jquery Easyui进度条组件Progress使用详解(8)
2020/03/26 Javascript
JavaScript中常见的八个陷阱总结
2017/06/28 Javascript
vue-prop父组件向子组件进行传值的方法
2018/03/01 Javascript
Angular异步变同步处理方法
2018/08/13 Javascript
微信小程序跨页面数据传递事件响应实现过程解析
2019/12/19 Javascript
React Hooks 实现和由来以及解决的问题详解
2020/01/17 Javascript
超详细小程序定位地图模块全系列开发教学
2020/11/24 Javascript
python使用urllib模块开发的多线程豆瓣小站mp3下载器
2014/01/16 Python
利用Python抓取行政区划码的方法
2016/11/28 Python
Python 备份程序代码实现
2017/03/06 Python
深入了解NumPy 高级索引
2020/07/24 Python
如何使用 Python 读取文件和照片的创建日期
2020/09/05 Python
纯HTML+CSS3制作导航菜单(附源码)
2013/04/24 HTML / CSS
详解HTML5如何使用可选样式表为网站或应用添加黑暗模式
2020/04/07 HTML / CSS
俄罗斯的精英皮具:Wittchen
2018/01/29 全球购物
工程管理造价应届生求职信
2013/11/13 职场文书
集团公司总经理岗位职责
2013/12/20 职场文书
高一新生军训感言
2014/03/02 职场文书
自我查摆剖析材料
2014/10/11 职场文书
普通党员群众路线教育实践活动心得体会
2014/11/04 职场文书
大学三好学生主要事迹范文
2015/11/03 职场文书
如何撰写出一份完美的商业计划书?
2019/07/12 职场文书
Go语言中的UTF-8实现
2021/04/26 Golang