Siamese Network通常用于小样本的学习,是meta learning的方法。

Siamese Network,其使用CNN网络作为特征提取器,不同类别的样本,共用一个CNN网络,在CNN网络之后添加了全连接层,可以用于判别输入的样本是否是同一类别。也就是二分类问题。

这里实现的孪生网络输入是从相同类别或不同类别样本中随机采样一对数据,如果是相同类别,则标签为1 ,如果是不同类别,则标签为0。注意相同类别和不同类别样本对要平衡。具体实现还是看代码比较直接。

图一特征提取

图2 contrstive loss

相同类别相识度为1, 不同类别相识度为0

图3 三元法 triplet loss

\alpha:margin (>0)超参数,期望不同类别的分离程度

如果,,则没有loss, 否则,loss为

写成数学表达式就是:

还有一种方法是,还可以三个样本组成一对样本送入网络,即随机抽取一个一个样本,再随机抽取一个相同类别的样本作为正样本和不同类别的样本作为负样本,组成样本对。

相应的loss function使用triplet loss,这种方法可以取得更好的效果。

这里先给出triplet loss,相应自定义数据集该日再补充

class TripletLoss(nn.Module):"""Triplet lossTakes embeddings of an anchor sample, a positive sample and a negative sample"""def __init__(self, margin):super(TripletLoss, self).__init__()self.margin = margindef forward(self, anchor, positive, negative, size_average=True):distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)losses = F.relu(distance_positive - distance_negative + self.margin)return losses.mean() if size_average else losses.sum()

这里给出完整代码,包含三个代码文件,siamese_dataset, model, main()

下面代码是自定义的数据集

import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
import torchvision.utils
import numpy as np
import random
from torch.utils.data.sampler import BatchSampler
from PIL import Imageclass SiameseMNIST(Dataset):"""Train: For each sample creates randomly a positive or a negative pairTest: Creates fixed pairs for testing"""def __init__(self, mnist_dataset):self.mnist_dataset = mnist_datasetself.train = self.mnist_dataset.trainself.transform = self.mnist_dataset.transformif self.train:self.train_labels = self.mnist_dataset.targetsself.train_data = self.mnist_dataset.dataself.labels_set = set(self.train_labels.numpy())self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]for label in self.labels_set}else:# generate fixed pairs for testingself.test_labels = self.mnist_dataset.targetsself.test_data = self.mnist_dataset.dataself.labels_set = set(self.test_labels.numpy())self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]for label in self.labels_set}random_state = np.random.RandomState(29)positive_pairs = [[i,random_state.choice(self.label_to_indices[self.test_labels[i].item()]),1]for i in range(0, len(self.test_data), 2)]negative_pairs = [[i,random_state.choice(self.label_to_indices[np.random.choice(list(self.labels_set - set([self.test_labels[i].item()])))]),0]for i in range(1, len(self.test_data), 2)]self.test_pairs = positive_pairs + negative_pairsdef __getitem__(self, index):if self.train:target = np.random.randint(0, 2)img1, label1 = self.train_data[index], self.train_labels[index].item()if target == 1:siamese_index = indexwhile siamese_index == index:siamese_index = np.random.choice(self.label_to_indices[label1])else:siamese_label = np.random.choice(list(self.labels_set - set([label1])))siamese_index = np.random.choice(self.label_to_indices[siamese_label])img2 = self.train_data[siamese_index]else:img1 = self.test_data[self.test_pairs[index][0]]img2 = self.test_data[self.test_pairs[index][1]]target = self.test_pairs[index][2]img1 = Image.fromarray(img1.numpy(), mode='L')img2 = Image.fromarray(img2.numpy(), mode='L')if self.transform is not None:img1 = self.transform(img1)img2 = self.transform(img2)return (img1, img2), targetdef __len__(self):return len(self.mnist_dataset)

因为MNIST数据集比较简单,所以模型也比较也简单。重点是,Contrastiveloss函数

