简单步骤

加载一张照片
维度转换为224*224
转换为array

from keras.preprocessing.image import load_img,img_to_arrayimg_path = '1.jpg'
img = load_img(img_path,target_size=(224,224))
img = img_to_array(img)
type(img)

numpy.ndarray

获取vgg16 主要卷积层 不要后面的全连接层 (自己写)
改变输入数据的维度
图像预处理

from keras.applications.vgg16 import VGG16
from keras.applications.vgg16 import preprocess_input
import numpy as np
model_vgg = VGG16(weights='imagenet',include_top=False)
x = np.expand_dims(img,axis=0)
x = preprocess_input(x)
print(x.shape)

(1, 224, 224, 3)

#特征提取
features = model_vgg.predict(x)`在这里插入代码片`
print(features.shape)

(1, 7, 7, 512)

#全连接层准备   改变数据维度
features = features.reshape(1,7*7*512)
print(features.shape)

(1, 25088)

整体处理

from keras.preprocessing.image import img_to_array,load_img
from keras.applications.vgg16 import VGG16
from keras.applications.vgg16 import preprocess_input
import numpy as np
# 数据预处理和模型加载
model_vgg = VGG16(weights='imagenet', include_top=False)
#define a method to load and preprocess the image
def modelProcess(img_path,model):img = load_img(img_path, target_size=(224, 224))img = img_to_array(img)x = np.expand_dims(img,axis=0)x = preprocess_input(x)x_vgg = model.predict(x)x_vgg = x_vgg.reshape(1,25088)return x_vgg
#list file names of the training datasets
import os
folder = "dataset/data_vgg/cats"
dirs = os.listdir(folder)
#generate path for the images
img_path = []
for i in dirs:                             if os.path.splitext(i)[1] == ".jpg":   img_path.append(i)
img_path = [folder+"//"+i for i in img_path]#preprocess multiple images
features1 = np.zeros([len(img_path),25088])
for i in range(len(img_path)):feature_i = modelProcess(img_path[i],model_vgg)print('preprocessed:',img_path[i])features1[i] = feature_ifolder = "dataset/data_vgg/dogs"
dirs = os.listdir(folder)
img_path = []
for i in dirs:                             if os.path.splitext(i)[1] == ".jpg":   img_path.append(i)
img_path = [folder+"//"+i for i in img_path]
features2 = np.zeros([len(img_path),25088])
for i in range(len(img_path)):feature_i = modelProcess(img_path[i],model_vgg)print('preprocessed:',img_path[i])features2[i] = feature_i#label the results
print(features1.shape,features2.shape)
y1 = np.zeros(300)
y2 = np.ones(300)#generate the training data
X = np.concatenate((features1,features2),axis=0)
y = np.concatenate((y1,y2),axis=0)
y = y.reshape(-1,1)
print(X.shape,y.shape)from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.3,random_state=50)
print(X_train.shape,X_test.shape,X.shape)
from keras.models import Sequential
from keras.layers import Dense
# 构建模型 两个全连接层 25088-10 10-1 (1就是二分类)
model = Sequential()
model.add(Dense(units=10,activation='relu',input_dim=25088))
model.add(Dense(units=1,activation='sigmoid'))
model.summary()
#定义loss opt acc基本参数
odel.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])
#train the model
model.fit(X_train,y_train,epochs=50)
#测试准确率
from sklearn.metrics import accuracy_score
y_train_predict = model.predict_classes(X_train)
accuracy_train = accuracy_score(y_train,y_train_predict)
print(accuracy_train)
y_test_predict = model.predict_classes(X_test)
accuracy_test = accuracy_score(y_test,y_test_predict)
print(accuracy_test)
# 找个网图测试
img_path = 'cat1.jpg'
img = load_img(img_path,target_size=(224,224))
img = img_to_array(img)
x = np.expand_dims(img,axis=0)
x = preprocess_input(x)
features = model_vgg.predict(x)
features = features.reshape(1,7*7*512)
result = model.predict_classes(features)
print(result)
可视化测试  找10个图片名字是1-10
import matplotlib as mlp
font2 = {'family' : 'SimHei',
'weight' : 'normal',
'size'   : 20,
}
mlp.rcParams['font.family'] = 'SimHei'
mlp.rcParams['axes.unicode_minus'] = False
from matplotlib import pyplot as plt
from matplotlib.image import imread
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.models import load_model
#from cv2 import load_img
a = [i for i in range(1,10)]
fig = plt.figure(figsize=(10,10))
for i in a:img_name = str(i)+'.jpg'img_path = img_nameimg = load_img(img_path, target_size=(224, 224))img = img_to_array(img)x = np.expand_dims(img,axis=0)x = preprocess_input(x)x_vgg = model_vgg.predict(x)x_vgg = x_vgg.reshape(1,25088)result = model.predict_classes(x_vgg)img_ori = load_img(img_name, target_size=(250, 250))plt.subplot(3,3,i)plt.imshow(img_ori)plt.title('预测为:狗狗' if result[0][0] == 1 else '预测为:猫咪')
plt.show()

结果如图

