神经网络如何调参、超参数的最优化方法、python实现

  • 一、what is 超参数
  • 二、超参数优化实验

一、what is 超参数

超参数是什么,其实就是,各层神经元数量、batch大小、学习率等人为设定的一些数。

数据集分为训练数据、测试数据、验证数据。

用测试数据评估超参数值的好坏,就可能导致超参数的值被调整为只拟合测试数据,所以加了个验证数据。

训练数据用于参数的学习,验证数据用于超参数的性能评估。

进行超参数最优化,重要的是,逐渐缩小超参数好值存在范围。

一开始大致设定一个范围,从范围中随机采样出超参数,用这个采样值进行识别精度评估,根据这个结果缩小超参数好值范围,然后重复上述操作。研究发现,随机采样效果好点。

二、超参数优化实验

接下来用MNISIT数据集进行超参数最优化,参考斯坦福大学的实验。

实验:最优化学习率和控制权值衰减强度系数这两个参数。

实验中,权值衰减系数初始范围1e- 8到1e- 4,学习率初始范围1e- 6到1e- 2。

随机采样体现在下面代码:

weight_decay = 10 ** np.random.uniform(-8, -4)lr = 10 ** np.random.uniform(-6, -2)

实验结果:

结果可以看出,学习率在0.001到0.01之间,权值衰减系数在1e-8到1e-6之间时,学习可以顺利进行。

观察可以使学习顺利进行的超参数范围,从而缩小值的范围。

然后可以从缩小的范围中继续缩小,然后选个最终值。

=========== Hyper-Parameter Optimization Result ===========
Best-1(val acc:0.8) | lr:0.008986830875594513, weight decay:3.716187805144909e-07
Best-2(val acc:0.76) | lr:0.007815234765792472, weight decay:8.723036800420108e-08
Best-3(val acc:0.73) | lr:0.004924088836198354, weight decay:5.044414627324654e-07
Best-4(val acc:0.7) | lr:0.006838530258012433, weight decay:7.678322790416307e-06
Best-5(val acc:0.69) | lr:0.0037618568422154793, weight decay:6.384663995933291e-08
Best-6(val acc:0.69) | lr:0.004818463383741305, weight decay:4.875486288914377e-08
Best-7(val acc:0.65) | lr:0.004659925318439445, weight decay:1.4968108648982665e-05
Best-8(val acc:0.64) | lr:0.005664124223619111, weight decay:6.070191899324037e-06
Best-9(val acc:0.56) | lr:0.003954240835144594, weight decay:1.5725686195018805e-06
Best-10(val acc:0.5) | lr:0.002554755378245952, weight decay:4.481334628759244e-08
Best-11(val acc:0.5) | lr:0.002855983685917335, weight decay:1.9598718051356917e-05
Best-12(val acc:0.47) | lr:0.004592998586693871, weight decay:4.888121831499798e-05
Best-13(val acc:0.47) | lr:0.0025326736070483947, weight decay:3.200796060402024e-05
Best-14(val acc:0.44) | lr:0.002645798359877985, weight decay:5.0830237860839325e-06
Best-15(val acc:0.42) | lr:0.001942571686958991, weight decay:3.0673143794194257e-06
Best-16(val acc:0.37) | lr:0.001289748323175032, weight decay:2.3690338828642213e-06
Best-17(val acc:0.36) | lr:0.0017017390582746337, weight decay:9.176068035802207e-05
Best-18(val acc:0.3) | lr:0.0015961247160317246, weight decay:1.3527453417413358e-08
Best-19(val acc:0.28) | lr:0.002261959202515378, weight decay:6.004620370338303e-05
Best-20(val acc:0.26) | lr:0.0008799239275589458, weight decay:4.600825912333848e-07

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from common.multi_layer_net import MultiLayerNet
from common.util import shuffle_dataset
from common.trainer import Trainer(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)# 为了实现高速化,减少训练数据
x_train = x_train[:500]
t_train = t_train[:500]# 分割验证数据
validation_rate = 0.20
validation_num = int(x_train.shape[0] * validation_rate)
x_train, t_train = shuffle_dataset(x_train, t_train)
x_val = x_train[:validation_num]
t_val = t_train[:validation_num]
x_train = x_train[validation_num:]
t_train = t_train[validation_num:]def __train(lr, weight_decay, epocs=50):network = MultiLayerNet(input_size=784, hidden_size_list=[100, 100, 100, 100, 100, 100],output_size=10, weight_decay_lambda=weight_decay)trainer = Trainer(network, x_train, t_train, x_val, t_val,epochs=epocs, mini_batch_size=100,optimizer='sgd', optimizer_param={'lr': lr}, verbose=False)trainer.train()return trainer.test_acc_list, trainer.train_acc_list# 超参数的随机搜索======================================
optimization_trial = 100
results_val = {}
results_train = {}
for _ in range(optimization_trial):# 指定搜索的超参数的范围===============weight_decay = 10 ** np.random.uniform(-8, -4)lr = 10 ** np.random.uniform(-6, -2)# ================================================val_acc_list, train_acc_list = __train(lr, weight_decay)print("val acc:" + str(val_acc_list[-1]) + " | lr:" + str(lr) + ", weight decay:" + str(weight_decay))key = "lr:" + str(lr) + ", weight decay:" + str(weight_decay)results_val[key] = val_acc_listresults_train[key] = train_acc_list# 绘制图形========================================================
print("=========== Hyper-Parameter Optimization Result ===========")
graph_draw_num = 20
col_num = 5
row_num = int(np.ceil(graph_draw_num / col_num))
i = 0for key, val_acc_list in sorted(results_val.items(), key=lambda x:x[1][-1], reverse=True):print("Best-" + str(i+1) + "(val acc:" + str(val_acc_list[-1]) + ") | " + key)plt.subplot(row_num, col_num, i+1)plt.title("Best-" + str(i+1))plt.ylim(0.0, 1.0)if i % 5: plt.yticks([])plt.xticks([])x = np.arange(len(val_acc_list))plt.plot(x, val_acc_list)plt.plot(x, results_train[key], "--")i += 1if i >= graph_draw_num:breakplt.show()

