问题描述
如下二维特征,每个样本属于正样本(红色)或负样本(蓝色),实现二分类模型

单层神经网络
[w1w2][x1x2]+[b]=[y]\left[\begin{matrix} w_1 & w_2 \end{matrix}\right] \left[ \begin{matrix} x_1\\ x_2 \end{matrix}\right] + \left[ \begin{matrix} b \end{matrix}\right] = \left[ \begin{matrix} y \end{matrix}\right][w1​​w2​​][x1​x2​​]+[b​]=[y​]

class LogisticRegression(nn.Module):def __init__(self):super(LogisticRegression, self).__init__()self.lr = nn.Linear(2, 1)self.sm = nn.Sigmoid()def forward(self, x):out = self.lr(x)out = self.sm(out)return out

我们想查看模型的训练效果,需要把LinearLinearLinear层可视化
w1∗x1+w2∗x2+b=yw_1*x_1+w_2*x_2+b=yw1​∗x1​+w2​∗x2​+b=y
y=0y=0y=0时,x1x_1x1​和x2x_2x2​的关系如下:
x2=−w1∗x1−bw2x_2=\frac{-w_1*x_1-b}{w2}x2​=w2−w1​∗x1​−b​
取x1x_1x1​在[30,100][30,100][30,100]之间,使用神经网络LinearLinearLinear层的参数将x2x_2x2​的值求出,画出训练结果。

def vis_one_layer(logistic_model):w1, w2 = logistic_model.lr.weight.data.numpy()[0]b = logistic_model.lr.bias.data.numpy()[0]plot_x = np.arange(30, 100, 0.1)plot_y = (-w1 * plot_x - b) / w2plt.plot(plot_x, plot_y)plt.show()

多层神经网络
如果神经网络的层数不止一层,无法代入求得x1x_1x1​和x2x_2x2​的关系,应该如何可视化神经网络的预测结果?

class MyLogistic(nn.Module):def __init__(self, input_size):super().__init__()self.hidden_1 = nn.Linear(input_size, 64)self.hidden_2 = nn.Linear(64, 32)self.hidden_3 = nn.Linear(32, 16)self.output = nn.Linear(16, 1)self.relu_1 = nn.ReLU()self.relu_2 = nn.ReLU()self.relu_3 = nn.ReLU()self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.hidden_1(x)x = self.relu_1(x)x = self.hidden_2(x)x = self.relu_2(x)x = self.hidden_3(x)x = self.relu_3(x)x = self.output(x)x = self.sigmoid(x)return x

思路还是通过一系列点确定神经网络划分二分类问题的区域,只不过因为预测的结果不是线性的,所以要对区域内的点使用模型预测正负性,然后用不同颜色可视化,使用plt.contourfplt.contourfplt.contourf绘制轮廓线并填充。

def vis_result(x_data, y_data, model):x_min, x_max = x_data[:, 0].min() - 1, x_data[:, 0].max() + 1y_min, y_max = x_data[:, 1].min() - 1, x_data[:, 1].max() + 1print(x_min, x_max, y_min, y_max)h = 0.1xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))z = model(torch.from_numpy(np.c_[xx.ravel(), yy.ravel()]).float())z = z.reshape(xx.shape).detach().numpy()print(z)z = [z[i] >= 0.5 for i in range(len(z))]z = np.array(z)plt.contourf(xx, yy, z, alpha=0.3)for i in range(len(y_data)):if y_data[i] == 1:plt.scatter(x_data[i][0], x_data[i][1], c='r')else:plt.scatter(x_data[i][0], x_data[i][1], c='b')plt.show()

