卷积神经网络的基本结构和原理我们已有介绍,在这一节里将使用 TensorFlow 搭建一个简单的卷积神经网络,实现图像分类。

这里我们要解决的任务是来自于 Kaggle 上的一道赛题:在加拿大的东海岸经常会有漂流的冰山,对航行在该海域的船舶造成了很大的威胁。挪威国家石油公司(Statoil)是一家在全球运营的国际能源公司,该公司曾与 C-CORE 等公司合作,C-CORE 基于其卫星数据和计算机视觉技术建立了一个监控系统。

Statoil 发布该赛题的目的是希望利用机器学习的技术,更准确地及早发现和识别出威胁船舶航行的冰山。

1. 数据介绍

赛题提供了两个数据文件“train.json”和“test.json”,其中“test.json”是比赛中用来对模型进行评分的,没有类标,这里我们只需要使用“train.json”文件。该数据集中有 1604 个打标过的训练数据,单个样本的数据格式如表 1 所示。

表 1:单个样本的数据格式
字段名 字段说明
id 图像的 id
band_1 ,band_2 卫星图像数据,band_1 和 band_2 是以特定入射角下不同极化方式产生的雷达后向散射为特征的信号, 分别对应 HH (水平发射/水平接收)和 HV (水平发射/垂直接收)两种极化方式的数据,其大小均为 75x75
inc_angle 获得该数据时的入射角度。该字段部分缺少数据,标记为“na”
is_iceberg

类标,0 代表船只,1 代表冰山

我们将数据可视化后进行观察,如图 1 所示。

图 1:训练数据可视化效果(易观察)

图像上方是冰山图像的可视化效果,三幅图分别对应“HH”极化方式、“HV”极化方式,以及两者结合后的数据。图像下方是船只图像的可视化效果。

图 1 中的冰山和船只,通过观察可以较为容易地区分出来,但是还有很多如图 2 所示的数据,即使仔细观察也很难区分开来。

图 2:训练数据可视化效果(不易观察)

2. 数据预处理

首先导入需要的包:

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers

接下来定义一个数据预处理的函数:

def data_preprocess(path, more_data):#读取数据data_frame = pd.read_json(path)#获取图像数据images =[]for _, row in data_frame.iterrows():#将一维数据转换另75x75的二维数据band_1 = np.array(row['band_1']).reshape(75, 75)band_2 = np.array(row['band_2 ']).reshape(75f 75)band_3 = band_1 + band_2images.append(np.dstack((band_1, band_2, band_3)))if more_data:#扩充数据集images = create_more_data(np.array(images))#获取类标labels = np.array(data_frame['is_iceberg'])if more_data:#扩充数据集后,类标也需要相应扩充labels = np.concatenate((labels, labels, labels, labels, labels, labels))return np.array(images), labels

“data_preprocess”函数接收两个参数:“path”为训练数据“train.json”的文件路径;“more_data”为布尔类型,当其为True 时,会调用函数“create_more_data”进行训练数据的扩充(即数据增强)。

第 7 行到第 13 行代码对样本数据进行处理,除了原有的“band_1”和“band_2”,我们增加了“band_3”,band_3=band_1+band_2。最后使用 NumPy 的“dstack”函数将三种数据进行堆叠,因此我们单个样本的数据维度为75×75×3。

第 16 行代码调用“create_more_data”函数对训练数据进行扩充,第 22 行代码对训练集的类标数据进行扩充,因为“create_more_data”函数将训练数据扩充为了原来的 6 倍,所以这里也要对应地将类标扩充为原来的 6 倍。

“create_more_data”函数的实现如下:

def create_more_data(images):#通过旋转、翻转扩充数据image_rot90 = []image_rot180 = []image_rot270 = []img_lr = []img_ud = []

在“create_more_data”函数中,我们通过对图片进行旋转和翻转来扩充数据集,虽然旋转前后的图片是同一张,但是由于特征的位置发生了变化,因此对于模型来说就是不同的数据,旋转或翻转操作是扩充图像数据集的一个简单有效的方法。

在第 3 行至第 7 行代码中,定义了 5 个列表,用来保存扩充的数据集,对应的操作分别是逆时针旋转 90°、逆时针旋转 180°、逆时针旋转 270°、左右翻转和上下翻转。具体实现如下:

for i in range(0, images.shape[0]):band_1 = images[i, :, :, 0 ]band_2 = images[i, :, :, 1 ]band_3 = images[i, :, :, 2]#旋转90°band_1_rot90 = np.rot90(band_1)band_2_rot90 = np.rot90(band_2)band_3_rot90 = np.rot90(band_3)image_rot90.append(np.dstack((band_1_rot90, band_2_rot90, band_3_rot90)))#旋转180°band_1_rot180=np.rot90(band_1_rot90)band_2_rot180=np.rot90(band_2_rot90)band_3_rot180 = np.rot90(band_3_rot90)image_rot180.append(np.dstack((band_1_rot180, band_2_rot180, band_3_rot180)))#旋转270°band_1_rot270 = np.rot90(band_1_rot180)band_2_rot270 = np.rot90(band_2_rot180)   band_3_rot270 = np.rot90(band_3_rot180)image_rot270.append(np.dstack((band_1_rot270, band_2_rot270, band_3_rot270)))#左右翻转lr1 = np.flip(band_1, 0)lr2 = np.flip(band_2, 0)lr3 = np.flip(band_3, 0)img_lr.append(np.dstack((lrlf lr2, lr3)))#上下翻转ud1 = np.flip(band_1, 1)ud2 = np.flip(band_2, 1)ud3 = np.flip(band_3, 1)img_ud.append(np.dstack((ud1, ud2, ud3)))

