介绍

optuna作为调参工具适合绝大多数的机器学习框架,sklearn,xgb,lgb,pytorch等。

主要的调参原理如下:
1 采样算法
利用 suggested 参数值和评估的目标值的记录,采样器基本上不断缩小搜索空间,直到找到一个最佳的搜索空间,
其产生的参数会带来 更好的目标函数值。

  • optuna.samplers.TPESampler 实现的 Tree-structured Parzen Estimator 算法
  • optuna.samplers.CmaEsSampler 实现的 CMA-ES 算法
  • optuna.samplers.GridSampler 实现的网格搜索
  • optuna.samplers.RandomSampler 实现的随机搜索
    默认TPE
study = optuna.create_study()
print(f"Sampler is {study.sampler.__class__.__name__}")study = optuna.create_study(sampler=optuna.samplers.RandomSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")study = optuna.create_study(sampler=optuna.samplers.CmaEsSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")

2 剪枝算法
自动在训练的早期(也就是自动化的 early-stopping)终止无望的 trial

  • optuna.pruners.SuccessiveHalvingPruner 实现的 Asynchronous Successive Halving 算法。
  • optuna.pruners.HyperbandPruner 实现的 Hyperband 算法。
  • optuna.pruners.MedianPruner 实现的中位数剪枝算法
  • optuna.pruners.ThresholdPruner 实现的阈值剪枝算法

激活 Pruner
要打开剪枝特性的话,你需要在迭代式训练的每一步后调用 report() 和 should_prune(). report() 定期监控目标函数的中间值. should_prune() 确定终结那些没有达到预先设定条件的 trial.

import logging
import sys
import sklearn.datasets
import sklearn.linear_model
import sklearn.model_selectiondef objective(trial):iris = sklearn.datasets.load_iris()classes = list(set(iris.target))train_x, valid_x, train_y, valid_y = sklearn.model_selection.train_test_split(iris.data, iris.target, test_size=0.25, random_state=0)alpha = trial.suggest_float("alpha", 1e-5, 1e-1, log=True)clf = sklearn.linear_model.SGDClassifier(alpha=alpha)for step in range(100):clf.partial_fit(train_x, train_y, classes=classes)# Report intermediate objective value.intermediate_value = 1.0 - clf.score(valid_x, valid_y)trial.report(intermediate_value, step)# Handle pruning based on the intermediate value.if trial.should_prune():raise optuna.TrialPruned()return 1.0 - clf.score(valid_x, valid_y)# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study = optuna.create_study(pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=20)

对 optuna.samplers.RandomSampler 而言 optuna.pruners.MedianPruner 是最好的。
对于 optuna.samplers.TPESampler 而言 optuna.pruners.Hyperband 是最好的。

当 Optuna 被用于机器学习时,目标函数通常返回模型的损失或者准确度。

1. Study 对象

  • Trial: 目标函数的单次调用
  • Study: 一次优化过程,包含一系列的 trials.
  • Parameter: 待优化的参数.
    在 Optuna 中,我们用 study 对象来管理优化过程。 create_study() 方法会返回一个 study 对象。该对象包含若干有用的属性,可以用于分析优化结果。
    获得参数名和参数值的字典:
    study.best_params
    获得最佳目标值:
    study.best_values

2.超参数采样

  • optuna.trial.Trial.suggest_categorical() 用于类别参数
  • optuna.trial.Trial.suggest_int() 用于整形参数
  • optuna.trial.Trial.suggest_float() 用于浮点型参数

通过可选的 steplog 参数,我们可以对整形或者浮点型参数进行离散化或者取对数操作。
这里的step比较好理解,对于整型就是步长,对于float就是离散化程度(分箱)

log开始不是特别理解,查看了optuna的源码:
对于float:If log is true, the value is sampled from the range in the log domain.
Otherwise, the value is sampled from the range in the linear domain.
还是很懵逼,看看numpy里面是怎么搞的,numpy里面有三种抽样方式:
logspace
Similar to geomspace, but with endpoints specified using log and base.
linspace
Similar to geomspace, but with arithmetic instead of geometric progression.
geomspace
Similar to logspace, but with endpoints specified directly.
举个例子比较直观:

