多输入模型多适用于问答模型或者对于时间序列模型来说有部分特征是针对样本个体而固定的,不随时间变换而发生改变的情况下。
对于模型的输入数据格式来说,有很多种方式,普通的全部数据导入,或者写成生成器等,可以逐批读取数据然后训练模型,但是当你使用tensorflow内置分布式训练,也就是多机多卡模卡MultiWorkerMirroredStrategy的时候,就必须使用Dataset格式。
因为Dataset会自动根据batch_size分发数据进行迭代训练。

如果对MultiWorkerMirroredStrategy以及MirroredStrateg两种训练模型感兴趣的可以去看看我的另外一篇文章。https://blog.csdn.net/qq_35869630/article/details/106313745

一、多输入模型

import tensorflow as tf
from tensorflow.keras.layers import LSTM, Dense, concatenate,Bidirectional
from tensorflow.keras import Input,Modeldef build_and_compile_model():inputA = Input(shape=(7, 22))inputB = Input(shape=(3,))x = LSTM(128, return_sequences=False, activation="relu")(inputA)x = Dense(128, activation="relu")(x)x = Dense(64, activation="relu")(x)x = Model(inputs=inputA, outputs=x)y = Dense(32, activation="relu")(inputB)y = Dense(16, activation="relu")(y)y = Model(inputs=inputB, outputs=y)combined = concatenate([x.output, y.output], axis=-1)z = Dense(32, activation="relu")(combined)z = Dense(1, activation="sigmoid")(z)model = Model(inputs=[x.input, y.input], outputs=z)model.compile(loss='binary_crossentropy', optimizer='rmsprop',metrics=['accuracy'])# rmspropmodel.summary()return model

介绍一下背景,我们这里构建的 模型主要是为了用于观察某个玩家7天内的22个特征数据变化,以及这个玩家的3个原本固定属性,来判断这个玩家是否违反了规定。

inputA 表示的是对时间序列长度为7,特征数为22的输入数据。
inputB 表示的是对特征数为3个的数据的输入。
x = Model(inputs=inputA, outputs=x)
y = Model(inputs=inputB, outputs=y)
构建过单输入模型的都懂这是将模型建立起来,设置输入输出层。
concatenate 将两个模型联结起来。axis有多种取值。
axis=n表示从第n个维度进行拼接,对于一个三维矩阵,axis的取值可以为[-3, -2, -1, 0, 1, 2]。虽然keras用模除允许axis的取值可以在这个范围之外,但不建议那么用。
如果看不懂可以看看大佬介绍的
https://blog.csdn.net/leviopku/article/details/82380710

简单的结构如下


这样就完成了多输入模型的构建。
那针对这样的对输入模型,怎么构建相对应的dataset呢?

二、Dataset

def create_dataset(batch_size,num_workers,clean_data):train_x, train_s_x, train_y, test_x, test_s_x, test_y = clean_data[:]# 转成 datasetg_batch_size = batch_size*num_workerssteps_per_epoch_size = int(train_x.shape[0] / g_batch_size)test_steps = int(test_x.shape[0] / g_batch_size)def func_s(x1, x2):return [x1, x2]input_dataset = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices((train_x, train_s_x)).map(func_s),tf.data.Dataset.from_tensor_slices(train_y)))train_datasets_unbatched = input_dataset.cache().shuffle(1000)train_datasets = train_datasets_unbatched.batch(g_batch_size).repeat(-1)test_input_dataset = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices((test_x, test_s_x)).map(func_s),tf.data.Dataset.from_tensor_slices(test_y)))test_input_dataset = test_input_dataset.cache().shuffle(1000)test_datasets = test_input_dataset.batch(g_batch_size).repeat(-1)return train_datasets,test_datasets,steps_per_epoch_size,test_steps

因为是多输入模型,所以这里用到了 tf.data.Dataset.zip 将两个输入合并起来。
repeat(-1)设置为无限循环取数。

这里其他函数也挺简单的就不多做介绍了。

直接看全部完整的代码。

