TensorFlow搭建VGG-Siamese网络


  • Siamese原理

Siamese网络,中文称为孪生网络。大致结构如下图所示:

Siamese网络有两个输入,一个输出。其中,两个输入经过相同的网络层知道成为一个n维向量,再对这个n维向量进行求距离,对此距离应用softmax函数,得到输出的结果。

例如,使用Siamese做一个人脸识别,那么输入就是两个人脸图像,若是同一个人输出1,若是不同的人则输出0。

首先,我们制作一个输入为(h, w, c),输出为(1, 128)的VGG模型,这里不使用完整的模型,我称为VGG-lite版。

import tensorflow as tf
from tensorflow import keras
from keras import backend as K
from tensorflow.keras import layers, Sequential
from tensorflow.keras.layers import Conv2D, ZeroPadding2D, Activation, MaxPooling2D, Dropout, Flatten, Dense, Lambda, Input
from tensorflow.keras.models import Model# 这里实现一个VGG网络,返回的是一个128维向量,用于siamese的输入
def VGG(X_input):X = X_inputX = Conv2D(64, (3,3), padding = 'same',activation='relu')(X)X = Conv2D(64, (3,3), padding = 'same',activation='relu')(X)X = MaxPooling2D(pool_size=(2,2), strides=2)(X)X = Conv2D(128, (3,3), padding = 'same',activation='relu')(X)X = Conv2D(128, (3,3), padding = 'same',activation='relu')(X)X = Dropout(0.4)(X)X = MaxPooling2D(pool_size=(2,2), strides=2)(X)X = Conv2D(256, (3,3), padding = 'same',activation='relu')(X)# X = MaxPooling2D(pool_size=(2,2), strides=2)(X)X = Conv2D(256, (3,3), padding = 'same',activation='relu')(X)X = Conv2D(256, (3,3), padding = 'same',activation='relu')(X)X = Dropout(0.4)(X)X = MaxPooling2D(pool_size=(2,2), strides=2)(X)X = Conv2D(512, (3,3), padding = 'same',activation='relu')(X)X = MaxPooling2D(pool_size=(2,2), strides=2)(X)X = Conv2D(512, (3,3), padding = 'same',activation='relu')(X)X = Conv2D(512, (3,3), padding = 'same',activation='relu')(X)X = Dropout(0.4)(X)X = MaxPooling2D(pool_size=(2,2), strides=2)(X)X = Flatten()(X)X = Dense(1024, activation="relu")(X)X = Dense(128, activation="relu")(X)X = Lambda(lambda  x: K.l2_normalize(x,axis=1))(X)return X

这里对模型不再详细解释,只解释下对X的最后一步操作:通过keras.layers中的Lambda将128维的X进行L2正则化再输出。

对于模型构建其他部分的疑问,可以参考我的前两份文章。

接下来,我们要制作一个可以接受两个输入的模型。在TensorFlow中,只需要在定义模型的函数中,使用多次Input()即可获得多个输入。

def VGG_Siamese(input_shape):# 接收两个输入,X1_input和X2_input.X1_input = Input(input_shape)X2_input = Input(input_shape)X1 = ZeroPadding2D((3, 3), name='layer1')(X1_input)X2 = ZeroPadding2D((3, 3), name='layer2')(X2_input)X1 = VGG(X1)X2 = VGG(X2)print(X1)print(X2)l1_distance_layer = Lambda(lambda tensors: K.abs(tensors[0] - tensors[1]))l1_distance = l1_distance_layer([X1, X2])    X = Dense(512, activation='relu')(l1_distance)X = Dense(2, activation='softmax')(X)model = Model(inputs = [X1_input, X2_input], outputs = X)return model

在使用Input()获得两个输入后,将两个输入一同经过了VGG()函数,这说明两个输入会经历相同的卷积网络成为两个128维向量。而

 l1_distance_layer = Lambda(lambda tensors: K.abs(tensors[0] - tensors[1]))l1_distance = l1_distance_layer([X1, X2])

这两句是将得到的两个128维向量进行距离求和,使用差值绝对值求得,得到的结果也是一个128维向量。

