论文地址:https://arxiv.org/pdf/1808.04560.pdf

代码地址:https://github.com/weichen582/RetinexNet

解析目录:https://zhuanlan.zhihu.com/p/88761829


整个模型架构被实现为一个类:

class lowlight_enhance(object):

其构造函数实现了网络结构的搭建、损失函数的定义、训练的配置和参数的初始化,具体如下。

网络结构的搭建(该部分包括低/正常光照图像输入的定义以及Decom-Net、Enhance-Net和重建这三部分的对接,注意这里并没有对Rlow进行去噪的部分):

# build the model
self.input_low = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low')
self.input_high = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high')[R_low, I_low] = DecomNet(self.input_low, layer_num=self.DecomNet_layer_num)
[R_high, I_high] = DecomNet(self.input_high, layer_num=self.DecomNet_layer_num)I_delta = RelightNet(I_low, R_low)I_low_3 = concat([I_low, I_low, I_low])
I_high_3 = concat([I_high, I_high, I_high])
I_delta_3 = concat([I_delta, I_delta, I_delta])self.output_R_low = R_low
self.output_I_low = I_low_3
self.output_I_delta = I_delta_3
self.output_S = R_low * I_delta_3

损失函数的定义(该部分包括低/正常光照图像的重建损失、反射分量一致性损失、光照分量平滑损失以及最后分别计算的Decom-Net和Enhance-Net的总损失):

# loss
self.recon_loss_low = tf.reduce_mean(tf.abs(R_low * I_low_3 - self.input_low))
self.recon_loss_high = tf.reduce_mean(tf.abs(R_high * I_high_3 - self.input_high))
self.recon_loss_mutal_low = tf.reduce_mean(tf.abs(R_high * I_low_3 - self.input_low))
self.recon_loss_mutal_high = tf.reduce_mean(tf.abs(R_low * I_high_3 - self.input_high))
self.equal_R_loss = tf.reduce_mean(tf.abs(R_low - R_high))
self.relight_loss = tf.reduce_mean(tf.abs(R_low * I_delta_3 - self.input_high))self.Ismooth_loss_low = self.smooth(I_low, R_low)
self.Ismooth_loss_high = self.smooth(I_high, R_high)
self.Ismooth_loss_delta = self.smooth(I_delta, R_low)self.loss_Decom = self.recon_loss_low + self.recon_loss_high + 0.001 * self.recon_loss_mutal_low + 0.001 * self.recon_loss_mutal_high + 0.1 * self.Ismooth_loss_low + 0.1 * self.Ismooth_loss_high + 0.01 * self.equal_R_loss
self.loss_Relight = self.relight_loss + 3 * self.Ismooth_loss_delta

训练的配置(该部分包括学习率以及Decom-Net和Enhance-Net的优化器设置):

self.lr = tf.placeholder(tf.float32, name='learning_rate')
optimizer = tf.train.AdamOptimizer(self.lr, name='AdamOptimizer')self.var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name]
self.var_Relight = [var for var in tf.trainable_variables() if 'RelightNet' in var.name]self.train_op_Decom = optimizer.minimize(self.loss_Decom, var_list = self.var_Decom)
self.train_op_Relight = optimizer.minimize(self.loss_Relight, var_list = self.var_Relight)

训练参数的初始化:

self.sess.run(tf.global_variables_initializer())self.saver_Decom = tf.train.Saver(var_list = self.var_Decom)
self.saver_Relight = tf.train.Saver(var_list = self.var_Relight)print("[*] Initialize model successfully...")

接下来是该类的一些成员函数。

def gradient(self, input_tensor, direction):self.smooth_kernel_x = tf.reshape(tf.constant([[0, 0], [-1, 1]], tf.float32), [2, 2, 1, 1])self.smooth_kernel_y = tf.transpose(self.smooth_kernel_x, [1, 0, 2, 3])if direction == "x":kernel = self.smooth_kernel_xelif direction == "y":kernel = self.smooth_kernel_yreturn tf.abs(tf.nn.conv2d(input_tensor, kernel, strides=[1, 1, 1, 1], padding='SAME'))

