你能通过“优雅的单宁香”、“成熟的黑醋栗香气”或“浓郁的酒香”这样的描述,预测葡萄酒的价格吗?事实证明,机器学习模型可以。

在这篇文章中,我将解释我是如何利用Keras(tf.keras)建立一个Wide & Deep神经网络,并基于产品描述来预测葡萄酒的价格。对于那些刚接触Keras的人来说,这个用于构建ML模型的TensorFlow API,已经是更高级别的方法了。如果你想直接获取代码,可以去GitHub上查找。你也可以在浏览器中直接运行这个模型,用Colab无需进行设置。

GitHub:https://github.com/sararob/keras-wine-model

colab:https://colab.research.google.com/github/sararob/keras-wine-model/blob/master/keras-wide-deep.ipynb

利用Keras建立Wide & Deep神经网络

我最近一直在用Sequential模型API构建很多Keras模型,但这次我想尝试一下Functional API。Sequential API是Keras的最佳入门方法,它可以让你轻松地将模型定义为层堆栈。而Functional API允许更多灵活性,最适合应用于多重输入模型或组合模型。Functional API的一个实例,就是在Keras中实现一个Wide & Deep网络。因为已经有很多关于Wide & Deep方面的资源,所以我不会描述太多细节。

Wide & Deep Learning:https://research.googleblog.com/2016/06/wide-deep-learning-better-together-with.html

在你用Wide & Deep网络来解决你的ML问题之前,最好确保它非常适合你想要预测的问题。如果你有一个预测任务,输入和输出之间有相对直接的关系,那么一个wide模型可能就足够了。Wide模型是具有稀少特征向量的模型,或者说是大多为零值向量的模型。另一方面,多层深度网络在图像或语音识别等任务中表现良好,在这样的情况下,输入和输出之间可能存在意想不到的关系。如果你的预测任务可以从这两种模型中获益(推荐模型或带有文本输入的模型都是很好的例子),wide & deep可能就会非常适合你的问题。在这种情况下,我分别尝试了wide模型和deep模型,又将它们结合起来,结果发现wide & deep组合精确度最高。

数据集:预测葡萄酒的价格

我们将使用Kaggle的葡萄酒数据集来测试:能否通过描述和种类预测一瓶葡萄酒的价格?

这个问题很适合wide & deep learning,因为它涉及到文本输入,而且葡萄酒的描述和价格之间并没有显著的相关性。我们不能肯定地说,在描述中有“果味浓”的葡萄酒更贵,或者有“单宁柔和”的葡萄酒更便宜。此外,当我们将文本输入到我们的模型中时,有多种方法来表示文本,并且上述两者都可以导致不同类型的见解。而且两者皆包含wide(词袋)和deep(embedding)特征,结合两者可以使我们从文本中获得更多的意义。这个数据集有很多不同的潜在特征,但是我们只使用描述和种类这两种特征,这样结构相对简单。

下面是这个数据集的输入样本和预测:

输入

描述:酒瓶里冒出浓郁的香草味,即使是在这个不佳葡萄酒酿造期,果香也毫不逊色。这样强烈的水果酸中,含有草本成分,水果、酸、药草和香草以相同的比例迅速作用,生成美味的酒。密封的瓶体,这款酒年份不长,需要过酒换瓶或继续贮存,以完美地出现在世人面前。

种类:Pinot Noir(黑皮诺)

预测

价格:45美元

首先,我们要构建这个模型,需要导入:

由于我们的模型的输出(预测)是具体的价格(数字),我们就直接把价格数值输入到模型中进行训练和评估。这个模型的完整代码可以在GitHub上找到。这里我只列出重点。

首先,下载数据并将其转换为Pandas数据帧:

之后,我们将它分为训练集和测试集并提取特征和标签:

第一部分:wide模型

特征1:葡萄酒描述

为了创建我们的文本描述的wide代表,我们将使用bag of words model。简单解释下bag of words model:它可以在模型的每个输入中寻找单词。你可以把每一个输入想象成一个拼字块游戏,每一块都是一个单词而不是一个分解的字母。用这个模型无需考虑到描述中单词的顺序,只需查找一个单词是否存在。

我们不会去查看数据集中每个描述中存在的每个词,而是将我们的词袋限制在数据集中的12 000个单词中(内置的Keras工具可以创建这个词汇表)。这就可以代表wide,因为对于每个描述,模型的输入都是12000元素宽的向量,其中1 s和0s分别表示在特定的描述中,来自我们的词汇表的词是否存在。

Keras中有一些用于文本预处理的便利工具,我们用这种工具将文本描述转换成词袋。用bag of words model,我们通常只希望在词汇表中,找到数据集中所有词的子集。在本例中,我使用了12000个单词,但这是一个超参数,所以你可以进行调整(尝试一些数值,看看哪些在数据集上的效果最好)。我们可以使用Keras Tokenizer class来创建词袋:

然后用texts_to_matrix函数将每个描述转换为词袋向量:

特征2:葡萄酒种类

