点击上方“算法数据侠”,选择“星标”公众号第一时间获取最新推文与资源分享小侠客们好呀,我是oubahe。继续为各位小侠客带来深度学习实战的实用小技巧,Sklearn和Tensorflow想必是大家做机器学习和深度学习时很熟悉的两个Python库,其中sklearn中有很多机器学习算法、数据预处理以及参数寻优的函数API,keras则可以快速实现你的神经网络结构。那么如何让sklearn和keras相遇而完美结合呢?换句话说,我们建立的深度学习网络模型有很多,通过Python和Tensorflow或者Pytorch构建一个神经网络模型非常方便,那么要想取得一个好的模型效果,就需要对神经网络模型进行调参,单一的人工调参是非常繁琐的(超参数一个一个调的感觉确实美滋滋~),往往不容易取得一个好的效果。我们其实可以借助sklearn来自动参数搜索:Tensorflow2中提供sklearn包装器,分别用于分类的tensorflow.keras.wrappers.scikit_learn.KerasClassifier和用于回归的tensorflow.keras.wrappers.scikit_learn.KerasRegressor。下面就带各位小侠客来看一下通过Sklearn网格搜索GridsearchCV进行Tensorflow调参的方法。来吧,展示~

01

调整batch_size和epochs

首先我们可以使用网格搜索对batch_size和epochs这两个参数进行调整,我们可以根据自己的需要设置待选参数值,在这里我们设置batch_size 为 [10, 20, 40, 60, 80, 100]且epochs 为 [10, 50, 100]。具体实现过程如下:

import numpyfrom tensorflow.keras.wrappers.scikit_learn import KerasClassifierfrom sklearn.model_selection import GridSearchCVfrom tensorflow.keras.models import Sequentialfrom tensorflow.keras.layers import Dense# Function to create deep learning model, required for KerasClassifierdef create_model():    # create model    model = Sequential()    model.add(Dense(12, input_dim=8, activation='relu'))    model.add(Dense(1, activation='sigmoid'))    # Compile model    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])    return model# fix random seed for reproducibilityseed = 7numpy.random.seed(seed)# load dataset (这里可以自行载入数据集)dataset = numpy.loadtxt("diabetes.csv", delimiter=",")# split into input (X) and output (Y) variablesX = dataset[:,0:8]Y = dataset[:,8]# create modelmodel = KerasClassifier(build_fn=create_model, verbose=0)# define the grid search parametersbatch_size = [10, 20, 40, 60, 80, 100]  # 设定的超参数batch_size取值范围epochs = [10, 50, 100]  # 设定的超参数epochs取值范围param_grid = dict(batch_size=batch_size, epochs=epochs)grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1)grid_result = grid.fit(X, Y)# summarize resultsprint("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))means = grid_result.cv_results_['mean_test_score']stds = grid_result.cv_results_['std_test_score']params = grid_result.cv_results_['params']for mean, stdev, param in zip(means, stds, params):    print("%f (%f) with: %r" % (mean, stdev, param))

通过上图中展示的训练结果图,我们可以看到batch_size为20和epochs为100个时达到最佳准确度68%。

02

调整优化算法

其次,深度网络中的优化算法有sgd、adam、RMSprop等,如何选择一个合适的优化算法是非常重要的。下面通过一个例子来展示如何通过网格搜索挑选优化算法。具体实现流程和第一步的流程类似,只是搜索变量换成了优化器optimizer。

import numpyfrom tensorflow.keras.wrappers.scikit_learn import KerasClassifierfrom sklearn.model_selection import GridSearchCVfrom tensorflow.keras.models import Sequentialfrom tensorflow.keras.layers import Dense# Function to create deep learning model, required for KerasClassifierdef create_model(optimizer='adam'):    # create model    model = Sequential()    model.add(Dense(12, input_dim=8, activation='relu'))    model.add(Dense(1, activation='sigmoid'))    # Compile model    model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])    return model# fix random seed for reproducibilityseed = 7numpy.random.seed(seed)# load dataset (这里可以自行载入数据集)dataset = numpy.loadtxt("diabetes.csv", delimiter=",")# split into input (X) and output (Y) variablesX = dataset[:,0:8]Y = dataset[:,8]# create modelmodel = KerasClassifier(build_fn=create_model, verbose=0)# define the grid search parametersoptimizer = ['SGD', 'RMSprop', 'Adagrad', 'Adadelta', 'Adam', 'Adamax', 'Nadam'] # 设定的超参数optimizer取值param_grid = dict(optimizer=optimizer)grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1)grid_result = grid.fit(X, Y)# summarize resultsprint("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))means = grid_result.cv_results_['mean_test_score']stds = grid_result.cv_results_['std_test_score']params = grid_result.cv_results_['params']for mean, stdev, param in zip(means, stds, params):    print("%f (%f) with: %r" % (mean, stdev, param))

由上图中的结果可知,针对不同的问题,不同的优化函数取得的结果确实是不一样的,从本例结果可以看到采用adam优化算法取得最优结果。到这里各位小侠客应该理解了如何通过Sklearn中的网格搜索来对Tensorflow构建的深度模型调参,本文只列出来几个网络参数,其他如学习率以及神经元数量等参数的调整方法也是一样的,只需要将待选参数输入进去就可以等待运行结果。各位小侠客如果有兴趣可以找一个数据集然后按照本文实例中的代码自己运行一遍,相信你将会有不一样的收获哟。

