Resnet 残差网络使用案例
Resnet 网络
深度残差网络(Deep residual network, ResNet)的提出是CNN图像史上的一件里程碑事件,在各类数据集上都有不凡的表现,Resnet是残差网络(Residual Network)的缩写,该系列网络广泛用于目标分类等领域以及作为计算机视觉任务主干经典神经网络的一部分,典型的网络有resnet50, resnet101等。Resnet网络的证明网络能够向更深(包含更多隐藏层)的方向发展。
本篇是从国外的教程learnopencv 中的TensorFlow-Fully-Convolutional-Image-Classification而来,使用tensorflow2.1以上版本,文章中使用下载预训练的模型,改成了直接在本地中加载模型。
使用已经训练好的model
Resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5 ,可以从https://storage.googleapis.com/tensorflow/keras-applications/resnet/ 直接下载
直接下载模型
效果
探测出一头阿拉伯单峰驼
探测出是一头虎鲨
其他点
在卷积操作中,一般使用 padding=‘SAME’ 填充0,但有时不灵活,我们想自己去进行补零操作,此时可以使用tf.keras.layers.ZeroPadding2D
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras import Input
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet import preprocess_input
from tensorflow.keras.layers import (Activation,AveragePooling2D,BatchNormalization,Conv2D,MaxPooling2D,ZeroPadding2D,
)
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.utils import data_utilsfrom utils import (BASE_WEIGHTS_PATH,WEIGHTS_HASHES,stack1,
)#setting FC weights to the final convolutional layer
def set_conv_weights(model, feature_extractor):# get pre-trained ResNet50 FC weightsdense_layer_weights = feature_extractor.layers[-1].get_weights()weights_list = [tf.reshape(dense_layer_weights[0], (1, 1, *dense_layer_weights[0].shape),).numpy(),dense_layer_weights[1],]model.get_layer(name="last_conv").set_weights(weights_list)def fully_convolutional_resnet50(input_shape, num_classes=1000, pretrained_resnet=True, use_bias=True,
):# init input layerimg_input = Input(shape=input_shape)# define basic model pipelinex = ZeroPadding2D(padding=((3, 3), (3, 3)), name="conv1_pad")(img_input)x = Conv2D(64, 7, strides=2, use_bias=use_bias, name="conv1_conv")(x)x = BatchNormalization(axis=3, epsilon=1.001e-5, name="conv1_bn")(x)x = Activation("relu", name="conv1_relu")(x)x = ZeroPadding2D(padding=((1, 1), (1, 1)), name="pool1_pad")(x)x = MaxPooling2D(3, strides=2, name="pool1_pool")(x)# the sequence of stacked residual blocksx = stack1(x, 64, 3, stride1=1, name="conv2")x = stack1(x, 128, 4, name="conv3")x = stack1(x, 256, 6, name="conv4")x = stack1(x, 512, 3, name="conv5")# add avg pooling layer after feature extraction layersx = AveragePooling2D(pool_size=7)(x)# add final convolutional layerconv_layer_final = Conv2D(filters=num_classes, kernel_size=1, use_bias=use_bias, name="last_conv",)(x)# configure fully convolutional ResNet50 modelmodel = training.Model(img_input, x)# load model weightsif pretrained_resnet:model_name = "resnet50"# configure full file namefile_name = model_name + "_weights_tf_dim_ordering_tf_kernels_notop.h5"# get the file hash from TF WEIGHTS_HASHES#file_hash = WEIGHTS_HASHES[model_name][1]# weights_path = data_utils.get_file(# file_name,# BASE_WEIGHTS_PATH + file_name,# cache_subdir="models",# file_hash=file_hash,# )model.load_weights(file_name)# form final modelmodel = training.Model(inputs=model.input, outputs=[conv_layer_final])if pretrained_resnet:# get model with the dense layer for further FC weights extractionresnet50_extractor = ResNet50(include_top=True, weights="imagenet", classes=num_classes,)# set ResNet50 FC-layer weights to final convolutional layerset_conv_weights(model=model, feature_extractor=resnet50_extractor)return modelif __name__ == "__main__":# read ImageNet class ids to a list of labelswith open("imagenet_classes.txt") as f:labels = [line.strip() for line in f.readlines()]# read imageoriginal_image = cv2.imread("camel.jpg")# convert image to the RGB formatimage = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)# pre-process imageimage = preprocess_input(image)# convert image to NCHW tf.tensorimage = tf.expand_dims(image, 0)# load modified resnet50 model with pre-trained ImageNet weightsmodel = fully_convolutional_resnet50(input_shape=(image.shape[-3:]))# Perform inference.# Instead of a 1×1000 vector, we will get a# 1×1000×n×m output ( i.e. a probability map# of size n × m for each 1000 class,# where n and m depend on the size of the image).preds = model.predict(image)preds = tf.transpose(preds, perm=[0, 3, 1, 2])preds = tf.nn.softmax(preds, axis=1)print("Response map shape : ", preds.shape)# find the class with the maximum score in the n × m output mappred = tf.math.reduce_max(preds, axis=1)class_idx = tf.math.argmax(preds, axis=1)print(class_idx)row_max = tf.math.reduce_max(pred, axis=1)row_idx = tf.math.argmax(pred, axis=1)col_idx = tf.math.argmax(row_max, axis=1)predicted_class = tf.gather_nd(class_idx, (0, tf.gather_nd(row_idx, (0, col_idx[0])), col_idx[0]),)# print top predicted classprint("Predicted Class : ", labels[predicted_class], predicted_class)# find the n × m score map for the predicted classscore_map = tf.expand_dims(preds[0, predicted_class, :, :], 0).numpy()score_map = score_map[0]# resize score map to the original image sizescore_map = cv2.resize(score_map, (original_image.shape[1], original_image.shape[0]),)# binarize score map_, score_map_for_contours = cv2.threshold(score_map, 0.65, 1, type=cv2.THRESH_BINARY,)score_map_for_contours = score_map_for_contours.astype(np.uint8).copy()# find the contour of the binary blobcontours, _ = cv2.findContours(score_map_for_contours, mode=cv2.RETR_EXTERNAL, method=cv2.CHAIN_APPROX_SIMPLE,)# find bounding box around the object.rect = cv2.boundingRect(contours[0])# apply score map as a mask to original imagescore_map = score_map - np.min(score_map[:])score_map = score_map / np.max(score_map[:])score_map = cv2.cvtColor(score_map, cv2.COLOR_GRAY2BGR)masked_image = (original_image * score_map).astype(np.uint8)# display bounding boxcv2.rectangle(masked_image, rect[:2], (rect[0] + rect[2], rect[1] + rect[3]), (0, 0, 255), 2,)# display imagescv2.imshow("Original Image", original_image)cv2.imshow("scaled_score_map", score_map)cv2.imshow("activations_and_bbox", masked_image)cv2.waitKey(0)
这里是util.py
from tensorflow.keras.layers import (Activation,Add,BatchNormalization,Conv2D,
)#https://github.com/tensorflow/tensorflow/blob/2b96f3662bd776e277f86997659e61046b56c315/tensorflow/python/keras/applications/resnet.py#L32
BASE_WEIGHTS_PATH = ("https://storage.googleapis.com/tensorflow/keras-applications/resnet/"
)
WEIGHTS_HASHES = {"resnet50": "4d473c1dd8becc155b73f8504c6f6626",
}#https://github.com/tensorflow/tensorflow/blob/2b96f3662bd776e277f86997659e61046b56c315/tensorflow/python/keras/applications/resnet.py#L262
def stack1(x, filters, blocks, stride1=2, name=None):"""A set of stacked residual blocks.Arguments:x: input tensor.filters: integer, filters of the bottleneck layer in a block.blocks: integer, blocks in the stacked blocks.stride1: default 2, stride of the first layer in the first block.name: string, stack label.Returns:Output tensor for the stacked blocks."""x = block1(x, filters, stride=stride1, name=name + "_block1")for i in range(2, blocks + 1):x = block1(x, filters, conv_shortcut=False, name=name + "_block" + str(i))return x#https://github.com/tensorflow/tensorflow/blob/2b96f3662bd776e277f86997659e61046b56c315/tensorflow/python/keras/applications/resnet.py#L217
def block1(x, filters, kernel_size=3, stride=1, conv_shortcut=True, name=None):"""A residual block.Arguments:x: input tensor.filters: integer, filters of the bottleneck layer.kernel_size: default 3, kernel size of the bottleneck layer.stride: default 1, stride of the first layer.conv_shortcut: default True, use convolution shortcut if True,otherwise identity shortcut.name: string, block label.Returns:Output tensor for the residual block."""# channels_last formatbn_axis = 3if conv_shortcut:shortcut = Conv2D(4 * filters, 1, strides=stride, name=name + "_0_conv")(x)shortcut = BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + "_0_bn",)(shortcut)else:shortcut = xx = Conv2D(filters, 1, strides=stride, name=name + "_1_conv")(x)x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + "_1_bn")(x)x = Activation("relu", name=name + "_1_relu")(x)x = Conv2D(filters, kernel_size, padding="SAME", name=name + "_2_conv")(x)x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + "_2_bn")(x)x = Activation("relu", name=name + "_2_relu")(x)x = Conv2D(4 * filters, 1, name=name + "_3_conv")(x)x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + "_3_bn")(x)x = Add(name=name + "_add")([shortcut, x])x = Activation("relu", name=name + "_out")(x)return x
代码和图片
所有代码和图片以及模型
Resnet 残差网络使用案例相关推荐
- (pytorch-深度学习系列)ResNet残差网络的理解-学习笔记
ResNet残差网络的理解 ResNet伴随文章 Deep Residual Learning for Image Recognition 诞生,该文章是MSRA何凯明团队在2015年ImageNet ...
- ResNet 残差网络、残差块
在深度学习中,为了增强模型的学习能力,网络的层数会不断的加深,于此同时,也伴随着一些比较棘手的问题,主要包括: ①模型复杂度上升,网络训练困难 ②出现梯度消失/梯度爆炸问题 ③网络退化,即增加层数并不 ...
- ResNet残差网络Pytorch实现——对花的种类进行训练
ResNet残差网络Pytorch实现--对花的种类进行训练 上一篇:[结合各个残差块] ✌✌✌✌ [目录] ✌✌✌✌ 下一篇:[对花的种类进行单数据预测] 大学生一枚,最近在学习神经网络,写这篇文章 ...
- 目标检测学习笔记2——ResNet残差网络学习、ResNet论文解读
ResNet残差网络学习.ResNet论文解读 一.前言 为什么会提出ResNet? 什么是网络退化现象? 那网络退化现象是什么造成的呢? ResNet要如何解决退化问题? 二.残差模块 三.残差模块 ...
- ResNet残差网络
(二十七)通俗易懂理解--Resnet残差网络 - 梦里寻梦的文章 - 知乎 https://zhuanlan.zhihu.com/p/67860570
- ResNet残差网络及变体详解(符代码实现)
本文通过分析深度网络模型的缺点引出ResNet残差网络,并介绍了几种变体,最后用代码实现ResNet18. 文章目录 前言 模型退化 残差结构 ResNet网络结构 Pre Activation Re ...
- 【五一创作】使用Resnet残差网络对图像进行分类(猫十二分类,模型定义、训练、保存、预测)(二)
使用Resnet残差网络对图像进行分类 (猫十二分类,模型定义.训练.保存.预测)(二) 目录 (6).数据集划分 (7).训练集增强 (8).装载数据集 (9).初始化模型 (10).模型训练 (1 ...
- 【五一创作】使用Resnet残差网络对图像进行分类(猫十二分类,模型定义、训练、保存、预测)(一)
使用Resnet残差网络对图像进行分类 (猫十二分类,模型定义.训练.保存.预测)(一) 目录 一.项目简介 二.环境说明 1.安装库 2.导入需要的库 三.分类过程 (1).解压数据集 (2).相关 ...
- cnpm 网络不能连接_(二十七)通俗易懂理解——Resnet残差网络
Resnet看相关的文章都比较容易理解,本文主要转自两篇对该内容有较为全面解释和理解的文章. 1. 引言 网络的深度为什么重要? 因为CNN能够提取low/mid/high-level的特征,网络的层 ...
最新文章
- torch学习笔记(二) nn类结构-Linear
- mysql使用数据库预处理_php中对MYSQL操作之预处理技术(2)数据库dql查询语句
- 微信支付开发(1) JS API支付
- highcharts x轴 按照时间 datetime排序
- Ubuntu 12.10连接米2
- Python学习日记之中文支持
- Android之通过VideoView控件播放一个视频出现的问题以及我的解决办法
- CDH5 6 7安装包
- 一个nginx 502问题解决方案
- tkinter的可视化拖拽工具_拒绝丑图表,教你用最简单的方法做最炫酷的可视化图表!附教程...
- leetcode—25.链表排序题目leetcode总结
- maven安装教程+Eclipse整合
- form表单提交中文乱码的详细解析
- 神舟笔记本怎么进入bios?神舟笔记本bios设置U盘启动教程
- html span 字体位置,span内文字居中css布局方法_让span内容居中
- 如何建设一个集团网站
- c/c++ 内存使用指南 和实践指导
- 参赛【GGJ2022】回顾,作品--双生共合
- Word 项目文档排版
- 如何用rose画出展示对象流的活动图
热门文章
- 最新软件开发企业网站Pbootcms模板源码
- 除权除息日为:2015年5月21日 股票一览
- 20230126使AIO-3568J开发板在原厂Android11下跑起来
- Inno Setup实例教程之一:软件安装和demo使用
- tamp-s2gcnets: coupling time-aware multipersistence knowledge representation with spatio-supra gr...
- catia怎样倒2d_Catia Drafting平面图生成 3D 转2D教程.ppt
- 【SQL开发实战技巧】系列(一):关于SQL不得不说的那些事
- 【MD】高等数学常用符号
- html静态登录、注册页面
- [转载]让matlab发出声音、播放音乐