本文主要讲述如何使用keras来微调VGG16模型在kaggle的猫狗大战的数据集上实现迁移学习,精度达到了97%,在多训练几个epoch会更高,如果本文有错误的地方欢迎大家斧正,有什么问题也欢迎大家与我交流讨论。

一、对数据集进行预处理

首先在kaggle的官网下载猫狗的数据集https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data

读取图像数据存储到数组中,然后归一化到0-1之间,将图像大小统一resize为100x100

设置数据生成器,减小内存负担。

def data_generator(all_img_name, all_label, batch_size, h=100, w=100):"""该函数用于生成批量的数据,用于fit_generator批量训练:param all_train_name::param batch_size::return:"""batches = len(all_img_name) // batch_sizewhile True:for i in range(batches):name_batch = all_img_name[i * batch_size: (i + 1) * batch_size]label_batch = all_label[i * batch_size: (i + 1) * batch_size]# label 转化为one-hot编码Y = to_categorical(label_batch, num_classes=2)X = np.array([])for j in range(batch_size):img_path = name_batch[j]labels = label_batch# 读取imgimg = cv.imread(img_path)# resizeimg = cv.resize(img, (h, w))/255.0if len(X.shape) < 3:X = img[np.newaxis, :, :]else:X = np.concatenate((X, img[np.newaxis, :, :]), axis=0)yield (X, Y)

二、定义网络结构

在这里我们以VGG16的结构为基础,去掉其后面的全连接层,加上自己设计的3个全连接层,下图左图为我使用的网络结构右图为VGG16的网络结构,我修改了网络的输入大小,vgg16的大小为224x224我的大小为100x100

                   

keras封装了vgg模型的函数

model = keras.applications.vgg16.VGG16(include_top=True, weights='imagenet', input_tensor=None, input_shape=None, pooling=None, classes=1000)

你可以在下面这个地址下载VGG16的权重(如果你设置weights参数为imagenet的话他会自动下载)

https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5

参数
include_top:是否保留顶层的3个全连接网络
weights:None代表随机初始化,即不加载预训练权重。'imagenet'代表加载预训练权重,这里输入权重的地址,如果你填的是imagenet那么他会自动从github上下载VGG16的权重
input_tensor:可填入Keras tensor作为模型的图像输出tensor
input_shape:可选,仅当include_top=False有效,应为长为3的tuple,指明输入图片的shape,图片的宽高必须大于48,如(200,200,3)
返回值
pooling:当include_top=False时,该参数指定了池化方式。None代表不池化,最后一个卷积层的输出为4D张量。‘avg’代表全局平均池化,‘max’代表全局最大值池化。
classes:可选,图片分类的类别数,仅当include_top=True并且不加载预训练权重时可用。

def vgg_model(vgg_weights_path):# 定义模型base_model = VGG16(weights=vgg_weights_path,include_top=False, input_shape=(100, 100, 3))x = Flatten()(base_model.output)x = Dense(1024, activation="relu")(x)x = Dense(200, activation="relu")(x)y_pred = Dense(2, activation="softmax")(x)model = Model(inputs=base_model.input, outputs=y_pred)model.summary()return model

训练

使用sklearn.model_selection的train_test_split函数来划分训练集与测试集

train_X, test_X, train_Y, test_Y = train_test_split(x_name, y, test_size=0.2, random_state=0)

使用keras里面的fit_generator()函数来进行批量训练

