梯度下降法求解多元线性回归问题

使用梯度下降法求解一元线性回归的方法也可以被推广到求解多元线性回归问题。

这是多元线性回归的模型:

其中的 X 和 W 都是 m+1 维的向量。

下图为它的损失函数:

它也是一个高维空间中的凸函数,因此也可以使用梯度下降法来求解。
下图为它的权值更新算法:

代入偏导数,

可以得到最终的迭代公式:

问题描述

依然是房价预测的问题,这是一个二元线性回归问题。

需要注意的是,如果直接使用上图中的数据 x1 和 x2 来训练模型,就会因为面积(x1)值远远大于房间(x2)值而造成在学习过程中占主导,甚至决定性的地位,这显然是不合理的。

那应该怎么解决呢?

这时候应该将各个属性值进行归一化。

归一化

归一化又被称为标准化,是将数据的值限制在一定的范围之内。

在机器学习中,对所有属性进行归一化处理就是让它们处于同一个范围、同一个数量级下。这样才能更加的具有合理性。

使用归一化处理后,不仅可以使得模型更快的收敛到最优解,还可以提高学习器的精度。

归一化可以分为线性归一化、线性归一化、非线性映射归一化。

1、线性归一化

线性归一化是对原始数据的线性变换,转换函数如下:

线性归一化实现对原始数据的 等比例缩放。
归一化之后,所有的数据都会被映射到 [0,1] 之间。

这种归一化方法适合于样本数据分布比较均匀,比较集中的情况,而如果最大值或最小值不稳定,或者和绝大多数数据差距比较大的情况,使用这种方法得到的结果也会不稳定,为了避免这种情况,在实际应用中,可以使用经验常量来代替最大值和最小值。

2、标准差归一化

将数据集归一化为均值为0,方差为1的标准正态分布,转换函数如下:

其中,μ是均值,σ是标准差。
标准差归一化适合于样本近似于正态分布或者最大值和最小值未知的情况,有时最大值和最小值处于孤立点的情况也适用。

3、非线性映射归一化

对原始数据的非线性变换。常用的映射方法有指数、对数和正切等。非线性映射归一化适合于数据分化比较大的情况,也就是有的数据特别大、有的比较小。通过这种非线性映射归一化后,可以使数据变的更加均匀或者有特点。

样本数据的归一化需要根据实际数据的分布情况和特点来决定采用哪种方法。

这里的数据归一化方式选择线性归一化,归一化结果如下:

import numpy as np
import matplotlib.pyplot as plt# 第一步:加载数据
# area 是商品房面积
area = np.array([137.97, 104.50, 100.00, 124.32, 79.20, 99.00, 124.00, 114.00,106.69, 138.05, 53.75, 46.91, 68.00, 63.02, 81.26, 86.21])  # (16, )# room 是商品房房间数
room = np.array([3, 2, 2, 3, 1, 2, 3, 2,2, 3, 1, 1, 1, 1, 2, 2])# 第二步:样本数据归一化 —— 采用线性归一化
# x1 是商品房面积归一化后的结果
x1 = (area - area.min()) / (area.max() - area.min())
# x2 是商品房房间数归一化后的结果
x2 = (room - room.min()) / (room.max() - room.min())print(x1)
"""
[0.99912223 0.63188501 0.58251042 0.84935264 0.3542901  0.571538290.84584156 0.73612025 0.65591398 1.         0.07504937 0.0.23140224 0.17676103 0.37689269 0.43120474]可以看出最大值被归一化为1, 最小值被归一化为 0 。
"""
print(x2)
"""
[1.  0.5 0.5 1.  0.  0.5 1.  0.5 0.5 1.  0.  0.  0.  0.  0.5 0.5]
"""

代码实现

第一步:加载样本数据集,area,room,price

第二步:数据处理 — 样本数据归一化,X,Y

第三步:设置超参数 学习率,迭代次数

第四步:设置模型参数初值 W0 (w0、w1、w2)

第五步:训练模型 W
这里的训练模型的公式如下:

第六步:结果可视化

