文本检测和一般目标检测的不同——文本线是一个sequence(字符、字符的一部分、多字符组成的一个sequence),而不是一般目标检测中只有一个独立的目标。这既是优势,也是难点。优势体现在同一文本线上不同字符可以互相利用上下文,可以用sequence的方法比如RNN来表示。难点体现在要检测出一个完整的文本线,同一文本线上不同字符可能差异大,距离远,要作为一个整体检测出来难度比单个目标更大——因此,作者认为预测文本的竖直位置(文本bounding box的上下边界)比水平位置(文本bounding box的左右边界)更容易。


  • pytorch最新版
  • ubunt18.05
  • opencv
  • pillow
  • numpy


  • 前言
  • 一、数据集准备
  • 二、数据标签准备
  • 三、模型训练
    • 源代码链接:
    • 数据集链接:
  • 四、文字检测(CTPN)完整代码
  • 五、训练结果展示
  • 六、加载CTPN文字检测模型,验证












链接: https://pan.baidu.com/s/1RNRaObQBnWaM_Rwd4KYQYg
提取码: 4s6s

大家吧config.py文件里面的数据 集路径配置好就行


链接: https://pan.baidu.com/s/1dOscxy1fkobW_g3VOM2qcQ 提取码: win6

针对这个数据集(1.6G),为天池开源数据集,如果大家有感兴趣的,可以下载下来训练模型; 如果要是大家觉得时间有限的话,可以不去训练模型,可以直接加载大家下载那个**CTPN.path**那个模型,玩玩就可以。注意注意:此为开源项目



链接: https://pan.baidu.com/s/1dOscxy1fkobW_g3VOM2qcQ
提取码: win6



