这个项目我们要做一个识别猫狗的模型,这和上次的数字识别一样,也是运用深度学习,不过这次模型较为复杂,我们会用到迁移学习,站在巨人的肩膀上,借用大佬们已经训练好的模型来搭建我们自己的模型并让它做我们想做的事。

安装要求Python3

Numpy

Scipy

matplotlib

tensorflow

keras

opencv

数据预处理

def make_label(file_name):

label = file_name.split('.')[0]

##one-hot-encoding

if label == 'cat':

return [0]

elif label == 'dog':

return [1]

def make_data(img_path,img_size):

path_length = len(os.listdir(img_path))

images = np.zeros((path_length,img_size,img_size, 3), dtype=np.uint8)

labels = np.zeros((path_length,1),dtype=np.float32)

count = 0

for file_name in os.listdir(img_path):

labels[count] = make_label(file_name)

images[count] = cv2.resize(cv2.imread(img_path+'/'+file_name),(img_size,img_size))

b,g,r = cv2.split(images[count]) # get b,g,r

images[count] = cv2.merge([r,g,b]) # switch it to rgb

count+=1

##shuffle

p = np.random.permutation(path_length)

images,labels = images[p],labels[p]

return images,labels

(猫的标签为0.,狗的标签为1.)

模型基于VGG16的模型

input = Input(shape=(img_size, img_size, 3))

base_model = VGG16(weights='imagenet', input_tensor=input,include_top=False)

x = Flatten()(base_model.output)

x = Dense(2048, activation='relu')(x)

x = Dense(1024, activation='relu')(x)

x = Dropout(0.7)(x)

output = Dense(1, activation='sigmoid')(x)

model = Model(input=input, output=output)训练

from keras.callbacks import TensorBoard

from keras.optimizers import SGD

for layer in model.layers[:19]:

layer.trainable = False

opt = SGD(lr=0.0001, momentum=0.9)

model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])

model.fit(train_img,train_label,validation_split=0.2, callbacks=[TensorBoard(log_dir='./log')])

model.save('model.h5')

基于Xception的模型

train_img, train_label = make_data(train_path,299) ## Xception要求的shape为299*299

from keras.applications.xception import Xception

from keras.callbacks import TensorBoard

from keras.optimizers import SGD

input = Input(shape=(img_size, img_size, 3))

base_model_2 = Xception(weights='imagenet', input_tensor=input,include_top=False)

x = Flatten()(base_model_2.output)

#x = Dense(2048, activation='relu')(x)

x = Dense(512, activation='relu')(x)

x = Dropout(0.85)(x)

output = Dense(1, activation='sigmoid')(x)

model_2 = Model(input=input, output=output)

for layer in model_2.layers[:132]: ## Xception除了top的全连接还有132层

layer.trainable = False

opt = SGD(lr=0.0001, momentum=0.9)

model_2.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])

model_2.fit(train_img,train_label,validation_split=0.2, batch_size=10, callbacks=[TensorBoard(log_dir='./log')])

model_2.save('model_2.h5')

上面是Xception模型结构,较为复杂,我只更改了它的全连接层。

Xception相比VGG16更为庞大和复杂,当然效果也更好。预测可视化

数据提升

from scipy.ndimage.interpolation import shift

def img_change_brightness(img):

# Convert the image to HSV

img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)

# Compute a random brightness value and apply to the image

brightness = np.random.uniform(0.25,1) ##调整范围在0.25到1之间

img[:, :, 2] = img[:, :, 2] * brightness

# Convert back to RGB

return cv2.cvtColor(img, cv2.COLOR_HSV2BGR)

index = random.randint(0,len(train_images))

for i in range(4):

plt.subplot(2,2,i+1)

img = [train_images[index], #原图

np.flipud(train_images[index]), #上下翻转

np.fliplr(train_images[index]), #左右翻转

img_change_brightness(train_images[index])] #亮度调整

plt.imshow(img[i])

plt.axis('off')

#随机将25%的训练数据进行亮度调整

for i in range(int(len(train_images)*0.25)):

index = random.randint(0,len(train_images))

train_images[index] = img_change_brightness(train_images[index])

#随机将25%的训练数据进行左右翻转

for i in range(int(len(train_images)*0.25)):

index = random.randint(0,len(train_images))

train_images[index] = np.fliplr(train_images[index])

#随机将25%的训练数据进行上下翻转

for i in range(int(len(train_images)*0.25)):

index = random.randint(0,len(train_images))

train_images[index] = np.flipud(train_images[index])

训练第n次

from keras.callbacks import TensorBoard

from keras.models import load_model

model_3 = load_model('model_2.h5')

model_3.fit(train_images,train_labels,validation_data=(valid_images,valid_labels), batch_size=16, callbacks=[TensorBoard(log_dir='./log')])

model_3.save('model_3.h5')

我们可以直接用上面已经构建好的模型,直接载入新数据来训练

模型效果比较

VGG16

Xception

数据提示后的Xception

结尾

文章代码/数据地址ciozhang

