Pytorch深度学习(一):前馈神经网络(FNN)

参考B站课程:《PyTorch深度学习实践》完结合集
传送门:《PyTorch深度学习实践》完结合集

一、线性模型:

已知数据:
x=[1,2,3],y=[2,4,6]x=[1,2,3],\quad y=[2,4,6]x=[1,2,3],y=[2,4,6]
预测x=4x=4x=4时,yyy等于多少
建立线性模型
y^=w∗x,x,y^∈R\hat{y}=w*x,\quad x,\hat{y}\in\mathbb{R}y^​=w∗x,x,y^​∈R
估计权重www(weight),再来预测x=4x=4x=4时yyy的值。
定义MSE(Mean Square Error):
cost=1N∑n=1N(y^n−yn)2cost=\frac{1}{N}\sum_{n=1}^{N} (\hat{y}_n-y_n)^2cost=N1​n=1∑N​(y^​n​−yn​)2

二、普通方法

在可能的区间内(0,4.1)(0,\;4.1)(0,4.1)内遍历来试探最优的权重www使得误差MSE最小

import numpy as np
from matplotlib import pyplot as pltxdata = [1, 2, 3]
ydata = [2, 4, 6]def forward(x):return x * wdef loss(x, y):ypred = forward(x)return (ypred-y)**2wlist = []      # weight
mse = []    # mean squre error
for w in np.arange(0, 4.1, 0.1):print('w=', w)lsum = 0for x_val, y_val in zip(xdata, ydata):ypred_val = forward(x_val)lossval = loss(x_val, y_val)lsum += lossvalprint('\t', x_val, y_val, ypred_val, lossval)print('MSE=', lsum/3)wlist.append(w)mse.append(lsum/3)plt.plot(wlist, mse)
plt.ylabel('Mean squre error')
plt.xlabel('weight')
plt.show()


可见最佳的权重在2.0附近,但这种算法是不够好的,首先区间不好找,其次遍历的步长会影响精度,并且遍历会耗费了大量的计算力。

三、梯度下降法(Gradient Descent)

这是经典的优化算法,在局部确定MSE下降速度最快的方向(梯度的反方向)来求得局部解.
cost=1N∑n=1N(y^n−yn)2=1N∑n=1N(xn∗w−yn)2cost=\frac{1}{N}\sum_{n=1}^{N} (\hat{y}_n-y_n)^2=\frac{1}{N}\sum_{n=1}^{N} (x_n*w-y_n)^2cost=N1​n=1∑N​(y^​n​−yn​)2=N1​n=1∑N​(xn​∗w−yn​)2
则梯度方向
∂cost∂w=2N∑n=1Nxn(xn∗w−yn)\frac{\partial cost}{\partial w}=\frac{2}{N}\sum_{n=1}^{N} x_n(x_n *w -y_n)∂w∂cost​=N2​n=1∑N​xn​(xn​∗w−yn​)
更新权重www
w=w−0.01∗∂cost∂ww = w-0.01*\frac{\partial cost}{\partial w}w=w−0.01∗∂w∂cost​
其中,0.01是学习率,其不宜取得过大

import numpy as np
from matplotlib import pyplot as pltxdata = [1, 2, 3]
ydata = [2, 4, 6]w = 1.0     # stating weight
costlist = []def forward(x):return x * wdef cost(xs, ys):cost = 0for x, y in zip(xs, ys):ypred = forward(x)cost += (ypred - y) **2return cost / len(xs)def gradient(xs, ys):grad = 0for x, y in zip(xs, ys):grad += 2 * x * (x * w -y)return grad / len(xs)print('Predict(before training', 4, forward(4))     # starting valuefor epoch in range(100):cost_val = cost(xdata, ydata)costlist.append(cost_val)grad_val = gradient(xdata, ydata)w -= 0.01 * grad_val    # study rate = 0.01print('Epoch:', epoch, 'w=', w, 'loss=', cost_val)print('Predict(after training)', 4, forward(4))plt.plot(range(100), costlist)
plt.xlabel('epoch')
plt.ylabel('cost')
plt.title('Gradient Descent')
plt.show()


