1.项目简介

1.数据集

cat_12数据集包含3个部分,训练集cat_12_train,测试集cat_test,以及存储图片名称及标签的train_list.txt

2.数据预处理

首先,定义一个prepare_image函数,取出文本文件中的图片路径与标签,并且打乱顺序

def prepare_image(file_path):X_train = []y_train = []with open(file_path) as f:context = f.readlines()random.shuffle(context)for str in context:str = str.strip('\n').split('\t')X_train.append('./cat_12/' + str[0])y_train.append(str[1])return X_train, y_train

再定义一个preprocess_image进行图片归一化操作,将像素值限制在0-1之间。

# 数据归一化
def preprocess_image(image):image = tf.io.read_file(image)image = tf.image.decode_jpeg(image, channels=3)image = tf.image.resize_with_pad(image, 227, 227)image = image / 255.0return image

三、主函数部分

导包:tqdm包用于打印进度条

from dataset import prepare_image, preprocess_image
import tensorflow as tf
from tqdm import tqdm
import numpy as np
from matplotlib import pyplot as plt

数据预处理:

X_train, y_train = prepare_image('./cat_12/train_list.txt')
train_images = []for i in tqdm(X_train):train_image = preprocess_image(i)train_images.append(train_image)from tensorflow import kerastrain_images = np.array(train_images)
print(train_images.shape)
y_train = keras.utils.to_categorical(y_train, 12)

定义LRN层:

class LRN(keras.layers.Layer):def __init__(self, depth_radius=5, bias=1, alpha=1, beta=0.5, **kwargs):super().__init__(**kwargs)self.depth_radius = depth_radiusself.bias = biasself.alpha = alphaself.beta = betadef call(self, input):return tf.nn.lrn(input, self.depth_radius, self.bias, self.alpha, self.beta)def get_config(self):base_config = super().get_config()return {**base_config, 'depth_radius': self.depth_radius,'bias': self.bias, 'alpha': self.alpha, 'beta': self.beta}

构建模型:

model = keras.Sequential()
# 第一层
model.add(keras.layers.Conv2D(filters=96, kernel_size=(11, 11), strides=(4, 4), input_shape=(227, 227, 3), padding="VALID",activation="relu"))
model.add(LRN())
model.add(keras.layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding="same"))
# 第二层
model.add(keras.layers.Conv2D(filters=256, kernel_size=(5, 5), strides=(1, 1), padding="SAME", activation="relu"))
model.add(LRN())
model.add(keras.layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding="same"))
# 第三层
model.add(keras.layers.Conv2D(filters=384, kernel_size=(5, 5), strides=(1, 1), padding="SAME", activation="relu"))
# 第四层
model.add(keras.layers.Conv2D(filters=384, kernel_size=(5, 5), strides=(1, 1), padding="SAME", activation="relu"))
# 第五层
model.add(keras.layers.Conv2D(filters=256, kernel_size=(5, 5), strides=(1, 1), padding="SAME", activation="relu"))
model.add(keras.layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding="same"))
# 第六层
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(4096, activation="relu"))
model.add(keras.layers.Dropout(0.5))
# 第七层
model.add(keras.layers.Dense(4096, activation="relu"))
model.add(keras.layers.Dropout(0.5))
# 第八层
model.add(keras.layers.Dense(12, activation="softmax"))# keras.utils.plot_model(model=model, to_file='AlexNet.png', show_shapes=True)model.compile(loss="categorical_crossentropy", optimizer="sgd", metrics=["accuracy"])
history = model.fit(train_images, y_train, epochs=50, batch_size=16, validation_split=0.2)model.save('the_AlexNet_model.h5')

打印误差曲线:

def show_training_history(train_history, train, val):plt.plot(train_history[train], linestyle='-', color='b')plt.plot(train_history[val], linestyle='--', color='r')plt.xlabel('Epoch', fontsize=12)plt.ylabel('train', fontsize=12)plt.legend(['train', 'validation'], loc='lower right')plt.show()show_training_history(history.history, 'loss', 'val_loss')
show_training_history(history.history, 'acc', 'val_acc')

