前言

之前很多人在,如何进行XXX的识别,对应的神经网络如何搭建。对应神经网络怎么搭建,我也是照本宣科,只能说看得懂而已,没有对这块进行深入的研究,但是现在tensorflow,paddle这些工具,都提供了非常成熟的神经网络进行直接使用。
本文对过往的一些文章进行改造,使用已经集成的神经网络,简单的实现多个种类的动物识别。

环境

tensorflow:2.9.1
keras:2.9.0
os:windows10
gpu:RTX3070
cuda:cuda_11.4.r11.4
如何安装tensorflow就不在做赘述,要重点说明 tensorflow与keras版本的不同会引起不同工具类的使用。

数据准备

链接: https://pan.baidu.com/s/1J7yRsTS2o0LcVkbKKJD-Bw 提取码: 6666
解压之后,结构如下

代码

一、模型训练代码(animalv2_model_train.py)

导入

import osimport plotly.express as px
import matplotlib.pyplot as plt
from IPython.display import clear_output as cls
import numpy as np
from glob import glob
import pandas as pd# Model
import keras
from keras.models import Sequential, load_model
from keras.layers import GlobalAvgPool2D as GAP, Dense, Dropout
from keras.preprocessing.image import ImageDataGenerator# Callbacks
from keras.callbacks import EarlyStopping, ModelCheckpoint# 模型与处理工具
import tensorflow as tf
from tensorflow.keras.applications import ResNet50V2
from tensorflow.keras.utils import load_img, img_to_array

数据集合处理

root_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Training Data/'
valid_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Validation Data/'
test_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Testing Data/'
# 动物种类
class_names = sorted(os.listdir(root_path))
n_classes = len(class_names)print(f"Total Number of Classes : {n_classes} \nClass Names : {class_names}")class_dis = [len(os.listdir(root_path+name)) for name in class_names]fig = px.pie(names=class_names, values=class_dis, title="Training Class Distribution", hole=0.4)
fig.update_layout({'title':{'x':0.48}})
fig.show()fig = px.bar(x=class_names, y=class_dis, title="Training Class Distribution", color=class_names)
fig.update_layout({'title':{'x':0.48}})
fig.show()# 归一化
train_gen = ImageDataGenerator(rescale=1/255., rotation_range=10, horizontal_flip=True)
valid_gen = ImageDataGenerator(rescale=1/255.)
test_gen = ImageDataGenerator(rescale=1/255)# Load Data
train_ds = train_gen.flow_from_directory(root_path, class_mode='binary', target_size=(256,256), shuffle=True, batch_size=32)
valid_ds = valid_gen.flow_from_directory(valid_path, class_mode='binary', target_size=(256,256), shuffle=True, batch_size=32)
test_ds = test_gen.flow_from_directory(test_path, class_mode='binary', target_size=(256,256), shuffle=True, batch_size=32)

结果如下:

Total Number of Classes : 10
Class Names : ['Cat', 'Cow', 'Dog', 'Elephant', 'Gorilla', 'Hippo', 'Monkey', 'Panda', 'Tiger', 'Zebra']
Found 20000 images belonging to 10 classes.
Found 1000 images belonging to 10 classes.
Found 1907 images belonging to 10 classes.

图片展示