VGG16的猫狗识别相关推荐

  1. 深度学习100例-卷积神经网络(VGG-16)猫狗识别 | 第21天

    最近更新有点慢,后台收到不少小伙伴的催更,先说声抱歉哈.最近在参加一个目标检测的比赛,时间比较紧张.这段时间我也打算调整一下思路,试着将目标检测中涉及的内容拆开来,将这些拆分的内容一点点融入到在后续的 ...

  2. 小试牛刀:猫狗识别 Cat VS Dog

    小试牛刀:用猫狗识别来练练手(只用了 10% 数据来训练) ps: 猫狗识别拿来做新算/想法的尝试,也是一个挺不错的选择 怎么能少了实现代码:https://github.com/Azure-Sky- ...

  3. 【卷积神经网络】CNN详解以及猫狗识别实例

    文章目录 一.卷积神经网络(CNN)介绍 1.1 整体结构 1.2 说明 1.3 特点 1.4 应用领域 二.配置实验环境 三.猫狗识别实例 3.1 准备数据集 3.2 图片分类 3.3 网络模型搭建 ...

  4. 华为云深度学习kaggle猫狗识别

    使用华为云深度学习服务完成kaggle猫狗识别竞赛 参考: kaggle猫狗竞赛kernel第一名的代码 Tensorflow官网代码 华为云DLS服务github代码 1. 环境配置与数据集处理 首 ...

  5. 详解pytorch实现猫狗识别98%附代码

    详解pytorch实现猫狗识别98%附代码 前言 一.为什么选用pytorch这个框架? 二.实现效果 三.神经网络从头到尾 1.来源:仿照人为处理图片的流程,模拟人们的神经元处理信息的方式 2.总览 ...

  6. tensorflow2.3.0迁移学习案例分析(以猫狗识别为例)

    我对迁移学习的简单理解就是,将别人训练好的模型用到自己的程序中,同时根据实际情况,重新训练模型中的部分参数.我认为这样有2个好处,一是由于使用了已知的模型,那么节省训练的时间,二是在充分利用已知的成果 ...

  7. Tensorflow实现kaggle猫狗识别(循序渐进进行网络设计)

    这篇是tensorflow版本,pytorch版本会在下一篇博客给出 友情提示:尽量上GPU,博主CPU上跑一个VGG16花了1.5h... Tensorflow实现kaggle猫狗识别 数据集获取 ...

  8. 基于卷积神经网络VGG的猫狗识别

    !有需要本项目的实验源码的可以私信博主! 摘要:随着大数据时代的到来,深度学习.数据挖掘.图像处理等已经成为了一个热门研究方向.深度学习是一个复杂的机器学习算法,在语音和图像识别方面取得的效果,远远超 ...

  9. 猫狗识别——PyTorch

    猫狗识别 数据集下载: 网盘链接:https://pan.baidu.com/s/1SlNAPf3NbgPyf93XluM7Fg 提取密码:hpn4 1. 要导入的包 import os import ...

最新文章

  1. 字符串算法--KMP--Java实现
  2. 数据解读:资本追逐的14个人工智能细分领域
  3. 1044 拦截导弹——http://codevs.cn/problem/1044/
  4. 计算机网络期中考试题周静,期中考试优秀作文
  5. JQuery学习使用笔记 -- JQuery插件开发
  6. centos7 + VMware Workstation Pro
  7. JavaBean 持久化
  8. console对象的方法log、info、warn、error的区别及几个实用的方法
  9. python自动测试v_python下selenium自动化测试自我实践
  10. Task Parallel Library
  11. 筛选数据库_网络药理学(2)| 使用TCMSP数据库检索中药成分并基于ADME参数进行成分筛选...
  12. 最强的linux命令总结.pdf
  13. 用python处理mp4与gif格式互转,简单到爆!
  14. 在HTTPS网站安装百度分享按钮代码及其SEO外链作用
  15. 「视频直播技术详解」系列之六:现代播放器原理
  16. php7.4新特性 多线程,PHP7新特性WhatwillbeinPHP7/PHPNG
  17. 背景图全屏水平垂直居中的方法
  18. checkstyle + gradle + git pre-commit 实现代码提交前对代码规范的检查
  19. 关于全球变暖,你应该知道的事实
  20. 【数据分析】基于matlab GUI学生成绩管理系统【含Matlab源码 1981期】

热门文章

  1. 大数据工程师岗位工作内容是什么
  2. 34岁程序员面试美团被拒绝:只招30岁以下,卖力能加班工资又少的
  3. js获取当前时间并转换为一定的格式
  4. IDEA中maven项目的language level 修改后自动重置问题
  5. sim_com AT
  6. 【游戏设计模式】之三 状态模式、有限状态机
  7. 赏析角度有哪些_从“视听语言”角度,主要从哪些方面进行影视赏析(? ?)。...
  8. 什么是zigbee?
  9. 谷歌网站错误服务器连接,谷歌浏览器 您的链接存在安全隐患 此网站是用的安全配置已过时...---服务器 TLS1.0 1TLS.2配置方法...
  10. RPA应用场景-交通违章查询