在这次实战训练中,首先对下载的猫十二数据集进行预处理,使用了tensorflow构建resnet模型,在学习率调度上,使用了1周期调度,并且使用了动量优化和Nesterov加速梯度

1.导包

from tensorflow import keras
import tensorflow as tf
from keras.preprocessing import image
import random
from matplotlib import pyplot as plt
import cv2
from tqdm import tqdm
import numpy as np
import math

2.数据预处理

cat_12数据集包含3个部分,训练集cat_12_train,测试集cat_test,以及存储图片名称及标签的train_list.txt

(1)定义prepare_image函数从文件中分离路径和标签

def prepare_image(file_path):X_train = []y_train = []with open(file_path) as f:context = f.readlines()random.shuffle(context)for str in context:str = str.strip('\n').split('\t')X_train.append('./image/cat_12/' + str[0])y_train.append(str[1])return X_train, y_train

(2)定义preprocess_image函数进行图像的归一化

def preprocess_image(img):img = image.load_img(img, target_size=(224, 224))img = image.img_to_array(img)img = img / 255.0return img

(3)定义plot_image函数打印图像

def plot_image(images,classes):fig,axes = plt.subplots(4,3,figsize=(60,60),sharex=True)fig.subplots_adjust(hspace=0.3,wspace=0.3)for i,ax in enumerate(axes.flat):image = cv2.imread(images[i])image = cv2.resize(image,(224,224))ax.imshow(cv2.cvtColor(image,cv2.COLOR_BGR2RGB),cmap="hsv")ax.set_xlabel("Breed:{}".format(classes[i]))ax.xaxis.label.set_size(60)ax.set_yticks([])plt.show()

3.展示数据集

X_train, y_train = prepare_image('./image/cat_12/train_list.txt')
print(X_train)

x_array = np.array(X_train)
y_array = np.array(y_train)
y_unique = np.unique(y_array)
print(y_unique)
['0' '1' '10' '11' '2' '3' '4' '5' '6' '7' '8' '9']
imgs = []
classes = []
for i in y_unique:sort = x_array[y_array==i]idx = np.random.randint(len(sort)-1)imgs.append(sort[idx])classes.append(i)
print(imgs)
plot_image(imgs,classes)

4.准备数据集

train_images = []
for i in tqdm(X_train):train_image = preprocess_image(i)train_images.append(train_image)
train_images = np.array(train_images)
y_train = keras.utils.to_categorical(y_train, 12)
print(train_images.shape)

5. 基于tensorflow构建ResNet

# 构建卷积块
def conv_2d(x,filters,kernel_size,strides,padding="same"):x = keras.layers.Conv2D(filters,kernel_size=kernel_size,strides=strides,padding=padding)(x)x = keras.layers.BatchNormalization()(x)x = keras.activations.relu(x)return x
# 构建残差块
def resual_block(inputs,filters,strides):x = inputsx = conv_2d(x,filters=filters,kernel_size=1,strides=strides)x = conv_2d(x,filters=filters,kernel_size=3,strides=1)x = conv_2d(x,filters=4*filters,kernel_size=1,strides=1)x_short = conv_2d(inputs,filters=4*filters,kernel_size=1,strides=strides)x = keras.layers.Add()([x,x_short])x = keras.activations.relu(x)return x
# 构建resnet_152
def resnet(input_shape,n_classes=1000):x_input = keras.layers.Input(shape=input_shape)x = conv_2d(x_input,filters=64,kernel_size=7,strides=2)x = keras.layers.MaxPooling2D(pool_size=(3,3),strides=2,padding="same")(x)# input 64*3x = resual_block(x,64,strides=1)x = resual_block(x,64,strides=1)x = resual_block(x,64,strides=1)# input 128*8x = resual_block(x,128,strides=2)for i in range(7):x = resual_block(x,128,strides=1)# input 256*36x = resual_block(x,256,strides=2)for i in range(35):x = resual_block(x,256,strides=1)# input 512*3x = resual_block(x,512,strides=2)for i in range(2):x = resual_block(x,512,strides=1)# 全局平均池化x = keras.layers.GlobalAveragePooling2D()(x)output = keras.layers.Dense(n_classes,activation="softmax")(x)model = keras.models.Model(inputs=[x_input],outputs=[output])return model
model = resnet([224,224,3],12)

6.1周期调度

k = keras.backend
class One_Cycle(keras.callbacks.Callback):def __init__(self,interations,max_rate,min_rate=None,start_rate=None,last_interations=None):self.interations = interationsself.max_rate = max_rateself.min_rate = min_rate or self.max_rate/10000self.start_rate = start_rate or self.max_rate/100self.last_interations = last_interations or self.interations//10+1self.half_interations = (self.interations - self.last_interations)//2self.interation = 0self.loss = []self.learning_rate = []self.numbers = []def _interpolate(self,iter1,iter2,rate1,rate2):return ((rate2-rate1)*(self.interation-iter1)/(iter2-iter1)+rate1)def on_batch_begin(self,batch,logs=None):if self.interation < self.half_interations:rate = self._interpolate(0,self.half_interations,self.start_rate,self.max_rate)elif self.interation < 2*self.half_interations:rate = self._interpolate(self.half_interations,2*self.half_interations,self.max_rate,self.start_rate)else:rate = self._interpolate(2*self.half_interations,self.interations,self.max_rate,self.min_rate)self.interation += 1k.set_value(self.model.optimizer.learning_rate,rate)def on_batch_end(self,batch,logs=None):self.learning_rate.append(k.get_value(self.model.optimizer.learning_rate))self.loss.append(logs["loss"])self.numbers.append(self.interation)

7.训练模型