def show_images(GRID=[5, 5], model=None, size=(20, 20), data=train_ds):n_rows = GRID[0]n_cols = GRID[1]n_images = n_cols * n_rowsi = 1plt.figure(figsize=size)for images, labels in data:id = np.random.randint(len(images))image, label = images[id], class_names[int(labels[id])]plt.subplot(n_rows, n_cols, i)plt.imshow(image)if model is None:title = f"Class : {label}"else:pred = class_names[int(np.argmax(model.predict(image[np.newaxis, ...])))]title = f"Org : {label}, Pred : {pred}"cls()plt.title(title)plt.axis('off')i += 1if i >= (n_images + 1):breakplt.tight_layout()plt.show()def load_image(path):image = tf.cast(tf.image.resize(img_to_array(load_img(path))/255., (256,256)), tf.float32)return image
def show_image(image, title=None):plt.imshow(image)plt.axis('off')plt.title(title)show_images(data=train_ds)
show_images(data=valid_ds)
show_images(data=test_ds)path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Interesting Data/'
interesting_images = [glob(path + name + "/*") for name in class_names]# Interesting Cat Images
for name in class_names:plt.figure(figsize=(25, 8))cat_interesting = interesting_images[class_names.index(name)]for i, i_path in enumerate(cat_interesting):name = i_path.split("/")[-1].split(".")[0]image = load_image(i_path)plt.subplot(1,len(cat_interesting),i+1)show_image(image, title=name.title())plt.show()

模型训练

with tf.device("/GPU:0"):## 定义网络base_model = ResNet50V2(input_shape=(256,256,3), include_top=False)base_model.trainable = Falsecls()# 设计参数name = "ResNet50V2"model = Sequential([base_model,GAP(),Dense(256, activation='relu', kernel_initializer='he_normal'),Dropout(0.2),Dense(n_classes, activation='softmax')], name=name)# Callbacks# 容忍度为3,在容忍度之内就结束训练cbs = [EarlyStopping(patience=3, restore_best_weights=True), ModelCheckpoint(name + "_V2.h5", save_best_only=True)]# Modelopt = tf.keras.optimizers.Adam(learning_rate=2e-3)model.compile(loss='sparse_categorical_crossentropy', optimizer=opt, metrics=['accuracy'])# Model Traininghistory = model.fit(train_ds, validation_data=valid_ds, callbacks=cbs, epochs=50)data = pd.DataFrame(history.history)

模型训练

运行上面代码,我电脑的配置差不多需要1700+s(PS:可以换一下内存大一些的显卡比如 RTX40XX )
执行结果为如下:

2022-11-29 17:43:01.082836: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-29 17:43:01.449655: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 5472 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3070, pci bus id: 0000:01:00.0, compute capability: 8.6
Epoch 1/50
2022-11-29 17:43:18.284528: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8204
2022-11-29 17:43:21.378441: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
625/625 [==============================] - 292s 457ms/step - loss: 0.2227 - accuracy: 0.9361 - val_loss: 0.1201 - val_accuracy: 0.9630
Epoch 2/50
625/625 [==============================] - 217s 348ms/step - loss: 0.1348 - accuracy: 0.9596 - val_loss: 0.1394 - val_accuracy: 0.9610
Epoch 3/50
625/625 [==============================] - 218s 349ms/step - loss: 0.1193 - accuracy: 0.9641 - val_loss: 0.1452 - val_accuracy: 0.9620
Epoch 4/50
625/625 [==============================] - 219s 350ms/step - loss: 0.1035 - accuracy: 0.9690 - val_loss: 0.1147 - val_accuracy: 0.9690
Epoch 5/50
625/625 [==============================] - 221s 354ms/step - loss: 0.0897 - accuracy: 0.9736 - val_loss: 0.1117 - val_accuracy: 0.9730
Epoch 6/50
625/625 [==============================] - 219s 351ms/step - loss: 0.0817 - accuracy: 0.9747 - val_loss: 0.1347 - val_accuracy: 0.9640
Epoch 7/50
625/625 [==============================] - 219s 351ms/step - loss: 0.0818 - accuracy: 0.9740 - val_loss: 0.1126 - val_accuracy: 0.9700
Epoch 8/50
625/625 [==============================] - 219s 350ms/step - loss: 0.0731 - accuracy: 0.9785 - val_loss: 0.1366 - val_accuracy: 0.9680

验证模型

验证模型代码(animalv2_model_evaluate.py)

