TensorFlow中的RNNCell基本单元使用
0.charRNN基础介绍
- charRNN 是N vs N的循环神经网络,要求输入序列长度等于输出序列长度。
原理:用已经输入的字母去预测下一个字母的概率。一个句子是hello!,例如输入序列是hello,则输出序列是ello!
预测时:首先选择一个x1当作起始的字符,然后用训练好的模型得到下一个字符出现的概率。根据这个概率选择一个字符输出,然后将此字符当作下一步的x2输入到模型中。依次递推,得到任意长度的文字。注意:输入的单个字母是以one-hot形式进行编码的! - 对中文进行建模时,每一步输入模型的是一个汉字,由于汉字的种类太多,导致模型太大,一般采用下面的方法进行优化:
- 1.取最常用的N个汉字,将剩下的汉字变成单独的一类,用一个<unk>字符来进行标注
- 2.在输入时,可以加入一个embedding层,将汉字的one-hot编码转为稠密的词嵌入表示。对单个字母不使用embedding是由于单个字母不具备任何的含义,只需要使用one-hot编码即可。单个汉字是具有一定的实际意义的,所以使用embedding层
1.实现RNN的基本单元RNNCell抽象类--------有两种直接使用的子类:BasicRNNCell(基本的RNN)和LSTMCell(基本的LSTM)
- RNNCell有三个属性:
- 1.类方法call:所有的子类都会实现一个__call__函数,可以实现RNN的单步计算,调用形式为(output,next_state) = __call__(input, state)
- 2.类属性state_size:隐藏层的大小,输入数据是以batch_size的形式进行输入的即input=(batch_size, input_size),
调用__call__函数时隐藏层的形状是(batch_size, state_size),输出层的形状是(batch_size, output_size) - 3.类属性output_size:输出向量的大小
2.定义一个基本的RNN单元
import tensorflow as tf
import numpy as nprnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
print("rnn_cell.state_size:", rnn_cell.state_size)
3.定义一个基本的LSTM的基本单元
lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=128)
print("lstm_cell.state_size:", lstm_cell.state_size)lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=128) # batch_size=32, input_size=100
inputs = tf.placeholder(np.float32, shape=(32, 100))
h0 = lstm_cell.zero_state(32, np.float32) # 通过zero_state得到一个全0的初始状态
output, h1 = lstm_cell.__call__(inputs, h0)
print(h1.c)
print(h1.h)
4.对RNN进行堆叠:MultiRNNCell
# 每次调用这个函数返回一个BasicRNNCell
def get_a_cell():return tf.nn.rnn_cell.BasicRNNCell(num_units=128)# 使用MultiRNNCell创建3层RNN
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)])
# 得到的RNN也是RNNCell的子类,state_size=(128, 128, 128):三个隐层状态,每个隐层状态的大小是128
print(cell.state_size)# 32是batch_size, 100是input_size
inputs = tf.placeholder(np.float32, shape=(32, 100))
h0 = cell.zero_state(32, np.float32)
output, h1 = cell.__call__(inputs, h0)
print(h1)
5.使用tf.nn.dunamic_rnn按时间展开:相当于增加了一个时间维度time_steps,通过{h0,x1,x2…,xn}得到{h1,h2,h3,…hn}
inputs: shape=(batch_size, time_steps, input_size) # 输入数据的格式是(batch_size, time_steps, input_size)
initial_state: shape(batch_size,cell.state_size) # 初始状态,一般可以取零矩阵
outputs, state = tf.nn.dynamic_rnn(cell,inputs,initial_state)
# outputs是time_steps中所有的输出,形状是(batch_size, time_steps, cell.output_size)
# state是最后一步的隐状态,形状是(batch_size,cell.state_size)
- 注意:输入数据的形状是(time_steps,batch_size, input_size),可以调用tf.nn.dynamic_rnn()函数中设定参数time_major=True。此时,得到的outputs的形状是(time_steps, batch_size, cell.output_size);state的形状不变化
TensorFlow中的RNNCell基本单元使用相关推荐
- TensorFlow中RNN实现的正确打开方式
上周写的文章<完全图解RNN.RNN变体.Seq2Seq.Attention机制>介绍了一下RNN的几种结构,今天就来聊一聊如何在TensorFlow中实现这些结构,这篇文章的主要内容为: ...
- TensorFlow中RNN实现的正确打开方式(转)
上周写的文章<完全图解RNN.RNN变体.Seq2Seq.Attention机制>介绍了一下RNN的几种结构,今天就来聊一聊如何在TensorFlow中实现这些结构,这篇文章的主要内容为: ...
- 循序渐进的学习TensorFlow中RNN实现的方法
本文转载自:https://zhuanlan.zhihu.com/p/28196873 上周写了一篇文章介绍了一下RNN的几种结构,今天就来聊一聊如何在TensorFlow中实现这些结构,这篇文章的主 ...
- TensorFlow 中文文档 介绍
介绍 本章的目的是让你了解和运行 TensorFlow 在开始之前, 先看一段使用 Python API 撰写的 TensorFlow 示例代码, 对将要学习的内容有初步的印象. 这段很短的 Pyth ...
- struts实现分页_在TensorFlow中实现点Struts
struts实现分页 If you want to get started on 3D Object Detection and more specifically on Point Pillars, ...
- Python 3深度置信网络(DBN)在Tensorflow中的实现MNIST手写数字识别
任何程序错误,以及技术疑问或需要解答的,请扫码添加作者VX:1755337994 使用DBN识别手写体 传统的多层感知机或者神经网络的一个问题: 反向传播可能总是导致局部最小值. 当误差表面(erro ...
- SELU︱在keras、tensorflow中使用SELU激活函数
arXiv 上公开的一篇 NIPS 投稿论文<Self-Normalizing Neural Networks>引起了圈内极大的关注,它提出了缩放指数型线性单元(SELU)而引进了自归一化 ...
- 在Tensorflow中使用深度学习构建图像标题生成器
by Cole Murray 通过科尔·默里(Cole Murray) 在Tensorflow中使用深度学习构建图像标题生成器 (Building an image caption generator ...
- TensorFlow中张量,变量、常量、占位符概念
1.总结TensorFlow中的张量概念 张量:数据结构:多维数组 零阶张量表示标量(scalar),也就是一个数: 一阶张量为向量(vector),也就是一个数组: N阶张量可以理解为一个n维数组: ...
最新文章
- 涨姿势!北京地铁原来是16条旅游专线
- 使用CSS实现无滚动条滚动
- 惠普用的是微软服务器吗,惠普抛弃MediaSmart服务器 微软表示淡定
- 05-02 docker 安装与配置-CentOS
- 【二分图】【最大匹配】【匈牙利算法】CODEVS 2776 寻找代表元
- failed to accept an incoming connection: connection from 192.168.1.114 rejected, allowed hosts: 1
- bzoj 1691: [Usaco2007 Dec]挑剔的美食家(multiset贪心)
- 两个平面的位置关系和判定方程组解_2018年高考数学总复习第九章平面解析几何第2讲两直线的位置关系学案!...
- matlab三角形外接圆
- iOS系统录屏如何增加雷达波纹效果(从一个点向周围扩散)的简单实现
- 一条校招/社招潜规则~
- 物联网(IoT)及其未来应用方向
- 基于机器学习的电影票房分析与预测系统
- 《数据结构与算法之二叉平衡树(AVL)》
- 新手如何学习Java以及学习java的步骤
- 哈工大软件过程与工具
- simulink中MUX
- FMEA-MSR步骤六:优化
- 从零开始学AI(Python基础)
- 你的第一篇SCI是怎么发的呢?