配置

tensorflow2.4.0
python3.6
猫狗大战数据集

代码

VGG16网络很著名,这里不再介绍。
keras里有预训练好的VGG16,tensorflow2.0以后的版本中已经集成了keras。
解释在代码中。

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras import applications
from tensorflow.keras.layers import Dropout, Flatten, Dense
from tensorflow.keras.optimizers import SGD
import pickle
import numpy as np# 开启GPU加速
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:tf.config.experimental.set_memory_growth(gpu, True)OUT_CATEGORIES = 2  # 分类数
batch_size = 2  # 批量大小
epochs = 50  # 迭代次数
imgSize = 256
def model():img_shape = (imgSize, imgSize, 3)# 加载不包含全连接层的VGG16网络base_model = applications.VGG16(weights='imagenet', include_top=False, input_shape=img_shape)base_model.summary()# 根据分类数目增加自定义的全连接层,并与VGG16连接model_out = Sequential()model_out.add(Flatten(input_shape=base_model.output_shape[1:]))model_out.add(Dense(256, activation='relu'))model_out.add(Dropout(0.5))model_out.add(Dense(OUT_CATEGORIES, activation='sigmoid'))model = Model(inputs=base_model.input, outputs=model_out(base_model.output))model.compile(loss='binary_crossentropy', optimizer=SGD(lr=0.0001, momentum=0.9),metrics=['accuracy'])  # 损失函数为二进制交叉熵,优化器为SGDreturn modelpickle_in = open("x.pickle", "rb")
x = pickle.load(pickle_in)pickle_in = open("y.pickle", "rb")
y = pickle.load(pickle_in)
y = np.array(y)
# 数据集分割为训练集和测试集
train_num = int(x.shape[0] * 0.7)
test_num = x.shape[0] - train_num
# x归一化,制作数据集时没有归一化
x = x/255
# 打乱
state = np.random.get_state()
np.random.shuffle(x)
np.random.set_state(state)
np.random.shuffle(y)train_x = x[0:train_num, :, :, :]
test_x = x[train_num:train_num+test_num, :, :, :]
train_y = y[0:train_num, :]
test_y = y[train_num:train_num+test_num, :]
# 将label转换
train_label = keras.utils.to_categorical(train_y, OUT_CATEGORIES)
test_label = keras.utils.to_categorical(test_y, OUT_CATEGORIES)model = model()
model.fit(train_x, train_label, batch_size=batch_size, epochs=epochs, validation_data=(test_x, test_label), shuffle=True)
model.save("catDog.h5")

