一、任务描述

看到一张图像,你的大脑可以很容易地分辨出图像是关于什么的,但是计算机能分辨出图像所代表的内容吗?随着深度学习技术的进步、庞大数据集的可用性和计算机能力,我们可以构建可以为图像生成说明的模型。

我们将使用CNN(卷积神经网络) 和 LSTM(长期短期记忆)来实现字幕生成器。图像特征将从 Xception 中提取,Xception 是在 imagenet 数据集上训练的 CNN 模型,然后我们将特征输入 LSTM 模型,该模型将负责生成图像说明。

二、数据集说明

这里使用 Flickr_8K 数据集。虽然还有其他大型数据集,如 Flickr_30K 和 MSCOCO 数据集,但仅训练网络可能需要几周时间,所以我们将使用小型 Flickr8k 数据集。但是庞大数据集的优势在于我们可以构建更好的模型。

Flickr_8K 数据集为基于句子的图像描述和搜索引入了一个新的基准集合,由 8,000 张图像组成,每张图像都与五个不同的标题配对,这些标题提供了对显着实体和事件的清晰描述。

数据集下载地址

链接:https://pan.baidu.com/s/1aG3CYioORpPdXC89_F_s3A 
提取码:q9el

数据集内的Flickr8k_Dataset.zip是图像文件压缩包。

Flickr8k_text.zip是英文描述等文件。

flickr8kcn文件夹是对应的中文描述等文件,下面代码是根据英语等训练的,可以自行修改,以中文进行训练。

三、模型结构

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_2 (InputLayer)           [(None, 32)]         0           []                               
                                                                                                  
 input_1 (InputLayer)           [(None, 2048)]       0           []                               
                                                                                                  
 embedding (Embedding)          (None, 32, 256)      1939712     ['input_2[0][0]']                
                                                                                                  
 dropout (Dropout)              (None, 2048)         0           ['input_1[0][0]']                
                                                                                                  
 dropout_1 (Dropout)            (None, 32, 256)      0           ['embedding[0][0]']              
                                                                                                  
 dense (Dense)                  (None, 256)          524544      ['dropout[0][0]']                
                                                                                                  
 lstm (LSTM)                    (None, 256)          525312      ['dropout_1[0][0]']              
                                                                                                  
 add (Add)                      (None, 256)          0           ['dense[0][0]',                  
                                                                  'lstm[0][0]']                   
                                                                                                  
 dense_1 (Dense)                (None, 256)          65792       ['add[0][0]']                    
                                                                                                  
 dense_2 (Dense)                (None, 7577)         1947289     ['dense_1[0][0]']                
                                                                                                  
==================================================================================================
Total params: 5,002,649
Trainable params: 5,002,649
Non-trainable params: 0
__________________________________________________________________________________________________

四、训练模型

1、参考代码

这里在图像预处理的时候使用了Xception进行特征抽取,并保存到features.p文件。