上面的代码中,我们使用 NumPy 的“rot90”和“flip”函数对图片进行旋转和翻转操作。“flip”函数的第二个参数控制翻转的方式,“0”为左右翻转,“1”为上下翻转。

rot90 = np.array(image_rot90)
rot180 = np.array(image_rot180)
rot270 = np.array(image_rot270)
lr = np.array(img_lr)
ud = np.array(img_ud)
images = np.concatenate((image,rot90, rot180, rot270, lr, ud))return images

第 6 行代码使用 NumPy 的“concatenate”函数将扩充的数据与原数据进行拼接。

3. 模型搭建

接下来使用 TensorFlow 的高级 API 来搭建模型。

#定义模型
def get_model():#建立一个序贯模型model = tf.keras.Sequential()#第一个卷积块model.add(layers.Conv2D(128, kernel_size=(3, 3), activation= 'relu', input_shape=(75, 75, 3)))model.add(layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))model.add(layers.Dropout(0.2))#第二个卷积块model.add(layers.Conv2D(128, kernel_size=(3, 3), activation= 'relu'))model.add(layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))model.add(layers.Dropout(0.2))#第三个卷积块model.add(layers.Conv2D(64, kernel_size=(2, 2), activation='relu'))model.add(layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))model.add(layers.Dropout(0.2))#第四个卷积块model.add(layers.Conv2D(64, kernel_size={2, 2), activation= 'relu'))model.add(layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))model.add(layers.Dropout(0.2))#将上一层的输出特征映射转化为一维数据,以便进行全连接操作model.add(layers.Flatten())#第一个全连接层model.add(layers.Dense(256))model.add(layers.Activation('relu'))model.add(layers.Dropout(0.2))#第二个全连接层model.add(layers.Dense(128))model.add(layers.Activation('relu'))model.add(layers.Dropout(0.2))#第三个全连接层model.add(layers.Dense(1))model.add(layers.Activation('sigmoid'))#编译模型model.compile(loss= 'binary_crossentropy', optimizer=tf.keras.optimizers.Adam(0.0001), metrics=['accuracy'])#打印出模型的概况信息model.summary()return model

在第 4 行代码中,使用“tf.keras.Sequential()”创建一个序贯模型,序贯模型是多个网络层的线性堆叠,使用“tf.keras.Sequential().add()”方法逐层添加网络结构。第 7 行到第 9 行代码是第一个卷积块,这里使用了 128 个大小为 3×3 的卷积核,以 ReLU 为激活函数。

在卷积层后面是一个池化层,采用最大池化,池化窗口的大小为 3×3,横向和纵向的步长都为 2。在池化层的后面进行 Dropout 操作,丢弃了 20% 的神经元,防止参数过多导致过拟合。接下来是三个类似的卷积块。

第 27 行代码使用“Flatten()”将前一层网络的输出转换为了一维的数据,这是为接下来的全连接操作做准备。第 30 行代码是第一个全连接层,有 256 个神经元,全连接层后面接 ReLU 激活函数,同样进行 Dropout 操作。第 35 行至第 37 行代码是类似的全连接层部分。

由于是二分类问题,第 40 行至第 41 行代码使用了一个只有一个神经元的全连接层,并使用了 Sigmoid 激活函数,得到最终的输出。

第 44 行代码使用“compile”编译模型,其中“loss='binary_crossentropy'”指明使用的是对数损失函数,通过“optimizers”参数设置使用Adam 优化器,设置学习率为 0.0001。“metrics” 列表包含评估模型在训练和测试时的性能指标,这里设置了“metrics= ['accuracy']”,则在训练的过程中,训练集和验证集上的准确率都会打印出来。

第 47 行代码使用了“summary()”函数,训练开始后终端会打印出模型的概况信息,如图 3 所示,其中包含了网络的结构,以及每层的参数数量等信息,最后一行显示出,总的训练数据为 7699 条,验证集的数据量为 1925 条。

图 3:模型的概况信息

4. 结果分析

接下来读取数据,并训练模型:

#数据预处理
train_x, train_y = data_preprocess('./data/train.json', more_data=True)
#初始化模型
cnn_model = get_model()
#模型训练
cnn_model. fit (train_x, train_y, batch_size=25, epochs=100, verbose=1, validation_split=0.2)

第 2 行代码调用“data_preprocess”函数获取预处理后的训练数据,将“more_data”设置为“True”进行数据扩充。第 8 行代码调用“fit”方法开始模型的训练,通过“batch_size”设置每个批次训练 25 条数据,通过“epochs”设置训练的总回合数为“100”。

