相信很多朋友在做视觉工作入门的时候首先都会接触到读库代码,那么在图像质量评价方向中的读库代码该如何实现呢?接下来我会给大家介绍一段详细的读库代码,代码框架我主要是从2014年kang等人利用CNN进行图像质量评价的源代码进行修改的。代码基于pytorch框架。

在进行读库代码介绍前,我们需要有一些先验知识储备。

Dataset和Dataloader

PyTorch提供torch.utils.data.DataLoader 和 torch.utils.data.Dataset允许你使用预下载的数据集或自己制作的数据。

Dataloader

        这边我直接用一段代码来解释,IQAdatast在下面提到了。这里偷个懒就不说明dataloader里的参数了,代码浅显易懂,根据参数名字也能判断。

train_data = IQAdatasets(root=config.folder_path[args.dataset], index=train_index, transform=transforms,transform_gray = transforms_gray , patch_num=args.train_patch_num)train_loader = torch.utils.data.DataLoader(train_data,batch_size=args.batch_size,shuffle=True,pin_memory=True,num_workers=0)

Dataset

我的上一篇博客提到了torchvision.datasets,链接:torchvision中的dataset

那么torchvision.datasets和这里的torch.utils.data.Dataset有什么区别呢?

1.torchvision.datasets

从名字中就可以看到,这个datasets不仅是小写,说明它作用的范围小,而且加了个s,这就说明它是复数,为什么是复数呢,是因为如果我们需要不自己设置训练数据集而去使用官方给的数据集的时候,它里边有。

2.torch.utils.data.Dataset

这个模块更偏重于自己独立的去建立一个新的训练数据集,需要自己去设定参数

代码详解

自定义数据集类必须实现三个函数:__init__, __len__, 和__getitem__

import torch
import torch.utils.data as data
import torchvision
from PIL import Image
import os
import os.pathdef pil_loader(path):with open(path, 'rb') as f:img = Image.open(f)return img.convert('RGB')def gray_loader(path):with open(path, 'rb') as f:img = Image.open(f)return img.convert('L')class IQAdatasets(data.Dataset):def __init__(self, root, index, transform,transform_gray, patch_num):imgpath = []ref_names = []for line1 in open("data/im_names.txt", "r"):line1 = line1.strip()path = os.path.join(root, line1)# print(path)imgpath.append(path)labels = []for line5 in open("data/mos.txt", "r"):line5 = float(line5.strip())labels.append(line5)sample = []for i, item in enumerate(index):for aug in range(patch_num):# print(item)sample.append((imgpath[item ], labels[item ]))self.samples = sampleself.transform = transformself.transform_gray = transform_graydef __getitem__(self, index):path, target = self.samples[index]sample = pil_loader(path)sample = self.transform(sample)sample_gra = pil_loader(path)sample_gray = self.transform_gray(sample_gra)return (sample , sample_gray), targetdef __len__(self):length = len(self.samples)return length

在上述代码中,我们重写了data.Dataset方法,在这里,我返回的主要是三通道图像和灰度图像的数据集,代码中的 data/im_names.txt是我自己从数据集中的mat文件中提取出来的标签信息,如果大家要用其他数据集的话,这里也可以跟我一样提取出来到txt中或者用h5py来提取LivefullInfo.mat等数据集中的标签。

这段代码最主要的地方就在于:

def __getitem__(self, index):path, target = self.samples[index]sample = pil_loader(path)sample = self.transform(sample)sample_gra = pil_loader(path)sample_gray = self.transform_gray(sample_gra)return (sample , sample_gray), target

这里我们返回了(sample , sample_gray), target三类数据,分别是三通道图像、灰度图像、图像标签(在IQA任务中,图像标签就是MOS值、STD值、图像路径等)

以上就是自定义数据集来做IQA任务的读库代码,现在我们来进行调试一下:

if __name__ == '__main__':device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")parser = argparse.ArgumentParser()parser.add_argument('--dataset', dest='dataset', type=str, default='live',help='Support datasets: livec|koniq-10k|bid|live|csiq|tid2013')parser.add_argument('--train_patch_num', dest='train_patch_num', type=int, default=25,help='Number of sample patches from training image')parser.add_argument('--test_patch_num', dest='test_patch_num', type=int, default=25,help='Number of sample patches from testing image')parser.add_argument('--lr', dest='lr', type=float, default=1e-5, help='Learning rate')parser.add_argument('--weight_decay', dest='weight_decay', type=float, default=5e-4, help='Weight decay')parser.add_argument('--lr_ratio', dest='lr_ratio', type=int, default=10,help='Learning rate ratio for hyper network')parser.add_argument('--batch_size', dest='batch_size', type=int, default=32, help='Batch size')parser.add_argument('--epochs', dest='epochs', type=int, default=5, help='Epochs for training')parser.add_argument('--patch_size', dest='patch_size', type=int, default=224,help='Crop size for training & testing image patches')parser.add_argument('--train_test_num', dest='train_test_num', type=int, default=10, help='Train-test times')args = parser.parse_args()sel_num = config.img_num[args.dataset]train_index = sel_num[0:int(round(0.8 * len(sel_num)))]test_index = sel_num[int(round(0.8 * len(sel_num))):len(sel_num)]transforms = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.RandomCrop(size=args.patch_size),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225))])transforms_gray = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.RandomCrop(size=args.patch_size),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=(0.5),std=(0.5))])train_data = IQAdatasets(root=config.folder_path[args.dataset], index=train_index, transform=transforms,transform_gray = transforms_gray , patch_num=args.train_patch_num)train_loader = torch.utils.data.DataLoader(train_data,batch_size=args.batch_size,shuffle=True,pin_memory=True,num_workers=0)for index , (sample , target) in tqdm(enumerate(train_loader)):x_rgb = sample[0].to(device)x_gray = sample[1].to(device)

