0.charRNN基础介绍

  • charRNN 是N vs N的循环神经网络,要求输入序列长度等于输出序列长度。
    原理:用已经输入的字母去预测下一个字母的概率。一个句子是hello!,例如输入序列是hello,则输出序列是ello!
    预测时:首先选择一个x1当作起始的字符,然后用训练好的模型得到下一个字符出现的概率。根据这个概率选择一个字符输出,然后将此字符当作下一步的x2输入到模型中。依次递推,得到任意长度的文字。注意:输入的单个字母是以one-hot形式进行编码的!
  • 对中文进行建模时,每一步输入模型的是一个汉字,由于汉字的种类太多,导致模型太大,一般采用下面的方法进行优化:
    • 1.取最常用的N个汉字,将剩下的汉字变成单独的一类,用一个<unk>字符来进行标注
    • 2.在输入时,可以加入一个embedding层,将汉字的one-hot编码转为稠密的词嵌入表示。对单个字母不使用embedding是由于单个字母不具备任何的含义,只需要使用one-hot编码即可。单个汉字是具有一定的实际意义的,所以使用embedding层

1.实现RNN的基本单元RNNCell抽象类--------有两种直接使用的子类:BasicRNNCell(基本的RNN)和LSTMCell(基本的LSTM)

  • RNNCell有三个属性:

    • 1.类方法call:所有的子类都会实现一个__call__函数,可以实现RNN的单步计算,调用形式为(output,next_state) = __call__(input, state)
    • 2.类属性state_size:隐藏层的大小,输入数据是以batch_size的形式进行输入的即input=(batch_size, input_size),
      调用__call__函数时隐藏层的形状是(batch_size, state_size),输出层的形状是(batch_size, output_size)
    • 3.类属性output_size:输出向量的大小

2.定义一个基本的RNN单元

import tensorflow as tf
import numpy as nprnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
print("rnn_cell.state_size:", rnn_cell.state_size)

3.定义一个基本的LSTM的基本单元

lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=128)
print("lstm_cell.state_size:", lstm_cell.state_size)lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=128) # batch_size=32, input_size=100
inputs = tf.placeholder(np.float32, shape=(32, 100))
h0 = lstm_cell.zero_state(32, np.float32)  # 通过zero_state得到一个全0的初始状态
output, h1 = lstm_cell.__call__(inputs, h0)
print(h1.c)
print(h1.h)

4.对RNN进行堆叠:MultiRNNCell

# 每次调用这个函数返回一个BasicRNNCell
def get_a_cell():return tf.nn.rnn_cell.BasicRNNCell(num_units=128)# 使用MultiRNNCell创建3层RNN
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)])
# 得到的RNN也是RNNCell的子类,state_size=(128, 128, 128):三个隐层状态,每个隐层状态的大小是128
print(cell.state_size)# 32是batch_size, 100是input_size
inputs = tf.placeholder(np.float32, shape=(32, 100))
h0 = cell.zero_state(32, np.float32)
output, h1 = cell.__call__(inputs, h0)
print(h1)

5.使用tf.nn.dunamic_rnn按时间展开:相当于增加了一个时间维度time_steps,通过{h0,x1,x2…,xn}得到{h1,h2,h3,…hn}

inputs: shape=(batch_size, time_steps, input_size)  # 输入数据的格式是(batch_size, time_steps, input_size)
initial_state:  shape(batch_size,cell.state_size)  # 初始状态,一般可以取零矩阵
outputs, state = tf.nn.dynamic_rnn(cell,inputs,initial_state)
# outputs是time_steps中所有的输出,形状是(batch_size, time_steps, cell.output_size)
# state是最后一步的隐状态,形状是(batch_size,cell.state_size)
  • 注意:输入数据的形状是(time_steps,batch_size, input_size),可以调用tf.nn.dynamic_rnn()函数中设定参数time_major=True。此时,得到的outputs的形状是(time_steps, batch_size, cell.output_size);state的形状不变化

