Otto 分类问题

这里,我们将对Otto数据集进行分类。

  • 本文主要参考 2.3 Introduction to Keras。个人觉得这是一个很好Keras教程,希望大家也去学习学习。
  • 关于Otto,可以在 otto group 找到更多详细的材料
  • 本文主要关注代码的实现,具体细节和基本概念不会详细展开

让我们开始吧

就像以前说过的那样,处理一个问题主要分为三个部分:数据准备,模型构建和模型优化

导入模块

这里遇到了新的模块

  • StandardScaler 用于归一化,感觉很好使。详见StandardScaler
  • LabelEncoder 配合np_utils用于One-hot编码,详见LabelEncoder。注意和OneHotEncoder的区别。
  • EarlyStopping 当监测值不再改善时,该回调函数将中止训练。详见EarlyStopping
  • ModelCheckpoint 保存模型。详见ModelCheckpoint
import numpy as np
import pandas as pdfrom sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import LabelEncoderfrom keras.utils import np_utils
from keras.models import Sequential
from keras.layers.core import Dense, Activation, Dropoutfrom keras.callbacks import EarlyStopping, ModelCheckpoint

## 数据准备 读取数据。数据可以在 [otto group](https://www.kaggle.com/c/otto-group-product-classification-challenge/data) 找到

train_path = './data/train.csv'
test_path = './data/test.csv'df = pd.read_csv(train_path)

观察数据。有93个特征,最后一列是种类,第一列的id对于训练没有任何作用。

df.head()
id feat_1 feat_2 feat_3 feat_4 feat_5 feat_6 feat_7 feat_8 feat_9 feat_85 feat_86 feat_87 feat_88 feat_89 feat_90 feat_91 feat_92 feat_93 target
0 1 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 Class_1
1 2 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 Class_1
2 3 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 Class_1
3 4 1 0 0 1 6 1 5 0 0 0 1 2 0 0 0 0 0 0 Class_1
4 5 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 Class_1

5 rows × 95 columns

导入数据。

  • 第一列id对训练没用,所以我们不需要它。
  • train 和 test 两个文件有所区别(test中没有给出target)
def load_data(path, train=True):df = pd.read_csv(path)X = df.values.copy()if train:np.random.shuffle(X)X, label = X[:, 1:-1].astype(np.float32), X[:, -1]return X, labelelse:X, ids = X[:, 1:].astype(np.float32), X[:, 0].astype(str)return X, ids
X_train, y_train = load_data(train_path)
X_test, ids = load_data(test_path, train=False)

预处理,训练数据和测试数据一起归一化,以免忘记了

def preprocess_data(X, scaler=None):if not scaler:scaler = StandardScaler()scaler.fit(X)X = scaler.transform(X)return X, scaler
X_train, scaler = preprocess_data(X_train)
X_test, _ = preprocess_data(X_test, scaler)

One-hot 编码

def preprocess_label(labels, encoder=None, categorical=True):if not encoder:encoder = LabelEncoder()encoder.fit(labels)y = encoder.transform(labels).astype(np.int32)if categorical:y = np_utils.to_categorical(y)return y, encoder
y_train, encoder = preprocess_label(y_train)

搭建网络模型

dim = X_train.shape[1]
print(dim, 'dims')
print('Building model')nb_classes = y_train.shape[1]model = Sequential()model.add(Dense(256, input_shape=(dim, )))
model.add(Activation('relu'))
model.add(Dropout(0.5))model.add(Dense(128))
model.add(Activation('relu'))
model.add(Dropout(0.5))model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dropout(0.5))model.add(Dense(nb_classes))
model.add(Activation('softmax'))
93 dims
Building model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
batch_size = 128
epochs = 2

训练,同时保持最佳模型

fBestModel = 'best_model.h5'
early_stop = EarlyStopping(monitor='val_acc', patience=5, verbose=1)
best_model = ModelCheckpoint(fBestModel, verbose=0, save_best_only=True)model.fit(X_train, y_train, epochs=epochs, batch_size=batch_size, verbose=1, validation_split=0.1, callbacks=[best_model, early_stop])
Train on 55690 samples, validate on 6188 samples
Epoch 1/2
55690/55690 [==============================] - 2s 42us/step - loss: 0.5256 - acc: 0.7967 - val_loss: 0.5268 - val_acc: 0.7982
Epoch 2/2
55690/55690 [==============================] - 2s 42us/step - loss: 0.5251 - acc: 0.7991 - val_loss: 0.5256 - val_acc: 0.8017<keras.callbacks.History at 0x13551adaef0>

预测并保存结果。将结果保存为Kaggle上要求的格式,然后提交了测试结果,得到了0.5左右的分数,据说大概前50%左右

prediction = model.predict(X_test)
num_pre = prediction.shape[0]
columns = ['Class_'+str(post+1) for post in range(9)]df2 = pd.DataFrame({'id' : range(1,num_pre+1)})
df3 = pd.DataFrame(prediction, columns=columns)df_pre = pd.concat([df2, df3], axis=1)
df_pre.to_csv('predition.csv', index=False)

