im2rec.py解读

直接给代码,注释我写上了。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.from __future__ import print_function
import os
import syscurr_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(curr_path, "../python"))
import mxnet as mx
import random
import argparse
import cv2
import time
import tracebacktry:import multiprocessing
except ImportError:multiprocessing = None#一个生成器,生成(图片序号,路径,图片标签)
def list_image(root, recursive, exts):   #root图片根目录  recursive是否递归遍历(递归遍历根目录下的子目录并为每个文件夹中的图像指定一个唯一的标签)  exts支持图片类型i = 0if recursive:                   #判递归   是:递归根目录下每个文件夹,每个文件夹下的图片一个标签   否:不递归,同一标签0cat = {}for path, dirs, files in os.walk(root, followlinks=True):     #递归遍历图像目录 (os.walk,文件遍历器)dirs.sort()files.sort()for fname in files:        #遍历图像文件fpath = os.path.join(path, fname)suffix = os.path.splitext(fname)[1].lower()       #获取文件扩展名(先分离文件命和文件扩展名,然后拿扩展名)if os.path.isfile(fpath) and (suffix in exts):    #判:是否文件和是否为支持扩展名if path not in cat:                           #将不重复的路径添加到字典中cat[path] = len(cat)yield (i, os.path.relpath(fpath, root), cat[path])    #生成(图片序号,路径,图片标签)i += 1for k, v in sorted(cat.items(), key=lambda x: x[1]):print(os.path.relpath(k, root), v)else:for fname in sorted(os.listdir(root)):fpath = os.path.join(root, fname)suffix = os.path.splitext(fname)[1].lower()if os.path.isfile(fpath) and (suffix in exts):yield (i, os.path.relpath(fpath, root), 0)         #生成(图片序号,路径,图片标签0)i += 1#负责编写.lst的内容
def write_list(path_out, image_list):with open(path_out, 'w') as fout:for i, item in enumerate(image_list):line = '%d\t' % item[0]for j in item[2:]:line += '%f\t' % jline += '%s\n' % item[1]fout.write(line)#生成xxx.lst文件,制作.lst的主函数
def make_list(args):image_list = list_image(args.root, args.recursive, args.exts)image_list = list(image_list)if args.shuffle is True:          #是否打乱random.seed(100)random.shuffle(image_list)N = len(image_list)chunk_size = (N + args.chunks - 1) // args.chunks       #args.chunks(块数)  chunk_size(块大小)for i in range(args.chunks):chunk = image_list[i * chunk_size:(i + 1) * chunk_size]if args.chunks > 1:str_chunk = '_%d' % ielse:str_chunk = ''sep = int(chunk_size * args.train_ratio)sep_test = int(chunk_size * args.test_ratio)if args.train_ratio == 1.0:                        #全训练集的情况write_list(args.prefix + str_chunk + '.lst', chunk)else:                        #划分训练集,验证集,测试集if args.test_ratio:write_list(args.prefix + str_chunk + '_test.lst', chunk[:sep_test])if args.train_ratio + args.test_ratio < 1.0:write_list(args.prefix + str_chunk + '_val.lst', chunk[sep_test + sep:])write_list(args.prefix + str_chunk + '_train.lst', chunk[sep_test:sep_test + sep])#读取.lst文件
def read_list(path_in):with open(path_in) as fin:while True:line = fin.readline()if not line:breakline = [i.strip() for i in line.strip().split('\t')]line_len = len(line)if line_len < 3:print('lst should at least has three parts, but only has %s parts for %s' %(line_len, line))continuetry:item = [int(line[0])] + [line[-1]] + [float(i) for i in line[1:-1]]except Exception as e:print('Parsing lst met error for %s, detail: %s' %(line, e))continueyield item#对图像进行编码或裁剪
def image_encode(args, i, item, q_out):fullpath = os.path.join(args.root, item[1])if len(item) > 3 and args.pack_label:          #判断是否将多维标签保存header = mx.recordio.IRHeader(0, item[2:], item[0], 0)else:header = mx.recordio.IRHeader(0, item[2], item[0], 0)if args.pass_through:           #是否跳过转换并按原样保存图像try:with open(fullpath, 'rb') as fin:img = fin.read()s = mx.recordio.pack(header, img)      #打包q_out.put((i, s, item))     #输出except Exception as e:traceback.print_exc()print('pack_img error:', item[1], e)q_out.put((i, None, item))returntry:      #不跳过编码img = cv2.imread(fullpath, args.color)except:traceback.print_exc()print('imread error trying to load file: %s ' % fullpath)q_out.put((i, None, item))returnif img is None:print('imread read blank (None) image for file: %s' % fullpath)q_out.put((i, None, item))returnif args.center_crop:                    #是否裁剪中心图像以使其为矩形if img.shape[0] > img.shape[1]:margin = (img.shape[0] - img.shape[1]) // 2;img = img[margin:margin + img.shape[1], :]else:margin = (img.shape[1] - img.shape[0]) // 2;img = img[:, margin:margin + img.shape[0]]if args.resize:                         #调整图片大小(将图像的较短边缘调整为新大小)if img.shape[0] > img.shape[1]:newsize = (args.resize, img.shape[0] * args.resize // img.shape[1])else:newsize = (img.shape[1] * args.resize // img.shape[0], args.resize)img = cv2.resize(img, newsize)try:s = mx.recordio.pack_img(header, img, quality=args.quality, img_fmt=args.encoding)     #打包(img_fmt:图片编码,jpg或png) (quality:编码质量,jpge质量保存率,png压缩率)q_out.put((i, s, item))             #输出except Exception as e:traceback.print_exc()print('pack_img error on file: %s' % fullpath, e)q_out.put((i, None, item))return#读取图片数据函数,读取制作rec需要的图片(在使用多线程会用到)
def read_worker(args, q_in, q_out):while True:deq = q_in.get()if deq is None:breaki, item = deqimage_encode(args, i, item, q_out)#写函数,制作recIO(在使用多线程会用到)
def write_worker(q_out, fname, working_dir):pre_time = time.time()count = 0fname = os.path.basename(fname)fname_rec = os.path.splitext(fname)[0] + '.rec'fname_idx = os.path.splitext(fname)[0] + '.idx'record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx),os.path.join(working_dir, fname_rec), 'w')buf = {}more = Truewhile more:deq = q_out.get()if deq is not None:i, s, item = deqbuf[i] = (s, item)else:more = Falsewhile count in buf:s, item = buf[count]del buf[count]if s is not None:record.write_idx(item[0], s)if count % 1000 == 0:cur_time = time.time()print('time:', cur_time - pre_time, ' count:', count)pre_time = cur_timecount += 1#脚本参数定义
def parse_args():parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,description='Create an image list or \make a record database by reading from an image list')parser.add_argument('prefix', help='prefix of input/output lst and rec files.')      #输出文件(.rec或.lst)的前缀,如 xxx.rec  xxx.lstparser.add_argument('root', help='path to folder containing images.')        #root 图片文件夹cgroup = parser.add_argument_group('Options for creating image lists')cgroup.add_argument('--list', action='store_true',help='If this is set im2rec will create image list(s) by traversing root folder\and output to <prefix>.lst.\Otherwise im2rec will read <prefix>.lst and create a database at <prefix>.rec')cgroup.add_argument('--exts', nargs='+', default=['.jpeg', '.jpg', '.png'],           #可接受图像扩展名help='list of acceptable image extensions.')cgroup.add_argument('--chunks', type=int, default=1, help='number of chunks.')cgroup.add_argument('--train-ratio', type=float, default=1.0,help='Ratio of images to use for training.')cgroup.add_argument('--test-ratio', type=float, default=0,help='Ratio of images to use for testing.')cgroup.add_argument('--recursive', action='store_true',help='If true recursively walk through subdirs and assign an unique label\to images in each folder. Otherwise only include images in the root folder\and give them label 0.')cgroup.add_argument('--no-shuffle', dest='shuffle', action='store_false',help='If this is passed, \im2rec will not randomize the image order in <prefix>.lst')rgroup = parser.add_argument_group('Options for creating database')rgroup.add_argument('--pass-through', action='store_true',help='whether to skip transformation and save image as is')rgroup.add_argument('--resize', type=int, default=0,help='resize the shorter edge of image to the newsize, original images will\be packed by default.')rgroup.add_argument('--center-crop', action='store_true',help='specify whether to crop the center image to make it rectangular.')rgroup.add_argument('--quality', type=int, default=95,help='JPEG quality for encoding, 1-100; or PNG compression for encoding, 1-9')rgroup.add_argument('--num-thread', type=int, default=1,help='number of thread to use for encoding. order of images will be different\from the input list if >1. the input list will be modified to match the\resulting order.')rgroup.add_argument('--color', type=int, default=1, choices=[-1, 0, 1],help='specify the color mode of the loaded image.\1: Loads a color image. Any transparency of image will be neglected. It is the default flag.\0: Loads image in grayscale mode.\-1:Loads image as such including alpha channel.')rgroup.add_argument('--encoding', type=str, default='.jpg', choices=['.jpg', '.png'],help='specify the encoding of the images.')rgroup.add_argument('--pack-label', action='store_true',help='Whether to also pack multi dimensional label in the record file')args = parser.parse_args()args.prefix = os.path.abspath(args.prefix)args.root = os.path.abspath(args.root)return argsif __name__ == '__main__':args = parse_args()if args.list:              #制作.lst文件make_list(args)else:                      #制作.rec文件if os.path.isdir(args.prefix):      #获取工作目录working_dir = args.prefixelse:working_dir = os.path.dirname(args.prefix)files = [os.path.join(working_dir, fname) for fname in os.listdir(working_dir)            #列表表达式(文件路径)if os.path.isfile(os.path.join(working_dir, fname))]count = 0for fname in files:          #在工作目录下利用.lst文件生成.rec文件if fname.startswith(args.prefix) and fname.endswith('.lst'):print('Creating .rec file from', fname, 'in', working_dir)count += 1image_list = read_list(fname)              # 读取.lst文件# -- write_record -- #if args.num_thread > 1 and multiprocessing is not None:    # 多线程生成.rec文件q_in = [multiprocessing.Queue(1024) for i in range(args.num_thread)]q_out = multiprocessing.Queue(1024)read_process = [multiprocessing.Process(target=read_worker, args=(args, q_in[i], q_out)) \for i in range(args.num_thread)]for p in read_process:               # 启动读线程p.start()write_process = multiprocessing.Process(target=write_worker, args=(q_out, fname, working_dir)) # 以write_worker函数(方法)设置线程write_process.start()         # 启动写线程for i, item in enumerate(image_list):q_in[i % len(q_in)].put((i, item))for q in q_in:q.put(None)for p in read_process:p.join()q_out.put(None)write_process.join()else:print('multiprocessing not available, fall back to single threaded encoding')try:import Queue as queueexcept ImportError:import queueq_out = queue.Queue()             # 输出队列fname = os.path.basename(fname)fname_rec = os.path.splitext(fname)[0] + '.rec'fname_idx = os.path.splitext(fname)[0] + '.idx'record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx),os.path.join(working_dir, fname_rec), 'w')cnt = 0pre_time = time.time()for i, item in enumerate(image_list):image_encode(args, i, item, q_out)       # 根据参数,编码和裁剪图片if q_out.empty():continue_, s, _ = q_out.get()record.write_idx(item[0], s)if cnt % 1000 == 0:               # 执行进度打印cur_time = time.time()print('time:', cur_time - pre_time, ' count:', cnt)pre_time = cur_timecnt += 1if not count:print('Did not find and list file with prefix %s'%args.prefix)

