Neural Turing Machines-NTM系列(三)ntm-lasagne源码分析

在NTM系列文章(二)中,我们已经成功运行了一个ntm工程的源代码。在这一章中,将对它的源码实现进行分析。

1.网络结构

1.1 模块结构图


在图中可以看到,输入的数据在经过NTM的处理之后,输出经过NTM操作后的,跟之前大小相同的数据块。来看下CopyTask的完整输出图:

图中右侧的Input是输入数据,Output是目标数据,Prediction是通过NTM网络预测出来的输出数据,可以看出预测数据与目标数据只在区域上大致相同,具体到每个白色的块差距较大。(这里只迭代训练了100次)
训练次数可以在这里调整(task-copy.py):

其中的参数max_iter就是训练时的迭代次数,size是输入的数据宽度(即上图中Input/Output小矩形的“高”-1,多出来的维度用作结束标记)
输入数据如下,从上到下对应上图中的从左到右,最后一行是结束标志,只有最后一个元素为1:
array( [[
[ 0., 1., 1., 0., 1., 1., 1., 1., 0.],
[ 0., 1., 1., 0., 0., 1., 0., 0., 0.],
[ 0., 0., 1., 0., 1., 1., 1., 0., 0.],
[ 1., 1., 1., 1., 1., 1., 1., 0., 0.],
[ 0., 0., 0., 0., 0, 0., 0., 0., 1.],
[ 0., 0., 0., 0., 0, 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0, 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0, 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0, 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0, 0., 0., 0., 0.],
]]
目标数据和预测数据的格式相似,就不详细介绍了。需要注意的是,由于输出层使用的是sigmoid函数,所以预测数据的范围在0和1之间。

1.2 Head对象内部的计算流


上图对应的实现在ntm-lasagne/ntm/heads.py中的Head基类中的get_output_for函数

    def get_output_for(self, h_t, w_tm1, M_t, **kwargs):if self.sign is not None:sign_t = self.sign.get_output_for(h_t, **kwargs)else:sign_t = 1.k_t = self.key.get_output_for(h_t, **kwargs)beta_t = self.beta.get_output_for(h_t, **kwargs)g_t = self.gate.get_output_for(h_t, **kwargs)s_t = self.shift.get_output_for(h_t, **kwargs)gamma_t = self.gamma.get_output_for(h_t, **kwargs)# Content Adressing (3.3.1)beta_t = T.addbroadcast(beta_t, 1)betaK = beta_t * similarities.cosine_similarity(sign_t * k_t, M_t)w_c = lasagne.nonlinearities.softmax(betaK)# Interpolation (3.3.2)g_t = T.addbroadcast(g_t, 1)w_g = g_t * w_c + (1. - g_t) * w_tm1# Convolutional Shift (3.3.2)w_g_padded = w_g.dimshuffle(0, 'x', 'x', 1)conv_filter = s_t.dimshuffle(0, 'x', 'x', 1)pad = (self.num_shifts // 2, (self.num_shifts - 1) // 2)w_g_padded = padding.pad(w_g_padded, [pad], batch_ndim=3)convolution = T.nnet.conv2d(w_g_padded, conv_filter,input_shape=(self.input_shape[0], 1, 1, self.memory_shape[0] + pad[0] + pad[1]),filter_shape=(self.input_shape[0], 1, 1, self.num_shifts),subsample=(1, 1),border_mode='valid')w_tilde = convolution[:, 0, 0, :]# Sharpening (3.3.2)gamma_t = T.addbroadcast(gamma_t, 1)w = T.pow(w_tilde + 1e-6, gamma_t)w /= T.sum(w)return w

其中的传入参数解释如下:
h_t:controller的隐层输出;
w_tm1:前一时刻的输出值,即 wt−1 w_{t-1};
M_t:Memory矩阵

1.3 NTMLayer结构图


NTM层的数据处理实现在ntm-lasagne/ntm/layers.py中的NTMLayer.get_output_for函数中:

注意到其中还有一个内部函数step,这个函数中实现了每一次数据输入后NTM网络要进行的操作逻辑。
其中的参数解释如下:
x_t:当前的网络输入,即1.1中输入矩阵中的一行;
M_tm1:前一时刻的Memory矩阵,即 Mt−1 M_{t-1}
h_tm1:前一时刻的controller隐层输出
state_tm1:前一时刻的controller隐层状态,当controller为前馈网络时,等于前一时刻的输出
params:存放write heads和read heads上一时刻的输出即 wt−1 w_{t-1},顺序如下:
[write_head1_w,write_head2_w,…,write_headn1_w,read_head1_w,read_head2_w,…,read_headn2_w]
1.每次网络接收到输入后,会进入step迭代函数,先走write(erase+add)流程,更新Memory,然后再执行read操作,生成 rt r_t向量。这部分代码如下:

最后的r_t就是读取出来的 rt r_t向量,注意这里有个比较特殊的参数W_hid_to_sign_add,这是一个开关参数,类似于LSTM中的“门”。这个参数默认为None。
2.read vector生成后,将作为输入参数被传入Controller:

3.step函数结束,返回值为一list,代码如下:

list中的元素依次为:[M_t, h_t, state_t + write_weights_t + read_weights_t]
step函数通过 theano.scan来进行迭代调用,每次的输入即为当前的input及上一时刻的list值
4.最后NTMLayer.get_out_for函数的返回值为:
hid_out = hids[1],正好对应了Controller隐层最近一次的输出值。

1.4 NTM网络结构图

2.公式及主要Class说明

αt=σalpha(htWalpha+balpha) \alpha_{t} = \sigma_{alpha}(h_{t} W_{alpha} + b_{alpha})
kt=σkey(htWkey+bkey) k_{t} = \sigma_{key}(h_{t} W_{key} + b_{key})
βt=σbeta(htWbeta+bbeta) \beta_{t} = \sigma_{beta}(h_{t} W_{beta} + b_{beta})
gt=σgate(htWgate+bgate) g_{t} = \sigma_{gate}(h_{t} W_{gate} + b_{gate})
st=σshift(htWshift+bshift) s_{t} = \sigma_{shift}(h_{t} W_{shift} + b_{shift})
γt=σgamma(htWgamma+bgamma) \gamma_{t} = \sigma_{gamma}(h_{t} W_{gamma} + b_{gamma})

wct=softmax(βt∗K(αt∗kt,Mt)) w_{t}^{c} = softmax(\beta_{t} * K(\alpha_{t} * k_{t}, M_{t}))
wgt=gt∗wct+(1−gt)∗wt−1 w_{t}^{g} = g_{t} * w_{t}^{c} + (1 - g_{t}) * w_{t-1}
w̃ t=st∗wgt \tilde{w}_{t} = s_{t} \ast w_{t}^{g}
wt∝w̃ γtt w_{t} \propto \tilde{w}_{t}^{\gamma_{t}}

NTMLayer:父类为 lasagne.layers.Layer
功能:Neural Turing Machine的框架层
字段:memory:即Memory
controller:控制器,父类为Layer,默认100个节点
controller.hid_init:隐层的状态集合,大小为:(1,100)
heads:读写取Head集合
write_heads:写入Head集合
read_heads:读取Head集合
函数:get_output_for:在给定的输入input下,返回对应的输出值

Head:父类为lasagne.layers.Layer
功能:读写头的基类
字段:sign:DenseLayer(全连接网络),输出为 αt \alpha_{t},激活函数为ClippedLinear(-1,1),节点数:20;
key:DenseLayer,输出为 kt k_{t},激活函数为ClippedLinear(0,1),节点数:20,输入层为controller;
beta:DenseLayer,输出为 βt \beta_{t},激活函数为rectify,节点数:1,输入层为controller;
gate:DenseLayer,输出为 gt g_{t},激活函数为hard_sigmoid,节点数:1,输入层为controller;
shift:DenseLayer,输出为 st s_{t},激活函数为softmax,节点数:3(等于num_shifts,默认为3),输入层为controller,最终将输出3个概率值,分别对应 st(−1),st(0),st(1) s_{t}(-1),s_{t}(0),s_{t}(1),s_{t}长度为N,除softmax输出的3个位置非0之外,其余位置为0;
gamma:DenseLayer,输出为 γt \gamma_{t},激活函数为1+rectify,节点数:1,输入层为controller;
num_shifts:卷积shifts的操作宽度(奇数),当宽度为n时,移位向量为:[-n/2,…,-1,0,1,…,n/2],比如,当n=3时,为:[-1,0,1]
weights_init:输出为OneHot 1×128 1\times 128的权值向量,其初始值为除第一个元素为1之外,其余元素为0.
gate:DenseLayer,输出为 eraset erase_{t},激活函数为hard_sigmoid,节点数:20,输入层为controller;
add:DenseLayer,输出为 addt add_{t},激活函数为ClippedLinear(0,1),节点数:20,输入层为controller;
rectify: f(x)=max(0,x) f(x)=max(0, x)
sign_add:DenseLayer,输出为 signAddt signAdd_{t},激活函数为ClippedLinear(-1,1),节点数:20,输入层为controller;
rectify: f(x)=max(0,x) f(x)=max(0, x)
softmax: f(x)=exj∑Kk=1exk f(x)=\frac{e^{\mathbf{x}_j}}{\sum_{k=1}^K e^{\mathbf{x}_k}}

hard_sigmoid:

f(x)=⎧⎩⎨⎪⎪x=0,x<0x=0.2x+0.5,x∈[0,1]x=1,x>1

f(x)=\left\{ \begin{aligned} x=0,x1\\ \end{aligned} \right.
ClippedLinear(a,b):

f(x)={x=a,x<ax=b,x>b

f(x)=\left\{ \begin{aligned} x = a , xb\\ \end{aligned} \right.

3.copy-task实验

(待续)
参考文章:
http://blog.csdn.net/niuwei22007/article/details/49208643
https://medium.com/snips-ai/ntm-lasagne-a-library-for-neural-turing-machines-in-lasagne-2cdce6837315
http://lasagne.readthedocs.org/en/latest/user/tutorial.html
http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.imshow

Neural Turing Machines-NTM系列(三)ntm-lasagne源码分析相关推荐

  1. 【深入浅出MyBatis系列十一】缓存源码分析

    为什么80%的码农都做不了架构师?>>>    #0 系列目录# 深入浅出MyBatis系列 [深入浅出MyBatis系列一]MyBatis入门 [深入浅出MyBatis系列二]配置 ...

  2. Dubbo系列(二)源码分析之SPI机制

    Dubbo系列(二)源码分析之SPI机制 在阅读Dubbo源码时,常常看到 ExtensionLoader.getExtensionLoader(*.class).getAdaptiveExtensi ...

  3. jieba tfidf_【NLP】【三】jieba源码分析之关键字提取(TF-IDF/TextRank)

    [一]综述 利用jieba进行关键字提取时,有两种接口.一个基于TF-IDF算法,一个基于TextRank算法.TF-IDF算法,完全基于词频统计来计算词的权重,然后排序,在返回TopK个词作为关键字 ...

  4. 我的架构梦:(三)MyBatis源码分析

    mybatis的源码分析 一.传统方式源码分析 二.Mapper代理方式源码分析 三.MyBatis源码中涉及到的设计模式 一.传统方式源码分析 分析之前我们来回顾下传统方式的写法: /*** 传统方 ...

  5. Java源码详解三:Hashtable源码分析--openjdk java 11源码

    文章目录 注释 哈希算法与映射 线程安全的实现方法 put 操作 get操作 本系列是Java详解,专栏地址:Java源码分析 Hashtable官方文档:Hashtable (Java Platfo ...

  6. 集合框架知识系列05 HashMap的源码分析和使用示例

    一.HashMap简介 HashMap是基于"拉链法"实现的散列表.一般用于单线程程序中,JDK 1.8对HashMap进行了比较大的优化,底层实现由之前的"数组+链表& ...

  7. 查询已有链表的hashmap_源码分析系列1:HashMap源码分析(基于JDK1.8)

    1.HashMap的底层实现图示 如上图所示: HashMap底层是由  数组+(链表)=(红黑树) 组成,每个存储在HashMap中的键值对都存放在一个Node节点之中,其中包含了Key-Value ...

  8. 源码分析系列1:HashMap源码分析(基于JDK1.8)

    1.HashMap的底层实现图示 如上图所示: HashMap底层是由  数组+(链表)+(红黑树) 组成,每个存储在HashMap中的键值对都存放在一个Node节点之中,其中包含了Key-Value ...

  9. Linux线程同步(三)---互斥锁源码分析

    先给自己打个广告,本人的微信公众号:嵌入式Linux江湖,主要关注嵌入式软件开发,股票基金定投,足球等等,希望大家多多关注,有问题可以直接留言给我,一定尽心尽力回答大家的问题. 一 源码分析 1.li ...

最新文章

  1. 室内设计木地板材质合集包 Arroway – Design Craft Vol.4
  2. 使用Ext Form自动绑定Html中的Form元素
  3. lightoj 1037 - Agent 47(状压dp)
  4. db2 本地db 到实例_如何登录到FreeCodeCamp的本地实例
  5. 关于类模版迭代器提出时的错误
  6. 电机学(1) - 绪论
  7. python3安装MySQLdb
  8. WorkTool(一)企业微信群管理机器人实现
  9. iOS视频播放器开发
  10. windows 10远程桌面连接报错解决办法
  11. cosine similarity 余弦相似度
  12. 顺序表练习(三):对称矩阵的压缩储存
  13. 微信小程序weui在线入门教程-WeUi操作反馈-actionsheet弹出式菜单
  14. 新媒体运营教程:实现用户增长5个步骤,5个基础方法
  15. 掌控板教程 | 搞定 Siri 语音控制,只要半小时!
  16. 高考成绩等位分查询2021,干货│如何查询等位分?精确填报志愿必备......
  17. 大学生的福音,学习 Java 最强书单推荐,附学习方法
  18. Matlab工业检测之尘埃统计
  19. 解决wine 1.35 无法发声问题
  20. [荐]没羽箭张清到底连打梁山多少好汉?

热门文章

  1. JavaSE入门0基础笔记 第二章Java基础语法
  2. 马尔可夫链的定义、举例和应用
  3. App开发者必备的运营、原型、UI设计工具整理
  4. 如何进行APP界面设计
  5. SystemInfo 类
  6. 学习c语言必备的书籍推荐
  7. 【洛谷】P3386 【模板】二分图最大匹配
  8. 0基础学前端开发,CSS盒子模型居中方法
  9. 【LeetCode-SQL每日一练】—— 181. 超过经理收入的员工
  10. Visual Studio 2017下载地址和安装教程(图解版)