4.1 多层感知机(分类)

这篇文章开始就是深度学习了。多层感知机的架构是这样:

输入层除了提供数据之外,不干任何事情。隐层和输出层的每个节点都计算一次线性变换,并应用非线性激活函数。隐层的激活函数是压缩性质的函数。输出层的激活函数取决于标签的取值范围。

其本质上相当于广义线性回归模型的集成。

操作步骤

导入所需的包。

import tensorflow as tf
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import sklearn.datasets as ds
import sklearn.model_selection as ms

导入数据,并进行预处理。我们使用鸢尾花数据集所有样本。根据萼片长度和花瓣长度预测样本是不是杂色鸢尾(第二种)。要注意杂色鸢尾在另外两种之间,所以它不是线性问题。

iris = ds.load_iris()x_ = iris.data[:, [0, 2]]
y_ = (iris.target == 1).astype(int)
y_ = np.expand_dims(y_ , 1)x_train, x_test, y_train, y_test = \ms.train_test_split(x_, y_, train_size=0.7, test_size=0.3)

定义超参数。

变量 含义
n_input 样本特征数
n_epoch 迭代数
n_hidden1 隐层 1 的单元数
n_hidden2 隐层 2 的单元数
lr 学习率
threshold 如果输出超过这个概率,将样本判定为正样本
n_input = 2
n_hidden1 = 4
n_hidden2 = 4
n_epoch = 2000
lr = 0.05
threshold = 0.5

搭建模型。要注意隐层的激活函数使用了目前暂时最优的 ELU。由于这个是二分类问题,输出层激活函数只能是 Sigmoid。

变量 含义
x 输入
y 真实标签
w_l{1,2,3} {1,2,3}层的权重
b_l{1,2,3} {1,2,3}层的偏置
z_l{1,2,3} {1,2,3}层的中间变量,前一层输出的线性变换
a_l{1,2,3} {1,2,3}层的输出,其中a_l3样本是正样本的概率
x = tf.placeholder(tf.float64, [None, n_input])
y = tf.placeholder(tf.float64, [None, 1])
w_l1 = tf.Variable(np.random.rand(n_input, n_hidden1))
b_l1 = tf.Variable(np.random.rand(1, n_hidden1))
w_l2 = tf.Variable(np.random.rand(n_hidden1, n_hidden2))
b_l2 = tf.Variable(np.random.rand(1, n_hidden2))
w_l3 = tf.Variable(np.random.rand(n_hidden2, 1))
b_l3 = tf.Variable(np.random.rand(1, 1))
z_l1 = x @ w_l1 + b_l1
a_l1 = tf.nn.elu(z_l1)
z_l2 = a_l1 @ w_l2 + b_l2
a_l2 = tf.nn.elu(z_l2)
z_l3 = a_l2 @ w_l3 + b_l3
a_l3 = tf.sigmoid(z_l3)

定义交叉熵损失、优化操作、和准确率度量指标。

变量 含义
loss 损失
op 优化操作
y_hat 标签的预测值
acc 准确率
loss = - tf.reduce_mean(y * tf.log(a_l3) + (1 - y) * tf.log(1 - a_l3))
op = tf.train.AdamOptimizer(lr).minimize(loss)y_hat = tf.to_double(a_l3 > threshold)
acc = tf.reduce_mean(tf.to_double(tf.equal(y_hat, y)))

使用训练集训练模型。

losses = []
accs = []with tf.Session() as sess:sess.run(tf.global_variables_initializer())for e in range(n_epoch):_, loss_ = sess.run([op, loss], feed_dict={x: x_train, y: y_train})losses.append(loss_)

使用测试集计算准确率。

        acc_ = sess.run(acc, feed_dict={x: x_test, y: y_test})accs.append(acc_)

每一百步打印损失和度量值。

        if e % 100 == 0:print(f'epoch: {e}, loss: {loss_}, acc: {acc_}')

