目录

前言

一、DeepLab简述

二、DeepLab怎么使用

1、配置

2、代码解析


前言

最近在接触一个项目,要用到语义分割,如果大家有什么问题,都可以跟我沟通交流。除了通过博客交流外,欢迎你加入我的QQ群,一起交流有关于区块链、机器学习、深度学习、计算机视觉有关内容。我现在处于广泛涉猎阶段,希望我们能够一起沟通。下图是我的群二维码:

一、DeepLab简述

Deeplab 是谷歌在FCN的基础上搞出来的。Google 研究团队开源在 Tensorflow 中进行语义图像分割(Semantic Image Segmentation)模型 DeepLab-v3+。

当然,因为主要是应用,不是深入研究,所以不会很详细的说明deeplab的各个版本。

论文下载:https://arxiv.org/pdf/1802.02611.pdf

代码下载:https://github.com/tensorflow/models/tree/master/research/deeplab

二、DeepLab怎么使用

1、配置

配置最重要,特别是我们要做的是深度学习与计算机视觉,不仅要求硬件,还要求软件。

在这里,我们不需要一个特别强大的GPU(有了最好),一个win7或者win10的电脑(win8也可以,但是win8现在很少使用,大多数程序员都偏爱win7)

这个代码是python写的,所以需要用到python,我使用的是python3.6,大家可以根据自己的需要去下载。软件用的是pycharm。

安装教程:请点击安装配置教程,即可查看。

pycharm设置教程:请点击pycharm设置教程,即可查看。

2、代码解析

全部代码如下