import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import numpy as np
from PIL import Image
from PIL import Image
prob_thresh = 0.5
gpu = True
if not torch.cuda.is_available():gpu = False
device = torch.device('cuda:0' if gpu else 'cpu')
class basic_conv(nn.Module):def __init__(self,in_planes,out_planes,kernel_size,stride=1,padding=0,dilation=1,groups=1,relu=True,bn=True,bias=True):super(basic_conv, self).__init__()self.out_channels = out_planesself.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,dilation=dilation, groups=groups, bias=bias)self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else Noneself.relu = nn.ReLU(inplace=True) if relu else Nonedef forward(self, x):x = self.conv(x)if self.bn is not None:x = self.bn(x)if self.relu is not None:x = self.relu(x)return xclass CTPN_Model(nn.Module):def __init__(self):super().__init__()base_model = models.vgg16(pretrained=False)layers = list(base_model.features)[:-1]self.base_layers = nn.Sequential(*layers)  # block5_conv3 outputself.rpn = basic_conv(512, 512, 3, 1, 1, bn=False)self.brnn = nn.GRU(512, 128, bidirectional=True, batch_first=True)self.lstm_fc = basic_conv(256, 512, 1, 1, relu=True, bn=False)self.rpn_class = basic_conv(512, 10 * 2, 1, 1, relu=False, bn=False)self.rpn_regress = basic_conv(512, 10 * 2, 1, 1, relu=False, bn=False)def forward(self, x):x = self.base_layers(x)# rpnx = self.rpn(x)  # [b, c, h, w]x1 = x.permute(0, 2, 3, 1).contiguous()  # channels last   [b, h, w, c]b = x1.size()  # b, h, w, cx1 = x1.view(b[0] * b[1], b[2], b[3])x2, _ = self.brnn(x1)xsz = x.size()x3 = x2.view(xsz[0], xsz[2], xsz[3], 256)  # torch.Size([4, 20, 20, 256])x3 = x3.permute(0, 3, 1, 2).contiguous()  # channels first [b, c, h, w]x3 = self.lstm_fc(x3)x = x3cls = self.rpn_class(x)regr = self.rpn_regress(x)cls = cls.permute(0, 2, 3, 1).contiguous()regr = regr.permute(0, 2, 3, 1).contiguous()cls = cls.view(cls.size(0), cls.size(1) * cls.size(2) * 10, 2)regr = regr.view(regr.size(0), regr.size(1) * regr.size(2) * 10, 2)return cls, regrweights = '/home/zc/桌面/pythonProject2/ocr_master/checkpoints/CTPN.pth'  # CTPN模型路径
model = CTPN_Model()
model.load_state_dict(torch.load(weights, map_location=device)['model_state_dict'])
IMAGE_MEAN = [123.68, 116.779, 103.939]
def gen_anchor(featuresize, scale):"""gen base anchor from feature map [HXW][9][4]reshape  [HXW][9][4] to [HXWX9][4]"""heights = [11, 16, 23, 33, 48, 68, 97, 139, 198, 283]widths = [16, 16, 16, 16, 16, 16, 16, 16, 16, 16]# gen k=9 anchor size (h,w)heights = np.array(heights).reshape(len(heights), 1)widths = np.array(widths).reshape(len(widths), 1)base_anchor = np.array([0, 0, 15, 15])# center x,yxt = (base_anchor[0] + base_anchor[2]) * 0.5yt = (base_anchor[1] + base_anchor[3]) * 0.5# x1 y1 x2 y2x1 = xt - widths * 0.5y1 = yt - heights * 0.5x2 = xt + widths * 0.5y2 = yt + heights * 0.5base_anchor = np.hstack((x1, y1, x2, y2))h, w = featuresizeshift_x = np.arange(0, w) * scaleshift_y = np.arange(0, h) * scale# apply shiftanchor = []for i in shift_y:for j in shift_x:anchor.append(base_anchor + [j, i, j, i])return np.array(anchor).reshape((-1, 4))
def bbox_transfor_inv(anchor, regr):"""return predict bbox"""Cya = (anchor[:, 1] + anchor[:, 3]) * 0.5ha = anchor[:, 3] - anchor[:, 1] + 1Vcx = regr[0, :, 0]Vhx = regr[0, :, 1]Cyx = Vcx * ha + Cyahx = np.exp(Vhx) * haxt = (anchor[:, 0] + anchor[:, 2]) * 0.5x1 = xt - 16 * 0.5y1 = Cyx - hx * 0.5x2 = xt + 16 * 0.5y2 = Cyx + hx * 0.5bbox = np.vstack((x1, y1, x2, y2)).transpose()return bbox
def clip_box(bbox, im_shape):# x1 >= 0bbox[:, 0] = np.maximum(np.minimum(bbox[:, 0], im_shape[1] - 1), 0)# y1 >= 0bbox[:, 1] = np.maximum(np.minimum(bbox[:, 1], im_shape[0] - 1), 0)# x2 < im_shape[1]bbox[:, 2] = np.maximum(np.minimum(bbox[:, 2], im_shape[1] - 1), 0)# y2 < im_shape[0]bbox[:, 3] = np.maximum(np.minimum(bbox[:, 3], im_shape[0] - 1), 0)return bbox
def filter_bbox(bbox, minsize):ws = bbox[:, 2] - bbox[:, 0] + 1hs = bbox[:, 3] - bbox[:, 1] + 1keep = np.where((ws >= minsize) & (hs >= minsize))[0]return keep
def nms(dets, thresh):x1 = dets[:, 0]y1 = dets[:, 1]x2 = dets[:, 2]y2 = dets[:, 3]scores = dets[:, 4]areas = (x2 - x1 + 1) * (y2 - y1 + 1)order = scores.argsort()[::-1]keep = []while order.size > 0:i = order[0]keep.append(i)xx1 = np.maximum(x1[i], x1[order[1:]])yy1 = np.maximum(y1[i], y1[order[1:]])xx2 = np.minimum(x2[i], x2[order[1:]])yy2 = np.minimum(y2[i], y2[order[1:]])w = np.maximum(0.0, xx2 - xx1 + 1)h = np.maximum(0.0, yy2 - yy1 + 1)inter = w * hovr = inter / (areas[i] + areas[order[1:]] - inter)inds = np.where(ovr <= thresh)[0]order = order[inds + 1]return keep
class Graph:def __init__(self, graph):self.graph = graphdef sub_graphs_connected(self):sub_graphs = []for index in range(self.graph.shape[0]):if not self.graph[:, index].any() and self.graph[index, :].any():v = indexsub_graphs.append([v])while self.graph[v, :].any():v = np.where(self.graph[v, :])[0][0]sub_graphs[-1].append(v)return sub_graphs
class TextLineCfg:SCALE = 600MAX_SCALE = 1200TEXT_PROPOSALS_WIDTH = 16MIN_NUM_PROPOSALS = 2MIN_RATIO = 0.5LINE_MIN_SCORE = 0.9MAX_HORIZONTAL_GAP = 60TEXT_PROPOSALS_MIN_SCORE = 0.7TEXT_PROPOSALS_NMS_THRESH = 0.3MIN_V_OVERLAPS = 0.6MIN_SIZE_SIM = 0.6class Graph:def __init__(self, graph):self.graph = graphdef sub_graphs_connected(self):sub_graphs = []for index in range(self.graph.shape[0]):if not self.graph[:, index].any() and self.graph[index, :].any():v = indexsub_graphs.append([v])while self.graph[v, :].any():v = np.where(self.graph[v, :])[0][0]sub_graphs[-1].append(v)return sub_graphs
class TextProposalGraphBuilder:"""Build Text proposals into a graph."""def get_successions(self, index):box = self.text_proposals[index]results = []for left in range(int(box[0]) + 1, min(int(box[0]) + TextLineCfg.MAX_HORIZONTAL_GAP + 1, self.im_size[1])):adj_box_indices = self.boxes_table[left]for adj_box_index in adj_box_indices:if self.meet_v_iou(adj_box_index, index):results.append(adj_box_index)if len(results) != 0:return resultsreturn resultsdef get_precursors(self, index):box = self.text_proposals[index]results = []for left in range(int(box[0]) - 1, max(int(box[0] - TextLineCfg.MAX_HORIZONTAL_GAP), 0) - 1, -1):adj_box_indices = self.boxes_table[left]for adj_box_index in adj_box_indices:if self.meet_v_iou(adj_box_index, index):results.append(adj_box_index)if len(results) != 0:return resultsreturn resultsdef is_succession_node(self, index, succession_index):precursors = self.get_precursors(succession_index)if self.scores[index] >= np.max(self.scores[precursors]):return Truereturn Falsedef meet_v_iou(self, index1, index2):def overlaps_v(index1, index2):h1 = self.heights[index1]h2 = self.heights[index2]y0 = max(self.text_proposals[index2][1], self.text_proposals[index1][1])y1 = min(self.text_proposals[index2][3], self.text_proposals[index1][3])return max(0, y1 - y0 + 1) / min(h1, h2)def size_similarity(index1, index2):h1 = self.heights[index1]h2 = self.heights[index2]return min(h1, h2) / max(h1, h2)return overlaps_v(index1, index2) >= TextLineCfg.MIN_V_OVERLAPS and \size_similarity(index1, index2) >= TextLineCfg.MIN_SIZE_SIMdef build_graph(self, text_proposals, scores, im_size):self.text_proposals = text_proposalsself.scores = scoresself.im_size = im_sizeself.heights = text_proposals[:, 3] - text_proposals[:, 1] + 1boxes_table = [[] for _ in range(self.im_size[1])]for index, box in enumerate(text_proposals):boxes_table[int(box[0])].append(index)self.boxes_table = boxes_tablegraph = np.zeros((text_proposals.shape[0], text_proposals.shape[0]), np.bool)for index, box in enumerate(text_proposals):successions = self.get_successions(index)if len(successions) == 0:continuesuccession_index = successions[np.argmax(scores[successions])]if self.is_succession_node(index, succession_index):# NOTE: a box can have multiple successions(precursors) if multiple successions(precursors)# have equal scores.graph[index, succession_index] = Truereturn Graph(graph)
class TextProposalConnectorOriented:"""Connect text proposals into text lines"""def __init__(self):self.graph_builder = TextProposalGraphBuilder()def group_text_proposals(self, text_proposals, scores, im_size):graph = self.graph_builder.build_graph(text_proposals, scores, im_size)return graph.sub_graphs_connected()def fit_y(self, X, Y, x1, x2):# len(X) != 0# if X only include one point, the function will get line y=Y[0]if np.sum(X == X[0]) == len(X):return Y[0], Y[0]p = np.poly1d(np.polyfit(X, Y, 1))return p(x1), p(x2)def get_text_lines(self, text_proposals, scores, im_size):"""text_proposals:boxes"""# tp=text proposaltp_groups = self.group_text_proposals(text_proposals, scores, im_size)  # 首先还是建图,获取到文本行由哪几个小框构成text_lines = np.zeros((len(tp_groups), 8), np.float32)for index, tp_indices in enumerate(tp_groups):text_line_boxes = text_proposals[list(tp_indices)]  # 每个文本行的全部小框X = (text_line_boxes[:, 0] + text_line_boxes[:, 2]) / 2  # 求每一个小框的中心x,y坐标Y = (text_line_boxes[:, 1] + text_line_boxes[:, 3]) / 2z1 = np.polyfit(X, Y, 1)  # 多项式拟合,根据之前求的中心店拟合一条直线(最小二乘)x0 = np.min(text_line_boxes[:, 0])  # 文本行x坐标最小值x1 = np.max(text_line_boxes[:, 2])  # 文本行x坐标最大值offset = (text_line_boxes[0, 2] - text_line_boxes[0, 0]) * 0.5  # 小框宽度的一半# 以全部小框的左上角这个点去拟合一条直线,然后计算一下文本行x坐标的极左极右对应的y坐标lt_y, rt_y = self.fit_y(text_line_boxes[:, 0], text_line_boxes[:, 1], x0 + offset, x1 - offset)# 以全部小框的左下角这个点去拟合一条直线,然后计算一下文本行x坐标的极左极右对应的y坐标lb_y, rb_y = self.fit_y(text_line_boxes[:, 0], text_line_boxes[:, 3], x0 + offset, x1 - offset)score = scores[list(tp_indices)].sum() / float(len(tp_indices))  # 求全部小框得分的均值作为文本行的均值text_lines[index, 0] = x0text_lines[index, 1] = min(lt_y, rt_y)  # 文本行上端 线段 的y坐标的小值text_lines[index, 2] = x1text_lines[index, 3] = max(lb_y, rb_y)  # 文本行下端 线段 的y坐标的大值text_lines[index, 4] = score  # 文本行得分text_lines[index, 5] = z1[0]  # 根据中心点拟合的直线的k,btext_lines[index, 6] = z1[1]height = np.mean((text_line_boxes[:, 3] - text_line_boxes[:, 1]))  # 小框平均高度text_lines[index, 7] = height + 2.5text_recs = np.zeros((len(text_lines), 9), np.float)index = 0for line in text_lines:b1 = line[6] - line[7] / 2  # 根据高度和文本行中心线,求取文本行上下两条线的b值b2 = line[6] + line[7] / 2x1 = line[0]y1 = line[5] * line[0] + b1  # 左上x2 = line[2]y2 = line[5] * line[2] + b1  # 右上x3 = line[0]y3 = line[5] * line[0] + b2  # 左下x4 = line[2]y4 = line[5] * line[2] + b2  # 右下disX = x2 - x1disY = y2 - y1width = np.sqrt(disX * disX + disY * disY)  # 文本行宽度fTmp0 = y3 - y1  # 文本行高度fTmp1 = fTmp0 * disY / widthx = np.fabs(fTmp1 * disX / width)  # 做补偿y = np.fabs(fTmp1 * disY / width)if line[5] < 0:x1 -= xy1 += yx4 += xy4 -= yelse:x2 += xy2 += yx3 -= xy3 -= ytext_recs[index, 0] = x1text_recs[index, 1] = y1text_recs[index, 2] = x2text_recs[index, 3] = y2text_recs[index, 4] = x3text_recs[index, 5] = y3text_recs[index, 6] = x4text_recs[index, 7] = y4text_recs[index, 8] = line[4]index = index + 1return text_recs"""
def get_det_boxes(image,display = True, expand = True):# image = resize(image, height=height)image_r = image.copy()image_c = image.copy()h, w = image.shape[:2]image = image.astype(np.float32) - IMAGE_MEANimage = torch.from_numpy(image.transpose(2, 0, 1)).unsqueeze(0).float()with torch.no_grad():image = image.to(device)cls, regr = model(image)cls_prob = F.softmax(cls, dim=-1).cpu().numpy()regr = regr.cpu().numpy()anchor = gen_anchor((int(h / 16), int(w / 16)), 16)bbox = bbox_transfor_inv(anchor, regr)bbox = clip_box(bbox, [h, w])# print(bbox.shape)fg = np.where(cls_prob[0, :, 1] > prob_thresh)[0]# print(np.max(cls_prob[0, :, 1]))select_anchor = bbox[fg, :]select_score = cls_prob[0, fg, 1]select_anchor = select_anchor.astype(np.int32)# print(select_anchor.shape)keep_index = filter_bbox(select_anchor, 16)# nmsselect_anchor = select_anchor[keep_index]select_score = select_score[keep_index]select_score = np.reshape(select_score, (select_score.shape[0], 1))nmsbox = np.hstack((select_anchor, select_score))keep = nms(nmsbox, 0.3)# print(keep)select_anchor = select_anchor[keep]select_score = select_score[keep]# text line-textConn = TextProposalConnectorOriented()text = textConn.get_text_lines(select_anchor, select_score, [h, w])# expand textif expand:for idx in range(len(text)):text[idx][0] = max(text[idx][0] - 10, 0)text[idx][2] = min(text[idx][2] + 10, w - 1)text[idx][4] = max(text[idx][4] - 10, 0)text[idx][6] = min(text[idx][6] + 10, w - 1)if display:blank = np.zeros(image_c.shape,dtype=np.uint8)for box in select_anchor:pt1 = (box[0], box[1])pt2 = (box[2], box[3])print(pt1, pt2)cv2.rectangle(image_c,pt1, pt2, (0, 0, 0))return [pt1, pt2],image_c #返回检测框,画框图片def single_pic_proc(image_file):image = np.array(Image.open(image_file).convert('RGB'))_, img = get_det_boxes(image)return img
if __name__ == '__main__':"""上传图片路径返回图片和坐标"""url = '/home/zc/桌面/pythonProject2/imgs/91110101MA00BEU57K.jpg'img = single_pic_proc(url)Image.fromarray(img).save('./op.jpg')