得到决策边界:

    x_plt = x_[:, 0]y_plt = x_[:, 1]c_plt = y_.ravel()x_min = x_plt.min() - 1x_max = x_plt.max() + 1y_min = y_plt.min() - 1y_max = y_plt.max() + 1x_rng = np.arange(x_min, x_max, 0.05)y_rng = np.arange(y_min, y_max, 0.05)x_rng, y_rng = np.meshgrid(x_rng, y_rng)model_input = np.asarray([x_rng.ravel(), y_rng.ravel()]).Tmodel_output = sess.run(y_hat, feed_dict={x: model_input}).astype(int)c_rng = model_output.reshape(x_rng.shape)

输出:

epoch: 0, loss: 8.951598255929909, acc: 0.28888888888888886
epoch: 100, loss: 0.5002945631529941, acc: 0.7333333333333333
epoch: 200, loss: 0.10712651780120697, acc: 0.9333333333333333
epoch: 300, loss: 0.08321807852608396, acc: 0.9333333333333333
epoch: 400, loss: 0.08013835031876741, acc: 0.9333333333333333
epoch: 500, loss: 0.07905186419367002, acc: 0.9333333333333333
epoch: 600, loss: 0.07850865683940819, acc: 0.9333333333333333
epoch: 700, loss: 0.07808251016428093, acc: 0.9333333333333333
epoch: 800, loss: 0.07780712763974691, acc: 0.9333333333333333
epoch: 900, loss: 0.07759866398922599, acc: 0.9333333333333333
epoch: 1000, loss: 0.07744327666591566, acc: 0.9333333333333333
epoch: 1100, loss: 0.07731295774932465, acc: 0.9333333333333333
epoch: 1200, loss: 0.07721162022836371, acc: 0.9333333333333333
epoch: 1300, loss: 0.07712807776857629, acc: 0.9333333333333333
epoch: 1400, loss: 0.07735547120278226, acc: 0.9333333333333333
epoch: 1500, loss: 0.07700215794853897, acc: 0.9333333333333333
epoch: 1600, loss: 0.07695230759382654, acc: 0.9333333333333333
epoch: 1700, loss: 0.07690933782097598, acc: 0.9333333333333333
epoch: 1800, loss: 0.07687191279304387, acc: 0.9333333333333333
epoch: 1900, loss: 0.07683911419647445, acc: 0.9333333333333333

绘制整个数据集以及决策边界。

plt.figure()
cmap = mpl.colors.ListedColormap(['r', 'b'])
plt.scatter(x_plt, y_plt, c=c_plt, cmap=cmap)
plt.contourf(x_rng, y_rng, c_rng, alpha=0.2, linewidth=5, cmap=cmap)
plt.title('Data and Model')
plt.xlabel('Petal Length (cm)')
plt.ylabel('Sepal Length (cm)')
plt.show()

绘制训练集上的损失。

plt.figure()
plt.plot(losses)
plt.title('Loss on Training Set')
plt.xlabel('#epoch')
plt.ylabel('Cross Entropy')
plt.show()

绘制测试集上的准确率。

plt.figure()
plt.plot(accs)
plt.title('Accurary on Testing Set')
plt.xlabel('#epoch')
plt.ylabel('Accurary')
plt.show()

扩展阅读

  • DeepLearningAI 笔记:浅层神经网络
  • DeepLearningAI 笔记:深层神经网络

