文章目录

  • 1 为什么有这么一篇文章
  • 2 获取并保存数据集分割预测结果
  • 3 deeplab.get_miou_png()函数代码解析
  • 4 感谢链接

1 为什么有这么一篇文章

其实之前有写过deeplabv3+图像输入->处理->输出全过程,里面包含了如下内容:


该有的似乎都有了,只是想着大家平时针对数据集操作还挺多的,保存数据集的分割预测结果也是一小部分工作内容,故又加了这一篇,内容和上述文章区别不是很大,很容易。

2 获取并保存数据集分割预测结果

get_miou.py代码中,给出了下列代码,完成图片从输入到得到数据集预测结果灰度图的全部过程。

import osfrom PIL import Image
from tqdm import tqdm# ----------------------------------------------------------#
#   DeeplabV3表示分割网络结构,其代码在deeplab.py中,解读见下一节
# ----------------------------------------------------------#
from deeplab import DeeplabV3
# ---------------------------------------------------------------------#
#   compute_mIoU和show_results,其代码在utils/utils_metrics.py中,
#   解读见链接:会有的
#   !本文中并未用到!
# ---------------------------------------------------------------------#
from utils.utils_metrics import compute_mIoU, show_results"""
进行指标评估需要注意:
该文件生成的图为灰度图,因为值比较小,按照PNG形式的图看是没有显示效果的,所以看到近似全黑的图是正常的。
"""
if __name__ == "__main__":#---------------------------------------------------------------------------##   miou_mode用于指定该文件运行时计算的内容#   miou_mode为0代表整个miou计算流程,包括获得预测结果、计算miou。#   miou_mode为1代表仅仅获得预测结果。#   miou_mode为2代表仅仅计算miou。          !!本文中并未用到!!#---------------------------------------------------------------------------#miou_mode       = 1#------------------------------------##   分类个数+1、如2+1#   VOC数据集,所需要区分的类的个数+1#------------------------------------#num_classes     = 21#--------------------------------------------##   区分的种类,和json_to_dataset里面的一样#   种类名称,此例为VOC#--------------------------------------------#name_classes    = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]# name_classes    = ["_background_","cat","dog"]#-------------------------------------------------------------------##   指向VOC数据集所在的文件夹#   默认指向根目录下的VOC数据集#   链接:https://pan.baidu.com/s/1OZfxoyVUKlESsyqs1nuuuw 提取码:wlna#-------------------------------------------------------------------#VOCdevkit_path  = '../VOCdevkit'#--------------------------------------------##   image_ids:['图片名1', '图片名2',...]#--------------------------------------------#image_ids       = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Segmentation/val.txt"),'r').read().splitlines() gt_dir          = os.path.join(VOCdevkit_path, "VOC2007/SegmentationClass/")miou_out_path   = "miou_out"#-------------------------------------------------##   pred_dir预测结果png图片路径,只有8位深度,灰度图#   正常jpg,RGB三通道,24位深度#   彩色png,RGBA四通道,32位深度#-------------------------------------------------#pred_dir        = os.path.join(miou_out_path, 'detection-results')  #-------------------------------------------------##   获得预测结果,输出为8位深度的灰度图#-------------------------------------------------#if miou_mode == 0 or miou_mode == 1:if not os.path.exists(pred_dir):os.makedirs(pred_dir)#-----------------------------------------------------------------------------------##   下方有给出代码#   详细解读见:https://blog.csdn.net/weixin_45377629/article/details/124124238#-----------------------------------------------------------------------------------#print("Load model.")deeplab = DeeplabV3()print("Load model done.")print("Get predict result.")for image_id in tqdm(image_ids):image_path  = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/"+image_id+".jpg")image       = Image.open(image_path)# ------------------------------------##   image是png图片,8位深度,灰度图#   deeplab.get_miou_png(image)见下方解读#   # image size:(原图宽, 原图高)# ------------------------------------#image       = deeplab.get_miou_png(image)   image.save(os.path.join(pred_dir, image_id + ".png"))print("Get predict result done.")

结果输出:

该文件生成的图为灰度图,因为值比较小,按照PNG形式的图看是没有显示效果的,所以看到近似全黑的图是正常的。

3 deeplab.get_miou_png()函数代码解析

通过deeplab.py完成get_miou.pyimage= deeplab.get_miou_png(image),用来获取数据集分割结果灰度图。deeplabv3+网络结构详细介绍可见 DeeplabV3+网络结构详解,通过网络结构获取8位深度的分割结果灰度图见下方代码。