def get_multi_strategy():"""分布式模式:return:"""import tensorflow as tfimport argparseimport osimport jsonparser = argparse.ArgumentParser(description='tensorflow_test')parser.add_argument('-worker', default=0)  # --work_numargs, unknown_args = parser.parse_known_args()HOST_CONFIG = ip_listos.environ['TF_CONFIG'] = json.dumps({'cluster': {'worker': HOST_CONFIG},'task': {'type': 'worker', 'index': args.worker}})print (os.environ['TF_CONFIG'])strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()return strategydef build_and_compile_model():inputA = Input(shape=(7, 22))inputB = Input(shape=(3,))x = LSTM(128, return_sequences=False, activation="relu")(inputA)x = Dense(128, activation="relu")(x)x = Dense(64, activation="relu")(x)x = Model(inputs=inputA, outputs=x)y = Dense(32, activation="relu")(inputB)y = Dense(16, activation="relu")(y)y = Model(inputs=inputB, outputs=y)combined = concatenate([x.output, y.output], axis=0)z = Dense(32, activation="relu")(combined)z = Dense(1, activation="sigmoid")(z)model = Model(inputs=[x.input, y.input], outputs=z)model.compile(loss='binary_crossentropy', optimizer='rmsprop',metrics=['accuracy'])# rmspropmodel.summary()return model# batch_size 单机训练的size,num_workers 机器个数
def create_dataset(batch_size,num_workers,clean_data):train_x, train_s_x, train_y, test_x, test_s_x, test_y = clean_data[:]# 转成 datasetg_batch_size = batch_size*num_workerssteps_per_epoch_size = int(train_x.shape[0] / g_batch_size)test_steps = int(test_x.shape[0] / g_batch_size)def func_s(x1, x2):return [x1, x2]input_dataset = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices((train_x, train_s_x)).map(func_s),tf.data.Dataset.from_tensor_slices(train_y)))train_datasets_unbatched = input_dataset.cache().shuffle(1000)train_datasets = train_datasets_unbatched.batch(g_batch_size).repeat(-1)test_input_dataset = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices((test_x, test_s_x)).map(func_s),tf.data.Dataset.from_tensor_slices(test_y)))test_input_dataset = test_input_dataset.cache().shuffle(1000)test_datasets = test_input_dataset.batch(g_batch_size).repeat(-1)return train_datasets,test_datasets,steps_per_epoch_size,test_stepsif __name__ == '__main__':sequence_length = 7batch_size = 64num_workers = 4epochs = 50strategy = get_multi_strategy()clean_d = clean_data()c_data = train_data(clean_d, sequence_length)train_datasets, test_datasets, steps_per_epoch_size, test_steps = create_dataset(batch_size, num_workers, c_data)print("steps_per_epoch_size,test_steps:", steps_per_epoch_size, test_steps)with strategy.scope():model = build_and_compile_model()print('start fit model')history = model.fit(train_datasets, epochs=epochs, steps_per_epoch=steps_per_epoch_size, shuffle=True,validation_data=test_datasets, validation_steps=test_steps)

三、总结

tensorflow多输入多输出模型结构具备多样化,dataset也要跟着进行改变,如果你不采用MultiWorkerMirroredStrategy多机多卡模型的话也可以不用dataset数据格式,根据实际情况进行调整。

顺便介绍一下MultiWorkerMirroredStrategy多机多卡模式确实能够在机器不行的情况下通过多加机器的方法起到加速模型训练的作用,但是由于还在实验阶段,并不具备所有完整的函数,因此使用的时候也要稍加注意。

我是一只前进的蚂蚁,希望能一起前行。

如果对您有一点帮助,一个赞就够了,感谢!

注:如果本篇博客有任何错误和建议,欢迎各位指出,不胜感激!!!