im2rec.py代码解读相关推荐

  1. 图像分割套件PaddleSeg全面解析(一)train.py代码解读

    首先祝贺百度团队百度斩获NeurIPS2020挑战赛冠军,https://www.jiqizhixin.com/articles/2020-12-09-2. 在此次比赛中使用的是基于飞桨深度学习框架开 ...

  2. deepsort代码解读

    YOLOV5-DEEPSORT 代码解读 文件格式如下: 下面分别叙述每个文件的内容: detection.py 代码解读 # vim: expandtab:ts=4:sw=4 import nump ...

  3. 飞桨PP-HumanSeg本地实时视频推理代码解读

    文章同样发布在百度AIStudio,Fork后即可在线运行,请点击这里 本人希望基于PaddleSeg对视频实时进行图像分割,但在AiStudio中检索分割和实时两个关键词后并没有得到理想的结果,大部 ...

  4. Faceboxes pytorch代码解读(一) box_utils.py(上篇)

    Faceboxes pytorch代码解读(一) box_utils.py(上篇) 有幸读到Shifeng Zhang老师团队的人脸检测论文,感觉对自己的人脸学习论文十分有帮助.通过看别人的paper ...

  5. MAML-RL Pytorch 代码解读 (6) -- maml_rl/envs/bandit.py

    MAML-RL Pytorch 代码解读 (6) – maml_rl/envs/bandit.py 文章目录 MAML-RL Pytorch 代码解读 (6) -- maml_rl/envs/band ...

  6. 装逼一步到位!GauGAN代码解读来了

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:游璐颖,福州大学,Datawhale成员 AI神笔马良 如何装逼一 ...

  7. BERT:代码解读、实体关系抽取实战

    目录 前言 一.BERT的主要亮点 1. 双向Transformers 2.句子级别的应用 3.能够解决的任务 二.BERT代码解读 1. 数据预处理 1.1 InputExample类 1.2 In ...

  8. 基于SegNet和UNet的遥感图像分割代码解读

    基于SegNet和UNet的遥感图像分割代码解读 目录 基于SegNet和UNet的遥感图像分割代码解读 前言 概述 代码框架 代码细节分析 划分数据集gen_dataset.py UNet模型训练u ...

  9. VGAE(Variational graph auto-encoders)论文及代码解读

    一,论文来源 论文pdf Variational graph auto-encoders 论文代码 github代码 二,论文解读 理论部分参考: Variational Graph Auto-Enc ...

