完整工程代码点击这里。

import cv2
import torch.nn as nn
import torch
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import numpy as np
import random
from collections import Counter
torch.manual_seed(10)#固定每次初始化模型的权重#-----------------加载图像数据------------
img = cv2.imread('olivettifaces.jpg')
h = int(img.shape[0]/20)
w = int(img.shape[1]/20)
IMG = []
label = []
id = 0
for i in range(0,20*h,h):for j in range(0,20*w,w):IMG.append(img[i:i+h,j:j+w,:].reshape(3,h,w)/255)label.append(int(id/10))id += 1# 对训练集进行切割,然后进行训练
X_train,X_val,Y_train,Y_val = train_test_split(IMG,label,test_size=0.2)#-------------生成数据集-----------------x_train = []
y_train = []
x_val = []
y_val = []
for i in range(len(X_train)):for j in range(i+1,len(X_train)):if Y_train[i] == Y_train[j]:x_train.append([X_train[i],X_train[j]])y_train.append(1)else:key = random.randint(1,10)if key>=2:continuex_train.append([X_train[i],X_train[j]])y_train.append(0)for i in range(len(X_val)):for j in range(i+1,len(X_val)):if Y_val[i] == Y_val[j]:x_val.append([X_val[i],X_val[j]])y_val.append(1)else:key = random.randint(1,10)if key>=2:continuex_val.append([X_val[i],X_val[j]])y_val.append(0)x_train = torch.from_numpy(np.array(x_train)).to(torch.float32)
y_train = np.array(y_train)
x_val = torch.from_numpy(np.array(x_val)).to(torch.float32)
y_val = np.array(y_val)print('train',Counter(y_train),'val',Counter(y_val))
#------------------搭建网络框架------------
class Siamese(nn.Module):def __init__(self):super(Siamese, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3,#通道数目,刚输入的图片是彩色的三通道数目out_channels=16,kernel_size=3,stride=1,padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.conv2 = nn.Sequential(nn.Conv2d(in_channels=16,out_channels=32,kernel_size=3,stride=1,padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.l1 = nn.Linear(4928,300)#输出300个节点self.l2 = nn.Linear(300,1)#输出1个节点self.l3 = nn.Sigmoid()def forward(self, x1,x2):out1 = self.conv1(x1)out1 = self.conv2(out1)out1 = out1.view(out1.size(0),-1)out1 = self.l1(out1)out2 = self.conv1(x2)out2 = self.conv2(out2)out2 = out2.view(out2.size(0),-1)out2 = self.l1(out2)out = torch.abs(out1-out2)#计算均值误差out = self.l2(out)out = self.l3(out)return outtraining_step = 500#迭代次数
batch_size = 256#每个批次的大小
learning_rate = 0.005
model = Siamese()optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)#定义优化器
loss_func = nn.BCELoss() #定义损失函数#开始迭代
for step in range(training_step):print('step=',step)M_train = len(x_train)M_val = len(x_val)with tqdm(np.arange(0,M_train,batch_size), desc='Training...') as tbar:for index in tbar:L = indexR = min(M_train,index+batch_size)#-----------------训练内容------------------train_pre = model(x_train[L:R,0],x_train[L:R,1])     # 喂给 model训练数据 x, 输出预测值train_loss = loss_func(train_pre, torch.from_numpy(y_train[L:R].reshape(R-L,1)).to(torch.float))val_pre = model(x_val[:,0],x_val[:,1])val_loss = loss_func(val_pre, torch.from_numpy(y_val.reshape(M_val,1)).to(torch.float))#----------- -----计算准确率----------------train_acc = np.sum((np.array(train_pre.data)>=0.5)==(y_train[L:R].reshape(R-L,1)>=0.5))/(R-L) val_acc = np.sum((np.array(val_pre.data)>=0.5)==(y_val.reshape(M_val,1)>=0.5))/M_val #---------------打印在进度条上--------------tbar.set_postfix(train_loss=float(train_loss.data),train_acc=train_acc,val_loss=float(val_loss.data),val_acc=val_acc)tbar.update()  # 默认参数n=1,每update一次,进度+n#-----------------反向传播更新---------------optimizer.zero_grad()   # 清空上一步的残余更新参数值train_loss.backward()         # 以训练集的误差进行反向传播, 计算参数更新值optimizer.step()        # 将参数更新值施加到 net 的 parameters 上

训练结果

