tensorflow2.0处理文本数据

知识点:
1.使用tf.data处理文本数据
2.使用模型子类构建模型
3.单词编码
4.自定义迭代
5.自定义评估函数

# !/usr/bin/python
# -*- coding: utf-8 -*-
# @Time    : 2021/11/3 9:55
# @Author  : 郑浩鑫
# @Email   : [email protected]
# @File    : class2.py
# @Software: PyCharm
''''
文本数据
'''
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow.keras import models,layers,preprocessing,optimizers,losses,metrics
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization
import re,string
train_data_path = "G:/resourceCode_github/eat_tensorflow2_in_30_days-master/data/imdb/train.csv"
test_data_path =  "G:/resourceCode_github/eat_tensorflow2_in_30_days-master/data/imdb/test.csv"
MAX_WORDS = 10000  # 仅考虑最高频的10000个词
MAX_LEN = 200  # 每个样本保留200个词的长度
BATCH_SIZE = 20#构建管道
def split_line(line):arr = tf.strings.split(line,"\t")label = tf.expand_dims(tf.cast(tf.strings.to_number(arr[0]),tf.int32),axis = 0)text = tf.expand_dims(arr[1],axis = 0)return (text,label)ds_train_raw =  tf.data.TextLineDataset(filenames = [train_data_path]) \.map(split_line,num_parallel_calls = tf.data.experimental.AUTOTUNE) \.shuffle(buffer_size = 1000).batch(BATCH_SIZE) \.prefetch(tf.data.experimental.AUTOTUNE)
ds_test_raw = tf.data.TextLineDataset(filenames = [test_data_path]) \.map(split_line,num_parallel_calls = tf.data.experimental.AUTOTUNE) \.batch(BATCH_SIZE) \.prefetch(tf.data.experimental.AUTOTUNE)
#构建词典
def clean_text(text):lowercase = tf.strings.lower(text)stripped_html = tf.strings.regex_replace(lowercase, '<br />', ' ')cleaned_punctuation = tf.strings.regex_replace(stripped_html,'[%s]' % re.escape(string.punctuation),'')return cleaned_punctuation
vectorize_layer = TextVectorization(standardize=clean_text,split = 'whitespace',max_tokens=MAX_WORDS-1, #有一个留给占位符output_mode='int',output_sequence_length=MAX_LEN)
ds_text = ds_train_raw.map(lambda text,label: text)
vectorize_layer.adapt(ds_text)
print(vectorize_layer.get_vocabulary()[0:100])
#单词编码
ds_train = ds_train_raw.map(lambda text,label:(vectorize_layer(text),label)).prefetch(tf.data.experimental.AUTOTUNE)
ds_test = ds_test_raw.map(lambda text,label:(vectorize_layer(text),label)).prefetch(tf.data.experimental.AUTOTUNE)# 演示自定义模型范例,实际上应该优先使用Sequential或者函数式API
tf.keras.backend.clear_session()
class CnnModel(models.Model):def __init__(self):super(CnnModel, self).__init__()def build(self,input_shape):self.embedding = layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN)self.conv_1 = layers.Conv1D(16, kernel_size= 5,name = "conv_1",activation = "relu")self.pool = layers.MaxPool1D()self.conv_2 = layers.Conv1D(128, kernel_size=2,name = "conv_2",activation = "relu")self.flatten = layers.Flatten()self.dense = layers.Dense(1,activation = "sigmoid")super(CnnModel,self).build(input_shape)def call(self, x):x = self.embedding(x)x = self.conv_1(x)x = self.pool(x)x = self.conv_2(x)x = self.pool(x)x = self.flatten(x)x = self.dense(x)return(x)
model = CnnModel()
model.build(input_shape =(None,MAX_LEN))
print(model.summary())#打印时间分割线
@tf.function
def printbar():today_ts = tf.timestamp()%(24*60*60)hour = tf.cast(today_ts//3600+8,tf.int32)%tf.constant(24)minite = tf.cast((today_ts%3600)//60,tf.int32)second = tf.cast(tf.floor(today_ts%60),tf.int32)def timeformat(m):if tf.strings.length(tf.strings.format("{}",m))==1:return(tf.strings.format("0{}",m))else:return(tf.strings.format("{}",m))timestring = tf.strings.join([timeformat(hour),timeformat(minite),timeformat(second)],separator = ":")tf.print("=========="*8+timestring)optimizer = optimizers.Nadam()
loss_func = losses.BinaryCrossentropy()
train_loss = metrics.Mean(name='train_loss')
train_metric = metrics.BinaryAccuracy(name='train_accuracy')
valid_loss = metrics.Mean(name='valid_loss')
valid_metric = metrics.BinaryAccuracy(name='valid_accuracy')@tf.function
def train_step(model, features, labels):with tf.GradientTape() as tape:predictions = model(features,training = True)loss = loss_func(labels, predictions)gradients = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(gradients, model.trainable_variables))train_loss.update_state(loss)train_metric.update_state(labels, predictions)@tf.function
def valid_step(model, features, labels):predictions = model(features,training = False)batch_loss = loss_func(labels, predictions)valid_loss.update_state(batch_loss)valid_metric.update_state(labels, predictions)def train_model(model,ds_train,ds_valid,epochs):for epoch in tf.range(1,epochs+1):for features, labels in ds_train:train_step(model,features,labels)for features, labels in ds_valid:valid_step(model,features,labels)#此处logs模板需要根据metric具体情况修改logs = 'Epoch={},Loss:{},Accuracy:{},Valid Loss:{},Valid Accuracy:{}'if epoch%1==0:printbar()tf.print(tf.strings.format(logs,(epoch,train_loss.result(),train_metric.result(),valid_loss.result(),valid_metric.result())))tf.print("")train_loss.reset_states()valid_loss.reset_states()train_metric.reset_states()valid_metric.reset_states()train_model(model,ds_train,ds_test,epochs = 6)# 通过自定义训练循环训练的模型没有经过编译,无法直接使用model.evaluate(ds_valid)方法# 评估模型
def evaluate_model(model,ds_valid):for features, labels in ds_valid:valid_step(model,features,labels)logs = 'Valid Loss:{},Valid Accuracy:{}'tf.print(tf.strings.format(logs,(valid_loss.result(),valid_metric.result())))valid_loss.reset_states()train_metric.reset_states()valid_metric.reset_states()evaluate_model(model,ds_test)

