基于SegNet和UNet的遥感图像分割代码解读

目录

    • 基于SegNet和UNet的遥感图像分割代码解读
  • 前言
  • 概述
  • 代码框架
  • 代码细节分析
    • 划分数据集gen_dataset.py
    • UNet模型训练unet_train.py
    • 模型融合combind.py
    • UNet模型预测unet_predict.py
    • 分类结果集成ensemble.py
    • SegNet模型训练segnet_train.py

前言

上了一学期的课,趁着寒假有时间,看了往年论文和部分比赛的代码,现在整理出来。整理的这部分内容以实际操作为主,主要讲解代码部分的分析。

概述

首先来分享一个小项目,基于SegNet和UNet的遥感图像比赛。代码来自github,这是对项目的简要介绍。

代码框架

以下是项目的代码结构:总共有4个子目录,分别是deprecated、ensemble、segnet、unet,其中deprecated是作者的一些代码草稿,ensemble是对不同分类结果的集成,segnet和unet分别是两个典型网络的网络架构、训练代码、预测代码、划分训练集和测试集的代码。

代码细节分析

划分数据集gen_dataset.py

import cv2
import random
import os
import numpy as np
from tqdm import tqdmimg_w = 256
img_h = 256
# 数据集一共5张图片
image_sets = ['1.png','2.png','3.png','4.png','5.png']def gamma_transform(img, gamma):gamma_table = [np.power(x / 255.0, gamma) * 255.0 for x in range(256)]gamma_table = np.round(np.array(gamma_table)).astype(np.uint8)# LUT: Look Up Table查找表,通过LUT变换可以改变图像的曝光和色彩return cv2.LUT(img, gamma_table)def random_gamma_transform(img, gamma_vari):log_gamma_vari = np.log(gamma_vari)alpha = np.random.uniform(-log_gamma_vari, log_gamma_vari)gamma = np.exp(alpha)return gamma_transform(img, gamma)# 旋转image
def rotate(xb,yb,angle):M_rotate = cv2.getRotationMatrix2D((img_w/2, img_h/2), angle, 1)xb = cv2.warpAffine(xb, M_rotate, (img_w, img_h))yb = cv2.warpAffine(yb, M_rotate, (img_w, img_h))return xb,ybdef blur(img):# cv2.blur(img,(size,size))表示对img使用尺寸为size x size的均值滤波器进行平滑img = cv2.blur(img, (3, 3));return img
# 加噪声
def add_noise(img):for i in range(200): #添加点噪声temp_x = np.random.randint(0,img.shape[0])temp_y = np.random.randint(0,img.shape[1])img[temp_x][temp_y] = 255return img# 数据增强:图像旋转、gamma变换、模糊变换、加噪声
def data_augment(xb,yb):if np.random.random() < 0.25:xb,yb = rotate(xb,yb,90)if np.random.random() < 0.25:xb,yb = rotate(xb,yb,180)if np.random.random() < 0.25:xb,yb = rotate(xb,yb,270)if np.random.random() < 0.25:xb = cv2.flip(xb, 1)  # flipcode > 0:沿y轴翻转yb = cv2.flip(yb, 1)if np.random.random() < 0.25:xb = random_gamma_transform(xb,1.0)if np.random.random() < 0.25:xb = blur(xb)if np.random.random() < 0.2:xb = add_noise(xb)return xb,yb
# 构建数据集
def creat_dataset(image_num = 50000, mode = 'original'):print('creating dataset...')# len(image_sets) = 5image_each = image_num / len(image_sets)g_count = 0for i in tqdm(range(len(image_sets))):count = 0# 读取源图像和标记图像src_img = cv2.imread('./data/src/' + image_sets[i])  # 3 channelslabel_img = cv2.imread('./data/road_label/' + image_sets[i],cv2.IMREAD_GRAYSCALE)  # single channelX_height,X_width,_ = src_img.shapewhile count < image_each:# img_w = img_h = 256random_width = random.randint(0, X_width - img_w - 1)random_height = random.randint(0, X_height - img_h - 1)# 随机截取img_h x img_w大小的图像src_roi = src_img[random_height: random_height + img_h, random_width: random_width + img_w,:]label_roi = label_img[random_height: random_height + img_h, random_width: random_width + img_w]# 如果是增强模式,那么对源图像和标记图像使用数据增强if mode == 'augment':src_roi,label_roi = data_augment(src_roi,label_roi)visualize = np.zeros((256,256)).astype(np.uint8)visualize = label_roi *50# 划分数据集cv2.imwrite(('./unet_train/visualize/%d.png' % g_count),visualize)cv2.imwrite(('./unet_train/road/src/%d.png' % g_count),src_roi)cv2.imwrite(('./unet_train/road/label/%d.png' % g_count),label_roi)count += 1 g_count += 1if __name__=='__main__':  creat_dataset(mode='augment')

UNet模型训练unet_train.py

