卷积神经网络实现图像分类
卷积神经网络的基本结构和原理我们已有介绍,在这一节里将使用 TensorFlow 搭建一个简单的卷积神经网络,实现图像分类。
这里我们要解决的任务是来自于 Kaggle 上的一道赛题:在加拿大的东海岸经常会有漂流的冰山,对航行在该海域的船舶造成了很大的威胁。挪威国家石油公司(Statoil)是一家在全球运营的国际能源公司,该公司曾与 C-CORE 等公司合作,C-CORE 基于其卫星数据和计算机视觉技术建立了一个监控系统。
Statoil 发布该赛题的目的是希望利用机器学习的技术,更准确地及早发现和识别出威胁船舶航行的冰山。
1. 数据介绍
赛题提供了两个数据文件“train.json”和“test.json”,其中“test.json”是比赛中用来对模型进行评分的,没有类标,这里我们只需要使用“train.json”文件。该数据集中有 1604 个打标过的训练数据,单个样本的数据格式如表 1 所示。
字段名 | 字段说明 |
---|---|
id | 图像的 id |
band_1 ,band_2 | 卫星图像数据,band_1 和 band_2 是以特定入射角下不同极化方式产生的雷达后向散射为特征的信号, 分别对应 HH (水平发射/水平接收)和 HV (水平发射/垂直接收)两种极化方式的数据,其大小均为 75x75 |
inc_angle | 获得该数据时的入射角度。该字段部分缺少数据,标记为“na” |
is_iceberg |
类标,0 代表船只,1 代表冰山 |
我们将数据可视化后进行观察,如图 1 所示。
图 1:训练数据可视化效果(易观察)
图像上方是冰山图像的可视化效果,三幅图分别对应“HH”极化方式、“HV”极化方式,以及两者结合后的数据。图像下方是船只图像的可视化效果。
图 1 中的冰山和船只,通过观察可以较为容易地区分出来,但是还有很多如图 2 所示的数据,即使仔细观察也很难区分开来。
图 2:训练数据可视化效果(不易观察)
2. 数据预处理
首先导入需要的包:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers
接下来定义一个数据预处理的函数:
def data_preprocess(path, more_data):#读取数据data_frame = pd.read_json(path)#获取图像数据images =[]for _, row in data_frame.iterrows():#将一维数据转换另75x75的二维数据band_1 = np.array(row['band_1']).reshape(75, 75)band_2 = np.array(row['band_2 ']).reshape(75f 75)band_3 = band_1 + band_2images.append(np.dstack((band_1, band_2, band_3)))if more_data:#扩充数据集images = create_more_data(np.array(images))#获取类标labels = np.array(data_frame['is_iceberg'])if more_data:#扩充数据集后,类标也需要相应扩充labels = np.concatenate((labels, labels, labels, labels, labels, labels))return np.array(images), labels
“data_preprocess”函数接收两个参数:“path”为训练数据“train.json”的文件路径;“more_data”为布尔类型,当其为True 时,会调用函数“create_more_data”进行训练数据的扩充(即数据增强)。
第 7 行到第 13 行代码对样本数据进行处理,除了原有的“band_1”和“band_2”,我们增加了“band_3”,band_3=band_1+band_2。最后使用 NumPy 的“dstack”函数将三种数据进行堆叠,因此我们单个样本的数据维度为75×75×3。
第 16 行代码调用“create_more_data”函数对训练数据进行扩充,第 22 行代码对训练集的类标数据进行扩充,因为“create_more_data”函数将训练数据扩充为了原来的 6 倍,所以这里也要对应地将类标扩充为原来的 6 倍。
“create_more_data”函数的实现如下:
def create_more_data(images):#通过旋转、翻转扩充数据image_rot90 = []image_rot180 = []image_rot270 = []img_lr = []img_ud = []
在“create_more_data”函数中,我们通过对图片进行旋转和翻转来扩充数据集,虽然旋转前后的图片是同一张,但是由于特征的位置发生了变化,因此对于模型来说就是不同的数据,旋转或翻转操作是扩充图像数据集的一个简单有效的方法。
在第 3 行至第 7 行代码中,定义了 5 个列表,用来保存扩充的数据集,对应的操作分别是逆时针旋转 90°、逆时针旋转 180°、逆时针旋转 270°、左右翻转和上下翻转。具体实现如下:
for i in range(0, images.shape[0]):band_1 = images[i, :, :, 0 ]band_2 = images[i, :, :, 1 ]band_3 = images[i, :, :, 2]#旋转90°band_1_rot90 = np.rot90(band_1)band_2_rot90 = np.rot90(band_2)band_3_rot90 = np.rot90(band_3)image_rot90.append(np.dstack((band_1_rot90, band_2_rot90, band_3_rot90)))#旋转180°band_1_rot180=np.rot90(band_1_rot90)band_2_rot180=np.rot90(band_2_rot90)band_3_rot180 = np.rot90(band_3_rot90)image_rot180.append(np.dstack((band_1_rot180, band_2_rot180, band_3_rot180)))#旋转270°band_1_rot270 = np.rot90(band_1_rot180)band_2_rot270 = np.rot90(band_2_rot180) band_3_rot270 = np.rot90(band_3_rot180)image_rot270.append(np.dstack((band_1_rot270, band_2_rot270, band_3_rot270)))#左右翻转lr1 = np.flip(band_1, 0)lr2 = np.flip(band_2, 0)lr3 = np.flip(band_3, 0)img_lr.append(np.dstack((lrlf lr2, lr3)))#上下翻转ud1 = np.flip(band_1, 1)ud2 = np.flip(band_2, 1)ud3 = np.flip(band_3, 1)img_ud.append(np.dstack((ud1, ud2, ud3)))
上面的代码中,我们使用 NumPy 的“rot90”和“flip”函数对图片进行旋转和翻转操作。“flip”函数的第二个参数控制翻转的方式,“0”为左右翻转,“1”为上下翻转。
rot90 = np.array(image_rot90)
rot180 = np.array(image_rot180)
rot270 = np.array(image_rot270)
lr = np.array(img_lr)
ud = np.array(img_ud)
images = np.concatenate((image,rot90, rot180, rot270, lr, ud))return images
第 6 行代码使用 NumPy 的“concatenate”函数将扩充的数据与原数据进行拼接。
3. 模型搭建
接下来使用 TensorFlow 的高级 API 来搭建模型。
#定义模型
def get_model():#建立一个序贯模型model = tf.keras.Sequential()#第一个卷积块model.add(layers.Conv2D(128, kernel_size=(3, 3), activation= 'relu', input_shape=(75, 75, 3)))model.add(layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))model.add(layers.Dropout(0.2))#第二个卷积块model.add(layers.Conv2D(128, kernel_size=(3, 3), activation= 'relu'))model.add(layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))model.add(layers.Dropout(0.2))#第三个卷积块model.add(layers.Conv2D(64, kernel_size=(2, 2), activation='relu'))model.add(layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))model.add(layers.Dropout(0.2))#第四个卷积块model.add(layers.Conv2D(64, kernel_size={2, 2), activation= 'relu'))model.add(layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))model.add(layers.Dropout(0.2))#将上一层的输出特征映射转化为一维数据,以便进行全连接操作model.add(layers.Flatten())#第一个全连接层model.add(layers.Dense(256))model.add(layers.Activation('relu'))model.add(layers.Dropout(0.2))#第二个全连接层model.add(layers.Dense(128))model.add(layers.Activation('relu'))model.add(layers.Dropout(0.2))#第三个全连接层model.add(layers.Dense(1))model.add(layers.Activation('sigmoid'))#编译模型model.compile(loss= 'binary_crossentropy', optimizer=tf.keras.optimizers.Adam(0.0001), metrics=['accuracy'])#打印出模型的概况信息model.summary()return model
在第 4 行代码中,使用“tf.keras.Sequential()”创建一个序贯模型,序贯模型是多个网络层的线性堆叠,使用“tf.keras.Sequential().add()”方法逐层添加网络结构。第 7 行到第 9 行代码是第一个卷积块,这里使用了 128 个大小为 3×3 的卷积核,以 ReLU 为激活函数。
在卷积层后面是一个池化层,采用最大池化,池化窗口的大小为 3×3,横向和纵向的步长都为 2。在池化层的后面进行 Dropout 操作,丢弃了 20% 的神经元,防止参数过多导致过拟合。接下来是三个类似的卷积块。
第 27 行代码使用“Flatten()”将前一层网络的输出转换为了一维的数据,这是为接下来的全连接操作做准备。第 30 行代码是第一个全连接层,有 256 个神经元,全连接层后面接 ReLU 激活函数,同样进行 Dropout 操作。第 35 行至第 37 行代码是类似的全连接层部分。
由于是二分类问题,第 40 行至第 41 行代码使用了一个只有一个神经元的全连接层,并使用了 Sigmoid 激活函数,得到最终的输出。
第 44 行代码使用“compile”编译模型,其中“loss='binary_crossentropy'”指明使用的是对数损失函数,通过“optimizers”参数设置使用Adam 优化器,设置学习率为 0.0001。“metrics” 列表包含评估模型在训练和测试时的性能指标,这里设置了“metrics= ['accuracy']”,则在训练的过程中,训练集和验证集上的准确率都会打印出来。
第 47 行代码使用了“summary()”函数,训练开始后终端会打印出模型的概况信息,如图 3 所示,其中包含了网络的结构,以及每层的参数数量等信息,最后一行显示出,总的训练数据为 7699 条,验证集的数据量为 1925 条。
图 3:模型的概况信息
4. 结果分析
接下来读取数据,并训练模型:
#数据预处理
train_x, train_y = data_preprocess('./data/train.json', more_data=True)
#初始化模型
cnn_model = get_model()
#模型训练
cnn_model. fit (train_x, train_y, batch_size=25, epochs=100, verbose=1, validation_split=0.2)
第 2 行代码调用“data_preprocess”函数获取预处理后的训练数据,将“more_data”设置为“True”进行数据扩充。第 8 行代码调用“fit”方法开始模型的训练,通过“batch_size”设置每个批次训练 25 条数据,通过“epochs”设置训练的总回合数为“100”。
通过设置“verbose”为“1”,在终端上显示训练的进度。通过设置“validation_split”为“0.2”,将训练集一分为二,其中 80% 作为训练集,20% 作为验证集。
模型的训练过程和结果如图 4 所示。
图 4:模型的训练过程和结果
卷积神经网络实现图像分类相关推荐
- pytorch1.7教程实验——迁移学习训练卷积神经网络进行图像分类
只是贴上跑通的代码以供参考学习 参考网址:迁移学习训练卷积神经网络进行图像分类 需要用到的数据集下载网址: https://download.pytorch.org/tutorial/hymenopt ...
- 图像处理神经网络python_深度学习使用Python进行卷积神经网络的图像分类教程
深度学习使用Python进行卷积神经网络的图像分类教程 好的,这次我将使用python编写如何使用卷积神经网络(CNN)进行图像分类.我希望你事先已经阅读并理解了卷积神经网络(CNN)的基本概念,这里 ...
- 卷积神经网络和图像分类识别
Andrew Kirillov 著 Conmajia 译 2019 年 1 月 15 日 原文发表于 CodeProject(2018 年 10 月 28 日). 中文版有小幅修改,已获作者本人授权. ...
- 【神经网络与深度学习】CIFAR10数据集介绍,并使用卷积神经网络训练图像分类模型——[附完整训练代码]
[神经网络与深度学习]CIFAR-10数据集介绍,并使用卷积神经网络训练模型--[附完整代码] 一.CIFAR-10数据集介绍 1.1 CIFAR-10数据集的内容 1.2 CIFAR-10数据集的结 ...
- 轻松学Pytorch-使用卷积神经网络实现图像分类
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达本文转自|人工智能与算法学习 大家好,本篇教程的贡献者来自社区投稿作 ...
- Bag of Tricks for Image Classification with Convolutional Neural Networks(卷积神经网络在图像分类中的技巧)
来源:Tong He Zhi Zhang Hang Zhang Zhongyue Zhang Junyuan Xie Mu L Amazon Web Services fhtong,zhiz,hzaw ...
- 经典卷积神经网络(CNN)图像分类算法详解
本文原创,转载请引用 https://blog.csdn.net/dan_teng/article/details/87192430 CNN图像分类网络 一点废话:CNN网络主要特点是使用卷积层,这其 ...
- 在PyTorch中使用卷积神经网络建立图像分类模型
概述 在PyTorch中构建自己的卷积神经网络(CNN)的实践教程 我们将研究一个图像分类问题--CNN的一个经典和广泛使用的应用 我们将以实用的格式介绍深度学习概念 介绍 我被神经网络的力量和能力所 ...
- PyTorch实战福利从入门到精通之四——卷积神经网络CIFAR-10图像分类
在本教程中,我们将使用CIFAR10数据集.它有类别:"飞机"."汽车"."鸟"."猫"."鹿".& ...
最新文章
- IT人面试必看!25个雷区和11个必问句!
- STL之vector,数组线性容器array,list容器,算法find,find_if,bind1st,仿函数
- mysql集群安装指南
- python整数池_对Python中小整数对象池和大整数对象池的使用详解
- 基于JAVA+SpringMVC+MYSQL的排班管理系统
- Server 2016DC查看五种AD角色的方法
- ×××S 2012 高级图表类型 -- 小面积扇形处理
- Ubuntu 11.10为何值得我们期待?
- java的inputbox,解释InputBox函数(输入函数)的用途
- 【网易云音乐】浏览器控制台抓包(纯前端)
- nifty_Nifty JUnit:在方法和类级别上使用规则
- 二、C语言基本数据类型全解和基本的数据存储方式
- 台式机通过网线连接笔记本上网
- 2016第1篇--Python查看微信被删好友
- 为什么任何数的0次方都是1
- iptables -j MARK --set-xmark 解析
- 一个int类型到底占多少个字节?
- cout保留两位小数位
- memcmp性能测试
- Mac查看端口占用情况
热门文章
- MAC免费解压软件——解压RAR、7Zip等五六十种格式
- linux centos rar解压,Centos解压rar压缩文件
- 但总觉得明白了一点点什么
- 网络摄像头无插件直播H265编码视频播放器EasyPlayer网页播放器不能播放怎么处理?
- 互联网无插件直播流媒体服务器方案EasyNVR下载新的软件执行程序,出现“invalid license”字样是什么意思?
- 沧海一声笑(最好版)
- allegro如何删除没有网络的走线,查还没有连的网络线
- ie显示服务器拒接链接,IE浏览器拒接访问是怎么回事 IE浏览器显示拒接访问的有效解决方法...
- 乘风破浪的码农——仿佛身体被掏空
- java开发规划_java开发程序员职业发展规划路线