pytorch实现孪生神经网络对人脸相似度进行识别相关推荐

  1. 神经网络学习小记录52——Pytorch搭建孪生神经网络(Siamese network)比较图片相似性

    神经网络学习小记录52--Pytorch搭建孪生神经网络(Siamese network)比较图片相似性 学习前言 什么是孪生神经网络 代码下载 孪生神经网络的实现思路 一.预测部分 1.主干网络介绍 ...

  2. pytorch搭建孪生网络比较人脸相似性

    参考文献: 神经网络学习小记录52--Pytorch搭建孪生神经网络(Siamese network)比较图片相似性_Bubbliiiing的博客-CSDN博客_神经网络图片相似性 Python - ...

  3. 单样本学习:使用孪生神经网络进行人脸识别

    这篇文章简要介绍单样本学习,以孪生神经网络(Siamese neural network)进行人脸识别的例子,分享了作者从论文 FaceNet 以及 deeplearning.ai 中学到的内容. 图 ...

  4. 孪生神经网络_基于局部和全局孪生网络的鲁棒的人脸跟踪

    论文名称 Siamese local and global networks for robust face tracking 引用:Qi, Yuankai, et al. "Siamese ...

  5. Siamese网络(孪生神经网络)详解

    SiameseFC Siamese网络(孪生神经网络) 本文参考文章: Siamese背景 Siamese网络解决的问题 要解决什么问题? 用了什么方法解决? 应用的场景: Siamese的创新 Si ...

  6. 孪生神经网络--一个简单神奇的结构

    本文转载自机器学习小知识. 01 名字的由来 Siamese和Chinese有点像.Siam是古时候泰国的称呼,中文译作暹罗.Siamese也就是"暹罗"人或"泰国&qu ...

  7. 孪生神经网络_驾驶习惯也能识人?基于时空孪生神经网络的轨迹识别

    ⬆⬆⬆ 点击蓝字 关注我们 AI TIME欢迎每一位AI爱好者的加入! 前言: 给定一组单独的人员(例如行人,出租车司机)的历史轨迹以及由特定人员生成的一组新轨迹,轨迹识别问题旨在验证传入的轨迹是否是 ...

  8. [深度学习概念]·Siamese network 孪生神经网络简介

    Siamese network 孪生神经网络--一个简单神奇的结构 名字的由来 Siamese和Chinese有点像.Siam是古时候泰国的称呼,中文译作暹罗.Siamese也就是"暹罗&q ...

  9. Siamese network 孪生神经网络--一个简单神奇的结构

    转自: 作者:fighting41love 链接:https://www.jianshu.com/p/92d7f6eaacf5 1.名字的由来 Siamese和Chinese有点像.Siam是古时候泰 ...

  10. 基于深度卷积神经网络进行人脸识别的原理是什么?

    原文:https://www.zhihu.com/question/60759296 基于深度卷积神经网络进行人脸识别的原理是什么? 这里的人脸识别包括但不限于:人脸检测,人脸对齐,身份验证识别,和表 ...

最新文章

  1. python编写赛车游戏单机版_使用Keras和DDPG玩赛车游戏(自动驾驶)
  2. 如何使用cmd进入打印机选项_怎样用命令行方式添加打印机端口? (已解决)
  3. 第十届蓝桥杯(含题目文件下载)
  4. springboot加载外部xml_SpringBoot读取外部配置文件的方法
  5. git 和 github 关系?
  6. [bzoj 2555]Substring
  7. JAVA如何判断两个字符串是否相等(亲测第二种方式)
  8. DVWA-暴力破解-对‘g0tmi1k’文章的学习笔记
  9. 计算机网络是互相连接的自治系统,自治系统内ip子网和sdn子网的互联机制imisa-江苏计算机网络.pdf...
  10. 后台返回数据打印是[object object]的,报错:SyntaxError: JSON.parse: expected property name or ‘}‘ at line 1 column
  11. 通过cookie保存并读取用户登录信息实例
  12. matlab将矩阵分解成lu,10行代码实现矩阵的LU分解(matlab)
  13. typedef 指向函数的指针
  14. 认识 URL 及其编码
  15. oracle 体系结构及内存管理 15_存储结构
  16. listview mysql源码_用ListView实现对数据库的内容显示
  17. 量子计算机编程教程,量子信息与量子计算简明教程 PDF扫描版[12MB]
  18. lol进服务器时文件损坏,英雄联盟文件损坏怎么修复2018 | 手游网游页游攻略大全...
  19. 在手机/平板上安装kali系统
  20. linux系统格式化硬盘

热门文章

  1. word怎么显示计算机数字,如何键入word2007圆圈数字1到10及以上?
  2. word三线表标题两条线之间如何出现空白间隔(论文必备)
  3. 三菱fx2n做从站的modbus通讯_三菱PLC的通讯与编程,附实际案例
  4. 遗传算法求解TSP问题及MTATLAB代码
  5. 融云 SDK 5.0.0 功能迭代
  6. 计算机的发展导致了计算思维的诞生,尔雅电子计算机的诞生(上)
  7. 串口硬盘如何应用于并口硬盘计算机,串口并口硬盘连接具体步骤(转)
  8. 深圳雷赛智能自动控制软件使用说明(运动控制卡)
  9. php获取qq头像地址,使用PHP语言通过邮箱获取全球公认的Gravatar头像地址
  10. CSS 实现文字头像(圆角头像文字内容)