1. #!/usr/bin/env python3

  2. # -*- coding: utf-8 -*-

  3. """

  4. Created on Sun Sep 30 17:12:12 2018

  5. 这是用keras搭建的vgg16网络

  6. 这是很经典的cnn,在图像和时间序列分析方面有很多的应用

  7. @author: lg

  8. """

  9. #################

  10. import keras

  11. from keras import regularizers

  12. from keras.datasets import cifar10

  13. from keras.models import Sequential

  14. from keras.layers import Conv2D, MaxPooling2D, Dense, Dropout, Flatten, BatchNormalization

  15. from keras.optimizers import SGD

  16. import os

  17. import argparse

  18. import random

  19. import numpy as np

  20. from scipy.misc import imread, imresize, imsave

  21. import pickle

  22. from sklearn.model_selection import StratifiedKFold

  23. parser = argparse.ArgumentParser()

  24. parser.add_argument('--train_dir', default='./train/')

  25. parser.add_argument('--test_dir', default='./test/')

  26. parser.add_argument('--log_dir', default='./')

  27. parser.add_argument('--batch_size', default=16)

  28. parser.add_argument('--gpu', type=int, default=0)

  29. args = parser.parse_args()

  30. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

  31. type_list = ['cat', 'dog']

  32. def vgg_16_net():

  33. model = Sequential()

  34. model.add(Conv2D(64, (3, 3), input_shape=(32, 32, 3), padding='same', activation='relu', name='conv1_block'))

  35. model.add(Conv2D(64, (3, 3), activation='relu', padding='same', name='conv2_block'))

  36. model.add(BatchNormalization())

  37. model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))

  38. model.add(Conv2D(128, (3, 3), activation='relu', padding='same', name='conv3_block'))

  39. model.add(Conv2D(128, (3, 3), activation='relu', padding='same', name='conv4_block'))

  40. model.add(BatchNormalization())

  41. model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))

  42. model.add(Conv2D(256, (3, 3), activation='relu', padding='same', name='conv5_block'))

  43. model.add(Conv2D(256, (3, 3), activation='relu', padding='same', name='conv6_block'))

  44. model.add(Conv2D(256, (1, 1), activation='relu', padding='same', name='conv7_block'))

  45. model.add(BatchNormalization())

  46. model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))

  47. model.add(Conv2D(512, (3, 3), activation='relu', padding='same', name='conv8_block'))

  48. model.add(Conv2D(512, (3, 3), activation='relu', padding='same', name='conv9_block'))

  49. model.add(Conv2D(512, (1, 1), activation='relu', padding='same', name='conv10_block'))

  50. model.add(BatchNormalization())

  51. model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))

  52. model.add(Conv2D(512, (3, 3), activation='relu', padding='same', name='conv11_block'))

  53. model.add(Conv2D(512, (3, 3), activation='relu', padding='same', name='conv12_block'))

  54. model.add(Conv2D(512, (1, 1), activation='relu', padding='same', name='conv13_block'))

  55. model.add(BatchNormalization())

  56. model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))

  57. model.add(Flatten())

  58. model.add(Dense(2048, activation='relu'))

  59. model.add(Dropout(0.5))

  60. model.add(Dense(4096, activation='relu'))

  61. model.add(Dropout(0.5))

  62. model.add(Dense(10, activation='softmax'))

  63. #model.add(Dense(1, activation='sigmoid'))

  64. return model

  65. def prepare_data():

  66. file_dict1 = unpickle('F:/cifar10/cifar-10-batches-py/data_batch_1')

  67. label = file_dict1[b'labels']

  68. image = file_dict1[b'data']

  69. print(type(image))

  70. file_dict2 = unpickle('F:/cifar10/cifar-10-batches-py/data_batch_2')

  71. label = label + file_dict2[b'labels']

  72. image = np.vstack((image, file_dict2[b'data']))

  73. file_dict3 = unpickle('F:/cifar10/cifar-10-batches-py/data_batch_3')

  74. label = label + file_dict3[b'labels']

  75. image = np.vstack((image, file_dict3[b'data']))

  76. file_dict4 = unpickle('F:/cifar10/cifar-10-batches-py/data_batch_4')

  77. label = label + file_dict4[b'labels']

  78. image = np.vstack((image, file_dict4[b'data']))

  79. file_dict5 = unpickle('F:/cifar10/cifar-10-batches-py/data_batch_5')

  80. label = label + file_dict5[b'labels']

  81. image = np.vstack((image, file_dict5[b'data']))

  82. image = np.reshape(image/255, (-1, 32, 32, 3))

  83. label = keras.utils.to_categorical(label, 10)

  84. #seed = 7

  85. #np.random.seed(seed)

  86. #train_data, test_data, train_label, test_label = train_test_split(image, label, test_size=0.2, random_state=0)

  87. #train_num = int(len(label) * 0.8 )

  88. #train_data, train_label, test_data, test_label = image[0:train_num], label[0:train_num], image[train_num:], label[train_num:]

  89. # (X_train, y_train), (X_test, y_test) = cifar10.load_data()

  90. # train_data = np.reshape(X_train/255, (-1, 32, 32, 3))

  91. # train_label = keras.utils.to_categorical(y_train, 10)

  92. # test_data = np.reshape(X_test/255, (-1, 32, 32, 3))

  93. # test_label = keras.utils.to_categorical(y_test, 10)

  94. # return train_data, train_label, test_data, test_label

  95. return image, label

  96. def unpickle(file):

  97. with open(file, 'rb') as fo:

  98. file_dict = pickle.load(fo, encoding='bytes')

  99. return file_dict

  100. def train():

  101. kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=7)

  102. data, label = prepare_data()

  103. index = 1

  104. for train, test in kfold.split(data, label.argmax(1)):

  105. model = vgg_16_net()

  106. sgd = SGD(lr=0.001, decay=1e-8, momentum=0.9, nesterov=True)

  107. model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy'])

  108. model.fit(data[train], label[train], validation_data=(data[test], label[test]), epochs=20, batch_size=32, shuffle=True)

  109. #train_data, train_label = prepare_data()

  110. #model.fit(train_data, train_label, batch_size=64, epochs=20, shuffle=True, validation_split=0.2)

  111. model.save_weights(args.log_dir + 'model_' + str(index) + '.h5')

  112. index = index + 1

  113. if __name__ == '__main__':

  114. try:

  115. train()

  116. #model.load_weights(args.log_dir + '/model.h5')

  117. #predict(model)

  118. except Exception as err:

  119. print(err)

