1.数据处理 dataset.py

训练集选用T91, 测试集选用Set5
低分辨图片(x)切割成17x17,高分辨图片(y)切割成17scale x 17scale

import cv2
import glob
import numpy as np
import tensorflow as tf
from keras import backend as kdef psnr(y_true, y_pred):# assume RGB imagereturn 10.0 * k.log(1.0 / (k.mean(k.square(y_pred - y_true)))) / k.log(10.0)# 将图片裁剪成放大倍数的整数倍
def Cropping(img, scale=3):x, y = img.shape[0], img.shape[1]x = x - np.mod(x, scale)  # 取余y = y - np.mod(y, scale)img = img[0:x, 0:y]return img# 数据处理
def get_lr_hr(img, scale):hr = img / 255.    # 归一化# 先缩小再放大 得到低分辨图片lr = cv2.resize(hr, (hr.shape[1] // scale, hr.shape[0] // scale), interpolation=cv2.INTER_CUBIC)return lr, hrdef load_data(path, scale=3, cut_size=17, stride=17):LR = []  # 切片后的lrHR = []  # 切片后的hrsize = []  # (nx, ny)切片次数for i in path:img = cv2.imread(i)  # 读取图片img = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb)[:, :, 0]   # 格式转换 取Y通道imgmod = Cropping(img, scale)  # 尺寸处理lr, hr = get_lr_hr(imgmod, scale)  # 获取lr和hr# 开切h, w = lr.shape[0], lr.shape[1]nx, ny = 0, 0for x in range(0, h - cut_size + 1, stride):nx += 1for y in range(0, w - cut_size + 1, stride):ny += 1sub_lr = lr[x:x + cut_size, y:y + cut_size]  # (17, 17)sub_hr = hr[x * scale:(x + cut_size) * scale,y * scale:(y + cut_size) * scale]  # (51, 51)# 转换为便于训练的格式sub_input = sub_lr.reshape([cut_size, cut_size, 1])sub_label = sub_hr.reshape([cut_size * scale, cut_size * scale, 1])LR.append(sub_input)HR.append(sub_label)size.append((nx, ny // nx))x_train = np.array(LR)y_train = np.array(HR)return x_train, y_train, size# 合并
def merge(images, size):(nx, ny) = size[0]_, h, w, d = images.shapeimg = np.zeros((h * nx, w * ny, d))for idx, image in enumerate(images):i = idx % ny     # 取余j = idx // ny    # 取整img[j * h:j * h + h, i * w:i * w + w, :] = imagereturn imgif __name__ == "__main__":# 训练集train_path = '../train/*.bmp'train_path_list = glob.glob(train_path)# 测试集test_path = '../test/Set5/*.bmp'test_path_list = glob.glob(test_path)# 放大倍数scale = 3cut_size = 17stride = 17x_train, y_train, size = load_data(train_path_list, scale=scale, stride=stride)print(x_train.shape)print(y_train.shape)# img = merge(y_train, size)a = np.uint8(x_train[50] * 255)b = np.uint8(y_train[50] * 255)# cv2.namedWindow('1', cv2.WINDOW_NORMAL)cv2.imshow('1', a)# cv2.namedWindow('2', cv2.WINDOW_NORMAL)cv2.imshow('2', b)cv2.waitKey(0)cv2.destroyAllWindows()

x_train.shape y_train.shape

2.搭建模型 model.py

第一层:5x5x64
第二层:3x3x32
第三层:3x3x(scalescalechannel)
scale为放大倍数,channel为输入图片的通道数

最后一层是亚像素卷积,用tf.nn.depth_to_space实现

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Activationclass ESPCN(Model):def __init__(self, scale, channel):super(ESPCN, self).__init__()self.scale = scaleself.channel = channelself.c1 = Conv2D(filters=64, kernel_size=5, strides=1, padding='same')self.a1 = Activation('relu')self.c2 = Conv2D(filters=32, kernel_size=3, strides=1, padding='same')self.a2 = Activation('relu')self.c3 = Conv2D(filters=self.scale * self.scale * self.channel, kernel_size=3, strides=1, padding='same')self.d2s = tf.nn.depth_to_spaceself.a3 = Activation('tanh')def call(self, inputs, training=None, mask=None):x = self.c1(inputs)x = self.a1(x)x = self.c2(x)x = self.a2(x)x = self.c3(x)x = self.d2s(x, self.scale, data_format='NHWC')y = self.a3(x)return y

