笔者小白,初学VQA,如有不对之处还请指教。

mmf是什么?官方提供的README中是这么说的:

MMF is a modular framework for vision and language multimodal research from Facebook AI Research. MMF contains reference implementations of state-of-the-art vision and language models and has powered multiple research projects at Facebook AI Research. See full list of project inside or built on MMF here.

mmf中包含了许多vqa中基本模型的实现,通过学习这些模型的代码实现,可以快速地了解vqa的发展与技术基础。

今天首先从CNNLSTM这个最简单的模型出发,学习mmf构建模型的基本框架。

上图是该模型的基本思路。mmf中的代码仅对融合特征进行了Classfiy并没有进行RNN 的 decoder 。

# Copyright (c) Facebook, Inc. and its affiliates.
​
from copy import deepcopy
​
import torch
from mmf.common.registry import registry
from mmf.models.base_model import BaseModel
from mmf.modules.layers import ClassifierLayer, ConvNet, Flatten
from torch import nn
​
​
_TEMPLATES = {"question_vocab_size": "{}_text_vocab_size","number_of_answers": "{}_num_final_outputs",
}
​
_CONSTANTS = {"hidden_state_warning": "hidden state (final) should have 1st dim as 2"}
​
​
@registry.register_model("cnn_lstm")
class CNNLSTM(BaseModel):"""CNNLSTM is a simple model for vision and language tasks. CNNLSTM is supposedto acts as a baseline to test out your stuff without any complex functionality.Passes image through a CNN, and text through an LSTM and fuses them usingconcatenation. Then, it finally passes the fused representation from a MLP togenerate scores for each of the possible answers.
​Args:config (DictConfig): Configuration node containing all of the necessaryconfig required to initialize CNNLSTM.
​Inputs: sample_list (SampleList)- **sample_list** should contain image attribute for image, text forquestion split into word indices, targets for answer scores"""
​def __init__(self, config):super().__init__(config)self._global_config = registry.get("config")self._datasets = self._global_config.datasets.split(",")
​@classmethoddef config_path(cls):return "configs/models/cnn_lstm/defaults.yaml"
​def build(self):assert len(self._datasets) > 0num_question_choices = registry.get(_TEMPLATES["question_vocab_size"].format(self._datasets[0]))num_answer_choices = registry.get(_TEMPLATES["number_of_answers"].format(self._datasets[0]))
​self.text_embedding = nn.Embedding(num_question_choices, self.config.text_embedding.embedding_dim)self.lstm = nn.LSTM(**self.config.lstm)
​layers_config = self.config.cnn.layersconv_layers = []for i in range(len(layers_config.input_dims)):conv_layers.append(ConvNet(layers_config.input_dims[i],layers_config.output_dims[i],kernel_size=layers_config.kernel_sizes[i],))conv_layers.append(Flatten())self.cnn = nn.Sequential(*conv_layers)
​# As we generate output dim dynamically, we need to copy the config# to update itclassifier_config = deepcopy(self.config.classifier)classifier_config.params.out_dim = num_answer_choicesself.classifier = ClassifierLayer(classifier_config.type, **classifier_config.params)
​def forward(self, sample_list):self.lstm.flatten_parameters()
​question = sample_list.textimage = sample_list.image
​# Get (h_n, c_n), last hidden and cell state_, hidden = self.lstm(self.text_embedding(question))# X x B x H => B x X x H where X = num_layers * num_directionshidden = hidden[0].transpose(0, 1)
​# X should be 2 so we can merge in that dimensionassert hidden.size(1) == 2, _CONSTANTS["hidden_state_warning"]
​hidden = torch.cat([hidden[:, 0, :], hidden[:, 1, :]], dim=-1)image = self.cnn(image)
​# Fuse into single dimensionfused = torch.cat([hidden, image], dim=-1)scores = self.classifier(fused)
​return {"scores": scores}
​

以上类继承了BaseModel类。mmf中所有的model类都要继承自BaseModel。

在生成类时,对模型进行了注册。

@registry.register_model("cnn_lstm")