03

结束语

上述内容就是Sklearn与Tensorflow完美结合的全部内容啦,主要分享了通过Sklearn中的GridSearch来帮助Tensorflow构建的深度分类或回归模型搜索取值范围内最优的超参数组合,使得深度模型达到相对最佳的学习性能,这些技巧你学废了吗!?学习的时间如此短暂,如此实用的技巧各位小侠客要学以致用哟。我是oubahe,下次再见叻~

码字虽少,原创不易。分享是快乐的源泉,小侠客们记得来个素质三连 :点击左下角分享 —> 右下角点赞—>在看本文,可以汇聚好运气召唤神龙哟~

sklearn gridsearchcv_Sklearn与Tensorflow的完美结合相关推荐

  1. python 多分类逻辑回归_机器学习实践:多分类逻辑回归(softmax回归)的sklearn实现和tensorflow实现...

    本文所有代码及数据可下载. Scikit Learn 篇:Light 版 scikit learn内置了逻辑回归,对于小规模的应用较为简单,一般使用如下代码即可 from sklearn.linear ...

  2. 【电子书+代码】Sklearn,Keras与Tensorflow机器学习实用指南

    我们都知道:Scikit-Learn,Keras,Tensorflow是机器学习工具链的重要组成部分.本书的作者,根据上述三个机器学习工具箱,融汇贯通成一个个机器学习实例,让即使对人工智能了解不多的程 ...

  3. 调用训练好的模型(tensorflow)

    使用Tensorflow框架完美保存并实现调用训练好的模型 opencv调用tf训练好的模型          主机调用,不用安装tf,不需要显卡 OpenCV的dnn模块调用TesorFlow训练的 ...

  4. 使用TensorFlow的基本步骤

    学习任务 学习使用TensorFlow,并以california的1990年的人口普查中的城市街区的房屋价值中位数作为预测目标,使用均方根误差(RMSE)评估模型的准确率,并通过调整超参数提高模型的准 ...

  5. 使用tensorflow和Keras的初级教程

    作者|Angel Das 编译|VK 来源|Towards Datas Science 介绍 人工神经网络(ANNs)是机器学习技术的高级版本,是深度学习的核心.人工神经网络涉及以下概念.输入输出层. ...

  6. 机器学习从入门到创业手记-1.3 必备的工具与框架

    今天的课程主要以自我学习为主,李里发给了每个人一张印有培训内容的表格,要求按照培训内容列表中提到的工具,将其简介都写在工具名称的后面. 李里解释道:作为机器学习的初学者刚进入这个领域时肯定是一头雾水, ...

  7. 化解谷歌AI霸权的另一种思路?开发平台的生态围剿

    来源: 脑极体 概要:无论是学界还是巨头,都只能给出规则和参考,以及一小部分示例性应用,而最终让人工智能落地产生价值的,只能是成千上万脑中闪过鬼点子的开发者. 相较移动互联网,AI将是一个更激进的开发 ...

  8. 零基础自学python教程-零基础学Python不迷茫——基本学习路线及教程

    什么是Python? 在过去的2018年里,Python成功的证明了它自己有多火,它那"简洁"与明了的语言成功的吸引了大批程序员与大数据应用这的注意,的确,它的实用性的确是配的上它 ...

  9. 深度学习端上部署工具

    深度学习端上部署工具 模型 公司 通用性别 说明 tf-lite tensorflow,开源 通用性最强,与 tensorflow 适配完美,不过性能一般 支持CPU和GPU roadmap 中预计年 ...

最新文章

  1. Nginx源码分析--基本数据类型的别名
  2. vue读取redis 值_Jmeter连接Redis,一定很容易学会吧
  3. ajax catch,promise记得写上catch
  4. AngularJS优缺点、使用场景
  5. docusign文档打不开_怎样查看 docusign pdf 电子签名
  6. webapp 中为span元素赋值
  7. Nginx配置指定媒体类型文件强制下载
  8. 让 CefSharp.WinForms 应用程序同时支持32位(x86)和64位(x64)的解决方案
  9. 《大数据》第1期“动态”——站在大数据的风口上
  10. 利用ScriptEngineManager实现字符串公式灵活计算
  11. 图像语义分割 —利用Deeplab v3+训练VOC2012数据集
  12. django admin单例对象
  13. 如何在前端删除项目中的文件_如何在macOS上恢复已删除的文件
  14. mongoDB的基本使用----飞天博客
  15. 3 docker容器
  16. 程序员应该如何学习线性代数
  17. 前后端报文传输加密方案
  18. 【WIN 07】笔记本重装系统找回预装的office
  19. 锂电池原理与使用保养
  20. python判断是工作日还是休息日

热门文章

  1. 【kafka】kafka 消费 带有 kerberos认证的服务器
  2. 【FLink】四种图 以及 数据在 taskManager 之间的流转
  3. 【ElasticSearch】ElasticSearch 嵌套查询:如何搜索嵌入的文档
  4. 【Spring】SpringMVC 初始化 流程
  5. 【clickhouse】clickhouse数据文件目录移动到新目录并建立软连接
  6. SpringBoot : SpringBoot自定义的ApplicationContext实现类
  7. java : jstack 显示虚拟机的线程快照
  8. Integer的缓存机制
  9. servlet技术是否过时
  10. 云计算教程学习入门视频课件:常用数据库排名