可见收敛效果较好。最后一步输出的权重以及,预测值如下

Epoch: 99 w= 1.9999444396553017 loss= 1.752432687141379e-08
Predict(after training) 4 7.999777758621207

四、随机梯度下降法(Stochastic Gradient Descent)

正如前所言,DG算法可能陷入局部,一旦碰到拐点等情况就没有办法求到全局最优,于是引入随机梯度下降法
优点:有可能避免陷入局部最优
缺点:计算量比DG更大

import numpy as np
from matplotlib import pyplot as pltxdata = [1, 2, 3]
ydata = [2, 4, 6]w = 1.0     # stating weight
costlist = []def forward(x):return x * wdef loss(x,y):ypred = forward(x)return (ypred - y) ** 2def gradient(xs, ys):return 2 * x *(x*w -y)print('Predict(before training', 4, forward(4))     # starting valuefor epoch in range(100):for x,y in zip(xdata, ydata):grad = gradient(x,y)w = w - 0.01*gradprint('\t grad :', x, y, grad)l = loss(x,y)print('progress:', epoch, 'w=', w, 'loss=', l)costlist.append(l)print('Predict(after training)', 4, forward(4))plt.plot(range(100), costlist)
plt.xlabel('epoch')
plt.ylabel('cost')
plt.title('Stochastic Gradient Descent')
plt.show()

progress: 99 w= 1.9999999999999236 loss= 5.250973729513143e-26
Predict(after training) 4 7.9999999999996945

可以看到同样的迭代步数,SDG算法最后产生的误差的数量级比DG算法更小

我们比较两种算法的误差图像:

但SDG算法存在计算量相对大的缺点,所以在实际应用当中可以部分数据用DG,部分数据用SDG。

五、前馈神经网络(FNN)

以上就是一个简单的前馈神经网络的例子,它的第一次“学习”过程如下:

这里由于我们推导出了误差(MSE)关于权重www的解析式 ∂cost∂w=2N∑n=1Nxn(xn∗w−yn)\frac{\partial cost}{\partial w}=\frac{2}{N}\sum_{n=1}^{N} x_n(x_n *w -y_n)∂w∂cost​=N2​n=1∑N​xn​(xn​∗w−yn​)所以带入相关数据和即可得到梯度值,从而更新www。然而实际情况可能更为复杂,不一定能直接求出解析式,仔细观察我们知道:

  • 我们不一定非得得到解析式,只需要得到每一次误差 losslossloss “反馈” 给权重www的值,用于更新www

这就是反馈神经网络(BPNN)的基本思想,具体见下一篇文章。

