Python3.0导入CIFAR-10数据集

  • 前言
  • load_CIFAR10函数
  • 调用cifar-10的数据集的数据
  • 结果

前言

在搞完本科毕业设计之后终于有时间可以重新开始年前的关于深度学习的一些知识的学习了,年前其实一直因为python的版本问题导致无法很好的解决CIFAR-10数据集的导入问题(windows10)。在这段时间的学习后,终于可以攻克这一难题了,现将经验总结如下。

本次笔记采用参考书是《深度学习实战》,杨云,杜飞著-北京:清华大学出版社,2018版本

实现的目的:成功在python3.0以上版本中实现对CIFAR-10数据集的导入工作。

load_CIFAR10函数

​ 将load_CIFAR10(root)函数封装在名为data_utils.py的模块库中,置于python的默认路径之下:

首先得下好了imageio,numpy等包,此项操作在pycharm中较好实现

将以下代码封装至data_utils.py文件中作为模块文件

import pickle
import numpy as np
import os
from imageio import imreaddef load_CIFAR_batch(filename):with open(filename, 'rb') as f:datadict = pickle.load(f,encoding='iso-8859-1')X = datadict['data']Y = datadict['labels']X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")Y = np.array(Y)return X, Ydef load_CIFAR10(ROOT):xs = []ys = []for b in range(1,6):f = os.path.join(ROOT, 'data_batch_%d' % (b, ))X, Y = load_CIFAR_batch(f)xs.append(X)ys.append(Y)    Xtr = np.concatenate(xs)Ytr = np.concatenate(ys)del X, YXte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))return Xtr, Ytr, Xte, Ytedef get_CIFAR10_data(num_training=49000, num_validation=1000, num_test=1000):cifar10_dir = 'datasets/cifar-10-batches-py'X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)mask = range(num_training, num_training + num_validation)X_val = X_train[mask]y_val = y_train[mask]mask = range(num_training)X_train = X_train[mask]y_train = y_train[mask]mask = range(num_test)X_test = X_test[mask]y_test = y_test[mask]mean_image = np.mean(X_train, axis=0)X_train -= mean_imageX_val -= mean_imageX_test -= mean_imageX_train = X_train.transpose(0, 3, 1, 2).copy()X_val = X_val.transpose(0, 3, 1, 2).copy()X_test = X_test.transpose(0, 3, 1, 2).copy()return {'X_train': X_train, 'y_train': y_train,'X_val': X_val, 'y_val': y_val,'X_test': X_test, 'y_test': y_test,}def load_tiny_imagenet(path, dtype=np.float32):with open(os.path.join(path, 'wnids.txt'), 'r') as f:wnids = [x.strip() for x in f]wnid_to_label = {wnid: i for i, wnid in enumerate(wnids)}with open(os.path.join(path, 'words.txt'), 'r') as f:wnid_to_words = dict(line.split('\t') for line in f)for wnid, words in wnid_to_words.iteritems():wnid_to_words[wnid] = [w.strip() for w in words.split(',')]class_names = [wnid_to_words[wnid] for wnid in wnids]X_train = []y_train = []for i, wnid in enumerate(wnids):if (i + 1) % 20 == 0:print ('loading training data for synset %d / %d' % (i + 1, len(wnids)))boxes_file = os.path.join(path, 'train', wnid, '%s_boxes.txt' % wnid)with open(boxes_file, 'r') as f:filenames = [x.split('\t')[0] for x in f]num_images = len(filenames)X_train_block = np.zeros((num_images, 3, 64, 64), dtype=dtype)y_train_block = wnid_to_label[wnid] * np.ones(num_images, dtype=np.int64)for j, img_file in enumerate(filenames):img_file = os.path.join(path, 'train', wnid, 'images', img_file)img = imread(img_file)if img.ndim == 2:img.shape = (64, 64, 1)X_train_block[j] = img.transpose(2, 0, 1)X_train.append(X_train_block)y_train.append(y_train_block)X_train = np.concatenate(X_train, axis=0)y_train = np.concatenate(y_train, axis=0)with open(os.path.join(path, 'val', 'val_annotations.txt'), 'r') as f:img_files = []val_wnids = []for line in f:img_file, wnid = line.split('\t')[:2]img_files.append(img_file)val_wnids.append(wnid)num_val = len(img_files)y_val = np.array([wnid_to_label[wnid] for wnid in val_wnids])X_val = np.zeros((num_val, 3, 64, 64), dtype=dtype)for i, img_file in enumerate(img_files):img_file = os.path.join(path, 'val', 'images', img_file)img = imread(img_file)if img.ndim == 2:img.shape = (64, 64, 1)X_val[i] = img.transpose(2, 0, 1)img_files = os.listdir(os.path.join(path, 'test', 'images'))X_test = np.zeros((len(img_files), 3, 64, 64), dtype=dtype)for i, img_file in enumerate(img_files):img_file = os.path.join(path, 'test', 'images', img_file)img = imread(img_file)if img.ndim == 2:img.shape = (64, 64, 1)X_test[i] = img.transpose(2, 0, 1)y_test = Noney_test_file = os.path.join(path, 'test', 'test_annotations.txt')if os.path.isfile(y_test_file):with open(y_test_file, 'r') as f:img_file_to_wnid = {}for line in f:line = line.split('\t')img_file_to_wnid[line[0]] = line[1]y_test = [wnid_to_label[img_file_to_wnid[img_file]] for img_file in img_files]y_test = np.array(y_test)return class_names, X_train, y_train, X_val, y_val, X_test, y_testdef load_models(models_dir):models = {}for model_file in os.listdir(models_dir):with open(os.path.join(models_dir, model_file), 'rb') as f:try:models[model_file] = pickle.load(f)['model']except pickle.UnpicklingError:continuereturn models