最开始没有使用交叉验证,但是测试集与验证集的准确率一直维持在50%~60%,基本属于盲猜系列。原因大概是数据量太多,进行随机划分时,测试数据的分类不是很均匀。所以采用了交叉验证的方式,最终测试集与训练集的准确率能够达到99%,应该是有点过拟合了,结果还是非常满意的。

VGG16的10折交叉验证实现cifar10的分类(keras实现)相关推荐

  1. R语言caret包构建xgboost模型实战:特征工程(连续数据离散化、因子化、无用特征删除)、配置模型参数(随机超参数寻优、10折交叉验证)并训练模型

    R语言caret包构建xgboost模型实战:特征工程(连续数据离散化.因子化.无用特征删除).配置模型参数(随机超参数寻优.10折交叉验证)并训练模型 目录

  2. matlab.10折交叉验证

    clc clear all % 导入数据 data = load('F:\work_matlab\Matlab\wdbc.txt'); [data_r, data_c] = size(data); % ...

  3. 机器学习 - 随机森林手动10 折交叉验证

    随机森林的 10 折交叉验证 再回到之前的随机森林(希望还没忘记,机器学习算法-随机森林初探(1)) library(randomForest) set.seed(304) rf1000 <- ...

  4. 10折交叉验证(10-fold Cross Validation)与留一法(Leave-One-Out)、分层采样(Stratification)

    10折交叉验证我们构建一个分类器,输入为运动员的身高.体重,输出为其从事的体育项目-体操.田径或篮球. 一旦构建了分类器,我们就可能有兴趣回答类似下述的问题: 1. 该分类器的精确率怎么样? 2. 该 ...

  5. 《写给程序员的数据挖掘实践指南》——5.2. 10折交叉验证的例子

    本节书摘来自异步社区出版社<写给程序员的数据挖掘实践指南>一书中的第5章,第5.2节,作者:[美]Ron Zacharski(扎哈尔斯基),更多章节内容可以访问云栖社区"异步社区 ...

  6. 《机器学习》课后习题 3.4 选择两个 UCI 数据集,比较 10 折交叉验证法和留 法所估计出的对率回归的错误率.

    参考了han同学的答案,数据集也可在han同学的github上下载. 3.4 选择两个 UCI 数据集,比较 10 折交叉验证法和留 法所估计出的对率回归的错误率. import numpy as n ...

  7. 10折交叉验证(10-fold Cross Validation)与留一法(Leave-One-Out)、分层采样(Stratification)...

    10折交叉验证 我们构建一个分类器,输入为运动员的身高.体重,输出为其从事的体育项目-体操.田径或篮球. 一旦构建了分类器,我们就可能有兴趣回答类似下述的问题: 1. 该分类器的精确率怎么样? 2. ...

  8. 十折交叉验证10-fold cross validation, 数据集划分 训练集 验证集 测试集

    机器学习 数据挖掘 数据集划分 训练集 验证集 测试集 Q:如何将数据集划分为测试数据集和训练数据集? A:three ways: 1.像sklearn一样,提供一个将数据集切分成训练集和测试集的函数 ...

  9. k折交叉验证法python实现_Jason Brownlee专栏| 如何解决不平衡分类的k折交叉验证-不平衡分类系列教程(十)...

    作者:Jason Brownlee 编译:Florence Wong – AICUG 本文系AICUG翻译原创,如需转载请联系(微信号:834436689)以获得授权 在对不可见示例进行预测时,模型评 ...

  10. 机器学习代码实战——K折交叉验证(K Fold Cross Validation)

    文章目录 1.实验目的 2.导入数据和必要模块 3.比较不同模型预测准确率 3.1.逻辑回归 3.2.决策树 3.3.支持向量机 3.4.随机森林 1.实验目的 使用sklearn库中的鸢尾花数据集, ...

