课程来源:人工智能实践:Tensorflow笔记2

文章目录

  • 前言
  • 1、文件一览
  • 2、将load_data()函数替换掉
  • 2、调用generateds函数
  • 4、效果
  • 总结

前言

本讲目标:自制数据集,解决本领域应用
将我们手中的图片和标签信息制作为可以直接导入的npy文件。


1、文件一览

首先看看我们的文件长什么样:
路径:D:\python code\AI\class4\MNIST_FC\mnist_image_label\mnist_test_jpg_10000
图片文件:(黑底白字的灰度图,大小:28x28,每个像素点都是0~255之间的整数)

标签文件:(图片名和对应的标签,中间用空格隔开)

2、将load_data()函数替换掉

之前我们导入数据集的方式是(以mnist数据集为例):

fashion = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()

导入后变量的数据类型和形状:

x_train.shape (60000,28,28) ,3维数组,60000个28行28列的图片灰度值
y_train.shape (60000,) ,60000张图片对应的标签,是1维数组
x_test.shape (10000,28,28) ,3维数组,10000个28行28列的图片灰度值
y_test.shape (10000,) ,10000张图片对应的标签,是1维数组

我们需要自己写个函数generateds(图片路径,标签文件):
观察数据集:

我们需要做的:把图片灰度值数据拼接到图片列表,把标签数据拼接到标签列表。

函数代码如下:

def generateds(path, txt):f = open(txt, 'r')          #只读形式读取文本数据contents = f.readlines()  # 按行读取,读取所有行f.close()                #关闭文件x, y_ = [], []            #建立空列表for content in contents:    #逐行读出value = content.split()  # 以空格分开,存入数组   图片名为value0   标签为value1img_path = path + value[0] #图片路径+图片名->拼接出索引路径img = Image.open(img_path)   #读入图片img = np.array(img.convert('L'))img = img / 255.       #归一化数据x.append(img)         #将归一化的数据贴到列表xy_.append(value[1])        #标签贴到列表y_print('loading : ' + content)   #打印状态提示x = np.array(x)y_ = np.array(y_)y_ = y_.astype(np.int64)return x, y_

2、调用generateds函数

使用函数代码:

'''添加了:
训练集图片路径
训练集标签文件
训练集输入特征存储文件
训练集标签存储文件
测试集图片路径
测试集标签文件
测试集输入特征存储文件
测试集标签存储文件'''
train_path = 'D:/python code/AI/class4/FASHION_FC/fashion_image_label/fashion_train_jpg_60000/'
train_txt = 'D:/python code/AI/class4/FASHION_FC/fashion_image_label/fashion_train_jpg_60000.txt'
x_train_savepath = 'D:/python code/AI/class4/FASHION_FC/fashion_image_label/fashion_x_train.npy'
y_train_savepath = 'D:/python code/AI/class4/FASHION_FC/fashion_image_label/fahion_y_train.npy'test_path = 'D:/python code/AI/class4/FASHION_FC/fashion_image_label/fashion_test_jpg_10000/'
test_txt = 'D:/python code/AI/class4/FASHION_FC/fashion_image_label/fashion_test_jpg_10000.txt'
x_test_savepath = 'D:/python code/AI/class4/FASHION_FC/fashion_image_label/fashion_x_test.npy'
y_test_savepath = 'D:/python code/AI/class4/FASHION_FC/fashion_image_label/fashion_y_test.npy'
#观察测试集训练集文件是否存在,如果存在直接读取,如果不存在调用generate datasets函数
if os.path.exists(x_train_savepath) and os.path.exists(y_train_savepath) and os.path.exists(x_test_savepath) and os.path.exists(y_test_savepath):print('-------------Load Datasets-----------------')x_train_save = np.load(x_train_savepath)y_train = np.load(y_train_savepath)x_test_save = np.load(x_test_savepath)y_test = np.load(y_test_savepath)x_train = np.reshape(x_train_save, (len(x_train_save), 28, 28))x_test = np.reshape(x_test_save, (len(x_test_save), 28, 28))
else:print('-------------Generate Datasets-----------------')x_train, y_train = generateds(train_path, train_txt)x_test, y_test = generateds(test_path, test_txt)print('-------------Save Datasets-----------------')x_train_save = np.reshape(x_train, (len(x_train), -1))x_test_save = np.reshape(x_test, (len(x_test), -1))np.save(x_train_savepath, x_train_save)np.save(y_train_savepath, y_train)np.save(x_test_savepath, x_test_save)np.save(y_test_savepath, y_test)model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

