1. 代码下载地址:https://github.com/hellochick/PSPNet-tensorflow
  2. 下载预训练模型地址(需翻墙):https://drive.google.com/drive/folders/1S90PWzXEX_GNzulG1f2eTHvsruITgqsm?usp=sharing  放在./model/文件夹下,并将checkpoint后的格式.txt删除。
  3. 修改train.py文件
IMG_MEAN = np.array((103.939, 116.779, 123.68), dtype=np.float32)
BATCH_SIZE = 2
DATA_DIRECTORY = ''
DATA_LIST_PATH = './list/train.txt'  #训练数据的
IGNORE_LABEL = 255
INPUT_SIZE = '320,320'  #训练图片的大小
LEARNING_RATE = 1e-3
MOMENTUM = 0.9
NUM_CLASSES = 2  #分类的类别
NUM_STEPS = 60001
POWER = 0.9
RANDOM_SEED = 1234
WEIGHT_DECAY = 0.0001
RESTORE_FROM = './'
SNAPSHOT_DIR = './model/'
SAVE_NUM_IMAGES = 4
SAVE_PRED_EVERY = 50   #保存模型的训练步数

inference.txt文件
ADE20k_param = {'crop_size': [320, 320],   #修改尺寸和输入图片大小相同
'num_classes': 2,
'model': PSPNet50}
cityscapes_param = {'crop_size': [320, 320], #修改尺寸和输入图片大小相同
'num_classes': 2,
'model': PSPNet101}

修改tool.py
使用网络训练自己的数据,首先建立自己的数据集,可以使用labelme等软件制作标签,将标签图像转化为二值图,具体方法可以自行查阅相关资料,修改tool.py中代码将前景和背景颜色修改为自己的标签的颜色,我们的为二分类所以只修改为两种颜色。
IMG_MEAN = np.array((103.939, 116.779, 123.68), dtype=np.float32)
label_colours = [(0, 0, 0),# (244, 35, 231), (69, 69, 69)        #背景色
# 0 = road, 1 = sidewalk, 2 = building
#,(102, 102, 156), (190, 153, 153), (153, 153, 153)
# 3 = wall, 4 = fence, 5 = pole
#,(250, 170, 29), (219, 219, 0), (106, 142, 35)
# 6 = traffic light, 7 = traffic sign, 8 = vegetation
#,(152, 250, 152), (69, 129, 180), (219, 19, 60)
# 9 = terrain, 10 = sky, 11 = person
#,(255, 0, 0), (0, 0, 142), (0, 0, 69)
# 12 = rider, 13 = car, 14 = truck
#,(0, 60, 100), (0, 79, 100), (0, 0, 230)
# 15 = bus, 16 = train, 17 = motocycle
(255, 255, 255)]  #前景色
# 18 = bicycle
matfn = './utils/color150.mat'

针对报错:ValueError: Negative dimension size caused by subtracting 90 from 40 for 'conv5_3_pool1' (op: 'AvgPool') with input shapes: [4,40,40,2048].
训练要求图片的大小为720*720(我们将训练和测试图片大小改为320*×320),修改model代码:
(self.feed('conv5_3/relu')
.avg_pool(40, 40, 40, 40, name='conv5_3_pool1')
.conv(1, 1, 512, 1, 1, biased=False, relu=False, name='conv5_3_pool1_conv')
.batch_normalization(relu=True, name='conv5_3_pool1_conv_bn')
.resize_bilinear(shape, name='conv5_3_pool1_interp'))
(self.feed('conv5_3/relu')
.avg_pool(30, 30, 30, 30, name='conv5_3_pool2')
.conv(1, 1, 512, 1, 1, biased=False, relu=False, name='conv5_3_pool2_conv')
.batch_normalization(relu=True, name='conv5_3_pool2_conv_bn')
.resize_bilinear(shape, name='conv5_3_pool2_interp'))
(self.feed('conv5_3/relu')
.avg_pool(20, 20, 20, 20, name='conv5_3_pool3')
.conv(1, 1, 512, 1, 1, biased=False, relu=False, name='conv5_3_pool3_conv')
.batch_normalization(relu=True, name='conv5_3_pool3_conv_bn')
.resize_bilinear(shape, name='conv5_3_pool3_interp'))
(self.feed('conv5_3/relu')
.avg_pool(10, 10, 10, 10, name='conv5_3_pool6')
.conv(1, 1, 512, 1, 1, biased=False, relu=False, name='conv5_3_pool6_conv')
.batch_normalization(relu=True, name='conv5_3_pool6_conv_bn')
.resize_bilinear(shape, name='conv5_3_pool6_interp'))

  1. 训练:
python train.py --update-mean-var --train-beta-gamma 2>&1 | tee log/train.log
  1. 测试单张数据:
python inference.py --img-path=./input/029.jpg --dataset cityscapes
很抱歉好长时间没有登陆账号,请关注公共号共同学习AI算法相关知识或与我进行沟通