二维特征逻辑回归预测结果可视化相关推荐

  1. OpenCV之feature2d 模块. 2D特征框架(2)特征描述 使用FLANN进行特征点匹配 使用二维特征点(Features2D)和单映射(Homography)寻找已知物体 平面物体检测

    特征描述 目标 在本教程中,我们将涉及: 使用 DescriptorExtractor 接口来寻找关键点对应的特征向量. 特别地: 使用 SurfDescriptorExtractor 以及它的函数  ...

  2. 二维特征分类的基础_3D 分割分类总结

    三维深度学习的几种方法: 多视角(multi-view):通过多视角二维图片组合为三维物体,此方法将传统CNN应用于多张二维视角的图片,特征被view pooling procedure聚合起来形成三 ...

  3. 二维特征分类的基础_带你搞懂朴素贝叶斯分类算法

    贝叶斯分类是一类分类算法的总称,这类算法均以贝叶斯定理为基础,故统称为贝叶斯分类.而朴素朴素贝叶斯分类是贝叶斯分类中最简单,也是常见的一种分类方法.这篇文章我尽可能用直白的话语总结一下我们学习会上讲到 ...

  4. 二维特征分类的基础_纹理特征1:灰度共生矩阵(GLCM)

    GLCM复习备用: 纹理分析是对图像灰度(浓淡)空间分布模式的提取和分析.纹理分析在遥感图像.X射线照片.细胞图像判读和处理方面有广泛的应用.关于纹理,还没有一个统一的数学模型.它起源于表征纺织品表面 ...

  5. 逻辑回归预测瘀血阻络证||LogRegression 二分类 python3|五折交叉验证

    要求 把数据集分为训练集和测试集使用逻辑回归训练.预测,得出相应的分类指标准确率accuracy,精确率precision,召回率recall,F1-score,并画出最终的ROC曲线,得出AUC值. ...

  6. 吴恩达机器学习(二十六) 数据压缩与可视化、PCA

    文章目录 1.数据压缩 2.数据可视化 3.PCA 1.数据压缩   降维也是一种无监督学习的方法,降维并不需要使用数据的标签.   降维的其中一个目的是数据压缩,数据压缩不仅能够压缩数据,使用较少的 ...

  7. R语言逻辑回归预测分析付费用户

    原文链接:http://tecdat.cn/?p=967 对于某企业新用户,会利用大数据来分析该用户的信息来确定是否为付费用户,弄清楚用户属性,从而针对性的进行营销,提高运营人员的办事效率(点击文末& ...

  8. Kaggle泰坦尼克号船难--逻辑回归预测生存率

    Kaggle泰坦尼克号船难–逻辑回归预测生存率#一.题目 https://www.kaggle.com/c/titanic 二.题意分析 train.csv中有891条泰坦尼克号乘客的数据,包括这些乘 ...

  9. 使用逻辑回归预测用户是否会购买SUV

    往期推荐 机器学习100天学习计划 - 第1天 数据预处理 机器学习100天学习计划 - 第2天 线性回归 机器学习100天学习计划 - 第3天 多元线性回归 这是机器学习100天学习计划的第4天,我 ...

最新文章

  1. 继承和多态 1.0 -- 继承概念(is-a、has-a,赋值兼容规则,隐藏重定义)
  2. 分布式事务篇——第二章:分布式事务解决之2PC剖析
  3. 最近看Kafka源码,着实被它的客户端缓冲池技术优雅到了
  4. 工作区 暂存区 版本库之间的关系
  5. Hystrix能解决的问题
  6. 情人节:找一个程序员当老公的10大好处
  7. 2015年,我们一起经历的IT安全事件
  8. java asm 全称,java ASM
  9. XMPP即时通讯机制
  10. CS188-Project 4
  11. AOV网与拓扑排序、拓扑排序算法
  12. 【转】加班与加薪的秘密:一位华为工程师的经验分享
  13. 空手套白狼案例,18个月零成本开了 3 家健身房,分红400多万!
  14. 【mysql数据导入】数据导入时的几种方法
  15. 从未改过的网名,一如既往的孤荷凌寒——我的信息技术之路之五
  16. 网易互娱2017实习生招聘游戏研发工程师在线笔试第二场(神奇的数)
  17. RedHat(RHEL)6.2 X64 Oracle11g X64 安装参考文档
  18. mobi怎么在iphone上打开?
  19. html video 控件,HTML video controls 属性
  20. 杨辉三角c语言杭电,杭电 杨辉三角

热门文章

  1. 第四章——绕翼型的不可压缩流动
  2. 关于安卓Facebook接入时的坑
  3. Arduino 开发入门 学习笔记 Arduino编程基础
  4. 虚拟机如何设置静态IP
  5. 微PE装Win10详细教程:UEFI+GPT方式
  6. android 截屏函数_android截屏功能实现代码
  7. 懒汉模式在多线程中的问题
  8. 定时播放音乐程序之三:MCI设备的播放和控制
  9. python 典型变量分析
  10. 运算放大器的datasheet参数介绍