神经网络如何调参、超参数的最优化方法、python实现相关推荐

  1. 【Coursera】深度神经网络的改进:超参数调整、正则化和优化(更新中2023/04/12)

    文章目录 1 Setting up your Machine Learning Application 1.1 Train / Dev / Test sets 1.2 Bias / Variance ...

  2. XGBoost-Python完全调参指南-参数解释篇

    本文转自XGBoost-Python完全调参指南-参数解释篇.对XGBoost有需要小伙伴可以参看.并在此基础上,添加了一些内容. 在analytics vidhya上看到一篇<Complete ...

  3. 单个GPU无法训练GPT-3,但有了这个,你能调优超参数了

    视学算法报道 编辑:陈萍.小舟 模型越大,超参数(HP)调优成本越高,微软联合 OpenAI 提出 HP 调优新范式,单个 GPU 上就可以调优 GPT-3 超参数. 伟大的科学成就不能仅靠反复试验取 ...

  4. 深度学习笔记第二门课 改善深层神经网络 第三周 超参数调试、Batch正则化和程序框架...

    本文是吴恩达老师的深度学习课程[1]笔记部分. 作者:黄海广[2] 主要编写人员:黄海广.林兴木(第四所有底稿,第五课第一二周,第三周前三节).祝彦森:(第三课所有底稿).贺志尧(第五课第三周底稿). ...

  5. 关于神经网络的调参经验技巧和调参顺序

    文章目录 前言 一.写在前面 超参数和参数区分 超参数选取尺度的考虑 二.调参顺序 1.learning rate 2.batch size 3.其他参数 3.1Hidden Units(隐含层单元数 ...

  6. 机器学习之超参数调优——超参数调优的方法

    超参数调优的方法 概述 网格搜索 随机搜索 贝叶斯优化算法 概述 对于很多算法工程师来说, 超参数调优是件非常头疼的事.除了根据经验设定所谓的"合 理值"之外, 一般很难找到合理的 ...

  7. Caret模型训练和调参更多参数解读(2)

    trainControl函数控制参数 trainControl函数用于定义train函数运行的一些参数,如交叉验证方式.模型评估函数.模型选择标准.调参方式等. 部分参数解释如下: method: 重 ...

  8. android 摄像头调参,摄像头参数调整方法和装置、电子设备和存储介质与流程

    技术特征: 1.一种摄像头参数调整数调整方法,包括: 对参考摄像头采集的第一图像进行检测,获取第一检测结果: 对待调参摄像头采集的第二图像进行检测,获取第二检测结果,其中,所述第一图像与所述第二图像是 ...

  9. 深度学习——夏侯南溪的深度神经网络的调参日志

    2019年12月4日: MTCNN--人脸检测和关键回归的CNN级联模型 baseline1: P-Net:lr = 0.001, batch = 256 Q-Net:lr = 0.001, batc ...

最新文章

  1. javascript基础系列:javascript中的变量和数据类型(一)
  2. mac pycharm 卸载_Mac上Virtual Box虚拟机Linux系统安装
  3. python索引用法_python 列表索引问题
  4. 使用客户端登陆ftp 500 OOPS: cannot change directory:/home/virftp解决
  5. 随想录(移动app下的生活)
  6. (转载)C语言右移运算符的问题(特别当与取反运算符一起时)
  7. QCache 缓存(类似于map的模板类,逻辑意义上的缓存Cache,方便管理,默认类似于LRU的淘汰算法)...
  8. wacom数位板怎么调压感_手绘板压感是什么 数位板压感怎么调【教程】
  9. JSP学科竞赛管理系统
  10. 多图详解IT架构师完整知识体系及技术栈
  11. 世界著名激励大师约翰·库缇斯的传奇人生
  12. 色彩三原色,RGB,CMYK
  13. STM32F1读取MLX90614ESF非接触式温度传感器
  14. 用计算机怎么谈黑人团队,光遇黑人抬棺乐谱怎么弹奏 计算机演奏乐谱16
  15. VON本源的内幕是什么?VON公链尊皇社区为你揭秘!
  16. HDOJ--1162--Eddy's picture
  17. 火影忍者服务器维护时间,1月4日停机更新公告
  18. 更新TKK失败,请检查网络连接的解决办法
  19. 异构群体机器人协作任务分配(群体智能论文学习)
  20. vue-amap 实现定位+跑步路程+跑步时间计算功能

热门文章

  1. Python openpyxl打开有公式的excel表取值错误的解决办法,Python openpyxl获取excel有公式的单元格的数值错误,Python操作excel(.xlsx)封装类
  2. Pytorch torchvision完成Faster-rcnn目标检测demo及源码详解
  3. python实验二报告_20172304 2019-2020-2 《Python程序设计》实验二报告
  4. [ZJOI2007]时态同步 树形DP
  5. mysql错误码1709_MySQL5.6出现ERROR 1709 (HY000): Index column size too large问题的解决方法...
  6. python 找质数的个数_用Python打造一款文件搜索工具,所有功能自己定义!
  7. thinkphp v5.0.11漏洞_ThinkPHP5丨远程代码执行漏洞动态分析
  8. windows7电脑删除文件特别慢怎么回事
  9. 400 bad request的原因意思和解决方法
  10. form:radiobuttons单选按钮i-check选中触发