通过设置“verbose”为“1”,在终端上显示训练的进度。通过设置“validation_split”为“0.2”,将训练集一分为二,其中 80% 作为训练集,20% 作为验证集。

模型的训练过程和结果如图 4 所示。

图 4:模型的训练过程和结果

卷积神经网络实现图像分类相关推荐

  1. pytorch1.7教程实验——迁移学习训练卷积神经网络进行图像分类

    只是贴上跑通的代码以供参考学习 参考网址:迁移学习训练卷积神经网络进行图像分类 需要用到的数据集下载网址: https://download.pytorch.org/tutorial/hymenopt ...

  2. 图像处理神经网络python_深度学习使用Python进行卷积神经网络的图像分类教程

    深度学习使用Python进行卷积神经网络的图像分类教程 好的,这次我将使用python编写如何使用卷积神经网络(CNN)进行图像分类.我希望你事先已经阅读并理解了卷积神经网络(CNN)的基本概念,这里 ...

  3. 卷积神经网络和图像分类识别

    Andrew Kirillov 著 Conmajia 译 2019 年 1 月 15 日 原文发表于 CodeProject(2018 年 10 月 28 日). 中文版有小幅修改,已获作者本人授权. ...

  4. 【神经网络与深度学习】CIFAR10数据集介绍,并使用卷积神经网络训练图像分类模型——[附完整训练代码]

    [神经网络与深度学习]CIFAR-10数据集介绍,并使用卷积神经网络训练模型--[附完整代码] 一.CIFAR-10数据集介绍 1.1 CIFAR-10数据集的内容 1.2 CIFAR-10数据集的结 ...

  5. 轻松学Pytorch-使用卷积神经网络实现图像分类

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达本文转自|人工智能与算法学习 大家好,本篇教程的贡献者来自社区投稿作 ...

  6. Bag of Tricks for Image Classification with Convolutional Neural Networks(卷积神经网络在图像分类中的技巧)

    来源:Tong He Zhi Zhang Hang Zhang Zhongyue Zhang Junyuan Xie Mu L Amazon Web Services fhtong,zhiz,hzaw ...

  7. 经典卷积神经网络(CNN)图像分类算法详解

    本文原创,转载请引用 https://blog.csdn.net/dan_teng/article/details/87192430 CNN图像分类网络 一点废话:CNN网络主要特点是使用卷积层,这其 ...

  8. 在PyTorch中使用卷积神经网络建立图像分类模型

    概述 在PyTorch中构建自己的卷积神经网络(CNN)的实践教程 我们将研究一个图像分类问题--CNN的一个经典和广泛使用的应用 我们将以实用的格式介绍深度学习概念 介绍 我被神经网络的力量和能力所 ...

  9. PyTorch实战福利从入门到精通之四——卷积神经网络CIFAR-10图像分类

    在本教程中,我们将使用CIFAR10数据集.它有类别:"飞机"."汽车"."鸟"."猫"."鹿".& ...

最新文章

  1. IT人面试必看!25个雷区和11个必问句!
  2. STL之vector,数组线性容器array,list容器,算法find,find_if,bind1st,仿函数
  3. mysql集群安装指南
  4. python整数池_对Python中小整数对象池和大整数对象池的使用详解
  5. 基于JAVA+SpringMVC+MYSQL的排班管理系统
  6. Server 2016DC查看五种AD角色的方法
  7. ×××S 2012 高级图表类型 -- 小面积扇形处理
  8. Ubuntu 11.10为何值得我们期待?
  9. java的inputbox,解释InputBox函数(输入函数)的用途
  10. 【网易云音乐】浏览器控制台抓包(纯前端)
  11. nifty_Nifty JUnit:在方法和类级别上使用规则
  12. 二、C语言基本数据类型全解和基本的数据存储方式
  13. 台式机通过网线连接笔记本上网
  14. 2016第1篇--Python查看微信被删好友
  15. 为什么任何数的0次方都是1
  16. iptables -j MARK --set-xmark 解析
  17. 一个int类型到底占多少个字节?
  18. cout保留两位小数位
  19. memcmp性能测试
  20. Mac查看端口占用情况

热门文章

  1. MAC免费解压软件——解压RAR、7Zip等五六十种格式
  2. linux centos rar解压,Centos解压rar压缩文件
  3. 但总觉得明白了一点点什么
  4. 网络摄像头无插件直播H265编码视频播放器EasyPlayer网页播放器不能播放怎么处理?
  5. 互联网无插件直播流媒体服务器方案EasyNVR下载新的软件执行程序,出现“invalid license”字样是什么意思?
  6. 沧海一声笑(最好版)
  7. allegro如何删除没有网络的走线,查还没有连的网络线
  8. ie显示服务器拒接链接,IE浏览器拒接访问是怎么回事 IE浏览器显示拒接访问的有效解决方法...
  9. 乘风破浪的码农——仿佛身体被掏空
  10. java开发规划_java开发程序员职业发展规划路线