import colorsys
import copy
import timeimport cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import FloatTensor, nn, tensor#---------------------------------------------------------------------------------#
#   DeepLab网络及代码
#   详细介绍可见 https://blog.csdn.net/weixin_45377629/article/details/124083978
#---------------------------------------------------------------------------------#
from nets.deeplabv3_plus import DeepLab
#----------------------------------------------------------------------------------#
#   三个函数代码下方给出
#   cvtColor:          将图像转换成RGB图像,防止灰度图在预测时报错。
#   preprocess_input:  归一化
#   resize_image:      对输入图像进行resize,letterbox_image方式,不失真resize
#----------------------------------------------------------------------------------#
from utils.utils import cvtColor, preprocess_input, resize_image#-----------------------------------------------------------------------------------#
#   使用自己训练好的模型预测需要修改3个参数
#   model_path、backbone和num_classes都需要修改!
#   如果出现shape不匹配,一定要注意训练时的model_path、backbone和num_classes的修改
#-----------------------------------------------------------------------------------#
class DeeplabV3(object):_defaults = {#-------------------------------------------------------------------##   model_path指向logs文件夹下的权值文件#   训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。#   验证集损失较低不代表miou较高,仅代表该权值在验证集上泛化性能较好。#   链接:https://pan.baidu.com/s/1TrBlnZUd6xwxUvgFjbz7TQ 提取码:cj80#-------------------------------------------------------------------#"model_path"        : 'model_data/deeplab_mobilenetv2.pth',#----------------------------------------##   所需要区分的类的个数+1#----------------------------------------#"num_classes"       : 21,#----------------------------------------##   所使用的的主干网络:#   mobilenet  #----------------------------------------#"backbone"          : "mobilenet",#----------------------------------------##   输入图片的大小#----------------------------------------#"input_shape"       : [512, 512],#----------------------------------------##   下采样的倍数,一般可选的为8和16#   与训练时设置的一样即可#----------------------------------------#"downsample_factor" : 16,#-------------------------------------------------##   mix_type参数用于控制检测结果的可视化方式##   mix_type = 0的时候代表原图与生成的图进行混合#   mix_type = 1的时候代表仅保留生成的图#   mix_type = 2的时候代表仅扣去背景,仅保留原图中的目标#   下方有给出三种可视化结果的区别#-------------------------------------------------#"mix_type"          : 0,#-------------------------------##   是否使用Cuda#   没有GPU可以设置成False#-------------------------------#"cuda"              : False,}#---------------------------------------------------##   初始化Deeplab#---------------------------------------------------#def __init__(self, **kwargs):#---------------------------------------------------##   _defaults字典原来是这么用起来的#---------------------------------------------------#self.__dict__.update(self._defaults)for name, value in kwargs.items():#-----------------------------------------------##   设置属性 name 值,即self.name==value#-----------------------------------------------#setattr(self, name, value)#---------------------------------------------------##   画框设置不同的颜色#---------------------------------------------------#if self.num_classes <= 21:self.colors = [ (0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128), (128, 64, 12)]else:hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))#---------------------------------------------------##   获得模型#---------------------------------------------------#self.generate()#---------------------------------------------------##   获得所有的分类#---------------------------------------------------#def generate(self):#-----------------------------------------------------------------------------------##   载入模型与权值#   详细介绍可见 https://blog.csdn.net/weixin_45377629/article/details/124083978#-----------------------------------------------------------------------------------#self.net = DeepLab(num_classes=self.num_classes, backbone=self.backbone, downsample_factor=self.downsample_factor, pretrained=False)device      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')self.net.load_state_dict(torch.load(self.model_path, map_location=device))self.net    = self.net.eval()print('{} model, and classes loaded.'.format(self.model_path))if self.cuda:self.net = nn.DataParallel(self.net)self.net = self.net.cuda()#---------------------------------------------------##   预测图片,得到灰度图结果#---------------------------------------------------#def get_miou_png(self, image):#---------------------------------------------------------##   在这里将图像转换成RGB图像,防止灰度图在预测时报错。#   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB#---------------------------------------------------------#image       = cvtColor(image)orininal_h  = np.array(image).shape[0]orininal_w  = np.array(image).shape[1]#---------------------------------------------------------##   给图像增加灰条,实现不失真的resize#   也可以直接resize进行识别#---------------------------------------------------------#image_data, nw, nh  = resize_image(image, (self.input_shape[1],self.input_shape[0]))#---------------------------------------------------------##   添加上batch_size维度#---------------------------------------------------------#image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)with torch.no_grad():images = torch.from_numpy(image_data)if self.cuda:images = images.cuda()#---------------------------------------------------##   图片传入网络进行预测#   VOC为例,self.net(images) shape:torch.size([1,21,512,512])#   pr :tensor, shape:torch.size([21,512,512])pr = self.net(images)[0]#---------------------------------------------------##   取出每一个像素点的种类#   pr.permute(1,2,0):通道交换#   F.softmax(input, dim=-1):在行上softmax,和为1#   pr :array, shape:(512,512,21)#---------------------------------------------------#pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()#----------------------------------------------------##   将灰条部分截取掉#   letterbox_image一般会引入灰条#   pr :array, shape:(512,512,21),有灰条w、h尺寸会变#----------------------------------------------------#pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]#---------------------------------------------------##   进行图片的resize#   灰条去掉后,resize回原图大小#   pr :array, shape:(orininal_w, orininal_h,21)#---------------------------------------------------#pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)#---------------------------------------------------##   取出每一个像素点的种类#   pr :array, shape:(orininal_w, orininal_h)#---------------------------------------------------#pr = pr.argmax(axis=-1)image = Image.fromarray(np.uint8(pr))   # size:(orininal_w, orininal_h)return image

4 感谢链接

https://blog.csdn.net/weixin_44791964/article/details/120113686
https://www.bilibili.com/video/BV173411q7xF?p=15

【DeeplabV3+ get_miou_png】DeeplabV3+获取数据集预测结果灰度图相关推荐

