文章目录

  • 原理以及公式
    • 【1】一元线性回归问题
    • 【2】多元线性回归问题
    • 【3】学习率
    • 【4】流程分析(一元线性回归)
    • 【5】流程分析(多元线性回归)
      • 归一化原理以及每种归一化适用的场合
  • 一元线性回归代码以及可视化结果
  • 多元线性回归代码以及可视化结果
  • 总结

原理以及公式

【1】一元线性回归问题

原函数是一元函数(关于x),它的损失函数是二元函数(关于w和b)

这里介绍两种损失函数:平方损失函数和均方差损失函数

【2】多元线性回归问题

X和W都是m+1维的向量,损失函数是高维空间中的凸函数

【3】学习率

学习率属于超参数(超参数:在开始学习之前设置,不是通过训练得到的)
可以选择在迭代次数增加时减少学习率大小.
下图是学习率正常或较小、稍大、过大的迭代图。

【4】流程分析(一元线性回归)

过程分析:

1、加载样本数据x,y
2、设置超参数学习率,迭代次数
3、设置模型参数初值w0, b0
4、训练模型w, b
5、结果可视化

                                                     流程图:

【5】流程分析(多元线性回归)

归一化原理以及每种归一化适用的场合

线性归一化:适用于样本分布均匀且集中的情况,如果最大值(或者最小值)不稳定,和绝大数样本数据相差较大,使用这种方法得到的结果也不稳定.为了抑制这个问题,在实际问题中可以用经验值来代替最大值和最小值
标准差归一化适用于样本近似正态分布,或者最大最小值未知的情况,有时当最大最小值处于孤立点时也可以使用标准差归一化
非线性映射归一化,通常用于数据分化较大的情况(有的很大有的很小)
总结:样本属性归一化需要根据属性样本分布规律定制

过程分析:

加载样本数据area,room,price
数据处理归一化,X,Y
设置超参数学习率,迭代次数
设置模型参数初值W0(w0,w1,w2)
训练模型W
结果可视化

一元线性回归代码以及可视化结果

#解析法实现一元线性回归
# #Realization of one variable linear regression by analytic method
#导入库
import numpy as np
import matplotlib.pyplot as plt
#设置字体
plt.rcParams['font.sans-serif'] =['SimHei']
#加载样本数据
x=np.array([137.97,104.50,100.00,124.32,79.20,99.00,124.00,114.00,106.69,138.05,53.75,46.91,68.00,63.02,81.26,86.21])
y=np.array([145.00,110.00,93.00,116.00,65.32,104.00,118.00,91.00,62.00,133.00,51.00,45.00,78.50,69.65,75.69,95.30])
#设置超参数,学习率
learn_rate=0.00001
#迭代次数
iter=100
#每10次迭代显示一下效果
display_step=10
#设置模型参数初值
np.random.seed(612)
w=np.random.randn()
b=np.random.randn()
#训练模型
#存放每次迭代的损失值
mse=[]
for i in range(0,iter+1):#求偏导dL_dw=np.mean(x*(w*x+b-y))dL_db=np.mean(w*x+b-y)#更新模型参数w=w-learn_rate*dL_dwb=b-learn_rate*dL_db#得到估计值pred=w*x+b#计算损失(均方误差)Loss=np.mean(np.square(y-pred))/2mse.append(Loss)#显示模型#plt.plot(x,pred)if i%display_step==0:print("i:%i,Loss:%f,w:%f,b:%f"%(i,mse[i],w,b))
#模型和数据可视化
plt.figure(figsize=(20,4))
plt.subplot(1,3,1)
#绘制散点图
#张量和数组都可以作为散点函数的输入提供点坐标
plt.scatter(x,y,color="red",label="销售记录")
plt.scatter(x,pred,color="blue",label="梯度下降法")
plt.plot(x,pred,color="blue")#设置坐标轴的标签文字和字号
plt.xlabel("面积(平方米)",fontsize=14)
plt.xlabel("价格(万元)",fontsize=14)#在左上方显示图例
plt.legend(loc="upper left")#损失变化可视化
plt.subplot(1,3,2)
plt.plot(mse)
plt.xlabel("迭代次数",fontsize=14)
plt.ylabel("损失值",fontsize=14)
#估计值与标签值比较可视化
plt.subplot(1,3,3)
plt.plot(y,color="red",marker="o",label="销售记录")
plt.plot(pred,color="blue",marker="o",label="梯度下降法")
plt.legend()
plt.xlabel("sample",fontsize=14)
plt.ylabel("price",fontsize=14)
#显示整个绘图
plt.show()