#coding=utf-8
import matplotlib
# matplotlib.use('Agg')必须放在import matplotlib.pyplot as plt前面,这个语句的意思是不使用交互式页面,仅仅保存图像而是不把图像shhow出来
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import argparse
import numpy as np
from keras.models import Sequential
from keras.layers import Conv2D,MaxPooling2D,UpSampling2D,BatchNormalization,Reshape,Permute,Activation,Input
from keras.utils.np_utils import to_categorical
from keras.preprocessing.image import img_to_array
from keras.callbacks import ModelCheckpoint
from sklearn.preprocessing import LabelEncoder
from keras.models import Model
from keras.layers.merge import concatenate
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import random
import os
from tqdm import tqdm
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
# 设置随机数种子,以便每次产生的随机数一样,方便比较在同一批数据上比较实验结果
seed = 7
np.random.seed(seed)  #data_shape = 360*480
img_w = 256
img_h = 256
#有一个为背景
#n_label = 4+1
n_label = 1
# 总共5个类别
classes = [0. ,  1.,  2.,   3.  , 4.]  labelencoder = LabelEncoder()
labelencoder.fit(classes)  image_sets = ['1.png','2.png','3.png']def load_img(path, grayscale=False):if grayscale:# cv2.IMREAD_GRAYSCALE将灰度图读取成灰度图,否则cv2.imread默认将图像读取为RGBimg = cv2.imread(path,cv2.IMREAD_GRAYSCALE)else:img = cv2.imread(path)# 归一化img = np.array(img,dtype="float") / 255.0return img
# 训练数据路径
filepath ='./unet_train/'
# 划分训练集和验证集,其中用25%的数据来做验证集
def get_train_val(val_rate = 0.25):train_url = []    train_set = []val_set  = []for pic in os.listdir(filepath + 'src'):train_url.append(pic)random.shuffle(train_url)total_num = len(train_url)val_num = int(val_rate * total_num)# 打乱顺序之后的前25%作为验证集,剩余75%作为训练集for i in range(len(train_url)):if i < val_num:val_set.append(train_url[i]) else:train_set.append(train_url[i])return train_set,val_set
# 产生训练数据
# data for training
def generateData(batch_size,data=[]):  #print 'generateData...'while True:  train_data = []  train_label = []  batch = 0  for i in (range(len(data))): url = data[i]batch += 1 img = load_img(filepath + 'src/' + url)img = img_to_array(img)  train_data.append(img)label = load_img(filepath + 'label/' + url, grayscale=True) label = img_to_array(label)train_label.append(label)  if batch % batch_size==0: #print 'get enough batch!\n'train_data = np.array(train_data)  train_label = np.array(train_label)  yield (train_data,train_label)  train_data = []  train_label = []  batch = 0
# 产生验证数据
# data for validation
def generateValidData(batch_size,data=[]):  #print 'generateValidData...'while True:  valid_data = []  valid_label = []  batch = 0  for i in (range(len(data))):  url = data[i]batch += 1  img = load_img(filepath + 'src/' + url)img = img_to_array(img)  valid_data.append(img)  label = load_img(filepath + 'label/' + url, grayscale=True)label = img_to_array(label)valid_label.append(label)  if batch % batch_size==0:  valid_data = np.array(valid_data)  valid_label = np.array(valid_label)  yield (valid_data,valid_label)  valid_data = []  valid_label = []  batch = 0
# 定义unet,整体上来看是一个对称的U型结构
def unet():inputs = Input((3, img_w, img_h))conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(inputs)conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv1)pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(pool1)conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv2)pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(pool2)conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv3)pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(pool3)conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv4)pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(pool4)conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(conv5)conv5 = MaxPooling2D(pool_size=(2,2))(conv5)# 引入上采样将特征图方法,就是简单的插值。其中,UpSampling2D(size = size)(x),执行的操作是分别将x的行和列重复size[0]和size[1]次# 例如令size = [2,2], 从[[1,2],[3,4]]变成[[1,1,2,2],[1,1,2,2],[3,3,4,4],[3,3,4,4]]up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=1)conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(up6)conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv6)up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=1)conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(up7)conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv7)up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=1)conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(up8)conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv8)up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=1)conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(up9)conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv9)conv10 = Conv2D(n_label, (1, 1), activation="sigmoid")(conv9)#conv10 = Conv2D(n_label, (1, 1), activation="softmax")(conv9)model = Model(inputs=inputs, outputs=conv10)# 使用二元分类的cross_entropy,直接用cross_entropy也可以,多分类问题也适用于二分类问题model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])return modeldef train(args): EPOCHS = 10# batch_sizeBS = 16#model = SegNet()  model = unet()modelcheck = ModelCheckpoint(args['model'],monitor='val_accuracy',save_best_only=True,mode='max')  callable = [modelcheck]  train_set,val_set = get_train_val()train_numb = len(train_set)  valid_numb = len(val_set)  print ("the number of train data is",train_numb)  print ("the number of val data is",valid_numb)# max_q_size定义了内部训练队列(queue)的最大大小H = model.fit_generator(generator=generateData(BS,train_set),steps_per_epoch=train_numb//BS,epochs=EPOCHS,verbose=1,  validation_data=generateValidData(BS,val_set),validation_steps=valid_numb//BS,callbacks=callable,max_q_size=1)  # plot the training loss and accuracy# plt.style.use('ggplot')用ggplot样式美化画图效果# 可选的plt.style(plt.style.available)如下:# ['bmh', 'classic', 'dark_background', 'fast', 'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn-bright', 'seaborn-colorblind', # 'seaborn-dark-palette', 'seaborn-dark', 'seaborn-darkgrid', 'seaborn-deep', 'seaborn-muted', 'seaborn-notebook', 'seaborn-paper',# 'seaborn-pastel', 'seaborn-poster', 'seaborn-talk', 'seaborn-ticks', 'seaborn-white', 'seaborn-whitegrid', 'seaborn', # 'Solarize_Light2', 'tableau-colorblind10', '_classic_test']plt.style.use("ggplot")plt.figure()N = EPOCHSplt.plot(np.arange(0, N), H.history["loss"], label="train_loss")plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")plt.plot(np.arange(0, N), H.history["acc"], label="train_acc")plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc")plt.title("Training Loss and Accuracy on U-Net Satellite Seg")plt.xlabel("Epoch #")plt.ylabel("Loss/Accuracy")# 在右下角画图plt.legend(loc="lower left")plt.savefig(args["plot"])# 命令行输入参数的提示以及默认参数
def args_parse():# construct the argument parse and parse the argumentsap = argparse.ArgumentParser()ap.add_argument("-d", "--data", help="training data's path",default=True)ap.add_argument("-m", "--model", required=True,help="path to output model")ap.add_argument("-p", "--plot", type=str, default="plot.png",help="path to output accuracy/loss plot")args = vars(ap.parse_args()) return argsif __name__=='__main__':  args = args_parse()filepath = args['data']train(args)  #predict()  

