本文逻辑:

我从网上下载了十几张猫和狗的图片,用于检验我们训练好的模型。

处理我们下载的图片

加载模型

将图片输入模型进行检验

代码如下:

#coding=utf-8

import tensorflow as tf

from PIL import Image

import matplotlib.pyplot as plt

import input_data

import numpy as np

import model

import os

#从指定目录中选取一张图片

def get_one_image(train):

files = os.listdir(train)

n = len(files)

ind = np.random.randint(0,n)

img_dir = os.path.join(train,files[ind])

image = Image.open(img_dir)

plt.imshow(image)

plt.show()

image = image.resize([208, 208])

image = np.array(image)

return image

def evaluate_one_image():

#存放的是我从百度下载的猫狗图片路径

train = '/Users/yangyibo/GitWork/pythonLean/AI/猫狗识别/testImg/'

image_array = get_one_image(train)

with tf.Graph().as_default():

BATCH_SIZE = 1 # 因为只读取一副图片 所以batch 设置为1

N_CLASSES = 2 # 2个输出神经元,[1,0] 或者 [0,1]猫和狗的概率

# 转化图片格式

image = tf.cast(image_array, tf.float32)

# 图片标准化

image = tf.image.per_image_standardization(image)

# 图片原来是三维的 [208, 208, 3] 重新定义图片形状 改为一个4D 四维的 tensor

image = tf.reshape(image, [1, 208, 208, 3])

logit = model.inference(image, BATCH_SIZE, N_CLASSES)

# 因为 inference 的返回没有用激活函数,所以在这里对结果用softmax 激活

logit = tf.nn.softmax(logit)

# 用最原始的输入数据的方式向模型输入数据 placeholder

x = tf.placeholder(tf.float32, shape=[208, 208, 3])

# 我门存放模型的路径

logs_train_dir = '/Users/yangyibo/GitWork/pythonLean/AI/猫狗识别/saveNet/'

# 定义saver

saver = tf.train.Saver()

with tf.Session() as sess:

print("从指定的路径中加载模型。。。。")

# 将模型加载到sess 中

ckpt = tf.train.get_checkpoint_state(logs_train_dir)

if ckpt and ckpt.model_checkpoint_path:

global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]

saver.restore(sess, ckpt.model_checkpoint_path)

print('模型加载成功, 训练的步数为 %s' % global_step)

else:

print('模型加载失败,,,文件没有找到')

# 将图片输入到模型计算

prediction = sess.run(logit, feed_dict={x: image_array})

# 获取输出结果中最大概率的索引

max_index = np.argmax(prediction)

if max_index==0:

print('猫的概率 %.6f' %prediction[:, 0])

else:

print('狗的概率 %.6f' %prediction[:, 1])

# 测试

evaluate_one_image()

/Users/yangyibo/GitWork/pythonLean/AI/猫狗识别/testImg/ 存放的是我从百度下载的猫狗图片

执行结果:

因为从testimg 中选取图片是随机的,所以每次执行的结果不同

从指定的路径中加载模型。。。。

模型加载成功, 训练的步数为 11999

狗的概率 0.964047

[Finished in 6.8s]

欢迎star。

总结

以上就是这篇文章的全部内容了,希望本文的内容对大家的学习或者工作具有一定的参考学习价值,谢谢大家对脚本之家的支持。如果你想了解更多相关内容请查看下面相关链接