多元线性回归代码以及可视化结果

#解析法实现多元线性回归
#Realization of multiple linear regression by analytic method
#导入库与模块
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
#=======================【1】加载样本数据===============================================
area=np.array([137.97,104.50,100.00,124.32,79.20,99.00,124.00,114.00,106.69,138.05,53.75,46.91,68.00,63.02,81.26,86.21])
room=np.array([3,2,2,3,1,2,3,2,2,3,1,1,1,1,2,2])
price=np.array([145.00,110.00,93.00,116.00,65.32,104.00,118.00,91.00,62.00,133.00,51.00,45.00,78.50,69.65,75.69,95.30])
num=len(area) #样本数量
#=======================【2】数据处理===============================================
x0=np.ones(num)
#归一化处理,这里使用线性归一化
x1=(area-area.min())/(area.max()-area.min())
x2=(room-room.min())/(room.max()-room.min())
#堆叠属性数组,构造属性矩阵
#从(16,)到(16,3),因为新出现的轴是第二个轴所以axis为1
X=np.stack((x0,x1,x2),axis=1)
print(X)
#得到形状为一列的数组
Y=price.reshape(-1,1)
print(Y)
#=======================【3】设置超参数===============================================
learn_rate=0.001
#迭代次数
iter=500
#每10次迭代显示一下效果
display_step=50
#=======================【4】设置模型参数初始值===============================================
np.random.seed(612)
W=np.random.randn(3,1)
#=======================【4】训练模型=============================================
mse=[]
for i in range(0,iter+1):#求偏导dL_dW=np.matmul(np.transpose(X),np.matmul(X,W)-Y)   #XT(XW-Y)#更新模型参数W=W-learn_rate*dL_dW#得到估计值PRED=np.matmul(X,W)#计算损失(均方误差)Loss=np.mean(np.square(Y-PRED))/2mse.append(Loss)#显示模型#plt.plot(x,pred)if i % display_step==0:print("i:%i,Loss:%f"%(i,mse[i]))
#=======================【5】结果可视化============================================
plt.rcParams['font.sans-serif'] =['SimHei']
plt.figure(figsize=(12,4))
#损失变化可视化
plt.subplot(1,2,1)
plt.plot(mse)
plt.xlabel("迭代次数",fontsize=14)
plt.ylabel("损失值",fontsize=14)
#估计值与标签值比较可视化
plt.subplot(1,2,2)
PRED=PRED.reshape(-1)
plt.plot(price,color="red",marker="o",label="销售记录")
plt.plot(PRED,color="blue",marker="o",label="预测房价")
plt.xlabel("sample",fontsize=14)
plt.ylabel("price",fontsize=14)
plt.legend()
plt.show()

总结

注意点:选择归一化方式


喜欢的话点个赞和关注呗!