import string
import numpy as np
from PIL import Image
import os
from pickle import dump, load
import numpy as npimport tensorflow as tf
from tensorflow.keras.applications.xception import Xception, preprocess_input
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import add
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Dense, LSTM, Embedding, Dropout
from tensorflow.keras.utils import plot_model# small library for seeing the progress of loops.
from tqdm import tqdm_notebook as tqdm
#tqdm().pandas()# Loading a text file into memory
def load_doc(filename):# Opening the file as read onlyfile = open(filename, 'r')text = file.read()file.close()return text# get all imgs with their captions
def all_img_captions(filename):file = load_doc(filename)captions = file.split('\n')descriptions ={}for caption in captions[:-1]:img, caption = caption.split('\t')if img[:-2] not in descriptions:descriptions[img[:-2]] = [caption]else:descriptions[img[:-2]].append(caption)return descriptions##Data cleaning- lower casing, removing puntuations and words containing numbers
#此函数获取所有描述并执行数据清理。这是我们处理文本数据时的重要一步,根据我们的目标,我们决定要对文本执行哪种类型的清理。在我们的例子中,我们将删除标点符号,将所有文本转换为小写,并删除包含数字的单词。
def cleaning_text(captions):table = str.maketrans('','',string.punctuation)for img,caps in captions.items():for i,img_caption in enumerate(caps):img_caption.replace("-"," ")desc = img_caption.split()#converts to lower casedesc = [word.lower() for word in desc]#remove punctuation from each tokendesc = [word.translate(table) for word in desc]#remove hanging 's and adesc = [word for word in desc if(len(word)>1)]#remove tokens with numbers in themdesc = [word for word in desc if(word.isalpha())]#convert back to stringimg_caption = ' '.join(desc)captions[img][i]= img_captionreturn captionsdef text_vocabulary(descriptions):# build vocabulary of all unique wordsvocab = set()for key in descriptions.keys():[vocab.update(d.split()) for d in descriptions[key]]return vocab#此函数将创建一个已预处理的所有描述的列表并将它们存储到一个文件中。
def save_descriptions(descriptions, filename):lines = list()for key, desc_list in descriptions.items():for desc in desc_list:lines.append(key + '\t' + desc )data = "\n".join(lines)file = open(filename,"w")file.write(data)file.close()# all_train_captions = []
# for key, val in descriptions.items():
#     for cap in val:
#         all_train_captions.append(cap)# # Consider only words which occur at least 8 times in the corpus
# word_count_threshold = 8
# word_counts = {}
# nsents = 0
# for sent in all_train_captions:
#     nsents += 1
#     for w in sent.split(' '):
#         word_counts[w] = word_counts.get(w, 0) + 1# vocab = [w for w in word_counts if word_counts[w] >= word_count_threshold]# print('preprocessed words %d ' % len(vocab))dataset_text = "Flickr8k_text"
dataset_images = "Flicker8k_Dataset"#we prepare our text data
filename = dataset_text + "/" + "Flickr8k.token.txt"
#loading the file that contains all data
#mapping them into descriptions dictionary img to 5 captions
descriptions = all_img_captions(filename)
print("Length of descriptions =" ,len(descriptions))#cleaning the descriptions
clean_descriptions = cleaning_text(descriptions)#building vocabulary
vocabulary = text_vocabulary(clean_descriptions)
print("Length of vocabulary = ", len(vocabulary))#saving each description to file
save_descriptions(clean_descriptions, "descriptions.txt")def extract_features(directory):model = Xception(include_top=False, pooling='avg')features = {}for img in tqdm(os.listdir(directory)):filename = directory + "/" + imgimage = Image.open(filename)image = image.resize((299, 299))image = np.expand_dims(image, axis=0)# image = preprocess_input(image)image = image / 127.5image = image - 1.0feature = model.predict(image)features[img] = featurereturn features# 提取特征向量 2048 feature vector
# 如果已经提取好了,可以注释掉下面两句
features = extract_features(dataset_images)
dump(features, open("features.p","wb"))features = load(open("features.p","rb"))# load the data
def load_photos(filename):file = load_doc(filename)photos = file.split("\n")[:-1]return photosdef load_clean_descriptions(filename, photos):# loading clean_descriptionsfile = load_doc(filename)descriptions = {}for line in file.split("\n"):words = line.split()if len(words) < 1:continueimage, image_caption = words[0], words[1:]if image in photos:if image not in descriptions:descriptions[image] = []desc = '<start> ' + " ".join(image_caption) + ' <end>'descriptions[image].append(desc)return descriptionsdef load_features(photos):# loading all featuresall_features = load(open("features.p", "rb"))# selecting only needed featuresfeatures = {k: all_features[k] for k in photos}return featuresfilename = dataset_text + "/" + "Flickr_8k.trainImages.txt"#train = loading_data(filename)
train_imgs = load_photos(filename)
train_descriptions = load_clean_descriptions("descriptions.txt", train_imgs)
train_features = load_features(train_imgs)#converting dictionary to clean list of descriptions
def dict_to_list(descriptions):all_desc = []for key in descriptions.keys():[all_desc.append(d) for d in descriptions[key]]return all_desc#creating tokenizer class
#this will vectorise text corpus
#each integer will represent token in dictionaryfrom keras.preprocessing.text import Tokenizerdef create_tokenizer(descriptions):desc_list = dict_to_list(descriptions)tokenizer = Tokenizer()tokenizer.fit_on_texts(desc_list)return tokenizer# give each word a index, and store that into tokenizer.p pickle file
tokenizer = create_tokenizer(train_descriptions)
dump(tokenizer, open('tokenizer.p', 'wb'))
vocab_size = len(tokenizer.word_index) + 1
vocab_size#calculate maximum length of descriptions
def max_length(descriptions):desc_list = dict_to_list(descriptions)return max(len(d.split()) for d in desc_list)max_length = max_length(descriptions)
max_lengthprint(features['1000268201_693b08cb0e.jpg'][0])# Define the model#1 Photo feature extractor - we extracted features from pretrained model Xception.
#2 Sequence processor - word embedding layer that handles text, followed by LSTM
#3 Decoder - Both 1 and 2 model produce fixed length vector. They are merged together and processed by dense layer to make final prediction#create input-output sequence pairs from the image description.#data generator, used by model.fit_generator()
def data_generator(descriptions, features, tokenizer, max_length):while 1:for key, description_list in descriptions.items():#retrieve photo featuresfeature = features[key][0]input_image, input_sequence, output_word = create_sequences(tokenizer, max_length, description_list, feature)yield [[input_image, input_sequence], output_word]def create_sequences(tokenizer, max_length, desc_list, feature):X1, X2, y = list(), list(), list()# walk through each description for the imagefor desc in desc_list:# encode the sequenceseq = tokenizer.texts_to_sequences([desc])[0]# split one sequence into multiple X,y pairsfor i in range(1, len(seq)):# split into input and output pairin_seq, out_seq = seq[:i], seq[i]# pad input sequencein_seq = pad_sequences([in_seq], maxlen=max_length)[0]# encode output sequenceout_seq = to_categorical([out_seq], num_classes=vocab_size)[0]# storeX1.append(feature)X2.append(in_seq)y.append(out_seq)return np.array(X1), np.array(X2), np.array(y)[a,b],c = next(data_generator(train_descriptions, features, tokenizer, max_length))
a.shape, b.shape, c.shape# 定义模型
def define_model(vocab_size, max_length):# features from the CNN model squeezed from 2048 to 256 nodesinputs1 = Input(shape=(2048,))fe1 = Dropout(0.5)(inputs1)fe2 = Dense(256, activation='relu')(fe1)# LSTM sequence modelinputs2 = Input(shape=(max_length,))se1 = Embedding(vocab_size, 256, mask_zero=True)(inputs2)se2 = Dropout(0.5)(se1)se3 = LSTM(256)(se2)# Merging both modelsdecoder1 = add([fe2, se3])decoder2 = Dense(256, activation='relu')(decoder1)outputs = Dense(vocab_size, activation='softmax')(decoder2)# tie it together [image, seq] [word]model = Model(inputs=[inputs1, inputs2], outputs=outputs)model.compile(loss='categorical_crossentropy', optimizer='adam')# summarize modelprint(model.summary())#plot_model(model, to_file='model.png', show_shapes=True)return model# 开始创建并训练模型
print('Dataset: ', len(train_imgs))
print('Descriptions: train=', len(train_descriptions))
print('Photos: train=', len(train_features))
print('Vocabulary Size:', vocab_size)
print('Description Length: ', max_length)model = define_model(vocab_size, max_length)
epochs = 10
steps = len(train_descriptions)# 创建保存model的文件夹
os.mkdir("models")
for i in range(epochs):generator = data_generator(train_descriptions, train_features, tokenizer, max_length)model.fit_generator(generator, epochs=1, steps_per_epoch= steps, verbose=1)model.save("models/model_" + str(i) + ".h5")

