VGG16的猫狗识别
简单步骤
加载一张照片
维度转换为224*224
转换为array
from keras.preprocessing.image import load_img,img_to_arrayimg_path = '1.jpg'
img = load_img(img_path,target_size=(224,224))
img = img_to_array(img)
type(img)
numpy.ndarray
获取vgg16 主要卷积层 不要后面的全连接层 (自己写)
改变输入数据的维度
图像预处理
from keras.applications.vgg16 import VGG16
from keras.applications.vgg16 import preprocess_input
import numpy as np
model_vgg = VGG16(weights='imagenet',include_top=False)
x = np.expand_dims(img,axis=0)
x = preprocess_input(x)
print(x.shape)
(1, 224, 224, 3)
#特征提取
features = model_vgg.predict(x)`在这里插入代码片`
print(features.shape)
(1, 7, 7, 512)
#全连接层准备 改变数据维度
features = features.reshape(1,7*7*512)
print(features.shape)
(1, 25088)
整体处理
from keras.preprocessing.image import img_to_array,load_img
from keras.applications.vgg16 import VGG16
from keras.applications.vgg16 import preprocess_input
import numpy as np
# 数据预处理和模型加载
model_vgg = VGG16(weights='imagenet', include_top=False)
#define a method to load and preprocess the image
def modelProcess(img_path,model):img = load_img(img_path, target_size=(224, 224))img = img_to_array(img)x = np.expand_dims(img,axis=0)x = preprocess_input(x)x_vgg = model.predict(x)x_vgg = x_vgg.reshape(1,25088)return x_vgg
#list file names of the training datasets
import os
folder = "dataset/data_vgg/cats"
dirs = os.listdir(folder)
#generate path for the images
img_path = []
for i in dirs: if os.path.splitext(i)[1] == ".jpg": img_path.append(i)
img_path = [folder+"//"+i for i in img_path]#preprocess multiple images
features1 = np.zeros([len(img_path),25088])
for i in range(len(img_path)):feature_i = modelProcess(img_path[i],model_vgg)print('preprocessed:',img_path[i])features1[i] = feature_ifolder = "dataset/data_vgg/dogs"
dirs = os.listdir(folder)
img_path = []
for i in dirs: if os.path.splitext(i)[1] == ".jpg": img_path.append(i)
img_path = [folder+"//"+i for i in img_path]
features2 = np.zeros([len(img_path),25088])
for i in range(len(img_path)):feature_i = modelProcess(img_path[i],model_vgg)print('preprocessed:',img_path[i])features2[i] = feature_i#label the results
print(features1.shape,features2.shape)
y1 = np.zeros(300)
y2 = np.ones(300)#generate the training data
X = np.concatenate((features1,features2),axis=0)
y = np.concatenate((y1,y2),axis=0)
y = y.reshape(-1,1)
print(X.shape,y.shape)from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.3,random_state=50)
print(X_train.shape,X_test.shape,X.shape)
from keras.models import Sequential
from keras.layers import Dense
# 构建模型 两个全连接层 25088-10 10-1 (1就是二分类)
model = Sequential()
model.add(Dense(units=10,activation='relu',input_dim=25088))
model.add(Dense(units=1,activation='sigmoid'))
model.summary()
#定义loss opt acc基本参数
odel.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])
#train the model
model.fit(X_train,y_train,epochs=50)
#测试准确率
from sklearn.metrics import accuracy_score
y_train_predict = model.predict_classes(X_train)
accuracy_train = accuracy_score(y_train,y_train_predict)
print(accuracy_train)
y_test_predict = model.predict_classes(X_test)
accuracy_test = accuracy_score(y_test,y_test_predict)
print(accuracy_test)
# 找个网图测试
img_path = 'cat1.jpg'
img = load_img(img_path,target_size=(224,224))
img = img_to_array(img)
x = np.expand_dims(img,axis=0)
x = preprocess_input(x)
features = model_vgg.predict(x)
features = features.reshape(1,7*7*512)
result = model.predict_classes(features)
print(result)
可视化测试 找10个图片名字是1-10
import matplotlib as mlp
font2 = {'family' : 'SimHei',
'weight' : 'normal',
'size' : 20,
}
mlp.rcParams['font.family'] = 'SimHei'
mlp.rcParams['axes.unicode_minus'] = False
from matplotlib import pyplot as plt
from matplotlib.image import imread
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.models import load_model
#from cv2 import load_img
a = [i for i in range(1,10)]
fig = plt.figure(figsize=(10,10))
for i in a:img_name = str(i)+'.jpg'img_path = img_nameimg = load_img(img_path, target_size=(224, 224))img = img_to_array(img)x = np.expand_dims(img,axis=0)x = preprocess_input(x)x_vgg = model_vgg.predict(x)x_vgg = x_vgg.reshape(1,25088)result = model.predict_classes(x_vgg)img_ori = load_img(img_name, target_size=(250, 250))plt.subplot(3,3,i)plt.imshow(img_ori)plt.title('预测为:狗狗' if result[0][0] == 1 else '预测为:猫咪')
plt.show()
结果如图
VGG16的猫狗识别相关推荐
- 深度学习100例-卷积神经网络(VGG-16)猫狗识别 | 第21天
最近更新有点慢,后台收到不少小伙伴的催更,先说声抱歉哈.最近在参加一个目标检测的比赛,时间比较紧张.这段时间我也打算调整一下思路,试着将目标检测中涉及的内容拆开来,将这些拆分的内容一点点融入到在后续的 ...
- 小试牛刀:猫狗识别 Cat VS Dog
小试牛刀:用猫狗识别来练练手(只用了 10% 数据来训练) ps: 猫狗识别拿来做新算/想法的尝试,也是一个挺不错的选择 怎么能少了实现代码:https://github.com/Azure-Sky- ...
- 【卷积神经网络】CNN详解以及猫狗识别实例
文章目录 一.卷积神经网络(CNN)介绍 1.1 整体结构 1.2 说明 1.3 特点 1.4 应用领域 二.配置实验环境 三.猫狗识别实例 3.1 准备数据集 3.2 图片分类 3.3 网络模型搭建 ...
- 华为云深度学习kaggle猫狗识别
使用华为云深度学习服务完成kaggle猫狗识别竞赛 参考: kaggle猫狗竞赛kernel第一名的代码 Tensorflow官网代码 华为云DLS服务github代码 1. 环境配置与数据集处理 首 ...
- 详解pytorch实现猫狗识别98%附代码
详解pytorch实现猫狗识别98%附代码 前言 一.为什么选用pytorch这个框架? 二.实现效果 三.神经网络从头到尾 1.来源:仿照人为处理图片的流程,模拟人们的神经元处理信息的方式 2.总览 ...
- tensorflow2.3.0迁移学习案例分析(以猫狗识别为例)
我对迁移学习的简单理解就是,将别人训练好的模型用到自己的程序中,同时根据实际情况,重新训练模型中的部分参数.我认为这样有2个好处,一是由于使用了已知的模型,那么节省训练的时间,二是在充分利用已知的成果 ...
- Tensorflow实现kaggle猫狗识别(循序渐进进行网络设计)
这篇是tensorflow版本,pytorch版本会在下一篇博客给出 友情提示:尽量上GPU,博主CPU上跑一个VGG16花了1.5h... Tensorflow实现kaggle猫狗识别 数据集获取 ...
- 基于卷积神经网络VGG的猫狗识别
!有需要本项目的实验源码的可以私信博主! 摘要:随着大数据时代的到来,深度学习.数据挖掘.图像处理等已经成为了一个热门研究方向.深度学习是一个复杂的机器学习算法,在语音和图像识别方面取得的效果,远远超 ...
- 猫狗识别——PyTorch
猫狗识别 数据集下载: 网盘链接:https://pan.baidu.com/s/1SlNAPf3NbgPyf93XluM7Fg 提取密码:hpn4 1. 要导入的包 import os import ...
最新文章
- 字符串算法--KMP--Java实现
- 数据解读:资本追逐的14个人工智能细分领域
- 1044 拦截导弹——http://codevs.cn/problem/1044/
- 计算机网络期中考试题周静,期中考试优秀作文
- JQuery学习使用笔记 -- JQuery插件开发
- centos7 + VMware Workstation Pro
- JavaBean 持久化
- console对象的方法log、info、warn、error的区别及几个实用的方法
- python自动测试v_python下selenium自动化测试自我实践
- Task Parallel Library
- 筛选数据库_网络药理学(2)| 使用TCMSP数据库检索中药成分并基于ADME参数进行成分筛选...
- 最强的linux命令总结.pdf
- 用python处理mp4与gif格式互转,简单到爆!
- 在HTTPS网站安装百度分享按钮代码及其SEO外链作用
- 「视频直播技术详解」系列之六:现代播放器原理
- php7.4新特性 多线程,PHP7新特性WhatwillbeinPHP7/PHPNG
- 背景图全屏水平垂直居中的方法
- checkstyle + gradle + git pre-commit 实现代码提交前对代码规范的检查
- 关于全球变暖,你应该知道的事实
- 【数据分析】基于matlab GUI学生成绩管理系统【含Matlab源码 1981期】
热门文章
- 大数据工程师岗位工作内容是什么
- 34岁程序员面试美团被拒绝:只招30岁以下,卖力能加班工资又少的
- js获取当前时间并转换为一定的格式
- IDEA中maven项目的language level 修改后自动重置问题
- sim_com AT
- 【游戏设计模式】之三 状态模式、有限状态机
- 赏析角度有哪些_从“视听语言”角度,主要从哪些方面进行影视赏析(? ?)。...
- 什么是zigbee?
- 谷歌网站错误服务器连接,谷歌浏览器 您的链接存在安全隐患 此网站是用的安全配置已过时...---服务器 TLS1.0 1TLS.2配置方法...
- RPA应用场景-交通违章查询