最初的Kaggle数据集中,葡萄酒分为632种。为了让模型更容易提取模式,我做了一些预处理,只保留了前40个种类(大约占原始数据集的65%,或者说共有96000个例子)。我们将使用Keras实用工具将每一个种类转换成整数表示,然后我们为每个表示种类的输入,创建了40个元素wide独热向量。

目前为止,我们已做好建立wide模型的准备了。

用Keras functional API创建wide模型

Keras有两种用于构建模型的API:Sequential API和Functional API。Functional API给我们提供了更多的灵活性,让我们可以对层进行定义,并将多重特征输入合并到一个层中。当我们做好准备,它也能够很容易地将我们的wide和deep模型结合到一起。使用Functional API,我们就可以在短短几行代码中定义我们的wide模型。首先,我们将输入层定义为12000个元素向量(对应词汇表中的每个单词)。然后我们将它连接到Dense输出层,以得出价格预测。

然后我们编译这个模型,这样就可以使用了:

如果我们对其本身使用wide模型,那么这里我们就要调用fit()函数进行训练,调用evaluate()函数进行评估。因为我们以后会把它和deep模型结合起来,所以我们可以在两个模型结合后在进行训练。现在,是时候建立deep模型了。

第二部分:deep模型

为了用deep代表葡萄酒描述,我们将把它作为一种embedding来表示。有很多关于word embeddings的资源,但简单来说就是它们提供了一种将词映射到向量的方法,这样类似的词在向量空间中将会更紧密地结合。

代表描述作为word embedding

为了将我们的文本描述转换为embedding层,我们首先需要将每个描述进行转换,使其成为对应于词汇表中的每个单词的整数向量。我们可以用Keras texts to sequence方法来实现这一点。

现在我们已经有了完整的描述向量,我们需要确保它们长度相同,才能把它们输入到我们的模型中。Keras也有可以作此处理的实用工具。我们用pad_sequences函数在每个描述向量中加入零点,以便它们长度相同(我将170设为最大长度,这样就无需缩短描述)。

描述被转换成长度相同的向量,我们已经准备好创建embedding层并将其输入到deep模型中。

建立deep模型

有两种方法可以创建一个embedding层:其一,我们可以用预训练得到的embedding权重(有很多开源的embedding词);其二,我们可以从词汇表中学习embedding。最好是对两者进行试验,看看哪一个在数据集上的表现更好。这里我们将使用第二种,即习得的embedding。

首先,我们将定义添加到deep模型的输入的形状。然后我们再将输入添加到embedding层。这里我使用了维度为8的embedding层(你可以尝试对embedding层的维度稍作调整)。embedding层的输出将是一个具有形状的三维向量:批处理大小,序列长度(本例中是170),embedding维度(本例中是8)。为了将我们的embedding层连接到Dense,并充分连接到输出层,我们需要先调用flatten()函数:

一旦用flatten()函数对embedding层进行了调整,就可以将它添加至模型并编译了:

第三部分:wide & deep

一旦我们成功定义了两个模型,将它们结合起来就很容易了。我们只需要创建一个层,将每个模型的输出连接起来,然后将它们合并到可以充分连接的Dense层中,将每个模型的输入和输出结合在一起,最后定义这一组合模型。显然,由于每个模型都在预测相同的事物(价格),所以每个模型的输出或标签都是相同的。还要注意的是,由于我们的模型输出是一个数值,我们不需要做任何预处理,它早已以正确的形式显示出来了:

这样我们就可以开始进行训练和评估了。你可以尝试找到最适合数据集的训练周期和批处理大小:

# Training

combined_model.fit([description_bow_train, variety_train] + [train_embed], labels_train, epochs=10, batch_size=128)

# Evaluation

combined_model.evaluate([description_bow_test, variety_test] + [test_embed], labels_test, batch_size=128)

通过受过训练的模型得出预测

终于到了最激动人心的时刻,现在让我们看看基于数据的模型性能,这样的表现是前所未有的。我们可以为受过训练的模型调用predict()函数,将其传递我们的测试数据集:

然后我们将比较测试数据集的前15种葡萄酒的实际价格与预测价格:

模型是如何进行比较的?让我们看看测试集中的三个例子:

1.酒瓶里冒出浓郁的香草味,即使是在这个不佳葡萄酒酿造期,果香也毫不逊色。这样强烈的水果酸中,含有草本成分,水果、酸、药草和香草以相同的比例迅速作用,生成美味的酒。密封的瓶体,这款酒年份不长,需要过酒换瓶或继续贮存,以完美地出现在世人面前。

预测价格:46.233624 实际价格:  45.0

2.日常酒品,干且浓郁,浆果樱桃口味,组织细腻均匀。

预测价格:9.694958实际价格:10.0

3.时尚圆瓶装,醇和的Barolo葡萄酒(产自Monforte d’Alba),喜欢浓稠多汁口感的人绝不会错过。薰衣草,甜胡椒,肉桂,白巧克力和香草混合的香气让人难忘。清爽的果酸和稳定的单宁使酸浆果的味道达到极致,口感独特。

预测价格:41.028854实际价格: 49.0

事实证明,葡萄酒的描述与价格之间存在某种联系。也许人类可能无法凭直觉进行推断,但ML模型可以。