python 图片比较 猫_TensorFlow卷积神经网络之使用训练好的模型识别猫狗图片相关推荐

  1. python狗图像识别_TensorFlow卷积神经网络之使用训练好的模型识别猫狗图片

    本文是Python通过TensorFlow卷积神经网络实现猫狗识别的姊妹篇,是加载上一篇训练好的模型,进行猫狗识别 本文逻辑: 我从网上下载了十几张猫和狗的图片,用于检验我们训练好的模型. 处理我们下 ...

  2. Python深度学习实例--基于卷积神经网络的小型数据处理(猫狗分类)

    Python深度学习实例--基于卷积神经网络的小型数据处理(猫狗分类) 1.卷积神经网络 1.1卷积神经网络简介 1.2卷积运算 1.3 深度学习与小数据问题的相关性 2.下载数据 2.1下载原始数据 ...

  3. 【深度学习】基于Torch的Python开源机器学习库PyTorch卷积神经网络

    [深度学习]基于Torch的Python开源机器学习库PyTorch卷积神经网络 文章目录 1 CNN概述 2 PyTorch实现步骤2.1 加载数据2.2 CNN模型2.3 训练2.4 可视化训练 ...

  4. 论文解析:人脸检测中级联卷积神经网络的联合训练

    论文解析:人脸检测中级联卷积神经网络的联合训练 商汤科技解析CVPR2016论文:人脸检测中级联卷积神经网络的联合训练 width="250" height="250&q ...

  5. python卷积神经网络cnn的训练算法_【深度学习系列】卷积神经网络CNN原理详解(一)——基本原理...

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  6. 商品识别系统Python,基于深度学习卷积神经网络

    介绍 商品识别系统采用了Python.TensorFlow.ResNet50算法以及Django等技术栈.其中,Python作为主要的编程语言,它的清晰简洁的语法使得代码易于阅读和编写.TensorF ...

  7. Python交通标志识别基于卷积神经网络的保姆级教程(Tensorflow)

    项目介绍 TensorFlow2.X 搭建卷积神经网络(CNN),实现交通标志识别.搭建的卷积神经网络是类似VGG的结构(卷积层与池化层反复堆叠,然后经过全连接层,最后用softmax映射为每个类别的 ...

  8. python机器学习库keras——CNN卷积神经网络人脸识别

    全栈工程师开发手册 (作者:栾鹏) python教程全解 github地址:https://github.com/626626cdllp/kears/tree/master/Face_Recognit ...

  9. Keras 搭建图片分类 CNN (卷积神经网络)

    1. 导入keras from keras.models import Sequential from keras.layers import Conv2D, MaxPooling2D, Flatte ...

最新文章

  1. ICCV 2021 Oral | NerfingMVS:引导优化神经辐射场实现室内多视角三维重建
  2. jquery中not方法失效的解决方案
  3. 运维基础(13)日志切割工具 Logrotate
  4. 云服务器系统租赁费用,云服务器创建租赁费用
  5. 擴展PictureBox的一個組件
  6. 循环链表解决约瑟夫环问题
  7. android 保存文件_Android 数据库操作框架LitePal使用介绍(一)
  8. 栈的基本操作(数组/链表)
  9. eclipse android 第一个程序,Eclipse 开发 Android,第一个 HelloWord 程序(学习1)-Fun言
  10. java泛型约束_java泛型
  11. IE与FF的常见兼容问题及总结
  12. Mysql查看编码方式专题
  13. 思科三层交换机开启ipv6路由功能_三层交换机实现路由功能配置示例与详解 (Cisco Packer Tracer 模拟器)...
  14. 安卓 多条通知_【安卓+苹果】石头阅读,全网小说、漫画免费看,最好用的追书神器!...
  15. vrp问题的java_VRP(车辆路径问题)的两种简单算法
  16. Berkeley CS 61B 学习笔记 - 1
  17. python界面设计实例qt_Python GUI教程(六):使用Qt设计师进行窗口布局
  18. ad19原理图标注_AD19原理图ID复位
  19. 03 Python安装 - 编辑器安装
  20. regulatory domain

热门文章

  1. 图的顺序存储及其深度优先遍历和广度优先遍历
  2. Python登录qq邮箱发送邮件(附件)
  3. 微信小程序输入框聚焦获取键盘安全高度
  4. Cronolog切割tomcat日志
  5. Rasa课程、Rasa培训、Rasa面试、Rasa实战系列之Understanding Word Embeddings CBOW and Skip Gram
  6. mysql分组后占比、累计占比和排序计算方法
  7. 虚拟主播软件有哪些?哪家的虚拟软件比较好用?
  8. Win32 OpenGL编程(4) 2D图形基础(颜色及坐标体系进阶知识)
  9. 如何用C语言完成水仙花数的搜索
  10. springboot项目导入Redis依赖后在测试类中无法使用(RedisTemplate),报空指针