为了看清楚unet的每一层的输入输出的tensor是怎么样的形状,我们将其打印出来如下:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_7 (InputLayer)            (None, 3, 256, 256)  0
__________________________________________________________________________________________________
conv2d_79 (Conv2D)              (None, 32, 256, 256) 896         input_7[0][0]
__________________________________________________________________________________________________
conv2d_80 (Conv2D)              (None, 32, 256, 256) 9248        conv2d_79[0][0]
__________________________________________________________________________________________________
max_pooling2d_29 (MaxPooling2D) (None, 32, 128, 128) 0           conv2d_80[0][0]
__________________________________________________________________________________________________
conv2d_81 (Conv2D)              (None, 64, 128, 128) 18496       max_pooling2d_29[0][0]
__________________________________________________________________________________________________
conv2d_82 (Conv2D)              (None, 64, 128, 128) 36928       conv2d_81[0][0]
__________________________________________________________________________________________________
max_pooling2d_30 (MaxPooling2D) (None, 64, 64, 64)   0           conv2d_82[0][0]
__________________________________________________________________________________________________
conv2d_83 (Conv2D)              (None, 128, 64, 64)  73856       max_pooling2d_30[0][0]
__________________________________________________________________________________________________
conv2d_84 (Conv2D)              (None, 128, 64, 64)  147584      conv2d_83[0][0]
__________________________________________________________________________________________________
max_pooling2d_31 (MaxPooling2D) (None, 128, 32, 32)  0           conv2d_84[0][0]
__________________________________________________________________________________________________
conv2d_85 (Conv2D)              (None, 256, 32, 32)  295168      max_pooling2d_31[0][0]
__________________________________________________________________________________________________
conv2d_86 (Conv2D)              (None, 256, 32, 32)  590080      conv2d_85[0][0]
__________________________________________________________________________________________________
max_pooling2d_32 (MaxPooling2D) (None, 256, 16, 16)  0           conv2d_86[0][0]
__________________________________________________________________________________________________
conv2d_87 (Conv2D)              (None, 512, 16, 16)  1180160     max_pooling2d_32[0][0]
__________________________________________________________________________________________________
conv2d_88 (Conv2D)              (None, 512, 16, 16)  2359808     conv2d_87[0][0]
__________________________________________________________________________________________________
up_sampling2d_13 (UpSampling2D) (None, 512, 32, 32)  0           conv2d_88[0][0]
__________________________________________________________________________________________________
concatenate_13 (Concatenate)    (None, 768, 32, 32)  0           up_sampling2d_13[0][0]conv2d_86[0][0]
__________________________________________________________________________________________________
conv2d_89 (Conv2D)              (None, 256, 32, 32)  1769728     concatenate_13[0][0]
__________________________________________________________________________________________________
conv2d_90 (Conv2D)              (None, 256, 32, 32)  590080      conv2d_89[0][0]
__________________________________________________________________________________________________
up_sampling2d_14 (UpSampling2D) (None, 256, 64, 64)  0           conv2d_90[0][0]
__________________________________________________________________________________________________
concatenate_14 (Concatenate)    (None, 384, 64, 64)  0           up_sampling2d_14[0][0]conv2d_84[0][0]
__________________________________________________________________________________________________
conv2d_91 (Conv2D)              (None, 128, 64, 64)  442496      concatenate_14[0][0]
__________________________________________________________________________________________________
conv2d_92 (Conv2D)              (None, 128, 64, 64)  147584      conv2d_91[0][0]
__________________________________________________________________________________________________
up_sampling2d_15 (UpSampling2D) (None, 128, 128, 128 0           conv2d_92[0][0]
__________________________________________________________________________________________________
concatenate_15 (Concatenate)    (None, 192, 128, 128 0           up_sampling2d_15[0][0]conv2d_82[0][0]
__________________________________________________________________________________________________
conv2d_93 (Conv2D)              (None, 64, 128, 128) 110656      concatenate_15[0][0]
__________________________________________________________________________________________________
conv2d_94 (Conv2D)              (None, 64, 128, 128) 36928       conv2d_93[0][0]
__________________________________________________________________________________________________
up_sampling2d_16 (UpSampling2D) (None, 64, 256, 256) 0           conv2d_94[0][0]
__________________________________________________________________________________________________
concatenate_16 (Concatenate)    (None, 96, 256, 256) 0           up_sampling2d_16[0][0]conv2d_80[0][0]
__________________________________________________________________________________________________
conv2d_95 (Conv2D)              (None, 32, 256, 256) 27680       concatenate_16[0][0]
__________________________________________________________________________________________________
conv2d_96 (Conv2D)              (None, 32, 256, 256) 9248        conv2d_95[0][0]
__________________________________________________________________________________________________
conv2d_97 (Conv2D)              (None, 1, 256, 256)  33          conv2d_96[0][0]
==================================================================================================
Total params: 7,846,657
Trainable params: 7,846,657
Non-trainable params: 0
__________________________________________________________________________________________________

模型融合combind.py

#coding=utf-8import numpy as np
import cv2
import csv
from tqdm import tqdm
# 定义三个mask
mask1_pool = ['testing1_vegetation_predict.png','testing1_building_predict.png','testing1_water_predict.png','testing1_road_predict.png']mask2_pool = ['testing2_vegetation_predict.png','testing2_building_predict.png','testing2_water_predict.png','testing2_road_predict.png']mask3_pool = ['testing3_vegetation_predict.png','testing3_building_predict.png','testing3_water_predict.png','testing3_road_predict.png']              ## 0:none  1:vegetation   2:building   3:water   4:road#after mask combind
img_sets = ['pre1.png','pre2.png','pre3.png']def combind_all_mask():for mask_num in tqdm(range(3)):if mask_num == 0:final_mask = np.zeros((5142,5664),np.uint8)#生成一个全黑全0图像,图片尺寸与原图相同elif mask_num == 1:final_mask = np.zeros((2470,4011),np.uint8)elif mask_num == 2:final_mask = np.zeros((6116,3356),np.uint8)#final_mask = cv2.imread('final_1_8bits_predict.png',0)if mask_num == 0:mask_pool = mask1_poolelif mask_num == 1:mask_pool = mask2_poolelif mask_num == 2:mask_pool = mask3_poolfinal_name = img_sets[mask_num]for idx,name in enumerate(mask_pool):img = cv2.imread('./predict_mask/'+name,0)height,width = img.shapelabel_value = idx+1  #coressponding labels valuefor i in tqdm(range(height)):    #priority:building>water>road>vegetationfor j in range(width):# 模型融合if img[i,j] == 255:# 如果当前像素为全部为全白,那么到底这个区域属于哪个类别呢?按照优先级的顺序来定:building>water>road>vegetationif label_value == 2:final_mask[i,j] = label_valueelif label_value == 3 and final_mask[i,j] != 2:final_mask[i,j] = label_valueelif label_value == 4 and final_mask[i,j] != 2 and final_mask[i,j] != 3:final_mask[i,j] = label_valueelif label_value == 1 and final_mask[i,j] == 0:final_mask[i,j] = label_value                        cv2.imwrite('./final_result/'+final_name,final_mask)           print 'combinding mask...'
combind_all_mask()                

UNet模型预测unet_predict.py

import cv2
import random
import numpy as np
import os
import argparse
from keras.preprocessing.image import img_to_array
from keras.models import load_model
from sklearn.preprocessing import LabelEncoder
# 设置用编号为1的GPU来训练
os.environ["CUDA_VISIBLE_DEVICES"] = "1"TEST_SET = ['1.png','2.png','3.png']image_size = 256classes = [0. ,  1.,  2.,   3.  , 4.]  labelencoder = LabelEncoder()
labelencoder.fit(classes) def args_parse():
# construct the argument parse and parse the argumentsap = argparse.ArgumentParser()ap.add_argument("-m", "--model", required=True,help="path to trained model model")ap.add_argument("-s", "--stride", required=False,help="crop slide stride", type=int, default=image_size)args = vars(ap.parse_args())    return argsdef predict(args):# load the trained convolutional neural networkprint("[INFO] loading network...")# 加载训练好的模型model = load_model(args["model"])stride = args['stride']for n in range(len(TEST_SET)):path = TEST_SET[n]#load the image读取测试图片image = cv2.imread('./test/' + path)h,w,_ = image.shape# 要怎么样进行预测呢?由于在训练的时候输入的图像大小是256x256,在测试的时候喂给model的size也是256,# 可以先对原图补零,确保padding之后的size刚好可以被256整除padding_h = (h//stride + 1) * stride padding_w = (w//stride + 1) * stridepadding_img = np.zeros((padding_h,padding_w,3),dtype=np.uint8)# 不足的部分补零padding_img[0:h,0:w,:] = image[:,:,:]#padding_img = padding_img.astype("float") / 255.0padding_img = img_to_array(padding_img)print ('src:',padding_img.shape)mask_whole = np.zeros((padding_h,padding_w),dtype=np.uint8)for i in range(padding_h//stride):for j in range(padding_w//stride):# 放到padding之后的图像对应的位置crop = padding_img[:3,i*stride:i*stride+image_size,j*stride:j*stride+image_size]_,ch,cw = crop.shapeif ch != 256 or cw != 256:print ('invalid size!')continuecrop = np.expand_dims(crop, axis=0) # fit当中的verbose = 0 为不在标准输出流输出日志信息# verbose = 1 为输出进度条记录# verbose = 2 为每个epoch输出一行记录# evaluate当中的verbose = 0 为不在标准输出流输出日志信息# verbose = 1 为输出进度条记录pred = model.predict(crop,verbose=2)#print (np.unique(pred))  pred = pred.reshape((256,256)).astype(np.uint8)#print ('pred:',pred.shape)mask_whole[i*stride:i*stride+image_size,j*stride:j*stride+image_size] = pred[:,:]# 再把图像切割成跟原来一样大小的图像cv2.imwrite('./predict/pre'+str(n+1)+'.png',mask_whole[0:h,0:w])if __name__ == '__main__':args = args_parse()predict(args)

分类结果集成ensemble.py

import numpy as np
import cv2
import argparseRESULT_PREFIXX = ['./result1/','./result2/','./result3/']# each mask has 5 classes: 0~4def vote_per_image(image_id):result_list = []for j in range(len(RESULT_PREFIXX)):im = cv2.imread(RESULT_PREFIXX[j]+str(image_id)+'.png',0)result_list.append(im)# each pixelheight,width = result_list[0].shapevote_mask = np.zeros((height,width))for h in range(height):for w in range(width):# 像素级别# 每个像素的所属的类别,总共5类,因此类别list是一个1x5的recordrecord = np.zeros((1,5))# 下面这个for循环是每个像素的类别级别for n in range(len(result_list)):#对于每一类结果中的每一张图片的每一个像素,统计这个位置的类别票数mask = result_list[n]pixel = mask[h,w]#print('pix:',pixel)record[0,pixel]+=1# 集成学习,取票数最多的为最终类别label = record.argmax()#print(label)vote_mask[h,w] = labelcv2.imwrite('vote_mask'+str(image_id)+'.png',vote_mask)
# 总共3类结果
vote_per_image(3)

SegNet模型训练segnet_train.py

#coding=utf-8
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import argparse
import numpy as np
from keras.models import Sequential
from keras.layers import Conv2D,MaxPooling2D,UpSampling2D,BatchNormalization,Reshape,Permute,Activation
from keras.utils.np_utils import to_categorical
from keras.preprocessing.image import img_to_array
from keras.callbacks import ModelCheckpoint
from sklearn.preprocessing import LabelEncoder
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import random
import os
from tqdm import tqdm
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
seed = 7
np.random.seed(seed)  #data_shape = 360*480
img_w = 256
img_h = 256
#有一个为背景
n_label = 4+1  classes = [0. ,  1.,  2.,   3.  , 4.]  labelencoder = LabelEncoder()
labelencoder.fit(classes)  image_sets = ['1.png','2.png','3.png']def load_img(path, grayscale=False):if grayscale:img = cv2.imread(path,cv2.IMREAD_GRAYSCALE)else:img = cv2.imread(path)img = np.array(img,dtype="float") / 255.0return imgfilepath ='./train/'  def get_train_val(val_rate = 0.25):train_url = []    train_set = []val_set  = []for pic in os.listdir(filepath + 'src'):train_url.append(pic)random.shuffle(train_url)total_num = len(train_url)val_num = int(val_rate * total_num)for i in range(len(train_url)):if i < val_num:val_set.append(train_url[i]) else:train_set.append(train_url[i])return train_set,val_set# data for training
def generateData(batch_size,data=[]):  #print 'generateData...'while True:  train_data = []  train_label = []  batch = 0  for i in (range(len(data))): url = data[i]batch += 1 img = load_img(filepath + 'src/' + url)img = img_to_array(img) train_data.append(img)  label = load_img(filepath + 'label/' + url, grayscale=True)label = img_to_array(label).reshape((img_w * img_h,))  # print label.shape  train_label.append(label)  if batch % batch_size==0: #print 'get enough bacth!\n'train_data = np.array(train_data)  train_label = np.array(train_label).flatten()  train_label = labelencoder.transform(train_label)  train_label = to_categorical(train_label, num_classes=n_label)  train_label = train_label.reshape((batch_size,img_w * img_h,n_label))  yield (train_data,train_label)  train_data = []  train_label = []  batch = 0  # data for validation
def generateValidData(batch_size,data=[]):  #print 'generateValidData...'while True:  valid_data = []  valid_label = []  batch = 0  for i in (range(len(data))):  url = data[i]batch += 1  img = load_img(filepath + 'src/' + url)img = img_to_array(img)  valid_data.append(img)  label = load_img(filepath + 'label/' + url, grayscale=True)label = img_to_array(label).reshape((img_w * img_h,))  # print label.shape  valid_label.append(label)  if batch % batch_size==0:  valid_data = np.array(valid_data)  valid_label = np.array(valid_label).flatten()  valid_label = labelencoder.transform(valid_label)  valid_label = to_categorical(valid_label, num_classes=n_label)  valid_label = valid_label.reshape((batch_size,img_w * img_h,n_label))  yield (valid_data,valid_label)  valid_data = []  valid_label = []  batch = 0  def SegNet():  model = Sequential()  #encoder  model.add(Conv2D(64,(3,3),strides=(1,1),input_shape=(3,img_w,img_h),padding='same',activation='relu'))  model.add(BatchNormalization())  model.add(Conv2D(64,(3,3),strides=(1,1),padding='same',activation='relu'))  model.add(BatchNormalization())  model.add(MaxPooling2D(pool_size=(2,2),dim_ordering = 'th'))  #(128,128)  model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(MaxPooling2D(pool_size=(2, 2),dim_ordering = 'th'))  #(64,64)  model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(MaxPooling2D(pool_size=(2, 2),dim_ordering = 'th'))  #(32,32)  model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(MaxPooling2D(pool_size=(2, 2),dim_ordering = 'th'))  #(16,16)  model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(MaxPooling2D(pool_size=(2, 2),dim_ordering = 'th'))  #(8,8)  #decoder  model.add(UpSampling2D(size=(2,2)))  #(16,16)  model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(UpSampling2D(size=(2, 2)))  #(32,32)  model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(UpSampling2D(size=(2, 2)))  #(64,64)  model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(UpSampling2D(size=(2, 2)))  #(128,128)  model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(UpSampling2D(size=(2, 2)))  #(256,256)  model.add(Conv2D(64, (3, 3), strides=(1, 1), input_shape=(3,img_w, img_h), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(Conv2D(64, (3, 3), strides=(1, 1), padding='same', activation='relu'))  model.add(BatchNormalization())  model.add(Conv2D(n_label, (1, 1), strides=(1, 1), padding='same'))  model.add(Reshape((n_label,img_w*img_h)))  #axis=1和axis=2互换位置,等同于np.swapaxes(layer,1,2)  model.add(Permute((2,1)))  model.add(Activation('softmax'))  model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])  return model  def train(args): EPOCHS = 30BS = 16model = SegNet()  modelcheck = ModelCheckpoint(args['model'],monitor='val_acc',save_best_only=True,mode='max')  callable = [modelcheck]  train_set,val_set = get_train_val()train_numb = len(train_set)  valid_numb = len(val_set)  print ("the number of train data is",train_numb)  print ("the number of val data is",valid_numb)H = model.fit_generator(generator=generateData(BS,train_set),steps_per_epoch=train_numb//BS,epochs=EPOCHS,verbose=1,  validation_data=generateValidData(BS,val_set),validation_steps=valid_numb//BS,callbacks=callable,max_q_size=1)  # plot the training loss and accuracyplt.style.use("ggplot")plt.figure()N = EPOCHSplt.plot(np.arange(0, N), H.history["loss"], label="train_loss")plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")plt.plot(np.arange(0, N), H.history["acc"], label="train_acc")plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc")plt.title("Training Loss and Accuracy on SegNet Satellite Seg")plt.xlabel("Epoch #")plt.ylabel("Loss/Accuracy")plt.legend(loc="lower left")plt.savefig(args["plot"])def args_parse():# construct the argument parse and parse the argumentsap = argparse.ArgumentParser()ap.add_argument("-a", "--augment", help="using data augment or not",action="store_true", default=False)ap.add_argument("-m", "--model", required=True,help="path to output model")ap.add_argument("-p", "--plot", type=str, default="plot.png",help="path to output accuracy/loss plot")args = vars(ap.parse_args()) return argsif __name__=='__main__':  args = args_parse()if args['augment'] == True:filepath ='./aug/train/'train(args)  #predict()  

同理,为了搞清楚segnet每一层的输入输出的tensor分别是什么样的,我们将shape打印出来如下:

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d_98 (Conv2D)           (None, 64, 256, 256)      1792
_________________________________________________________________
batch_normalization_1 (Batch (None, 64, 256, 256)      1024
_________________________________________________________________
conv2d_99 (Conv2D)           (None, 64, 256, 256)      36928
_________________________________________________________________
batch_normalization_2 (Batch (None, 64, 256, 256)      1024
_________________________________________________________________
max_pooling2d_33 (MaxPooling (None, 64, 128, 128)      0
_________________________________________________________________
conv2d_100 (Conv2D)          (None, 128, 128, 128)     73856
_________________________________________________________________
batch_normalization_3 (Batch (None, 128, 128, 128)     512
_________________________________________________________________
conv2d_101 (Conv2D)          (None, 128, 128, 128)     147584
_________________________________________________________________
batch_normalization_4 (Batch (None, 128, 128, 128)     512
_________________________________________________________________
max_pooling2d_34 (MaxPooling (None, 128, 64, 64)       0
_________________________________________________________________
conv2d_102 (Conv2D)          (None, 256, 64, 64)       295168
_________________________________________________________________
batch_normalization_5 (Batch (None, 256, 64, 64)       256
_________________________________________________________________
conv2d_103 (Conv2D)          (None, 256, 64, 64)       590080
_________________________________________________________________
batch_normalization_6 (Batch (None, 256, 64, 64)       256
_________________________________________________________________
conv2d_104 (Conv2D)          (None, 256, 64, 64)       590080
_________________________________________________________________
batch_normalization_7 (Batch (None, 256, 64, 64)       256
_________________________________________________________________
max_pooling2d_35 (MaxPooling (None, 256, 32, 32)       0
_________________________________________________________________
conv2d_105 (Conv2D)          (None, 512, 32, 32)       1180160
_________________________________________________________________
batch_normalization_8 (Batch (None, 512, 32, 32)       128
_________________________________________________________________
conv2d_106 (Conv2D)          (None, 512, 32, 32)       2359808
_________________________________________________________________
batch_normalization_9 (Batch (None, 512, 32, 32)       128
_________________________________________________________________
conv2d_107 (Conv2D)          (None, 512, 32, 32)       2359808
_________________________________________________________________
batch_normalization_10 (Batc (None, 512, 32, 32)       128
_________________________________________________________________
max_pooling2d_36 (MaxPooling (None, 512, 16, 16)       0
_________________________________________________________________
conv2d_108 (Conv2D)          (None, 512, 16, 16)       2359808
_________________________________________________________________
batch_normalization_11 (Batc (None, 512, 16, 16)       64
_________________________________________________________________
conv2d_109 (Conv2D)          (None, 512, 16, 16)       2359808
_________________________________________________________________
batch_normalization_12 (Batc (None, 512, 16, 16)       64
_________________________________________________________________
conv2d_110 (Conv2D)          (None, 512, 16, 16)       2359808
_________________________________________________________________
batch_normalization_13 (Batc (None, 512, 16, 16)       64
_________________________________________________________________
max_pooling2d_37 (MaxPooling (None, 512, 8, 8)         0
_________________________________________________________________
up_sampling2d_17 (UpSampling (None, 512, 16, 16)       0
_________________________________________________________________
conv2d_111 (Conv2D)          (None, 512, 16, 16)       2359808
_________________________________________________________________
batch_normalization_14 (Batc (None, 512, 16, 16)       64
_________________________________________________________________
conv2d_112 (Conv2D)          (None, 512, 16, 16)       2359808
_________________________________________________________________
batch_normalization_15 (Batc (None, 512, 16, 16)       64
_________________________________________________________________
conv2d_113 (Conv2D)          (None, 512, 16, 16)       2359808
_________________________________________________________________
batch_normalization_16 (Batc (None, 512, 16, 16)       64
_________________________________________________________________
up_sampling2d_18 (UpSampling (None, 512, 32, 32)       0
_________________________________________________________________
conv2d_114 (Conv2D)          (None, 512, 32, 32)       2359808
_________________________________________________________________
batch_normalization_17 (Batc (None, 512, 32, 32)       128
_________________________________________________________________
conv2d_115 (Conv2D)          (None, 512, 32, 32)       2359808
_________________________________________________________________
batch_normalization_18 (Batc (None, 512, 32, 32)       128
_________________________________________________________________
conv2d_116 (Conv2D)          (None, 512, 32, 32)       2359808
_________________________________________________________________
batch_normalization_19 (Batc (None, 512, 32, 32)       128
_________________________________________________________________
up_sampling2d_19 (UpSampling (None, 512, 64, 64)       0
_________________________________________________________________
conv2d_117 (Conv2D)          (None, 256, 64, 64)       1179904
_________________________________________________________________
batch_normalization_20 (Batc (None, 256, 64, 64)       256
_________________________________________________________________
conv2d_118 (Conv2D)          (None, 256, 64, 64)       590080
_________________________________________________________________
batch_normalization_21 (Batc (None, 256, 64, 64)       256
_________________________________________________________________
conv2d_119 (Conv2D)          (None, 256, 64, 64)       590080
_________________________________________________________________
batch_normalization_22 (Batc (None, 256, 64, 64)       256
_________________________________________________________________
up_sampling2d_20 (UpSampling (None, 256, 128, 128)     0
_________________________________________________________________
conv2d_120 (Conv2D)          (None, 128, 128, 128)     295040
_________________________________________________________________
batch_normalization_23 (Batc (None, 128, 128, 128)     512
_________________________________________________________________
conv2d_121 (Conv2D)          (None, 128, 128, 128)     147584
_________________________________________________________________
batch_normalization_24 (Batc (None, 128, 128, 128)     512
_________________________________________________________________
up_sampling2d_21 (UpSampling (None, 128, 256, 256)     0
_________________________________________________________________
conv2d_122 (Conv2D)          (None, 64, 256, 256)      73792
_________________________________________________________________
batch_normalization_25 (Batc (None, 64, 256, 256)      1024
_________________________________________________________________
conv2d_123 (Conv2D)          (None, 64, 256, 256)      36928
_________________________________________________________________
batch_normalization_26 (Batc (None, 64, 256, 256)      1024
_________________________________________________________________
conv2d_124 (Conv2D)          (None, 1, 256, 256)       65
_________________________________________________________________
reshape_1 (Reshape)          (None, 1, 65536)          0
_________________________________________________________________
permute_1 (Permute)          (None, 65536, 1)          0
_________________________________________________________________
activation_1 (Activation)    (None, 65536, 1)          0
=================================================================
Total params: 31,795,841
Trainable params: 31,791,425
Non-trainable params: 4,416
_________________________________________________________________

基于SegNet和UNet的遥感图像分割代码解读相关推荐

  1. 【Keras】基于SegNet和U-Net的遥感图像语义分割

    from:[Keras]基于SegNet和U-Net的遥感图像语义分割 上两个月参加了个比赛,做的是对遥感高清图像做语义分割,美其名曰"天空之眼".这两周数据挖掘课期末projec ...

  2. Keras】基于SegNet和U-Net的遥感图像语义分割

    from:[Keras]基于SegNet和U-Net的遥感图像语义分割 上两个月参加了个比赛,做的是对遥感高清图像做语义分割,美其名曰"天空之眼".这两周数据挖掘课期末projec ...

  3. 基于U-Net的的图像分割代码详解及应用实现

    摘要 U-Net是基于卷积神经网络(CNN)体系结构设计而成的,由Olaf Ronneberger,Phillip Fischer和Thomas Brox于2015年首次提出应用于计算机视觉领域完成语 ...

  4. segnet和unet区别_图像分割:3D Unet网络性能一定优于2D Unet吗,如果优于,为什么优于?...

    上周在某大厂医疗AI组的一面刚被问到这个问题.我开门见山,先上结论:3D UNet不一定优于2D UNet.科学里面没有绝对的结论,但凡说A一定好于B,需要说明的是在什么样的情况下好于,也就是需要控制 ...

  5. OpenCV C++案例实战二十九《遥感图像分割》

    OpenCV C++案例实战二十九<遥感图像分割> 前言 一.准备数据 二.K-Means分类 三.效果显示 四.源码 总结 前言 本案例基于k-means机器学习算法进行遥感图像分割.主 ...

  6. Python基于改进FCN&VGG的高分辨率遥感图像分割(完整源码&数据集&视频教程)

    1.高分辨率遥感图像分割效果展示: 2.数据集简介: 首先介绍一下数据,我们这次采用的数据集是CCF大数据比赛提供的数据(2015年中国南方某城市的高清遥感图像),这是一个小数据集,里面包含了5张带标 ...

  7. 【Matlab/CV系列】基于K-means/分水岭分割的多光谱遥感图像分割的Matlab实现

    Date:2022.4.18 文章目录 前言 1.初始界面 2.三种方法分割界面 3.光谱图 前言 在之前的时候,毕业设计中实现了基于K-means/分水岭/交叉熵分割的多光谱遥感图像分割算法,效果不 ...

  8. matlab实现谱聚类法图像分割代码,一种基于谱聚类的图像分割方法与系统与流程...

    本发明是一种基于谱聚类的图像分割方法与系统,涉及聚类.机器学习和人工智能领域.特别涉及通过相关性将已经学习到的知识运用到图像分割中,并在此基础上构造性地改造谱聚类方法,从而达到快速精确地分割彩色图像的 ...

  9. 基于PyTorch深度学习无人机遥感影像目标检测、地物分类及语义分割

    随着无人机自动化能力的逐步升级,它被广泛的应用于多种领域,如航拍.农业.植保.灾难评估.救援.测绘.电力巡检等.但同时由于无人机飞行高度低.获取目标类型多.以及环境复杂等因素使得对无人机获取的数据处理 ...

最新文章

  1. Pass4side CompTIA PK0-002题库下载
  2. 大一统的宇宙与太极原理之随想
  3. ASP.NET那点不为人知的事(四)
  4. 【redis】liunx安装redis
  5. js 查找当前元素/this
  6. php catch e是什么,php try catch 中的catch(Exception $e) 中的Exception 有什么作用?
  7. php 重定向 cookie,php – 如何在CURL重定向上传递Cookie?
  8. 1039 Course List for Student (25 分)_33行代码AC
  9. php substr 去掉前n位_PHP全栈学习笔记16
  10. Eclipse里修改SVN的用户名和密码
  11. Arduino 极速入门系列 - 光控灯(3) - 光敏电阻、与电阻分压那些事
  12. 剖析IE浏览器子系统的性能权重
  13. 事件抽取与事理图谱的N个问题
  14. rest api是什么_一文搞懂什么是RESTful API
  15. [Java] 蓝桥杯ALGO-100 算法训练 整除问题
  16. 备忘--简单比较SPSS、RapidMiner、KNIME以及Kettle四款数据分析工具
  17. 剑指offer——复习1:二叉树三种遍历方式的迭代与递归实现
  18. AJAX 简单例程示例
  19. 客户数据中台(CDP): 当代数字化营销顶梁柱
  20. iOS 手势解锁密码

热门文章

  1. 使用Apple的感受
  2. U3D Debug.log的问题
  3. 【NET CORE微服务一条龙应用】开始篇与目录
  4. 如何判断线程运行结束
  5. nginx大量TIME_WAIT的解决办法 netstat -n | awk '/^tcp/ {++S[$NF]} END {for(a in S) print a, S[a]}'...
  6. Educational Codeforces Round 37 (Rated for Div. 2)
  7. LAMP(linux下apache+mysql+php)平台编译安装的实现
  8. apache 和 nginx 301重定向配置方法
  9. 批量消除图片的杂色背景
  10. ASP.NET AJAX入门系列(3):使用ScriptManagerProxy控件