手把手教你使用TensorFlow2实现RNN


Posted in Python onJuly 15, 2021
目录
  • 概述
  • 权重共享
  • 计算过程:
  • 案例
    • 数据集
    • RNN 层
    • 获取数据
  • 完整代码

 

概述

RNN (Recurrent Netural Network) 是用于处理序列数据的神经网络. 所谓序列数据, 即前面的输入和后面的输入有一定的联系.

手把手教你使用TensorFlow2实现RNN

 

权重共享

传统神经网络:

手把手教你使用TensorFlow2实现RNN

RNN:

手把手教你使用TensorFlow2实现RNN

RNN 的权重共享和 CNN 的权重共享类似, 不同时刻共享一个权重, 大大减少了参数数量.

 

计算过程:

手把手教你使用TensorFlow2实现RNN

计算状态 (State)

手把手教你使用TensorFlow2实现RNN

计算输出:

手把手教你使用TensorFlow2实现RNN

 

案例

 

数据集

IBIM 数据集包含了来自互联网的 50000 条关于电影的评论, 分为正面评价和负面评价.

 

RNN 层

class RNN(tf.keras.Model):

    def __init__(self, units):
        super(RNN, self).__init__()

        # 初始化 [b, 64] (b 表示 batch_size)
        self.state0 = [tf.zeros([batch_size, units])]
        self.state1 = [tf.zeros([batch_size, units])]

        # [b, 80] => [b, 80, 100]
        self.embedding = tf.keras.layers.Embedding(total_words, embedding_len, input_length=max_review_len)

        self.rnn_cell0 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)
        self.rnn_cell1 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)

        # [b, 80, 100] => [b, 64] => [b, 1]
        self.out_layer = tf.keras.layers.Dense(1)

    def call(self, inputs, training=None):
        """

        :param inputs: [b, 80]
        :param training:
        :return:
        """

        state0 = self.state0
        state1 = self.state1

        x = self.embedding(inputs)

        for word in tf.unstack(x, axis=1):
            out0, state0 = self.rnn_cell0(word, state0, training=training)
            out1, state1 = self.rnn_cell1(out0, state1, training=training)

        # [b, 64] -> [b, 1]
        x = self.out_layer(out1)

        prob = tf.sigmoid(x)

        return prob

 

获取数据

def get_data():
    # 获取数据
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=total_words)

    # 更改句子长度
    X_train = tf.keras.preprocessing.sequence.pad_sequences(X_train, maxlen=max_review_len)
    X_test = tf.keras.preprocessing.sequence.pad_sequences(X_test, maxlen=max_review_len)

    # 调试输出
    print(X_train.shape, y_train.shape)  # (25000, 80) (25000,)
    print(X_test.shape, y_test.shape)  # (25000, 80) (25000,)

    # 分割训练集
    train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    train_db = train_db.shuffle(10000).batch(batch_size, drop_remainder=True)

    # 分割测试集
    test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test))
    test_db = test_db.batch(batch_size, drop_remainder=True)

    return train_db, test_db

 

完整代码

import tensorflow as tf


class RNN(tf.keras.Model):

    def __init__(self, units):
        super(RNN, self).__init__()

        # 初始化 [b, 64]
        self.state0 = [tf.zeros([batch_size, units])]
        self.state1 = [tf.zeros([batch_size, units])]

        # [b, 80] => [b, 80, 100]
        self.embedding = tf.keras.layers.Embedding(total_words, embedding_len, input_length=max_review_len)

        self.rnn_cell0 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)
        self.rnn_cell1 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)

        # [b, 80, 100] => [b, 64] => [b, 1]
        self.out_layer = tf.keras.layers.Dense(1)

    def call(self, inputs, training=None):
        """

        :param inputs: [b, 80]
        :param training:
        :return:
        """

        state0 = self.state0
        state1 = self.state1

        x = self.embedding(inputs)

        for word in tf.unstack(x, axis=1):
            out0, state0 = self.rnn_cell0(word, state0, training=training)
            out1, state1 = self.rnn_cell1(out0, state1, training=training)

        # [b, 64] -> [b, 1]
        x = self.out_layer(out1)

        prob = tf.sigmoid(x)

        return prob


# 超参数
total_words = 10000  # 文字数量
max_review_len = 80  # 句子长度
embedding_len = 100  # 词维度
batch_size = 1024  # 一次训练的样本数目
learning_rate = 0.0001  # 学习率
iteration_num = 20  # 迭代次数
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)  # 优化器
loss = tf.losses.BinaryCrossentropy(from_logits=True)  # 损失
model = RNN(64)

# 调试输出summary
model.build(input_shape=[None, 64])
print(model.summary())

