手把手教你使用TensorFlow2实现RNN


Posted in Python onJuly 15, 2021
目录
  • 概述
  • 权重共享
  • 计算过程:
  • 案例
    • 数据集
    • RNN 层
    • 获取数据
  • 完整代码

 

概述

RNN (Recurrent Netural Network) 是用于处理序列数据的神经网络. 所谓序列数据, 即前面的输入和后面的输入有一定的联系.

手把手教你使用TensorFlow2实现RNN

 

权重共享

传统神经网络:

手把手教你使用TensorFlow2实现RNN

RNN:

手把手教你使用TensorFlow2实现RNN

RNN 的权重共享和 CNN 的权重共享类似, 不同时刻共享一个权重, 大大减少了参数数量.

 

计算过程:

手把手教你使用TensorFlow2实现RNN

计算状态 (State)

手把手教你使用TensorFlow2实现RNN

计算输出:

手把手教你使用TensorFlow2实现RNN

 

案例

 

数据集

IBIM 数据集包含了来自互联网的 50000 条关于电影的评论, 分为正面评价和负面评价.

 

RNN 层

class RNN(tf.keras.Model):

    def __init__(self, units):
        super(RNN, self).__init__()

        # 初始化 [b, 64] (b 表示 batch_size)
        self.state0 = [tf.zeros([batch_size, units])]
        self.state1 = [tf.zeros([batch_size, units])]

        # [b, 80] => [b, 80, 100]
        self.embedding = tf.keras.layers.Embedding(total_words, embedding_len, input_length=max_review_len)

        self.rnn_cell0 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)
        self.rnn_cell1 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)

        # [b, 80, 100] => [b, 64] => [b, 1]
        self.out_layer = tf.keras.layers.Dense(1)

    def call(self, inputs, training=None):
        """

        :param inputs: [b, 80]
        :param training:
        :return:
        """

        state0 = self.state0
        state1 = self.state1

        x = self.embedding(inputs)

        for word in tf.unstack(x, axis=1):
            out0, state0 = self.rnn_cell0(word, state0, training=training)
            out1, state1 = self.rnn_cell1(out0, state1, training=training)

        # [b, 64] -> [b, 1]
        x = self.out_layer(out1)

        prob = tf.sigmoid(x)

        return prob

 

获取数据

def get_data():
    # 获取数据
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=total_words)

    # 更改句子长度
    X_train = tf.keras.preprocessing.sequence.pad_sequences(X_train, maxlen=max_review_len)
    X_test = tf.keras.preprocessing.sequence.pad_sequences(X_test, maxlen=max_review_len)

    # 调试输出
    print(X_train.shape, y_train.shape)  # (25000, 80) (25000,)
    print(X_test.shape, y_test.shape)  # (25000, 80) (25000,)

    # 分割训练集
    train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    train_db = train_db.shuffle(10000).batch(batch_size, drop_remainder=True)

    # 分割测试集
    test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test))
    test_db = test_db.batch(batch_size, drop_remainder=True)

    return train_db, test_db

 

完整代码

import tensorflow as tf


class RNN(tf.keras.Model):

    def __init__(self, units):
        super(RNN, self).__init__()

        # 初始化 [b, 64]
        self.state0 = [tf.zeros([batch_size, units])]
        self.state1 = [tf.zeros([batch_size, units])]

        # [b, 80] => [b, 80, 100]
        self.embedding = tf.keras.layers.Embedding(total_words, embedding_len, input_length=max_review_len)

        self.rnn_cell0 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)
        self.rnn_cell1 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)

        # [b, 80, 100] => [b, 64] => [b, 1]
        self.out_layer = tf.keras.layers.Dense(1)

    def call(self, inputs, training=None):
        """

        :param inputs: [b, 80]
        :param training:
        :return:
        """

        state0 = self.state0
        state1 = self.state1

        x = self.embedding(inputs)

        for word in tf.unstack(x, axis=1):
            out0, state0 = self.rnn_cell0(word, state0, training=training)
            out1, state1 = self.rnn_cell1(out0, state1, training=training)

        # [b, 64] -> [b, 1]
        x = self.out_layer(out1)

        prob = tf.sigmoid(x)

        return prob


# 超参数
total_words = 10000  # 文字数量
max_review_len = 80  # 句子长度
embedding_len = 100  # 词维度
batch_size = 1024  # 一次训练的样本数目
learning_rate = 0.0001  # 学习率
iteration_num = 20  # 迭代次数
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)  # 优化器
loss = tf.losses.BinaryCrossentropy(from_logits=True)  # 损失
model = RNN(64)

# 调试输出summary
model.build(input_shape=[None, 64])
print(model.summary())

# 组合
model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"])


