工作和学习中设计一个神经网络中经常需要设计一个数据载入器。首先第一件事我们要根据我们的任务要求确定一个数据提供的方法。如我们是一个分类任务,我们就需要读取数据和数据本身对应的标签。

   

1                                                   2

除了分类任务之外当然还有一些图像到图像的任务,如超分辨率重建,图像去噪等任务那么对应的标签就是一张高分辨率的图像或清晰的无噪声图像。

第二件事就是根据我们的数据格式来确定数据的读取方式,以分类为例,每个文件夹下面的图像对应的为一个类别的图像的时候我们可以依次读取每个文件,并将每个文件编码成对应的0到n个类别。可以根据opencv,PIL等库读取图像opencv读取的是BGR格式的numpy数组,而PIL读取的是Image的对象。

import cv2
import PIL.Image as Im
import numpy as npim=cv2.imread('./data_dir')
#转换成rgb
im=cv2.cvtColor(im,cv2.COLOR_BGR2RGB)
#将数据转换成Image对象
im=Im.fromarray(im).convert('RGB')
#Image 直接读取图片
im=Im.open('./data_dir','rgb')
#将Image的对象转换成numpy数组
im=np.asarray(im)

当然你的文件也可能是mat文件或者npy件或者h5py文件:

import scipy.io as si
import h5py
import numpy as np
#读取npy文件
data=np.load('test.npy')
#保存npy文件
np.save('./test.npy',data)
#读取h5py文件
f=h5py.File('./test.h5','r')#以读的方式打开文件可以根据字典的键值获取数据
data=f['data']
#保存h5文件
f=h5py.File('./test.h5','w')
f['data']=im
f['label']=label
f.cloase()
#读取mat文件mat和h5类似都是字典格式
data=si.loadmat('test.mat')
im=data['x']
label=data['y']
#保存mat文件
si.savemat('test.mat',{'x':im,'y':label})

不论是哪种数据格式我们都要考虑一个问题我们的数据量是一个怎样的数量级,如果数据集过大我们没有那么多的内存就会遇到超内存的问题。如果是小数据集我们可以直接一次性读取。大数据一般按照分批次读取或者特殊的数据格式来读取。

import os
import cv2
import numpy as np
#有时候我们需要将图片随机裁剪
def random_crop(image_ref,image_dis,num_output,size):h,w=image_ref.shape[:2]random_h=np.random.randint(h-size,size=num_output)random_w=np.random.randint(w-size,size=num_output) patches_dis=[]patches_ref=[]for i in range(num_output):patch_dis=image_dis[random_h[i]:random_h[i]+size,random_w[i]:random_w[i]+size]patch_ref=image_ref[random_h[i]:random_h[i]+size,random_w[i]:random_w[i]+size]patches_ref.append(patch_ref)patches_dis.append(patch_dis)return patches_ref,patches_disdef read_data(path):file_name=os.listdir(path)#获取所有文件的文件名称data=[]labels=[]for idx,fn in enumerate(file_name):#以idx作为标签如果标签是图片则以另外的函数读取im_dirs=path+'/'+fnim_path=os.listdir(im_dirs)#读取每个文件夹下所有图像的名称for n in im_path:im=cv2.imread(im_dirs+'/'+n)data.append(im)labels.append(idx)return np.asarray(data),np.asarray(labels)
#一次性读取所有的数据
data,labels= read_data(data_dir)
#将数据集乱序num_example=data.shape[0]
arr=np.arange(num_example)
np.random.shuffle(arr)
data=data[arr]
label=label[arr]#将数据集的80%划分为训练集
s=int(num_example*0.8)
x_train=data[:s]
y_train=label[:s]
x_val=data[s:]
y_val=label[s:]#按照批次将数据送入模型中
def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False):assert len(inputs) == len(targets)if shuffle:indices = np.arange(len(inputs))np.random.shuffle(indices)for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):if shuffle:excerpt = indices[start_idx:start_idx + batch_size]else:excerpt = slice(start_idx, start_idx + batch_size)yield inputs[excerpt], targets[excerpt]
for x,y in minibatches(x_train,y_train,128,shuffle=False):feed_dict={x1:x,y1:y}

上面的方法是一次性读取所有数据的,我们有时处理大数据的问题时就需要按照批次来读取了,这里推荐两种方法一种是基于tensorflow的tfrecords文件或者pytorch的Imagefolder两种方法:这里我们以这个数据集为例:http://download.tensorflow.org/example_images/flower_photos.tgz

