最后我们来看一下主函数
论文解析:https://blog.csdn.net/bofu_sun/article/details/89206531
1.
首先是调用一些库函数,同时设置运行文件时的参数

from __future__ import absolute_import, division, print_function# only keep warnings and errors
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='1'import numpy as np
import argparse
import re
import time
import tensorflow as tf
import tensorflow.contrib.slim as slimfrom monodepth_model import *
from monodepth_dataloader import *
from average_gradients import *
# 设置参数对象
parser = argparse.ArgumentParser(description='Monodepth TensorFlow implementation.')
# 模式:训练还是测试
parser.add_argument('--mode',                      type=str,   help='train or test', default='train')
# 模型名:默认monodepth
parser.add_argument('--model_name',                type=str,   help='model name', default='monodepth')
# 编码器使用vgg网络还是残差网络
parser.add_argument('--encoder',                   type=str,   help='type of encoder, vgg or resnet50', default='vgg')
# 数据集使用:作者提供kitti和cityscapes两种
parser.add_argument('--dataset',                   type=str,   help='dataset to train on, kitti, or cityscapes', default='kitti')
# 数据集路径
parser.add_argument('--data_path',                 type=str,   help='path to the data', required=True)
# 记录数据集的文件名
parser.add_argument('--filenames_file',            type=str,   help='path to the filenames text file', required=True)
# 输入图片的高度
parser.add_argument('--input_height',              type=int,   help='input height', default=256)
# 输入图片的宽度
parser.add_argument('--input_width',               type=int,   help='input width', default=512)
# batch批量数
parser.add_argument('--batch_size',                type=int,   help='batch size', default=8)
# 周期数
parser.add_argument('--num_epochs',                type=int,   help='number of epochs', default=50)
# 学习率
parser.add_argument('--learning_rate',             type=float, help='initial learning rate', default=1e-4)
# 左右图相似权值
parser.add_argument('--lr_loss_weight',            type=float, help='left-right consistency weight', default=1.0)
# 计算损失时的参数alpha
parser.add_argument('--alpha_image_loss',          type=float, help='weight between SSIM and L1 in the image loss', default=0.85)
# 梯度损失权值,用于使得视差图尽可能连续
parser.add_argument('--disp_gradient_loss_weight', type=float, help='disparity smoothness weigth', default=0.1)
# 是否生成立体图
parser.add_argument('--do_stereo',                             help='if set, will train the stereo model', action='store_true')
# 双线性插值方式
parser.add_argument('--wrap_mode',                 type=str,   help='bilinear sampler wrap mode, edge or border', default='border')
# 是否使用反卷积层(还是使用插值的方式扩张视差图尺寸)
parser.add_argument('--use_deconv',                            help='if set, will use transposed convolutions', action='store_true')
# gpu数量
parser.add_argument('--num_gpus',                  type=int,   help='number of GPUs to use for training', default=1)
# 线程数
parser.add_argument('--num_threads',               type=int,   help='number of threads to use for data loading', default=8)
# 输出测试视差图到某路径
parser.add_argument('--output_directory',          type=str,   help='output directory for test disparities, if empty outputs to checkpoint folder', default='')
# 保存权值和日志的地址
parser.add_argument('--log_directory',             type=str,   help='directory to save checkpoints and summaries', default='')
# 加载权值地址
parser.add_argument('--checkpoint_path',           type=str,   help='path to a specific checkpoint to load', default='')
# 重新训练
parser.add_argument('--retrain',                               help='if used with checkpoint_path, will restart training from step zero', action='store_true')
# 所有记录都被打开
parser.add_argument('--full_summary',                          help='if set, will keep more data for each summary. Warning: the file can become very large', action='store_true')
# 设置参数
args = parser.parse_args()

2.np.fliplr:左右翻转

>> a=magic(3)
a =  8     1     6  3     5     7  4     9     2
>> b=fliplr(a)    %左右翻转
b =  6     1     8  7     5     3  2     9     4

详见:https://blog.csdn.net/qq_18343569/article/details/50393199

3.numpy.linspace
numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None):
在指定的间隔内返回均匀间隔的数字。

4.np.meshgrid:
二维坐标系中,X轴可以取三个值1,2,3, Y轴可以取三个值7,8, 请问可以获得多少个点的坐标?
显而易见是6个:
(1,7)(2,7)(3,7)
(1,8)(2,8)(3,8)

#coding:utf-8 import numpy as np # 坐标向量
a = np.array([1,2,3]) # 坐标向量
b = np.array([7,8]) # 从坐标向量中返回坐标矩阵
# 返回list,有两个元素,第一个元素是X轴的取值,第二个元素是Y轴的取值
res = np.meshgrid(a,b)
#返回结果: [array([ [1,2,3] [1,2,3] ]), array([ [7,7,7] [8,8,8] ])]