利用梯度下降法求解一元线性回归和多元线性回归相关推荐

  1. 【Python】梯度下降法求解一元二次函数的波谷

    import random''' drd notes: 使用梯度下降法 求y=3x^2 + 7x - 10波谷时x的值 '''def my_function(x):# drd notes:y = 3x ...

  2. 梯度下降法求解多元线性回归 — NumPy

    梯度下降法求解多元线性回归问题 使用梯度下降法求解一元线性回归的方法也可以被推广到求解多元线性回归问题. 这是多元线性回归的模型: 其中的 X 和 W 都是 m+1 维的向量. 下图为它的损失函数: ...

  3. 梯度下降法求解线性回归

    梯度下降法求解线性回归 通过梯度下降法求解简单的一元线性回归 分别通过梯度下降算法和sklearn的线性回归模型(即基于最小二乘法)解决简单的一元线性回归实际案例,通过结果对比两个算法的优缺. 通过最 ...

  4. 基于jupyter notebook的python编程-----利用梯度下降算法求解多元线性回归方程,并与最小二乘法求解进行精度对比

    基于jupyter notebook的python编程-----利用梯度下降算法求解多元线性回归方程,并与最小二乘法求解进行精度对比目录 一.梯度下降算法的基本原理 1.梯度下降算法的基本原理 二.题 ...

  5. python 梯度下降法实现一元线性回归

    一.简单过一下算法流程 ''' 梯度下降法实现一元线性回归 一元线性函数: y = ax + b 实际数据服从: y = x + 2 初始模型: y = 0.1*x + 0.1 ''' import ...

  6. 利用梯度下降法实现线性回归--python实现

    利用梯度下降法代替最小二乘法,求线性回归方程. 首先引用库 import numpy as np import matplotlib.pyplot as plt 定义相应的x和y np.random. ...

  7. 机器学习——一元线性回归和多元线性回归

    一元线性回归:梯度下降法 一元线性回归是线性回归的最简单的一种,即只有一个特征变量.首先是梯度下降法,这是比较经典的求法.一元线性回归通俗易懂地说,就是一元一次方程.只不过这里的斜率和截距要通过最小二 ...

  8. matlab重复线性回归,(MATLAB)一元线性回归和多元线性回归

    (MATLAB)一元线性回归和多元线性回归 (MATLAB)一元线性回归和多元线性回归 (MATLAB)一元线性回归和多元线性回归1.一元线性回归 2.多元线性回归2.1数据说明 2.2程序运行结果 ...

  9. 机器学习:回归分析—— 一元线性回归、多元线性回归的简单实现

    回归分析 回归分析概述 基本概念 可以解决的问题 基本步骤和分类 线性回归 一元线性回归 多元线性回归 回归分析概述 基本概念 回归分析是处理多变量间相关关系的一种数学方法.相关关系不同于函数关系,后 ...

最新文章

  1. 2019长安大学ACM校赛网络同步赛 J Binary Number(组合数学+贪心)
  2. 1057 数零壹(PAT乙级 C++实现)
  3. 条件运算符(?:)和 $替代string.Format()
  4. 审计文件服务器的5个核心要素
  5. TCP、UDP、IP 协议分析(转)
  6. python入门指南小说-Python 入门指南
  7. H3C WA2220E-AG 设置本地MAC+PSK认证:mac-and-psk
  8. 如何从 ArcView 3.3 版本的工程迁移到 ArcGIS Desktop 10 ?
  9. 德语翻译-德语在线批量翻译软件
  10. spring 通过yml格式配置log日志
  11. 如何删除双系统中的其中一个(完全删除)
  12. 国产操作系统银河麒麟V10桌面版新手小白常见问题
  13. 基于matlab的动态心形图案
  14. 网站添加Google翻译代码
  15. 文件在另一个程序中打开,无法删除~【删除文件被占用问题】(保姆级教程,五种解决办法~)
  16. 单片机C语言之学习矩阵按键
  17. Ubuntu 16.04 + cuda-8.0 + cudnn-6.0 + Tensorflow1.4和Caffe(极其简单)
  18. mysql limit 01怎么理解_MySQL limit实际用法的详细解析
  19. android 开发蓝牙电子秤,GitHub - xiangbohua/scales-bridge: scales-bridge 电子称 蓝牙电子秤 连接库...
  20. 一把王者的时间就写完了一个nginx的web集群项目

热门文章

  1. 字符集_第07期:有关 MySQL 字符集的 SQL 语句
  2. 学习笔记-AngularJs(十)
  3. 原型 - 实现自己的jQuery
  4. 一步步构建大型网站架构 [转]
  5. Hibernate关联映射(一对多/多对多)
  6. Unity插件Gaia使用介绍
  7. [转]Bing Maps Tile System 学习
  8. 二分图之匈牙利算法模版
  9. gl.vertexAtteib3f P42 讲数据传给location参数指定的attribute变量
  10. (转)求单链表是否有环,环入口和环长