使用VGG16网络结构训练自己的图像分类模型相关推荐

  1. 手动搭建的VGG16网络结构训练数据和使用ResNet50微调(迁移学习)训练数据对比(图像预测+前端页面显示)

    文章目录 1.VGG16训练结果: 2.微调ResNet50之后的训练结果: 3.结果分析: 4.实验效果: (1)VGG16模型预测的结果: (2)在ResNet50微调之后预测的效果: 5.相关代 ...

  2. 七、图像分类模型的部署(Datawhale组队学习)

    文章目录 前言 ONNX简介 应用场景 部署ImageNet预训练图像分类模型 导出ONNX模型 推理引擎ONNX Runtime部署-预测单张图像 前期准备 ONNX Runtime预测 推理引擎O ...

  3. 【神经网络与深度学习】CIFAR10数据集介绍,并使用卷积神经网络训练图像分类模型——[附完整训练代码]

    [神经网络与深度学习]CIFAR-10数据集介绍,并使用卷积神经网络训练模型--[附完整代码] 一.CIFAR-10数据集介绍 1.1 CIFAR-10数据集的内容 1.2 CIFAR-10数据集的结 ...

  4. Keras少量样本训练强大图像分类模型

    原文:Building powerful image classification models using very little data 作者:Francois Chollet,2016.6.2 ...

  5. ML.NET 示例:图像分类模型训练-首选API(基于原生TensorFlow迁移学习)

    ML.NET 版本 API 类型 状态 应用程序类型 数据类型 场景 机器学习任务 算法 Microsoft.ML 1.5.0 动态API 最新 控制台应用程序和Web应用程序 图片文件 图像分类 基 ...

  6. (三)mmclassification图像分类——模型训练

    (三)mmclassification图像分类--模型训练和测试 1.模型训练 1.1使用预训练模型 1.2使用自己的数据训练 1.2.1制作数据集 1.2.2修改模型参数(configs文件) (1 ...

  7. 基于paddlex图像分类模型训练(一):图像分类数据集切分:文件夹转化为imagenet训练格式

    相关博文 基于paddlex图像分类模型训练(二):训练自己的分类模型.熟悉官方demo 背景 在使用paddlex GUI训练图像分类时,内部自动对导入的分类文件夹进行细分,本文主要介绍其图像分类数 ...

  8. 使用pytorch训练你自己的图像分类模型(包括模型训练、推理预测、误差分析)

    开源代码:https://github.com/xxcheng0708/Pytorch_Image_Classifier_Template​​​​​ 使用pytorch框架搭建一个图像分类模型通常包含 ...

  9. Windows下Caffe的学习与应用(三)——使用OpenCV3调用自己训练好的Caffe模型进行图像分类

    前言 前面的博文中,我试了如何使用caffe训练得到想要的模型与其如何使用别人成熟的模型微调优化自己训练的模型,那么得到训练好的模型之后如何在自己的项目中呢,我这里使用opencv的DNN模块调用ca ...

  10. 【图像分类】如何使用 mmclassification 训练自己的分类模型

    文章目录 一.数据准备 二.模型修改 三.模型训练 四.模型效果可视化 五.如何分别计算每个类别的精确率和召回率 MMclassification 是一个分类工具库,这篇文章是简单记录一下如何用该工具 ...

最新文章

  1. 5弹出阴影遮罩_千文详述Cocos Creator弹出式对话框实现技术,着实硬核
  2. 【论文笔记】CNN for NLP
  3. MySql提示服务已经启动成功但又提示can’t connect to MySQL server解决方法,mysql服务自动停止处理方法
  4. android socket 服务,android 创建socket 通信型service
  5. 在Android应用外获取app的签名
  6. HDU - 4552 怪盗基德的挑战书(后缀数组+RMQ/KMP+dp)
  7. Diagram of Interpositioning and Default Global Scope插入和默认全局范围图
  8. Domino学习笔记之邮件发送程序
  9. 重温数据结构——(1)
  10. 排序算法(二)Sort with Swap(0,*)
  11. Noip模拟题解题报告
  12. Elastic Search 查询语法大全
  13. 请求发送者与接收者解耦——命令模式(五)
  14. java十字链表存储,图的十字链表存储结构
  15. 学习日记| javaScript在网页绘制国际象棋盘
  16. 因特尔显卡自定义分辨率_为什么从最新的英特尔显卡控制面板中移除了自定义分辨率?...
  17. idea运行javaweb项目出现“该网页无法正常运作”
  18. [AHK]按住左键可以移动鼠标下的窗口
  19. vs修改项目属性无效的原因
  20. 某人将1000元存入银行 某公司需用一台设备 某企业为了建一项目 建设期3年,共贷款700万元

热门文章

  1. Windows小技巧 – Win+R提高Windows使用效率
  2. 多种方式Map集合遍历
  3. Activiti学习记录 Activiti初始化数据库、Activiti6增加表注释字段注释
  4. 【下载一】NI 系列软件卸载工具
  5. 阿里云长视频上传以及返回播放地址
  6. Tomcat日志显示乱码问题
  7. SQL连接两张或多张表
  8. SqlServer无法连接服务器
  9. 资深制作人谈游戏策划如何入行
  10. 《人工智能:一种现代的方法》笔记(一)