我的第一个Keras神经网络二分类模型

  • 目标
  • 网络结构
  • 实现
    • 数据
    • 模型
    • 验证
  • 小结

目标

使用Keras 训练一个简单的二分类模型,对下图中的点分类,其中训练特征为点的坐标(x, y),红色标签为0,蓝色标签为1。

网络结构

二分类神经网络模型结构如下,其中:

  1. 输入层为点的坐标(x, y)。
  2. 输出层为点的标签[0, 1], 激活函数为sigmoid。
  3. 模型只包含一个隐藏层,隐藏层包含50个神经单元,激活函数为relu。

实现

数据

使用sklearn.makemoons()函数生成1000个测试样本,并按照7:3的比例拆成训练/测试集。


from sklearn.datasets import make_moons
from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
from matplotlib import pyplot
from keras.models import load_model
from matplotlib import pyplot
from pandas import DataFrame
from keras.layers import Dropout
import numpy as np# generate 2d classification dataset
X, y = make_moons(n_samples=1000, noise=0.1, random_state=1)
# scatter plot, dots colored by class value
df = DataFrame(dict(x=X[:,0], y=X[:,1], label=y))
colors = {0:'red', 1:'blue'}
fig, ax = pyplot.subplots()
grouped = df.groupby('label')
for key, group in grouped:group.plot(ax=ax, kind='scatter', x='x', y='y', label=key, color=colors[key])
pyplot.show()# split into train and test
n_train = int(X.shape[0] * 0.7)
trainX, testX = X[:n_train, :], X[n_train:, :]
trainy, testy = y[:n_train], y[n_train:]

模型

编辑模型,绘制训练结果:

  1. model.add(Dense(50, input_dim=2, activation=‘relu’)) 描述了第一个隐藏层的结构。
  2. model.add(Dense(1, activation=‘sigmoid’)) 描述了输出层的结构。
  3. 模型的损失函数(Loss)为 ‘binary_crossentropy’。
  4. 模型采用 ‘adam’ 梯度下降搜索算法。
  5. 模型执行1000(epochs=1000)次训练。

# define model
model = Sequential()
model.add(Dense(50, input_dim=2, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])# fit model
history = model.fit(trainX, trainy, validation_data=(testX, testy), epochs=1000, verbose=1)
pyplot.plot(history.history['loss'], label='train')
pyplot.plot(history.history['val_loss'], label='test')
pyplot.legend()
pyplot.show()

验证

最后选择两个特殊的点(右下角蓝色与左上角红色)验证一下模型有效性:


testArray = np.array([[2.0, -1.0], [-2.0, 1.0]])
print(model.predict(testArray))[[1.0000000e+00] (蓝色)[1.4053326e-24] (红色)]

小结

  1. trainX是 (700x2)的二维数据,其中每一行是一个训练数据,总共700个训练数据。
  2. trainy是700维的向量,总共700个标签标记红色(0)或者蓝色(1)。
  3. 将训练/测试数据堆成一个二维数组可以有效利用cpu/gpu的并行计算加快训练速度。