from keras.models import load_model
import tensorflow as tf
from tensorflow.keras.utils import load_img, img_to_array
import numpy as np
import osimport matplotlib.pyplot as pltroot_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Training Data/'class_names = sorted(os.listdir(root_path))model = load_model('./ResNet50V2_V2.h5')
model.summary()def load_image(path):image = tf.cast(tf.image.resize(img_to_array(load_img(path))/255., (256,256)), tf.float32)return imagei_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Validation Data/Gorilla/Gorilla (3).jpeg'
image = load_image(i_path)preds = model.predict(image[np.newaxis, ...])[0]print(preds)pred_class = class_names[np.argmax(preds)]confidence_score = np.round(preds[np.argmax(preds)], 2)# Configure Title
title = f"Pred : {pred_class}\nConfidence : {confidence_score:.2}"
print(title)plt.figure(figsize=(25, 8))
plt.title(title)
plt.imshow(image)
plt.show()while True:path =  input("input:")if (path == "q!"):exit()image = load_image(path)preds = model.predict(image[np.newaxis, ...])[0]print(preds)pred_class = class_names[np.argmax(preds)]confidence_score = np.round(preds[np.argmax(preds)], 2)# Configure Titletitle = f"Pred : {pred_class}\nConfidence : {confidence_score:.2}"print(title)plt.figure(figsize=(25, 8))plt.title(title)plt.imshow(image)plt.show()

验证结果

Model: "ResNet50V2"
_________________________________________________________________Layer (type)                Output Shape              Param #
=================================================================resnet50v2 (Functional)     (None, 8, 8, 2048)        23564800  global_average_pooling2d (G  (None, 2048)             0         lobalAveragePooling2D)                                          dense (Dense)               (None, 256)               524544    dropout (Dropout)           (None, 256)               0         dense_1 (Dense)             (None, 10)                2570      =================================================================
Total params: 24,091,914
Trainable params: 527,114
Non-trainable params: 23,564,800
_________________________________________________________________
2022-11-29 20:33:15.981925: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8204
2022-11-29 20:33:18.070138: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
1/1 [==============================] - 3s 3s/step
[1.2199847e-09 1.0668253e-12 6.8980124e-13 1.0352933e-08 9.9999988e-014.1255888e-09 7.1100374e-08 3.0439090e-10 3.1216061e-11 2.8051938e-12]
Pred : Gorilla
Confidence : 1.0

做了一个input的能力,可以通过本地的图片地址进行验证

input:./animal/Animals_Classification/Animal-Data-V2/Data-V2/Validation Data/Zebra/Zebra-Valid (276).jpeg
1/1 [==============================] - 0s 21ms/step
[1.5658158e-12 1.6018555e-10 9.6812911e-13 6.2212702e-10 5.4042397e-095.8055113e-05 4.7865592e-12 3.4024495e-12 3.0037000e-08 9.9994195e-01]
Pred : Zebra
Confidence : 1.0