2、测试模型

from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.applications.xception import Xception
from keras.models import load_model
from pickle import load
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import argparseap = argparse.ArgumentParser()
ap.add_argument('-i', '--image', required=False, help="Image Path", default='test/3291255271_a185eba408.jpg')
args = vars(ap.parse_args())
img_path = args['image']def extract_features(filename, model):try:image = Image.open(filename)except:print("ERROR: Couldn't open image! Make sure the image path and extension is correct")image = image.resize((299,299))image = np.array(image)# for images that has 4 channels, we convert them into 3 channelsif image.shape[2] == 4: image = image[..., :3]image = np.expand_dims(image, axis=0)image = image/127.5image = image - 1.0feature = model.predict(image)return featuredef word_for_id(integer, tokenizer):for word, index in tokenizer.word_index.items():if index == integer:return wordreturn Nonedef generate_desc(model, tokenizer, photo, max_length):in_text = 'start'for i in range(max_length):sequence = tokenizer.texts_to_sequences([in_text])[0]sequence = pad_sequences([sequence], maxlen=max_length)pred = model.predict([photo,sequence], verbose=0)pred = np.argmax(pred)word = word_for_id(pred, tokenizer)if word is None:breakin_text += ' ' + wordif word == 'end':breakreturn in_text#path = 'Flicker8k_Dataset/111537222_07e56d5a30.jpg'
max_length = 32
tokenizer = load(open("tokenizer.p","rb"))
model = load_model('models/model_9.h5')
xception_model = Xception(include_top=False, pooling="avg")photo = extract_features(img_path, xception_model)
img = Image.open(img_path)description = generate_desc(model, tokenizer, photo, max_length)
print("\n\n")
print(description)
plt.imshow(img)

