TensorFlow2.0教程-使用RNN实现文本分类

原文地址:https://blog.csdn.net/qq_31456593/article/details/89923645

Tensorflow 2.0 教程持续更新 :https://blog.csdn.net/qq_31456593/article/details/88606284

本教程主要由tensorflow2.0官方教程的个人学习复现笔记整理而来,并借鉴了一些keras构造神经网络的方法,中文讲解,方便喜欢阅读中文教程的朋友,tensorflow官方教程:https://www.tensorflow.org

完整tensorflow2.0教程代码请看https://github.com/czy36mengfei/tensorflow2_tutorials_chinese (欢迎star)

Tensorflow2.0部分教程内容:

TensorFlow 2.0 教程- Keras 快速入门
TensorFlow 2.0 教程-keras 函数api
TensorFlow 2.0 教程-使用keras训练模型
TensorFlow 2.0 教程-用keras构建自己的网络层
TensorFlow 2.0 教程-keras模型保存和序列化

1.使用tensorflow_datasets 构造输入数据

!pip install -q tensorflow_datasets
[31mspacy 2.0.18 has requirement numpy>=1.15.0, but you'll have numpy 1.14.3 which is incompatible.[0m
[31mplotnine 0.5.1 has requirement matplotlib>=3.0.0, but you'll have matplotlib 2.2.2 which is incompatible.[0m
[31mplotnine 0.5.1 has requirement pandas>=0.23.4, but you'll have pandas 0.23.0 which is incompatible.[0m
[31mneo4j-driver 1.6.2 has requirement neotime==1.0.0, but you'll have neotime 1.7.2 which is incompatible.[0m
[31mmizani 0.5.3 has requirement pandas>=0.23.4, but you'll have pandas 0.23.0 which is incompatible.[0m
[31mfastai 0.7.0 has requirement torch<0.4, but you'll have torch 0.4.1 which is incompatible.[0m
[33mYou are using pip version 10.0.1, however version 19.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m
import tensorflow_datasets as tfds
dataset, info = tfds.load('imdb_reviews/subwords8k', with_info=True,as_supervised=True)

获取训练集、测试集

train_dataset, test_dataset = dataset['train'], dataset['test']

获取tokenizer对象,用进行字符处理级id转换(这里先转换成subword,再转换为id)等操作

tokenizer = info.features['text'].encoder
print('vocabulary size: ', tokenizer.vocab_size)
vocabulary size:  8185

token对象测试

sample_string = 'Hello word , Tensorflow'
tokenized_string = tokenizer.encode(sample_string)
print('tokened id: ', tokenized_string)# 解码会原字符串
src_string = tokenizer.decode(tokenized_string)
print('original string: ', src_string)
tokened id:  [4025, 222, 2621, 1199, 6307, 2327, 2934]
original string:  Hello word , Tensorflow

解出每个subword

for t in tokenized_string:print(str(t)+'->['+tokenizer.decode([t])+ ']')
4025->[Hell]
222->[o ]
2621->[word]
1199->[ , ]
6307->[Ten]
2327->[sor]
2934->[flow]

构建批次训练集

BUFFER_SIZE=10000
BATCH_SIZE = 64train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.padded_batch(BATCH_SIZE, train_dataset.output_shapes)test_dataset = test_dataset.padded_batch(BATCH_SIZE, test_dataset.output_shapes)

模型构建

下面,因为此处的句子是变长的,所以只能使用序列模型,而不能使用keras的函数api

# def get_model():
#     inputs = tf.keras.Input((1240,))
#     emb = tf.keras.layers.Embedding(tokenizer.vocab_size, 64)(inputs)
#     h1 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64))(emb)
#     h1 = tf.keras.layers.Dense(64, activation='relu')(h1)
#     outputs = tf.keras.layers.Dense(1, activation='sigmoid')(h1)
#     model = tf.keras.Model(inputs, outputs)
#     return modeldef get_model():model = tf.keras.Sequential([tf.keras.layers.Embedding(tokenizer.vocab_size, 64),tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(1, activation='sigmoid')])return model
model = get_model()
model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])

模型训练

history = model.fit(train_dataset, epochs=10,validation_data=test_dataset)
Epoch 1/10
391/391 [==============================] - 827s 2s/step - loss: 0.5606 - accuracy: 0.7068 - val_loss: 0.0000e+00 - val_accuracy:
....
Epoch 10/10
391/391 [==============================] - 791s 2s/step - loss: 0.1333 - accuracy: 0.9548 - val_loss: 0.6117 - val_accuracy: 0.8199
# 查看训练过程
import matplotlib.pyplot as pltdef plot_graphs(history, string):plt.plot(history.history[string])plt.plot(history.history['val_'+string])plt.xlabel('epochs')plt.ylabel(string)plt.legend([string, 'val_'+string])plt.show()plot_graphs(history, 'accuracy')

plot_graphs(history, 'loss')

测试

test_loss, test_acc = model.evaluate(test_dataset)
print('test loss: ', test_loss)
print('test acc: ', test_acc)
    391/Unknown - 68s 174ms/step - loss: 0.6117 - accuracy: 0.8199test loss:  0.6117385012262008
test acc:  0.81988

上述模型不会mask掉序列的padding,所以如果在有padding的寻列上训练,测试没有padding的序列时可能有所偏差。

def pad_to_size(vec, size):zeros = [0] * (size-len(vec))vec.extend(zeros)return vecdef sample_predict(sentence, pad=False):tokened_sent = tokenizer.encode(sentence)if pad:tokened_sent = pad_to_size(tokened_sent, 64)pred = model.predict(tf.expand_dims(tokened_sent, 0))return pred
# 没有padding的情况
sample_pred_text = ('The movie was cool. The animation and the graphics ''were out of this world. I would recommend this movie.')
predictions = sample_predict(sample_pred_text, pad=False)
print(predictions)
[[0.2938048]]
# 有paddin的情况
sample_pred_text = ('The movie was cool. The animation and the graphics ''were out of this world. I would recommend this movie.')
predictions = sample_predict(sample_pred_text, pad=True)
print (predictions)
[[0.42541984]]

堆叠更多的lstm层

from tensorflow.keras import layers
model = keras.Sequential(
[layers.Embedding(tokenizer.vocab_size, 64),layers.Bidirectional(layers.LSTM(64, return_sequences=True)),layers.Bidirectional(layers.LSTM(32)),layers.Dense(64, activation='relu'),layers.Dense(1, activation='sigmoid')
])
model.compile(loss=tf.keras.losses.binary_crossentropy,optimizer=tf.keras.optimizers.Adam(),metrics=['accuracy'])
history=model.fit(train_dataset, epochs=6, validation_data=test_dataset)
Epoch 1/6
391/391 [==============================] - 1646s 4s/step - loss: 0.5270 - accuracy: 0.7414 - val_loss: 0.0000e+00 - val_accuracy: ....
Epoch 6/6
391/391 [==============================] - 1622s 4s/step - loss: 0.1619 - accuracy: 0.9430 - val_loss: 0.5484 - val_accuracy: 0.7808
plot_graphs(history, 'accuracy')

plot_graphs(history, 'loss')

res = model.evaluate(test_dataset)
print(res)
    391/Unknown - 125s 320ms/step - loss: 0.5484 - accuracy: 0.7808[0.5484032468570162, 0.78084]

TensorFlow2.0教程-使用RNN实现文本分类相关推荐

  1. TensorFlow2.0教程-使用keras训练模型

    TensorFlow2.0教程-使用keras训练模型 Tensorflow 2.0 教程持续更新: https://blog.csdn.net/qq_31456593/article/details ...

  2. tensorflow2.0教程- Keras 快速入门

    tensorflow2.0教程-tensorflow.keras 快速入门 Tensorflow 2.0 教程持续更新: https://blog.csdn.net/qq_31456593/artic ...

  3. TensorFlow2.0教程-keras 函数api

    TensorFlow2.0教程-keras 函数api Tensorflow 2.0 教程持续更新: https://blog.csdn.net/qq_31456593/article/details ...

  4. TensorFlow2.0 教程-图像分类

    TensorFlow2.0 教程-图像分类 Tensorflow 2.0 教程持续更新: https://blog.csdn.net/qq_31456593/article/details/88606 ...

  5. 【论文复现】使用RNN进行文本分类

    写在前面 这是文本分类任务的第二个系列----基于RNN的文本分类实现(Text RNN) 复现的论文是2016年复旦大学IJCAI 上的发表的关于循环神经网络在多任务文本分类上的应用:Recurre ...

  6. tensorflow2.0 基于LSTM模型的文本生成

    春水碧于天,画船听雨眠 基于LSTM模型的唐诗文本生成 实验基本要求 实验背景 实验数据下载 LSTM模型分析 实验过程 文本预处理 编解码模型 LSTM模型设置 实验代码 实验结果 总结 致谢 实验 ...

  7. 【深度学习】利用tensorflow2.0卷积神经网络进行卫星图片分类实例操作详解

    本文的应用场景是对于卫星图片数据的分类,图片总共1400张,分为airplane和lake两类,也就是一个二分类的问题,所有的图片已经分别放置在2_class文件夹下的两个子文件夹中.下面将从这个实例 ...

  8. [Python人工智能] 二十.基于Keras+RNN的文本分类vs基于传统机器学习的文本分类

    从本专栏开始,作者正式研究Python深度学习.神经网络及人工智能相关知识.前一篇文章分享了循环神经网络RNN的原理知识,并采用Keras实现手写数字识别的RNN分类案例及可视化呈现.这篇文章作者将带 ...

  9. 神经网络与深度学习理论,tensorflow2.0教程,cnn

    *免责声明: 1\此方法仅提供参考 2\搬了其他博主的操作方法,以贴上路径. 3* 场景一:神经网络与深度学习理论 场景二:tensorflow的安装 场景三:numpy包介绍 场景四:机器学习基础 ...

最新文章

  1. Doxygen生成代码关系调用图
  2. UA MATH524 复变函数13 补充:留数计算的例题
  3. .NET、C#和ASP.NET,ASP.NET MVC 四者之间的区别
  4. codeforces B. Fox and Cross 解题报告
  5. java中implements是什么意思_java中extends与implements区别
  6. mongovue mysql_MongoDB 客户端 MongoVue
  7. 实训三:交换机恢复出厂设置以及基本配置
  8. adb shell screencap/screenrecord(三级命令)
  9. edem颗粒替换_Altair EDEM Professional 2020.2安装教程(附替换补丁)
  10. 打印机主流的指令类型(ESC命令集+CPCL命令集+TSPL命令集)...
  11. Linux下把ncsi设置成OCP模式,一种支持NCSI信号管理功能自动切换的电路及服务器的制作方法...
  12. taocat服务器的作用,随笔2_tww
  13. Tomasulo算法与记分牌算法的区别
  14. 小程序或者公众号授权给第三方平台流程
  15. Python合并有相同列的两个表格
  16. 城市选择插件 V-Distpicker 组件详解以及全套用法
  17. java趣味编程心形_求源代码!(迪卡尔心形图案)
  18. 杂言乱谈,以后的日志很辉煌?
  19. 【146期】面试官问:说一说 RabbitMQ 的几种工作模式和优化建议?
  20. 树的结点数+蒲丰投针概率

热门文章

  1. 佛系计算机二级第二弹
  2. 淘宝开放平台淘宝店铺OAuth2.0订单商品接口接入解决方案
  3. 个人申办在职人才引进
  4. 关于matlab如何使用符号变量建立符号函数表达式
  5. cookie的组成与作用
  6. wps怎么做时间线_ 在家办公总是做不好时间管理怎么办?学会加减乘除轻松搞定...
  7. word插入隐形表格,轻松实现公式居中、公式编号右对齐
  8. python文件与数据格式化
  9. swing- 使用颜色画笔装饰你的容器背景
  10. VLAN划分及配置注意事项