相关代码可以查看Registry类中的相关类函数。

 @classmethoddef register_model(cls, name):r"""Register a model to registry with key 'name'
​Args:name: Key with which the model will be registered.
​Usage::
​from mmf.common.registry import registryfrom mmf.models.base_model import BaseModel
​@registry.register_task("pythia")class Pythia(BaseModel):..."""
​def wrap(func):from mmf.models.base_model import BaseModel
​assert issubclass(func, BaseModel), "All models must inherit BaseModel class"cls.mapping["model_name_mapping"][name] = funcreturn func
​return wrap

模型的默认配置在configs/models/cnn_lstm/defaults.yaml中

model_config:cnn_lstm:losses:- type: logit_bcetext_embedding:embedding_dim: 20lstm:input_size: 20hidden_size: 50bidirectional: truebatch_first: truecnn:layers:input_dims: [3, 64, 128, 128, 64, 64]output_dims: [64, 128, 128, 64, 64, 10]kernel_sizes: [7, 5, 5, 5, 5, 1]classifier:type: mlpparams:in_dim: 450out_dim: 2

之后这些配置会详细的介绍。

首先生成一个20维的embedding

 self.text_embedding = nn.Embedding(num_question_choices, self.config.text_embedding.embedding_dim)

生成LSTM模块,隐藏层维度为50.

self.lstm = nn.LSTM(**self.config.lstm)

生成CNN模块,各层通道数和卷积核的大小由config中定义。共6层卷积。

layers_config = self.config.cnn.layersconv_layers = []for i in range(len(layers_config.input_dims)):conv_layers.append(ConvNet(layers_config.input_dims[i],layers_config.output_dims[i],kernel_size=layers_config.kernel_sizes[i],))conv_layers.append(Flatten())

在modules/layers.py中有对ConvNet的定义。

卷积后加池化加batchnorm构成一个ConvNet。最后在所有卷积层之后Flatten。

class ConvNet(nn.Module):def __init__(self,in_channels,out_channels,kernel_size,padding_size="same",pool_stride=2,batch_norm=True,):super().__init__()
​if padding_size == "same":padding_size = kernel_size // 2self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding_size)self.max_pool2d = nn.MaxPool2d(pool_stride, stride=pool_stride)self.batch_norm = batch_norm
​if self.batch_norm:self.batch_norm_2d = nn.BatchNorm2d(out_channels)
​def forward(self, x):x = self.max_pool2d(nn.functional.leaky_relu(self.conv(x)))
​if self.batch_norm:x = self.batch_norm_2d(x)
​return x​
class Flatten(nn.Module):def forward(self, input):if input.dim() > 1:input = input.view(input.size(0), -1)
​return input

在前向传播中

    def forward(self, sample_list):self.lstm.flatten_parameters()
​question = sample_list.textimage = sample_list.image
​# Get (h_n, c_n), last hidden and cell state_, hidden = self.lstm(self.text_embedding(question))# X x B x H => B x X x H where X = num_layers * num_directionshidden = hidden[0].transpose(0, 1)
​# X should be 2 so we can merge in that dimensionassert hidden.size(1) == 2, _CONSTANTS["hidden_state_warning"]
​hidden = torch.cat([hidden[:, 0, :], hidden[:, 1, :]], dim=-1)image = self.cnn(image)
​# Fuse into single dimensionfused = torch.cat([hidden, image], dim=-1)scores = self.classifier(fused)
​return {"scores": scores}

hidden与image的维度都为2,第一维为batch_size,第二维分别为lstm和cnn出来的特征向量

将lstm最后一个隐藏状态和cnn的输出特征进行拼接,输入全连接网络,输出两个评分。这两个评分,是对于预设答案的评分。

该网络是vqa中最简单的网络。然而,任何复杂的网络都需要从简单的网络中逐渐演变诞生而来。Rome wasn‘t built in a day !

