pytorch版本的SRCNN代码一共分为6个.py文件,结构如下:

  • datasets.py
  • models.py
  • prepare.py
  • utils.py
  • test.py
  • train.py

  以上文件不分先后,执行时通过import…或者from…import…语句进行调用。以下解释import部分均省略,个别例外。

prepare.py

  readme.md中给出了不同放大倍数下的训练数据,验证数据和测试数据的下载地址。如果下载了直接把对应的路径写好就可以执行了,这里我们使用自己下载的数据通过使用prepare.py来制作训练和验证的h5格式的数据集。

import argparse
import glob
import h5py
import numpy as np
import PIL.Image as pil_image
from utils import convert_rgb_to_y#该函数用来创建自己的h5数据,包括俩个函数:对训练数据的处理和验证部分的处理。
def train(args):h5_file = h5py.File(args.output_path, 'w')'''def是python的关键字,用来定义函数。这里通过def定义名为train的函数,函数的参数为args,args这个参数通过外部命令行传入output的路径,通过h5py.File()方法的w模式--创建文件自己自写,已经存在的文件会被覆盖,文件的路径是通过args.output_path来传入'''lr_patches = []hr_patches = []'''创建俩个空列表:lr_patches和hr_patches(通过ctrl左键该变量名查看在其他位置的引用)'''for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))):'''这部分代码的目的就是搜索指定文件夹下的文件并排序,for这一句包含了几个知识点:1.{}.format():-->格式化输出函数,从args.images_dir路径中格式化输出路径2.glob.glob():-->返回所有匹配的文件路径列表,将1得到的路径中的所有文件返回3.sorted():-->排序,将2得到的所有文件按照某种顺序返回,,默认是升序4.for x in *:   -->循换输出'''hr = pil_image.open(image_path).convert('RGB')'''1.***.open():是PIL图像库的函数,用来从image_path中加载图像2.***.convert():是PIL图像库的函数,用来转换图像的模式'''hr_width = (hr.width // args.scale) * args.scalehr_height = (hr.height // args.scale) * args.scalehr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)#缩放处理hr = np.array(hr).astype(np.float32)lr = np.array(lr).astype(np.float32)hr = convert_rgb_to_y(hr)lr = convert_rgb_to_y(lr)'''" / "  表示浮点数除法,返回浮点结果;" // " 表示整数除法,返回不大于结果的一个最大的整数,也就是向下取整这里的hr是输入的原图,先进行mod和缩放的预处理,lr是hr在mod之后经过scale的结果,得到的lr再经过缩放处理得到最终要用的lr的图片resize():缩放操作np.array():将列表list或元组tuple转换为ndarray数组astype():转换数组的数据类型convert_rgb_to_y():将图像从RGB格式转换为Y通道格式的图片假设原始输入图像为(321,481,3)-->依次为高,宽,通道数1.先mod,之后hr的图像尺寸为(320,480,3)2.对hr图像进行双三次上采样放大操作3.将hr//scale进行双三次上采样放大操作之后×scale得到lr4.接着进行通道数转换和类型转换'''for i in range(0, lr.shape[0] - args.patch_size + 1, args.stride):'''图像的shape是宽度、高度和通道数,shape[0]是指图像的高度=320;shape[1]是图像的宽度=480;shape[2]是指图像的通道数'''for j in range(0, lr.shape[1] - args.patch_size + 1, args.stride):lr_patches.append(lr[i:i + args.patch_size, j:j + args.patch_size])hr_patches.append(hr[i:i + args.patch_size, j:j + args.patch_size])lr_patches = np.array(lr_patches)hr_patches = np.array(hr_patches)#把得到的数据转化为数组类型h5_file.create_dataset('lr', data=lr_patches)h5_file.create_dataset('hr', data=hr_patches)h5_file.close()def eval(args):h5_file = h5py.File(args.output_path, 'w')lr_group = h5_file.create_group('lr')hr_group = h5_file.create_group('hr')for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))):hr = pil_image.open(image_path).convert('RGB')hr_width = (hr.width // args.scale) * args.scalehr_height = (hr.height // args.scale) * args.scalehr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)hr = np.array(hr).astype(np.float32)lr = np.array(lr).astype(np.float32)hr = convert_rgb_to_y(hr)lr = convert_rgb_to_y(lr)lr_group.create_dataset(str(i), data=lr)hr_group.create_dataset(str(i), data=hr)h5_file.close()if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--images-dir', type=str,default='/home/dushuai/word/SRCNN_pytorch/evaldata')parser.add_argument('--output-path', type=str,default='/home/dushuai/word/SRCNN_pytorch/evalout/evalout.h5')parser.add_argument('--patch-size', type=int, default=33)parser.add_argument('--stride', type=int, default=14)parser.add_argument('--scale', type=int, default=2)parser.add_argument('--eval', action='store_true')args = parser.parse_args()if not args.eval:train(args)else:eval(args)
'''
最后这个if..else..要注意一下,是和parser传入的最后一个参数有关的,它是用来决定使用哪个函数来生成h5文件,因为有俩个不同的函数train和eval生成对应的h5文件。该参数的具体使用方法如下
'''

实验:action

  在我看来这是个很鸡肋的参数设置,但是存在即合理,我们只需要明白它就ok了。

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--eval', action='store_false')
args = parser.parse_args()
def main():x = args.evalprint(x)if __name__ == '__main__':main()

  可以看到我上边的action=‘store_false’,但是边一个是直接在IDE中run的结果是True,而我通过命令行运行得到的结果却是false,这是为什么?


  顾名思义,store_flase就是存储一个bool值false,也就是说在该参数在被激活时它会输出store存储的值也就是这里我通过命令行得到的值,而IDE得到的值没有激活该参数,得到的是它的默认值True.

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--eval', action='store_false')
args = parser.parse_args()def a():print('a')
def b():print('b')
def main():x = args.evalprint(x)if not args.eval:print(args.eval)a()else:print(args.eval)b()
if __name__ == '__main__':main()


  在SRCNN的预处理中可以通过修改action中store的值也可以通过if not args.eval来调整函数运行哪个函数来得到对应的结果。

datasets.py

一共包含俩个类TrainDataset()和EvalDataset(),分别用来加载prepare.py制作的训练和验证俩个数据集的。这部分想自己写,但是发现了一篇不错的博客,传送门在此

models.py

这部分更为简单,首先定义了模型类SRCNN,它继承自父类nn.Module。super这句是对继承自父类的属性进行初始化。接下来就是对卷积层的定义和前向传播的定义。

utils.py

这个utils.py相当于是工具类,定义网络需要使用的各种函数。这个文件一共包括了四个函数和一个类,至于test和train都很简单,很容易看懂,略

参考文献:
1.if name == ‘main’:
2.Python之argparse

SRCNN-pytoch代码讲解相关推荐

  1. 手把手教你如何做建模竞赛(baseline代码讲解)

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 1.大赛背景 随着科技发展,银行陆续打造了线上线下.丰富多样的客户触 ...

  2. 【资源】Faster R-CNN原理及代码讲解电子书

    <Faster R-CNN原理及代码讲解>是首发于GiantPandaCV公众号的教程,针对陈云大佬实现的Faster R-CNN代码讲解,Github链接如下: https://gith ...

  3. 激光-视觉-IMU-GPS融合SLAM算法梳理和代码讲解

    应用背景介绍 自主导航是机器人与自动驾驶的核心功能,而SLAM技术是实现自主导航的前提与关键.现有的机器人与自动驾驶车辆往往会安装激光雷达,相机,IMU,GPS等多种模态的传感器,而且已有许多优秀的激 ...

  4. 彻底剖析激光-视觉-IMU-GPS融合SLAM算法:理论推导、代码讲解和实战

    应用背景介绍 自主导航是机器人与自动驾驶的核心功能,而SLAM技术是实现自主导航的前提与关键.现有的机器人与自动驾驶车辆往往会安装激光雷达,相机,IMU,GPS等多种模态的传感器,而且已有许多优秀的激 ...

  5. 彻底搞透视觉三维重建:原理剖析、代码讲解、及优化改进

    视觉三维重建 = 定位定姿 + 稠密重建 + surface reconstruction +纹理贴图.三维重建技术是计算机视觉的重要技术之一,基于视觉的三维重建技术通过深度数据获取.预处理.点云配准 ...

  6. mysql多表联查分页_sqlserver多表联合查询和多表分页查询的代码讲解

    sqlserver多表联合查询和多表分页查询的代码讲解 发布时间:2020-05-14 14:42:07 来源:亿速云 阅读:700 作者:Leah 这篇文章主要为大家详细介绍了sqlserver多表 ...

  7. python中的object是什么意思_Python object类中的特殊方法代码讲解

    python版本:3.8class object: """ The most base type """ # del obj.xxx或del ...

  8. 三层代码讲解--第一课

    主题:三层代码讲解--第一课 主持人:老吴 时间:2004-05-24 2004-05-24 10:47:00 天之痕_若虹(86278566) 請教大家一個問題好嗎 2004-05-24 10:47 ...

  9. WPF第一章(XAML前台标记语言(Chapter02代码讲解))

    XAML前台标记语言(Chapter2代码讲解)     很不好意思,工作有点忙,博客停了两天.相对于一门语言的学习,理论知识和实践必不可少,大多数时间我们要用,对于代码也是,一边不行可以看两遍,实在 ...

  10. python代码大全p-python处理写入数据代码讲解

    首先要利用python进行读取整个文件,然后逐行读取,最后写入数据.具体实现步骤参考如下: 步骤一.读取整个文件 先在当前目录下创建一个TXT文件,例如文件名为'pi_digits.txt'的文本文件 ...

最新文章

  1. 双屏全屏跳回到主屏_双屏双倍乐趣?华硕灵耀X2 Duo笔记本评测
  2. 计算机在中职教育中的运用论文,中职计算机教育的相关论文(2)
  3. [VN2020 公开赛]CSRe
  4. 云炬Android开发笔记 5-8文件下载功能设计与实现
  5. Push代码:Git@github.com: Permission denied (publickey)
  6. HDU -2546饭卡(01背包+贪心)
  7. ibm服务器和微软,微软与IBM不得不说的事情
  8. 2018-2019 20165226 Exp9 Web安全基础
  9. 数据库MySQL/mariadb知识点——数据类型
  10. Mysql实战练习之简单图书管理系统
  11. JDE 开发-部分系统函数
  12. 随记:PNP和NPN三极管区别
  13. Kinect for Windows SDK开发入门(五):景深数据处理 下
  14. 计算机思维在化学上的应用,【科学思维】化隐性为显性思想在化学中的应用
  15. Adjustment Office
  16. 小米五怎么设置锁屏显示无服务器,小米手机怎么设置锁屏状态下不能关机 - 卡饭网...
  17. 集成学习-Bagging和Pasting
  18. 超全!常用的 70 个数据分析网址
  19. dns的基本设定(一)
  20. hrbust 2155 钱多多【水题】

热门文章

  1. 现代经济中的货币创造
  2. 基于pywifi库的暴力破解wifi方法
  3. 移动硬盘无法在Mac上装载如何修复?
  4. 100 余个超实用网站
  5. SPSS入门教程——如何分析两个变量之间的关联度?
  6. 【传智播客】Javaweb程序设计任务教程 黑马程序员 第一章 课后答案
  7. 网络安全和CTF相关内容
  8. 推荐系统之基于用户行为数据的协同过滤(Collaborative Filtering)
  9. ANSYS workbench 有限元分析 学习
  10. 计算机考研408真题和答案