详见:https://blog.csdn.net/littlehaes/article/details/83543459
5. np.clip:将数字强制在一个范围内

def post_process_disparity(disp):_, h, w = disp.shape                       # 获取视差图的三个维度l_disp = disp[0,:,:]                       # 获取左视差图r_disp = np.fliplr(disp[1,:,:])            # 获取右视差图m_disp = 0.5 * (l_disp + r_disp) l, _ = np.meshgrid(np.linspace(0, 1, w), np.linspace(0, 1, h))      # 在0和1之间取宽度*高度个数l_mask = 1.0 - np.clip(20 * (l - 0.05), 0, 1)                       # 生成左标志r_mask = np.fliplr(l_mask)                                          # 生成右标志return r_mask * l_disp + l_mask * r_disp + (1.0 - l_mask - r_mask) * m_disp           #返回???# 记录txt文件中的行数
def count_text_lines(file_path):f = open(file_path, 'r')            # 打开文件lines = f.readlines()               # 记录行数f.close()                           # 关闭文件return len(lines)                   # 返回行数

6.tf.ConfigProto(allow_soft_placement=True)
在tf中,通过命令 “with tf.device(’/cpu:0’):”,允许手动设置操作运行的设备。如果手动设置的设备不存在或者不可用,就会导致tf程序等待或异常,为了防止这种情况,可以设置tf.ConfigProto()中参数allow_soft_placement=True,允许tf自动选择一个存在并且可用的设备来运行操作。
7. tf.train.Coordinator:Coordinator类用来管理在Session中的多个线程,可以用来同时停止多个工作线程并且向那个在等待所有工作线程终止的程序报告异常,该线程捕获到这个异常之后就会终止所有线程。使用 tf.train.Coordinator()来创建一个线程管理器(协调器)对象。
9. 训练函数

def train(params):"""Training loop."""with tf.Graph().as_default(), tf.device('/cpu:0'):global_step = tf.Variable(0, trainable=False)                       # 记录迭代次数# OPTIMIZERnum_training_samples = count_text_lines(args.filenames_file)        # 训练样本在txt文件中steps_per_epoch = np.ceil(num_training_samples / params.batch_size).astype(np.int32)     #设置每个周期进行迭代次数num_total_steps = params.num_epochs * steps_per_epoch               # 设置迭代总次数start_learning_rate = args.learning_rate                            # 设置起始learning_rate # 设定调整学习率的界限,训练3/5和4/5的样本时调整学习率boundaries = [np.int32((3/5) * num_total_steps), np.int32((4/5) * num_total_steps)]      # 学习率分别调整为原来的1/2和1/4values = [args.learning_rate, args.learning_rate / 2, args.learning_rate / 4]       # 设定整个learningrate变化情况# tf.train.piecewise_constant:当走到一定步长时更改学习率learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)# Adam 是一种反向传播方法opt_step = tf.train.AdamOptimizer(learning_rate)                #设置优化器,用于计算loss并进行反向传播# 输出总样本数和步数print("total number of samples: {}".format(num_training_samples))print("total number of steps: {}".format(num_total_steps))# 读取数据集文件dataloader = MonodepthDataloader(args.data_path, args.filenames_file, params, args.dataset, args.mode)left  = dataloader.left_image_batchright = dataloader.right_image_batch# split for each gpu# 为多个gpu分配任务left_splits  = tf.split(left,  args.num_gpus, 0)right_splits = tf.split(right, args.num_gpus, 0)# 梯度数组tower_grads  = []# 损失数组tower_losses = []reuse_variables = Nonewith tf.variable_scope(tf.get_variable_scope()):# 对于每个gpufor i in range(args.num_gpus):with tf.device('/gpu:%d' % i):# 建立monodepth模型model = MonodepthModel(params, args.mode, left_splits[i], right_splits[i], reuse_variables, i)#记录损失loss = model.total_losstower_losses.append(loss)reuse_variables = Truegrads = opt_step.compute_gradients(loss)                 #根据loss计算梯度tower_grads.append(grads)                                 # 添加梯度grads = average_gradients(tower_grads)                               # 计算batch平均梯度apply_gradient_op = opt_step.apply_gradients(grads, global_step=global_step)          #应用梯度total_loss = tf.reduce_mean(tower_losses)                            # 计算平均损失# 记录学习率、损失tf.summary.scalar('learning_rate', learning_rate, ['model_0'])tf.summary.scalar('total_loss', total_loss, ['model_0'])summary_op = tf.summary.merge_all('model_0')# SESSION# 配置设备,自动寻找设备config = tf.ConfigProto(allow_soft_placement=True)sess = tf.Session(config=config)# SAVER# 保存日志summary_writer = tf.summary.FileWriter(args.log_directory + '/' + args.model_name, sess.graph)train_saver = tf.train.Saver()# COUNT PARAMS# 记录计算量total_num_parameters = 0for variable in tf.trainable_variables():total_num_parameters += np.array(variable.get_shape().as_list()).prod()print("number of trainable parameters: {}".format(total_num_parameters))# INIT# 初始化全局变量sess.run(tf.global_variables_initializer())sess.run(tf.local_variables_initializer())# 多线程管理器coordinator = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coordinator)# LOAD CHECKPOINT IF SET# 保存训练好的权值等数据if args.checkpoint_path != '':train_saver.restore(sess, args.checkpoint_path.split(".")[0])if args.retrain:sess.run(global_step.assign(0))# GO!start_step = global_step.eval(session=sess)            # 获取最后的步数start_time = time.time()                               # 开始时间计时for step in range(start_step, num_total_steps):        # 开始循环训练before_op_time = time.time()                       # 每次训练计时_, loss_value = sess.run([apply_gradient_op, total_loss])      #运行网络duration = time.time() - before_op_time            # 记录每次运行的时间if step and step % 100 == 0:                       # 每训练一百次记录一下examples_per_sec = params.batch_size / duration                      # 每秒处理样本数量time_sofar = (time.time() - start_time) / 3600                       # 到现在为止所用时间,单位:小时training_time_left = (num_total_steps / step - 1.0) * time_sofar     # 距离训练结束剩余时间# 输出记录内容print_string = 'batch {:>6} | examples/s: {:4.2f} | loss: {:.5f} | time elapsed: {:.2f}h | time left: {:.2f}h'print(print_string.format(step, examples_per_sec, loss_value, time_sofar, training_time_left))# 添加日志内容summary_str = sess.run(summary_op)summary_writer.add_summary(summary_str, global_step=step)# 每训练一万次保存一次权值等数据if step and step % 10000 == 0:train_saver.save(sess, args.log_directory + '/' + args.model_name + '/model', global_step=step)#保存日志train_saver.save(sess, args.log_directory + '/' + args.model_name + '/model', global_step=num_total_steps)