history = model.fit_generator(generator=data_generator(train_X, train_Y, batch_size, h, w),steps_per_epoch=len(train_X) // batch_size,epochs=epoch, verbose=1, validation_data=data_generator(test_X, test_Y, h, w),validation_steps=len(test_X) // batch_size)

在这里我是使用GPU来进行训练的,一个epoch大概3-4分钟左右

预测

读取要预测的图片,然后通过模型,得到概率,使用np.argmax函数将概率转化为label,并显示图像

# 加载模型
model = load_model(model_path)
pred = model.predict(img_list, batch_size=4)
# print(pred)
for index, i in enumerate(pred):maxer = np.argmax(i)# 猫0狗1if maxer == 0:print(img_name_list[index], "is a cat!")img = cv.imread(os.path.join(image_path, img_name_list[index]))cv.imshow("is a cat", img)cv.waitKey(0)else:print(img_name_list[index], "is a dog!")img = cv.imread(os.path.join(image_path, img_name_list[index]))cv.imshow("is a dog", img)cv.waitKey(0)print("over!")

预测效果如下,我总共训练了4个epoch,在训练集上的精度达到了98%,在测试集上的精度达到了95%,稍微有点过拟合,加个dropout然后稍微调一下应该就好了。

我使用的环境及库版本:

python3.6(anaconda虚拟环境)

win10

emmm我觉得你只要保证tensorflow  和keras、opencv的版本没问题就可以了,当然你也不需要严格按照我的版本来,只要不报错就行(tf保持在12左右,keras保持在2.2左右,opencv保持在3.?左右),另外我的tf是gpu版的,用cpu版的也可以,不过你需要调一下自己的tf和keras版本。。。。(对新手来说配置gpu版的可能比较繁琐,建议用cpu版)

总结

keras真的是简洁方便,所有的代码都在我的github上---------》https://github.com/henryccl/dogvscat/tree/master

如何运行?

1.首先在训练前你需要下载数据集和VGG16的权重(下载地址在上面有提到)

2.配置参数,下面是我自己配的超参数,你如果想调的话可以自己改,不改也行,不过最重要的,还是要最下面的三行地址,一定要改成你自己的文件地址,不然会报错。

3.然后运行train.py文件,python train.py就可以开始训练啦!

如果你遇到问题,欢迎你和我讨论

迁移学习VGG16实现猫狗大战相关推荐

  1. 使用VGG迁移学习开启《猫狗大战挑战赛》

    文章目录 一.前言 二.加载数据集 三.数据预处理 四.构建VGG模型 五.训练VGG模型 六.保存与测试模型 七.总结 一.前言 猫狗大战挑战由Kaggle于2013年举办的,目前比赛已经结束,不过 ...

  2. 零基础实战迁移学习VGG16解决图像分类问题

    文章目录 1 前言 2 Transfer Learning 3 How to transfer? 4 代码实战:基于迁移学习对猫狗图片进行辨识 5 参考 1 前言 本文涉及到的代码均已开源,读者可自行 ...

  3. 【深度学习】一文看懂 (Transfer Learning)迁移学习(pytorch实现)

    前言 你会发现聪明人都喜欢"偷懒", 因为这样的偷懒能帮我们节省大量的时间, 提高效率. 还有一种偷懒是 "站在巨人的肩膀上". 不仅能看得更远, 还能看到更多 ...

  4. 迁移学习 Transfer Learning—通俗易懂地介绍(常见网络模型pytorch实现)

    前言 你会发现聪明人都喜欢"偷懒", 因为这样的偷懒能帮我们节省大量的时间, 提高效率. 还有一种偷懒是 "站在巨人的肩膀上". 不仅能看得更远, 还能看到更多 ...

  5. 基于tensorflow2.0实现猫狗大战(搭建网络迁移学习)

    猫狗大战是kaggle平台上的一个比赛,用于实现猫和狗的二分类问题.最近在学卷积神经网络,所以自己动手搭建了几层网络进行训练,然后利用迁移学习把别人训练好的模型直接应用于猫狗分类这个数据集,比较一下实 ...

  6. VGG16迁移学习实现

    VGG16迁移学习实现 本文讨论迁移学习,它是一个非常强大的深度学习技术,在不同领域有很多应用.动机很简单,可以打个比方来解释.假设想学习一种新的语言,比如西班牙语,那么从已经掌握的另一种语言(比如英 ...

  7. 手动搭建的VGG16网络结构训练数据和使用ResNet50微调(迁移学习)训练数据对比(图像预测+前端页面显示)

    文章目录 1.VGG16训练结果: 2.微调ResNet50之后的训练结果: 3.结果分析: 4.实验效果: (1)VGG16模型预测的结果: (2)在ResNet50微调之后预测的效果: 5.相关代 ...

  8. 1、VGG16 2、VGG19 3、ResNet50 4、Inception V3 5、Xception介绍——迁移学习

    ResNet, AlexNet, VGG, Inception: 理解各种各样的CNN架构 本文翻译自ResNet, AlexNet, VGG, Inception: Understanding va ...

  9. DL之VGG16:基于VGG16(Keras)利用Knifey-Spoony数据集对网络架构进行迁移学习

    DL之VGG16:基于VGG16(Keras)利用Knifey-Spoony数据集对网络架构迁移学习 目录 数据集 输出结果 设计思路 1.基模型 2.思路导图 核心代码 更多输出 数据集 Datas ...

  10. 【Pytorch实战6】一个完整的分类案例:迁移学习分类蚂蚁和蜜蜂(Res18,VGG16)

    参考资料: <深度学习之pytorch实战计算机视觉> Pytorch官方教程 Pytorch官方文档 本文是采用pytorch进行迁移学习的实战演练,实战目的是为了进一步学习和熟悉pyt ...

最新文章

  1. JSON http://www.cnblogs.com/haippy/archive/2012/05/20/2509329.html
  2. 如何优雅的导出Excel
  3. BOM字符(#8203;)转textNode对象
  4. 生成树生成森林c语言中文网,生成树协议(STP)基本知识及实验(使用eNSP)
  5. 360Stack裸金属服务器部署实践
  6. 小学三年级计算机基础知识课件,小学三年级信息技术基础知识ppt课件.ppt
  7. Python yaml模块
  8. 浏览器插件 如何方便查看md文件内容 markdown
  9. WEB安全之:密码穷举破解
  10. 软件测试工程师 岗位分析
  11. python蒙特卡洛方法圆周率_使用Python语言的蒙特卡洛方法计算圆周率π的一种实现...
  12. 开发谷歌浏览器插件会上瘾,搞了一个JSONViewer,一个页面格式化多条JSON,提升工作效率...
  13. 前后端对接及接口管理平台浅析
  14. explain用法和结果的含义
  15. 蓝桥杯2017年第八届C/C++ B组省赛习题题解
  16. 现在Python就业很难吗?百万程序员都在关心的问题
  17. 台式计算机32位和64位的区别,电脑系统32位和64位有哪些区别 32位和64位是什么意思 【详解】...
  18. 实现html表单下划线可输入/css实现input只显示下划线
  19. 如何获得高质量的外链
  20. 网站中加入站长流量统计代码

热门文章

  1. 证明:模n加法满足结合律
  2. ◎◎首都机场大巴最新路线时刻表◎◎
  3. 【资讯】1225- Flutter 2.10发布,稳定支持Windows
  4. ios uri正则表达式_众果搜的博客
  5. 数据预处理transforms
  6. 分布式GNN系统环境配置
  7. 【基于python的量化策略回测框架搭建】策略表现衡量指标模块
  8. python反爬中url之aes加密_python反爬之前端加密技术
  9. 【算法学习笔记】09.数据结构基础 二叉树初步练习2
  10. 运动耳机哪些好用?专业运动耳机购买指南