keras训练完以后怎么预测_使用Keras建立Wide Deep神经网络,通过描述预测葡萄酒价格...相关推荐

  1. pytorch神经网络因素预测_实战:使用PyTorch构建神经网络进行房价预测

    微信公号:ilulaoshi / 个人网站:lulaoshi.info 本文将学习一下如何使用PyTorch创建一个前馈神经网络(或者叫做多层感知机,Multiple-Layer Perceptron ...

  2. 【Matlab风电功率预测】粒子群算法优化BP神经网络风电功率预测【含源码 347期】

    一.代码运行视频(哔哩哔哩) [Matlab风电功率预测]粒子群算法优化BP神经网络风电功率预测[含源码 347期] 二.matlab版本及参考文献 1 matlab版本 2014a 2 参考文献 [ ...

  3. python逻辑回归训练预测_[Python] 机器学习笔记 基于逻辑回归的分类预测

    导学问题 什么是逻辑回归(一),逻辑回归的推导(二 3),损失函数的推导(二 4) 逻辑回归与SVM的异同 逻辑回归和SVM都用来做分类,都是基于回归的概念 SVM的处理方法是只考虑 support ...

  4. 神经网络 并行预测_研究人员研究了为什么神经网络可以有效地进行预测

    人工智能,机器学习和神经网络是日常生活中越来越多的术语.面部识别,对象检测以及人的分类和分割是机器学习算法的常见任务,这些算法现在已得到广泛使用.所有这些过程的基础都是机器学习,这意味着计算机可以捕获 ...

  5. keras训练完以后怎么预测_还在使用“龟速”的单显卡训练模型?动动手,让TPU节省你的时间...

    点击上方关注,All in AI中国 本文将介绍如何使用Keras和Google CoLaboratory与TPU一起训练LSTM模型,与本地计算机上的GPU相比,这样训练能大大缩短训练时间. 很长一 ...

  6. keras训练完以后怎么预测_一文告诉你如何将Keras模型保存到文件中,并再次加载它们来进行预测。...

    Keras是一个用于深度学习的简单而强大的Python库. 鉴于深度学习模式可能需要数小时.数天甚至数周的时间来培训,了解如何保存并将其从磁盘中加载是很重要的. 在本文中,您将发现如何将Keras模型 ...

  7. keras训练完以后怎么预测_农村小孩只有户口,没有承包地,以后怎么养老?看完我安心了...

    阅读本文前,请您先点击上面的蓝色字体"三农荟",再点击"关注",这样您就可以继续免费收到最新情感文章了.每天都有分享.完全是免费订阅,请放心关注. 农村小孩,只 ...

  8. python训练好的图片验证_利用keras加载训练好的.H5文件,并实现预测图片

    我就废话不多说了,直接上代码吧! import matplotlib matplotlib.use('Agg') import os from keras.models import load_mod ...

  9. keras时间序列数据预测_使用Keras的时间序列数据中的异常检测

    keras时间序列数据预测 Anomaly Detection in time series data provides e-commerce companies, finances the insi ...

最新文章

  1. 【9.22校内测试】【可持久化并查集(主席树实现)】【DP】【点双联通分量/割点】...
  2. TensorFlow学习笔记(十一)读取自己的数据进行训练
  3. Yet Another Array Partitioning Task CodeForces - 1114B(思维)
  4. python剑指offer面试题_剑指Offer(Python语言)面试题38
  5. 【笔试记录】2021/3/10阿里
  6. 课题开题报告范文样本_成都汽车职业技术学校举行 2020年省、市、区课题开题报告会...
  7. C#的发展历程第五 - C# 7开始进入快速迭代道路
  8. 数据挖掘10大算法(1)——PageRank
  9. WinDbg / SOS Cheat Sheet
  10. MySQL安装与卸载教程
  11. 联想 缺少计算机所需的介质驱动程序,u盘安装win10显示缺少介质驱动最佳解决方法...
  12. Oracle和MySQL新增只有查询权限用户
  13. 洛谷P3939填颜色
  14. Linux性能检查命令总结
  15. UE4材质03_纹理采样及UV
  16. Git安装教程(Windows安装超详细教程)
  17. java-php-python-旅游景区预约管理系统计算机毕业设计
  18. kotlin中标准函数的使用(with、also、aply、let、run)
  19. Qt串口通信-qextserialport
  20. 三维交互开发(1)-Quest3D与程序的通信

热门文章

  1. 【Python培训基础知识】单例模式
  2. 直接依赖,间接依赖,可选依赖,排除依赖,依赖冲突
  3. 为什么匿名内部类参数必须为final类型
  4. (转)软件测试的分类软件测试生命周期
  5. 如何打造一流的视觉AI技术
  6. java 多维数组转化为字符串
  7. 跨平台网络游戏趋势和优势
  8. c语言饭卡管理系统链表文件,C语言《学生信息管理系统》链表+文件操作
  9. 机器学习-----有监督,无监督,半监督学习的简单阐释
  10. 2018-3-26论文(GWO和WOA)中Table1--Table3中的benchmark函数F1-F23图形