import torch
import torch.nn as nn
import torch.nn.functional as Fclass SiameseNetwork(nn.Module):def __init__(self):super(SiameseNetwork, self).__init__()self.cnn1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5),nn.PReLU(),nn.MaxPool2d(2, 2),nn.Conv2d(32, 64, kernel_size=5),nn.PReLU(),nn.MaxPool2d(2, 2),)self.fc1 = nn.Sequential(nn.Linear(64*4*4, 256),nn.PReLU(),nn.Linear(256, 256),nn.PReLU(),nn.Linear(256, 2))self._initialize_weights()def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)def forward_once(self, x):output = self.cnn1(x)output = output.view(output.size()[0], -1)output = self.fc1(output)return outputdef forward(self, input1, input2):output1 = self.forward_once(input1)output2 = self.forward_once(input2)return output1, output2class ContrastiveLoss(nn.Module):def __init__(self, margin=2.0):super(ContrastiveLoss, self).__init__()self.margin = marginself.eps = 1e-9def forward(self, output1, output2, target, size_average=True):distances = (output1-output2).pow(2).sum(1)loss = 0.5*(target.float()*distances +(1 - target).float()*F.relu(self.margin - (distances+self.eps).sqrt()).pow(2))return loss.mean() if size_average else loss.sum()

训练和验证主程序,把文件路径改一下就可以了。

