在构建tensorflow模型过程中,可谓是曲折颇多,一些教程上教会了我们如何使用下载的现成数据集,但却没有提及如何构建自己的数据集。我自己在学习过程中也走了不少弯路,希望这一系列的博客能解决大家的一些困惑。

我们本地构建数据集主要是以下几个步骤

1.数据处理

2.数据增强

3.数据导入

4.构建模型

5.训练模型

这篇先讲一下数据处理的一些操作,后面的步骤会慢慢发出来。

1.导入第三方库

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import math
import pathlib
import random
import matplotlib.pyplot as plt
import numpy as np

这里会注意到,我在导入os库时,在后面加了

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

这句话的作用是避免报错:This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)

2.导入数据路径

data_root = pathlib.Path('./image')
all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]

我这里的./image是我本地图片集所在的文件夹,image文件夹下是两个分别保存不同种类图片的文件夹,因为我这里是做二分类,所以只有两个不同种类的文件夹,如果大家需要构建识别多种图片的模型,可以添加其他文件夹。

3.随机打乱图片,这一步的目的是为了让图片集去特殊化,提高模型的准确率,因为如果你的图片中有比较相近的,而且数量比较多,会影响模型的学习。这一步是调用了random的shuffle,传入图片集列表,随机打乱。

random.shuffle(all_image_paths)

4.构建标签及索引

其实是构建了一个字典

#列出标签
label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
#为标签分配索引
label_to_index = dict((name, index) for index, name in enumerate(label_names))
#创建列表,存放标签和索引
all_image_labels = [label_to_index[pathlib.Path(path).parent.name]for path in all_image_paths]

5.加载和格式化图片

我们可以看到,tf.image.decode_jpeg(image,channels=3,这句话的作用是把图片变成三通道图,即RGB式图片。需要强调一下,tf.image.resize()这个小东西好用的很,可以把你的图片统一大小,这在后面我们训练模型是必须的,统一大小的图片更有利于我们的模型学习。而image/255.0是为了使图像进行归一化,得到的数值范围为[0, 1],彩色图片会变成灰图。

load_and_prepro_image()这个函数就是读取传入路径的图片集,然后返回值是经过了preprocess_image 这个函数的调用,将返回的图片处理为灰度图,比较简单暴力。

#加载和格式化图片
def preprocess_image(image):image = tf.image.decode_jpeg(image,channels=3)image = tf.image.resize(image,[192,192])image /= 255.0return imagedef load_and_prepro_image(path):image = tf.io.read_file(path)return preprocess_image(image)
for i in range(len(all_image_paths)):image_path = all_image_paths[i]label = all_image_labels[i]plt.imshow(load_and_prepro_image(image_path))plt.grid(False)plt.xlabel(image_path)plt.title(label_names[label].title())#plt.show()

然后关于这个for循环,其实不是必须的,只是为了方便我们检查图片的处理效果,调用的库是matplotlib,python比较有名的绘图库。

就先到这,后会有期。

下面是全部源码,tensorflow版本是2.5,py版本3.7,cuda11.6。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import math
import pathlib
import random
import matplotlib.pyplot as plt
import numpy as np#数据处理
#导入数据路径
data_root = pathlib.Path('./image')
all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]
#随机打乱图片
random.shuffle(all_image_paths)
#列出标签
label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
#为标签分配索引
label_to_index = dict((name, index) for index, name in enumerate(label_names))
#创建列表,存放标签和索引
all_image_labels = [label_to_index[pathlib.Path(path).parent.name]for path in all_image_paths]
#加载和格式化图片
def preprocess_image(image):image = tf.image.decode_jpeg(image,channels=3)image = tf.image.resize(image,[192,192])image = image/255.0return imagedef load_and_prepro_image(path):image = tf.io.read_file(path)return preprocess_image(image)
for i in range(len(all_image_paths)):image_path = all_image_paths[i]label = all_image_labels[i]plt.imshow(load_and_prepro_image(image_path))plt.grid(False)plt.xlabel(image_path)plt.title(label_names[label].title())#plt.show()