Pytorch深度学习(一):前馈神经网络(FNN)相关推荐

  1. 水很深的深度学习-Task03前馈神经网络

    本文参考 Datawhale:水很深的深度学习 深度学习(四)-前馈神经网络_未名湖畔的落叶-CSDN博客_前馈神经网络 神经元模型   在前馈神经网络中,各神经元分别属于不同的层.每一层的神经元可以 ...

  2. 深度学习3 前馈神经网络

    深度学习3 前馈神经网络 目录 深度学习3 前馈神经网络 1. 神经元模型(M-P) (1)公式 (2)运算 (3)结构 2. 感知机模型 (1)单层感知机 (2)多层感知器 (3)BP算法 1. 神 ...

  3. 深度学习(四)-前馈神经网络

      在前馈神经网络中,各神经元分别属于不同的层.每一层的神经元可以接收前一层神经元的信号,并产生信号输出到下一层.第 0 层叫输入层,最后一层叫输出层,其它中间层叫做隐藏层,相邻两层的神经元之间为全连 ...

  4. 深度学习之前馈神经网络(前向传播和误差反向传播)

    转自:https://www.cnblogs.com/Luv-GEM/p/10694471.html 这篇文章主要整理三部分内容,一是常见的三种神经网络结构:前馈神经网络.反馈神经网络和图网络:二是整 ...

  5. 深度学习:前馈神经网络

    对深度学习(或称神经网络)的探索通常从它在计算机视觉中的应用入手.计算机视觉属于人工智能领域,因深度学习技术而不断革新,并且计算机视觉的基础(光强度)是用实数来表示的,处理实数正是神经网络所擅长的. ...

  6. 深度学习入门——前馈神经网络

    前馈神经网络作为深度学习基础中的基础,是很多同学入门深度学习的必经之路.由于马上要迎来考试复习周,在这里简单记录一下学习心得. 感知机模型 感知机(perceptron)是深度学习中最基本的元素,很多 ...

  7. 【深度学习】前馈神经网络

    一.前馈神经网络 思维导图 线性问题分为两个: 1.与门 IN IN OUT 1 1 1 1 0 0 0 1 0 0 0 0 2.或门 IN IN OUT 1 1 1 1 0 1 0 1 1 0 0 ...

  8. 猿创征文|深度学习基于前馈神经网络完成鸢尾花分类

    大家我是猿童学!这次给大家带来的是基于前馈神经网络完成鸢尾花分类! 在本实验中,我们使用的损失函数为交叉熵损失:优化器为随机梯度下降法:评价指标为准确率. 一.小批量梯度下降法 在梯度下降法中,目标函 ...

  9. 深度学习~模糊神经网络(FNN)

    模糊神经网络(Fuzzy Neural Network, FNN) 背景 系统复杂度的增加,人工智能深度化发展 模糊数学创始人L. A. Zadeh, 1921. 当系统的复杂性增加时,我们使它精确化 ...

  10. PyTorch 深度学习实践 GPU版本B站 刘二大人第11讲卷积神经网络(高级篇)GPU版本

    第11讲 卷积神经网络(高级篇) GPU版本源代码 原理是基于B站 刘二大人 :传送门PyTorch深度学习实践--卷积神经网络(高级篇) 这篇基于博主错错莫:传送门 深度学习实践 第11讲博文 仅在 ...

最新文章

  1. 思科安全——企业安全棋局的“宇宙流”
  2. mysql扩展使用_mysql的扩展应用
  3. JAVA 求数组中的最大值
  4. 前端学习(1153):常量const01
  5. LeetCode 46. 全排列(回溯)
  6. @RequiresPermissionss是否可以填写多种权限标识,只要满足其一就可以访问?
  7. vSAN其实很简单-运维工程师眼里的vSAN
  8. xvidcore.dll not found视频播放问题
  9. Exchange2010安装指南
  10. 加速数据无限超高速空间免费虚拟主机无限大小 支持SSL
  11. 不支持16位应用程序,%1和64位电脑不兼容问题
  12. 腾讯笔试——安排机器 【 题目描述】小 Q 的公司最近接到 m 个任务, 第 i 个任务需要 xi 的时间去完成, 难度等级为 yi。 小 Q 拥有 n 台机器, 每台机器最长工作时间 zi, 机器等
  13. 单片机蓝桥杯——PWM呼吸灯
  14. 2022款联想小新air15和联想小新pro14哪个好
  15. Windows环境下安装RabbitMQ(官方文档中文版)
  16. 字符串操作——substr用法
  17. php mysql mvc_PHP MVC框架【Myphp】的编写
  18. Docker 使用--link出现Cannot link to /xxx, as it does not belong to异常
  19. 计算机网络16进制首部检验和,校验和
  20. 按键精灵调试三天,气到吐血!!快速开发脚本代码常见错误!绝对干货!

热门文章

  1. Qt学习笔记--文件读写(QFile、QDataStream、QTextStream)
  2. 初探:使用Jest进行React单元测试
  3. 哪些方法可以用来提高微信小程序的应用速度?
  4. 电工基础知识-配电室安全须知
  5. ASC文件 - CAN报文回放
  6. HDU 4125 Moles 线段树+KMP
  7. body onload
  8. word2019技巧:段落的段前段后单位行设置为磅
  9. 为何延时函数不起作用?
  10. AD-FMCOMMS3 使用matlab+Linux/No-OS传输QPSK信号