该函数实现的是通过与指定梯度算子进行卷积的方式求图像的水平/垂直梯度图。

def ave_gradient(self, input_tensor, direction):return tf.layers.average_pooling2d(self.gradient(input_tensor, direction), pool_size=3, strides=1, padding='SAME')

该函数实现的是通过平均池化的方式来对图像的水平/垂直梯度图进行平滑。

def smooth(self, input_I, input_R):input_R = tf.image.rgb_to_grayscale(input_R)return tf.reduce_mean(self.gradient(input_I, "x") * tf.exp(-10 * self.ave_gradient(input_R, "x")) + self.gradient(input_I, "y") * tf.exp(-10 * self.ave_gradient(input_R, "y")))

该函数是对光照分量平滑损失的具体实现(可对应原论文中的公式来看)。

def evaluate(self, epoch_num, eval_low_data, sample_dir, train_phase):print("[*] Evaluating for phase %s / epoch %d..." % (train_phase, epoch_num))for idx in range(len(eval_low_data)):input_low_eval = np.expand_dims(eval_low_data[idx], axis=0)if train_phase == "Decom":result_1, result_2 = self.sess.run([self.output_R_low, self.output_I_low], feed_dict={self.input_low: input_low_eval})if train_phase == "Relight":result_1, result_2 = self.sess.run([self.output_S, self.output_I_delta], feed_dict={self.input_low: input_low_eval})save_images(os.path.join(sample_dir, 'eval_%s_%d_%d.png' % (train_phase, idx + 1, epoch_num)), result_1, result_2)

该函数是对训练epoch_num次后的Decom-Net/Enhance-Net模型进行评估,并保存评估结果图。

接下来是关于模型的训练:

def train(self, train_low_data, train_high_data, eval_low_data, batch_size, patch_size, epoch, lr, sample_dir, ckpt_dir, eval_every_epoch, train_phase):

该函数中包含了预训练模型的加载、数据的读取与处理、模型的训练、评估和保存这几个部分。

assert len(train_low_data) == len(train_high_data)
numBatch = len(train_low_data) // int(batch_size)

检查所有需要参与训练的低/正常光照样本数量是否一致,若一致则计算训练集含有的batch数量。

# load pretrained model
if train_phase == "Decom":train_op = self.train_op_Decomtrain_loss = self.loss_Decomsaver = self.saver_Decom
elif train_phase == "Relight":train_op = self.train_op_Relighttrain_loss = self.loss_Relightsaver = self.saver_Relightload_model_status, global_step = self.load(saver, ckpt_dir)
if load_model_status:iter_num = global_stepstart_epoch = global_step // numBatchstart_step = global_step % numBatchprint("[*] Model restore success!")
else:iter_num = 0start_epoch = 0start_step = 0
print("[*] Not find pretrained model!")

若存在Decom-Net/Enhance-Net对应的预训练模型,则进行加载;否则从头开始训练。

# generate data for a batch
batch_input_low = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32")
batch_input_high = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32")
for patch_id in range(batch_size):h, w, _ = train_low_data[image_id].shapex = random.randint(0, h - patch_size)y = random.randint(0, w - patch_size)rand_mode = random.randint(0, 7)batch_input_low[patch_id, :, :, :] = data_augmentation(train_low_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode)batch_input_high[patch_id, :, :, :] = data_augmentation(train_high_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode)image_id = (image_id + 1) % len(train_low_data)if image_id == 0:tmp = list(zip(train_low_data, train_high_data))random.shuffle(list(tmp))train_low_data, train_high_data = zip(*tmp)

顺序读取训练图像,在每次读取的低/正常光照图像对上随机取patch,并进行数据扩增(具体见 中对函数data_augmentation的描述)。这里,应当注意的是,训练数据每满一个batch时将会重新打乱整个训练集。

# train
_, loss = self.sess.run([train_op, train_loss], feed_dict={self.input_low: batch_input_low, self.input_high: batch_input_high, self.lr: lr[epoch]})print("%s Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" % (train_phase, epoch + 1, batch_id + 1, numBatch, time.time() - start_time, loss))
iter_num += 1

训练一个iter并打印相关信息。