VQA学习笔记(一)CNN-LSTM相关推荐

  1. 唐宇迪之tensorflow学习笔记项目实战(LSTM情感分析)

    我们首先来看看RNN的网络结构,如下图所示 xt 表示第t,t=1,2,3-步(step)的输入 st 为隐藏层的第t步的状态,它是网络的记忆单元. st=f(u×xt+w×st−1) ,其中f一般是 ...

  2. 七月算法深度学习笔记4 -- CNN与常用框架

    这套笔记是跟着七月算法五月深度学习班的学习而记录的,主要记一下我再学习机器学习的时候一些概念比较模糊的地方,具体课程参考七月算法官网: http://www.julyedu.com/ 神经网络的结构 ...

  3. 【theano-windows】学习笔记二十——LSTM理论及实现

    前言 上一篇学习了RNN,也知道了在沿着时间线对上下文权重求梯度的时候,可能会导致梯度消失或者梯度爆炸,然后我们就得学习一波比较常见的优化方法之LSTM 国际惯例,参考网址: LSTM Network ...

  4. tensorflow 学习笔记使用CNN做英文文本分类任务

    使用CNN做英文文本分类任务 本文同时也是学习唐宇迪老师深度学习课程的一些理解与记录. 文中代码是实现在TensorFlow下使用卷积神经网络(CNN)做英文文本的分类任务(本次是垃圾邮件的二分类任务 ...

  5. 学习笔记:cnn 猫狗识别

    版权声明:本文为博主原创文章,未经博主允许不得转载. 1.数据获取 本次学习的数据为,kaggle 中的 Dogs vs Cats 数据集 如果不清楚,kaggle,可以看一下,我前面写的这篇文章:h ...

  6. 李宏毅学习笔记11.CNN(上)

    文章目录 前言 为什么要用CNN来处理图像? 为什么CNN可以去掉一些神经元后仍然可以工作? 第一个原因 第二个原因 第三个原因 CNN长什么样? Convolution 学生提问 彩色图片的处理 C ...

  7. 深度学习学习笔记——RNN(LSTM、GRU、双向RNN)

    目录 前置知识 循环神经网络(RNN) 文本向量化 RNN 建模 RNN 模型改进 LSTM(Long Short Term Memory) LSTM变形与数学表达式 门控循环单元GRU(Grated ...

  8. 李宏毅深度学习笔记(CNN)

    卷积神经网络(CNN) 为什么使用CNN? CNN可以很好用于图像的处理,这主要基于两个假设: 图像中同样的特征片段可能出现在不同的位置.图像上不同小片段,以及不同图像上的小片段的特征是类似的,也就是 ...

  9. 学习笔记:cnn 身份证数字识别

    版权声明:本文为博主原创文章,未经博主允许不得转载. 这篇文章跟cnn猫狗识别是差不了多少的,只是数据处理,与训练时做了稍微的调整,数据集和代码可以通过,https://github.com/zr94 ...

最新文章

  1. golang 基础知识4
  2. 进程间通信——自定义消息方式实现(SetWindowsHookEx)
  3. C++/C--vector初始化与赋值【转载】
  4. 零售连锁专卖信息化解决方案简介之一
  5. EVC获取当前工作路径
  6. 用简单的实例来实践TDD的核心思想
  7. 【SpringBoot_ANNOTATIONS】自动装配 02 @Resource @Inject
  8. 两平面平行方向向量关系_空间向量,如果一条直线与一平面平行,那么直线的方向向量与平面的法向量有什么关系??垂直呢?...
  9. python动态规划dp
  10. Ubuntu速配指南之软件参考
  11. 阳振坤:OceanBase 数据库七亿 tpmC 的关键技术
  12. html+css 设置select标签的宽高
  13. seacms海洋cms漏洞
  14. Centos删除乱码文件或文件夹
  15. 算法与数据结构实验题 5.18 小孩的游戏
  16. caffe与Python接口的配置(VC2013 Windows CUDA7.5 Python2.7.12)
  17. SD-Branch多分支组网解决方案
  18. 计算机黑屏跳横杠,电脑开机时黑屏左上角显示一个横杠是怎么回事
  19. Git submodule 采坑
  20. 一键清除本地缓存的所有无用的docker镜像命令

热门文章

  1. java调用迅雷_java jna调用迅雷接口下载
  2. 2022年数维杯国际大学生数学建模挑战赛D题三重拉尼娜事件下极端气候灾害损失评估与应对策略研究解题过程
  3. 关于iebook的应用和传播(破解)
  4. Linux资源监控命令/工具(综合)
  5. 机器 · 搜索 · 未来
  6. 基于ActionScript3.0的DoodleJump 游戏实现
  7. 关于2000版ISO 9001标准的新思考之四(转载)
  8. 关于ACCESS数据库的不可更新查询
  9. 先进先出页面置换算法详解
  10. 流氓金泰丰pctools.dll,不过Avast认为其为广告软体,杀