Tensorflow多输入模型构建以及Dataset数据构建相关推荐

  1. R语言使用glm函数构建泊松对数线性回归模型处理三维列联表数据构建饱和模型、使用summary函数获取模型汇总统计信息

    R语言使用glm函数构建泊松对数线性回归模型处理三维列联表数据构建饱和模型.使用summary函数获取模型汇总统计信息 目录

  2. R语言使用glm函数构建泊松对数线性回归模型处理三维列联表数据构建饱和模型、使用step函数基于AIC指标实现逐步回归筛选最佳模型、使用summary函数查看简单模型的汇总统计信息

    R语言使用glm函数构建泊松对数线性回归模型处理三维列联表数据构建饱和模型.使用step函数基于AIC指标实现逐步回归筛选最佳模型.使用summary函数查看简单模型的汇总统计信息 目录

  3. 利用TensorFlow2.0为胆固醇、血脂、血压数据构建时序深度学习模型(python完整源代码)

    背景数据描述 胆固醇.高血脂.高血压是压在广大中年男性头上的三座大山,如何有效的监控他们,做到早发现.早预防.早治疗尤为关键,趁着这个假期我就利用TF2.0构建了一套时序预测模型,一来是可以帮我预发疾 ...

  4. TensorFlow构建二维数据拟合模型(1)

    知识图谱 TensorFlow运行机制 TensorFlow是基于计算图的深度学习编程模型 Tensor表示张量,其实质上是某种类型的多维数组 Flow表示基于数据流图的计算,实质上是张量在不同节点间 ...

  5. 深度学习-Pytorch:项目标准流程【构建、保存、加载神经网络模型;数据集构建器Dataset、数据加载器DataLoader(线性回归案例、手写数字识别案例)】

    1.拿到文本,分词,清晰数据(去掉停用词语): 2.建立word2index.index2word表 3.准备好预训练好的word embedding 4.做好DataSet / Dataloader ...

  6. python血压测量程序代码_利用TensorFlow2.0为胆固醇、血脂、血压数据构建时序深度学习模型(python源代码)...

    背景数据描述 胆固醇.高血脂.高血压是压在广大中年男性头上的三座大山,如何有效的监控他们,做到早发现.早预防.早治疗尤为关键,趁着这个假期我就利用TF2.0构建了一套时序预测模型,一来是可以帮我预发疾 ...

  7. TensorFlow构建二维数据拟合模型(3)

    占位符与数据喂入机制 placeholder是TensorFlow提供的占位符节点,由tf.placeholder()函数创建,其实质上也是一种变量.占位符没有初始值,只会分配必要的内存,其值由会话中 ...

  8. R语言构建文本分类模型:文本数据预处理、构建词袋模型(bag of words)、构建xgboost文本分类模型、基于自定义函数构建xgboost文本分类模型

    R语言构建文本分类模型:文本数据预处理.构建词袋模型(bag of words).构建xgboost文本分类模型.基于自定义函数构建xgboost文本分类模型 目录

  9. R语言构建文本分类模型并使用LIME进行模型解释实战:文本数据预处理、构建词袋模型、构建xgboost文本分类模型、基于文本训练数据以及模型构建LIME解释器解释多个测试语料的预测结果并可视化

    R语言构建文本分类模型并使用LIME进行模型解释实战:文本数据预处理.构建词袋模型.构建xgboost文本分类模型.基于文本训练数据以及模型构建LIME解释器解释多个测试语料的预测结果并可视化 目录

  10. excel数据输入模型前的转换

    对于excel的数据输入神经网络前要进行数据类型转换,不然好像会有问题,如果能直接输入请指教.下面讲讲对excel数据的转换. 首先对原始数据进行解释一下:下图是部分训练数据,前20列是特征,第21列 ...

最新文章

  1. 《网络攻防实践》第二周作业
  2. python读取大文件csv内存溢出_Python,内存错误,csv文件太大
  3. Java CAS AtomicInteger使用
  4. python解决open()函数、xlrd.open_workbook()函数文件名包含中文,sheet名包含中文报错的问题
  5. Linux的Nginx七:对比|模块
  6. delphi中combobox键值对
  7. 图片image和byte处理,fileupload上传图片
  8. Windows任务管理 连接用户登录信息 通用类[C#版]
  9. Dynamics CRM 2013 installation
  10. 单反相机的常用的几个参数之间的关系
  11. echarts隐藏之后的显示问题
  12. cf——Sasha and a Bit of Relax(dp,math)
  13. RK3288 USB触摸屏无法使用,需要添加PID和VID
  14. python语言的多行注释以什么开头和结尾_python多行注释
  15. Organon将收购Forendo Pharma
  16. CDPSE-数据隐私解决方案工程师
  17. 人工智能应用实例:图片降噪
  18. python Linux下的安装
  19. python 按键精灵识图_Python实现按键精灵(二)-找图找色
  20. 7-2 jmu-Java-03面向对象-06-继承覆盖综合练习-Person、Student、Employee、Company (15分)

热门文章

  1. matlab生成面导出stl格式,导出建模文件到STL格式时需要注意的问题
  2. JavaMail概述
  3. 页面置换算法java_页面置换算法之Clock算法
  4. Java定时任务自动调用方法
  5. 华为推出首款折叠屏5G手机;微信“上车”时间已定;社区团购暗潮涌动
  6. java:单例模式的五种实现方式
  7. VS2015社区版安装教程
  8. linux命令行 teamview,linux centos 命令行 安装 teamviewer 启动 停止
  9. mysql 多重循环_SQL循环语句 详解
  10. 圣经与超级计算机,圣经创世纪里的时间概念和爱因斯坦相对论