# evalutate the model and save a checkpoint file for it
if (epoch + 1) % eval_every_epoch == 0:self.evaluate(epoch + 1, eval_low_data, sample_dir=sample_dir, train_phase=train_phase)self.save(saver, iter_num, ckpt_dir, "RetinexNet-%s" % train_phase)

每训练eval_every_epoch次评估并保存一次模型。

保存指定iter的模型:

def save(self, saver, iter_num, ckpt_dir, model_name):if not os.path.exists(ckpt_dir):os.makedirs(ckpt_dir)print("[*] Saving model %s" % model_name)saver.save(self.sess, os.path.join(ckpt_dir, model_name), global_step=iter_num)

加载最新的模型:

def load(self, saver, ckpt_dir):ckpt = tf.train.get_checkpoint_state(ckpt_dir)if ckpt and ckpt.model_checkpoint_path:full_path = tf.train.latest_checkpoint(ckpt_dir)try:global_step = int(full_path.split('/')[-1].split('-')[-1])except ValueError:global_step = Nonesaver.restore(self.sess, full_path)return True, global_stepelse:print("[*] Failed to load model from %s" % ckpt_dir)return False, 0

最后是关于模型的测试(其中test_high_data并没有用到):

def test(self, test_low_data, test_high_data, test_low_data_names, save_dir, decom_flag):

该函数中包含了模型的加载、模型的测试和结果图的保存这几个部分。

tf.global_variables_initializer().run()print("[*] Reading checkpoint...")
load_model_status_Decom, _ = self.load(self.saver_Decom, './model/Decom')
load_model_status_Relight, _ = self.load(self.saver_Relight, './model/Relight')
if load_model_status_Decom and load_model_status_Relight:print("[*] Load weights successfully...")

初始化所有参数并加载最新的Decom-Net和Enhance-Net模型。

print("[*] Testing...")
for idx in range(len(test_low_data)):print(test_low_data_names[idx])[_, name] = os.path.split(test_low_data_names[idx])suffix = name[name.find('.') + 1:]name = name[:name.find('.')]input_low_test = np.expand_dims(test_low_data[idx], axis=0)[R_low, I_low, I_delta, S] = self.sess.run([self.output_R_low, self.output_I_low, self.output_I_delta, self.output_S], feed_dict = {self.input_low: input_low_test})if decom_flag == 1:save_images(os.path.join(save_dir, name + "_R_low." + suffix), R_low)save_images(os.path.join(save_dir, name + "_I_low." + suffix), I_low)save_images(os.path.join(save_dir, name + "_I_delta." + suffix), I_delta)save_images(os.path.join(save_dir, name + "_S." + suffix), S)

遍历测试样本进行测试,并保存最终结果图(可自行指定是否保存Decom-Net的分解结果)。

欢迎关注公众号:huangxiaobai880

