网络训练

这一节主要是训练网络来获得网络中的权重。把神经网络当作输入变量x到输出变量y的参数化非线性函数。使用与1.1节的方法对多项式拟合的问题,所以需要最小化平方和误差。现在给定一个由输入向量{xn}组成的训练集,和对应的目标tn,最小化误差函数:

回归问题

假定t服从高斯分布,均值与x相关,神经网络的输出为:

p(t|x,w)=N(t|y(x,w),β−1)p(t|x,w)=N(t|y(x,w),β−1)

p(t|x,w) = N(t|y(x,w),\beta^{-1})
将输出单元激活函数取成恒等函数,这样的话可以近似任何从想到y的连续函数。给定N个独立同分布的观测数据集X={x1,x2,,,}以及对应的目标变量t={t1,t2,,,tn},构成似然函数:

p(t|X,w,β)=∏p(tn|xn,w,β)p(t|X,w,β)=∏p(tn|xn,w,β)

p(t|X,w,\beta)=\prod p(tn|xn,w,\beta)
然后再取负对数,就得到误差函数:

β2∑[y(x,w)−tn]2−N2lnβ+N2ln2πβ2∑[y(x,w)−tn]2−N2ln⁡β+N2ln⁡2π

\frac{\beta}{2}\sum{[y(x,w)-tn]}^2-\frac{N}{2}\ln\beta+\frac{N}{2}\ln2\pi
首先考虑w。把最大似然函数等价于最小平方和误差函数:

E(w)=12∑[y(xn,w)−tn]2E(w)=12∑[y(xn,w)−tn]2

E(w)=\frac{1}{2}\sum{[y(xn,w)-tn]}^2
通过最小化得到的w被记作wML。
注意一点的是,在实际中神经网络函数y(xn,w)的非线性性质导致误差函数E(w)不是凸函数,在应用中寻找的是似然函数的局部最大值,对英语的误差的局部最小值。
现在使用已经找到的wML来寻找β:

1β=1N∑[y(xn,wML)−tn]21β=1N∑[y(xn,wML)−tn]2

\frac{1}{\beta}=\frac{1}{N}\sum{[y(xn,wML)-tn]}^2
现在考虑多个目标变量,假定w和x条件下,目标变量是相互独立的,那么目标变量的条件分布:

p(t|x,w)=N(t|y(x,w),ββ−1I)p(t|x,w)=N(t|y(x,w),ββ−⁡1⁡I)

p(t|x,w)=N(t|y(x,w),\sideset{}{^-1}\beta I)
这种情况下噪声的精度为:

1βML=1NK∑(y(xn,wML)−tn)21βML=1NK∑(y(xn,wML)−tn)2

\frac{1}{\beta ML}=\frac{1}{NK}\sum(y(xn,wML)-tn)^2

二分类

一个单一目标变量t,且t=1为C1、t=0为C2。现在只考虑单一输出,激活函数:

y=σ(a)=11+exp(−a)y=σ(a)=11+exp(−a)

y=\sigma(a)=\frac{1}{1+exp(-a)}
现在给定输入,目标变量的条件概率分布是一个伯努利分布:

p(t|x,w)=y(x,w)t[1−y(x,w)]1−tp(t|x,w)=y(x,w)t[1−y(x,w)]1−t

p(t|x,w)=y(x,w)^t{[1-y(x,w)]^{1-t}}
如果是考虑一个由独立的观测组成的训练集,由负对数似然函数就是由一个交叉熵误差函数:

E(w)=−∑[tnlnyn+(1−tn)ln(1−yn)]E(w)=−∑[tnln⁡yn+(1−tn)ln⁡(1−yn)]

E(w)=-\sum[tn\ln yn+(1-tn)\ln(1-yn)]
有人提出在分类问题上,使用交叉熵误差函数会使训练速度更快,同时提高泛化能力。
现在有k个二元分类问题,则目标向量的条件概率:

p(t|x,w)=∏(yk(x,w)tk[1−yk(x,w)]1−tk)p(t|x,w)=∏(yk(x,w)tk[1−yk(x,w)]1−tk)

p(t|x,w)=\prod(yk(x,w)^{tk}[1-yk(x,w)]^{1-tk})
就可以推出误差函数:

E(w)=−∑∑[tnklnynk+(1−tnk)ln(1−ynk)]E(w)=−∑∑[tnkln⁡ynk+(1−tnk)ln⁡(1−ynk)]

E(w)=-\sum \sum[tnk\ln ynk+(1-tnk)\ln(1-ynk)]
假设使用标准的两层神经网络,第一层的权向量由各个输出共享,而在线性模型中每个分类问题都是独立解决。第一层被看成进行非线性的特征提取,不同的输出之间共享特征可以节省计算量。
前面这些都是说明了对于不同问题,要选取不同的误差函数,可以有效地进行计算提高计算能力和泛化能力。

