原文链接:http://tecdat.cn/?p=12693

原文出处:拓端数据部落公众号


介绍

在本教程中,我们将讨论一种非常强大的优化(或自动化)算法,即网格搜索算法。它最常用于机器学习模型中的超参数调整。我们将学习如何使用Python来实现它,以及如何将其应用到实际应用程序中,以了解它如何帮助我们为模型选择最佳参数并提高准确性

前提

阅读本教程,您最好对Python或其他某种编程语言有基本的了解,也具有机器学习的基本知识,但这不是必需的。除此之外,本文是初学者友好的,任何人都可以关注。

安装

要完成本教程,您需要在系统中安装以下库/框架:

  1. Python 3
  2. NumPy
  3. Pandas
  4. Keras
  5. Scikit-Learn

它们的安装都非常简单-您可以单击它们各自的网站,获取各自的详细安装说明。通常,可以使用pip安装软件包:

$ pip install numpy pandas tensorflow keras scikit-learn

什么是网格搜索?

网格搜索本质上是一种优化算法,可让你从提供的参数选项列表中选择最适合优化问题的参数,从而使“试验和误差”方法自动化。尽管它可以应用于许多优化问题,但是由于其在机器学习中的使用而获得最广为人知的参数,该参数可以使模型获得最佳精度。

假设您的模型采用以下三个参数作为输入:

  1. 隐藏层数[2,4]
  2. 每层中的神经元数量[5,10]
  3. 神经元数[10,50]

如果对于每个参数输入,我们希望尝试两个选项(如上面的方括号中所述),则总计总共2 ^3 = 8个不同的组合(例如,一个可能的组合为[2,5,10])。手动执行此操作会很麻烦。

现在,假设我们有10个不同的输入参数,并且想为每个参数尝试5个可能的值。每当我们希望更改参数值,重新运行代码并跟踪所有参数组合的结果时,都需要从我们这边进行手动输入。网格搜索可自动执行该过程,因为它仅获取每个参数的可能值并运行代码以尝试所有可能的组合,输出每个组合的结果,并输出可提供最佳准确性的组合。

网格搜索实施

让我们将网格搜索应用于实际应用程序。讨论机器学习和数据预处理这一部分不在本教程的讨论范围之内,因此我们只需要运行其代码并深入讨论Grid Search的引入部分即可。

我们将使用糖尿病数据集,该数据集包含有关患者是否基于不同属性(例如血糖,葡萄糖浓度,血压等)的糖尿病信息。使用read_csv()方法。

以下脚本导入所需的库:

from sklearn.model_selection import GridSearchCV, KFold
from keras.models import Sequential
from keras.optimizers import Adam
import sys
import pandas as pd
import numpy as np

以下脚本导入数据集并设置数据集的列标题。

df = pd.read_csv(data_path, names=columns)

让我们看一下数据集的前5行:

df.head()

输出:

如你所见,这5行都是用来描述每一列的标签,因此它们对我们没有用。我们将从删除这些非数据行开始,然后将所有NaN值替换为0:

df.dropna(inplace=True) # 删除所有缺失值的行

以下脚本将数据分为变量和标签集,并将标准化应用于数据集:

# 变换和显示训练数据
X_standardized = scaler.transform(X)

以下方法创建了我们简单的深度学习模型:

    # 创建模型model = Sequential()model.add(Dense(8, input_dim=8, kernel_initializer='normal', activation='relu'))#编译模型model.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy'])

这是加载数据集,对其进行预处理并创建机器学习模型所需的部分代码。因为我们只对Grid Search的功能感兴趣,所以我没有进行训练/测试拆分,我们将模型拟合到整个数据集。

在下一节中,我们将开始了解Grid Search如何通过优化参数使训练模型变得更轻松。

在没有网格搜索的情况下训练模型

在下面的代码中,我们将随机决定或根据直觉决定的参数值创建模型,并查看模型的性能:


model = create_model(learn_rate, dropout_rate)

输出:

Epoch 1/1
130/130 [==============================] - 0s 2ms/step - loss: 0.6934 - accuracy: 0.6000

正如看到的,我们得到的精度是60.00%。这是相当低的。

使用网格搜索优化超参数

如果不使用Grid Search,则可以直接fit()在上面创建的模型上调用方法。但是,要使用网格搜索,我们需要将一些参数传递给create_model()函数。此外,我们需要使用不同的选项声明我们的网格,我们希望为每个参数尝试这些选项。让我们分部分进行。

首先,我们修改create_model()函数以接受调用函数的参数:


# 创建模型
Classifier(create_model, verbose=1)

现在,我们准备实现网格搜索算法并在其上拟合数据集:


# 建立和拟合GridSearch
GridSearch(estimator=mode)

输出:

Best: 0.7959183612648322, using {'batch_size': 10, 'dropout_rate': 0.2, 'epochs': 10, 'learn_rate': 0.02}

在输出中,我们可以看到它为我们提供了最佳精度的参数组合。

可以肯定地说,网格搜索在Python中非常容易实现,并且在人工方面节省了很多时间。您可以列出所有您想要调整的参数,声明要测试的值,运行您的代码。您无需再输入任何信息。找到最佳参数组合后,您只需将其用于最终模型即可。

结论

总结起来,我们了解了什么是Grid Search,它如何帮助我们优化模型以及它带来的诸如自动化的好处。此外,我们学习了如何使用Python语言在几行代码中实现它。为了了解其有效性,我们还训练了带有和不带有Grid Search的机器学习模型,使用Grid Search的准确性提高了19%。

