在《python深度学习》这本书中。

一、21页mnist十分类

导入数据集

from keras.datasets import mnist

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

初始数据维度:

>>> train_images.shape

(60000, 28, 28)

>>> len(train_labels)

60000

>>> train_labels

array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

数据预处理:

train_images = train_images.reshape((60000, 28 * 28))

train_images = train_images.astype('float32') / 255

train_labels = to_categorical(train_labels)

之后:

print(train_images, type(train_images), train_images.shape, train_images.dtype)

print(train_labels, type(train_labels), train_labels.shape, train_labels.dtype)

结果:

[[0. 0. 0. ... 0. 0. 0.]

[0. 0. 0. ... 0. 0. 0.]

[0. 0. 0. ... 0. 0. 0.]

...

[0. 0. 0. ... 0. 0. 0.]

[0. 0. 0. ... 0. 0. 0.]

[0. 0. 0. ... 0. 0. 0.]] (60000, 784) float32

[[0. 0. 0. ... 0. 0. 0.]

[1. 0. 0. ... 0. 0. 0.]

[0. 0. 0. ... 0. 0. 0.]

...

[0. 0. 0. ... 0. 0. 0.]

[0. 0. 0. ... 0. 0. 0.]

[0. 0. 0. ... 0. 1. 0.]] (60000, 10) float32

二、51页IMDB二分类

导入数据:

from keras.datasets import imdb (train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)

参数 num_words=10000 的意思是仅保留训练数据中前 10 000 个最常出现的单词。

train_data和test_data都是numpy.ndarray类型,都是一维的(共25000个元素,相当于25000个list),其中每个list代表一条评论,每个list中的每个元素的值范围在0-9999 ,代表10000个最常见单词的每个单词的索引,每个list长度不一,因为每条评论的长度不一,例如train_data中的list最短的为11,最长的为189。

train_labels和test_labels都是含25000个元素(元素的值要不0或者1,代表两类)的list。

数据预处理:

# 将整数序列编码为二进制矩阵

def vectorize_sequences(sequences, dimension=10000):

# Create an all-zero matrix of shape (len(sequences), dimension)

results = np.zeros((len(sequences), dimension))

for i, sequence in enumerate(sequences):

results[i, sequence] = 1. # set specific indices of results[i] to 1s

return results

x_train = vectorize_sequences(train_data)

x_test = vectorize_sequences(test_data)

第一种方式:shape为(25000,)

y_train = np.asarray(train_labels).astype('float32') #就用这种方式就行了

y_test = np.asarray(test_labels).astype('float32')

第二种方式:shape为(25000,1)

y_train = np.asarray(train_labels).astype('float32').reshape(25000, 1)

y_test = np.asarray(test_labels).astype('float32').reshape(25000, 1)

第三种方式:shape为(25000,2)

y_train = to_categorical(train_labels) #变成one-hot向量

y_test = to_categorical(test_labels)

第三种方式,相当于把二分类看成了多分类,所以网络的结构同时需要更改,

最后输出的维度:1->2

最后的激活函数:sigmoid->softmax

损失函数:binary_crossentropy->categorical_crossentropy

预处理之后,train_data和test_data变成了shape为(25000,10000),dtype为float32的ndarray(one-hot向量),train_labels和test_labels变成了shape为(25000,)的一维ndarray,或者(25000,1)的二维ndarray,或者shape为(25000,2)的one-hot向量。

注:

1.sigmoid对应binary_crossentropy,softmax对应categorical_crossentropy

2.网络的所有输入和目标都必须是浮点数张量

补充知识:keras输入数据的方法:model.fit和model.fit_generator

1.第一种,普通的不用数据增强的

from keras.datasets import mnist,cifar10,cifar100

(X_train, y_train), (X_valid, Y_valid) = cifar10.load_data()

model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, shuffle=True,

verbose=1, validation_data=(X_valid, Y_valid), )

2.第二种,带数据增强的 ImageDataGenerator,可以旋转角度、平移等操作。

from keras.preprocessing.image import ImageDataGenerator

(trainX, trainY), (testX, testY) = cifar100.load_data()

trainX = trainX.astype('float32')

testX = testX.astype('float32')

trainX /= 255.

testX /= 255.

Y_train = np_utils.to_categorical(trainY, nb_classes)

Y_test = np_utils.to_categorical(testY, nb_classes)

generator = ImageDataGenerator(rotation_range=15,

width_shift_range=5./32,

height_shift_range=5./32)

generator.fit(trainX, seed=0)

