神经网络如何调参、超参数的最优化方法、python实现
神经网络如何调参、超参数的最优化方法、python实现
- 一、what is 超参数
- 二、超参数优化实验
一、what is 超参数
超参数是什么,其实就是,各层神经元数量、batch大小、学习率等人为设定的一些数。
数据集分为训练数据、测试数据、验证数据。
用测试数据评估超参数值的好坏,就可能导致超参数的值被调整为只拟合测试数据,所以加了个验证数据。
训练数据用于参数的学习,验证数据用于超参数的性能评估。
进行超参数最优化,重要的是,逐渐缩小超参数好值存在范围。
一开始大致设定一个范围,从范围中随机采样出超参数,用这个采样值进行识别精度评估,根据这个结果缩小超参数好值范围,然后重复上述操作。研究发现,随机采样效果好点。
二、超参数优化实验
接下来用MNISIT数据集进行超参数最优化,参考斯坦福大学的实验。
实验:最优化学习率和控制权值衰减强度系数这两个参数。
实验中,权值衰减系数初始范围1e- 8到1e- 4,学习率初始范围1e- 6到1e- 2。
随机采样体现在下面代码:
weight_decay = 10 ** np.random.uniform(-8, -4)lr = 10 ** np.random.uniform(-6, -2)
实验结果:
结果可以看出,学习率在0.001到0.01之间,权值衰减系数在1e-8到1e-6之间时,学习可以顺利进行。
观察可以使学习顺利进行的超参数范围,从而缩小值的范围。
然后可以从缩小的范围中继续缩小,然后选个最终值。
=========== Hyper-Parameter Optimization Result ===========
Best-1(val acc:0.8) | lr:0.008986830875594513, weight decay:3.716187805144909e-07
Best-2(val acc:0.76) | lr:0.007815234765792472, weight decay:8.723036800420108e-08
Best-3(val acc:0.73) | lr:0.004924088836198354, weight decay:5.044414627324654e-07
Best-4(val acc:0.7) | lr:0.006838530258012433, weight decay:7.678322790416307e-06
Best-5(val acc:0.69) | lr:0.0037618568422154793, weight decay:6.384663995933291e-08
Best-6(val acc:0.69) | lr:0.004818463383741305, weight decay:4.875486288914377e-08
Best-7(val acc:0.65) | lr:0.004659925318439445, weight decay:1.4968108648982665e-05
Best-8(val acc:0.64) | lr:0.005664124223619111, weight decay:6.070191899324037e-06
Best-9(val acc:0.56) | lr:0.003954240835144594, weight decay:1.5725686195018805e-06
Best-10(val acc:0.5) | lr:0.002554755378245952, weight decay:4.481334628759244e-08
Best-11(val acc:0.5) | lr:0.002855983685917335, weight decay:1.9598718051356917e-05
Best-12(val acc:0.47) | lr:0.004592998586693871, weight decay:4.888121831499798e-05
Best-13(val acc:0.47) | lr:0.0025326736070483947, weight decay:3.200796060402024e-05
Best-14(val acc:0.44) | lr:0.002645798359877985, weight decay:5.0830237860839325e-06
Best-15(val acc:0.42) | lr:0.001942571686958991, weight decay:3.0673143794194257e-06
Best-16(val acc:0.37) | lr:0.001289748323175032, weight decay:2.3690338828642213e-06
Best-17(val acc:0.36) | lr:0.0017017390582746337, weight decay:9.176068035802207e-05
Best-18(val acc:0.3) | lr:0.0015961247160317246, weight decay:1.3527453417413358e-08
Best-19(val acc:0.28) | lr:0.002261959202515378, weight decay:6.004620370338303e-05
Best-20(val acc:0.26) | lr:0.0008799239275589458, weight decay:4.600825912333848e-07
# coding: utf-8
import sys, os
sys.path.append(os.pardir) # 为了导入父目录的文件而进行的设定
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from common.multi_layer_net import MultiLayerNet
from common.util import shuffle_dataset
from common.trainer import Trainer(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)# 为了实现高速化,减少训练数据
x_train = x_train[:500]
t_train = t_train[:500]# 分割验证数据
validation_rate = 0.20
validation_num = int(x_train.shape[0] * validation_rate)
x_train, t_train = shuffle_dataset(x_train, t_train)
x_val = x_train[:validation_num]
t_val = t_train[:validation_num]
x_train = x_train[validation_num:]
t_train = t_train[validation_num:]def __train(lr, weight_decay, epocs=50):network = MultiLayerNet(input_size=784, hidden_size_list=[100, 100, 100, 100, 100, 100],output_size=10, weight_decay_lambda=weight_decay)trainer = Trainer(network, x_train, t_train, x_val, t_val,epochs=epocs, mini_batch_size=100,optimizer='sgd', optimizer_param={'lr': lr}, verbose=False)trainer.train()return trainer.test_acc_list, trainer.train_acc_list# 超参数的随机搜索======================================
optimization_trial = 100
results_val = {}
results_train = {}
for _ in range(optimization_trial):# 指定搜索的超参数的范围===============weight_decay = 10 ** np.random.uniform(-8, -4)lr = 10 ** np.random.uniform(-6, -2)# ================================================val_acc_list, train_acc_list = __train(lr, weight_decay)print("val acc:" + str(val_acc_list[-1]) + " | lr:" + str(lr) + ", weight decay:" + str(weight_decay))key = "lr:" + str(lr) + ", weight decay:" + str(weight_decay)results_val[key] = val_acc_listresults_train[key] = train_acc_list# 绘制图形========================================================
print("=========== Hyper-Parameter Optimization Result ===========")
graph_draw_num = 20
col_num = 5
row_num = int(np.ceil(graph_draw_num / col_num))
i = 0for key, val_acc_list in sorted(results_val.items(), key=lambda x:x[1][-1], reverse=True):print("Best-" + str(i+1) + "(val acc:" + str(val_acc_list[-1]) + ") | " + key)plt.subplot(row_num, col_num, i+1)plt.title("Best-" + str(i+1))plt.ylim(0.0, 1.0)if i % 5: plt.yticks([])plt.xticks([])x = np.arange(len(val_acc_list))plt.plot(x, val_acc_list)plt.plot(x, results_train[key], "--")i += 1if i >= graph_draw_num:breakplt.show()
神经网络如何调参、超参数的最优化方法、python实现相关推荐
- 【Coursera】深度神经网络的改进:超参数调整、正则化和优化(更新中2023/04/12)
文章目录 1 Setting up your Machine Learning Application 1.1 Train / Dev / Test sets 1.2 Bias / Variance ...
- XGBoost-Python完全调参指南-参数解释篇
本文转自XGBoost-Python完全调参指南-参数解释篇.对XGBoost有需要小伙伴可以参看.并在此基础上,添加了一些内容. 在analytics vidhya上看到一篇<Complete ...
- 单个GPU无法训练GPT-3,但有了这个,你能调优超参数了
视学算法报道 编辑:陈萍.小舟 模型越大,超参数(HP)调优成本越高,微软联合 OpenAI 提出 HP 调优新范式,单个 GPU 上就可以调优 GPT-3 超参数. 伟大的科学成就不能仅靠反复试验取 ...
- 深度学习笔记第二门课 改善深层神经网络 第三周 超参数调试、Batch正则化和程序框架...
本文是吴恩达老师的深度学习课程[1]笔记部分. 作者:黄海广[2] 主要编写人员:黄海广.林兴木(第四所有底稿,第五课第一二周,第三周前三节).祝彦森:(第三课所有底稿).贺志尧(第五课第三周底稿). ...
- 关于神经网络的调参经验技巧和调参顺序
文章目录 前言 一.写在前面 超参数和参数区分 超参数选取尺度的考虑 二.调参顺序 1.learning rate 2.batch size 3.其他参数 3.1Hidden Units(隐含层单元数 ...
- 机器学习之超参数调优——超参数调优的方法
超参数调优的方法 概述 网格搜索 随机搜索 贝叶斯优化算法 概述 对于很多算法工程师来说, 超参数调优是件非常头疼的事.除了根据经验设定所谓的"合 理值"之外, 一般很难找到合理的 ...
- Caret模型训练和调参更多参数解读(2)
trainControl函数控制参数 trainControl函数用于定义train函数运行的一些参数,如交叉验证方式.模型评估函数.模型选择标准.调参方式等. 部分参数解释如下: method: 重 ...
- android 摄像头调参,摄像头参数调整方法和装置、电子设备和存储介质与流程
技术特征: 1.一种摄像头参数调整数调整方法,包括: 对参考摄像头采集的第一图像进行检测,获取第一检测结果: 对待调参摄像头采集的第二图像进行检测,获取第二检测结果,其中,所述第一图像与所述第二图像是 ...
- 深度学习——夏侯南溪的深度神经网络的调参日志
2019年12月4日: MTCNN--人脸检测和关键回归的CNN级联模型 baseline1: P-Net:lr = 0.001, batch = 256 Q-Net:lr = 0.001, batc ...
最新文章
- javascript基础系列:javascript中的变量和数据类型(一)
- mac pycharm 卸载_Mac上Virtual Box虚拟机Linux系统安装
- python索引用法_python 列表索引问题
- 使用客户端登陆ftp 500 OOPS: cannot change directory:/home/virftp解决
- 随想录(移动app下的生活)
- (转载)C语言右移运算符的问题(特别当与取反运算符一起时)
- QCache 缓存(类似于map的模板类,逻辑意义上的缓存Cache,方便管理,默认类似于LRU的淘汰算法)...
- wacom数位板怎么调压感_手绘板压感是什么 数位板压感怎么调【教程】
- JSP学科竞赛管理系统
- 多图详解IT架构师完整知识体系及技术栈
- 世界著名激励大师约翰·库缇斯的传奇人生
- 色彩三原色,RGB,CMYK
- STM32F1读取MLX90614ESF非接触式温度传感器
- 用计算机怎么谈黑人团队,光遇黑人抬棺乐谱怎么弹奏 计算机演奏乐谱16
- VON本源的内幕是什么?VON公链尊皇社区为你揭秘!
- HDOJ--1162--Eddy's picture
- 火影忍者服务器维护时间,1月4日停机更新公告
- 更新TKK失败,请检查网络连接的解决办法
- 异构群体机器人协作任务分配(群体智能论文学习)
- vue-amap 实现定位+跑步路程+跑步时间计算功能
热门文章
- Python openpyxl打开有公式的excel表取值错误的解决办法,Python openpyxl获取excel有公式的单元格的数值错误,Python操作excel(.xlsx)封装类
- Pytorch torchvision完成Faster-rcnn目标检测demo及源码详解
- python实验二报告_20172304 2019-2020-2 《Python程序设计》实验二报告
- [ZJOI2007]时态同步 树形DP
- mysql错误码1709_MySQL5.6出现ERROR 1709 (HY000): Index column size too large问题的解决方法...
- python 找质数的个数_用Python打造一款文件搜索工具,所有功能自己定义!
- thinkphp v5.0.11漏洞_ThinkPHP5丨远程代码执行漏洞动态分析
- windows7电脑删除文件特别慢怎么回事
- 400 bad request的原因意思和解决方法
- form:radiobuttons单选按钮i-check选中触发