基于TPS(Thin Plate Spines)的STN网络是OCR领域CVPR论文《Robust Scene Text Recognition with Automatic Rectification》中提出的RARE网络的一部分,RARE网络的基本结构为空间变换网络(STN)+序列识别网络(SRN)。

import torch
import torch.nn as nn
import torch.nn.functional as function
import numpy as npclass LocalizationNetwork(nn.Module):"""空间变换网络1.读入输入图片,并利用其卷积网络提取特征2.使用特征计算基准点,基准点的个数由参数fiducial指定,参数channel指定输入图像的通道数3.计算基准点的方法是使用两个全连接层将卷积网络输出的特征进行降维,从而得到基准点集合"""def __init__(self, fiducial, channel):"""初始化方法:param fiducial: 基准点的数量:param channel: 输入图像通道数"""super(LocalizationNetwork, self).__init__()self.fiducial = fiducial # 指定基准点个数self.channel = channel   # 指定输入图像的通道数# 提取特征使用的卷积网络self.ConvNet = nn.Sequential(nn.Conv2d(self.channel, 64, 3, 1, padding=1, bias=False),nn.BatchNorm2d(64), nn.ReLU(True),  # [N, 64, H, W]nn.MaxPool2d(2, 2),  # [N, 64, H/2, W/2]nn.Conv2d(64, 128, 3, 1, padding=1, bias=False),nn.BatchNorm2d(128), nn.ReLU(True),  # [N, 128, H/2, W/2]nn.MaxPool2d(2, 2),  # [N, 128, H/4, W/4]nn.Conv2d(128, 256, 3, 1, padding=1, bias=False),nn.BatchNorm2d(256), nn.ReLU(True),  # [N, 256, H/4, W/4]nn.MaxPool2d(2, 2),  # [N, 256, H/8, W/8]nn.Conv2d(256, 512, 3, 1, padding=1, bias=False),nn.BatchNorm2d(512), nn.ReLU(True),  # [N, 512, H/8, W/8]nn.AdaptiveAvgPool2d(1))  # [N, 512, 1, 1]# 计算基准点使用的两个全连接层self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True))self.localization_fc2 = nn.Linear(256, self.fiducial * 2)# 将全连接层2的参数初始化为0self.localization_fc2.weight.data.fill_(0)"""全连接层2的偏移量bias需要进行初始化,以便符合RARE Paper中所介绍的三种初始化形式,三种初始化方式详见https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Shi_Robust_Scene_Text_CVPR_2016_paper.pdf,Fig. 6 (a)下初始化方法为三种当中的第一种"""ctrl_pts_x = np.linspace(-1.0, 1.0, fiducial // 2)ctrl_pts_y_top = np.linspace(0.0, -1.0, fiducial // 2)ctrl_pts_y_bottom = np.linspace(1.0, 0.0, fiducial // 2)ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)# 修改全连接层2的偏移量self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1)def forward(self, x):"""前向传播方法:param x: 输入图像,规模[batch_size, C, H, W]:return: 输出基准点集合C,用于图像校正,规模[batch_size, fiducial, 2]"""# 获取batch_sizebatch_size = x.size(0)# 提取特征features = self.ConvNet(x).view(batch_size, -1)# 使用特征计算基准点集合Cfeatures = self.localization_fc1(features)C = self.localization_fc2(features).view(batch_size, self.fiducial, 2)return Cclass GridGenerator(nn.Module):"""网格生成网络Grid Generator of RARE, which produces P_prime by multipling T with P."""def __init__(self, fiducial, output_size):"""初始化方法:param fiducial: 基准点与基本基准点的个数:param output_size: 校正后图像的规模基本基准点是被校正后的图片的基准点集合"""super(GridGenerator, self).__init__()self.eps = 1e-6# 基准点与基本基准点的个数self.fiducial = fiducial# 校正后图像的规模self.output_size = output_size # 假设为[w, h]# 论文公式当中的C',C'是基本基准点,也就是被校正后的图片的基准点集合self.C_primer = self._build_C_primer(self.fiducial)# 论文公式当中的P',P'是校正后的图片的像素坐标集合,规模为[h·w, 2],集合中有n个元素,每个元素对应校正图片的一个像素的坐标self.P_primer = self._build_P_primer(self.output_size)# 如果使用多GPU,则需要寄存器缓存register bufferself.register_buffer("inv_delta_C_primer",torch.tensor(self._build_inv_delta_C_primer(self.fiducial, self.C_primer)).float())self.register_buffer("P_primer_hat",torch.tensor(self._build_P_primer_hat(self.fiducial, self.C_primer, self.P_primer)).float())def _build_C_primer(self, fiducial):"""构建基本基准点集合C',即被校正后的图片的基准点,应该是一个矩形的fiducial个点集合:param fiducial: 基本基准点的个数,跟基准点个数相同该方法生成C'的方法与前面的空间变换网络相同"""ctrl_pts_x = np.linspace(-1.0, 1.0, fiducial // 2)ctrl_pts_y_top = -1 * np.ones(fiducial // 2)ctrl_pts_y_bottom = np.ones(fiducial // 2)ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)C_primer = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)return C_primerdef _build_P_primer(self, output_size):"""构建校正图像像素坐标集合P',构建的方法为按照像素靠近中心的程度形成等差数列作为像素横纵坐标值:param output_size: 模型输出的规模:return : 校正图像的像素坐标集合"""w, h = output_size# 等差数列output_grid_xoutput_grid_x = (np.arange(-w, w, 2) + 1.0) / w# 等差数列output_grid_youtput_grid_y = (np.arange(-h, h, 2) + 1.0) / h"""使用np.meshgrid将output_grid_x中每个元素与output_grid_y中每个元素组合形成一个坐标注意,若output_grid_x的规模为[w], output_grid_y为[h],则生成的元素矩阵规模为[h, w, 2]"""P_primer = np.stack(np.meshgrid(output_grid_x, output_grid_y), axis=2)# 在返回时将P'进行降维,将P'从[h, w, 2]降为[h·w, 2]return P_primer.reshape([-1, 2])  # [HW, 2]def _build_inv_delta_C_primer(self, fiducial, C_primer):"""计算deltaC'的逆,该矩阵为常量矩阵,在确定了fiducial与C'之后deltaC'也同时被确定:param fiducial: 基准点与基本基准点的个数:param C_primer: 基本基准点集合C':return: deltaC'的逆"""# 计算C'梯度公式中的R,R中的元素rij等于dij的平方再乘dij的平方的自然对数,dij是C'中第i个元素与C'中第j个元素的欧式距离,R矩阵是个对称矩阵R = np.zeros((fiducial, fiducial), dtype=float)# 对称矩阵可以简化for循环for i in range(0, fiducial):for j in range(i, fiducial):R[i, j] = R[j, i] = np.linalg.norm(C_primer[i] - C_primer[j])np.fill_diagonal(R, 1)  # 填充对称矩阵对角线元素,都为1R = (R ** 2) * np.log(R ** 2)  # 或者R = 2 * (R ** 2) * np.log(R)# 使用不同矩阵进行拼接,组成deltaC'delta_C_primer = np.concatenate([np.concatenate([np.ones((fiducial, 1)), C_primer, R], axis=1),       # 规模[fiducial, 1+2+fiducial],deltaC'计算公式的第一行np.concatenate([np.zeros((1, 3)), np.ones((1, fiducial))], axis=1),  # 规模[1, 3+fiducial],deltaC'计算公式的第二行np.concatenate([np.zeros((2, 3)), np.transpose(C_primer)], axis=1)   # 规模[2, 3+fiducial],deltaC'计算公式的第三行], axis=0)                                                               # 规模[fiducial+3, fiducial+3]# 调用np.linalg.inv求deltaC'的逆inv_delta_C_primer = np.linalg.inv(delta_C_primer)return inv_delta_C_primerdef _build_P_primer_hat(self, fiducial, C_primer, P_primer):"""求^P',即论文公式当中由校正后图片像素坐标经过变换矩阵T后反推得到的原图像素坐标P集合公式当中的P'帽,P = T^P':param fiducial: 基准点与基本基准点的个数:param C_primer: 基本基准点集合C',规模[fiducial, 2]:param P_primer: 校正图像的像素坐标集合,规模[h·w, 2]:return: ^P',规模[h·w, fiducial+3]"""n = P_primer.shape[0]  # P_primer的规模为[h·w, 2],即n=h·w# PAPER: d_{i,k} is the euclidean distance between p'_i and c'_kP_primer_tile = np.tile(np.expand_dims(P_primer, axis=1), (1, fiducial, 1))  # 规模变化 [h·w, 2] -> [h·w, 1, 2] -> [h·w, fiducial, 2]C_primer = np.expand_dims(C_primer, axis=0)                                  # 规模变化 [fiducial, 2] -> [1, fiducial, 2]# 此处相减是对于P_primer_tile的每一行都减去C_primer,因为这两个矩阵规模不一样dist = P_primer_tile - C_primer# 计算求^P'公式中的dik,dik为P'中第i个点与C'中第k个点的欧氏距离r_norm = np.linalg.norm(dist, ord=2, axis=2, keepdims=False)                 # 规模 [h·w, fiducial]# r'ik = d^2ik·lnd^2ikr = 2 * np.multiply(np.square(r_norm), np.log(r_norm + self.eps))# ^P'i = [1, x'i, y'i, r'i1,......, r'ik]的转置,k=fiducialP_primer_hat = np.concatenate([np.ones((n, 1)), P_primer, r], axis=1)        # 规模 经过垂直拼接[h·w, 1],[h·w, 2],[h·w, fiducial]形成[h·w, fiducial+3]return P_primer_hatdef _build_batch_P(self, batch_C):"""求本batch每一张图片的原图像素坐标集合P:param batch_C: 本batch原图的基准点集合C:return: 本batch的原图像素坐标集合P,规模[batch_size, h, w, 2]"""# 获取batch_sizebatch_size = batch_C.size(0)# 将本batch的基准点集合进行扩展,使其规模从[batch_size, fiducial, x] -> [batch_size, fiducial+3, 2]batch_C_padding = torch.cat((batch_C, torch.zeros(batch_size, 3, 2).float()), dim=1)# 按照论文求解T的公式求T,规模变化[fiducial+3, fiducial+3] × [batch_size, fiducial+3, 2] -> [batch_size, fiducial+3, 2]batch_T = torch.matmul(self.inv_delta_C_primer, batch_C_padding)# 按照论文公式求原图像素坐标的公式求解本batch的原图像素坐标集合P,P = T^P'# [h·w, fiducial+3] × [batch_size, fiducial+3, 2] -> [batch_size, h·w, 2]batch_P = torch.matmul(self.P_primer_hat, batch_T)# 将P从[batch_size, h·w, 2]转换到[batch_size, h, w, 2]return batch_P.reshape([batch_size, self.output_size[1], self.output_size[0], 2])def forward(self, batch_C):return self._build_batch_P(batch_C)class TPSSpatialTransformerNetwork(nn.Module):"""Rectification Network of RARE, namely TPS based STN"""def __init__(self, fiducial, input_size, output_size, channel):"""Based on RARE TPS:param fiducial: number of fiducial points:param input_size: (w, h) of the input image:param output_size: (w, h) of the rectified image:param channel: input image channel"""super(TPSSpatialTransformerNetwork, self).__init__()self.fiducial = fiducialself.input_size = input_sizeself.output_size = output_sizeself.channel = channelself.LNet = LocalizationNetwork(self.fiducial, self.channel)self.GNet = GridGenerator(self.fiducial, self.output_size)def forward(self, x):""":param x: batch input image [batch_size, c, w, h]:return: rectified image [batch_size, c, h, w]"""# 求原图的基准点集合CC = self.LNet(x)  # [batch_size, fiducial, 2]# 求原图对应校正图像素的像素坐标集合PP = self.GNet(C) # [batch_size, h, w, 2]# 按照P对x进行采样,对于越界的位置在网格中采用边界的pixel value进行填充rectified = function.grid_sample(x, P, padding_mode='border', align_corners=True)  #规模[batch_size, c, h, w]print(np.shape(rectified))return rectifiedif __name__ == '__main__':tps = TPSSpatialTransformerNetwork(6, (128, 64), (128, 64), 3)# input size: [batch_size, channel_num, w, h]input = torch.randn((1, 3, 128, 64))tps(input)a = [1, 2, 3]b = [4, 5]P_primer = np.stack(np.meshgrid(a, b), axis=2)print(P_primer)

