数据集地址:

链接:#####https://pan.baidu.com/s/1ZXVb7M5p0JtS1edYRCJG2w
#####提取码:1xhz

代码如下:

#coding=utf-8
import os
#图像读取库
from PIL import Image
#矩阵运算库
import numpy as np
import tensorflow as tf#数据集文件夹
train_dir = r"D:/Modeling_2/train/"
test_dir = r"D:/Modeling_2/test/"#是否训练
train = True
#模型文件路径
model_path = "D:/Modeling_2/models/image_model"def read_data(train_dir):file_paths = []datas = []labels = []# 读取 dir路径下的 文件的名称for file_name in os.listdir(train_dir):# 将该路径下的文件 及其路径一同打印file_path = os.path.join(train_dir, file_name)file_paths.append(file_path)# 返回(图片类型JPEG,大小32*32,RGB)image = Image.open(file_path)# 归一化处理data = np.array(image) / 255.0# 取出标签label = int(file_name.split("_")[0])datas.append(data)labels.append(label)datas = np.array(datas) labels = np.array(labels) print("shape of datas: {}\tshape of labels: {}".format(datas.shape, labels.shape)) return file_paths, datas, labels#定义权重
def weight_variable(shape):initial = tf.truncated_normal(shape, stddev=0.1)return tf.Variable(initial)#定义偏置
def bias_variable(shape):initial = tf.constant(0.1, shape=shape)return tf.Variable(initial)#卷积层
def conv2d(x, W):return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding="SAME")#池化层 2x2池化层 步长为2
def max_2x2_pool(x):return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding="SAME")if __name__ == "__main__":#返回文件路径,数据,标签file_paths, datas, labels = read_data(train_dir)# 分类数num_classes = len(set(labels))# 数据与标签x = tf.placeholder(tf.float32, [None, 32, 32, 3])y = tf.placeholder(tf.int32, [None])# 5x5x3 32个卷积核W_conv1 = weight_variable([5, 5, 3, 32])b_conv1 = bias_variable([32])# relu函数,卷积,池化h_conv1 = tf.nn.relu(conv2d(x, W_conv1)+ b_conv1)h_pool1 = max_2x2_pool(h_conv1)# 4x4x32 64个卷积核W_conv2 = weight_variable([5,5,32,64])b_conv2 = bias_variable([64])h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2)+ b_conv2)h_pool2 = max_2x2_pool(h_conv2)# 5x5x3 32个卷积核W_conv3 = weight_variable([5, 5, 64, 128])b_conv3 = bias_variable([128])# relu函数,卷积,池化h_conv3 = tf.nn.relu(conv2d(h_pool2, W_conv3) + b_conv3)h_pool3 = max_2x2_pool(h_conv3)# 4x4x32 64个卷积核W_conv4 = weight_variable([5, 5, 128, 256])b_conv4 = bias_variable([256])h_conv4 = tf.nn.relu(conv2d(h_pool3, W_conv4) + b_conv4)h_pool4 = max_2x2_pool(h_conv4)# 4次卷积池化过后 32/2/2/2/2 = 2  2*2*256# 通过计算得到 神经元个数 2*2*256# 扁平化处理h_pool2_flat = tf.reshape(h_pool4, [-1, 2 * 2 * 256])W_fc1 = weight_variable([2 * 2 * 256, num_classes])b_fc1 = bias_variable([num_classes])h_fc1 = tf.matmul(h_pool2_flat, W_fc1) + b_fc1# dropout优化keep_prob = tf.placeholder(tf.float32)h_fc1_dropout = tf.nn.dropout(h_fc1, keep_prob)logits = h_fc1# 返回最大值的标签prediction_label = tf.argmax(logits, 1)# 交叉熵函数 tf.one_hot(输入(一维),深度)losses = tf.nn.softmax_cross_entropy_with_logits(labels=tf.one_hot(y, num_classes), logits=logits)mean_loss =tf.reduce_mean(losses)# 最速梯度下降法train_step = tf.train.AdamOptimizer(1e-4).minimize(losses)#交叉熵代价函数correct_predition = tf.equal(tf.argmax(tf.one_hot(y, num_classes), -1), tf.argmax(logits, -1))# 求准确率accuracy = tf.reduce_mean(tf.cast(correct_predition, tf.float32))Saver = tf.train.Saver()with tf.Session() as sess:if train:print("训练模式!")sess.run(tf.global_variables_initializer())train_feed_dict = {x:datas, y:labels, keep_prob:0.5}for step in range(1001):_, mean_loss_val  = sess.run([train_step, mean_loss], feed_dict=train_feed_dict)if step%20 ==0:print("step = {}\tmean loss = {}".format(step, mean_loss_val))train_acc =  sess.run(accuracy, feed_dict={x:datas, y:labels})print("准确率:", train_acc)Saver.save(sess, model_path)print("训练结束,保存模型到{}".format(model_path))else:print("测试模式")Saver.restore(sess, model_path)print("从{}载入模型".format(model_path))label_name_dict = {0:"百合花",1:"白玉兰",2:"茉莉花",3:"栀子花"}test_feed_dict ={x:datas, y:labels, keep_prob:0}prediction_val = sess.run(prediction_label, feed_dict=test_feed_dict)for file_paths, real_label, predicted_label in zip(file_paths, labels, prediction_val):# 将label id转换为label名real_label_name = label_name_dict[real_label]predicted_label_name = label_name_dict[predicted_label]print("{}\t{} => {}".format(file_paths, real_label_name, predicted_label_name))