PSPNet-tensorflow实现并训练数据相关推荐

  1. tensorflow sigmoid 如何计算训练数据的正确率_“来自蒙娜丽莎的凝视”— 结合 TensorFlow.js 和深度学习实现...

    客座博文 / Emily Xie,软件工程师 背景 坊间传闻,当您在房间里走动时,蒙娜丽莎的眼睛会一直盯着您. 这就是所谓的"蒙娜丽莎效应".兴趣使然,我最近就编写了一个可互动的数 ...

  2. tensorflow sigmoid 如何计算训练数据的正确率_量化训练:Quantization Aware Training in Tensorflow(一)...

    本文的内容包括对神经网络模型量化的基本介绍.对Tensorflow量化训练的理解与上手实操. 此外,后续系列还对量化训练中的by pass和batch norm两种情况进行补充解释,欢迎点击浏览,量化 ...

  3. tensorflow sigmoid 如何计算训练数据的正确率_用于高级机器学习的自定义TensorFlow损失函数...

    在本文中,我们将看看: 在高级机器学习(ML)应用程序中使用自定义损失函数 定义自定义损失函数并集成到基本Tensorflow神经网络模型 一个简单的知识蒸馏学习的例子 介绍 机器学习中预定义的损失函 ...

  4. TensorFlow(2)-训练数据载入

    tensorflow 训练数据载入 1. tf.data.Dataset 2. dataset 创建数据集的方式 2.1 tf.data.Dataset.from_tensor_slices() 2. ...

  5. 关于使用tensorflow object detection API训练自己的模型-补充部分(代码,数据标注工具,训练数据,测试数据)

    之前分享过关于tensorflow object detection API训练自己的模型的几篇博客,后面有人陆续碰到一些问题,问到了我解决方法.所以在这里补充点大家可能用到的东西.声明一下,本人专业 ...

  6. Tensorflow版yolov3训练自己的数据

    Tensorflow版yolov3训练自己的数据 源代码:https://github.com/YunYang1994/TensorFlow2.0-Examples/tree/master/4-Obj ...

  7. 谷歌BERT预训练源码解析(一):训练数据生成

    目录 预训练源码结构简介 输入输出 源码解析 参数 主函数 创建训练实例 下一句预测&实例生成 随机遮蔽 输出 结果一览 预训练源码结构简介 关于BERT,简单来说,它是一个基于Transfo ...

  8. TensorFlow csv读取文件数据(代码实现)

    TensorFlow csv读取文件数据(代码实现) 大多数人了解 Pandas 及其在处理大数据文件方面的实用性.TensorFlow 提供了读取这种文件的方法. 前面章节中,介绍了如何在 Tens ...

  9. 仅50张图片训练数据的AI分类技术PK​,阿里拿下ECCV 2020竞赛冠军

    出品 | AI科技大本营(ID:rgznai100) 近日,两年一度的世界计算机视觉领域顶会ECCV 2020的各项挑战赛结果出炉,在图像分类赛中,阿里安全的高效AI分类技术超越三星.深兰科技.同济大 ...

  10. 利用卷积神经网络(VGG19)实现火灾分类(附tensorflow代码及训练集)

    源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...

最新文章

  1. Python,OpenCV中的K均值聚类——K-Means Cluster
  2. 腾讯是一只邪恶的小企鹅
  3. 谈谈我对服务熔断、服务降级的理解 专题
  4. [转]快速使用FileProvider解决Android7.0文件权限问题
  5. 12个关键词,告诉你到底什么是机器学习
  6. 【设计模式】—— 备忘录模式Memento
  7. 自定义AlertDialog控件的使用(AndroidStudio)
  8. sql联接查询_SQL联接
  9. CentOS 6.9之LVM创建,扩容
  10. python-socket作业
  11. IO之Socket网络编程
  12. mysql 校对规则名_MySQL字符集及校对规则的理解
  13. hdmi接口线_HDMI高清线注意事项
  14. 2018年将会改变人工智能的5个大数据趋势
  15. 计算机考研需要过英语六级吗,研究生毕业要过英语六级吗 研究生毕业对英语六级有要求吗...
  16. 结构方程模型中的R方改变量怎么求?
  17. 跳动爱心代码-李峋同款爱心代码(升级版)
  18. 坚持玩游戏为什么会这么容易
  19. 小技巧 打印出emoji表情
  20. java判断文件夹中的图片是否重复

热门文章

  1. 记录使用IDEA部署Tomcat时提示错误:the selected directory is not a TomEE home
  2. 再不懂ZooKeeper,就安安心心把这篇文章看完
  3. Windows mysql-64位 数据库安装
  4. 深入谈一谈iOS模块独立运行
  5. Linux目录管理类命令之ls
  6. JerseyTest
  7. maven中文乱码问题——编译错误
  8. 408业务课·计算机网络——【考研随笔】之一
  9. ZZULIOJ 1115: 数组最小值
  10. xml 属性value换行显示_跟光磊学Java开发-Java解析XML