np.linspace(0.02, 2.0, num=20)
np.geomspace(0.02, 2.0, num=20)
np.logspace(0.02, 2.0, num=20)

linspace是一列等差数列,

[ 0.02  0.12421053  0.22842105  0.33263158  0.43684211  0.541052630.64526316  0.74947368  0.85368421  0.95789474  1.06210526  1.166315791.27052632  1.37473684  1.47894737  1.58315789  1.68736842  1.791578951.89578947  2. ]

geomspace是一列等比数列

[0.02 ,  0.0254855 ,  0.03247553,  0.04138276,  0.05273302,0.06719637,  0.08562665,  0.1091119 ,  0.13903856,  0.17717336,0.22576758,  0.28768998,  0.36659614,  0.46714429,  0.59527029,0.75853804,  0.96658605,  1.23169642,  1.56951994,  2.]

logspace会计算默认计算一个basestartbase^{start}basestart和baseendbase^{end}baseend, base默认为10,计算了start和end
start=100.02=1.047,end=102=100.start=10^{0.02} =1.047, end=10^{2} =100.start=100.02=1.047,end=102=100.

[  1.04712855    1.33109952    1.69208062    2.15095626    2.734274463.47578281    4.41838095    5.61660244    7.13976982    9.0760052211.53732863   14.66613875   18.64345144   23.69937223   30.1264090438.29639507   48.68200101   61.88408121   78.6664358   100.  ]

代码示例:


import optuna
def objective(trial):# Categorical parameteroptimizer = trial.suggest_categorical("optimizer", ["MomentumSGD", "Adam"])# Integer parameternum_layers = trial.suggest_int("num_layers", 1, 3)# Integer parameter (log)num_channels = trial.suggest_int("num_channels", 32, 512, log=True)# Integer parameter (discretized)num_units = trial.suggest_int("num_units", 10, 100, step=5)# Floating point parameterdropout_rate = trial.suggest_float("dropout_rate", 0.0, 1.0)# Floating point parameter (log)learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True)# Floating point parameter (discretized)drop_path_rate = trial.suggest_float("drop_path_rate", 0.0, 1.0, step=0.1)

定义参数空间

在 Optuna 中,我们使用和 Python 语法类似的方式来定义搜索空间,其中包含条件和循环语句。
类似地,你也可以根据参数值采用分支或者循环。

# 分支
import sklearn.ensemble
import sklearn.svmdef objective(trial):classifier_name = trial.suggest_categorical("classifier", ["SVC", "RandomForest"])if classifier_name == "SVC":svc_c = trial.suggest_float("svc_c", 1e-10, 1e10, log=True)classifier_obj = sklearn.svm.SVC(C=svc_c)else:rf_max_depth = trial.suggest_int("rf_max_depth", 2, 32, log=True)classifier_obj = sklearn.ensemble.RandomForestClassifier(max_depth=rf_max_depth)
# 循环
import torch
import torch.nn as nn
def create_model(trial, in_size):n_layers = trial.suggest_int("n_layers", 1, 3)layers = []for i in range(n_layers):n_units = trial.suggest_int("n_units_l{}".format(i), 4, 128, log=True)layers.append(nn.Linear(in_size, n_units))layers.append(nn.ReLU())in_size = n_unitslayers.append(nn.Linear(in_size, 10))return nn.Sequential(*layers)

关于参数个数的注意事项
随着参数个数的增长,优化的难度约呈指数增长。也就是说,当你增加参数的个数的时候,优化所需要的 trial 个数会呈指数增长。因此我们不推荐增加不必要的参数。

Reference:
1.官网
2.github examples
3.Difference in output between numpy linspace and numpy logspace
4.np.geomspace