再之后,将得到的128维向量经过全连接层与512维、2维(即classes维)连接,得到一个二维向量,这个二维向量使用"softmax"激活函数,得到预测结果。


通过上面的两个函数,我们已经完成了模型的构建,接下来,我们从处理数据集开始,讲解如何对此模型进行训练。

笔者选用的数据集是LFW数据集,各位可以自行选择数据集,下面介绍一种简单的数据集处理方法(LFW数据集有pairs.txt文件,处理方式与下面介绍的不一致,这并不影响,因为得到的数据集形式是相同的):

  • 因为不同数据集可能有不同的初步获取方式,因此这里假设我们获得了dataset_x(图像)、dataset_y(标签).

对于数据处理的思想是:首先取数据集中的任一图片,然后再随机取另一张图片(不要与第一张图片相同),将第一张图片加入X_L(这是一个list),将第二张图片加入X_R,如果两张图片是同一个人,将1加入labels(这是标签集),如果两张图片不是同一个人,将0加入labels。具体操作如下:

X_L = []
X_R = []
labels = []
for i in range(dataset_x.shape[0]):for j in range(4):  # 每个数据与四个其他数据对比a = random.randint(0,dataset_x.shape[0]-1)while a == i:a = random.randint(0,dataset_x.shape[0]-1)X_L.append(dataset_x[i])X_R.append(dataset_x[a])if dataset_y[i] == dataset_y[a]:labels.append(1)else:labels.append(0)

这样,我们得到了一个具有两个图片并且已经标志其是否为同一人的数据集。但是我们对于数据集的处理还没有完成,如果使用以上的数据集去进行训练,会有多个错误产生。

  • TensorFlow的模型训练应接收带有shape方法的数据集,而我们上面的数据集是list类型,不具有shape方法,要使其得到此方法,可按如下处理:
import numpy as np
X_L = np.array(X_L)
X_R = np.array(X_R)
labels = np.array(labels)

numpy.array()方法将list转化为array,具有shape方法。到这里,数据处理仍没有结束。还记得我们模型最后的输出吗?应该是(?, 2)维的向量,而我们的labels是(?, 1)维向量,这是怎么回事?
这里我们的labels向量使用0和1代表两种结果,因此对于每对图片都只有一个标签。要处理这个问题,有两种解决方案。

  • 第一种解决方案是,将模型最后的输出激活函数换为’sigmoid’并改为1维。这样便与标签集维数相同。
  • 第二种解决方案是,将标签转为2维,并且要与softmax输出匹配,即转化为独热编码。(0->(1,0), 1->(0,1)).

这里我们采用第二种解决方案

labels = to_categorical(labels, num_classes=2)

现在,我们可以获取我们的模型了:

model = VGG_Siamese(input_shape=x_train[0].shape)

设置模型参数:

model.compile(optimizer='adam', loss="categorical_crossentropy", metrics=['accuracy'])
# 如果刚才采用第一种解决方案,将loss改为'binary_crossentropy'

参数设置完毕后,可以开始训练模型了:

model.fit([X_L, X_R], labels, validation_split=0.2, batch_size=32, epochs=30, verbose=1)

这里只为了演示如何构建Siamese模型,因此选用的模型较简单,训练效果并不优秀,但是便于理解Siamese的工作原理和创建方式,为了优化训练效果,可以自己动手尝试更换模型进行训练。

  • 训练完成后,可以将模型保存:
save_path = "./weights/my_weight" # 填文件地址和名称
model.save_weights(save_path) # 保存权重
model.save(save_path+'h5') # 保存模型和权重