4、效果

制作完数据集之后开始用神经网络训练:

可以发现原本的文件夹中出现了你所需要的npy文件。

完整代码:

import tensorflow as tf
from PIL import Image
import numpy as np
import ostrain_path = 'D:/python code/AI/class4/FASHION_FC/fashion_image_label/fashion_train_jpg_60000/'
train_txt = 'D:/python code/AI/class4/FASHION_FC/fashion_image_label/fashion_train_jpg_60000.txt'
x_train_savepath = 'D:/python code/AI/class4/FASHION_FC/fashion_image_label/fashion_x_train.npy'
y_train_savepath = 'D:/python code/AI/class4/FASHION_FC/fashion_image_label/fahion_y_train.npy'test_path = 'D:/python code/AI/class4/FASHION_FC/fashion_image_label/fashion_test_jpg_10000/'
test_txt = 'D:/python code/AI/class4/FASHION_FC/fashion_image_label/fashion_test_jpg_10000.txt'
x_test_savepath = 'D:/python code/AI/class4/FASHION_FC/fashion_image_label/fashion_x_test.npy'
y_test_savepath = 'D:/python code/AI/class4/FASHION_FC/fashion_image_label/fashion_y_test.npy'def generateds(path, txt):f = open(txt, 'r')contents = f.readlines()  # 按行读取f.close()x, y_ = [], []for content in contents:value = content.split()  # 以空格分开,存入数组img_path = path + value[0]img = Image.open(img_path)img = np.array(img.convert('L'))img = img / 255.x.append(img)y_.append(value[1])print('loading : ' + content)x = np.array(x)y_ = np.array(y_)y_ = y_.astype(np.int64)return x, y_if os.path.exists(x_train_savepath) and os.path.exists(y_train_savepath) and os.path.exists(x_test_savepath) and os.path.exists(y_test_savepath):print('-------------Load Datasets-----------------')x_train_save = np.load(x_train_savepath)y_train = np.load(y_train_savepath)x_test_save = np.load(x_test_savepath)y_test = np.load(y_test_savepath)x_train = np.reshape(x_train_save, (len(x_train_save), 28, 28))x_test = np.reshape(x_test_save, (len(x_test_save), 28, 28))
else:print('-------------Generate Datasets-----------------')x_train, y_train = generateds(train_path, train_txt)x_test, y_test = generateds(test_path, test_txt)print('-------------Save Datasets-----------------')x_train_save = np.reshape(x_train, (len(x_train), -1))x_test_save = np.reshape(x_test, (len(x_test), -1))np.save(x_train_savepath, x_train_save)np.save(y_train_savepath, y_train)np.save(x_test_savepath, x_test_save)np.save(y_test_savepath, y_test)model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

总结

课程链接:MOOC人工智能实践:TensorFlow笔记2

