为高效找到使损失函数的值最小的参数,关于最优化(optimization)提了很多方法。

http://ruder.io/optimizing-gradient-descent/  本网页介绍的原理可以参考。下面这段代码取自一本教材,比较经典。

# coding: utf-8
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDictclass SGD:"""随机梯度下降法(Stochastic Gradient Descent)"""def __init__(self, lr=0.01):self.lr = lrdef update(self, params, grads):for key in params.keys():params[key] -= self.lr * grads[key] class Momentum:"""Momentum SGD"""def __init__(self, lr=0.01, momentum=0.9):self.lr = lrself.momentum = momentumself.v = Nonedef update(self, params, grads):if self.v is None:self.v = {}for key, val in params.items():                                self.v[key] = np.zeros_like(val)for key in params.keys():self.v[key] = self.momentum*self.v[key] - self.lr*grads[key] params[key] += self.v[key]class Nesterov:"""Nesterov's Accelerated Gradient (http://arxiv.org/abs/1212.0901)"""def __init__(self, lr=0.01, momentum=0.9):self.lr = lrself.momentum = momentumself.v = Nonedef update(self, params, grads):if self.v is None:self.v = {}for key, val in params.items():self.v[key] = np.zeros_like(val)for key in params.keys():self.v[key] *= self.momentumself.v[key] -= self.lr * grads[key]params[key] += self.momentum * self.momentum * self.v[key]params[key] -= (1 + self.momentum) * self.lr * grads[key]class AdaGrad:"""AdaGrad"""def __init__(self, lr=0.01):self.lr = lrself.h = Nonedef update(self, params, grads):if self.h is None:self.h = {}for key, val in params.items():self.h[key] = np.zeros_like(val)for key in params.keys():self.h[key] += grads[key] * grads[key]params[key] -= self.lr * grads[key] / (np.sqrt(self.h[key]) + 1e-7)class RMSprop:"""RMSprop"""def __init__(self, lr=0.01, decay_rate = 0.99):self.lr = lrself.decay_rate = decay_rateself.h = Nonedef update(self, params, grads):if self.h is None:self.h = {}for key, val in params.items():self.h[key] = np.zeros_like(val)for key in params.keys():self.h[key] *= self.decay_rateself.h[key] += (1 - self.decay_rate) * grads[key] * grads[key]params[key] -= self.lr * grads[key] / (np.sqrt(self.h[key]) + 1e-7)class Adam:"""Adam (http://arxiv.org/abs/1412.6980v8)"""def __init__(self, lr=0.001, beta1=0.9, beta2=0.999):self.lr = lrself.beta1 = beta1self.beta2 = beta2self.iter = 0self.m = Noneself.v = Nonedef update(self, params, grads):if self.m is None:self.m, self.v = {}, {}for key, val in params.items():self.m[key] = np.zeros_like(val)self.v[key] = np.zeros_like(val)self.iter += 1lr_t  = self.lr * np.sqrt(1.0 - self.beta2**self.iter) / (1.0 - self.beta1**self.iter)         for key in params.keys():self.m[key] += (1 - self.beta1) * (grads[key] - self.m[key])self.v[key] += (1 - self.beta2) * (grads[key]**2 - self.v[key])params[key] -= lr_t * self.m[key] / (np.sqrt(self.v[key]) + 1e-7)def f(x, y):return x**2 / 20.0 + y**2def df(x, y):return x / 10.0, 2.0*yinit_pos = (-7.0, 2.0)
params = {}
params['x'], params['y'] = init_pos[0], init_pos[1]
grads = {}
grads['x'], grads['y'] = 0, 0optimizers = OrderedDict()
optimizers["SGD"] = SGD(lr=0.95)
optimizers["Momentum"] = Momentum(lr=0.1)
optimizers["AdaGrad"] = AdaGrad(lr=1.5)
optimizers["Adam"] = Adam(lr=0.3)idx = 1for key in optimizers:optimizer = optimizers[key]x_history = []y_history = []params['x'], params['y'] = init_pos[0], init_pos[1]for i in range(30):x_history.append(params['x'])y_history.append(params['y'])grads['x'], grads['y'] = df(params['x'], params['y'])optimizer.update(params, grads)x = np.arange(-10, 10, 0.01)y = np.arange(-5, 5, 0.01)X, Y = np.meshgrid(x, y) Z = f(X, Y)# for simple contour line  mask = Z > 7Z[mask] = 0# plot plt.subplot(2, 2, idx)idx += 1plt.plot(x_history, y_history, 'o-', color="red")plt.contour(X, Y, Z)#绘制等高线plt.ylim(-10, 10)plt.xlim(-10, 10)plt.plot(0, 0, '+')plt.title(key)plt.xlabel("x")plt.ylabel("y")plt.subplots_adjust(wspace =0, hspace =0)#调整子图间距
plt.show()