Keras:我的第一个神经网络二分类模型相关推荐

  1. AI:神经网络IMDB电影评论二分类模型训练和评估

    AI:Keras神经网络IMDB电影评论二分类模型训练和评估,python import keras from keras.layers import Dense from keras import ...

  2. Keras之ML~P:基于Keras中建立的简单的二分类问题的神经网络模型(根据200个数据样本预测新的5个样本)——概率预测

    Keras之ML~P:基于Keras中建立的简单的二分类问题的神经网络模型(根据200个数据样本预测新的5个样本)--概率预测 目录 输出结果 核心代码 输出结果 核心代码 # -*- coding: ...

  3. Keras之ML~P:基于Keras中建立的简单的二分类问题的神经网络模型(根据200个数据样本预测新的5+1个样本)——类别预测

    Keras之ML~P:基于Keras中建立的简单的二分类问题的神经网络模型(根据200个数据样本预测新的5+1个样本)--类别预测 目录 输出结果 核心代码 输出结果 核心代码 # -*- codin ...

  4. 深度学习框架tensorflow二实战(训练一个简单二分类模型)

    导入工具包 import os import warnings warnings.filterwarnings("ignore") import tensorflow as tf ...

  5. [机器学习] 二分类模型评估指标---精确率Precision、召回率Recall、ROC|AUC

    一 为什么要评估模型? 一句话,想找到最有效的模型.模型的应用是循环迭代的过程,只有通过持续调整和调优才能适应在线数据和业务目标. 选定模型时一开始都是假设数据的分布是一定的,然而数据的分布会随着时间 ...

  6. 【Kay】机器学习——二分类模型的评价

    一.评价二分类模型的好坏 二分类问题:预测这条数据是0还是1的问题 1.混淆矩阵 数字代表个数 2.准确率.精确率.召回率 ①准确率: ②精确率(查准率): ③召回率(查全率recall) :   ④ ...

  7. 衡量二分类模型的统计指标(TN,TP,FN,FP,F1,准确,精确,召回,ROC,AUC)

    文章目录 - 衡量二分类问题的统计指标 分类结果 混淆矩阵 准确率 精确率 召回率 F1评分 推导过程 ROC曲线.AUC - 衡量二分类问题的统计指标 分类结果   二分类问题,分类结果有以下四种情 ...

  8. Python实现PSO粒子群优化循环神经网络LSTM分类模型项目实战

    说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取. 1.项目背景 PSO是粒子群优化算法(Particle Swarm Optim ...

  9. 独家 | 教你用不到30行的Keras代码编写第一个神经网络(附代码教程)

    翻译:陈丹 校对:和中华 本文长度为3000字,建议阅读5分钟 本文为大家介绍了如何使用Keras来快速实现一个神经网络. 回忆起我第一次接触人工智能的时候,我清楚地记得有些概念看起来是多么令人畏惧. ...

  10. 用神经网络二分类人脑与电脑

    如果一个对象A,无论与B或C分类,分类的准确率都是50%,则A为一个具有智慧的对象. 用符号表示 ∵ (A,B)-n*m*k-(1,0)(0,1)  50%:50%    ①     (A,C)-n* ...

最新文章

  1. 御用导航官方网站提醒提示页_导航错误致四川青城山拥堵?交警提醒:别过度依赖导航...
  2. java之IO流(commons-IO)
  3. 泛型擦除机制、自定义注解、代理、反射
  4. nyoj-37 回文字符串
  5. Nand Flash与Nor Flash
  6. Win10系统修改MAC地址
  7. 使用SIFT匹配金馆长表情包
  8. Warning:java: 来自注释处理程序 'org.antlr.v4.runtime.misc.NullUsageProcessor' 的受支持 source 版本 'RELEASE_6' 低于
  9. 计算机教学研讨会议记录,教学教研工作会议记录3.doc
  10. Python 结合bat批处理文件 实现密码保管箱
  11. 2 电感耦合方式的射频前端
  12. python 进化树_SCHISM 构建克隆进化树
  13. 计算机网络技术广告,屏蔽QQ广告和迷你首页广告
  14. ArcGIS更多颜色调配
  15. java List/ArrayList 解惑
  16. JVM简笔—类的加载
  17. 【SpringBoot】12.SpringBoot整合Dubbo+Zookeeper
  18. 版本管理工具 git和SVN 忽略文件和目录
  19. Linux云计算之VSFTP服务器概述-安装vsftp服务器端、客户端
  20. 陕西省2011年教师资格证教育基础理论知识考试报名通知

热门文章

  1. matlab里线性规划,Matlab 中的数学建模算法 —— 线性规划函数
  2. java 系统找不到路径_java IOException:系统找不到指定的路径
  3. python可以破解网站吗_python变相破解校园网 - 『编程语言区』 - 吾爱破解 - LCG - LSG |安卓破解|病毒分析|www.52pojie.cn...
  4. 第3章 形式语言与自动机
  5. 提升生产力,7 款好用的原型图工具推荐给你
  6. hive 计算周几_HIVE 计算指定日期本周的第一天和最后一天
  7. OLED的字模提取(保姆级)---基于PCtoLCD2013
  8. 【开源项目】CircuitJS1在线电路仿真
  9. 平安普惠系统上线申请表模板
  10. MFC之打开(开发)映美精相机