10.测试函数

# 测试函数
def test(params):"""Test function."""# 加载数据集和左右两图dataloader = MonodepthDataloader(args.data_path, args.filenames_file, params, args.dataset, args.mode)left  = dataloader.left_image_batchright = dataloader.right_image_batch# 加载模型model = MonodepthModel(params, args.mode, left, right)# SESSION# 设置设备config = tf.ConfigProto(allow_soft_placement=True)sess = tf.Session(config=config)# SAVER# 数据储存train_saver = tf.train.Saver()# INIT# 初始化变量sess.run(tf.global_variables_initializer())sess.run(tf.local_variables_initializer())# 配置调节器和多线程coordinator = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coordinator)# RESTORE# 配置存储路径,存储权值数据if args.checkpoint_path == '':restore_path = tf.train.latest_checkpoint(args.log_directory + '/' + args.model_name)else:restore_path = args.checkpoint_path.split(".")[0]train_saver.restore(sess, restore_path)# 加载要测试的图片样本数量num_test_samples = count_text_lines(args.filenames_file)# 测试print('now testing {} files'.format(num_test_samples))disparities    = np.zeros((num_test_samples, params.height, params.width), dtype=np.float32)disparities_pp = np.zeros((num_test_samples, params.height, params.width), dtype=np.float32)# 生成视差图for step in range(num_test_samples):# 运行左图预测disp = sess.run(model.disp_left_est[0])disparities[step] = disp[0].squeeze()disparities_pp[step] = post_process_disparity(disp.squeeze())print('done.')print('writing disparities.')# 保存视差图数据if args.output_directory == '':output_directory = os.path.dirname(args.checkpoint_path)else:output_directory = args.output_directorynp.save(output_directory + '/disparities.npy',    disparities)np.save(output_directory + '/disparities_pp.npy', disparities_pp)print('done.')# 主函数,执行操作并导入参数
def main(_):params = monodepth_parameters(encoder=args.encoder,height=args.input_height,width=args.input_width,batch_size=args.batch_size,num_threads=args.num_threads,num_epochs=args.num_epochs,do_stereo=args.do_stereo,wrap_mode=args.wrap_mode,use_deconv=args.use_deconv,alpha_image_loss=args.alpha_image_loss,disp_gradient_loss_weight=args.disp_gradient_loss_weight,lr_loss_weight=args.lr_loss_weight,full_summary=args.full_summary)if args.mode == 'train':train(params)elif args.mode == 'test':test(params)# 执行主函数
if __name__ == '__main__':tf.app.run()

