文章目录

  • 1、导入模型

  • 2、定义加载函数

  • 3、定义批量加载函数

  • 4、加载数据

  • 5、定义数据预处理及训练模型的一些超参数

  • 6、定义数据增强模型

  • 7、构建模型

  • 7.1 构建多层感知器(MLP)

  • 7.2 创建一个类似卷积层的patch层

  • 7.3 查看由patch层随机生成的图像块

  • 7.4构建patch 编码层( encoding layer)

  • 7.5构建ViT模型

  • 8、编译、训练模型

  • 9、查看运行结果

使用Transformer来提升模型的性能
最近几年,Transformer体系结构已成为自然语言处理任务的实际标准,
但其在计算机视觉中的应用还受到限制。在视觉上,注意力要么与卷积网络结合使用,
要么用于替换卷积网络的某些组件,同时将其整体结构保持在适当的位置。2020年10月22日,谷歌人工智能研究院发表一篇题为“An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale”的文章。文章将图像切割成一个个图像块,组成序列化的数据输入Transformer执行图像分类任务。当对大量数据进行预训练并将其传输到多个中型或小型图像识别数据集(如ImageNet、CIFAR-100、VTAB等)时,与目前的卷积网络相比,Vision Transformer(ViT)获得了出色的结果,同时所需的计算资源也大大减少。
这里我们以ViT我模型,实现对数据CiFar10的分类工作,模型性能得到进一步的提升。

1、导入模型

import os
import math
import numpy as np
import pickle as p
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from tensorflow.keras import layers
import tensorflow_addons as tfa
%matplotlib inline

这里使用了TensorFlow_addons模块,它实现了核心 TensorFlow 中未提供的新功能。
tensorflow_addons的安装要注意与tf的版本对应关系,请参考:
https://github.com/tensorflow/addons。
安装addons时要注意其版本与tensorflow版本的对应,具体关系以上这个链接有。

2、定义加载函数

def load_CIFAR_data(data_dir):"""load CIFAR data"""images_train=[]labels_train=[]for i in range(5):f=os.path.join(data_dir,'data_batch_%d' % (i+1))print('loading ',f)# 调用 load_CIFAR_batch( )获得批量的图像及其对应的标签image_batch,label_batch=load_CIFAR_batch(f)images_train.append(image_batch)labels_train.append(label_batch)Xtrain=np.concatenate(images_train)Ytrain=np.concatenate(labels_train)del image_batch ,label_batchXtest,Ytest=load_CIFAR_batch(os.path.join(data_dir,'test_batch'))print('finished loadding CIFAR-10 data')# 返回训练集的图像和标签,测试集的图像和标签
return (Xtrain,Ytrain),(Xtest,Ytest)

3、定义批量加载函数

def load_CIFAR_batch(filename):""" load single batch of cifar """  with open(filename, 'rb')as f:# 一个样本由标签和图像数据组成#  (3072=32x32x3)# ...# data_dict = p.load(f, encoding='bytes')images= data_dict[b'data']labels = data_dict[b'labels']# 把原始数据结构调整为: BCWHimages = images.reshape(10000, 3, 32, 32)# tensorflow处理图像数据的结构:BWHC# 把通道数据C移动到最后一个维度images = images.transpose (0,2,3,1)labels = np.array(labels)return images, labels

4、加载数据

data_dir = r'C:\Users\wumg\jupyter-ipynb\data\cifar-10-batches-py'
(x_train,y_train),(x_test,y_test) = load_CIFAR_data(data_dir)

把数据转换为dataset格式

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))

5、定义数据预处理及训练模型的一些超参数

num_classes = 10
input_shape = (32, 32, 3)learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 10
image_size = 72  # We'll resize input images to this size
patch_size = 6  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [projection_dim * 2,projection_dim,
]  # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [2048, 1024]  # Size of the dense layers of the final classifier

6、定义数据增强模型

data_augmentation = keras.Sequential([layers.experimental.preprocessing.Normalization(),layers.experimental.preprocessing.Resizing(image_size, image_size),layers.experimental.preprocessing.RandomFlip("horizontal"),layers.experimental.preprocessing.RandomRotation(factor=0.02),layers.experimental.preprocessing.RandomZoom(height_factor=0.2, width_factor=0.2),],name="data_augmentation",
)
# 使预处理层的状态与正在传递的数据相匹配
#Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)