import numpy as np
import matplotlib.pyplot as plt# 设置字体
plt.rcParams['font.sans-serif'] = ['SimHei']# 第一步:加载数据
# area 是商品房面积
area = np.array([137.97, 104.50, 100.00, 124.32, 79.20, 99.00, 124.00, 114.00,106.69, 138.05, 53.75, 46.91, 68.00, 63.02, 81.26, 86.21])  # (16, )# room 是商品房房间数
room = np.array([3, 2, 2, 3, 1, 2, 3, 2,2, 3, 1, 1, 1, 1, 2, 2])# price 是样本房价
price = np.array([145.00, 110.00, 93.00, 116.00, 65.32, 104.00, 118.00, 91.00,62.00, 133.00, 51.00, 45.00, 78.50, 69.65, 75.69, 95.30])# 第二步:数据处理
num = len(area)# 创建元素值全为1的一维数组 x0
x0 = np.ones(num)
# x1 是商品房面积归一化后的结果
x1 = (area - area.min()) / (area.max() - area.min())
# x2 是商品房房间数归一化后的结果
x2 = (room - room.min()) / (room.max() - room.min())# 将 x0、x1、x2堆叠为形状为 (16, 3) 的二维数组
X = np.stack((x0, x1, x2), axis=1)# 将 price 转换为形状为 (16, 1) 的二维数组
Y = price.reshape(-1, 1)# 第三步:设置超参数 学习率,迭代次数
learn_rate = 0.0001
itar = 1000000  # 迭代次数为1000000次display_step = 50000  # 每循环50000次显示一次训练结果# 第四步:设置模型参数的初始值
np.random.seed(612)
W = np.random.randn(3, 1)# 第五步:训练模型 W
mse = []  # 这是个Python列表, 用来保存每次迭代后的损失值# 下面使用 for 循环来实现迭代
# 循环变量从 0 开始, 到 101 结束,循环 101 次, 为了描述方便, 以后就说迭代 100 次
# 同样, 当 i 等于 10 时, 我们就说第十次迭代
for i in range(0, itar + 1):# 首先计算损失函数对 W 的偏导数dL_dW = np.matmul(np.transpose(X), np.matmul(X, W)-Y)# 然后使用迭代公式更新 WW = W - learn_rate*dL_dW# 我们希望能够观察到每次迭代的结果, 判断是否收敛或者什么时候开始收敛# 因此需要使用每次迭代后的 W 来计算损失, 并且把它显示出来# 这里的 X 形状为 (16, 3), W 形状为 (3, 1), 得到 Y_PRED 的形状为 (16, 1)Y_PRED = np.matmul(X, W)  # 使用当前这次循环得到的W, 计算所有样本的房价的估计值Loss = np.mean(np.square(Y - Y_PRED)) / 2  # 使用房价的估计值和实际值计算均方误差mse.append(Loss)  # 把得到的均方误差加入列表 mseif i % display_step == 0:print("i:%i, Loss:%f" % (i, mse[i]))"""i:0, Loss:4368.213908i:500000, Loss:79.871073i:1000000, Loss:79.871073"""
print(W)
"""
[[51.39029673]
[48.74950958]
[28.66300756]]
"""# 第六步:样本数据可视化# 创建Figure对象
plt.figure(figsize=(10, 6))plt.subplot(1, 2, 1)
plt.plot(range(0, 5000), mse[0:5000])
plt.xlabel('Iteration', color='r', fontsize=14)
plt.ylabel('Loss', color='r', fontsize=14)
plt.title("前5000次迭代的损失值变化曲线图", fontsize=14)plt.subplot(1, 2, 2)
Y_PRED = Y_PRED.reshape(-1)
plt.plot(price, color="red", marker='o', label="销售记录")
plt.plot(Y_PRED, color="blue", marker='.', label="预测房价")
plt.xlabel('Sample', color='r', fontsize=14)
plt.ylabel('Price', color='r', fontsize=14)
plt.title("估计值 & 标签值", fontsize=14)
plt.legend(loc="upper right")plt.suptitle("梯度下降法求解多元线性回归", fontsize=18)# 将创建好的图像显示出来
plt.show()

运行结果如下:

梯度下降法求解多元线性回归 — NumPy相关推荐

  1. 利用梯度下降法求解一元线性回归和多元线性回归

    文章目录 原理以及公式 [1]一元线性回归问题 [2]多元线性回归问题 [3]学习率 [4]流程分析(一元线性回归) [5]流程分析(多元线性回归) 归一化原理以及每种归一化适用的场合 一元线性回归代 ...

  2. 梯度下降法求多元线性回归及Java实现

    为什么80%的码农都做不了架构师?>>>    对于数据分析而言,我们总是极力找数学模型来描述数据发生的规律, 有的数据我们在二维空间就可以描述,有的数据则需要映射到更高维的空间.数 ...

  3. excel计算二元线性回归_用人话讲明白梯度下降Gradient Descent(以求解多元线性回归参数为例)...

    文章目录 1.梯度 2.多元线性回归参数求解 3.梯度下降 4.梯度下降法求解多元线性回归 梯度下降算法在机器学习中出现频率特别高,是非常常用的优化算法. 本文借多元线性回归,用人话解释清楚梯度下降的 ...

  4. 梯度下降法求解线性回归

    梯度下降法求解线性回归 通过梯度下降法求解简单的一元线性回归 分别通过梯度下降算法和sklearn的线性回归模型(即基于最小二乘法)解决简单的一元线性回归实际案例,通过结果对比两个算法的优缺. 通过最 ...

  5. 基于jupyter notebook的python编程-----利用梯度下降算法求解多元线性回归方程,并与最小二乘法求解进行精度对比

    基于jupyter notebook的python编程-----利用梯度下降算法求解多元线性回归方程,并与最小二乘法求解进行精度对比目录 一.梯度下降算法的基本原理 1.梯度下降算法的基本原理 二.题 ...

  6. TensorFlow实现梯度下降法求解一元和多元线性回归问题

    使用TensorFlow求解一元线性回归问题 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt# 设 ...

  7. python多元线性回归代码_Python实现梯度下降算法求多元线性回归(一)

    预备知识及相关文档博客 学习吴恩达机器学习课程笔记,并用python实现算法 python numpy基本教程: numpy相关教程 数据来自于UCI的机器学习数据库: UCI的机器学习数据库 pyt ...

  8. python 梯度下降法实现一元线性回归

    一.简单过一下算法流程 ''' 梯度下降法实现一元线性回归 一元线性函数: y = ax + b 实际数据服从: y = x + 2 初始模型: y = 0.1*x + 0.1 ''' import ...

  9. 深度学习基础之-2.2用梯度下降法求解w,b

    用梯度下降法求解w,b. 预设函数 Hypothesis Function z=wx+bz = wx+bz=wx+b 损失函数 Loss Function J(w,b)=12(z−y)2J(w,b) ...

最新文章

  1. 自动驾驶的实现之路——几大关键传感器应用解析
  2. IOS开发笔记16-Object-C中的属性
  3. java.net.SocketException四大异常解决方案---转
  4. JDK1.8新特性:Stream流
  5. 深入浅出jQuery (五) 如何自定义UI-Dialog?
  6. bootstrap 单选按钮点击change事件 只触发一次_微信支付新增“确认”按钮,付错钱将成为历史?...
  7. 马尔可夫蒙特卡罗 MCMC 原理及经典实现
  8. (原创) JavaScript是什么?
  9. 删除mysql主键语句_MySQL主键添加/删除
  10. 数据中心运维管理经验39条
  11. 关于在Ubuntu安装JLink驱动的最简便方法
  12. C#/.NET 通过代码一键清理IE缓存文件/强制重置IE设置
  13. 年度双十佳广告爆笑金庸版
  14. 程序员叫啥名字_网友:什么是好程序员?腾讯员工:首先起个“配”自己的网名!...
  15. 树莓派3+安装centos
  16. 苹果审核状态为Metadata Rejected下的问题
  17. 服务器未能保存文件夹,Exchange服务器提示 Event ID 50 Ntfs (Ntfs) {延迟写入失败} Windows 无法保存文件...
  18. unity3d英语单词拼写小游戏Pics Quiz Maker With Categories 3.0
  19. 【DQN高级技巧2】DQN高估问题:Target Network和Double DQN
  20. 【我的架构之路】什么是代理服务器以及什么是负载均衡?

热门文章

  1. 201111621401-白乐乐-思维导图
  2. smartscreen筛选器阻止了这个不安全的下载
  3. 盘点:App 移动自动化测试工具
  4. Pytorch以及tensorflow中KLdivergence的计算
  5. 华为云计算相关知识点
  6. 超市信息管理系统的测试用例
  7. 生产环境下nginx代理跨域解决
  8. 数据库的多表查询操作-查询只选修了1门课程的学生,显示学号、姓名、课程名。
  9. Python机器学习日记4:监督学习算法的一些样本数据集(持续更新)
  10. Win7 64的cmd控制台进入下级目录和返回上级目录(上级目录cd .. 下级目录cd+文件名称)