3.训练 + 测试

import glob
import numpy as np
import tensorflow as tf
from model import ESPCN
from dataset import load_data, psnr, Cropping
import cv2
from skimage.measure import compare_psnr, compare_ssim# 训练集
train_path = '../291/*.bmp'
train_path_list = glob.glob(train_path)# 放大倍数 输入图片尺寸 输出图片尺寸 移动步长
scale = 3
channel = 1
cut_size = 17
stride = 7# 获取数据
x_train, y_train, train_size = load_data(train_path_list,scale=scale,cut_size=cut_size,stride=stride)test_path = '../test/Set5/*.bmp'
test_path_list = glob.glob(test_path)x_test, y_test, test_size = load_data(test_path_list,scale=scale,cut_size=cut_size,stride=stride)# 网络
model = ESPCN(scale=scale, channel=channel)# 配置
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),loss=tf.keras.losses.MSE,metrics=[psnr])# 调整学习率
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.99, patience=2, mode='auto')# 保存参数
checkpoint_save_path = './ESPCN2_checkpoint/ESPCN.ckpt'
# model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(monitor='val_psnr',filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True
)print(x_train.shape)
model.fit(x_train, y_train,batch_size=64,epochs=250,validation_data=(x_test, y_test),validation_freq=1,callbacks=[reduce_lr, cp_callback])i = 2
# 读取图片 转换YCrCb 规则化
img = cv2.imread(test_path_list[i])
img = Cropping(img, scale)
YCC = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb)
Y = YCC[:, :, 0]
color = YCC[:, :, 1:3]# 压缩图片
lr = cv2.resize(YCC, (YCC.shape[1] // scale, YCC.shape[0] // scale), interpolation=cv2.INTER_CUBIC)
lr_bgr = cv2.cvtColor(lr, cv2.COLOR_YCrCb2BGR)# 预测
h, w = lr.shape[0], lr.shape[1]
dt = lr[:, :, 0].reshape([1, h, w, 1]) / 255.
y_pre = model.predict(dt)
y_pre = y_pre.squeeze()
y_pre = y_pre * 255# 合并
result = np.zeros([YCC.shape[0], YCC.shape[1], 3], dtype=np.uint8)
result[:, :, 0] = y_pre
result[:, :, 1:3] = color
result = cv2.cvtColor(result, cv2.COLOR_YCrCb2BGR)# bicubic
bicubic = cv2.resize(lr_bgr, (YCC.shape[1], YCC.shape[0]), interpolation=cv2.INTER_CUBIC)
yB = cv2.cvtColor(bicubic, cv2.COLOR_BGR2YCrCb)[:, :, 0]cv2.imshow('original', img)
cv2.imshow('input', lr_bgr)
cv2.imshow('output1', result)
cv2.imshow('output2', bicubic)
cv2.waitKey(0)
cv2.destroyAllWindows()print('-' * 50)
psnr1 = compare_psnr(Y, y_pre, 255)
psnr2 = compare_psnr(Y, yB, 255)
print(psnr1, psnr2)     # 23.180898231092254 21.97495909635292print('-' * 50)
ssim1 = compare_ssim(Y, y_pre, win_size=11, data_range=255, multichannel=True)
ssim2 = compare_ssim(Y, yB, win_size=11, data_range=255, multichannel=True)
print(ssim1, ssim2)     # 0.8354814229799264 0.791309318119287

原图

输入

双插值

ESPCN

psnr

ssim

TensorFlow2.0 实现 ESPCN相关推荐

  1. 【深度学习】(6) tensorflow2.0使用keras高层API

    各位同学好,今天和大家分享一下TensorFlow2.0深度学习中借助keras的接口减少神经网络代码量.主要内容有: 1. metrics指标:2. compile 模型配置:3. fit 模型训练 ...

  2. 【TensorFlow2.0】(7) 张量排序、填充、复制、限幅、坐标选择

    各位同学好,今天和大家分享一下TensorFlow2.0中的一些操作.内容有: (1)排序 tf.sort().tf.argsort().top_k():(2)填充 tf.pad():(3)复制 tf ...

  3. 【TensorFlow2.0】(6) 数据统计,范数、最值、求和、均值、最值位置、唯一值、张量比较

    各位同学好,今天和大家分享一下TensorFlow2.0中的数据分析操作.内容有: (1)范数 tf.norm():(2)最值 tf.reduce_min(), tf.reduce_max()(3)求 ...

  4. 【TensorFlow2.0】(5) 数学计算、合并、分割

    各位同学好,今天和大家分享一下TensorFlow2.0中的数学运算方法.合并与分割.内容有: (1)基本运算:(2)矩阵相乘:(3)合并 tf.concat().tf.stack():(4)分割 t ...

  5. 【TensorFlow2.0】(4) 维度变换、广播

    各位同学好,今天我和大家分享一下TensorFlow2.0中有关数学计算的相关操作,主要内容有: (1) 改变维度:reshape():(2) 维度转置:transpose():(3) 增加维度:ex ...

  6. 【TensorFlow2.0】(3) 索引与切片操作

    各位同学好,今天我和大家分享一下TensorFlow2.0中索引与切片.内容有: (1) 给定每一维度的索引来获取数据:(2) 切片索引:(3) 省略号应用:(4) tf.gather() 方法:(5 ...

  7. 【TensorFlow2.0】(2) 创建tensor的方法

    各位同学好,今天和大家分享一下TensorFlow2.0中的tensor变量的创建方法.内容有: (1) 通过numpy和list创建tensor:(2) 创建全部为某个值的tensor:(3) 随机 ...

  8. 【TensorFlow2.0】(1) tensor数据类型,类型转换

    各位同学好,今天和大家分享一下TensorFlow2.0中的tensor数据类型,以及各种类型之间的相互转换方法. 1. tf.tensor 基础操作 scaler标量:1.2 vector向量:[1 ...

  9. mybatis-plus对datetime返回去掉.0_华为AI认证-TensorFlow2.0编程基础

    参考<HCIA-AI2.0培训教材><HCIA-AI2.0实验手册> 认证要求: 了解TensorFlow2.0是什么以及其特点 掌握TensorFlow2.0基础和高阶操作方 ...

最新文章

  1. 【转载】“error LNK1169: 找到一个或多个多重定义的符号”的解决方法
  2. 【安全牛学习笔记】手动漏洞挖掘(三)
  3. Leetcode-199二叉树的右视图(二叉树左视图)
  4. 找到一个或多个多重定义的符号_初中数学之相反数,总结规律,学会多重符号的化简...
  5. cordova与android通信_5:Cordova与原生交互--传值
  6. Swift开发之粒子动画的实现
  7. 二叉树和等于某值路径_Go刷LeetCode系列:二叉树(3)二叉树路径和
  8. 基于 Vue 的轻量级静态网站生成器 VuePress
  9. facebook react.js
  10. 一串数字中有两个只出现一次的数字其余都是成对相同,求这两个数
  11. Python脚本覆盖率分析方法介绍
  12. MATLAB代码:计及碳排放交易及多种需求响应的微网/虚拟电厂日前优化调度
  13. 基于React的可编辑在线简历模板
  14. 【深度学习】神经网络的学习过程
  15. signal信号值对应表
  16. Kotlin第三章:AndroidUI简介
  17. ucsd大学音乐计算机,音乐留学│综合名校UCSD音乐制作专业详解!
  18. linux docker安装nginx且测试elasticsearch分词
  19. python中seed的相关代码
  20. 2d有限元计算机仿真,超精密单点金刚石车削加工有限元仿真

热门文章

  1. Vue build打包之后,刷新页面出现404解决方案
  2. Xinlei cheng报告学习
  3. 关于meta标签详解
  4. 腾讯舆情团队谈:如何发现下一个现象级游戏?
  5. 白话文——过目不忘的sql索引是啥?
  6. RPG游戏自定义角色
  7. 导致服务器“中毒”的几种行为
  8. 三子棋的实现(C语言)
  9. QT,SSH开发——QSSH库编译成功率最高的方法
  10. 程序员开发游戏只为向女友求婚,每个关卡都是泪点!我是一个普通人,但是想成为你另一半玩家