【神经网络八股扩展】:自制数据集相关推荐

  1. Tensorflow2.x框架-神经网络八股扩展-自制数据集

    自制数据集,解决本领域应用 目录 摘要 一.Sequential() 实现自制数据集 二.Class() 实现自制数据集 摘要 mnist_image_label 文件夹: mnist_train_j ...

  2. 【神经网络八股扩展】:数据增强

    课程来源:人工智能实践:Tensorflow笔记2 文章目录 前言 TensorFlow2数据增强函数 数据增强+网络八股代码: 总结 前言 本讲目标:数据增强,增大数据量 关于我们为何要使用数据增强 ...

  3. 【深度学习21天学习挑战赛】3、使用自制数据集——卷积神经网络(CNN)天气识别

    活动地址:CSDN21天学习挑战赛 通过前两课的学习,加上私底下恶补基础,照猫画虎的基本算是掌握了卷积神经网络-CNN搭建模型的基本方法. 之前使用的,都是使用的现成的数据集,想想,如果今后真的需要应 ...

  4. 曹健老师 TensorFlow2.1 —— 第四章 网络八股扩展

    第一章 第二章 第三章 本章目的:扩展六步法功能,并实现应用. 4.1 搭建网络八股总览 利用自制数据集,解决本领域应用 利用数据增强,解决数据量过少问题,扩展数据,提高泛化力 利用断点续训,实时保存 ...

  5. 自制数据集之labelme软件的使用,深度学习入门(1)

    自制数据集之labelme软件的使用,深度学习入门(1) 说明 一.安装labelme 二.使用labelme标注 三 解析json文件 四.批处理json文件夹 说明 因为之前做语义分割项目需要自己 ...

  6. YOLO v5 实现目标检测(参考数据集自制数据集)

    YOLO v5 实现目标检测(参考数据集&自制数据集) Author: Labyrinthine Leo   Init_time: 2020.10.26 GitHub: https://git ...

  7. mask-rcnn训练测试自制数据集

    mask-rcnn训练测试自制数据集 本项目简介 本项目用于口腔模型分割,数据类型有7种,本文主要用于介绍如何使用自制数据集训练自己的模型 训练环境配置 操作系统:win10 GPU: GTX 108 ...

  8. yolov5s 预训练模型_YOLO v5 实现目标检测(参考数据集自制数据集)

    YOLO v5 实现目标检测(参考数据集&自制数据集) Author: Labyrinthine Leo   Init_time: 2020.10.26 GitHub: https://git ...

  9. tensorflow2自制数据集实线猫狗分类

    自制数据集实现猫狗分类 使用自制数据集训练神经网络模型实现猫狗分类器.使用框架为tensorflow2.网络结构为ResNet,网络结构,4个ResNetBlock,每个结构块4层卷积层,非跳连网络层 ...

最新文章

  1. java io流读写文件换行_java基础io流——OutputStream和InputStream的故事(温故知新)...
  2. 如何通过DBLINK取REMOTE DB的DDL
  3. JAVA微信开发:[17]如何获取所有关注用户
  4. bash特性之四、五
  5. 小程序开发(7)-之获取手机号、用户信息
  6. python mysql实例_Python 操作MySQL详解及实例
  7. micropython编程软件下载_MicroPython可视化拼插编辑器:让硬件编程更智能!
  8. nginx lua获取客户端ip
  9. 记号(notation)的学习
  10. linux应用编程之进程间同步
  11. php框架启动过程,框架启动方式 - CrossPHP 框架文档
  12. 深度神经网络 分布式训练 动手学深度学习v2
  13. windows编译opencv+opencv_contrib 以及解决cmake下载boostdesc_bgm等文件失败问题
  14. Ubuntu 16.04 安装搜狗输入法
  15. 零基础入行IC,选模拟版图还是数字后端?
  16. 如何调整基准电压提高ADC精度
  17. 百度地图绘制3D棱柱
  18. 取消改写模式(python)
  19. 如何加入家庭组计算机打印机,解决方案:Win7系统设置家庭组计算机设置共享打印机...
  20. 阿里云大数据平台的实操:ODPS的SQL语句

热门文章

  1. H5工程师跨页面取值的几种方法
  2. 有var d = new Date(‘20xx-m-09‘),可以设置为m+1月份的操作是?
  3. java properties 保存_Java 读写Properties配置文件
  4. linux安装程序过程,linux 应用程序安装过程
  5. rnn神经网络 层次_精讲深度学习RNN三大核心点,三分钟掌握循环神经网络
  6. 游戏大厅 从基础开始(6)--绕回来细说聊天室(中)之女仆编年史1
  7. Unity3D入门其实很简单
  8. ptmalloc内存分配和回收详解(文字版)
  9. Zabbix全方位告警接入-电话/微信/短信都支持
  10. Spring MVC-集成(Integration)-集成LOG4J示例(转载实践)