3、测试结果

由于训练有些慢,没有训练太多epochs,导致测试结果有些奇怪,但是为类似工作提供一个参考思路。

start man in red shirt is walking on the beach end

开始 两只狗在草地上玩球

4、完整代码

ml_toolset/案例59 使用LSTM生成图像描述 at main · bashendixie/ml_toolset · GitHubContribute to bashendixie/ml_toolset development by creating an account on GitHub.https://github.com/bashendixie/ml_toolset/tree/main/%E6%A1%88%E4%BE%8B59%20%E4%BD%BF%E7%94%A8LSTM%E7%94%9F%E6%88%90%E5%9B%BE%E5%83%8F%E6%8F%8F%E8%BF%B0

机器学习笔记 - 使用CNN和LSTM为图像生成文字描述相关推荐

  1. 从图像生成自动描述:对模型,数据集和评估方法的综述

    摘要 从自然图像生成自动描述是一个具有挑战性的问题,近来受到计算机视觉和自然语言处理社区的大量关注. 在本次调查中,我们根据他们如何将这个问题概念化的现有方法进行分类,即将描述作为生成问题或作为视觉或 ...

  2. 机器学习笔记 - 基于传统方法/深度学习的图像配准

    一.图像配准 图像配准是将 一个场景的不同图像变换到同一坐标系的过程.这些图像可以在不同的时间(多时间配准).由不同的传感器(多模态配准)和/或从不同的视点拍摄.这些图像之间的空间关系可以是 刚性的 ...

  3. 机器学习笔记:CNN卷积神经网络

    1,CNN概述 卷积神经网络由输入层.卷积层.池化层.全连接层和输出层组成. 通过增加卷积层和池化层,可以得到更深层次的网络. 与多层感知器相比,卷积神经网络的参数更少,不容易发生过拟合. 2, 为何 ...

  4. [论文笔记]基于 CNN+双向LSTM 实现服饰搭配的生成

    论文:<Learning Fashion Compatibility with Bidirectional LSTMs> 论文地址:https://arxiv.org/abs/1707.0 ...

  5. 机器学习笔记 invariance data augmentation

    1 Invariance vs. Sensitivity 无论是对于图像.文本还是视频,我们都希望找到好的向量表示 好的向量表示需要对我们任务所关心的特征敏感: 动物识别问题中,动物的品种就是一个值得 ...

  6. 机器学习笔记: Upsampling, U-Net, Pyramid Scene Parsing Net

    前言 在CNN-based 的 模型中,我们可能会用到downsampling 操作来减少模型参数,以及扩大感受野的效果. 下图是一个graph segmentation的例子,就先使用 downsa ...

  7. 从图像到语言:图像标题生成与描述

    从图像到语言:图像标题生成与描述 大家好,我是苏州程序大白,五一假都过去三天了.大家可以学习起来.今天我们讲讲图像到语言.欢迎大家一起讨论.还有请大家多多支持.关注我.谢谢!!! 1.图像简单标题生成 ...

  8. 机器学习笔记 :LSTM 变体 (conv-LSTM、Peephole LSTM、 coupled LSTM、conv-GRU)

    1 LSTM复习 机器学习笔记 RNN初探 & LSTM_UQI-LIUWJ的博客-CSDN博客 机器学习笔记:GRU_UQI-LIUWJ的博客-CSDN博客_gru 机器学习 2 Peeph ...

  9. 李弘毅机器学习笔记:第十三章—CNN

    李弘毅机器学习笔记:第十三章-CNN 为什么用CNN Small region Same Patterns Subsampling CNN架构 Convolution Propetry1 Propet ...