预处理层是在模型训练开始之前计算其状态的层。他们在训练期间不会得到更新。大多数预处理层为状态计算实现了adapt()方法。
adapt(data, batch_size=None, steps=None, reset_state=True)该函数参数说明如下:

7、构建模型

7.1 构建多层感知器(MLP)

def mlp(x, hidden_units, dropout_rate):for units in hidden_units:x = layers.Dense(units, activation=tf.nn.gelu)(x)x = layers.Dropout(dropout_rate)(x)return x

7.2 创建一个类似卷积层的patch层

class Patches(layers.Layer):def __init__(self, patch_size):super(Patches, self).__init__()self.patch_size = patch_sizedef call(self, images):batch_size = tf.shape(images)[0]patches = tf.image.extract_patches(images=images,sizes=[1, self.patch_size, self.patch_size, 1],strides=[1, self.patch_size, self.patch_size, 1],rates=[1, 1, 1, 1],padding="VALID",)patch_dims = patches.shape[-1]patches = tf.reshape(patches, [batch_size, -1, patch_dims])return patches

7.3 查看由patch层随机生成的图像块

import matplotlib.pyplot as pltplt.figure(figsize=(4, 4))
image = x_train[np.random.choice(range(x_train.shape[0]))]
plt.imshow(image.astype("uint8"))
plt.axis("off")resized_image = tf.image.resize(tf.convert_to_tensor([image]), size=(image_size, image_size)
)
patches = Patches(patch_size)(resized_image)
print(f"Image size: {image_size} X {image_size}")
print(f"Patch size: {patch_size} X {patch_size}")
print(f"Patches per image: {patches.shape[1]}")
print(f"Elements per patch: {patches.shape[-1]}")n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):ax = plt.subplot(n, n, i + 1)patch_img = tf.reshape(patch, (patch_size, patch_size, 3))plt.imshow(patch_img.numpy().astype("uint8"))plt.axis("off")

运行结果
Image size: 72 X 72
Patch size: 6 X 6
Patches per image: 144
Elements per patch: 108

7.4构建patch 编码层( encoding layer)

class PatchEncoder(layers.Layer):def __init__(self, num_patches, projection_dim):super(PatchEncoder, self).__init__()self.num_patches = num_patches#一个全连接层,其输出维度为projection_dim,没有指明激活函数self.projection = layers.Dense(units=projection_dim)#定义一个嵌入层,这是一个可学习的层#输入维度为num_patches,输出维度为projection_dimself.position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim)def call(self, patch):positions = tf.range(start=0, limit=self.num_patches, delta=1)encoded = self.projection(patch) + self.position_embedding(positions)return encoded

7.5构建ViT模型

def create_vit_classifier():inputs = layers.Input(shape=input_shape)# Augment data.augmented = data_augmentation(inputs)#augmented = augmented_train_batches(inputs)    # Create patches.patches = Patches(patch_size)(augmented)# Encode patches.encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)# Create multiple layers of the Transformer block.for _ in range(transformer_layers):# Layer normalization 1.x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)# Create a multi-head attention layer.attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=0.1)(x1, x1)# Skip connection 1.x2 = layers.Add()([attention_output, encoded_patches])# Layer normalization 2.x3 = layers.LayerNormalization(epsilon=1e-6)(x2)# MLP.x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)# Skip connection 2.encoded_patches = layers.Add()([x3, x2])# Create a [batch_size, projection_dim] tensor.representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)representation = layers.Flatten()(representation)representation = layers.Dropout(0.5)(representation)# Add MLP.features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)# Classify outputs.logits = layers.Dense(num_classes)(features)# Create the Keras model.model = keras.Model(inputs=inputs, outputs=logits)
return model

该模型的处理流程如下图所示

8、编译、训练模型

def run_experiment(model):optimizer = tfa.optimizers.AdamW(learning_rate=learning_rate, weight_decay=weight_decay)model.compile(optimizer=optimizer,loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy"),keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),],)#checkpoint_filepath = r".\tmp\checkpoint"checkpoint_filepath ="model_bak.hdf5"checkpoint_callback = keras.callbacks.ModelCheckpoint(checkpoint_filepath,monitor="val_accuracy",save_best_only=True,save_weights_only=True,)history = model.fit(x=x_train,y=y_train,batch_size=batch_size,epochs=num_epochs,validation_split=0.1,callbacks=[checkpoint_callback],)model.load_weights(checkpoint_filepath)_, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)print(f"Test accuracy: {round(accuracy * 100, 2)}%")print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")return history

实例化类,运行模型

