

我们将使用CNN(卷积神经网络) 和 LSTM(长期短期记忆)来实现字幕生成器。图像特征将从 Xception 中提取,Xception 是在 imagenet 数据集上训练的 CNN 模型,然后我们将特征输入 LSTM 模型,该模型将负责生成图像说明。


这里使用 Flickr_8K 数据集。虽然还有其他大型数据集,如 Flickr_30K 和 MSCOCO 数据集,但仅训练网络可能需要几周时间,所以我们将使用小型 Flickr8k 数据集。但是庞大数据集的优势在于我们可以构建更好的模型。

Flickr_8K 数据集为基于句子的图像描述和搜索引入了一个新的基准集合,由 8,000 张图像组成,每张图像都与五个不同的标题配对,这些标题提供了对显着实体和事件的清晰描述。







Model: "model"
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 32)]         0           []                               
 input_1 (InputLayer)           [(None, 2048)]       0           []                               
 embedding (Embedding)          (None, 32, 256)      1939712     ['input_2[0][0]']                
 dropout (Dropout)              (None, 2048)         0           ['input_1[0][0]']                
 dropout_1 (Dropout)            (None, 32, 256)      0           ['embedding[0][0]']              
 dense (Dense)                  (None, 256)          524544      ['dropout[0][0]']                
 lstm (LSTM)                    (None, 256)          525312      ['dropout_1[0][0]']              
 add (Add)                      (None, 256)          0           ['dense[0][0]',                  
 dense_1 (Dense)                (None, 256)          65792       ['add[0][0]']                    
 dense_2 (Dense)                (None, 7577)         1947289     ['dense_1[0][0]']                
Total params: 5,002,649
Trainable params: 5,002,649
Non-trainable params: 0




import string
import numpy as np
from PIL import Image
import os
from pickle import dump, load
import numpy as npimport tensorflow as tf
from tensorflow.keras.applications.xception import Xception, preprocess_input
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import add
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Dense, LSTM, Embedding, Dropout
from tensorflow.keras.utils import plot_model# small library for seeing the progress of loops.
from tqdm import tqdm_notebook as tqdm
#tqdm().pandas()# Loading a text file into memory
def load_doc(filename):# Opening the file as read onlyfile = open(filename, 'r')text = file.read()file.close()return text# get all imgs with their captions
def all_img_captions(filename):file = load_doc(filename)captions = file.split('\n')descriptions ={}for caption in captions[:-1]:img, caption = caption.split('\t')if img[:-2] not in descriptions:descriptions[img[:-2]] = [caption]else:descriptions[img[:-2]].append(caption)return descriptions##Data cleaning- lower casing, removing puntuations and words containing numbers
def cleaning_text(captions):table = str.maketrans('','',string.punctuation)for img,caps in captions.items():for i,img_caption in enumerate(caps):img_caption.replace("-"," ")desc = img_caption.split()#converts to lower casedesc = [word.lower() for word in desc]#remove punctuation from each tokendesc = [word.translate(table) for word in desc]#remove hanging 's and adesc = [word for word in desc if(len(word)>1)]#remove tokens with numbers in themdesc = [word for word in desc if(word.isalpha())]#convert back to stringimg_caption = ' '.join(desc)captions[img][i]= img_captionreturn captionsdef text_vocabulary(descriptions):# build vocabulary of all unique wordsvocab = set()for key in descriptions.keys():[vocab.update(d.split()) for d in descriptions[key]]return vocab#此函数将创建一个已预处理的所有描述的列表并将它们存储到一个文件中。
def save_descriptions(descriptions, filename):lines = list()for key, desc_list in descriptions.items():for desc in desc_list:lines.append(key + '\t' + desc )data = "\n".join(lines)file = open(filename,"w")file.write(data)file.close()# all_train_captions = []
# for key, val in descriptions.items():
#     for cap in val:
#         all_train_captions.append(cap)# # Consider only words which occur at least 8 times in the corpus
# word_count_threshold = 8
# word_counts = {}
# nsents = 0
# for sent in all_train_captions:
#     nsents += 1
#     for w in sent.split(' '):
#         word_counts[w] = word_counts.get(w, 0) + 1# vocab = [w for w in word_counts if word_counts[w] >= word_count_threshold]# print('preprocessed words %d ' % len(vocab))dataset_text = "Flickr8k_text"
dataset_images = "Flicker8k_Dataset"#we prepare our text data
filename = dataset_text + "/" + "Flickr8k.token.txt"
#loading the file that contains all data
#mapping them into descriptions dictionary img to 5 captions
descriptions = all_img_captions(filename)
print("Length of descriptions =" ,len(descriptions))#cleaning the descriptions
clean_descriptions = cleaning_text(descriptions)#building vocabulary
vocabulary = text_vocabulary(clean_descriptions)
print("Length of vocabulary = ", len(vocabulary))#saving each description to file
save_descriptions(clean_descriptions, "descriptions.txt")def extract_features(directory):model = Xception(include_top=False, pooling='avg')features = {}for img in tqdm(os.listdir(directory)):filename = directory + "/" + imgimage = Image.open(filename)image = image.resize((299, 299))image = np.expand_dims(image, axis=0)# image = preprocess_input(image)image = image / 127.5image = image - 1.0feature = model.predict(image)features[img] = featurereturn features# 提取特征向量 2048 feature vector
# 如果已经提取好了,可以注释掉下面两句
features = extract_features(dataset_images)
dump(features, open("features.p","wb"))features = load(open("features.p","rb"))# load the data
def load_photos(filename):file = load_doc(filename)photos = file.split("\n")[:-1]return photosdef load_clean_descriptions(filename, photos):# loading clean_descriptionsfile = load_doc(filename)descriptions = {}for line in file.split("\n"):words = line.split()if len(words) < 1:continueimage, image_caption = words[0], words[1:]if image in photos:if image not in descriptions:descriptions[image] = []desc = '<start> ' + " ".join(image_caption) + ' <end>'descriptions[image].append(desc)return descriptionsdef load_features(photos):# loading all featuresall_features = load(open("features.p", "rb"))# selecting only needed featuresfeatures = {k: all_features[k] for k in photos}return featuresfilename = dataset_text + "/" + "Flickr_8k.trainImages.txt"#train = loading_data(filename)
train_imgs = load_photos(filename)
train_descriptions = load_clean_descriptions("descriptions.txt", train_imgs)
train_features = load_features(train_imgs)#converting dictionary to clean list of descriptions
def dict_to_list(descriptions):all_desc = []for key in descriptions.keys():[all_desc.append(d) for d in descriptions[key]]return all_desc#creating tokenizer class
#this will vectorise text corpus
#each integer will represent token in dictionaryfrom keras.preprocessing.text import Tokenizerdef create_tokenizer(descriptions):desc_list = dict_to_list(descriptions)tokenizer = Tokenizer()tokenizer.fit_on_texts(desc_list)return tokenizer# give each word a index, and store that into tokenizer.p pickle file
tokenizer = create_tokenizer(train_descriptions)
dump(tokenizer, open('tokenizer.p', 'wb'))
vocab_size = len(tokenizer.word_index) + 1
vocab_size#calculate maximum length of descriptions
def max_length(descriptions):desc_list = dict_to_list(descriptions)return max(len(d.split()) for d in desc_list)max_length = max_length(descriptions)
max_lengthprint(features['1000268201_693b08cb0e.jpg'][0])# Define the model#1 Photo feature extractor - we extracted features from pretrained model Xception.
#2 Sequence processor - word embedding layer that handles text, followed by LSTM
#3 Decoder - Both 1 and 2 model produce fixed length vector. They are merged together and processed by dense layer to make final prediction#create input-output sequence pairs from the image description.#data generator, used by model.fit_generator()
def data_generator(descriptions, features, tokenizer, max_length):while 1:for key, description_list in descriptions.items():#retrieve photo featuresfeature = features[key][0]input_image, input_sequence, output_word = create_sequences(tokenizer, max_length, description_list, feature)yield [[input_image, input_sequence], output_word]def create_sequences(tokenizer, max_length, desc_list, feature):X1, X2, y = list(), list(), list()# walk through each description for the imagefor desc in desc_list:# encode the sequenceseq = tokenizer.texts_to_sequences([desc])[0]# split one sequence into multiple X,y pairsfor i in range(1, len(seq)):# split into input and output pairin_seq, out_seq = seq[:i], seq[i]# pad input sequencein_seq = pad_sequences([in_seq], maxlen=max_length)[0]# encode output sequenceout_seq = to_categorical([out_seq], num_classes=vocab_size)[0]# storeX1.append(feature)X2.append(in_seq)y.append(out_seq)return np.array(X1), np.array(X2), np.array(y)[a,b],c = next(data_generator(train_descriptions, features, tokenizer, max_length))
a.shape, b.shape, c.shape# 定义模型
def define_model(vocab_size, max_length):# features from the CNN model squeezed from 2048 to 256 nodesinputs1 = Input(shape=(2048,))fe1 = Dropout(0.5)(inputs1)fe2 = Dense(256, activation='relu')(fe1)# LSTM sequence modelinputs2 = Input(shape=(max_length,))se1 = Embedding(vocab_size, 256, mask_zero=True)(inputs2)se2 = Dropout(0.5)(se1)se3 = LSTM(256)(se2)# Merging both modelsdecoder1 = add([fe2, se3])decoder2 = Dense(256, activation='relu')(decoder1)outputs = Dense(vocab_size, activation='softmax')(decoder2)# tie it together [image, seq] [word]model = Model(inputs=[inputs1, inputs2], outputs=outputs)model.compile(loss='categorical_crossentropy', optimizer='adam')# summarize modelprint(model.summary())#plot_model(model, to_file='model.png', show_shapes=True)return model# 开始创建并训练模型
print('Dataset: ', len(train_imgs))
print('Descriptions: train=', len(train_descriptions))
print('Photos: train=', len(train_features))
print('Vocabulary Size:', vocab_size)
print('Description Length: ', max_length)model = define_model(vocab_size, max_length)
epochs = 10
steps = len(train_descriptions)# 创建保存model的文件夹
for i in range(epochs):generator = data_generator(train_descriptions, train_features, tokenizer, max_length)model.fit_generator(generator, epochs=1, steps_per_epoch= steps, verbose=1)model.save("models/model_" + str(i) + ".h5")


from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.applications.xception import Xception
from keras.models import load_model
from pickle import load
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import argparseap = argparse.ArgumentParser()
ap.add_argument('-i', '--image', required=False, help="Image Path", default='test/3291255271_a185eba408.jpg')
args = vars(ap.parse_args())
img_path = args['image']def extract_features(filename, model):try:image = Image.open(filename)except:print("ERROR: Couldn't open image! Make sure the image path and extension is correct")image = image.resize((299,299))image = np.array(image)# for images that has 4 channels, we convert them into 3 channelsif image.shape[2] == 4: image = image[..., :3]image = np.expand_dims(image, axis=0)image = image/127.5image = image - 1.0feature = model.predict(image)return featuredef word_for_id(integer, tokenizer):for word, index in tokenizer.word_index.items():if index == integer:return wordreturn Nonedef generate_desc(model, tokenizer, photo, max_length):in_text = 'start'for i in range(max_length):sequence = tokenizer.texts_to_sequences([in_text])[0]sequence = pad_sequences([sequence], maxlen=max_length)pred = model.predict([photo,sequence], verbose=0)pred = np.argmax(pred)word = word_for_id(pred, tokenizer)if word is None:breakin_text += ' ' + wordif word == 'end':breakreturn in_text#path = 'Flicker8k_Dataset/111537222_07e56d5a30.jpg'
max_length = 32
tokenizer = load(open("tokenizer.p","rb"))
model = load_model('models/model_9.h5')
xception_model = Xception(include_top=False, pooling="avg")photo = extract_features(img_path, xception_model)
img = Image.open(img_path)description = generate_desc(model, tokenizer, photo, max_length)



start man in red shirt is walking on the beach end

开始 两只狗在草地上玩球


ml_toolset/案例59 使用LSTM生成图像描述 at main · bashendixie/ml_toolset · GitHubContribute to bashendixie/ml_toolset development by creating an account on GitHub.https://github.com/bashendixie/ml_toolset/tree/main/%E6%A1%88%E4%BE%8B59%20%E4%BD%BF%E7%94%A8LSTM%E7%94%9F%E6%88%90%E5%9B%BE%E5%83%8F%E6%8F%8F%E8%BF%B0