最新文章

  1. shell 脚本简单入门
  2. Solr和lucene
  3. mysql数据库永久设置手动提交事务(InnoDB存储引擎禁止autocommit默认开启)
  4. unsigned int mysql_mysql 中int类型字段unsigned和signed的探索
  5. python弹窗输入_Python中使用tkinter弹窗获取输入文本
  6. VB讲课笔记12:文件管理
  7. mysql 无缓冲的查询_MySQL缓冲和无缓冲查询对比
  8. 计算机网络实验指导书 pdf,计算机网络实验指导书(新版).pdf
  9. 图像处理黑科技——弯曲矫正、去摩尔纹、切边增强、PS检测
  10. 电脑键盘出现计算机,电脑上出现了键盘怎么办
  11. SharePoint Online 触发的Automate工作流的调试
  12. SpringBoot整合WebSocket实现聊天室系统
  13. 大学生体育运动网页设计模板代码 DIV布局校园运动网页作业成品 HTML学校网页制作模板 学生简单体育运动网站设计成品
  14. 软件设计第一步——分离关注点和单一职责原则
  15. 性能测试脚本用例模版
  16. Linux编程和windows编程的区别
  17. 【pyqt5学习】—— 滑动条Qslider、计数器QSpinBox学习
  18. 魅族路由器(极速版)刷老毛子(padavad)固件-全网最详细教程
  19. u盘里的文件损坏了怎么修复?
  20. movielens数据集导入mysql_我来做数据--如何对数据进行处理以满足机器学习技术(一):MovieLens数据...

热门文章

  1. 计算机相关专业提升学历的解决方案(硕士研究生)
  2. Java程序员秋招三面蚂蚁金服,我总结了所有面试题,也不过如此
  3. 新手建站注意事项指南
  4. 蓝桥杯单片机比赛学习:11、频率测量的基本原理
  5. 登录用友显示java已被阻止_解决Spring Security 用户帐号已被锁定问题
  6. 苏州大学文正学院JAVA试卷_苏州大学文正学院试题库建设管理办法(试行)
  7. ffmpeg filter amix混音实现
  8. zimbra mysql stopping_Zimbra中的MySQL
  9. delphi function 与 procedure
  10. RT-thread基础移植//依据rtt实战学习记录