网络复现之基于TPS的STN网络相关推荐

  1. 基于php的网络教学平台,基于PHP技术的网络教学平台的设计与实现

    崔静静+++项小书+++吴燕红 摘要:该文基于简易.灵活的PHP语言及Sql Server数据库技术,设计并实现网络教学平台.该平台重点实现了在线测试.在线答疑等功能,为课堂教学提供了有益补充,增强了 ...

  2. unet是残差网络吗_基于深度监督残差网络的肝脏及肝肿瘤分割

    摘要: 针对医生手动对肝脏肿瘤CT图像分割耗时,耗力,且易受主观判断影响的问题,该研究提出一种深度监督残差网络(Deeply Supervised Residual Unet,DS-ResUnet)算 ...

  3. 计算机应用 网络管理开发,基于XML的iBAC网络管理系统的研究与开发-计算机应用技术专业论文.docx...

    ⅢY ⅢY iii■l 洲8 mmj■I ㈣0 Ⅲ4 6 ㈣2 学位论文数据集 中图分类号TP311.1学科分类号520.3040 论文编号10010200705 12密级 学位授予单位代码10010 ...

  4. 【匿名网络综述】匿名分布式网络之匿名网络综述

    文章目录 1.匿名网络的一些概念 1.1 匿名网络的目标 1.2 匿名网络的匿名性划分 1.2.1 发送者(接收者)匿名性 1.2.2 发送者(接收者)不可观测性 1.3 匿名网络的一些设计要素 1. ...

  5. 基于TPS(Thin Plate Spines)的STN网络的PyTorch实现

    基于TPS(Thin Plate Spines)的STN网络是OCR领域CVPR论文<Robust Scene Text Recognition with Automatic Rectifica ...

  6. Lesson 16.1016.1116.1216.13 卷积层的参数量计算,1x1卷积核分组卷积与深度可分离卷积全连接层 nn.Sequential全局平均池化,NiN网络复现

    二 架构对参数量/计算量的影响 在自建架构的时候,除了模型效果之外,我们还需要关注模型整体的计算效率.深度学习模型天生就需要大量数据进行训练,因此每次训练中的参数量和计算量就格外关键,因此在设计卷积网 ...

  7. 虚拟网络运维----基于wireshark报文分析快速过滤(tcp,icmp,http)报文时延

    文章目录 虚拟网络运维----基于wireshark报文分析快速过滤(tcp,icmp,http)报文时延 前言 tcp协议高时延报文定位 http协议高时延报文定位 icmp协议高时延报文 虚拟网络 ...

  8. 基于轻量化重构网络的表面缺陷视觉检测

    源自:自动化学报     作者:余文勇 张阳 姚海明 石绘  编辑:OpenCV与AI深度学习 摘 要 基于深度学习的方法在某些工业产品的表面缺陷识别和分类方面表现出优异的性能, 然而大多数工业产品缺 ...

  9. 用C#实现基于TCP协议的网络通讯

    TCP协议是一个基本的网络协议,基本上所有的网络服务都是基于TCP协议的,如HTTP,FTP等等,所以要了解网络编程就必须了解基于TCP协议的编程.然而TCP协议是一个庞杂的体系,要彻底的弄清楚它的实 ...