# 组合
model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"])


def get_data():
    # 获取数据
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=total_words)

    # 更改句子长度
    X_train = tf.keras.preprocessing.sequence.pad_sequences(X_train, maxlen=max_review_len)
    X_test = tf.keras.preprocessing.sequence.pad_sequences(X_test, maxlen=max_review_len)

    # 调试输出
    print(X_train.shape, y_train.shape)  # (25000, 80) (25000,)
    print(X_test.shape, y_test.shape)  # (25000, 80) (25000,)

    # 分割训练集
    train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    train_db = train_db.shuffle(10000).batch(batch_size, drop_remainder=True)

    # 分割测试集
    test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test))
    test_db = test_db.batch(batch_size, drop_remainder=True)

    return train_db, test_db


if __name__ == "__main__":
    # 获取分割的数据集
    train_db, test_db = get_data()

    # 拟合
    model.fit(train_db, epochs=iteration_num, validation_data=test_db, validation_freq=1)

输出结果:

Model: "rnn"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) multiple 1000000
_________________________________________________________________
simple_rnn_cell (SimpleRNNCe multiple 10560
_________________________________________________________________
simple_rnn_cell_1 (SimpleRNN multiple 8256
_________________________________________________________________
dense (Dense) multiple 65
=================================================================
Total params: 1,018,881
Trainable params: 1,018,881
Non-trainable params: 0
_________________________________________________________________
None

(25000, 80) (25000,)
(25000, 80) (25000,)
Epoch 1/20
2021-07-10 17:59:45.150639: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
24/24 [==============================] - 12s 294ms/step - loss: 0.7113 - accuracy: 0.5033 - val_loss: 0.6968 - val_accuracy: 0.4994
Epoch 2/20
24/24 [==============================] - 7s 292ms/step - loss: 0.6951 - accuracy: 0.5005 - val_loss: 0.6939 - val_accuracy: 0.4994
Epoch 3/20
24/24 [==============================] - 7s 297ms/step - loss: 0.6937 - accuracy: 0.5000 - val_loss: 0.6935 - val_accuracy: 0.4994
Epoch 4/20
24/24 [==============================] - 8s 316ms/step - loss: 0.6934 - accuracy: 0.5001 - val_loss: 0.6933 - val_accuracy: 0.4994
Epoch 5/20
24/24 [==============================] - 7s 301ms/step - loss: 0.6934 - accuracy: 0.4996 - val_loss: 0.6933 - val_accuracy: 0.4994
Epoch 6/20
24/24 [==============================] - 8s 334ms/step - loss: 0.6932 - accuracy: 0.5000 - val_loss: 0.6932 - val_accuracy: 0.4994
Epoch 7/20
24/24 [==============================] - 10s 398ms/step - loss: 0.6931 - accuracy: 0.5006 - val_loss: 0.6932 - val_accuracy: 0.4994
Epoch 8/20
24/24 [==============================] - 9s 382ms/step - loss: 0.6930 - accuracy: 0.5006 - val_loss: 0.6931 - val_accuracy: 0.4994
Epoch 9/20
24/24 [==============================] - 8s 322ms/step - loss: 0.6924 - accuracy: 0.4995 - val_loss: 0.6913 - val_accuracy: 0.5240
Epoch 10/20
24/24 [==============================] - 8s 321ms/step - loss: 0.6812 - accuracy: 0.5501 - val_loss: 0.6655 - val_accuracy: 0.5767
Epoch 11/20
24/24 [==============================] - 8s 318ms/step - loss: 0.6381 - accuracy: 0.6896 - val_loss: 0.6235 - val_accuracy: 0.7399
Epoch 12/20
24/24 [==============================] - 8s 323ms/step - loss: 0.6088 - accuracy: 0.7655 - val_loss: 0.6110 - val_accuracy: 0.7533
Epoch 13/20
24/24 [==============================] - 8s 321ms/step - loss: 0.5949 - accuracy: 0.7956 - val_loss: 0.6111 - val_accuracy: 0.7878
Epoch 14/20
24/24 [==============================] - 8s 324ms/step - loss: 0.5859 - accuracy: 0.8142 - val_loss: 0.5993 - val_accuracy: 0.7904
Epoch 15/20
24/24 [==============================] - 8s 330ms/step - loss: 0.5791 - accuracy: 0.8318 - val_loss: 0.5961 - val_accuracy: 0.7907
Epoch 16/20
24/24 [==============================] - 8s 340ms/step - loss: 0.5739 - accuracy: 0.8421 - val_loss: 0.5942 - val_accuracy: 0.7961
Epoch 17/20
24/24 [==============================] - 9s 378ms/step - loss: 0.5701 - accuracy: 0.8497 - val_loss: 0.5933 - val_accuracy: 0.8014
Epoch 18/20
24/24 [==============================] - 9s 361ms/step - loss: 0.5665 - accuracy: 0.8589 - val_loss: 0.5958 - val_accuracy: 0.8082
Epoch 19/20
24/24 [==============================] - 8s 353ms/step - loss: 0.5630 - accuracy: 0.8681 - val_loss: 0.5931 - val_accuracy: 0.7966
Epoch 20/20
24/24 [==============================] - 8s 314ms/step - loss: 0.5614 - accuracy: 0.8702 - val_loss: 0.5925 - val_accuracy: 0.7959

Process finished with exit code 0

到此这篇关于手把手教你使用TensorFlow2实现RNN的文章就介绍到这了,更多相关TensorFlow2实现RNN内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python中subprocess模块用法实例详解
May 20 Python
Django中对通过测试的用户进行限制访问的方法
Jul 23 Python
Python中Django框架利用url来控制登录的方法
Jul 25 Python
PyChar学习教程之自定义文件与代码模板详解
Jul 17 Python
python实现决策树
Dec 21 Python
详解tensorflow实现迁移学习实例
Feb 10 Python
pandas中去除指定字符的实例
May 18 Python
解决Ubuntu pip 安装 mysql-python包出错的问题
Jun 11 Python
用python实现k近邻算法的示例代码
Sep 06 Python
python flask安装和命令详解
Apr 02 Python
Python hashlib模块加密过程解析
Nov 05 Python
python ctypes库2_指定参数类型和返回类型详解
Nov 19 Python
一篇文章弄懂Python关键字、标识符和变量
python开发飞机大战游戏
详解Python中下划线的5种含义
Python操作CSV格式文件的方法大全
openstack中的rpc远程调用的方法
Python实现查询剪贴板自动匹配信息的思路详解
如何利用Python实现一个论文降重工具
You might like
下载文件的点击数回填
2006/10/09 PHP
ThinkPHP与PHPExcel冲突解决方法
2011/08/08 PHP
PHP图片等比缩放类SimpleImage使用方法和使用实例分享
2014/04/10 PHP
php根据年月获取当月天数及日期数组的方法
2016/11/30 PHP
详解PHP序列化和反序列化原理
2018/01/15 PHP
Ubuntu上安装yaf扩展的方法
2018/01/29 PHP
javascript脚本编程解决考试分数统计问题
2008/10/18 Javascript
js通过八个点 拖动改变div大小的实现方法
2014/03/05 Javascript
jquery设置按钮停顿3秒不可用
2014/03/07 Javascript
Javascript字符串拼接小技巧(推荐)
2016/06/02 Javascript
JS移动端/H5同时选择多张图片上传并使用canvas压缩图片
2017/06/20 Javascript
用Node编写RESTful API接口的示例代码
2018/07/04 Javascript
Vue使用Clipboard.JS在h5页面中复制内容实例详解
2019/09/03 Javascript
vue与django集成打包的实现方法
2019/11/11 Javascript
js 将多个对象合并成一个对象 assign方法的实现
2020/09/24 Javascript
用vue设计一个日历表
2020/12/03 Vue.js
VUE实现吸底按钮
2021/03/04 Vue.js
在Django同1个页面中的多表单处理详解
2017/01/25 Python
Django Admin 实现外键过滤的方法
2017/09/29 Python
pandas中Timestamp类用法详解
2017/12/11 Python
Python中实现switch功能实例解析
2018/01/11 Python
Tensorflow中tf.ConfigProto()的用法详解
2020/02/06 Python
OpenCV中VideoCapture类的使用详解
2020/02/14 Python
PyCharm+Miniconda3安装配置教程详解
2021/02/16 Python
HTML5 Canvas渐进填充与透明实现图像的Mask效果
2013/07/11 HTML / CSS
HTML5声音录制/播放功能的实现代码
2018/05/03 HTML / CSS
美国名牌手表折扣网站:Jomashop
2020/05/22 全球购物
城市轨道交通工程职业规划书范文
2014/01/18 职场文书
人力资源作业细则
2014/03/03 职场文书
2014年内勤工作总结
2014/11/24 职场文书
药店收银员岗位职责
2015/04/07 职场文书
投资申请报告
2015/05/19 职场文书
音乐之声观后感
2015/06/04 职场文书
爱国之歌(8首)
2019/09/29 职场文书
创业计划书之酒厂
2019/10/14 职场文书
table不让td文字溢出操作方法
2022/12/24 HTML / CSS