最新文章

  1. mysql总是出现-_mysql 总是莫名其妙的关闭:报错 -问答-阿里云开发者社区-阿里云...
  2. json schema多种形式_什么是JSON Schema?及其应用方式......
  3. 前端学习(3121):react-hello-react的state的简写方式
  4. 中考数学不准使用计算机,中考数学蒙题技巧
  5. nginx配置中location匹配规则详解
  6. java基础第九天_多线程、自动拆装箱
  7. bigint最大有多少位_《追光吧哥哥》21位艺人靠实力成团?不见得,节目赛制本不公平...
  8. css链接,列表,表格
  9. quartsu仿真8:二五十计数器74290的基本功能
  10. BULK INSERT, 实战手记:让百万级数据瞬间导入SQL Server
  11. Remember The Word-Trie
  12. 62. WWW 服务器
  13. 台式机装苹果系统_苹果、华为出手,ARM取代X86芯片这也是国产CPU的巨大机会
  14. EDA学习1.3之开关的封装
  15. 锐角云CEO许胜:因为认同,所以入行
  16. 国内优秀的PHP商城系统整理
  17. linux查看nbu数据库命令,NBU基本常用命令
  18. 【转载】双微信分享发生TransactionTooLargeException 异常记录
  19. 谁说QTP不能多线程 - 当Python遇上QTP
  20. 电脑关闭休眠模式清理 C盘内存

热门文章

  1. 从 0 开始了解 Docker(ubuntu )
  2. 推荐的MyBatis传参方式List、数组等
  3. 转:关于BFC的初步了解以及常见使用
  4. RHEL 5服务篇—LAMP平台的部署及应用
  5. Tsys1.1使用经验(汇集中)
  6. 五种提高 SQL 性能的方法
  7. 微信小程序-day1
  8. 05-Vue报错 Uncaught SyntaxError: Identifier has already been declared和路由
  9. Luogu P1967 货车运输 倍增+最大生成树
  10. CICD - Teamcity 配置之一: 数据库自动部署