最新文章

  1. saltstack 执行结果返回到mysql
  2. Linux下使用OTL操作mysql数据库
  3. 从零开始入门 K8s | Kubernetes API 编程利器:Operator 和 Operator Framework
  4. ios retain 与 copy 的区别
  5. SAP Spartacus全局配置模块里和layoutSlot相关的配置
  6. 瓜子二手车发12月二手车价格:汉兰达奥德赛CR-V保值率居首
  7. luoguSP1805,POJ2559-Largest Rectangle in a Histogram【单调栈】
  8. centos7手把手教你搭建zabbix监控
  9. oracle添加分区语句_Oracle表创建分区如何实现?
  10. Java同步队列(非阻塞队列与阻塞队列)——java并发容器
  11. 【MySQL】数据库命令练习题及答案
  12. win10下安装deepin双系统教程
  13. 企业盈利能力分析-毛利率、销售净利率、投资回报率、权益回报率、资产回报率...
  14. Ubuntu内核升级导致显卡冲突,升级显卡并禁用自动更新教程
  15. class SequenceFileOutputFormat takes type parameters
  16. 深入boot.img格式文件结构解析
  17. 基于量子计算的无收益标的资产欧式看涨期权定价和delta风险分析
  18. 计算机一直安装更新失败,win10系统一直安装更新失败的三种解决方法
  19. android手机号码恢复,安卓手机通讯录没有了怎么办?如何恢复手机通讯录
  20. 听说北京有个兄弟连!

热门文章

  1. 展开阅读全文代码html,展开阅读全文 js 爬虫操作
  2. android 投屏mac,MAC投屏ipad、手机
  3. 理想电流源与理想电压源
  4. iphone、ipad机型分辨率
  5. ZETA等物联网技术在新冠疫情防控中有哪些方面的应用?
  6. 高校BBS最HOT的100个笑话(不看保证后悔终身)
  7. 小米手机扩容教程_手机内部存储空间扩容方法
  8. Android应用优化之流畅度优化实操
  9. 查询加日期oracle,Oracle查询优化日期运算实例详解
  10. NOIP前的刷题记录