在做吴恩达老师的深度学习课程作业时,发现决策边界函数不好理解plot_decision_boundary(model , X , y)。将此函数理解记录下:
作业地址:https://blog.csdn.net/u013733326/article/details/79702148
绘制梯度下降算法图形或是决策边界,核心便在于知道plt.contourf函数的用法

plt.contourf函数

这里参考https://blog.csdn.net/qq_44669578/article/details/103348076?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-8.nonecase&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-8.nonecase

plt.contourf用来画出不同分类的边界线,也常常用来绘制等高线
1.生成数据点

x = np.arange(-5, 5, 1)
y = np.arange(0, 20, 2)
xx, yy = np.meshgrid(x, y)


2.对不同类的数据进行标记,即生成Z

z = np.square(xx) - yy > 0


3.生成边界图

plt.contourf(xx, yy, z, cmap=plt.cm.Spectral)
plt.scatter(xx, yy, c=z)
plt.show()


如果点设置得更密集一点,分界处会更光滑,可以看到完整的抛物线

完整代码

x = np.arange(-5, 5, 0.1)
y = np.arange(0, 20, 0.2)
xx, yy = np.meshgrid(x, y)z = np.square(xx) - yy > 0
#plt.cm.Spectral,在这的意思就是颜色会随Z的值变化
plt.contourf(xx, yy, z, cmap=plt.cm.Spectral)
plt.scatter(xx, yy, c=z)
plt.show()

2.plot_decision_boundary()函数理解

先附上完整代码

def plot_decision_boundary(model, X, y):x_min, x_max = X[0, :].min() - 1, X[0, :].max() + 1y_min, y_max = X[1, :].min() - 1, X[1, :].max() + 1h = 0.01xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))Z = model(np.c_[xx.ravel(), yy.ravel()]) Z = Z.reshape(xx.shape)plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)plt.ylabel('x2')plt.xlabel('x1')plt.scatter(X[0, :], X[1, :], c=np.squeeze(y), cmap=plt.cm.Spectral)

效果图

调用训练好的模型

绘制此决策边界的思路是:首先已经通过了神经网络拟合出了输入特征和标签的函数关系,然后生成间距很小的网格覆盖这些点(函数中用h表示网格点之间的距离),将网格的坐标送入训练好的神经网络,神经网络会为每个网格坐标输出一个预测值,注意,这个网格点非常多,占满了整张图,图中的红色和蓝色其实就是经过预测后的网格点所填充的颜色,以预测值0.5进行划分,预测值为0.5的这条线也即为红点和蓝点的区分线。

首先按照样本大小,以0.01的间距生成网格

 x_min, x_max = X[0, :].min() - 1, X[0, :].max() + 1y_min, y_max = X[1, :].min() - 1, X[1, :].max() + 1h = 0.01
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

然后用训练好的模型对网格上的所有的点进行预测,返回一个只含0/1的矩阵
然后改变矩阵Z的形状,因为在contourf函数中,当 X,Y,Z 都是 2 维数组时,它们的形状必须相同。

Z = model(np.c_[xx.ravel(), yy.ravel()])  #ravel()函数将多维降至一维,默认行优先
Z = Z.reshape(xx.shape)

然后画图,plt.cm.Spectral,在这的意思就是颜色会随Z的值变化

plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
plt.ylabel('x2')
plt.xlabel('x1')
plt.scatter(X[0, :], X[1, :], c=np.squeeze(y), cmap=plt.cm.Spectral)