猫狗大战 python_猫狗大战相关推荐

  1. Pytorch系列(四):猫狗大战1-训练和测试自己的数据集

    Pytorch猫狗大战系列: 猫狗大战1-训练和测试自己的数据集 猫狗大战2-AlexNet 猫狗大战3-MobileNet_V1&V2 猫狗大战3-MobileNet_V3 TensorFl ...

  2. 03_基于CNN的猫狗大战实现

    文章目录 猫狗大战背景介绍 代码示例 step1 对模型的修改 step2 数据的输入 step3 模型的重新训练与存储 step4 模型的复用 猫狗大战背景介绍 猫狗大战数据集来源于Kaggle上的 ...

  3. 深度学习之基于AlexNet实现猫狗大战

    这次实验的主角并不是猫狗大战,而是AlexNet网络,只不过数据集为猫狗大战数据集.本次实验利用自己搭建的AlexNet网络实现猫狗大战,测试一下AlexNet网络的性能. AlexNet网络作为Le ...

  4. 基于tensorflow2.0实现猫狗大战(搭建网络迁移学习)

    猫狗大战是kaggle平台上的一个比赛,用于实现猫和狗的二分类问题.最近在学卷积神经网络,所以自己动手搭建了几层网络进行训练,然后利用迁移学习把别人训练好的模型直接应用于猫狗分类这个数据集,比较一下实 ...

  5. 使用猫狗大战数据集进行一次完整的TensorFlow训练

    1.简介 一直想将图片制作成tfrecords文件,然后在模型中运行一下.最初想用的数据集是mnist,但是跑的过程中一直出现问题.找到这一篇知乎上的博客,写的非常不错. 原博客地址:https:// ...

  6. 第四次作业:猫狗大战挑战赛

    文章目录 1. 导入需要的包,检查使用设备 2. 导入数据集并修改数据集目录结构 3. 数据处理 4. 创建VGG Model 5. 修改最后一层,冻结前面层的参数 6. 训练并测试全连接层 7.可视 ...

  7. python猫狗大战pytorch_深度学习实战---猫狗大战(pytorch实现)

    数据准备 猫狗大战数据集下载链接 微软的数据集已经分好类,直接使用就行, 数据划分 我们将猫和狗的图片分别移动到训练集和验证集中,其中90%的数据作为训练集,10%的图片作为验证集,使用shutil. ...

  8. vijos1153猫狗大战

    新一年度的猫狗大战通过SC(星际争霸)这款经典的游戏来较量,野猫和飞狗这对冤家为此已经准备好久了,为了使战争更有难度和戏剧性,双方约定只能选择Terran(人族)并且只能造机枪兵. 比赛开始了,很快, ...

  9. 猫狗大战——基于TensorFlow的猫狗识别(2)

    微信公众号:龙跃十二 我是小玉,一个平平无奇的小天才! 上篇文章我们说了关于猫狗大战这个项目的一些准备工作,接下来,我们看看具体的代码详解. 猫狗大战--基于TensorFlow的猫狗识别(1) 文件 ...

最新文章

  1. Groovy学习专栏
  2. 3句话概括 PUT/POST 的区别
  3. 2.4 嵌入矩阵-深度学习第五课《序列模型》-Stanford吴恩达教授
  4. angularJS 表单验证
  5. 奥拉星插件flash下载手机版下载安装_终于等到你!安卓微信7.0.13内测版发布 支持夜间模式 附下载地址!...
  6. 利用云数据库 MongoDB ,为你的业务创建单节点实例
  7. 微服务的好处与弊端_《微服务架构设计模式》-学习总结07
  8. c++数据结构中 顺序队列的队首队尾_用队列实现栈,用栈实现队列,听起来有点绕,都搞懂了就掌握了精髓
  9. 「leetcode」700. 二叉搜索树中的搜索:【递归法】【迭代法】详解
  10. 5.7 C和C++的关系
  11. 更改windows 2003 密钥
  12. mysql命令创库创表_创库+表的操作
  13. 互联网晚报 | 10月24日 星期日 | 华为鸿蒙生态建设投入已超500亿;瑞幸门店端扭亏为盈;文旅部要求暂停经营旅游专列业务...
  14. python Django音乐推荐系统
  15. KM算法matlab实现
  16. poj 1205 :Water Treatment Plants (DP+高精度)
  17. 【Akka】Akka Actor生命周期
  18. webgl-画三角形
  19. PCL点云库必备知识点4——pointcloud2消息格式的转换
  20. 考研英语作文押题---垃圾分类

热门文章

  1. 2012年10月3日
  2. 系统分析技术简单介绍
  3. 简单工厂模式(代码实现)
  4. pta-7-2 最大公约数与最小公倍数 (15 分)
  5. opencv3.4.1: ippicv_2017u3_lnx_intel64_20170822.tgz下载包
  6. SpringBoot整合:Druid、MyBatis、MyBatis-Plus、多数据源、knife4j、日志、Redis,Redis的Java操作工具类、封装发送电子邮件等等
  7. android控制台没有报出错误,小弟我有意制造了一个错误,但是它却不在Console控制台显示啊100分...
  8. 发布使用Windows Media Format 9 Series SDK 开发的程序
  9. Idea——Tomcat和Maven使用 报错——Warning: No artifacts configured 解决方法
  10. 我的windows学习心得