vit_classifier = create_vit_classifier()
history = run_experiment(vit_classifier)

运行结果
Epoch 1/10
176/176 [==============================] - 68s 333ms/step - loss: 2.6394 - accuracy: 0.2501 - top-5-accuracy: 0.7377 - val_loss: 1.5331 - val_accuracy: 0.4580 - val_top-5-accuracy: 0.9092
Epoch 2/10
176/176 [==============================] - 58s 327ms/step - loss: 1.6359 - accuracy: 0.4150 - top-5-accuracy: 0.8821 - val_loss: 1.2714 - val_accuracy: 0.5348 - val_top-5-accuracy: 0.9464
Epoch 3/10
176/176 [==============================] - 58s 328ms/step - loss: 1.4332 - accuracy: 0.4839 - top-5-accuracy: 0.9210 - val_loss: 1.1633 - val_accuracy: 0.5806 - val_top-5-accuracy: 0.9616
Epoch 4/10
176/176 [==============================] - 58s 329ms/step - loss: 1.3253 - accuracy: 0.5280 - top-5-accuracy: 0.9349 - val_loss: 1.1010 - val_accuracy: 0.6112 - val_top-5-accuracy: 0.9572
Epoch 5/10
176/176 [==============================] - 58s 330ms/step - loss: 1.2380 - accuracy: 0.5626 - top-5-accuracy: 0.9411 - val_loss: 1.0212 - val_accuracy: 0.6400 - val_top-5-accuracy: 0.9690
Epoch 6/10
176/176 [==============================] - 58s 330ms/step - loss: 1.1486 - accuracy: 0.5945 - top-5-accuracy: 0.9520 - val_loss: 0.9698 - val_accuracy: 0.6602 - val_top-5-accuracy: 0.9718
Epoch 7/10
176/176 [==============================] - 58s 330ms/step - loss: 1.1208 - accuracy: 0.6060 - top-5-accuracy: 0.9558 - val_loss: 0.9215 - val_accuracy: 0.6724 - val_top-5-accuracy: 0.9790
Epoch 8/10
176/176 [==============================] - 58s 330ms/step - loss: 1.0643 - accuracy: 0.6248 - top-5-accuracy: 0.9621 - val_loss: 0.8709 - val_accuracy: 0.6944 - val_top-5-accuracy: 0.9768
Epoch 9/10
176/176 [==============================] - 58s 330ms/step - loss: 1.0119 - accuracy: 0.6446 - top-5-accuracy: 0.9640 - val_loss: 0.8290 - val_accuracy: 0.7142 - val_top-5-accuracy: 0.9784
Epoch 10/10
176/176 [==============================] - 58s 330ms/step - loss: 0.9740 - accuracy: 0.6615 - top-5-accuracy: 0.9666 - val_loss: 0.8175 - val_accuracy: 0.7096 - val_top-5-accuracy: 0.9806
313/313 [==============================] - 9s 27ms/step - loss: 0.8514 - accuracy: 0.7032 - top-5-accuracy: 0.9773
Test accuracy: 70.32%
Test top 5 accuracy: 97.73%
In [15]:
从结果看可以来看,测试精度已达70%,这是一个较大提升!