TensorFlow HOWTO 4.1 多层感知机(分类)相关推荐

  1. TensorFlow HOWTO 4.2 多层感知机回归(时间序列)

    4.2 多层感知机回归(时间序列) 这篇教程中,我们使用多层感知机来预测时间序列,这是回归问题. 操作步骤 导入所需的包. import tensorflow as tf import numpy a ...

  2. 多层感知机 分类思想_感知和行动是产品设计需要意识到的思想形式

    多层感知机 分类思想 I spend my working days at a company that builds a social media management platform for c ...

  3. 深度学习softmax与多层感知机分类模型

    softmax 简单的分类问题 一个简单的图像分类问题,输入图像的高和宽均为2像素,色彩为灰度.图像中的4像素分别记为x1,x2,x3,x4x_1,x_2,x_3,x_4x1​,x2​,x3​,x4​ ...

  4. TensorFlow HOWTO 2.1 支持向量分类(软间隔)

    在传统机器学习方法,支持向量机算是比较厉害的方法,但是计算过程非常复杂.软间隔支持向量机通过减弱了其约束,使计算变得简单. 操作步骤 导入所需的包. import tensorflow as tf i ...

  5. TensorFlow HOWTO 2.3 支持向量分类(高斯核)

    遇到非线性可分的数据集时,我们需要使用核方法,但为了使用核方法,我们需要返回到拉格朗日对偶的推导过程,不能简单地使用 Hinge 损失. 操作步骤 导入所需的包. import tensorflow ...

  6. TensorFlow实现多层感知机MINIST分类

    TensorFlow实现多层感知机MINIST分类 TensorFlow 支持自动求导,可以使用 TensorFlow 优化器来计算和使用梯度.使用梯度自动更新用变量定义的张量.本文将使用 Tenso ...

  7. TensorFlow多层感知机实现MINIST分类

    import tensorflow as tf import tensorflow.contrib.layers as layers from tensorflow.python import deb ...

  8. 基于Tensorflow实现多层感知机网络MLPs

    正文共1232张图,1张图,预计阅读时间7分钟. github:https://github.com/sladesha/deep_learning 之前在基于Tensorflow的神经网络解决用户流失 ...

  9. TensorFlow实现多层感知机

    一.感知机的简介 在前面我们实现了一个softmax regression,也可以说是一个多分类问题的logistic regression.它和传统意义上的神经网络最大的区别就是没有隐藏层.在一个神 ...

最新文章

  1. ORACLE如何使用DBMS_METADATA.GET_DDL获取DDL语句
  2. iOS 9音频应用播放音频之iOS 9音频播放进度
  3. html友情链接效果代码,HTML友情链接代码
  4. x264 n-th pass编码时候Stats文件的含义
  5. EasyUI_datagrid
  6. 博士当中学老师是“人才浪费”?
  7. php网站渗透实战_【案例分析】记一次综合靶场实战渗透
  8. matlab在同一窗口中画多个三维图像
  9. 数字图像处理 采样定理_数字图像处理(第4版)
  10. 构建java ut运行环境
  11. Guava之Joiner笔记
  12. python爬虫之QQ空间登陆获取信息(超级详细)
  13. switch 语句 -- 超详解
  14. HHUOJ 1887 班级聚会上的游戏
  15. 太原理工大于丹计算机,太原理工大学硕士生将参加中国第30次南极考察
  16. JAVA开发离线语音识别
  17. [INSTALL_FAILED_DUPLICATE_PERMISSION perm=quicksdk_packageName.permission.JPUSH_MESSAGE pkg=com.shou
  18. snap7-c++/MFC开发笔记
  19. 二、web通信知识拓
  20. 关于badboy录制脚本时无法打开网页的一些办法

热门文章

  1. (45)FPGA面试题格雷码特点及其应用
  2. python做单元测试_如何使用python做单元测试?
  3. memset() 初始化类对象
  4. 5009. tinyfsm有限状态机
  5. jenkins构建后脚本不执行_接口管理工具ApiPost-预(后)执行脚本常用方法集合
  6. 动态加载子节点_简易数据分析 10 | Web Scraper 翻页—抓取「滚动加载」类型网页...
  7. linux路由内核实现分析(二)---FIB相关数据结构(3)
  8. 中断触发流程三(中断控制器)
  9. 嵌入式Linux系统编程学习之九基于文件描述符的文件操作
  10. Java的echo_简单的Java echo服务器问题