import sys
import os
import torch
import torch.nn as nn
import torchvision
from torch.optim.lr_scheduler import ExponentialLR, MultiStepLR
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import torchvision.datasets as dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
import nibabel as nib
import argparse
from tqdm import tqdm
import visdom
from Siamese_minist import SiameseMNIST
from siamese_model import SiameseNetwork, ContrastiveLossparser = argparse.ArgumentParser()
parser.add_argument('--train_dir', type=str, default='./data')
parser.add_argument('--test_dir', type=str, default=None)
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--epochs', type=int, default=20, help='number epoch to training')
parser.add_argument('--lr', type=float, default=1e-3, help='initial learning rate')
parser.add_argument('--nw', type=int, default=16, help='Dataloader num_works')
parser.add_argument('--save_path', type=str, default='/home/yang/cnn3d/mutipule_calssification/weight_path/Siamese_model.pth', help='model weight save path')
parser.add_argument('--train_path', type=str, default='/home/yang/cnn3d/mutipule_calssification/data_selected/category_10/train_data', help='training data path')
parser.add_argument('--test_path', type=str, default='/home/yang/cnn3d/mutipule_calssification/data_selected/category_10/test_data', help='test data path')
parser.add_argument('--margin', type=float, default=1.0, help='contrastive loss margin ')
parser.add_argument('--gamma', type=float, default=0.95, help='optimizer scheduler gamma')torch.manual_seed(1)opt = parser.parse_args()
print(opt)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#visdom 可视化,在Teminal窗口输入 python3 -m visdom.server
viz = visdom.Visdom()
train_dataset_path = opt.train_path
test_dataset_path = opt.test_path
mean, std = 0.1307, 0.3081
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((mean,), (std,))])
minist_path = "/home/yang/cnn3d/mutipule_calssification/SiameseNetwork/mnist"minist_train = dataset.MNIST(minist_path, train=True, transform=transform, download=False)
minist_test = dataset.MNIST(minist_path, train=False, transform=transform, download=False)train_dataset = SiameseMNIST(minist_train)
test_dataset = SiameseMNIST(minist_test)train_loader = DataLoader(minist_train, batch_size=64)
test_loader = DataLoader(minist_test, batch_size=64)train_dataloader = DataLoader(train_dataset,shuffle=True,num_workers=opt.nw,batch_size=opt.batch_size)
test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=opt.batch_size, num_workers=opt.nw)net = SiameseNetwork().to(device)criterion = ContrastiveLoss(margin=opt.margin)
optimizer = optim.Adam(net.parameters(), lr=opt.lr)
scheduler = ExponentialLR(optimizer, gamma=opt.gamma)
scheduler1 = MultiStepLR(optimizer, [10, 20], gamma=0.1)
def show_plot(iteration,loss):plt.plot(iteration, loss)plt.show()mnist_classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728','#9467bd', '#8c564b', '#e377c2', '#7f7f7f','#bcbd22', '#17becf']def plot_embeddings(embeddings, targets, xlim=None, ylim=None):plt.figure(figsize=(10, 10))for i in range(10):inds = np.where(targets==i)[0]plt.scatter(embeddings[inds,0], embeddings[inds,1], alpha=0.5, color=colors[i])if xlim:plt.xlim(xlim[0], xlim[1])if ylim:plt.ylim(ylim[0], ylim[1])plt.legend(mnist_classes)def extract_embeddings(dataloader, model, cuda=True):with torch.no_grad():model.eval()embeddings = np.zeros((len(dataloader.dataset), 2))labels = np.zeros(len(dataloader.dataset))k = 0for images, target in dataloader:if cuda:images = images.to(device)embeddings[k:k+len(images)] = model.forward_once(images).data.cpu().numpy()labels[k:k+len(images)] = target.numpy()k += len(images)return embeddings, labelsviz.line([0.], [0.], win='train_loss', opts=dict(title='training Loss'))
viz.line([0.], [0.], win='val_loss', opts=dict(title='valuation Loss'))def main():net.train()counter = []loss_history = []iteration_number = 0global_step = 0.0val_step = 0.0for epoch in range(opt.epochs):train_loss = 0.0train_bar = tqdm(train_dataloader, file=sys.stdout)for index, data in enumerate(train_bar):(image0, image1), label = dataimage0, image1, label = image0.to(device), image1.to(device), label.to(device)optimizer.zero_grad()output1, output2 = net(image0, image1)loss_contrastive = criterion(output1, output2, label)loss_contrastive.backward()train_loss += loss_contrastive.item()global_step += 1optimizer.step()viz.line([loss_contrastive.item()], [global_step], win='train_loss', opts=dict(title='training Loss'),update='append')if index % 10 == 0:iteration_number += 10counter.append(iteration_number)loss_history.append(loss_contrastive.item())print("Epoch number {} Current loss {}".format(epoch+1, train_loss/(len(train_dataloader))))print("第%d个epoch的学习率:%f" % (epoch + 1, optimizer.param_groups[0]['lr']))scheduler1.step()if epoch % 5 == 0:net.eval()with torch.no_grad():loss = 0.0val_bar = tqdm(test_dataloader, file=sys.stdout)for index, data in enumerate(val_bar):val_step += 1(val_image0, val_image1), val_label = dataval_image0, val_image1, val_label = val_image0.to(device), val_image1.to(device), val_label.to(device)output1, output2 = net(val_image0, val_image1)loss_contrastive = criterion(output1, output2, val_label)loss += loss_contrastive.item()viz.line([loss_contrastive.item()], [val_step], win='val_loss', opts=dict(title='valuation loss'), update='append')print('epoch %d| valuation Loss:%.4f' % (epoch, loss/len(test_dataloader)))# torch.save(net.state_dict(), opt.save_path)show_plot(counter, loss_history)def valuation():net.eval()dataiter = iter(test_dataloader)with torch.no_grad():num = 0.0x0, _, label1 = next(dataiter)min_diatance = 10predic_label = Nonefor i in range(len(test_dataset)-1):_, x1, label2 = next(dataiter)output1, output2 = net(Variable(x0).cuda(), Variable(x1).cuda())euclidean_distance = F.pairwise_distance(output1, output2)if euclidean_distance < min_diatance:min_diatance = euclidean_distancepredic_label = label2if predic_label == label1:num += 1print('min diatance: ', min_diatance)print('predicted label', predic_label)if __name__ == '__main__':main()#聚类结果可视化train_embeddings, train_labels = extract_embeddings(train_loader, net)#figure1 train dataplot_embeddings(train_embeddings, train_labels)val_embeddings, val_labels = extract_embeddings(test_loader, net)#figure2 test dataplot_embeddings(val_embeddings, val_labels)plt.show()

运行结果:

训练的loss曲线:

训练集数据效果:

孪生网络(Siamese Network)实现手写数字聚类相关推荐

  1. 深度学习 LSTM长短期记忆网络原理与Pytorch手写数字识别

    深度学习 LSTM长短期记忆网络原理与Pytorch手写数字识别 一.前言 二.网络结构 三.可解释性 四.记忆主线 五.遗忘门 六.输入门 七.输出门 八.手写数字识别实战 8.1 引入依赖库 8. ...

  2. [Pytorch系列-41]:卷积神经网络 - 模型参数的恢复/加载 - 搭建LeNet-5网络与MNIST数据集手写数字识别

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  3. java手撕KMeans算法实现手写数字聚类(失败案例)

    最近几天刚刚接触机器学习,学完K-Means聚类算法.正好又赶上一个课程项目是识别"手写数字",因为KMeans能够实现聚类,因此自然而然地想要通过KMeans来实现. 前排提示: ...

  4. 深度学习100例-生成对抗网络(GAN)手写数字生成 | 第18天

    文章目录 一.前期工作 1. 设置GPU 2. 定义训练参数 二.什么是生成对抗网络 1. 简单介绍 2. 应用领域 三.网络结构 四.构建生成器 五.构建鉴别器 六.训练模型 1. 保存样例图片 2 ...

  5. 深度学习100例-生成对抗网络(DCGAN)手写数字生成 | 第19天

    文章目录 深度卷积生成对抗网络(DCGAN) 一.前言 二.什么是生成对抗网络? 1. 设置GPU 2. 加载和准备数据集 三.创建模型 1. 生成器 2. 判别器 四.定义损失函数和优化器 1. 判 ...

  6. 基于Python的BP网络实现手写数字识别

    资源下载地址:https://download.csdn.net/download/sheziqiong/86790047 资源下载地址:https://download.csdn.net/downl ...

  7. 教程 | 基于LSTM实现手写数字识别

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 基于tensorflow,如何实现一个简单的循环神经网络,完成手写 ...

  8. 学习笔记CB009:人工神经网络模型、手写数字识别、多层卷积网络、词向量、word2vec...

    人工神经网络,借鉴生物神经网络工作原理数学模型. 由n个输入特征得出与输入特征几乎相同的n个结果,训练隐藏层得到意想不到信息.信息检索领域,模型训练合理排序模型,输入特征,文档质量.文档点击历史.文档 ...

  9. CNN网络实现手写数字(MNIST)识别 代码分析

    CNN网络实现手写数字(MNIST)识别 代码分析(自学用) Github代码源文件 本文是学习了使用Pytorch框架的CNN网络实现手写数字(MNIST)识别 #导入需要的包 import num ...

最新文章

  1. 【OpenCV 4开发详解】边缘检测原理
  2. linux自学笔记——RAID级别特性以及软RAID的实现
  3. 浅谈超文本传输协议(HTTP)
  4. MFC CAsyncSocket类基本使用 - 1
  5. python 内存溢出能捕获吗_从0基础学习Python (19)[面向对象开发过程中的异常(捕获异常~相关)]...
  6. 质粒抽提常见问题与解答​
  7. [XSY4170] 妹子(线段树上二分)
  8. 不一样的视角,程序员世界里的环保
  9. python 第一行 报错_初学Python-只需4步,爬取网站图片
  10. mui案例:导航栏 颜色渐变
  11. 2021年Q2全球智能手机销量小米升至第二,苹果降至第三
  12. c语言自由存储区,C/C++ 内存分区以及自由存储区和堆的区别
  13. 财务有必要学python吗-财务人要学Python吗?
  14. 海康威视4G球机对接萤石云平台获取直播视频列表 ----- java
  15. JSON对象如何转化为字符串?
  16. javaScript、PHP连接外卖小票机打印机方案(佳博、芯烨等)
  17. c语言提供三种逻辑运算符,按优先级高低它们分别是,c语言逻辑运算符优先级
  18. Cocos2d-Html5--打怪升级之路
  19. 2021年保育员(中级)考试及保育员(中级)考试总结
  20. 算法分析之大O、大Ω、大Θ和小o表示法

热门文章

  1. matlab gui停止键,在MATLAB的过程中停止GUI
  2. python面向对象和面向过程解析
  3. 山东区域华三金牌总代理,三星级服务器代理商
  4. 创建第一个iOS程序
  5. 获取Stream流的几种方式
  6. 软件测试里的几种测试模型: V模型,W模型,H模型
  7. ENSP安装华为防火墙模拟器(附USG模拟设备安装包)
  8. 信联征信参展国际金融博览会,信用科技赋能金融业
  9. win10下AirSim搭建
  10. 满二叉树先序序列转后序序列