  1. deeplabv3+训练自己的数据集

    deeplabv3+训练自己的数据集 环境:ubuntu 16.04 + TensorFlow 1.9.1 + cuda 9.0 + cudnn 7.0 +python3.6 tensorflow 项 ...

  2. deeplabv3+——训练自己的数据集 torch1.12.0 cuda11.3

    目录 前言 环境 源码 参考博客 一.制作自己的数据集 二.训练 三.可视化 前言 环境 torch==1.12.0+cu113 cuda==11.3 显卡为 RTX3070ti tips:30系显卡 ...

  3. Python使用tpot获取最优模型、将最优模型应用于交叉验证数据集(5折)获取数据集下的最优表现,并将每一折(fold)的预测结果、概率、属于哪一折与测试集标签、结果、概率一并整合输出为结果文件

    Python使用tpot获取最优模型.将最优模型应用于交叉验证数据集(5折)获取数据集下的最优表现,并将每一折(fold)的预测结果.概率.属于哪一折与测试集标签.结果.概率一并整合输出为结果文件 目 ...

  4. Deeplabv3+训练自己的数据集(包含脚本)

    目录 前言 源码 一.环境配置 二.使用步骤 1.制作数据集 2.训练模型 3.测试 三.常见报错 总结 前言 最近在着手一个项目,需要用到语义分割这一块,最后经过慎重的考虑,最终选择deeplabv ...

  5. TensorFlow之DeepLabv3+训练自己的数据集

    0 背景 在之前的文章中,对 tensorflow 目标检测API进行了详细的测试,成功应用其模型做简单的检测任务.首先简单介绍下系统环境的配置 python3.6; tensorflow-gpu 1 ...

  6. 基于卷积神经网络的人脸识别(自我拍摄获取数据集)

    基于卷积神经网络的人脸识别 完整代码.数据请见:https://download.csdn.net/download/weixin_43521269/12837110 人脸识别,是基于人的脸部特征信息 ...

  7. 机器学习-sk-learn-Facebook数据集预测签到位置

    sk-learn Facebook数据集预测签到位置 本次比赛的目的是预测一个人将要签到的地方. 为了本次比赛,Facebook创建了一个虚拟世界,其中包括10公里*10公里共100平方公里的约10万 ...

  8. 泰坦尼克数据集预测分析_探索性数据分析—以泰坦尼克号数据集为例(第1部分)

    泰坦尼克数据集预测分析 Imagine your group of friends have decided to spend the vacations by travelling to an am ...

  9. Spark ML(lib)实验:利用银行营销数据集预测客户是否订阅产品

    一.实验描述 数据集来源于UCI的银行营销数据集(UCI Machine Learning Repository: Bank Marketing Data Set).数据与葡萄牙一家银行机构的直接营销 ...

最新文章

  1. 按照文字内容动态设置TableViewCell的高度
  2. java 读取远程文件夹_java读取远程共享文件 | 学步园
  3. [转载]web集群时利用memcache来同步session
  4. linux上寻找并杀死僵尸进程
  5. 聊聊 Redis 使用场景
  6. [Java] SSH框架笔记_Struts2配置问题
  7. 【Python】Python视频制作工具Manim入门,基础形状详细介绍
  8. VTK:Filtering之PerlinNoise
  9. POJ 2263 floyd思想
  10. This Android SDK requires An... ADT to the late...
  11. android framework,GitHub - zhaozepeng/Android_framework: android framework 用来快速开发的android框架...
  12. 程序员有必要参加软考吗?大一可以考的编程证书还有哪些
  13. 永辉发布元宵数据:汤圆销售明显提升,多个民生产品增长超150%
  14. vue加跨域代理静态文件404_解决vue-router history模式和跨域代理 部署到IIS时404的一些问题...
  15. Representation Flow for Action Recognition论文解读
  16. Tensorflow:TensorFlow基础(一)
  17. linux443端口无法建立连接,无法通过端口443连接到ssh
  18. 地方棋牌游戏里的家乡情结
  19. 相关性系数及其python实现
  20. 第十二届蓝桥杯省赛JAVA B组杨辉三角形个人题解

热门文章

  1. UltraISO使用和U盘安装原版系统指南
  2. 2022电工(初级)考试题库模拟考试平台操作
  3. 我的Crystal xcelsius之旅
  4. 通过RSA和DES实现网络报文加密加签(实例)
  5. java一打开就闪退怎么解决(如何解决java 闪退)
  6. 阿里云的大数据ACP认证含金量高吗?
  7. Java单元测试实践-11.Mock后Stub Spring的@Component组件
  8. Android逆向之旅—Hook神器Cydia Substrate使用详解
  9. 《操作系统》-调度算法
  10. 两化融合是从工业大国向工业强国转变必由之路