文章目录

  • 前言
  • 一、cifar100数据集简介
  • 二、实战目的

前言

这是一个真正能跑起来的程序,前提是安装好TensorFlow框架。
不明白的地方欢迎留言~


提示:以下是本篇文章正文内容,下面案例可供参考

一、cifar100数据集简介

CIFAR-100数据集由100个类的60000个32x32彩色图像组成,每个类有6000个图像。有50000个训练图像和10000个测试图像。
数据集分为五个训练批次和一个测试批次,每个批次有10000个图像。测试批次包含来自每个类别的恰好1000个随机选择的图像。训练批次以随机顺序包含剩余图像,但一些训练批次可能包含来自一个类别的图像比另一个更多。总体来说,五个训练集之和包含来自每个类的正好5000张图像。[ 百度百科]
以下是数据集中的类,以及来自每个类的10个随机图像:

二、实战目的

通过tensorflow2.0下载cifar100数据集,建立13层的神经网络(10个卷积层和3个全连接层)对训练集数据进行训练,将训练好的网络用来预测测试集上的图片,准确率高达50%以上。

代码如下(示例):

import  tensorflow as tf
from    tensorflow.keras import layers, optimizers, datasets, Sequential
import  osos.environ['TF_CPP_MIN_LOG_LEVEL']='2'
tf.random.set_seed(2345)conv_layers = [ # 5 units of conv + max pooling# unit 1layers.Conv2D(64, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),layers.Conv2D(64, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same'),# unit 2layers.Conv2D(128, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),layers.Conv2D(128, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same'),# unit 3layers.Conv2D(256, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),layers.Conv2D(256, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same'),# unit 4layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same'),# unit 5layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same')]def preprocess(x, y):# [0~1]x = tf.cast(x, dtype=tf.float32) / 255.y = tf.cast(y, dtype=tf.int32)return x,y(x,y), (x_test, y_test) = datasets.cifar100.load_data()
y = tf.squeeze(y, axis=1)
y_test = tf.squeeze(y_test, axis=1)
print(x.shape, y.shape, x_test.shape, y_test.shape)train_db = tf.data.Dataset.from_tensor_slices((x,y))
train_db = train_db.shuffle(1000).map(preprocess).batch(128)test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db = test_db.map(preprocess).batch(64)sample = next(iter(train_db))
print('sample:', sample[0].shape, sample[1].shape,tf.reduce_min(sample[0]), tf.reduce_max(sample[0]))def main():# [b, 32, 32, 3] => [b, 1, 1, 512]conv_net = Sequential(conv_layers)fc_net = Sequential([layers.Dense(256, activation=tf.nn.relu),layers.Dense(128, activation=tf.nn.relu),layers.Dense(100, activation=None),])conv_net.build(input_shape=[None, 32, 32, 3])fc_net.build(input_shape=[None, 512])optimizer = optimizers.Adam(lr=1e-4)# [1, 2] + [3, 4] => [1, 2, 3, 4]variables = conv_net.trainable_variables + fc_net.trainable_variablesfor epoch in range(50):for step, (x,y) in enumerate(train_db):with tf.GradientTape() as tape:# [b, 32, 32, 3] => [b, 1, 1, 512]out = conv_net(x)# flatten, => [b, 512]out = tf.reshape(out, [-1, 512])# [b, 512] => [b, 100]logits = fc_net(out)# [b] => [b, 100]y_onehot = tf.one_hot(y, depth=100)# compute lossloss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)loss = tf.reduce_mean(loss)grads = tape.gradient(loss, variables)optimizer.apply_gradients(zip(grads, variables))if step %100 == 0:print(epoch, step, 'loss:', float(loss))total_num = 0total_correct = 0for x,y in test_db:out = conv_net(x)out = tf.reshape(out, [-1, 512])logits = fc_net(out)prob = tf.nn.softmax(logits, axis=1)pred = tf.argmax(prob, axis=1)pred = tf.cast(pred, dtype=tf.int32)correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)correct = tf.reduce_sum(correct)total_num += x.shape[0]total_correct += int(correct)acc = total_correct / total_numprint(epoch, 'acc:', acc)if __name__ == '__main__':main()

基于Tensorflow针对cifar数据集运用卷积神经网络解决100类图片的分类问题。相关推荐

  1. 通过 Tensorflow 的基础类,构建卷积神经网络,用于花朵图片的分类

    实验目的 通过 Tensorflow 的基础类,构建卷积神经网络,用于花朵图片的分类. 实验环境 import tensorflow as tfprint(tf.__version__) output ...

  2. 用卷积神经网络实现猫狗图片分类

    该例程使用数据集来源于 kaggle cat_VS _dog 数据集中的一部分, 用卷积神经网络实现猫狗图片二分类,例程序比较简单,就不多解释了,代码中会有相应的注释,直接上代码: import nu ...

  3. python神经网络库识别验证码_基于TensorFlow 使用卷积神经网络识别字符型图片验证码...

    本项目使用卷积神经网络识别字符型图片验证码,其基于TensorFlow 框架.它封装了非常通用的校验.训练.验证.识别和调用 API,极大地减低了识别字符型验证码花费的时间和精力. 项目地址:http ...

  4. 论文翻译:基于深度卷积神经网络的肉鸡粪便识别与分类

    Recognition and Classification of Broiler Droppings Based on Deep Convolutional Neural Network 基于深度卷 ...

  5. 基于多源信息的深度卷积神经网络预测CircRNA疾病关联的有效方法

    An Efficient Approach based on Multi-sources Information to Predict CircRNA-disease Associations Usi ...

  6. 基于卷积神经网络和投票机制的三维模型分类与检索 2019 论文笔记

    作者:白静 计算机辅助设计与图形学学报 1.解决的问题 由于三维模型投影得到的视图是由不同视点得到,具有相对独立性,这种像素级的融合运算并没有直接的物理或者几何意义,更有可能造成图像有益信息淹没和混淆 ...

  7. 基于卷积神经网络的不良地质体识别与分类

    在泛函分析中,卷积.旋积或摺积(英语:Convolution)是通过两个函数f 和g 生成第三个函数的一种数学算子,表征函数f 与g经过翻转和平移的重叠部分的面积. 如果将参加卷积的一个函数看作区间的 ...

  8. 基于Pytorch再次解读NiN现代卷积神经网络和批量归一化

    个人简介:CSDN百万访问量博主,普普通通男大学生,深度学习算法.医学图像处理专攻,偶尔也搞全栈开发,没事就写文章,you feel me? 博客地址:lixiang.blog.csdn.net 基于 ...

  9. 基于Pytorch再次解读DenseNet现代卷积神经网络

    个人简介:CSDN百万访问量博主,普普通通男大学生,深度学习算法.医学图像处理专攻,偶尔也搞全栈开发,没事就写文章,you feel me? 博客地址:lixiang.blog.csdn.net 基于 ...

  10. 基于Pytorch再次解读LeNet-5现代卷积神经网络

    个人简介:CSDN百万访问量博主,普普通通男大学生,深度学习算法.医学图像处理专攻,偶尔也搞全栈开发,没事就写文章,you feel me? 博客地址:lixiang.blog.csdn.net 基于 ...

最新文章

  1. java多线程(三)
  2. 图片的奇怪Cache_MISS原因!
  3. Linux系统检测命令有哪些
  4. python 均方误差_一个很随意的Python智能优化库,一个文件就是一个库-- PySwarm
  5. 老罗学习MVC之旅:MVC组件分析
  6. C#怎么测试静态方法?我给出了2种方案
  7. happens-before规则和as-if-serial语义
  8. leetcode121买卖股票的最佳时机
  9. Oracle函数的信息,Oracle中获取会话信息的两个函数分享
  10. 数据分析常用的python包_量化投资数据分析之常用的python包(附代码)
  11. Oracle中表pagesize,Oracle 解决显示凌乱串行问题时column、pagesize、linesize的设定
  12. linux配置java环境变量(详细)(转)
  13. java filter param_Java过滤器Filter使用详解
  14. 计算机硬盘 安装,电脑新硬盘如何安装系统
  15. MODIS数据下载及图像处理教程
  16. 关于小米路由器的局域网内相互ping设备IP的解决方法
  17. 扫二维码 下载app
  18. 数据库空间管理-学习笔记
  19. Redis的初步使用教程
  20. Nexperia |超低电容 ESD 保护二极管保护汽车数据接口基础半导体器件

热门文章

  1. 日期相关的小函数汇总
  2. junper srx配置思路
  3. 基于OHCI的USB主机 —— UFI命令概述
  4. 06.SpringBoot的webjars和静态资源映射
  5. [Publish AAR To Maven] 使用 Gradle 发布 AAR 到 Maven 仓库
  6. [WebApi] 捣鼓一个资源管理器--多文件上传
  7. Win7下 安全、彻底删除Orcale数据库
  8. 数组模拟栈解决括号匹配
  9. 怎么判断二阶导数是否异号_「高等数学」给出函数的二阶导函数图形,求该曲线图形拐点的个数...
  10. 异常重试_面试题:如何基于 dubbo 进行服务治理、服务降级、失败重试?