二分类问题
我的方案:

import os
import sys
import cv2
import random
import pandas as pdimport matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns
import itertools
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrixfrom keras.utils.np_utils import to_categorical # convert to one-hot-encoding
from keras.models import Sequential
from keras.models import load_model
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPool2D
from keras.optimizers import RMSprop
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ReduceLROnPlateau
from keras.callbacks import EarlyStopping
from keras.preprocessing.image import load_img,img_to_arrayimport numpy as np
root_path=sys.path[0]
train_path=r'./input/train/'
test_path=r'./input/test/'
train_list=next(os.walk(root_path+r'\input\train'))[2]
test_list=next(os.walk(root_path+r'\input\test'))[2]IMAGE_WIDTH=224
IMAGE_HEIGHT=224# train_path="../input/train/"
# test_path="../input/test/"
# train_list=next(os.walk(train_path))[2]
# test_list=next(os.walk(test_path))[2]# 根据图片路径获取图片标签
def get_img_label(img_paths):img_labels = []for img_path in img_paths:animal = img_path.split("/")[-1].split('.')[0]if animal == 'cat':img_labels.append(0)else:img_labels.append(1)img_labels=to_categorical(img_labels,2)return img_labels#读取图片
def load_batch_image(img_path,train_set=True,target_size=(IMAGE_WIDTH,IMAGE_HEIGHT)):im=load_img(img_path,target_size=target_size)if train_set:return img_to_array(im)else:return img_to_array(im)/255.0#建立一个数据迭代器
def get_dataset_shuffle(X_sample,batch_size,train_set=True):random.shuffle(X_sample)batch_num=int(len(X_sample)/batch_size)max_len=batch_num*batch_sizeX_sample=np.array(X_sample[:max_len])y_samples=get_img_label(X_sample)X_batches=np.split(X_sample,batch_num)y_batches=np.split(y_samples,batch_num)for i in range(len(X_batches)):if train_set:x=np.array(list(map(load_batch_image,X_batches[i],[True for _ in range(batch_size)])))else:x=np.array(list(map(load_batch_image,X_batches[i],[False for _ in range(batch_size)])))y=np.array(y_batches[i])yield x,y#数据增强处理train_datagen = ImageDataGenerator(rescale=1. / 255,rotation_range=10,width_shift_range=0.2,height_shift_range=0.2,shear_range=0.2,zoom_range=0.2,horizontal_flip=True)def build_model():# Define the modelmodel = Sequential()model.add(Conv2D(filters=32, kernel_size=(5, 5), padding='Same',activation='relu', input_shape=(IMAGE_WIDTH, IMAGE_HEIGHT, 3)))model.add(Conv2D(filters=32, kernel_size=(5, 5), padding='Same',activation='relu'))model.add(MaxPool2D(pool_size=(2, 2)))model.add(Dropout(0.25))model.add(Conv2D(filters=64, kernel_size=(3, 3), padding='Same',activation='relu'))model.add(Conv2D(filters=64, kernel_size=(3, 3), padding='Same',activation='relu'))model.add(MaxPool2D(pool_size=(2, 2), strides=(2, 2)))model.add(Dropout(0.25))model.add(Flatten())model.add(Dense(256, activation="relu"))model.add(Dropout(0.5))model.add(Dense(2, activation="sigmoid"))# Define the optimizeroptimizer = RMSprop(lr=0.001, rho=0.9, epsilon=1e-08, decay=0.0)# Compile the modelmodel.compile(optimizer=optimizer, loss="categorical_crossentropy", metrics=["accuracy"])return modeldef train():model=build_model()train_X=[train_path+item for item in train_list]# Set the random seedrandom_seed = 2# Split the train and the validation set for the fittingtrain_X, val_X = train_test_split(train_X, test_size = 0.1, random_state=random_seed)n_epoch=20batch_size=16for e in range(n_epoch):print('epoch',e)batch_num=0loss_sum=np.array([0.0,0.0])for X_train,y_train in get_dataset_shuffle(train_X,batch_size,True):for X_batch,y_batch in train_datagen.flow(X_train,y_train,batch_size=batch_size):loss=model.train_on_batch(X_batch,y_batch)loss_sum+=lossbatch_num+=1break#手动breakif batch_num%200==0:print("epoch %s, batch %s: train_loss = %.4f, train_acc = %.4f" % (e, batch_num, loss_sum[0] / 200, loss_sum[1] / 200))loss_sum = np.array([0.0, 0.0])res=model.evaluate_generator(get_dataset_shuffle(val_X,batch_size,False),int(len(val_X)/batch_size))print("val_loss = %.4f, val_acc = %.4f: " % (res[0], res[1]))model.save('weight.h5')def test():model=load_model('weight.h5')X_test_path=[test_path+item for item in test_list]results=[]for path in X_test_path:X_test=np.array(load_batch_image(path,False))X_test=np.expand_dims(X_test,axis=0)results.append(model.predict(X_test))results=np.array(results)results=np.argmax(results,axis=2)test_df=pd.read_csv('./input/sample_submission.csv')test_df['label']=resultstest_df.to_csv('result1.csv', index=False)if __name__=='__main__':train()test()