def get_data():
    # 获取数据
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=total_words)

    # 更改句子长度
    X_train = tf.keras.preprocessing.sequence.pad_sequences(X_train, maxlen=max_review_len)
    X_test = tf.keras.preprocessing.sequence.pad_sequences(X_test, maxlen=max_review_len)

    # 调试输出
    print(X_train.shape, y_train.shape)  # (25000, 80) (25000,)
    print(X_test.shape, y_test.shape)  # (25000, 80) (25000,)

    # 分割训练集
    train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    train_db = train_db.shuffle(10000).batch(batch_size, drop_remainder=True)

    # 分割测试集
    test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test))
    test_db = test_db.batch(batch_size, drop_remainder=True)

    return train_db, test_db


if __name__ == "__main__":
    # 获取分割的数据集
    train_db, test_db = get_data()

    # 拟合
    model.fit(train_db, epochs=iteration_num, validation_data=test_db, validation_freq=1)

输出结果:

Model: "rnn"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) multiple 1000000
_________________________________________________________________
simple_rnn_cell (SimpleRNNCe multiple 10560
_________________________________________________________________
simple_rnn_cell_1 (SimpleRNN multiple 8256
_________________________________________________________________
dense (Dense) multiple 65
=================================================================
Total params: 1,018,881
Trainable params: 1,018,881
Non-trainable params: 0
_________________________________________________________________
None

(25000, 80) (25000,)
(25000, 80) (25000,)
Epoch 1/20
2021-07-10 17:59:45.150639: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
24/24 [==============================] - 12s 294ms/step - loss: 0.7113 - accuracy: 0.5033 - val_loss: 0.6968 - val_accuracy: 0.4994
Epoch 2/20
24/24 [==============================] - 7s 292ms/step - loss: 0.6951 - accuracy: 0.5005 - val_loss: 0.6939 - val_accuracy: 0.4994
Epoch 3/20
24/24 [==============================] - 7s 297ms/step - loss: 0.6937 - accuracy: 0.5000 - val_loss: 0.6935 - val_accuracy: 0.4994
Epoch 4/20
24/24 [==============================] - 8s 316ms/step - loss: 0.6934 - accuracy: 0.5001 - val_loss: 0.6933 - val_accuracy: 0.4994
Epoch 5/20
24/24 [==============================] - 7s 301ms/step - loss: 0.6934 - accuracy: 0.4996 - val_loss: 0.6933 - val_accuracy: 0.4994
Epoch 6/20
24/24 [==============================] - 8s 334ms/step - loss: 0.6932 - accuracy: 0.5000 - val_loss: 0.6932 - val_accuracy: 0.4994
Epoch 7/20
24/24 [==============================] - 10s 398ms/step - loss: 0.6931 - accuracy: 0.5006 - val_loss: 0.6932 - val_accuracy: 0.4994
Epoch 8/20
24/24 [==============================] - 9s 382ms/step - loss: 0.6930 - accuracy: 0.5006 - val_loss: 0.6931 - val_accuracy: 0.4994
Epoch 9/20
24/24 [==============================] - 8s 322ms/step - loss: 0.6924 - accuracy: 0.4995 - val_loss: 0.6913 - val_accuracy: 0.5240
Epoch 10/20
24/24 [==============================] - 8s 321ms/step - loss: 0.6812 - accuracy: 0.5501 - val_loss: 0.6655 - val_accuracy: 0.5767
Epoch 11/20
24/24 [==============================] - 8s 318ms/step - loss: 0.6381 - accuracy: 0.6896 - val_loss: 0.6235 - val_accuracy: 0.7399
Epoch 12/20
24/24 [==============================] - 8s 323ms/step - loss: 0.6088 - accuracy: 0.7655 - val_loss: 0.6110 - val_accuracy: 0.7533
Epoch 13/20
24/24 [==============================] - 8s 321ms/step - loss: 0.5949 - accuracy: 0.7956 - val_loss: 0.6111 - val_accuracy: 0.7878
Epoch 14/20
24/24 [==============================] - 8s 324ms/step - loss: 0.5859 - accuracy: 0.8142 - val_loss: 0.5993 - val_accuracy: 0.7904
Epoch 15/20
24/24 [==============================] - 8s 330ms/step - loss: 0.5791 - accuracy: 0.8318 - val_loss: 0.5961 - val_accuracy: 0.7907
Epoch 16/20
24/24 [==============================] - 8s 340ms/step - loss: 0.5739 - accuracy: 0.8421 - val_loss: 0.5942 - val_accuracy: 0.7961
Epoch 17/20
24/24 [==============================] - 9s 378ms/step - loss: 0.5701 - accuracy: 0.8497 - val_loss: 0.5933 - val_accuracy: 0.8014
Epoch 18/20
24/24 [==============================] - 9s 361ms/step - loss: 0.5665 - accuracy: 0.8589 - val_loss: 0.5958 - val_accuracy: 0.8082
Epoch 19/20
24/24 [==============================] - 8s 353ms/step - loss: 0.5630 - accuracy: 0.8681 - val_loss: 0.5931 - val_accuracy: 0.7966
Epoch 20/20
24/24 [==============================] - 8s 314ms/step - loss: 0.5614 - accuracy: 0.8702 - val_loss: 0.5925 - val_accuracy: 0.7959

Process finished with exit code 0

到此这篇关于手把手教你使用TensorFlow2实现RNN的文章就介绍到这了,更多相关TensorFlow2实现RNN内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python实现将内容分行输出
Nov 05 Python
对Python的Django框架中的项目进行单元测试的方法
Apr 11 Python
Python开发中爬虫使用代理proxy抓取网页的方法示例
Sep 26 Python
浅谈pandas中DataFrame关于显示值省略的解决方法
Apr 08 Python
python使用xlrd和xlwt读写Excel文件的实例代码
Sep 05 Python
python中使用zip函数出现错误的原因
Sep 28 Python
在Python中增加和插入元素的示例
Nov 01 Python
python的pytest框架之命令行参数详解(上)
Jun 27 Python
用openCV和Python 实现图片对比,并标识出不同点的方式
Dec 19 Python
python字符串,元组,列表,字典互转代码实例详解
Feb 14 Python
Django使用list对单个或者多个字段求values值实例
Mar 31 Python
Python轻量级web框架bottle使用方法解析
Jun 13 Python
一篇文章弄懂Python关键字、标识符和变量
python开发飞机大战游戏
详解Python中下划线的5种含义
Python操作CSV格式文件的方法大全
openstack中的rpc远程调用的方法
Python实现查询剪贴板自动匹配信息的思路详解
如何利用Python实现一个论文降重工具
You might like
php中count获取多维数组长度的方法
2014/11/03 PHP
php使用ZipArchive提示Fatal error: Class ZipArchive not found in的解决方法
2014/11/04 PHP
php几个预定义变量$_SERVER用法小结
2014/11/07 PHP
Windows下Apache + PHP SESSION丢失的解决过程全纪录
2015/04/07 PHP
PHP中substr_count()函数获取子字符串出现次数的方法
2016/01/07 PHP
PHP调用全国天气预报数据接口查询天气示例
2019/02/20 PHP
thinkphp框架实现路由重定义简化url访问地址的方法分析
2020/04/04 PHP
Ruffy javascript 学习笔记
2009/11/30 Javascript
深入理解JavaScript系列(7) S.O.L.I.D五大原则之开闭原则OCP
2012/01/15 Javascript
七个很有意思的PHP函数
2014/05/12 Javascript
js实现带有介绍的Select列表菜单实例
2015/08/18 Javascript
jquery可定制的在线UEditor编辑器
2015/11/17 Javascript
一个字符串中出现次数最多的字符 统计这个次数【实现代码】
2016/04/29 Javascript
Javascript oop设计模式 面向对象编程简单实例介绍
2016/12/13 Javascript
node简单实现一个更改头像功能的示例
2017/12/29 Javascript
jQuery实现的简单图片轮播效果完整示例
2018/02/08 jQuery
这15个Vue指令,让你的项目开发爽到爆
2019/10/11 Javascript
[01:32]DOTA2 2015国际邀请赛中国区预选赛第四日战报
2015/05/29 DOTA
Python批量重命名同一文件夹下文件的方法
2015/05/25 Python
Python操作MySQL模拟银行转账
2018/03/12 Python
Numpy掩码式数组详解
2018/04/17 Python
Python中遍历列表的方法总结
2019/06/27 Python
python plotly画柱状图代码实例
2019/12/13 Python
pytorch实现seq2seq时对loss进行mask的方式
2020/02/18 Python
Python Tricks 使用 pywinrm 远程控制 Windows 主机的方法
2020/07/21 Python
Python3中对json格式数据的分析处理
2021/01/28 Python
在浏览器端如何得到服务器端响应的XML数据
2012/11/24 面试题
印刷工程专业应届生求职信
2013/09/29 职场文书
给民警的表扬信
2014/01/08 职场文书
单位实习证明怎么写
2014/01/17 职场文书
优秀食品类广告词
2014/03/19 职场文书
小学生一分钟演讲稿
2014/08/26 职场文书
今日说法观后感
2015/06/08 职场文书
公司职员入党自传书
2015/06/26 职场文书
2016干部作风整顿心得体会
2016/01/22 职场文书
canvas 中如何实现物体的框选
2022/08/05 Javascript