loss低但精确度低_低光照图像增强网络-RetinexNet(model.py解析【2】)相关推荐

  1. 低代码开发平台_低代码开发平台系列:6、低代码是编程技术发展大势所趋

    一.低代码是一种编程技术低代码是快速开发工具/技术的一种,属于软件开发/编程工具/技术领域,主要应用于企业软件开发领域.借助低代码工具,使用者无需编码即可实现企业软件系统常见功能的交付:少量编码扩展更 ...

  2. 低代码开发平台_低代码开发平台测评——伙伴云

    ​本次测评的产品严格来说不算低代码开发平台,它自己给自己的定位更多是全流程数据生产力平台.不过它依然具备应用搭建的关键要素,而且在数据管理方面还比较出彩,所以不能放过它--伙伴云,这款由Discuz! ...

  3. 容错性低是什么意思_低容错颜值表示什么意思 低容错颜值什么梗

    关于一个人的颜值,我们通常都是用高和低来区分好看还是不好看,不过最近又出了一个新的词哦,那就是"低容错颜值",这是说颜值高还是不高呢?下面八宝网小编带来:低容错颜值表示什么意思 低 ...

  4. 手机投屏到电脑_低延迟,传声音

    手机投屏到电脑_低延迟,传声音 1.为什么发这个博客 2.所需硬件,软件 硬件 软件 3.具体步骤 1.投屏画面 2.投声音到电脑 3.过程中遇到的问题 4.存在的缺陷 5.哈哈 1.为什么发这个博客 ...

  5. 高内聚低耦合通俗理解_带你从入门到精通——「高内聚低耦合」

    如果这是第二次看到我的文章,欢迎订阅z哥的公号(跨界架构师)哦~ 本文长度为2871字,建议阅读8分钟. 坚持原创,每一篇都是用心之作- 下面的这个场景你可能会觉得很熟悉(Z哥我又要出演了): Z哥: ...

  6. 添加多浏览器支持是什么意思_低gi什么意思,减肥期间一定要多吃低gi的食物吗?- 理财技巧...

    摘要: 低gi什么意思?现代健康生活已经成为了新时代的标志之一,健康低脂的饮食已经成为大众追捧的潮流之一.但是,你了解低gi是什么吗?下面就和小编一起去看一下. 低gi什么意思?现代健康生活已经成为了 ...

  7. hmc830相位噪声_低相位噪声电压控制振荡器(VCO)和稳定基准电压构成的频率合成器...

    新兴的PLL + VCO (集成电压控制振荡器的锁相环)技术能够针对蜂窝/4G.微波无线电军事等应用快速开发低相位噪声频率合成器,ADI集成频综产品的频率覆盖为25 MHz到13.6 GHz. 蜂窝/ ...

  8. 复杂网络代码_低代码的兴起,程序员要拒绝还是拥抱

    低代码是一种近些年兴起的企业软件快速开发技术和工具.借助低代码使用者无需编码即可完成企业应用的常用功能,少量编码扩展出更多功能.低代码凭借低门槛.高效率和易集成等特性,被越来越多的软件开发团队青睐.G ...

  9. java高内聚低耦合什么意思_高内聚低耦合什么意思?合理通俗解释

    我们常听一些厉害的程序员说过高内聚.低耦合,小伙伴们知道它们是什么意思吗?下面听小编为你解析一下. 什么是低耦合? 官方的说,耦合就是元素与元素之间的连接.感知与依赖量度.元素代表什么?这里的元素代指 ...

最新文章

  1. 在CentOS 6.5 x86_64上安装libunwind的问题
  2. RabbitMQ之消息确认机制(事务+Confirm)
  3. import lombok 报错_lombok
  4. 磁珠与电感的区别,看了就灰常明白了
  5. 青岛经济职业学校有计算机专业吗,青岛经济职业学校
  6. 荣耀:目前还在观望鸿蒙,未来的对手是苹果
  7. 同一域名端口下,通过nginx部署多个vue项目
  8. mysql limti_MYSQL分页 limint
  9. iotop监视磁盘I/O
  10. java session 例子_JavaWeb——HttpSession常用方法示例
  11. 有没有五金产品展开计算机软件,拆单软件功能介绍
  12. 利用Druid Monitor做数据库连接异常排查
  13. ★如何引导客户需求?几个经典的案例分析!
  14. 今天收到一封非常牛B的离职信
  15. 姜小白的Python日记Day8 字符串编码转换与函数简介
  16. Invalid Component definition:header
  17. 一文读懂ADAS系统
  18. 关于Eclipse的使用入门
  19. 《Vue+Spring Boot前后端分离开发实战》专著累计发行上万册
  20. PCIE 2.0协议概念基本科普

热门文章

  1. Java学习第1天:序言,基础及配置tomcat
  2. JPDA 架构研究5 - Agent利用环境指针访问VM (内存管理篇)
  3. socket websocket
  4. peripheralStateNotificationCB
  5. MySql 自动更新时间为当前时间
  6. hive 导入hdfs数据_将数据加载或导入运行在基于HDFS的数据湖之上的Hive表中的另一种方法。
  7. linux分辨率和用户有关吗,Linux系统在高分屏非正常分辨率显示
  8. 美团骑手检测出虚假定位_在虚假信息活动中检测协调
  9. 余弦相似度和欧氏距离_欧氏距离和余弦相似度
  10. leetcode 148. 排序链表(归并排序)