Dogs vs. Cats相关推荐

  1. CNN入门+猫狗大战(Dogs vs. Cats)+PyTorch入门

    一些修改(修改后的代码) 修改原网络的输出方式.原网络采用的交叉熵torch.nn.CrossEntropyLoss()进行Loss计算,而这个函数内部是已经进行了softmax处理的(参考),所以网 ...

  2. PyTorch实战Kaggle之Dogs vs. Cats

    PyTorch实战Kaggle之Dogs vs. Cats 目录 1. 导包 2. 数据载入及装载 3. 数据预览 1)获取一个批次的数据 2)验证独热编码的对应关系 3)图片预览 4. 模型搭建 5 ...

  3. Dogs vs. Cats数据集

    DogsVsCats百度网盘下载 [大小]:813.56M [链接]:https://pan.baidu.com/s/1qXmQoLDfV2WnJMZ-Hxt0Ww [提取码]:n7zc kaggle ...

  4. kaggle之Dogs vs. Cats(Keras)

    数据宏观把握--->数据预处理--->导出特征向量--->载入特征向量--->构建模型--->训练模型--->预测测试集 一.数据宏观把握 训练集25000张,猫狗 ...

  5. Cats vs. Dogs(猫狗大战)数据集处理

    猫狗大战数据集 Cats vs. Dogs(猫狗大战)数据集下载地址为https://www.kaggle.com/c/dogs-vs-cats/data.这个数据集是Kaggle大数据竞赛某一年的一 ...

  6. 基于TensorFlow的Cats vs. Dogs(猫狗大战)实现和详解(1)

    2017.5.29 官方的MNIST例子里面训练数据的下载和导入都是用已经写好的脚本完成的,至于里面实现细节也没高兴去看源码,感觉写得太正式,我这个初学者不好理解.于是在优酷上找到了KevinRush ...

  7. 基于MXNet的Cats vs. Dogs(猫狗大战)实现和详解

    2019.12.8 更新完整代码 https://github.com/nickhuang1996/Dogs_vs_Cats_MXNet 具体的搭建和运行步骤可参看README.md 介绍 这个存储库 ...

  8. 每日英语——University Students Rush to Walk Stray Dogs

    University Students Rush to Walk Stray Dogs(高校"遛狗"社团火了!遛狗名额要靠抢) Rush:匆忙,冲进 Stray:行走,散步 Wal ...

  9. VGG16迁移学习实现

    VGG16迁移学习实现 本文讨论迁移学习,它是一个非常强大的深度学习技术,在不同领域有很多应用.动机很简单,可以打个比方来解释.假设想学习一种新的语言,比如西班牙语,那么从已经掌握的另一种语言(比如英 ...

最新文章

  1. 卷积神经网络的复杂度分析
  2. 【c语言】蓝桥杯算法提高 一元一次方程
  3. 字典生成_Python数据字典生成工具详解
  4. Hibernate中两种获取Session的方式
  5. Sql Server实用操作-SQL语句导入导出大全
  6. 组合恒等式2 五个基本的组合恒等式 更复杂的技巧与例题
  7. java元婴期(20)----java进阶(spring(4)---spring aop编程(全自动)AspectJ)
  8. 【转】phpize学习
  9. python类型和格式_json数据格式和python中字典的数据类型
  10. 4x4矩阵键盘工作原理及扫描程序_基于复杂可编程逻辑器件实现键盘接口电路的设计...
  11. Exchange server 2003迁移到2010无路由组连接器
  12. 500个爆文标题_美食爆文大放送 | 烹饪技巧从细节着手,夏日消暑美食最为应时...
  13. (转)比特币算法——SHA256算法介绍
  14. jumserver 官方文档和
  15. 中文版orgin图像数字化工具_GetData Graph Digitizer(图表数字化工具) V2.25 官方版
  16. Matlab非线性拟合函数——nlinfit
  17. PHP祝福语,日常祝福语
  18. element ui响应式布局笔记,适配笔记
  19. Google谷歌新手SEO优化教程篇【1】
  20. STM32F103_study46_The punctual atoms(STM32 The location of all interrupt service functions )

热门文章

  1. 钉钉windows端多开软件_Windows7系统便签怎么找?适合Windows系统的便签
  2. 周鸿祎力荐|纽约客16000字重磅刊文:区块链是回归互联网本质的唯一希望
  3. 简述电信运营商圈内的三大业务领域-B-M-O
  4. 极狐GitLab 和 ArgoCD 的集成实践
  5. linux can总线接收数据串口打包上传_「干货」手把手教你用Zedboard学习Linux移植和驱动开发...
  6. 适合写python的电脑_文言文的适是什么意思
  7. UGUI内核大探究(十六)InputField
  8. nvidia-smi常用选项汇总
  9. mac安装Texpad:提示无法打开,因为APPLE无法检查其是否包含恶意软件解决方案
  10. unicode转中文 C# (dotnetcore)