11.tf.app.run():入口函数
如果你的代码中的入口函数不叫main(),而是一个其他名字的函数,如test(),则你应该这样写入口tf.app.run(test)
如果你的代码中的入口函数叫main(),则你就可以把入口写成tf.app.run()

monodepth无监督卷积神经网络深度估计代码解析(三)相关推荐

  1. monodepth无监督卷积神经网络深度估计代码解析(一)

    论文解析:https://blog.csdn.net/bofu_sun/article/details/89206531 近期在做深度估计相关的毕业设计,发现monodepth项目比较不错,决定尝试一 ...

  2. 最强无监督单目深度估计Baseline--MonoViT--简介与代码复现

    1. 无监督单目深度估计 单目深度估计是指,借助于深度神经网络,从单张输入视图中推理场景的稠密深度信息:该技术可以广泛用于自动驾驶.虚拟现实.增强现实等依赖于三维场景感知理解的领域,同时也可以为其他视 ...

  3. 最新开源无监督单目深度估计方法,解决复杂室内场景难训练问题,效果性能远超SOTA...

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 一.摘要 无监督单目深度估计算法已经被证明能够在驾驶场景(如KITTI数据集)中得到精确的结果,然而最 ...

  4. ECCV2022 | 基于整合IMU运动动力学的无监督单目深度估计

    点击下方卡片,关注"自动驾驶之心"公众号 ADAS巨卷干货,即可获取 点击进入→自动驾驶之心技术交流群 0. 引言 虽然近年来无监督单目深度学习取得了很大的进展,但仍然存在一些基本 ...

  5. 无监督单目深度估计 Unsupervised Monocular Depth Estimation with Left-Right Consistency 论文方法分析

    最近在做深度估计相关的毕业设计,一般的基于深度学习单目深度估计算法都是基于监督学习的方法,也就是说我希望输入一张拍摄到的单目照片,将它通过卷积神经网络后生成一张深度图.在这个过程中我们就要求需要有大量 ...

  6. CoSTA:用于空间转录组分析的无监督卷积神经网络学习方法

    2021年8月,来自美国研究人员在<BMC Bioinformatics>杂志发表了题为"CoSTA: unsupervised convolutional neural net ...

  7. 基于无监督的单深度估计

    Atlas: End-to-End 3D Scene Reconstruction from Posed Images(论文复现) 从姿势图像进行端到端的3D场景重建,该论文发表在2020的ECCV上 ...

  8. 深度学习——day9(外 Q1 2021)基于多尺度特征融合的深度监督卷积神经网络路面裂缝检测

    基于多尺度特征融合的深度监督卷积神经网络路面裂缝检测 导图和笔记资源下载 三级目录# (外 Q1 2021)基于多尺度特征融合的深度监督卷积神经网络路面裂缝检测 chap2 传统裂纹检测方法 1)Tr ...

  9. 粒度语义感知表示增强的自监督单目深度估计 Fine-grained Semantics-aware Representation Enhancement

    Fine-grained Semantics-aware Representation Enhancement for Self-supervised Monocular Depth Estimati ...

最新文章

  1. Lua中的基本函数库
  2. 转《两个个很形象的依赖注入的比喻》
  3. 打开共享文件闪退怎么解决_文件共享解决方案-随时随地共享同步访问文件
  4. 解决grub引导错误的问题
  5. CTF短秘钥的RSA解密
  6. django升级问题
  7. python大气校正_Sentinel-2卫星影像的大气校正方法
  8. hive -e和hive -f的区别(转)
  9. javaone_JavaOne 2012:向上,向上和向外:使用Akka扩展软件
  10. 铃木uy125最高时速_五菱宏光mini EV月销三万辆,铃木是否后悔退出中国?
  11. sys python3 常用_python之sys模块【获取参数】
  12. 2017年总结-致毕业半年的自己
  13. 深度学习中的数据增强方法
  14. UIImageView contentModel
  15. java开发web应用开发,Java Web应用开发概述
  16. 华为云数据库三大优势
  17. xcode 真机调试无法选择对应设备 “ineligible devices“
  18. part实现实现单个(上传图片和文件上传)
  19. Java实现堆,最大堆,最小堆,左高树,左低树
  20. ssh远程连接报错:WARNING: POSSIBLE DNS SPOOFING DETECTED(已解决)

热门文章

  1. 软件库,CDKey卡密充值,php源码
  2. 网络安全的内容有哪些,需要学哪些知识点
  3. 十次方学习——nodejs(1)
  4. 从api获得当前用户信息
  5. 网站服务器怎么配置,怎么配置自己的网站服务器
  6. 第11章,从感知机到支持向量机
  7. 《沟通的方法》作者序读后感
  8. 什么是销售管理?销售管理的五大职能
  9. python 冷门知识点_Python中的10条冷门知识
  10. 极米4K激光电视新品:一杯敬坚果, 一杯敬百度