基于AlexNet网络的猫十二分类相关推荐

  1. 【五一创作】使用Resnet残差网络对图像进行分类(猫十二分类,模型定义、训练、保存、预测)(二)

    使用Resnet残差网络对图像进行分类 (猫十二分类,模型定义.训练.保存.预测)(二) 目录 (6).数据集划分 (7).训练集增强 (8).装载数据集 (9).初始化模型 (10).模型训练 (1 ...

  2. 【五一创作】使用Resnet残差网络对图像进行分类(猫十二分类,模型定义、训练、保存、预测)(一)

    使用Resnet残差网络对图像进行分类 (猫十二分类,模型定义.训练.保存.预测)(一) 目录 一.项目简介 二.环境说明 1.安装库 2.导入需要的库 三.分类过程 (1).解压数据集 (2).相关 ...

  3. 基于ResNet的猫十二分类

    在这次实战训练中,首先对下载的猫十二数据集进行预处理,使用了tensorflow构建resnet模型,在学习率调度上,使用了1周期调度,并且使用了动量优化和Nesterov加速梯度 1.导包 from ...

  4. pytorch 猫狗二分类 resnet

    深度学习(猫狗二分类) 题目要求 数据获取与预处理 网络模型 模型原理 Resnet背景 Resnet原理 代码实现 模型构建 训练过程 批验证过程 单一验证APP 运行结果 训练结果 批验证结果 A ...

  5. (论文加源码)基于DEAP和MABHOB数据集的二分类脑电情绪识别(pytorch深度神经网络(DNN)和卷积神经网络(CNN))

    该论文发表于2021年的顶级期刊.(pytorch框架) 代码解析部分在个人主页: https://blog.csdn.net/qq_45874683/article/details/13000797 ...

  6. 猫狗二分类与四种天气多分类

    提示:文章用于学习记录 文章目录 前言 一.猫狗图像分类 1.1 数据预处理 1.2 构建神经网络 二.四种天气图片数据分类(pytorch) 总结 前言 常见分类网络结构可以分为两部分,一部分是特征 ...

  7. RDKit | 基于随机森林的化合物活性二分类模型

    基于随机森林算法的化合物二分类机器学习模型 代码示例 #导入依赖包 import pandas as pd import numpy as np from rdkit import Chem, Dat ...

  8. ML之catboost:基于自带Pool数据集实现二分类预测

    ML之catboost:基于自带Pool数据集实现二分类预测 基于自带Pool数据集实现二分类预测 输出结果 Learning rate set to 0.5 0: learn: 0.9886498 ...

  9. Unity C# 网络学习(十二)——Protobuf生成协议

    Unity C# 网络学习(十二)--Protobuf生成协议 一.安装 去Protobuf官网下载对应操作系统的protoc,用于将.proto文件生成对应语言的协议语言文件 由于我使用的是C#所以 ...

最新文章

  1. PHP的Smarty
  2. mysql 重构同步老数据_MySQL 重构查询的方式
  3. ie9怎么开兼容模式
  4. 如何修改 SAP Spartacus CMS API 默认的 endpoint
  5. epoll原理_如果这篇文章说不清epoll的本质,那就过来掐死我吧! (1)
  6. 苹果Mac轻量级网页代码编辑器:​​​​​​​​​​​​Espresso
  7. PHP之mb_strrpos使用
  8. 拓端tecdat|虎扑社区论坛数据爬虫分析报告
  9. 蓝桥杯矩阵求和_刷蓝桥杯官网习题,准备蓝桥杯的小伙伴,一起来交流吧(✪ω✪)。(2月27日更新)...
  10. html毕业设计论文,静态网页HTML设计毕业设计论文
  11. 如何使用STVP_CmdLine.exe
  12. 分享200个App移动端模板
  13. 如何查看自己windows密钥
  14. linux安装键盘鼠标失灵,在archlinux安装界面这卡住了,鼠标键盘失灵
  15. 学习淘宝分享出来的链接web检测打开原生App
  16. PHP检测字数,PHP获取word文档字数的问题
  17. 基于C语言实现的足球信息查询系统 课程报告+项目源码+演示PPT+项目截图
  18. 三维激光扫描技术的应用领域有哪些?
  19. SpringBoot学习之路---简单记录整合SpringSecurity实现登录认证授权
  20. (重温)JavaWeb--Cookie 和 Session入门总结(了解cookie和session这一篇就够了)

热门文章

  1. [Mac软件推荐] paste - 好用的剪切板记录增强工具
  2. 关于STM32的裸机多任务多线程心得
  3. 服务器怎么增加独立显卡,dellr610服务器增加独立显卡(dell服务器装显卡)
  4. 记一次锐捷网络虚拟化(VSU)故障处理
  5. 鸿蒙5G多少钱一部手机,5G+鸿蒙,就是我下一部手机的标配,不接受反驳
  6. lisp 画伯努利双纽线_伯努利双纽线的应用有哪些?
  7. Android中的接口的使用举例
  8. 用Auto.js批量删除空间说说
  9. 【智能优化算法-正弦余弦算法】基于反向正弦余弦算法求解高维优化问题附matlab代码
  10. ArcEngine(五)用ICommand接口实现放大缩小