SRCNN-pytoch代码讲解
pytorch版本的SRCNN代码一共分为6个.py文件,结构如下:
- datasets.py
- models.py
- prepare.py
- utils.py
- test.py
- train.py
以上文件不分先后,执行时通过import…或者from…import…语句进行调用。以下解释import部分均省略,个别例外。
prepare.py
readme.md中给出了不同放大倍数下的训练数据,验证数据和测试数据的下载地址。如果下载了直接把对应的路径写好就可以执行了,这里我们使用自己下载的数据通过使用prepare.py来制作训练和验证的h5格式的数据集。
import argparse
import glob
import h5py
import numpy as np
import PIL.Image as pil_image
from utils import convert_rgb_to_y#该函数用来创建自己的h5数据,包括俩个函数:对训练数据的处理和验证部分的处理。
def train(args):h5_file = h5py.File(args.output_path, 'w')'''def是python的关键字,用来定义函数。这里通过def定义名为train的函数,函数的参数为args,args这个参数通过外部命令行传入output的路径,通过h5py.File()方法的w模式--创建文件自己自写,已经存在的文件会被覆盖,文件的路径是通过args.output_path来传入'''lr_patches = []hr_patches = []'''创建俩个空列表:lr_patches和hr_patches(通过ctrl左键该变量名查看在其他位置的引用)'''for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))):'''这部分代码的目的就是搜索指定文件夹下的文件并排序,for这一句包含了几个知识点:1.{}.format():-->格式化输出函数,从args.images_dir路径中格式化输出路径2.glob.glob():-->返回所有匹配的文件路径列表,将1得到的路径中的所有文件返回3.sorted():-->排序,将2得到的所有文件按照某种顺序返回,,默认是升序4.for x in *: -->循换输出'''hr = pil_image.open(image_path).convert('RGB')'''1.***.open():是PIL图像库的函数,用来从image_path中加载图像2.***.convert():是PIL图像库的函数,用来转换图像的模式'''hr_width = (hr.width // args.scale) * args.scalehr_height = (hr.height // args.scale) * args.scalehr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)#缩放处理hr = np.array(hr).astype(np.float32)lr = np.array(lr).astype(np.float32)hr = convert_rgb_to_y(hr)lr = convert_rgb_to_y(lr)'''" / " 表示浮点数除法,返回浮点结果;" // " 表示整数除法,返回不大于结果的一个最大的整数,也就是向下取整这里的hr是输入的原图,先进行mod和缩放的预处理,lr是hr在mod之后经过scale的结果,得到的lr再经过缩放处理得到最终要用的lr的图片resize():缩放操作np.array():将列表list或元组tuple转换为ndarray数组astype():转换数组的数据类型convert_rgb_to_y():将图像从RGB格式转换为Y通道格式的图片假设原始输入图像为(321,481,3)-->依次为高,宽,通道数1.先mod,之后hr的图像尺寸为(320,480,3)2.对hr图像进行双三次上采样放大操作3.将hr//scale进行双三次上采样放大操作之后×scale得到lr4.接着进行通道数转换和类型转换'''for i in range(0, lr.shape[0] - args.patch_size + 1, args.stride):'''图像的shape是宽度、高度和通道数,shape[0]是指图像的高度=320;shape[1]是图像的宽度=480;shape[2]是指图像的通道数'''for j in range(0, lr.shape[1] - args.patch_size + 1, args.stride):lr_patches.append(lr[i:i + args.patch_size, j:j + args.patch_size])hr_patches.append(hr[i:i + args.patch_size, j:j + args.patch_size])lr_patches = np.array(lr_patches)hr_patches = np.array(hr_patches)#把得到的数据转化为数组类型h5_file.create_dataset('lr', data=lr_patches)h5_file.create_dataset('hr', data=hr_patches)h5_file.close()def eval(args):h5_file = h5py.File(args.output_path, 'w')lr_group = h5_file.create_group('lr')hr_group = h5_file.create_group('hr')for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))):hr = pil_image.open(image_path).convert('RGB')hr_width = (hr.width // args.scale) * args.scalehr_height = (hr.height // args.scale) * args.scalehr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)hr = np.array(hr).astype(np.float32)lr = np.array(lr).astype(np.float32)hr = convert_rgb_to_y(hr)lr = convert_rgb_to_y(lr)lr_group.create_dataset(str(i), data=lr)hr_group.create_dataset(str(i), data=hr)h5_file.close()if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--images-dir', type=str,default='/home/dushuai/word/SRCNN_pytorch/evaldata')parser.add_argument('--output-path', type=str,default='/home/dushuai/word/SRCNN_pytorch/evalout/evalout.h5')parser.add_argument('--patch-size', type=int, default=33)parser.add_argument('--stride', type=int, default=14)parser.add_argument('--scale', type=int, default=2)parser.add_argument('--eval', action='store_true')args = parser.parse_args()if not args.eval:train(args)else:eval(args)
'''
最后这个if..else..要注意一下,是和parser传入的最后一个参数有关的,它是用来决定使用哪个函数来生成h5文件,因为有俩个不同的函数train和eval生成对应的h5文件。该参数的具体使用方法如下
'''
实验:action
在我看来这是个很鸡肋的参数设置,但是存在即合理,我们只需要明白它就ok了。
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--eval', action='store_false')
args = parser.parse_args()
def main():x = args.evalprint(x)if __name__ == '__main__':main()
可以看到我上边的action=‘store_false’,但是边一个是直接在IDE中run的结果是True,而我通过命令行运行得到的结果却是false,这是为什么?
顾名思义,store_flase就是存储一个bool值false,也就是说在该参数在被激活时它会输出store存储的值也就是这里我通过命令行得到的值,而IDE得到的值没有激活该参数,得到的是它的默认值True.
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--eval', action='store_false')
args = parser.parse_args()def a():print('a')
def b():print('b')
def main():x = args.evalprint(x)if not args.eval:print(args.eval)a()else:print(args.eval)b()
if __name__ == '__main__':main()
在SRCNN的预处理中可以通过修改action中store的值也可以通过if not args.eval来调整函数运行哪个函数来得到对应的结果。
datasets.py
一共包含俩个类TrainDataset()和EvalDataset(),分别用来加载prepare.py制作的训练和验证俩个数据集的。这部分想自己写,但是发现了一篇不错的博客,传送门在此
models.py
这部分更为简单,首先定义了模型类SRCNN,它继承自父类nn.Module。super这句是对继承自父类的属性进行初始化。接下来就是对卷积层的定义和前向传播的定义。
utils.py
这个utils.py相当于是工具类,定义网络需要使用的各种函数。这个文件一共包括了四个函数和一个类,至于test和train都很简单,很容易看懂,略
参考文献:
1.if name == ‘main’:
2.Python之argparse
SRCNN-pytoch代码讲解相关推荐
- 手把手教你如何做建模竞赛(baseline代码讲解)
点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 1.大赛背景 随着科技发展,银行陆续打造了线上线下.丰富多样的客户触 ...
- 【资源】Faster R-CNN原理及代码讲解电子书
<Faster R-CNN原理及代码讲解>是首发于GiantPandaCV公众号的教程,针对陈云大佬实现的Faster R-CNN代码讲解,Github链接如下: https://gith ...
- 激光-视觉-IMU-GPS融合SLAM算法梳理和代码讲解
应用背景介绍 自主导航是机器人与自动驾驶的核心功能,而SLAM技术是实现自主导航的前提与关键.现有的机器人与自动驾驶车辆往往会安装激光雷达,相机,IMU,GPS等多种模态的传感器,而且已有许多优秀的激 ...
- 彻底剖析激光-视觉-IMU-GPS融合SLAM算法:理论推导、代码讲解和实战
应用背景介绍 自主导航是机器人与自动驾驶的核心功能,而SLAM技术是实现自主导航的前提与关键.现有的机器人与自动驾驶车辆往往会安装激光雷达,相机,IMU,GPS等多种模态的传感器,而且已有许多优秀的激 ...
- 彻底搞透视觉三维重建:原理剖析、代码讲解、及优化改进
视觉三维重建 = 定位定姿 + 稠密重建 + surface reconstruction +纹理贴图.三维重建技术是计算机视觉的重要技术之一,基于视觉的三维重建技术通过深度数据获取.预处理.点云配准 ...
- mysql多表联查分页_sqlserver多表联合查询和多表分页查询的代码讲解
sqlserver多表联合查询和多表分页查询的代码讲解 发布时间:2020-05-14 14:42:07 来源:亿速云 阅读:700 作者:Leah 这篇文章主要为大家详细介绍了sqlserver多表 ...
- python中的object是什么意思_Python object类中的特殊方法代码讲解
python版本:3.8class object: """ The most base type """ # del obj.xxx或del ...
- 三层代码讲解--第一课
主题:三层代码讲解--第一课 主持人:老吴 时间:2004-05-24 2004-05-24 10:47:00 天之痕_若虹(86278566) 請教大家一個問題好嗎 2004-05-24 10:47 ...
- WPF第一章(XAML前台标记语言(Chapter02代码讲解))
XAML前台标记语言(Chapter2代码讲解) 很不好意思,工作有点忙,博客停了两天.相对于一门语言的学习,理论知识和实践必不可少,大多数时间我们要用,对于代码也是,一边不行可以看两遍,实在 ...
- python代码大全p-python处理写入数据代码讲解
首先要利用python进行读取整个文件,然后逐行读取,最后写入数据.具体实现步骤参考如下: 步骤一.读取整个文件 先在当前目录下创建一个TXT文件,例如文件名为'pi_digits.txt'的文本文件 ...
最新文章
- 双屏全屏跳回到主屏_双屏双倍乐趣?华硕灵耀X2 Duo笔记本评测
- 计算机在中职教育中的运用论文,中职计算机教育的相关论文(2)
- [VN2020 公开赛]CSRe
- 云炬Android开发笔记 5-8文件下载功能设计与实现
- Push代码:Git@github.com: Permission denied (publickey)
- HDU -2546饭卡(01背包+贪心)
- ibm服务器和微软,微软与IBM不得不说的事情
- 2018-2019 20165226 Exp9 Web安全基础
- 数据库MySQL/mariadb知识点——数据类型
- Mysql实战练习之简单图书管理系统
- JDE 开发-部分系统函数
- 随记:PNP和NPN三极管区别
- Kinect for Windows SDK开发入门(五):景深数据处理 下
- 计算机思维在化学上的应用,【科学思维】化隐性为显性思想在化学中的应用
- Adjustment Office
- 小米五怎么设置锁屏显示无服务器,小米手机怎么设置锁屏状态下不能关机 - 卡饭网...
- 集成学习-Bagging和Pasting
- 超全!常用的 70 个数据分析网址
- dns的基本设定(一)
- hrbust 2155 钱多多【水题】