基于tensorflow的ResNet50V2网络识别动物相关推荐

  1. 基于tensorflow、CNN网络识别花卉的种类(图像识别)

    基于tensorflow.CNN网络识别花卉的种类 这是一个图像识别项目,基于 tensorflow,现有的 CNN 网络可以识别四种花的种类.适合新手对使用 tensorflow进行一个完整的图像识 ...

  2. 猫狗大战——基于TensorFlow的猫狗识别(2)

    微信公众号:龙跃十二 我是小玉,一个平平无奇的小天才! 上篇文章我们说了关于猫狗大战这个项目的一些准备工作,接下来,我们看看具体的代码详解. 猫狗大战--基于TensorFlow的猫狗识别(1) 文件 ...

  3. 基于Tensorflow实现声纹识别

    前言 本章介绍如何使用Tensorflow实现简单的声纹识别模型,首先你需要熟悉音频分类,没有了解的可以查看这篇文章<基于Tensorflow实现声音分类>.基于这个知识基础之上,我们训练 ...

  4. python神经网络库识别验证码_基于TensorFlow 使用卷积神经网络识别字符型图片验证码...

    本项目使用卷积神经网络识别字符型图片验证码,其基于TensorFlow 框架.它封装了非常通用的校验.训练.验证.识别和调用 API,极大地减低了识别字符型验证码花费的时间和精力. 项目地址:http ...

  5. 基于TensorFlow的简单验证码识别

    TensorFlow 可以用来实现验证码识别的过程,这里识别的验证码是图形验证码,首先用标注好的数据来训练一个模型,然后再用模型来实现这个验证码的识别. 生成验证码 首先生成验证码,这里使用 Pyth ...

  6. 基于TensorFlow Lite的人声识别在端上的实现

    通过TensorFlow Lite,移动终端.IoT设备可以在端上实现声音识别,这可以应用在安防.医疗监护等领域.来自阿里巴巴闲鱼技术互动组仝辉和上叶通过TensorFlow Lite实现了一套完整的 ...

  7. 基于TensorFlow的手写体数字识别

    目录 一.MNIST数据集介绍 二.原理 2.1.卷积神经网络简介( convolutional neural network 简称CNN) 2.1.1卷积运算过程 2.1.2滑动的步长 2.1.3卷 ...

  8. 猫狗大战——基于TensorFlow的猫狗识别(1)

    微信公众号:龙跃十二 我是小玉,一个平平无奇的小天才! 简介: 关于猫狗识别是机器学习和深度学习的一个经典实例,下来小玉把自己做的基于CNN卷积神经网络利用Tensorflow框架进行猫狗的识别的程序 ...

  9. python opencv生成tf模型_基于TensorFlow+ Opencv 的人脸识别 和模型训练

    一.准备工作 本次实例的anaconda 环境 (有需要的自己导入anaconda) 链接:https://pan.baidu.com/s/1IVt2ap-NYdg64uHSh-viaA 提取码:g7 ...

最新文章

  1. MySQL面试题 | 附答案解析(十七)
  2. 直播:AI时代,普通程序员该如何转人工智能(限免报名)
  3. 基于 MATLAB 的 PCM 编码解码实现
  4. BaseExecutor.query()-清空本地缓存
  5. 《你必须知道的.NET》,前言
  6. MFC消息详解 (WindowProc|OnCommand|OnNotify)
  7. 惯用过程模型_惯用的Ruby:编写漂亮的代码
  8. 互联网日报 | 4月20日 星期二 | 华为正式宣布卖车;携程在港交所挂牌上市;广州期货交易所正式揭牌...
  9. QA专题阅读小组 | 每周一起读 #09
  10. 印度打车软件Ola将登陆伦敦,或将取代被吊销伦敦执照的Uber
  11. dispatcher在java中什么含义_java-我可以使用在DispatcherServlet上下文中声...
  12. linux多线程调度设置
  13. 首都师范 博弈论 5 4 3 多人合作博弈问题 Shapley值计算之股权与控股权
  14. 【微信小程序多人开发的配置流程】
  15. Laravel和Doctrine的测试驱动开发
  16. 暗金色 rgb_杜伽TAURUS K310樱桃RGB红轴体验:做工精良、手感优秀
  17. 读《软件测试经典教程》有感
  18. c++ 圆上任意点坐标计算_已知圆上任意三点坐标如何编程来计算这个圆的圆心和半径...
  19. js 获取汉字首字母和汉字转拼音
  20. Virtualbox centos虚拟机网络互联总结

热门文章

  1. Adobe Photoshop Lightroom Classic 中文版
  2. MATLAB中assignment模块,Simulink Matlab Function 模块使用问题求教
  3. dw如何制作图片自动切换效果_dw怎么用css做图片轮播
  4. @支付宝@微信支付,世界第一要来和你们抢生意了!
  5. idea使用常用基础设置
  6. C# 打开指定目录并定位到文件
  7. 蓝桥杯 基础练习 分解质因数 python语言
  8. Android Shape渐变色
  9. python注册用户名和密码登录_Python_36用户名密码登录注册的例子
  10. java没错泄露_记一次尴尬的Java应用内存泄露排查