决策边界绘制函数plot_decision_boundary()和plt.contourf函数详解相关推荐

  1. 决策边界绘制和plt.contourf函数讲解

    先讲解plt.contourf函数,然后用plt.contourf绘制决策边界 contourf contourf(*args, data=None, **kwargs) Plot contours. ...

  2. 【机器学习】逻辑回归案例二:鸢尾花数据分类,决策边界绘制逐步代码讲解

    逻辑回归案例二:鸢尾花数据分类,决策边界绘制逐步代码讲解 1 数据加载 2 数据EDA 3 模型创建及应用 3.1 数据切分 3.2 创建模型与分类 3.3 决策边界绘制 3.3.1 二分类决策边界绘 ...

  3. Matplot pyplot绘制单图,多子图不同样式详解,这一篇就够了

    Matplot pyplot绘制单图,多子图不同样式详解,这一篇就够了 1. 单图单线 2. 单图多线不同样式(红色圆圈.蓝色实线.绿色三角等) 3. 使用关键字字符串绘图(data 可指定依赖值为: ...

  4. python箱线图_Python 箱线图 plt.boxplot() 参数详解

    Python 绘制箱线图主要用 matplotlib 库里 pyplot 模块里的 boxplot() 函数. plt.boxplot() 参数详解 plt.boxplot(x, # 指定要绘制箱线图 ...

  5. python画三维平面-Python 绘制酷炫的三维图步骤详解

    通常我们用 Python 绘制的都是二维平面图,但有时也需要绘制三维场景图,比如像下面这样的: 这些图怎么做出来呢?今天就来分享下如何一步步绘制出三维矢量(SVG)图. 八面体 我们先以下面这个八面体 ...

  6. python画3d图-Python 绘制酷炫的三维图步骤详解

    通常我们用 Python 绘制的都是二维平面图,但有时也需要绘制三维场景图,比如像下面这样的: 这些图怎么做出来呢?今天就来分享下如何一步步绘制出三维矢量(SVG)图. 八面体 我们先以下面这个八面体 ...

  7. 【python教程入门学习】Python函数定义及传参方式详解(4种)

    这篇文章主要介绍了Python函数定义及传参方式详解(4种),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧 一.函数初识 1.定 ...

  8. 定时器 槽函数没执行_Web服务器项目详解 07 定时器处理非活动连接(上)

    点击"两猿社" 关注我们 Web服务器详解目录 00 项目概述 01 线程同步机制包装类 02 半同步/半反应堆线程池(上) 03 半同步/半反应堆线程池(下) 04 http连接 ...

  9. php 查找键名,array_key_exists()函数搜索数组键名步骤详解

    这次给大家带来array_key_exists()函数搜索数组键名步骤详解,array_key_exists()函数搜索数组键名的注意事项有哪些,下面就是实战案例,一起来看一下. array_key_ ...

最新文章

  1. 信息服务器已停止工作,游戏服务器已停止工作
  2. 《编程原本 》一2.1 变换
  3. Linux之系统文件管理
  4. Java 接口和抽象类的区别
  5. 华为云WeLink:智能工作空间,联接无限想象
  6. linux 审计工具auditd日志audit.log时间戳转换查看
  7. golang使用go-sql-driver实现mysql增删改操作
  8. class不生效 weblogic_weblogic部署常见问题
  9. 2021年低压电工考试试卷及低压电工作业模拟考试
  10. 高斯过程分类和高斯过程回归_高斯过程回归建模入门
  11. Optitrack光学动作捕捉
  12. 计算机网络冗余码计算
  13. PS 使用画笔修复工具去除文字
  14. 小牛的net程序开发之路
  15. Android进程系列1---进程基础
  16. 交叉报表制作--Smartbi报表工具一步完成
  17. 面向数据中心,浪潮存储双剑出鞘
  18. 如何实现ps的批量处理图片
  19. Python财务分析
  20. 【求助】winfrom怎么获取视频当前播放时间

热门文章

  1. 深圳各区对企业制定行业标准和国家标准的补贴,奖励5-200万
  2. java 读取pdf签名域_Java 获取PDF中的数字签名信息
  3. Arm中国开工礼:iPhone + AirPods Pro,我酸了!
  4. 十六进制转八进制(C语言版)
  5. 认识和选用常用的几种 GPRS 模块(转)
  6. 被孙杨遮挡LOGO的安踏,到底做错了什么?
  7. AdaNet: Adaptive Structural Learning of Artificial Neural Networks
  8. Android--ImageView读取本地路径图片
  9. 在windows中要使用计算机进行高级,2017年电大计算机上机操作题(带答案)
  10. 华为p40pro android11,90Hz的华为P40Pro用了半年?最流畅的安卓旗舰?