9、查看运行结果

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss =history.history['val_loss']plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1.1])
plt.title('Training and Validation Accuracy')plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([-0.1,4.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

运行结果

作者 :吴茂贵,资深大数据和人工智能技术专家,在BI、数据挖掘与分析、数据仓库、机器学习等领域工作超过20年!在基于Spark、TensorFlow、Pytorch、Keras等机器学习和深度学习方面有大量的工程实践经验。代表作有《深入浅出Embedding:原理解析与应用实践》、《Python深度学习基于Pytorch》和《Python深度学习基于TensorFlow》。

——The  End——


往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载机器学习的数学基础专辑黄海广老师《机器学习课程》视频课
本站qq群851320808,加入微信群请扫码:

【深度学习】使用transformer进行图像分类相关推荐

  1. 【深度学习】Transformer在语义分割上的应用探索

    [深度学习]Transformer在语义分割上的应用探索 文章目录 1 Segmenter 2 Swin-Unet:Unet形状的纯Transformer的医学图像分割 3 复旦大学提出SETR:基于 ...

  2. 【Android,Kotlin,TFLite】移动设备集成深度学习轻模型TFlite(图像分类篇)

    深度学习.Tensorflow.TFLite.移动设备集成深度学习轻模型TFlite.图像分类篇 Why i create it? 为了创建一个易用且易于集成的TFlite加载库, 所以TFLiteL ...

  3. 深度学习经典网络解析图像分类篇(二):AlexNet

    深度学习经典网络解析图像分类篇(二):AlexNet 1.背景介绍 2.ImageNet 3.AlexNet 3.1AlexNet简介 3.2AlexNet网络架构 3.2.1第一层(CONV1) 3 ...

  4. 【深度学习前沿应用】图像分类Fine-Tuning

    [深度学习前沿应用]图像分类Fine-Tuning 作者简介:在校大学生一枚,华为云享专家,阿里云星级博主,腾云先锋(TDP)成员,云曦智划项目总负责人,全国高等学校计算机教学与产业实践资源建设专家委 ...

  5. 【深度学习】Transformer 向轻量型迈进!微软与中科院提出两路并行的 Mobile-Former...

    作者丨happy 编辑丨极市平台 导读 本文创造性的将MobileNet与Transformer进行了两路并行设计,穿插着全局与特征的双向融合,同时利用卷积与Transformer两者的优势达到&qu ...

  6. 【机器学习PAI实践十】深度学习Caffe框架实现图像分类的模型训练

    背景 我们在之前的文章中介绍过如何通过PAI内置的TensorFlow框架实验基于Cifar10的图像分类,文章链接:https://yq.aliyun.com/articles/72841.使用Te ...

  7. 【深度学习】Transformer长大了,它的兄弟姐妹们呢?(含Transformers超细节知识点)...

    最近复旦放出了一篇各种Transformer的变体的综述(重心放在对Transformer结构(模块级别和架构级别)改良模型的介绍),打算在空闲时间把这篇文章梳理一下: 知乎:https://zhua ...

  8. 【深度学习】transformer 真的快要取代计算机视觉中的 CNN 吗?

    我相信你肯定已经在自然语言领域中听说过 transformer 这种结构,因为它在 2020 年的 GPT3 上引起了巨大轰动.Transformer 不仅仅可以用于NLP,在许多其他领域表现依然非常 ...

  9. 基于深度学习模型的花卉图像分类代码_华为不止有鸿蒙!教你快速入门华为免编程深度学习神器ModelArts...

    引言: 本文介绍利用华为ModelArts进行深度学习的图像分类任务,不用一行代码. 今年8月9日,在华为史上规模最大的开发者大会上,华为正式发布全球首个基于微内核的全场景分布式OS--鸿蒙操作系统( ...

  10. 【深度学习】Transformer温故知新

    这是之前学习paddle时候的笔记,对Transformer框架进行了拆解,附图解和代码,希望对大家有帮助  写在前面 最近在学习paddle相关内容,质量比较高的参考资料好像就paddle官方文档[ ...

最新文章

  1. 个人开发者做一款Android App需要知道的事情
  2. 【攻防世界004】dmd-50
  3. php pdo连接oracle乱码,php pdo oracle中文乱码的快速解决方法
  4. 【渝粤教育】 国家开放大学2020年春季 2071美学与美育 参考试题
  5. C++11中的右值引用
  6. char 转wchar_t 及wchar_t转char
  7. python中elif和while简单介绍及注意事项(含笔记)
  8. 【Java从0到架构师】日志处理 - Log4j 1.x、JCL
  9. Posta:跨文档信息安全搜索工具
  10. 如何在学习中找到乐趣?怎样才能找到学习的乐趣
  11. 1092: 地头蛇PIPI
  12. subclass and extends
  13. Dremel和Hadoop
  14. 批量保存西瓜无水印视频的方法步骤
  15. Redis源码篇(1)——底层数据结构与对象
  16. velocity 将字符串切割按每隔3位加逗号,map集合遍历,字符串拼接,
  17. 软件、Chrome字体细到模糊发虚解决方案
  18. 数学在计算机图形学中的应用
  19. elasticsearch:使用top_hits聚合获取分组列表
  20. c语言编程题库this is a c program,C语言末复习题编程题部.doc

热门文章

  1. HTML5新增的标签
  2. 一位996、CRUD开发者的一天
  3. 手机和邮箱的正则表达式
  4. 列表(list)之一定义 添加 删除 排序 反转 索引等其他操作
  5. 安卓自动化测试——rf
  6. mycat 编辑schema.xml
  7. Python学习笔记 setdict
  8. 3.Android 优化布局(解决TextView布局)
  9. Android Studio-目录结构
  10. 触发器初接触-同步两个表的指定字段