n_epochs = 100
one_cycle = One_Cycle(math.ceil(len(X_train)//16)*n_epochs,0.1,min_rate=1e-5)
optimizer = keras.optimizers.SGD(learning_rate=0.001,momentum=0.9,nesterov=True)
model.compile(loss="categorical_crossentropy",optimizer=optimizer,metrics=["accuracy"])
history = model.fit(train_images,y_train,epochs=n_epochs,validation_split=0.2,batch_size=16,callbacks=[one_cycle])

8.绘制损失变化情况

def show_training_history(train_history, train, val):plt.plot(train_history[train], linestyle='-', color='b')plt.plot(train_history[val], linestyle='--', color='r')plt.xlabel('Epoch', fontsize=12)plt.ylabel('train', fontsize=12)plt.legend(['train', 'validation'], loc='lower right')plt.show()

基于ResNet的猫十二分类相关推荐

  1. 【五一创作】使用Resnet残差网络对图像进行分类(猫十二分类,模型定义、训练、保存、预测)(二)

    使用Resnet残差网络对图像进行分类 (猫十二分类,模型定义.训练.保存.预测)(二) 目录 (6).数据集划分 (7).训练集增强 (8).装载数据集 (9).初始化模型 (10).模型训练 (1 ...

  2. 【五一创作】使用Resnet残差网络对图像进行分类(猫十二分类,模型定义、训练、保存、预测)(一)

    使用Resnet残差网络对图像进行分类 (猫十二分类,模型定义.训练.保存.预测)(一) 目录 一.项目简介 二.环境说明 1.安装库 2.导入需要的库 三.分类过程 (1).解压数据集 (2).相关 ...

  3. 基于AlexNet网络的猫十二分类

    1.项目简介 1.数据集 cat_12数据集包含3个部分,训练集cat_12_train,测试集cat_test,以及存储图片名称及标签的train_list.txt 2.数据预处理 首先,定义一个p ...

  4. pytorch 猫狗二分类 resnet

    深度学习(猫狗二分类) 题目要求 数据获取与预处理 网络模型 模型原理 Resnet背景 Resnet原理 代码实现 模型构建 训练过程 批验证过程 单一验证APP 运行结果 训练结果 批验证结果 A ...

  5. 基于Tensorflow的英文评论二分类CNN模型

    基于Tensorflow的英文评论二分类模型 前言 经过机器学习生成的模型,可以判断英语的肯定或否定含义,减轻了人的工作量,使得对大量意见进行归集,判断成为可能 ==>源代码Github下载 导 ...

  6. 猫狗二分类与四种天气多分类

    提示:文章用于学习记录 文章目录 前言 一.猫狗图像分类 1.1 数据预处理 1.2 构建神经网络 二.四种天气图片数据分类(pytorch) 总结 前言 常见分类网络结构可以分为两部分,一部分是特征 ...

  7. 猫狗二分类实战(PyTorch)

    PyTorch实战指南 文章目录 PyTorch实战指南 比赛介绍 文件组织架构 关于`__init__.py` 数据加载 模型定义 工具函数 配置文件 main.py 训练 验证 测试 帮助函数 使 ...

  8. 基于深度学习的简单二分类(招聘信息的真假)

    招聘数据真假分类 此次机器学习课程大作业-招聘数据真假分类,是一个二分类问题.训练集中共有14304个样本,每个样本有18个特征,目标是判断不含有标签的招聘信息的真假性. 利用Pandas读取训练集和 ...

  9. 【小白学习PyTorch教程】七、基于乳腺癌数据集​​构建Logistic 二分类模型

    「@Author:Runsen」 在逻辑回归中预测的目标变量不是连续的,而是离散的.可以应用逻辑回归的一个示例是电子邮件分类:标识为垃圾邮件或非垃圾邮件.图片分类.文字分类都属于这一类. 在这篇博客中 ...

最新文章

  1. 转mosquitto auth plugin 编译配置
  2. php 上传大文件涉及的配置,upload_max_filesize,post_max_size
  3. VC++基于APR实现禁止某个业务(开发行为控制软件用得着,例如上班禁止上QQ)...
  4. js函数提示 vscode_工欲善其事,必先利其器,VSCode高效插件
  5. [蓝桥杯][2019年第十届真题c/c++B组]完全二叉树的权值
  6. 华为eNSP静态路由原理与配置实例详解
  7. 递归下降算法语法分析c语言
  8. traceroute、tracert服务的工作原理
  9. 某次TPMC测试一直上不去的原因
  10. “闽南金三角”——世丰管道福建漳州高级水电工程师会议
  11. python编程基础-上海交通大学版答案
  12. Python dataframe绘制饼图_Python可视化29|matplotlib-饼图(pie)
  13. 拼多多关键搜索、商品列表接口、商品详情接口
  14. Android本地视频播放器开发--搜索本地视频(1)
  15. KMeans聚类 K值的确定以及初始类簇中心点的选取
  16. codeblocks-13.12mingw 配置opencv-3.1.0(一)
  17. 解决STC8串口2的txd发送脚无法使用的问题
  18. c语言笔试题目,C语言考试题库及答案2015.doc
  19. 【数据集】中国各类水文专业常用数据集合集
  20. P23 (**) Extract a given number of randomly selected elements from a list.

热门文章

  1. Java项目中使用Freemarker生成Word文档
  2. 三只松鼠3次方新品魅力何在?
  3. Express响应方法
  4. 【PS基础】-照片拼接基础
  5. 浅谈UEBA基本实现步骤
  6. 物联网iot私有云平台搭建
  7. Stochastic Weight Averaging (SWA) 随机权重平均
  8. AFX_VIRTUAL
  9. 木纹标识lisp_AutoLisp学习笔记:变量类型
  10. iOS-使用CoreLocation定位