OCR-CTPN 文字检测相关推荐

  1. 【OCR】文字检测:传统算法、CTPN、EAST

    我的east和ctpn速度差不多,east正确率高4% http://xiaofengshi.com/2019/01/23/深度学习-TextDetection/ https://codeload.g ...

  2. 中文OCR场景文字检测工具cnstd、文本框文字识别工具cnocr调试评测

    这款很好用的工具项目地址为:https://github.com/breezedeus/cnstd 目前基于PyTorch开发 首先需要安装相关依赖: pip install cnstdpip ins ...

  3. 【项目实践】中英文文字检测与识别项目(CTPN+CRNN+CTC Loss原理讲解)

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自:opencv学堂 OCR--简介 文字识别也是图像领域一 ...

  4. 【AI实战】手把手教你深度学习文字识别(文字检测篇:基于MSER, CTPN, SegLink, EAST等方法)

    附Java/C/C++/机器学习/算法与数据结构/前端/安卓/Python/程序员必读书籍书单大全: 书单导航页(点击右侧 极客侠栈 即可打开个人博客):极客侠栈 ①[Java]学习之路吐血整理技术书 ...

  5. OCR文字检测主要算法

    转载:https://www.mayi888.com/archives/60604 文字检测是文字识别过程中的一个非常重要的环节,文字检测的主要目标是将图片中的文字区域位置检测出来,以便于进行后面的文 ...

  6. (一)图像文字检测论文:CTPN方法

    论文传送门:Detecting Text in Natural Image with Connectionist Text Proposal Network 1 摘要 我们提出一个新颖的级联文本推荐网 ...

  7. ctpn:图像文字检测方法

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 01 区别 本文工作基于faster RCNN , 区别在于 1. ...

  8. OpenCV OCR实战 文档扫描与文字检测

    本文讲述使用OpenCV- python以及easyocr库实现文档扫描与文字检测的思路和具体实现过程. 目录 知识准备 项目概述 实现过程 代码讲解 1.读入图片并进行预处理(灰度转换,高斯滤波) ...

  9. OCR文字检测框的合并

    OCR文字检测框的合并 项目的github地址:https://github.com/zcswdt/merge_text_boxs 在我们使用文字检测模型的对文本进行检测的时候,可能效果不能如愿以偿, ...

  10. 场景文字检测(一)--CTPN(Connectionist Text Proposal Network)

    论文:Detecting Text in Nature Image with Connectionist Text Proposal Network 在通用目标检测中,每一个物体都有一个定义良好的封闭 ...


  1. 使用Apache Commons Configuration读取配置信息
  2. Eclipse 菜单
  3. python百度翻译包_python百度翻译移动端
  4. 注意Java 8的[Pri​​mitive] Stream.iterate()中的递归
  5. activitimq集群搭建_activitmq+keepalived+nfs 非zk的高可用集群构建
  6. ajax跨域请求(cors实现),ajax跨域请求(CORS实现)
  7. 利用python开发购物车系统
  8. 分类算法之K-近邻算法
  9. UI实用素材|扁平化UI设计模板,UI设计师都要会!
  10. 跨域的小小总结:js跨域及跨域的几种解决方法
  11. 【专利】实用新型专利设计模板
  12. 【PAT】A1150 Travelling Salesman Problem【中国邮递员问题】
  13. @Valid注解的使用---SpringMvc中的校验框架@valid和@validation的概念及相关使用
  14. Sql语句--日期函数用法
  15. python宏观经济研究应用_宏观经济学研究通常用什么软件?
  16. Java与.net的选择和比较
  17. BUUCTF 你尽然赶我走
  18. Android-tab页面-三种实现方法
  19. 网易云音乐真的是随机播放吗
  20. Linux编程基础:1~6章实训编程题


  1. ldoce5 android,朗文Longman 5词典界面修改优化补丁终极版(直接显示词典功能)
  2. 数理统计:方差分析与正交试验设计
  3. 模糊c均值聚类及python实现
  4. ARCore-普及篇
  5. Qt优秀开源项目之十四:SortFilterProxyModel
  6. 用摄动法证明fibs的一个公式
  7. 【python】80行代码实现压缩包密码破解软件,支持zip和rar
  8. 锁定计算机还能远程控制,我的电脑可能被远程控制
  9. 远程控制计算机危险,小心远方的黑手 解析远程控制带来的危险 (2)
  10. 网络代理之后无法抓包的解决方案