最新文章

  1. 数据结构:最大子序列和
  2. 怎么样才算是精通 Python?
  3. iview图表_【技术博客】iview常用工具记录
  4. 短视频自研还是选择第三方?技术选型前必看的自检清单
  5. Python:常用模块简介(1)
  6. 论文浅尝 | TANDA: Transfer and Adapt Pre-Trained Transformer Models
  7. android图片适配到裁剪框,Android图片剪裁-调用系统实现,完美适配魅族等机型
  8. 统计订单:复选+全选+计算 的列表
  9. 使用NPOI——C#和WEB API导出到Excel
  10. 滞后问题_富锂正极材料的电压滞后问题
  11. c语言mppt例子,mppt太阳能控制器电路原理
  12. QT 跨平台 代码框架
  13. 28岁功能测试被辞,最后结局令人感慨...
  14. 声声慢 - 程序人生
  15. 2017第八届蓝桥杯决赛(大学B组)java试题 瓷砖样式
  16. html5 语音输入小话筒,HTML5语音输入方法
  17. [名词解释] PATA和SATA I
  18. RayScan漏扫工具
  19. 与计算机相关的创意网名,最好的网名昵称大全_好听又有创意的网名
  20. Android图片海报制作-自定义文字排版控件组件

热门文章

  1. 关于Kingfisher--备用
  2. substringToIndex substringFromIndex
  3. flex module不编译的问题
  4. MindManager: Draw your own MindMap!
  5. C# 调Win32 API SendMessage简单用法及wMsg常量
  6. Java与C#事件处理详细对比
  7. java中appletviewer是什么意思_Java开发网 - 请教,appletviewer的问题
  8. Django中一个项目使用多个数据库(原生sql 的使用,亲测)
  9. nginx集群tomcat,session共享问题
  10. 全排列的生成算法:字典序法