拓端tecdat|Python中基于网格搜索算法优化的深度学习模型分析糖尿病数据相关推荐

  1. python网格搜索法_Python中基于网格搜索算法优化的深度学习模型分析糖尿病数据...

    介绍 在本教程中,我们将讨论一种非常强大的优化(或自动化)算法,即网格搜索算法.它最常用于机器学习模型中的超参数调整.我们将学习如何使用Python来实现它,以及如何将其应用到实际应用程序中,以了解它 ...

  2. 怎么装python的keras库_matlab调用keras深度学习模型(环境搭建)

    matlab没有直接调用tensorflow模型的接口,但是有调用keras模型的接口,而keras又是tensorflow的高级封装版本,所以就研究一下这个--可以将model-based方法和le ...

  3. 【基于 docker 的 Flask 的深度学习模型部署】

    文章目录 1.前言 2.docker简介 3.基于Falsk的REST API实现 4.编写dockerfile 5.基于docker的模型部署 1.前言 模型部署一直是深度学习算法走向落地的重要的一 ...

  4. 【NLP-NER】命名实体识别中最常用的两种深度学习模型

    命名实体识别(Named Entity Recognition,NER)是NLP中一项非常基础的任务.NER是信息提取.问答系统.句法分析.机器翻译等众多NLP任务的重要基础工具. 上一期我们介绍了N ...

  5. 单目标应用:基于麻雀搜索算法优化灰色神经网络(grey neural network)的数据预测(提供MATLAB代码)

    一.麻雀搜索算法 麻雀搜索算法(sparrow search algorithm,SSA)由Jiankai Xue等人于2020年提出,该算法是根据麻雀觅食并逃避捕食者的行为而提出的群智能优化算法.S ...

  6. 基于web端和C++的两种深度学习模型部署方式

    深度学习Author:louwillMachine Learning Lab 本文对深度学习两种模型部署方式进行总结和梳理.一种是基于web服务端的模型部署,一种是基... 深度学习 Author:l ...

  7. 【深度学习】【物联网】深度解读:深度学习在IoT大数据和流分析中的应用

    作者|Natalie 编辑|Emily AI 前线导读:在物联网时代,大量的感知器每天都在收集并产生着涉及各个领域的数据.由于商业和生活质量提升方面的诉求,应用物联网(IoT)技术对大数据流进行分析是 ...

  8. 【深度学习】深度解读:深度学习在IoT大数据和流分析中的应用

    来源:网络大数据(ID:raincent_com) 摘要:这篇论文对于使用深度学习来改进IoT领域的数据分析和学习方法进行了详细的综述. 在物联网时代,大量的感知器每天都在收集并产生着涉及各个领域的数 ...

  9. AI Earth 深度学习模型替换数值天气预报模型中的参数化方案-大气辐射传输方案

    1.背景 太阳辐射和热辐射是大气和海洋运动的最根本的驱动力.大气辐射传输过程实际上已经可以通过一种叫做LBLRTM的辐射模型精确计算,但是LBLRTM模型同时也最为耗时.因此,有各种各样的辐射传输参数 ...

  10. 深度学习模型在移动端的部署

    简介 自从 AlphaGo 出现以来,机器学习无疑是当今最火热的话题,而深度学习也成了机器学习领域内的热点,现在人工智能.大数据更是越来越贴近我们的日常生活,越来越多的人工智能应用开始在移植到移动端上 ...

最新文章

  1. MAC修改python和pip版本
  2. 获取窗口上指定控件集合 2012-08-22 16:14 498人阅读 评论(0) 收藏...
  3. LeetCode Algorithm 797. 所有可能的路径
  4. [转载] QoS的基本原理
  5. vue axios POST请求中参数以form data和request payload形式的原因
  6. 数据库设计(一对一、一对多、多对多)
  7. 哈希值+非对称加密+网络+数字签名,你真的知道怎么给游戏充钱吗
  8. Chrome浏览器显示“网站连接不安全”怎么解决?解决方法分享
  9. UVA 10534 Wavio Sequence DP LIS
  10. 吃相难看!它又又又涨价了......
  11. MATLAB eof用法,[转载]基于Matlab软件进行EOF分解、回归趋势分析
  12. python Pytesseract 动态验证码图片识别
  13. mysqldump关于--set-gtid-purged=OFF的使用(好文章!!)
  14. 随机数种子(seed)
  15. 黑客与技术提示:电脑出现文中现象说明你已经被黑客入侵
  16. 于Cd(Ⅲ)金属有机骨框架的新型造影剂Cd-MOF/Gd-DTPA/DMPE-DTPA-Gd-DMPE/
  17. 数字图像处理第六章——彩色图像处理(上)
  18. NTP服务端和客户端的部署——Chrony
  19. LCD Backlight 的分析
  20. 苹果手机怎么恢复丢失的数据?果粉必看!

热门文章

  1. XML的DTD和Schema约束
  2. 第十二周项目2 - 摩托车继承自行车和机动车
  3. Extjs4.2如何实现鼠标点击统计图时弹出窗口来展示统计的具体列表信息
  4. Q:判断链表中是否存在环的相关问题
  5. FFmpeg基础库编程开发学习笔记——视频常见格式
  6. 使用Spring自定义注解实现任务路由
  7. Android驱动工程师职位要求
  8. gcc/g++ 静态动态库 混链接.
  9. phpSQLiteAdmin - 基于Web的SQLite数据库管理工具 - OPEN 开发经验库
  10. 【iOS】编译静态库