Keras-3 Keras With Otto Group相关推荐

  1. keras系列︱keras是如何指定显卡且限制显存用量

    keras系列︱keras是如何指定显卡且限制显存用量 原创 2017年07月21日 10:59:24 标签: keras / gpu / 显卡 / 指定 / 限制 6630 keras在使用GPU的 ...

  2. DL之Keras:keras保存网络结构、网络拓扑图、网络模型(json、yaml、h5等)注意事项及代码实现

    DL之Keras:keras保存网络结构.网络拓扑图.网络模型(json.yaml.h5等)注意事项及代码实现 目录 keras保存网络结构.网络拓扑图.网络模型(json.yaml.h5等)注意事项 ...

  3. DL之Keras: Keras深度学习框架的注意事项(默认下载存放路径等)、使用方法之详细攻略

    DL之Keras: Keras深度学习框架的注意事项(自动下载存放路径等).使用方法之详细攻略 目录 Keras深度学习框架的注意事项 1.Keras自动下载默认数据集/模型存放位置 Windows系 ...

  4. 【keras】keras教程(参考官方文档)

    文章目录 一.callbacks篇 1.ReduceLROnPlateau 训练过程优化学习率 2.EarlyStopping 早停操作 3.ModelCheckpoint 用于设置保存的方式 4.T ...

  5. keras学习- No module named ' tensorflow.keras ' 报错,看清 tf.keras与keras

    环境描述: 系统ubantu16.04 安装anaconda  版本conda 4.5.4 创建虚拟环境 tf-gpu tensorflow-gpu版本(1.7.0-gpu, 能够import ten ...

  6. [Keras] 使用Keras调用多GPU时出现无法保存模型的解决方法

    在使用keras 的并行多路GPU时出现了模型无法 保存,在使用单个GPU时运行完全没有问题.运行出现can't pickle的问题 随后在网上找了很多解决方法.下面列举一些我实验成功的方法. 方法一 ...

  7. TensorFlow 2.0中的tf.keras和Keras有何区别?为什么以后一定要用tf.keras?

    选自pyimagesearch 作者:Adrian Rosebrock 参与:王子嘉.张倩 本文经机器之心授权转载,禁止二次转载 随着 TensorFlow 2.0 的发布,不少开发者产生了一些疑惑: ...

  8. Keras(part1)--Keras简介与安装

    学习笔记,仅供参考,有错必纠 参考自:<keras快速上手>:keras安装教程:VC14(VC2015)安装失败,0x80240017 - 未指定的错误,解决办法:运行TensorFlo ...

  9. TensorFlow2 tf.keras和keras

    keras: 1.python的神经网络api,后端可以是TensorFlow,CNTK,Theano tf.keras: 1.Tensorflow对keras的内部实现 tf.keras 1.全面支 ...

最新文章

  1. python3基本数据类型
  2. 要找到现阶段最适合自己的方法
  3. 【VS开发】PCIe体系结构的组成部件
  4. 短期目标[Till 2011-08-05]
  5. 四则运算栈c语言程序,四则运算   c语言编程
  6. 华为宣布:免费培养2000名大数据开发者!
  7. java 拟合曲线_如何通过指数曲线拟合数据
  8. 服务器控件开发之复杂属性
  9. python3x完全兼容python2x_李亚涛:一台电脑python2x与python3x如何都可以用?
  10. NameNode启动
  11. $Django 聚合函数、分组查询、F,Q查询、orm字段以及参数
  12. linux 设备/dev
  13. 传输层协议(6):TCP 连接(下-3)
  14. python多线程下载编程_Python多线程结合队列下载百度音乐代码详解
  15. 麻瓜编程python爬虫微专业_麻瓜编程·python实战·1-3作业:爬取租房信息
  16. java ajax sendrequest()请求_AJAX – 向服务器发送请求 | 菜鸟教程
  17. 安装Android SDK时无法识别JDK 10
  18. simulink与gt联合仿真问题求解
  19. CC2530F256RHAR收发器
  20. 测试 CS4344 立体声DA转换器

热门文章

  1. php显示图片缩略图,使用ThinkPHP生成缩略图及显示的方法
  2. 天津大学考研计算机专业课的教材,天津大学(专业学位)计算机技术研究生考试科目和考研参考书目...
  3. Docker安装Mysql 案例和Tomcat测试
  4. 用java实现学生管理系统
  5. whoosh mysql_使用WhooshAlchemy报错'function' object has no attribute 'config'
  6. 19_python基础—面向对象-类结构、类属性和类方法、静态方法
  7. jenkins+pytest+allure接口自动化测试(windows环境)
  8. python123第6周答案_python123 测验6: 组合数据类型 (第6周)
  9. border属性 php,如何通过CSS的border属性为图片设置边框效果
  10. 修改ubuntu默认的Python版本号