tensorflow2.0处理文本数据相关推荐

  1. Tensorflow2.0之文本生成莎士比亚作品

    文章目录 1.导入数据 2.创建模型 3.训练 3.1 编译模型 3.2 配置检查点 3.3 训练模型 4.预测 4.1 重建模型 4.2 生成文本 我们将使用 Andrej Karpathy 在&l ...

  2. 【TensorFlow2.0】(6) 数据统计,范数、最值、求和、均值、最值位置、唯一值、张量比较

    各位同学好,今天和大家分享一下TensorFlow2.0中的数据分析操作.内容有: (1)范数 tf.norm():(2)最值 tf.reduce_min(), tf.reduce_max()(3)求 ...

  3. tensorflow2.0 RNN文本预测

    https://blog.csdn.net/TeFuirnever/article/details/102686744

  4. tensorflow2.0下载mnist数据存放位置

    window:c/用户/账号/.keras/datasets linux:/home/用户名/.keras/datasets

  5. TensorFlow2.0教程-使用keras训练模型

    TensorFlow2.0教程-使用keras训练模型 Tensorflow 2.0 教程持续更新: https://blog.csdn.net/qq_31456593/article/details ...

  6. tensorflow2.0教程- Keras 快速入门

    tensorflow2.0教程-tensorflow.keras 快速入门 Tensorflow 2.0 教程持续更新: https://blog.csdn.net/qq_31456593/artic ...

  7. TensorFlow2.0 教程-图像分类

    TensorFlow2.0 教程-图像分类 Tensorflow 2.0 教程持续更新: https://blog.csdn.net/qq_31456593/article/details/88606 ...

  8. linux tf2 中文,ocrcn_tf2: TensorFlow2.0的中文汉字手写体识别!OCR必备,欢迎star!

    TensorFlow 2.0 中文手写字识别(汉字OCR) 在开始之前,必须要说明的是,本教程完全基于TensorFlow2.0 接口编写,请误与其他古老的教程混为一谈,本教程除了手把手教大家完成这个 ...

  9. pip更新失败_最全Tensorflow2.0 入门教程持续更新

    最全Tensorflow 2.0 入门教程持续更新: Doit:最全Tensorflow 2.0 入门教程持续更新​zhuanlan.zhihu.com 完整tensorflow2.0教程代码请看ht ...

最新文章

  1. 六个月学php,修学六个月心得体会
  2. wsl2设置挂载_Windows下的Linux子系统安装,WSL 2下配置docker
  3. uni-app定时器清除问题
  4. Servlet之javaweb应用(二)
  5. leetcode 452. 用最少数量的箭引爆气球(贪心算法)
  6. leetcode1509. 三次操作后最大值与最小值的最小差
  7. Qt容器类(总结)(新发现的QQueue和QStack,注意全都是泛型)
  8. Flash翻书效果研究
  9. 新浪sae部署html,利用新浪sae搭建discuz x2论坛
  10. C++模板之一:函数模板.odt
  11. [SAP ABAP开发技术总结]ABAP调优——Open SQL优化
  12. warning C4251编译警告解决办法
  13. yum rpm apt-get wget 辨析
  14. 在Sun Java System Web Server上使用Quercus运行PHP
  15. excel公式识别html,POI/Excel/HTML单元格公式问题
  16. mysql_wp_replication_tutorial
  17. idea修改css,js样式浏览器没更新问题
  18. 电力系统卫星时钟同步(GPS北斗授时)组成及配置
  19. 收藏|Java程序员必看的几本基础书籍和常用工具
  20. 重庆网络公司的几种死法

热门文章

  1. DDR3 SDRAM分析
  2. Alian解读SpringBoot 2.6.0 源码(五):启动流程分析之打印Banner
  3. 基于萤火虫优化的BP神经网络(分类应用) - 附代码
  4. Pygame(七) 碰撞检测
  5. Da网络编程、正则表达式
  6. 制鞋行业ERP管理系统应用解决方案(3)
  7. UniverSeg:通用医学图像分割模型来了!
  8. 亚马逊电动玩具CPC认证测试标准要求
  9. 集成PayPal支付
  10. 利用主成分分析(PCA)法对基金进行排名