TensorFlow搭建VGG-Siamese网络相关推荐

  1. 利用TensorFlow搭建CNN,DNN网络实现图像手写识别,总结。

    利用TensorFlow搭建CNN,DNN网络实现图像手写识别,总结. 摘要 一.神经网络与卷积网络的对比 1.数据处理 2.对获取到的数据进行归一化和独热编码 二.开始我们的tensorflow神经 ...

  2. TFboys:使用Tensorflow搭建深层网络分类器

    前言 根据官方文档整理而来的,主要是对Iris数据集进行分类.使用tf.contrib.learn.tf.contrib.learn快速搭建一个深层网络分类器, 步骤 导入csv数据 搭建网络分类器 ...

  3. #教计算机学画卡通人物#生成式对抗神经网络GAN原理、Tensorflow搭建网络生成卡通人脸

    生成式对抗神经网络GAN原理.Tensorflow搭建网络生成卡通人脸 下面这张图是我教计算机学画画,计算机学会之后画出来的,具体实现在下面. ▲以下是对GAN形象化地表述 ●赵某不务正业.游手好闲, ...

  4. pytorch 搭建 VGG 网络

    目录 1. VGG 网络介绍 2. 搭建VGG 网络 3. code 1. VGG 网络介绍 VGG16 的网络结构如图: VGG 网络是由卷积层和池化层构成基础的CNN 它的CONV卷积层的参数全部 ...

  5. 用TensorFlow搭建一个万能的神经网络框架(持续更新)

    博客作者:凌逆战 博客地址:https://www.cnblogs.com/LXP-Never/p/12774058.html 文章代码:https://github.com/LXP-Never/bl ...

  6. siamese网络_CVPR 2019手写签名认证的逆鉴别网络

    点击我爱计算机视觉标星,更快获取CVML新技术 本文简要介绍CVPR2019论文"Inverse Discriminative Networks for Handwritten Signat ...

  7. 如何用 TensorFlow 实现生成式对抗网络(GAN)

    我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodfellow 在 14 年发表了 论文 Generative Adversarial Nets 以 ...

  8. 不到 200 行代码,教你如何用 Keras 搭建生成对抗网络(GAN)

     不到 200 行代码,教你如何用 Keras 搭建生成对抗网络(GAN) 生成对抗网络(Generative Adversarial Networks,GAN)最早由 Ian Goodfello ...

  9. 深度学习Caffe实战笔记(6)Windows caffe平台用Siamese网络跑自己的数据

    终于到了介绍如何使用Siamese网络跑自己的数据了,在网上.论坛上.群里关于用Siamese网络的资料很多,但是实战的资料很少,难道是因为太容易了吗?反正博主查阅了各种地方,几乎没有找到Siames ...

最新文章

  1. 模拟人脑项目彻底宣告失败:耗资10亿欧,10年前轰动全球,如今死得悄无声息...
  2. wamp编写html页面空白,Win10安装wampserver后打开localhost出现空白怎么办
  3. Android应用框架浅析
  4. 手机端页面要加...
  5. nbu备份oracle rac,利用NETBACKUP将备份写到磁盘上
  6. [Bootstrap]7天深入Bootstrap(1)入门准备
  7. 5G UPF与MEC
  8. Python连接MySQL数据库
  9. VBA变量和函数名中文拼音输入提示
  10. 从亏损19亿到盈利6亿,恺英网络做对了什么?
  11. Jupyter Notebook与Pycharm代码连接Docker容器中的远程服务器运行
  12. Android旅游自助项目之订票系统订票功能实现
  13. PostgreSQL数据库----pgAdmin客户端工具的使用
  14. 彻底破解加密PDF文件
  15. 【3】Kubernetes资源管理
  16. matlab ax=b x=,matlab 求解 Ax=B 时所用算法
  17. 什么是服务器集群?集群服务器有什么好处
  18. 【08月01日】A股滚动市净率PB历史新低排名
  19. 使用Ajax实现百度下拉框
  20. 【JS】JavaScript

热门文章

  1. 自制文件管理工具介绍
  2. springboot实现小程序打赏支付功能
  3. 手把手教你用AirtestIDE无线连接手机
  4. 通过例子学TLA+(六)--标准模块
  5. Python面向对象(约束,异常处理,md5加密)(五)
  6. js登录特效+ajax提交表单+异步刷新验证
  7. 1、倒车轨迹计算方法和实施!(开启+更改!)
  8. 修改图片尺寸并且缩放和裁剪
  9. 开源国产化zabbix监控平台
  10. ubuntu anaconda 清理文件 相见恨晚