tensorflow导入自己的数据集相关推荐

  1. 基于tensorflow+RNN的MNIST数据集手写数字分类

    2018年9月25日笔记 tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流. RNN是recurrent neural network的简称,中文叫做循环 ...

  2. sceneflow 数据集多少张图片_快速使用 Tensorflow 读取 7 万数据集!

    原标题:快速使用 Tensorflow 读取 7 万数据集! 作者 | 郭俊麟 责编 | 胡巍巍 Brief 概述 这篇文章中,我们使用知名的图片数据库「THE MNIST DATABASE」作为我们 ...

  3. TF之LSTM:基于Tensorflow框架采用PTB数据集建立LSTM网络的自然语言建模

    TF之LSTM:基于Tensorflow框架采用PTB数据集建立LSTM网络的自然语言建模 目录 关于PTB数据集 代码实现 关于PTB数据集 PTB (Penn Treebank Dataset)文 ...

  4. 我的AI之路(20)--用Tensorflow object_detection跑raccoon数据集

    Raccoon是一个小巧有趣的加标签了的数据集,总共200张图片,用来训练识别浣熊,我们用它来学习体验object_detection的训练测试过程是可以的. 到https://github.com/ ...

  5. 快速使用Tensorflow读取7万数据集!

    一.Brief概述 这篇文章中,我们使用知名的图片数据库[THE MNIST DATABASE]作为我们的图片来源,它的数据内容是一共七万张28x28像素的手写数组图片. 并被分成六万张训练集与一万张 ...

  6. 快速使用 Tensorflow 读取 7 万数据集!

    作者 | 郭俊麟 责编 | 胡巍巍 Brief 概述 这篇文章中,我们使用知名的图片数据库「THE MNIST DATABASE」作为我们的图片来源,它的数据内容是一共七万张28×28像素的手写数字图 ...

  7. Tensorflow初探之MNIST数据集学习

    官方文档传送门 MNIST数据集是手写数字0~9的数据集,一般被用作机器学习领域的测试,相当于HelloWorld级别. 本程序先从网上导入数据,再利用最小梯度法进行训练使得样本交叉熵最小,最后给出训 ...

  8. tensorflow 导入新的tensorflow实例

    因为涉及到同一台电脑多个GPU,在指定tensorflow图的时候,需要为不同的图指定不同的GPU,所以必须在导入tensorflow之前,指定可用的GPU def import_tf(device_ ...

  9. 【TensorFlow】——实现minist数据集分类的前向传播(常规神经网络非卷积神经网络)

    目录 一.常规神经网络模型 二.TensorFlow实现前向传播步骤 1.读取数据集 2.batch划分 3.根据神经网络每一层的神经元个数来初始化参数w,b 4.进行每一层输入输出的计算 5.对每一 ...

最新文章

  1. TensorFlow优化器及用法
  2. 指数基金日涨跌幅python_看懂巴菲特推荐的指数基金定投,Python验证
  3. openresty开发系列29--openresty中发起http请求
  4. 奖金16万!首届电子商务AI算法大赛ECAA报名开启
  5. 使用NSURLProtocol实现UIWebView的离线缓存的简单实现
  6. 浏览器兼容性问题解决方案· 总结
  7. Span中显示内容过长显示省略号---SpringCloud Alibaba_若依微服务框架改造_前端基于Vue的ElementUI---工作笔记011
  8. HUE与Oozie的集成
  9. 买服务器做网站 镜像选什么,云服务器做网站镜像类型选啥
  10. 远程计算机用户名win7,win7局域网远程控制的方法(图文)
  11. 免费的HTTP代理IP服务器地址
  12. 土拍熔断意味着什么_熔断意味着什么
  13. macbook air 重置mysql密码
  14. 插件77:获取Yahoo!股票新闻
  15. html页面中汉字上面显示拼音
  16. Windows 取证之ShellBags
  17. 关于“运放“这些知识点
  18. C语言经典例题-两个分数相加
  19. K8S 常见面试题总结
  20. 轮廓线扫描算法:Theo Pavlidis' Algorithm

热门文章

  1. 智利银行在勒索软件攻击后关闭了所有分行
  2. Google advertiser 开发
  3. Point-cloud based 3D object detection and classification methods for self-driving applications
  4. 记一次修改DiyBox的经历(openwrt固件解包与打包)
  5. 手机验证码的测试用例梳理
  6. canvas 画一幅画
  7. 智慧交通|沪宜公路智慧车列交通仿真研究
  8. mysql5.7.19winx64安装_mysql5.7.19winx64安装配置方法图文教程(win10)
  9. 编写一个java_Java入门篇(一)——如何编写一个简单的Java程序
  10. 达内学软件测试发证书吗,达内软件测试培训让我拥有了实际工作经验