是一个关于花分类的数据集:

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, modelsdata_dir = 'E:/PytorchData/flower_photos'def load_split_train_test(data_dir,valid_size = 0.2):train_trainsforms = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),])test_trainsforms = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),])train_data = datasets.ImageFolder(datadir,transform=train_trainsforms)test_data = datasets.ImageFolder(datadir,transform=test_trainsforms)num_train = len(train_data)                               # 训练集数量indices = list(range(num_train))                          # 训练集索引split = int(np.floor(valid_size * num_train))             # 获取20%数据作为验证集np.random.shuffle(indices)                                # 打乱数据集from torch.utils.data.sampler import SubsetRandomSamplertrain_idx, test_idx = indices[split:], indices[:split]    # 获取训练集,测试集train_sampler = SubsetRandomSampler(train_idx)            # 打乱训练集,测试集test_sampler  = SubsetRandomSampler(test_idx)#============数据加载器:加载训练集,测试集===================train_loader = DataLoader(train_data,sampler=train_sampler,batch_size=64)test_loader = DataLoader(test_data,sampler=test_sampler,batch_size=64)return train_loader,test_loadertrain_loader,test_loader = load_split_train_test(data_dir, 0.2)
for inputs,labels in train_loader:#这里inputs,和labels输出的Tensor我们想看到输出的结果需要转换成numpy数组inputs,labels=np.asarray(inputs),np.asarray(labels)print(inputs.shape)#在pytorch中我们经常将数据放入到GPU中我们直接打印出来数据时会报错因此,我们需要将数据放入cpu中转换成numpy数组

上述DataLoader中实际上还有很多参数,这里没有列举出来如当内存比较充足的时候可以将pin_memeroy设置成True,将num_worker设置成8等方法可以加速数据的加载。除了pytorch之外还有tensorflow也提供了专门的数据接口,如常用的tfrecords,首先我们需要将自己的数据集保存成tfrecords文件

import os
import tensorflow as tf
from PIL import Image  #注意Image,如果是cv2需要转换成Image对象
import matplotlib.pyplot as plt
import numpy as npdata_dir='E:/PytorchData/flower_photos/'
classes={'1','2','3','4','5'} #将花数据改成1到5个类别
writer= tf.python_io.TFRecordWriter("./flower_classfication.tfrecords") #要生成的文件for index,name in enumerate(classes):class_path=data_dir+name+'/'for img_name in os.listdir(class_path): img_path=class_path+img_name #每一个图片的地址img=Image.open(img_path)img= img.resize((128,128))img_raw=img.tobytes()#将图片转化为二进制格式example = tf.train.Example(features=tf.train.Features(feature={"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))})) #example对象对label和image数据进行封装writer.write(example.SerializeToString())  #序列化为字符串writer.close()

在制作完成我们的数据集后需要读取:

import tensorflow as tf
def read_and_decode(filename): # 读入flower_classfication.tfrecordsfilename_queue = tf.train.string_input_producer([filename])#生成一个queue队列reader = tf.TFRecordReader()_, serialized_example = reader.read(filename_queue)#返回文件名和文件features = tf.parse_single_example(serialized_example,features={'label': tf.FixedLenFeature([], tf.int64),'img_raw' : tf.FixedLenFeature([], tf.string),})#将image数据和label取出来img = tf.decode_raw(features['img_raw'], tf.uint8)img = tf.reshape(img, [128, 128, 3])  #reshape为128*128的3通道图片,必须和保存的分辨率一致 #否则出错,此外如果需要resize需要在下面调用tf.image.resize_images()img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #在流中抛出img张量,并归一化减去0.5label = tf.cast(features['label'], tf.int32) #在流中抛出label张量return img, label
with tf.Session as sess:a,b=read_and_decode('./flower_classfication.tfrecords)for i inn range(100)img,labels=sess.run([a,b])

关于tf.data数据载入和pytorch类似。存在tf.data.Dataset和tf.data.Iterator这里给出一个简单的使用例子基于tf2.0:

import tensorflow as tf
import numpy as npfrom tensorflow.data import Dataset, Iterator
import os# tf 读取数据
def readim(name1, name2):im1 = tf.io.read_file(name1)im1 = tf.image.decode_png(im1, channels=3)im1 = tf.image.rgb_to_grayscale(im1)im2 = tf.io.read_file(name2)im2 = tf.image.decode_png(im2, channels=3)im2 = tf.image.rgb_to_grayscale(im2)return im1, im2target_dir = "E:/Datasets/hr/"
input_dir = "E:/Datasets/lr/"
#获取数据名构建输入数据和输出数据路径列表
im_list = os.listdir(target_dir)
xs = [input_dir+i for i in im_list]
ys = [target_dir+i for i in im_list]# 创建数据集(小数据集为例用from_tensor_slices)大数据集还得用tfrecord
dataset = Dataset.from_tensor_slices((xs, ys))
# 载入数据和数据增扩等方法用map实现, num_parallel_calls相当于设置线程数
dataset = dataset.map(readim, num_parallel_calls=4)
# 设置batch_size等参数,如需要加速可设置prefetch(GPU训练CPU载数据)等
dataset = dataset.batch(32)
dataset = dataset.prefetch(buffer_size=tf.data.expermental.AUTOTUNE)dataset = dataset.repeat(1)
iterator = iter(dataset)while True:x, y = next(iterator)x = np.asarray(x)print(x.shape, y.shape)

参考博客:PyTorch之—图像分类一(每个类对应一个文件夹)_SongpingWang的博客-CSDN博客

tf.data

tensorflow 1.0 学习:用CNN进行图像分类 - denny402 - 博客园

python读取图像数据的一些方法相关推荐

  1. python读取图像的几种方法

    方法一:利用PIL中的Image函数,这个函数读取出来不是array格式 这时候需要用 np.asarray(im) 或者np.array()函数 区别是 np.array() 是深拷贝,np.asa ...

  2. python 读取图像的几种方法

    方法一:利用PIL中的Image函数,这个函数读取出来不是array格式 这时候需要用 np.asarray(im) 或者np.array()函数 区别是 np.array() 是深拷贝,np.asa ...

  3. python读取 pcd 数据 三种方法

    代码在git import open3d as o3d import numpy as npdef read_pcd(file_path):pcd = o3d.io.read_point_cloud( ...

  4. halcon边缘提取颜色相近_初学者福利!三种用Python从图像数据中提取特征的技术...

    全文共4073字,预计学习时长8分钟 你之前是否使用过图像数据?也许你想建立自己的物体检测模型,或者仅仅是想统计走进某栋建筑物的人数,使用计算机视觉技术处理图像拥有无穷无尽的可能性. 但数据科学家最近 ...

  5. python批量读取grib_Windows下Python读取GRIB数据

    之前写了一篇<基于Python的GRIB数据可视化>的文章,好多博友在评论里问我Windows系统下如何读取GRIB数据,在这里我做一下说明. 一.在Windows下Python为什么无法 ...

  6. python读取grib文件_Windows下Python读取GRIB数据

    之前写了一篇<基于Python的GRIB数据可视化>的文章,好多博友在评论里问我Windows系统下如何读取GRIB数据,在这里我做一下说明. 一.在Windows下Python为什么无法 ...

  7. python读取图像数据流_浅谈TensorFlow中读取图像数据的三种方式

    本文面对三种常常遇到的情况,总结三种读取数据的方式,分别用于处理单张图片.大量图片,和TFRecorder读取方式.并且还补充了功能相近的tf函数. 1.处理单张图片 我们训练完模型之后,常常要用图片 ...

  8. python读取mat数据_Python几种读取mat格式数据的方法,python几种读取mat

    Python几种读取mat格式数据的方法,python几种读取mat matlab中使用的数据一般会以mat的格式存储,用python读取有以下几种方法 1.使用scipy,具体实现如下: impor ...

  9. Kinect V1读取图像数据(For Windows)

    Kinect V1读取图像数据(For Windows) 这篇博客 Kinect V1介绍 数据读取的基本流程 运行代码和注释 结尾 这篇博客  刚好有一台现成的Kinect V1相机,所以就拿过来学 ...

最新文章

  1. iOS设计模式 - 组合
  2. 【Python 必会技巧】对字典按照键(key)或者值(value)排序
  3. 加密芯片算法移植方案的优点
  4. PHP表单header post get
  5. showModalDialog和showModelessDialog中提交form不弹出新窗口
  6. ArcGIS 如何卸载再重装
  7. android怎么执行命令,Android程序中执行adb命令
  8. ios微信消息自动朗读_如何使您的iOS设备大声朗读文章,书籍和更多内容
  9. firefox插件开发和调试
  10. 服务器磁盘IO是什么意思?SATA和固态硬盘的性能差异
  11. 打字游戏之主界面实现
  12. java怎么没有jmf包_java JMF
  13. New UWP Community Toolkit - DeveloperTools
  14. 【练习记录】C语言实现正则表达式匹配
  15. Linux IPC:命名管道的使用
  16. linux c++ 守护 程序,supervisor守护进程 | C/C++程序员之家
  17. 计算机连接游戏手柄,电脑如何使用手柄_电脑怎么连手柄打游戏-系统城
  18. DSP中的EDMA是什么?
  19. 使用tabula-java解析pdf的表格生成csv,再用opencsv读取csv
  20. 显卡里面都有什么东西,看显卡好坏就看这些参数

热门文章

  1. python nonetype_Python NoneType类型
  2. 外星人在中国买苹果的故事
  3. jsp java方法调用_jsp怎么调用java方法
  4. git 安装包 最新 下载 快速 国内 镜像 地址
  5. p5.js 我的第一幅码绘——小丑
  6. Chrome 解决: 您目前无法访问 因为此网站使用了 HSTS。网络错误和攻击通常是暂时的,因此,此网页稍后可能会恢复正常。
  7. 任何事都不要指望别人
  8. 老卫带你学---动态语言和静态语言的区别
  9. 计算机删除网络对象,恢复故障转移群集中已删除的计算机对象 - Windows Server | Microsoft Docs...
  10. 支付宝公钥、私钥和沙箱环境的配置