随后再将其封装到文件夹utils中,存放与python的路径之下:

调用cifar-10的数据集的数据

下载链接:

cifar-10数据集下载网站

下载解压完的数据保存至python的路径如下:

'D:\Anaconda3\envs\PythonExamples\Lib\cifar-10-batches-py'

以下是显示数据信息的脚本:

import numpy as np
import random
from utils.data_utils import load_CIFAR10
from classifiers.chapter2 import *
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (10.0,8.0)
cifar10_dir = 'D:\Anaconda3\envs\PythonExamples\Lib\cifar-10-batches-py'
X_train,Y_train,X_test,Y_test = load_CIFAR10(cifar10_dir)
#以下是数据可视化操作
classes = ['plane','car','bird','cat','deer','dog','frog','horse','ship']
num_classes = len(classes)
sample_per_classes = 7
for y,cls in enumerate(classes):idxs = np.flatnonzero(Y_train == y)idxs = np.random.choice(idxs,sample_per_classes,replace=False)for i,idx in enumerate(idxs):plt_idx = i*num_classes+y+1plt.subplot(sample_per_classes,num_classes,plt_idx)plt.imshow(X_train[idx].astype('uint8'))plt.axis('off')if i == 0:plt.title(cls)
plt.show()

结果

tensorflow 入门笔记(二)相关推荐

  1. tensorflow学习笔记二——建立一个简单的神经网络拟合二次函数

    tensorflow学习笔记二--建立一个简单的神经网络 2016-09-23 16:04 2973人阅读 评论(2) 收藏 举报  分类: tensorflow(4)  目录(?)[+] 本笔记目的 ...

  2. MySql入门笔记二~悲催的用户

    这些是当年小弟的MySql学习笔记,木有多么复杂的结构操作,木有多炫丽的语句开发,木有...总之就是木有什么技术含量... 日复一日,彪悍的人生伴随着彪悍的健忘,运维操作为王,好记性不如烂笔头,山水有 ...

  3. (转)tensorflow入门教程(二十六)人脸识别(上)

    https://blog.csdn.net/rookie_wei/article/details/81676177 1.概述 查看全文 http://www.taodudu.cc/news/show- ...

  4. TensorFlow入门笔记

    TensorFlow 入门笔记 (1)(个人学习使用) 环境配置 Ubuntu16.04(VMware Workstation Pro 14) Python2.7 TensorFlow1.3.0 Te ...

  5. 机器学习入门 笔记(二) 机器学习基础概念

    第二章 机器学习基础概念 1.机器的数据 2.机器学习的主要任务 3.监督学习和非监督学习 4.批量.在线学习.参数.非参数学习 5.哲学思考 6.环境的搭建 1.机器的数据 我们以鸢尾花的数据为例. ...

  6. 海思入门笔记二:HiBurn工具实现镜像烧写

    这里写自定义目录标题 海思入门笔记:HiBurn工具实现镜像烧写 第一步:裸板可使用串口先烧录boot(速度慢) 第二步:已烧好boot的板子,可使用USB快速烧录 海思入门笔记:HiBurn工具实现 ...

  7. 区块链安全入门笔记(二) | 慢雾科普

    虽然有着越来越多的人参与到区块链的行业之中,然而由于很多人之前并没有接触过区块链,也没有相关的安全知识,安全意识薄弱,这就很容易让攻击者们有空可钻.面对区块链的众多安全问题,慢雾特推出区块链安全入门笔 ...

  8. tensorflow+入门笔记︱基本张量tensor理解与tensorflow运行结构与相关报错

    欢迎登陆官网(附https://tensorflow.google.cn/)了解更多 TensorFlow 内容,也可关注 TensorFlow 官方公众号获取更多资讯. Gokula Krishna ...

  9. TensorFlow入门之二:tensorflow手写数字识别

    一.基础知识 基础知识可以跳过,可以直接看后面的代码实现 MNIST数据集 MNIST数据集的官网是Yann LeCun's website.可以使用下面的python代码自动下载数据集. #已经下载 ...

最新文章

  1. 红米手机使用应用沙盒一键修改imsi信息
  2. python开发需要掌握哪些知识-Python的10个基础知识点,新手必须背下来!
  3. tensorflow常见问题
  4. MyBitis(iBitis)系列随笔之二:类型别名(typeAliases)与表-对象映射(ORM)
  5. app端微信支付(二) - 生成预付单
  6. 简明Vim练级攻略(初学者)
  7. Python成长之路【第七篇】:Python基础之装饰器
  8. 浅谈项目管理中的四要素
  9. 学成在线--22.课程营销
  10. 【python】1. 两数之和
  11. cannot use a string pattern on a bytes-like object(bytes与str互转)
  12. Outlook2010新建域内Exchang邮箱的另一种方法
  13. python read函数参数_最新Pandas.read_excel()全参数详解(案例实操,如何利用python导入excel)...
  14. 十大排序算法——堆排序(C语言)
  15. 请求支付宝渠道报错:40006,Insufficient Permissions,ISV权限不足
  16. 三菱q系列plc连接电脑步骤_三菱plc连接电脑步骤
  17. Matlab-图片上画线
  18. 谈谈java中封装的那点事
  19. 网间数据摆渡如何轻松实现数据安全交换
  20. 蚂蚁系统案例2【无标题】

热门文章

  1. 机器学习总结(一):线性回归、岭回归、Lasso回归
  2. 【EOS】2.3 深入理解ABI文件
  3. kA*与(kA)*的行列式计算
  4. onenote 不能同步的原因及解决方法(教训总结)
  5. numpy保存和读取dictionary字典
  6. Linux 基本命令(三)--histroy 常用命令详解
  7. 二叉树的遍历 《算法导论》10.4-1~10.4-3 10.4-5
  8. Hadoop Hive与Hbase关系 整合
  9. Ogre1.8.1 Basic Tutorial 6 - The Ogre Startup Sequence
  10. linux中断处理体系结构