model.fit_generator(generator.flow(trainX, Y_train, batch_size=batch_size),

steps_per_epoch=len(trainX) // batch_size, epochs=nb_epoch,

callbacks=callbacks,

validation_data=(testX, Y_test),

validation_steps=testX.shape[0] // batch_size, verbose=1)

以上这篇keras分类模型中的输入数据与标签的维度实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

本文标题: keras分类模型中的输入数据与标签的维度实例

本文地址: http://www.cppcns.com/jiaoben/python/324490.html

python输入数据的维度_keras分类模型中的输入数据与标签的维度实例相关推荐

  1. 检测恶意软件分类模型中的概念漂移

    科研笔记 论文题目-检测恶意软件分类模型中的概念漂移 共形预测 (conformal prediction)是一种置信度预测器,它生成具有用户定义的错误率的预测.在某个置信度水平下,所有预测范围的那部 ...

  2. python 两点曲线_python机器学习分类模型评估

    python机器学习分类模型评估 1.混淆矩阵 在分类任务下,预测结果(Predicted Condition)与正确标记(True Condition)之间存在四种不同的组合,构成混淆矩阵(适用于多 ...

  3. python编程:10种分类模型评估的方法及Python实现

    本文的文字及图片来源于网络,仅供学习.交流使用,不具有任何商业用途,版权归原作者所有,如有问题请及时联系我们以作处理 想要学习Python?有问题得不到第一时间解决?来看看这里,满足你的需求,资料都已 ...

  4. NLP 模型“解语如神”的诀窍:在文本分类模型中注入外部词典

    一. 引言 现实世界的文本表述如恒河沙数,以惊人的速度变换着,人工智能(AI)在快速识别形形色色的文本之前,必须经过充足的训练数据洗礼.然而,面对复杂多变的文本表述,NLP 模型往往无法从有限的训练数 ...

  5. Python实现直方图梯度提升分类模型(HistGradientBoostingClassifier算法)并基于网格搜索进行优化同时绘制PDP依赖图项目实战

    说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取. 1.项目背景 基于直方图的梯度提升分类树.此估算器对缺失值 (NaN) 具有原 ...

  6. 机器学习分类模型中的评价指标介绍:准确率、精确率、召回率、ROC曲线

    文章来源:https://blog.csdn.net/wf592523813/article/details/95202448 1 二分类评价指标 准确率,精确率,召回率,F1-Score, AUC, ...

  7. 【项目实战】Python实现深度神经网络RNN-LSTM分类模型(医学疾病诊断)

    说明:这是一个机器学习实战项目(附带数据+代码+视频+文档),如需数据+完整代码可以直接到文章最后获取. 1.项目背景 随着互联网+的不断深入,我们已步入人工智能时代,机器学习作为人工智能的一个分支越 ...

  8. python编程练习:模拟水文模型中的水箱模型(tank model),不含参数率定过程

    一.水箱模型结构 二.代码 import matplotlib.pyplot as plt import numpy as np import math import time # 设置标签为负号可显 ...

  9. 分类机器学习中,某一标签占比太大(标签稀疏),如何学习?

    链接:https://www.zhihu.com/question/372186043 编辑:深度学习与计算机视觉 声明:仅做学术分享,侵删 假设10000个数据,有100个1,200个2,其余全是0 ...

最新文章

  1. mysql数据库连接jar_mysql数据库连接包
  2. java 里面write,java 中 System.out.println()和System.out.write()的区别
  3. 飞鸽传书下载2013
  4. Kubernetes—动态存储卷配置(StorageClass资源)(十二)
  5. easyPR源码解析之plate_judge.h
  6. C#学习笔记之-----倒序输出字符串
  7. Linux Hackers/Suspicious Account Detection
  8. .net pdf转图片_PDF转图片要怎么转?两分钟解决!
  9. desktop.ini winxp之马上更新图标
  10. LeetCode 98 验证二叉搜索树
  11. 人脸识别接口_双目模组摄像头人脸识别技术中活体检测
  12. 从DOS中装操作系统时要加载smartdrv命令
  13. 格式化输出latex数字罗马字体
  14. 腾讯云服务器手动建立WordPress个人站点Windows系统教程-Unirech腾讯云国际版代充
  15. android studio 跳转后保留原页面数据_这些技巧和习惯,让你的原生 Android 手机更好用(上篇)...
  16. 苹果宣布 2022 年 Apple 设计大奖得主
  17. 「学习笔记」自适应辛普森法
  18. selenium_Selenium4 Alpha –期望什么?
  19. 性能监控:top命令
  20. TP5后端,VUE前端请求聚合数据驾照题库

热门文章

  1. 痞子衡嵌入式:无线通信技术协议全搜罗 - 索引
  2. MongoDB 小试牛刀
  3. nexus-3本地下载jar的settipng.xml配置
  4. Java删除文件和目录
  5. 动态创建模板列并绑定数据(GridView,Repeater,DataGrid)
  6. 【安卓】基于SharedPreferences实现用户登录信息的存储
  7. ZoomIt – 屏幕标注、电子画笔 [小工具]
  8. zk的数据目录:`version-2`
  9. 阿里DataV案例:制作实时销售大屏流程
  10. Scala元组数据的访问