#!--*-- coding:utf-8 --*--# Deeplab Demoimport os
import tarfilefrom matplotlib import gridspec
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import tempfile
from six.moves import urllib
import tensorflow as tfclass DeepLabModel(object):"""加载 DeepLab 模型;推断 Inference."""INPUT_TENSOR_NAME = 'ImageTensor:0'OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'INPUT_SIZE = 400FROZEN_GRAPH_NAME = 'frozen_inference_graph'def __init__(self, tarball_path):"""加载预训练模型"""self.graph = tf.Graph()graph_def = None# Extract frozen graph from tar archive.tar_file = tarfile.open(tarball_path)#得到模型中的文件for tar_info in tar_file.getmembers():if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name):file_handle = tar_file.extractfile(tar_info)graph_def = tf.GraphDef.FromString(file_handle.read())breaktar_file.close()if graph_def is None:raise RuntimeError('Cannot find inference graph in tar archive.')with self.graph.as_default():tf.import_graph_def(graph_def, name='')self.sess = tf.Session(graph=self.graph)def run(self, image):"""Args:image:  转换为PIL.Image 类,不能直接用图片,原始图片Returns:resized_image: RGB image resized from original input image.seg_map: Segmentation map of `resized_image`."""width, height = image.sizeresize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)target_size = (int(resize_ratio * width), int(resize_ratio * height))resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)batch_seg_map = self.sess.run(self.OUTPUT_TENSOR_NAME,feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})seg_map = batch_seg_map[0]return resized_image, seg_mapdef create_pascal_label_colormap():"""Creates a label colormap used in PASCAL VOC segmentation benchmark.Returns:A Colormap for visualizing segmentation results."""colormap = np.zeros((256, 3), dtype=int)# zeros:元素全为0的矩阵ind = np.arange(256, dtype=int) # 用于创建等差数组,range返回构成的list,而arange返回一个array对象for shift in reversed(range(8)): # reversed:返回的是一个把序列值经过反转之后的迭代器,通过遍历,List或者next()等方法,获取作用后的值;for channel in range(3):colormap[:, channel] |= ((ind >> channel) & 1) << shiftind >>= 3return colormapdef label_to_color_image(label):"""Adds color defined by the dataset colormap to the label.Args:label: A 2D array with integer type, storing the segmentation label.Returns:result: A 2D array with floating type. The element of the arrayis the color indexed by the corresponding element in the input labelto the PASCAL color map.Raises:ValueError: If label is not of rank 2 or its value is larger than colormap maximum entry."""if label.ndim != 2:raise ValueError('Expect 2-D input label')colormap = create_pascal_label_colormap()if np.max(label) >= len(colormap):raise ValueError('label value too large.')return colormap[label]def vis_segmentation(image, seg_map, imagefile):"""可视化三种图像."""plt.figure(figsize=(15, 5))grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])plt.subplot(grid_spec[0])plt.imshow(image)plt.axis('off')plt.title('input image')plt.subplot(grid_spec[1])seg_image = label_to_color_image(seg_map).astype(np.uint8)# seg_image = label_to_color_image(seg_map)# seg_image.save('/str(ss)+imagefile')plt.imshow(seg_image)plt.savefig('./'+imagefile+'.png')plt.axis('off')plt.title('segmentation map')plt.subplot(grid_spec[2])plt.imshow(image)plt.imshow(seg_image, alpha=0.7)plt.axis('off')plt.title('segmentation overlay') #添加标题""""""unique_labels = np.unique(seg_map)ax = plt.subplot(grid_spec[3])plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')ax.yaxis.tick_right()plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])plt.xticks([], [])ax.tick_params(width=0.0)plt.grid('off')plt.show()##标签名称LABEL_NAMES = np.asarray(['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus','car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike','person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv' ])FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)'''
## Tensorflow 提供的模型下载
MODEL_NAME = 'xception_coco_voctrainval'
# ['mobilenetv2_coco_voctrainaug', 'mobilenetv2_coco_voctrainval', 'xception_coco_voctrainaug', 'xception_coco_voctrainval']_DOWNLOAD_URL_PREFIX = 'http://download.tensorflow.org/models/'
_MODEL_URLS = {'mobilenetv2_coco_voctrainaug': 'deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz','mobilenetv2_coco_voctrainval': 'deeplabv3_mnv2_pascal_trainval_2018_01_29.tar.gz','xception_coco_voctrainaug': 'deeplabv3_pascal_train_aug_2018_01_04.tar.gz','xception_coco_voctrainval': 'deeplabv3_pascal_trainval_2018_01_04.tar.gz', }_TARBALL_NAME = 'deeplab_model.tar.gz'# model_dir = tempfile.mkdtemp()
model_dir = './'
# tf.gfile.MakeDirs(model_dir)#
download_path = os.path.join(model_dir, _TARBALL_NAME)
print('downloading model, this might take a while...')
#
urllib.request.urlretrieve(_DOWNLOAD_URL_PREFIX + _MODEL_URLS[MODEL_NAME], download_path)
print('download completed! loading DeepLab model...')'''# model_dir = '/‘# download_path = os.path.join(model_dir, _MODEL_URLS[MODEL_NAME])
MODEL = DeepLabModel('./deeplab_model.tar.gz') #修改这里更换模型# MODEL = './deeplab_model.tar.gz'
print('model loaded successfully!')##
def run_visualization(imagefile):"""DeepLab 语义分割,并可视化结果."""# orignal_im = Image.open(imagefile)# print(type(orignal_im))# orignal_im.show()print('running deeplab on image %s...' % imagefile)resized_im, seg_map = MODEL.run(Image.open(imagefile))vis_segmentation(resized_im, seg_map,imagefile)images_dir = './pictures' #修改这里更改图片路径
images = sorted(os.listdir(images_dir))
print(images)for imgfile in images:run_visualization(os.path.join(images_dir, imgfile))print('Done.')

1.引入包

我们需要引入一系列的包,当然,有些包是python自带的,有些需要自己安装,那什么的需要自己安装呢?

那些下面有红色下滑波浪线的,就是需要自己安装或配置的。

import os
import tarfilefrom matplotlib import gridspec
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import tempfile
from six.moves import urllib
import tensorflow as tf

我们直接利用Pycharm安装就好。以numpy为例:

1.依次点击File->Setting:

2.选择Project:(自己项目名称)->Project Interpreter。选择右面的 “+” 。

3.安装有如下几步:

(1)输入你要安装的包,

(2)选择你要安装的包,

(3)检查版本并查看信息,

(4)如果版本不对应自己选择版本,

(5)安装。

注:在安装PIL中可能出错。解决方案请点击:

Python报错:找不到PIL包

2.模型类

代码中需要用到一个DeepLabModel类,做相关操作。

class DeepLabModel(object):"""加载 DeepLab 模型;推断 Inference."""INPUT_TENSOR_NAME = 'ImageTensor:0'OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'INPUT_SIZE = 400FROZEN_GRAPH_NAME = 'frozen_inference_graph'def __init__(self, tarball_path):"""加载预训练模型"""self.graph = tf.Graph()graph_def = None# Extract frozen graph from tar archive.tar_file = tarfile.open(tarball_path)#得到模型中的文件for tar_info in tar_file.getmembers():if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name):file_handle = tar_file.extractfile(tar_info)graph_def = tf.GraphDef.FromString(file_handle.read())breaktar_file.close()if graph_def is None:raise RuntimeError('Cannot find inference graph in tar archive.')with self.graph.as_default():tf.import_graph_def(graph_def, name='')self.sess = tf.Session(graph=self.graph)def run(self, image):"""Args:image:  转换为PIL.Image 类,不能直接用图片,原始图片Returns:resized_image: RGB image resized from original input image.seg_map: Segmentation map of `resized_image`."""width, height = image.sizeresize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)target_size = (int(resize_ratio * width), int(resize_ratio * height))resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)batch_seg_map = self.sess.run(self.OUTPUT_TENSOR_NAME,feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})seg_map = batch_seg_map[0]return resized_image, seg_map

3.函数实现1

我们创建了三个函数。

def create_pascal_label_colormap():"""Creates a label colormap used in PASCAL VOC segmentation benchmark.Returns:A Colormap for visualizing segmentation results."""colormap = np.zeros((256, 3), dtype=int)# zeros:元素全为0的矩阵ind = np.arange(256, dtype=int) # 用于创建等差数组,range返回构成的list,而arange返回一个array对象for shift in reversed(range(8)): # reversed:返回的是一个把序列值经过反转之后的迭代器,通过遍历,List或者next()等方法,获取作用后的值;for channel in range(3):colormap[:, channel] |= ((ind >> channel) & 1) << shiftind >>= 3return colormapdef label_to_color_image(label):"""Adds color defined by the dataset colormap to the label.Args:label: A 2D array with integer type, storing the segmentation label.Returns:result: A 2D array with floating type. The element of the arrayis the color indexed by the corresponding element in the input labelto the PASCAL color map.Raises:ValueError: If label is not of rank 2 or its value is larger than colormap maximum entry."""if label.ndim != 2:raise ValueError('Expect 2-D input label')colormap = create_pascal_label_colormap()if np.max(label) >= len(colormap):raise ValueError('label value too large.')return colormap[label]def vis_segmentation(image, seg_map, imagefile):"""可视化三种图像."""plt.figure(figsize=(15, 5))grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])plt.subplot(grid_spec[0])plt.imshow(image)plt.axis('off')plt.title('input image')plt.subplot(grid_spec[1])seg_image = label_to_color_image(seg_map).astype(np.uint8)# seg_image = label_to_color_image(seg_map)# seg_image.save('/str(ss)+imagefile')plt.imshow(seg_image)plt.savefig('./'+imagefile+'.png')plt.axis('off')plt.title('segmentation map')plt.subplot(grid_spec[2])plt.imshow(image)plt.imshow(seg_image, alpha=0.7)plt.axis('off')plt.title('segmentation overlay') #添加标题""""""unique_labels = np.unique(seg_map)ax = plt.subplot(grid_spec[3])plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')ax.yaxis.tick_right()plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])plt.xticks([], [])ax.tick_params(width=0.0)plt.grid('off')plt.show()

这三个函数用于图像的语义分割,并可视化。

4.设置标签

设置标签及其对应颜色,会调用我们上面写好的函数。

##标签名称LABEL_NAMES = np.asarray(['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus','car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike','person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv' ])FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)

这三个函数用于图像的语义分割,并可视化。

5.下载模型

这部分要下载模型,第一次使用,要将下面多行注释取消,第二次使用,因为已经下载过模型,就无需再次下载,直接执行即可,所以把代码继续注释就好。


'''
## Tensorflow 提供的模型下载
MODEL_NAME = 'xception_coco_voctrainval'
# ['mobilenetv2_coco_voctrainaug', 'mobilenetv2_coco_voctrainval', 'xception_coco_voctrainaug', 'xception_coco_voctrainval']_DOWNLOAD_URL_PREFIX = 'http://download.tensorflow.org/models/'
_MODEL_URLS = {'mobilenetv2_coco_voctrainaug': 'deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz','mobilenetv2_coco_voctrainval': 'deeplabv3_mnv2_pascal_trainval_2018_01_29.tar.gz','xception_coco_voctrainaug': 'deeplabv3_pascal_train_aug_2018_01_04.tar.gz','xception_coco_voctrainval': 'deeplabv3_pascal_trainval_2018_01_04.tar.gz', }_TARBALL_NAME = 'deeplab_model.tar.gz'# model_dir = tempfile.mkdtemp()
model_dir = './'
# tf.gfile.MakeDirs(model_dir)#
download_path = os.path.join(model_dir, _TARBALL_NAME)
print('downloading model, this might take a while...')
#
urllib.request.urlretrieve(_DOWNLOAD_URL_PREFIX + _MODEL_URLS[MODEL_NAME], download_path)
print('download completed! loading DeepLab model...')'''# model_dir = '/‘# download_path = os.path.join(model_dir, _MODEL_URLS[MODEL_NAME])
MODEL = DeepLabModel('./deeplab_model.tar.gz') #修改这里更换模型
# MODEL = './deeplab_model.tar.gz'print('model loaded successfully!')

6.函数实现2

再次创建一个函数,把语义分割可视化封装好。

def run_visualization(imagefile):"""DeepLab 语义分割,并可视化结果."""# orignal_im = Image.open(imagefile)# print(type(orignal_im))# orignal_im.show()print('running deeplab on image %s...' % imagefile)resized_im, seg_map = MODEL.run(Image.open(imagefile))vis_segmentation(resized_im, seg_map,imagefile)

7.调用函数

最后就是调用函数,执行语义分割了。

images_dir = './pictures' #修改这里更改图片路径
images = sorted(os.listdir(images_dir))
print(images)for imgfile in images:run_visualization(os.path.join(images_dir, imgfile))print('Done.')

输出图像如下:

大家有什么问题可以在下面评论,我们一起相互交流。

【语义分割 - DeepLab学习笔记】之DeepLab初体验相关推荐

  1. PyTorch学习笔记:PyTorch初体验

    PyTorch学习笔记:PyTorch初体验 一.在Anaconda里安装PyTorch 1.进入虚拟环境mlcc 2.安装PyTorch 二.在PyTorch创建张量 1.启动mlcc环境下的Spy ...

  2. Python学习笔记:pandas初体验

    Python学习笔记:pandas初体验 一.安装pandas模块 1.安装Python3.7 注意:必须勾选pip(python install package). 2.配置环境变量 3.下载 pa ...

  3. 语义分割(学习笔记)

    语义分割基础网络(学习笔记) 参考博客:1. 图像语义分割模型综述_李明朔的博客-CSDN博客_语义分割模型 2.史上最全语义分割综述(FCN,UNet,SegNet,Deeplab,ASPP-)_一 ...

  4. 深度学习-【语义分割】学习笔记4 膨胀卷积(Dilated convolution)

    文章目录 膨胀卷积 为什么需要膨胀卷积 gridding effect 连续使用三次膨胀卷积--1 连续使用三次膨胀卷积--2 连续使用三次膨胀卷积--3 Understanding Convolut ...

  5. mysql初体验学习笔记_MySQL数据库初体验

    克隆表,将数据表的数据记录生成到新的表中 方法一: create table test01 like GCC01; #通过 LIKE 方法,复制GCC01表结构生成test01表 insert int ...

  6. CSS学习笔记——动画animation初体验(小熊案例)

    文章目录 一.动画概念 二.动画基本用法 1.动画模板 2.动画的使用 3.小熊案例 总结 一.动画概念 动画是CSS3的一个重要用法,动画使元素逐渐从一种样式变为另一种样式.您可以随意更改任意数量的 ...

  7. 语义分割——语义分割数据集的笔记

    1 语义分割数据集的笔记 1.1 Pascal VOC data sets Pascal VOC是一个语义分割的数据集: 官方链接:http://host.robots.ox.ac.uk/pascal ...

  8. Vue快速上手笔记1 - 使用初体验

    Vue快速上手笔记1 - 使用初体验 博主:李俊才 邮箱:291148484@163.com 若本文中存在的错误请告知博主更正 希望对大家有所帮助 专题目录:https://blog.csdn.net ...

  9. JeecgBoot学习_Online表单初体验

    JeecgBoot学习Online表单初体验 在这几个地方改起来 在页面属性里选择类型 再建一个商品表 为了使得选择有树形结构 我们修改这两处 备注: 1.字典table 填写树控件对应的表名 { 例 ...

最新文章

  1. Sqlite3的安装Windows
  2. 通信基站电池,再也不怕丢了
  3. 三个基本原理和概念 - 计算机图形学、数据加密、数据挖掘
  4. 我的iOS学习历程 - UISlider(简单的设置一组图片动画)
  5. 关于jupyter几个不得不知道的tips
  6. DOCTYPE html PUBLIC 指定了 HTML 文档遵循的文档类型定义
  7. C#中使用jieba.NET、WordCloudSharp制作词云图
  8. [笔记]java-package
  9. 四舍五入_从四舍五入谈起
  10. 【转】ELK 日志分析系统
  11. 通过set赋值,与select赋值的区别
  12. hutool获取5天前的日期_连载|日 产 物 流 管 理 方 式(5)
  13. 混淆的艺术-(苍井空变凤姐)Proguard源码分析(二)Proguard参数解析
  14. oracle EBS在中国的客户
  15. Python入门篇(下)
  16. 2016年11月份各省市报考时间与报名链接 软考 系统集成项目管理工程师
  17. 微信朋友圈信息流可能采用算法排序吗?
  18. ANDROID 学习笔记(三) UI THREAD AsyncTask 使用
  19. 电脑硬件升级完全解决方案-1
  20. 将RTSP流保存为本地TS文件

热门文章

  1. Ancient Printer
  2. 高光谱图像选择波段的研究
  3. 800C的GRPS通信第一步
  4. 服务器备用电源的原理,备自投接线原理图_备用电源自动投入原理和接线
  5. Swift中的访问权限
  6. 百度地图开发自定义图标无法显示的问题
  7. break跳出多层while循环标志位
  8. oracle中job任务停用启用,Oracle案例详细分析:Job任务停止执行
  9. 简单因式分解100道及答案_整式的乘除与因式分解拔高题习题.doc
  10. js+css实现花瓣飘落效果