我这里是对训练数据进行了随机裁剪成224*224的大小并进行归一化。 这段代码最独特的设计在于,我们这里对训练集数据和测试集数据进行了train_patch_num和test_patch_num的设计,多次随机裁剪,可以获取到图像的更多信息。

之后我会继续分享图像平均切割的读库代码,原理都是类似的。

图像质量评价(IQA)读库代码详细介绍相关推荐

  1. mysql 死锁 代码_MySQL死锁问题解决的代码详细介绍

    一次MySQL死锁问题解决 一.环境CentOS, MySQL 5.6.21-70, JPA 问题场景:系统有定时批量更新数据状态操作,每次更新上千条记录,表中总记录数约为500W左右. 二.错误日志 ...

  2. 【机器学习】 - 关于图像质量评价IQA(Image Quality Assessment)

    图像质量评价(Image Quality Assessment,IQA)是图像处理中的基本技术之一,主要通过对图像进行特性分析研究,然后评估出图像优劣(图像失真程度). 主要的目的是使用合适的评价指标 ...

  3. 如何使用Sentinel做流量控制?此文将附代码详细介绍Sentinel几种限流模式

    前言:大家好,我是小威,24届毕业生,在一家满意的公司实习.本篇文章将详细介绍Sentinel的两种限流模式,由于篇幅原因,后续文章将详细介绍Sentinel的其他三种. 如果文章有什么需要改进的地方 ...

  4. 用php做论坛头像代码,详细介绍PHP针对多用户实现头像更换代码示例

    一个网站,其实说白了就是某几个特定功能的组合,而更换用户头像就在这些功能之中.今天就来做个测试,针对不同的用户,实现头像上传功能. 成品图 思路针对不同的用户上传头像,我们要为每一个已登录的用户创建一 ...

  5. php伪静态教程,PHP伪静态的图文代码详细介绍

    前言 关于伪静态的话题,众说纷纭.我不是很在意这些讨论,但是有一些大牛给出的看法确实是很有味道的, 而且也是比较的公正.使用了伪静态的话,会耗费CPU资源,但是对于SEO什么的更加有益: 不适用伪静态 ...

  6. python 相关性分析原理及代码详细介绍

    一.相关性分析简介 相关性分析(correlation analysis)是指对两个或多个具备相关关系的变量进行线性相关分析,从而衡量变量间的相关程度或密切程度.相关性程度即为相关性系数R,R的取值范 ...

  7. php加入购物车怎样实现_php 实现简单加入购物车的图文代码详细介绍

    以下是本站的两部视频教程,欢迎观看 课程简介:<JS和jQuery开发购物车教程>通过JavaScript和jQuery两种方式实现购物车功能. 课程简介:<JavaScript实现 ...

  8. php用什么工具调试代码,详细介绍利用开源的DebugBar工具调试PHP代码(图文)

    DebugBar 是一个免费和开源的应用,能够集成至任何PHP项目中,并收集和展示分析数据. 它有没有任何依赖,支持Ajax请求,包括常用开发库的通用数据采集器和收集器. 相信用过Laravel的调试 ...

  9. 传奇脚本显示服务器开区时间代码,上百种开区脚本代码详细介绍以及脚本示例...

    变量名必须大写: 通用变量: ------------------------- $SERVERNAME //服务器名称 $SERVERIP //服务器IP $WEBSITE //网站 在String ...

最新文章

  1. Git安装配置(Linux)
  2. 真相了 | 敲代码时,程序员戴耳机究竟在听什么?
  3. SQL总结(三)其他查询
  4. OleCommand的SqlText占位符的问题
  5. PostgreSQL源码分析
  6. 学习Python3:201701030
  7. ubuntu16.4中创建帐户
  8. 笔记-项目沟通管理-沟通基本原则
  9. linux gcc 简单使用记录01
  10. 中国数学会副理事长田刚委员:建议从四个方面加强教师队伍建设
  11. redhat 7 防火墙配置
  12. 制作加密狗程序_【火腿DIY】用于SDR应用程序的自定义热键键盘 | 视障人士的选择...
  13. latex表格自动换行
  14. Linux 最小化安装后的主机名与域名的修改
  15. Golang publish module
  16. 第一部分 知己知彼
  17. 计算机组装维护教学工作总结,计算机组装与维护教师工作总结_2
  18. Android通过反射EthernetManager Api设置以太网为静态IP地址或者动态获取IP
  19. # Android实习周记-9.29
  20. numpy 学习汇总5-数组运算 tcy

热门文章

  1. LDO的PSRR测量
  2. 微信h5录音上传到自己的服务器,微信js-sdk 录音功能的示例代码
  3. 【Python自动化测试5】列表与元组知识讲解
  4. MySQL基础知识及其基本相关操作
  5. C语言之运算符 (笔记)
  6. 浅谈网络营销 | 什么是网络营销?
  7. flutter 自定义日历
  8. RTMP协议深度解析:从原理到实践,掌握实时流媒体传输技术
  9. React基础知识点
  10. 射雕英雄传入选北京朝阳区小学图书馆基本书目