神经网络最优化方法比较(代码理解)相关推荐

  1. Deep Learning论文笔记之(五)CNN卷积神经网络代码理解

    Deep Learning论文笔记之(五)CNN卷积神经网络代码理解 zouxy09@qq.com http://blog.csdn.net/zouxy09          自己平时看了一些论文,但 ...

  2. 图卷积神经网络GCN的一些理解以及DGL代码实例的一些讲解

    文章目录 前言 GCN 传播公式 例1 例2 DGL中的GCN实例 dgl.DGLGraph.update_all 参考 前言 近些年图神经网络十分火热,因为图数据结构其实在我们的现实生活中更常见,例 ...

  3. DeepLearning tutorial(4)CNN卷积神经网络原理简介+代码详解

    FROM: http://blog.csdn.net/u012162613/article/details/43225445 DeepLearning tutorial(4)CNN卷积神经网络原理简介 ...

  4. 【HSI】高光谱的数据集分类深度学习实战及代码理解

    [HSI]高光谱的数据集分类深度学习实战及代码理解 文章目录 [HSI]高光谱的数据集分类深度学习实战及代码理解 一.配置文件编写 二.高光谱图像的处理 2.1图像数据变换 2.2 数据整合 2.3 ...

  5. 基于粒子群算法优化的ELMAN动态递归神经网络预测-附代码

    基于粒子群算法优化的ELMAN动态递归神经网络预测及其MATAB实现 文章目录 基于粒子群算法优化的ELMAN动态递归神经网络预测及其MATAB实现 1. 模型与算法描述 1.1 ELMAN神经网络预 ...

  6. fishnet:论文阅读与代码理解

    fishnet:论文阅读与代码理解 一.论文概述 二.整体框架 三.代码理解 四.总结 fishnet论文地址:http://papers.nips.cc/paper/7356-fishnet-a-v ...

  7. MTCNN算法与代码理解—人脸检测和人脸对齐联合学习

    MTCNN算法与代码理解-人脸检测和人脸对齐联合学习 写在前面 主页:https://kpzhang93.github.io/MTCNN_face_detection_alignment/index. ...

  8. 【Pytorch深度学习实践】B站up刘二大人之BasicCNN Advanced CNN -代码理解与实现(9/9)

    这是刘二大人系列课程笔记的 最后一个笔记了,介绍的是 BasicCNN 和 AdvancedCNN ,我做图像,所以后面的RNN我可能暂时不会花时间去了解了: 写在前面: 本节把基础个高级CNN放在一 ...

  9. YOLOV5代码理解——损失函数的计算

    YOLOV5代码理解--损失函数的计算 摘要: 神经网络的训练的主要流程包括图像输入神经网络, 得到模型的输出结果,计算模型的输出与真实值的损失, 计算损失值的梯度,最后用梯度下降算法更新模型参数.损 ...

最新文章

  1. 中国工程院《全球工程前沿2020》报告在京发布
  2. ubuntu16.04 安装微信和qq
  3. 黑马程序员--java基础--其他对象
  4. c语言 误差小于10 -6,上海理工大学C语言2011期中试题和答案
  5. Android Listview 性能优化
  6. 「ROI 2017 Day 2」反物质(单调队列优化dp)
  7. AAS的完整形式是什么?
  8. vdcode C语言不能弹出运行窗口_C语言编程常见问题分析,以及错误解决办法!
  9. mysql resultmap_MyBatis ResultMap
  10. 《Java程序员职场全功略:从小工到专家》连载三:IT语言平台
  11. 题目1026:又一版 A+B
  12. Vue学习笔记(尚硅谷天禹老师)
  13. 蚂蚁课堂视频笔记思维导图-3期 七、互联网高并发解决方案
  14. R包minfi处理DNA甲基化芯片数据
  15. 计算机 在哪看是什么32位,如何查看自己的电脑是32位的还是64位
  16. iot会议纪要 20180105
  17. 【汇正财经】股票价格有哪些偏向性特征?
  18. jovi语音助手安装包_Jovi语音助手安装包下载-vivoJovi语音助手v3.1.1.0 最新版-腾牛安卓网...
  19. 迁移学习基础知识(一)——分类及应用
  20. 【有趣的Python小程序】Pygame制作键盘彩色闪烁打字游戏KeyBoardFlash

热门文章

  1. 微型计算机原理课本,微机原理与接口技术课本.doc
  2. 基于Java的RDMA高性能通信库(六):SDP - Java Socket Direct Protocol
  3. 无线宝服务器连接不上,无线网络连接不上怎么办 为什么无线网络连接不上
  4. 如何垂直居中一个浮动元素
  5. 深入理解php底层:php生命周期 [转]
  6. wpf custom control
  7. WPF ,listbox,平滑滚动的2种方式。
  8. 【新品发布】山海软件生产线pspl,包含了一个开源的混淆器
  9. Android开发小问题集
  10. 电量检测芯片BQ27510使用心得