参数最优化

我们把E(w)看作位于权空间的一个曲面。咋权空间中走一小步,从w到w+δw,误差函数变为:δE=δw.T▽E(w)。由于E(w)是w光滑连续函数,则最小值位于权空间中误差函数梯度等于零的位置上:

∇E(w)=0∇E(w)=0

\nabla E(w) = 0
如果最小值不在这个位置上,就沿着

−∇E(w)−∇E(w)

- \nabla E(w)方向走,进一步减小误差。
但是在误差函数上可能存在很多个驻点,梯度为零的那个驻点就是最小值,所以又提出了通过迭代的数值方法来计算最小值:

wτ+1=wτ+Δwτwτ+1=wτ+Δwτ

w^{\tau+1}=w^\tau + \Delta w^\tau

局部二次近似

这个是为了判断所找的驻点就是最小值。
现在把误差函数E(w’)在w出泰勒展开:

E(w)=E(w′)+(w−w′)Tb+12(w−w′)TH(w−w′)E(w)=E(w′)+(w−w′)Tb+12(w−w′)TH(w−w′)

E(w)=E(w')+(w-w')^Tb+\frac{1}{2}(w-w')^TH(w-w')
其中b的定义:

b≡∇Ew=w′b≡∇Ew=w′

b\equiv\nabla E_{w=w'}
海森矩阵H为:

(H)ij=∂E∂wi∂wj|w=w′(H)ij=∂E∂wi∂wj|w=w′

(H)_{ij}=\frac{\partial E}{\partial w_{i} \partial w_{j}} |_{w=w'}
则梯度的局部近似为:

∇E=b+H(w−w′)∇E=b+H(w−w′)

\nabla E = b + H(w-w')
考虑一个特殊情况,在误差函数最小值点w‘附近的二次近似。因为误差函数在那点处为零,所示公式变成了:

E(w)=E(w∗)+12(w−w∗)TH(w−w∗)E(w)=E(w∗)+12(w−w∗)TH(w−w∗)

E(w)=E(w*)+\frac{1}{2}(w-w*)^TH(w-w*)
其中H的特征值方程:

Hui=λiuiHui=λiui

Hu_{i}=\lambda _{i}u_{i}

uTiuj=δijuiTuj=δij

u_{i} ^T u_{j} = \delta_{ij}

w−w∗=∑αiμiw−w∗=∑αiμi

w-w*=\sum \alpha _{i}\mu_{i}
这样误差公式可以得到:

E(w)=E(w∗)+12∑λiα2iE(w)=E(w∗)+12∑λiαi2

E(w)=E(w*)+\frac{1}{2}\sum \lambda_{i}\alpha_{i}^2
其中H是正定矩阵。
这样得到一个验证方法:当求出一个驻点,可以通过海森矩阵是否是正定矩阵来判断是否是最小值点。

使用梯度信息

文中提到使用梯度信息可以大幅度加快找到极小值点的速度。
在上面给出的误差函数的二次近似中,误差函数由w和b确定,包含w(w+3)/2个元素,W是w的维度。这个二次近似的极小值点的位置因此依赖于O(W2)个参数。如果不使用梯度信息,必须进行O(w2)次函数求值,每次求值都需要O(w)个步骤。因此需要O(W3)的计算复杂度。
而使用梯度信息,由于每次计算E的梯度都有W条信息,预计找到极小值需要O(w)次梯度。通过反向误差传播,这样的计算需要O(W)步骤,可以在O(W2)步骤内找到极小值。

梯度下降最优化

这一节讲的是梯度下降的优化,主要介绍了梯度下降,批梯度下降、随机梯度下降。
最开始提出使用迭代数值的方式求得w的值,公式为:

wτ+1=wτ+Δwτwτ+1=wτ+Δwτ

w^{\tau+1}=w^\tau + \Delta w^\tau
现在更改更新的方式,每一次权值更新都是在负梯度方向上进行移动:

wτ+1=wτ−η∇E(wτ)wτ+1=wτ−η∇E(wτ)

w^{\tau+1}=w^\tau - \eta\nabla E( w^\tau)
这种方法被称为梯度下降,权值向量沿着误差函数下降速度最快的方向移动。
也还有批量梯度法和共轭梯度法、拟牛顿法,这些算法都有:误差函数在每次迭代时总是减小,除非权向量达到了局部或者全局的最小值。
现在讲解在线梯度下降。
基于一组独立观测的最大似然函数的误差函数由一个求和式构成,求和式每一项对应着一个数据点:

E(w)=∑En(w)E(w)=∑En(w)

E(w)=\sum E_{n}(w),在线梯度也被称作顺序梯度下降或者随机梯度下降,使得权向量的更新每次只依赖于一个数据点:

wτ+1=wτ−η∇E(wτ)wτ+1=wτ−η∇E(wτ)

w^{\tau+1}=w^\tau - \eta\nabla E( w^\tau)

PRML5.2--网络训练相关推荐

  1. Wide Deep的OneFlow网络训练

    Wide & Deep的OneFlow网络训练 HugeCTR是英伟达提供的一种高效的GPU框架,专为点击率(CTR)估计训练而设计. OneFlow对标HugeCTR搭建了Wide & ...

  2. 二值网络训练--A Empirical Study of Binary Neural Networks' Optimisation

    A Empirical Study of Binary Neural Networks' Optimisation ICLR2019 https://github.com/mi-lad/studyin ...

  3. 【深度学习】快照集成等网络训练优化算法系列

    [深度学习]快照集成等网络训练优化算法系列 文章目录 1 什么是快照集成? 2 什么是余弦退火学习率? 3 权重空间中的解决方案 4 局部与全局最优解 5 特别数据增强 6 机器学习中解决数据不平衡问 ...

  4. 图像识别python cnn_MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(一)...

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 全连接神经网络是深度学习的基础,理解它就可以掌握深度学习的核心概念:前向传播.反向误差传递.权重.学习 ...

  5. DL之AlexNet:AlexNet算法的架构详解、损失函数、网络训练和学习之详细攻略

    DL之AlexNet:AlexNet算法的架构详解.损失函数.网络训练和学习之详细攻略 相关文章 Dataset:数据集集合(CV方向数据集)--常见的计算机视觉图像数据集大集合(建议收藏,持续更新) ...

  6. 如何绘制caffe网络训练曲线

    本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/51774966 当我们设计好网络结构后, ...

  7. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)...

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  8. [深度学习-实践]Tensorflow 2.x应用ResNet SeNet网络训练cifar10数据集的模型在测试集上准确率 86%-87%,含完整代码

    环境 tensorflow 2.1 最好用GPU Cifar10数据集 CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题.任务的目标对一组32x32 RGB的图像进行分类,这个数据集涵 ...

  9. 设置随机种子之后,网络训练结果仍然不同的解决方法(针对随机采样的数据集)torch设置随机种子,num_worker对数据采样的影响。

    网络训练结果无法复现 设置随机种子 应该为torch, numpy,以及Python设置随机种子,并提高torch卷积精度. def set_seed(seed):random.seed(seed)n ...

  10. Faster-Rcnn 网络训练医学乳腺DDSM图像不能预测到定位框问题及其训练问题

    在faster-rcnn 网络训练中,不同的数据集所带来的问题是不同的,首先明确医学数据集以及常见的RGB数据的区别: 1.医学数据集是超分辨率数据集(DDSM), 其长宽值较高, 而常用的RGB图像 ...

最新文章

  1. mysql 各种恢复_Mysql数据库备份和还原常用的命令
  2. 论文盘点:基于图卷积GNN的多目标跟踪算法解析
  3. java更新新的知识要怎么知道_晟司小蒙告诉你,Java技术知识点,不定时更新!!!...
  4. JavaScript基础部分
  5. 双系统安装和ros安装踩坑
  6. 小米架构调整:拆分成立人工智能部,直接向CEO雷军汇报
  7. 是否可以在git中预览藏匿内容?
  8. 轻松学会硬盘还原卡的安装和使用
  9. AbstractQueuedSynchronizer浅析
  10. 一本通1373:鱼塘钓鱼(fishing)
  11. Postman高级用法
  12. 分式加法JAVA程序_分式加减运算的八种技巧,有几种方法学校老师没讲过,记得收藏...
  13. 解决Windows10更新后点击左下角开始图标无反应【报错0x800f081f】或点击个性化提示【ms-settings:personalisation-background】错误
  14. SDHC ADMA和SDMA区别
  15. 第二届金融交易技术大会拥抱Fin Tech-创新、科技、融合在沪圆满落幕!
  16. ▷Scratch课堂丨【编程趣味卡3】制作音乐
  17. 6-4 计算圆柱体的表面积(函数名隐藏)
  18. 用Dynamips构建能够与真实机器通信的IPSec ***环境
  19. [codeforces 718E]Matvey's Birthday
  20. 工作第十四周:整理收藏夹、旧文章有感

热门文章

  1. 2016 下半年网络工程师上午真题及解析
  2. Zigbee 概念理解
  3. uniapp 运行模拟器 (MUMU)
  4. cityengine快速创建城市模型
  5. 【C++】八皇后问题(竖列递进)
  6. Spark SQL 内置函数(五)Aggregate Functions(基于 Spark 3.2.0)
  7. HTML5七夕情人节表白网页(流星雨3D旋转相册) HTML+CSS+JS 求婚 html生日快乐祝福代码网页 520情人节告白代码 程序员表白源码 3D旋转相册 js烟花代码
  8. 我的idea偏好设置
  9. redis服务器配置
  10. mysql修改视图定义_MySQL修改视图