tensorflow花朵分类相关推荐

  1. [Python人工智能] 六.TensorFlow实现分类学习及MNIST手写体识别案例

    从本专栏开始,作者正式开始研究Python深度学习.神经网络及人工智能相关知识.前一篇文章讲解了Tensorboard可视化的基本用法,并绘制整个神经网络及训练.学习的参数变化情况:本篇文章将通过Te ...

  2. Python深度学习实战:多类花朵分类

    Python深度学习实战:多类花朵分类 鸢尾花分类数据集 导入库和函数 指定随机数种子 导入数据 输出变量编码 设计神经网络 用K折交叉检验测试模型 总结 本章我们使用Keras为多类分类开发并验证一 ...

  3. 第10章 项目:多类花朵分类

    第10章 项目:多类花朵分类 本章我们使用Keras为多类分类开发并验证一个神经网络.本章包括: 将CSV导入Keras 为Keras预处理数据 使用scikit-learn验证Keras模型 我们开 ...

  4. tensorflow迁移学习:VGG16花朵分类

    转自:https://blog.csdn.net/weixin_41770169/article/details/80330581 github地址:https://github.com/126921 ...

  5. 热文 | 卷积神经网络入门案例,轻松实现花朵分类

    作者 | 黎国溥 责编 | 寇雪芹 出品 | AI 科技大本营(ID:rgznai100) 前言 本文介绍卷积神经网络的入门案例,通过搭建和训练一个模型,来对几种常见的花朵进行识别分类:使用到TF的花 ...

  6. tensorflow: 花卉分类

    本文主要通过CNN进行花卉的分类,训练结束保存模型,最后通过调用模型,输入花卉的图片通过模型来进行类别的预测. 测试平台:win 10+tensorflow 1.2 数据集:http://downlo ...

  7. “花朵分类“ 手把手搭建【卷积神经网络】

    前言 本文介绍卷积神经网络的入门案例,通过搭建和训练一个模型,来对几种常见的花朵进行识别分类: 使用到TF的花朵数据集,它包含5类,即:"雏菊","蒲公英",& ...

  8. 人工智能学习07--pytorch11--分类网络:使用pytorch和tensorflow计算分类模型的混淆矩阵

    师兄说学目标检测之前先学分类 坏了,内容好多!学学学 感谢up主,好人一生平安 混淆矩阵 什么是混淆矩阵: 横坐标:每一列属于该类的所有验证样本.每一列所有元素对应真实类别. 纵坐标:网络的预测类别. ...

  9. TensorFlow 实现分类操作的函数学习

    函数:tf.nn.sigmoid_cross_entropy_with_logits(logits, targets, name=None) 说明:此函数是计算logits经过sigmod函数后的交叉 ...

最新文章

  1. 送你9个快速使用Pytorch训练解决神经网络的技巧(附代码)
  2. JavaScript——定时器(setTimeout/setInterval)
  3. 浅谈块级元素和内联元素的嵌套规则
  4. php标准库string,PHP中的一些标准库
  5. COM, COM+ and .NET 程序集的区别
  6. CentOS系统VMware克隆后 重新设置成eth0
  7. OC代码调用C++代码的回调函数步骤
  8. 消息队列(MQ):ZeroMQ 中间件设计【译文】
  9. 浏览器首页被360恶意篡改,解决方法
  10. 启用mysql系统找不到指定的文件类型_net start mysql 发生系统错误2 系统找不到指定的文件...
  11. AutoCAD .Net 不同文档间复制对象
  12. 8253工作方式区别、计数初值及应用
  13. 黑客入门(超级详细版)
  14. java通过framer生成word_DSO Framer Control Object 操作word文件
  15. Activiti实现会签功能程序Demo
  16. python csv
  17. docker-compose(部署微服务+MySQL)
  18. 香港特区银行怎么开帐户?
  19. CockroachDB的raft优化
  20. heic格式怎么转换?其实转换很简单

热门文章

  1. 雄智手机进销存系统 绿色
  2. Illustrator中纸杯扇形做法
  3. 相约QCon北京2013大会,图灵全程为您准备好图书
  4. (产品贴)百度手机卫士竞品分析报告
  5. android 获取蓝牙信号强度,连接后获取蓝牙RSSI信号强度
  6. 【读书笔记】关于写读书笔记的阶段性总结
  7. layui中layui-form与select2下拉多选样式冲突的解决方案
  8. [C#学习] BindingNavigator控件
  9. 两个世界2城堡防御攻略
  10. Nginx之Location