调参神器optuna学习笔记相关推荐

  1. 调参神器贝叶斯优化(bayesian-optimization)实战篇

    今天笔者来介绍一下和调参有关的一些事情,作为算法工程师,调参是不可避免的一个工作.在坊间算法工程师有时候也被称为:调参侠.但是一个合格的算法工程师,调参这部分工作不能花费太多的气力,因为还有很多艰深的 ...

  2. DL之模型调参:深度学习算法模型优化参数之对LSTM算法进行超参数调优

    DL之模型调参:深度学习算法模型优化参数之对LSTM算法进行超参数调优 目录 基于keras对LSTM算法进行超参数调优 1.可视化LSTM模型的loss和acc曲线

  3. DL之模型调参:深度学习算法模型优化参数之对深度学习模型的超参数采用网格搜索进行模型调优(建议收藏)

    DL之模型调参:深度学习算法模型优化参数之对深度学习模型的超参数采用网格搜索进行模型调优(建议收藏) 目录 神经网络的参数调优 1.神经网络的通病-各种参数随机性 2.评估模型学习能力

  4. catBoost 神器的学习笔记

    catBoost 神器的学习笔记,记录自己看原文章的心得.第一次发文,中间有些部分也是个人理解,不足之处,敬请谅解.欢迎扔砖 ^=^catBoost 原文的标题是 "CatBoost :un ...

  5. python学习——超参数调参工具optuna

    感觉目前是适配于各种框架,机器学习框架,深度学习框架,都比较好用的一个调参框架 参考文献:1.https://github.com/optuna/optuna-examples 2.https://z ...

  6. 【数据竞赛】席卷Kaggle的调参神器,NN和树模型通吃!

    作者:杰少 Optuna技术! 简 介 目前非常多的超参寻优算法都不可避免的有下面的一个或者多个问题: 需要人为的定义搜索空间: 没有剪枝操作,导致搜索耗时巨大: 无法通过小的设置变化使其适用于大的和 ...

  7. 【代码质量】-阿里巴巴java开发手册(代码质量提升神器)学习笔记

    前言:<阿里巴巴 Java 开发手册>是阿里巴巴集团技术团队的集体智慧结晶和经验总结,有了这些前人总结的经验,可以帮助我们写出高质量的代码,同时可以减少Bug数量,少踩坑,提高代码的可读性 ...

  8. 一文掌握模型调参神器:Hyperopt

    hyperopt是一个Python库,主要使用 ①随机搜索算法 ②模拟退火算法 ③TPE算法 来对某个算法模型的最佳参数进行智能搜索,它的全称是Hyperparameter Optimization. ...

  9. 机器学习调参神器——网格搜索方法

    网格搜索方法主要用于模型调参,也就是帮助我们找到一组最合适的模型设置参数,使得模型的预测达到更好的效果,这组参数于模型训练过程中学习到的参数不同,它是需要在训练前预设好的,我们称其为超参数. 超参数的 ...

最新文章

  1. C#操作excel(多种方法比较)
  2. python基础知识-11-函数装饰器
  3. AtCoder - 2581 Meaningful Mean
  4. ae在哪里直接复制合成_AE模板里修改复制的合成如何不影响原先的合成?
  5. 关于明星投票系统的作业分享
  6. prototype.js学习(3)函数绑定
  7. C语言bmp转JPEG不用库函数,C++图片格式转换:BMP转JPEG
  8. 把已经写好的Vue项目转成uni-app项目
  9. OSPF配置实验报告
  10. echarts 默认显示图例_Echarts 饼状图显示信息,内容,值,百分比都显示的代码 更改图例等问题汇总...
  11. excel汇总报表如何做?
  12. 使用xpath批量爬取堆糖图片
  13. 【用户角色权限设计】
  14. 算法训练 Beaver's Calculator (蓝桥杯)
  15. 苹果手机Home键失灵怎么办?
  16. 围观历史上最著名的十大思想实验,一起来疯狂思考一下
  17. Docker安装好后默认路径
  18. WPS如何将金额快速改为万元显示
  19. 数字孪生微电网,搭建源网荷储一体化管控平台
  20. python登陆成功页面跳转_Python QT由登陆界面到主界面

热门文章

  1. ArcGis 中打开 shp 文件时 未知的空间参考 警告
  2. sx1268芯片手册第13章翻译
  3. 8000字详解银行业数据治理架构体系搭建
  4. WebSocket聊天室
  5. AirDIsk产品第三方Samba同步工具
  6. 给员工的一封信:在职业生涯规划的框架中工作
  7. 如何解决Redis缓存雪崩、击穿与穿透
  8. java cmos_CMOS构成的常见电路
  9. Git 相关配置 用户名、邮箱
  10. 股票L1和L2都代表是什么意思?