TensorFlow中的RNNCell基本单元使用相关推荐

  1. TensorFlow中RNN实现的正确打开方式

    上周写的文章<完全图解RNN.RNN变体.Seq2Seq.Attention机制>介绍了一下RNN的几种结构,今天就来聊一聊如何在TensorFlow中实现这些结构,这篇文章的主要内容为: ...

  2. TensorFlow中RNN实现的正确打开方式(转)

    上周写的文章<完全图解RNN.RNN变体.Seq2Seq.Attention机制>介绍了一下RNN的几种结构,今天就来聊一聊如何在TensorFlow中实现这些结构,这篇文章的主要内容为: ...

  3. 循序渐进的学习TensorFlow中RNN实现的方法

    本文转载自:https://zhuanlan.zhihu.com/p/28196873 上周写了一篇文章介绍了一下RNN的几种结构,今天就来聊一聊如何在TensorFlow中实现这些结构,这篇文章的主 ...

  4. TensorFlow 中文文档 介绍

    介绍 本章的目的是让你了解和运行 TensorFlow 在开始之前, 先看一段使用 Python API 撰写的 TensorFlow 示例代码, 对将要学习的内容有初步的印象. 这段很短的 Pyth ...

  5. struts实现分页_在TensorFlow中实现点Struts

    struts实现分页 If you want to get started on 3D Object Detection and more specifically on Point Pillars, ...

  6. Python 3深度置信网络(DBN)在Tensorflow中的实现MNIST手写数字识别

    任何程序错误,以及技术疑问或需要解答的,请扫码添加作者VX:1755337994 使用DBN识别手写体 传统的多层感知机或者神经网络的一个问题: 反向传播可能总是导致局部最小值. 当误差表面(erro ...

  7. SELU︱在keras、tensorflow中使用SELU激活函数

    arXiv 上公开的一篇 NIPS 投稿论文<Self-Normalizing Neural Networks>引起了圈内极大的关注,它提出了缩放指数型线性单元(SELU)而引进了自归一化 ...

  8. 在Tensorflow中使用深度学习构建图像标题生成器

    by Cole Murray 通过科尔·默里(Cole Murray) 在Tensorflow中使用深度学习构建图像标题生成器 (Building an image caption generator ...

  9. TensorFlow中张量,变量、常量、占位符概念

    1.总结TensorFlow中的张量概念 张量:数据结构:多维数组 零阶张量表示标量(scalar),也就是一个数: 一阶张量为向量(vector),也就是一个数组: N阶张量可以理解为一个n维数组: ...

最新文章

  1. 涨姿势!北京地铁原来是16条旅游专线
  2. 使用CSS实现无滚动条滚动
  3. 惠普用的是微软服务器吗,惠普抛弃MediaSmart服务器 微软表示淡定
  4. 05-02 docker 安装与配置-CentOS
  5. 【二分图】【最大匹配】【匈牙利算法】CODEVS 2776 寻找代表元
  6. failed to accept an incoming connection: connection from 192.168.1.114 rejected, allowed hosts: 1
  7. bzoj 1691: [Usaco2007 Dec]挑剔的美食家(multiset贪心)
  8. 两个平面的位置关系和判定方程组解_2018年高考数学总复习第九章平面解析几何第2讲两直线的位置关系学案!...
  9. matlab三角形外接圆
  10. iOS系统录屏如何增加雷达波纹效果(从一个点向周围扩散)的简单实现
  11. 一条校招/社招潜规则~
  12. 物联网(IoT)及其未来应用方向
  13. 基于机器学习的电影票房分析与预测系统
  14. 《数据结构与算法之二叉平衡树(AVL)》
  15. 新手如何学习Java以及学习java的步骤
  16. 哈工大软件过程与工具
  17. simulink中MUX
  18. FMEA-MSR步骤六:优化
  19. 从零开始学AI(Python基础)
  20. 你的第一篇SCI是怎么发的呢?

热门文章

  1. TOP10全球ICT技术发展趋势
  2. IntelliJ IDEA 2016.3.1 学习git 码云插件 学习笔记
  3. 聊聊在博客园写博客的这两年《Unity 3D脚本编程:使用C#语言开发跨平台游戏》正式出版...
  4. 进程间通信(三)—信号量
  5. log4cplus使用(二)-自定义日志等级
  6. 机器学习及其在信息检索中的应用
  7. ZedGraph在项目中的应用
  8. Access自动编号 违反并发性原因解析
  9. 不能上传图片和编辑内容很慢,望改进
  10. 开发必备快速定位排查日志 9 大类命令详解