神经网络学习小记录37——Keras实现GRU与GRU参数量详解

  • 学习前言
  • 什么是GRU
    • 1、GRU单元的输入与输出
    • 2、GRU的门结构
    • 3、GRU的参数量计算
      • a、更新门
      • b、重置门
      • c、全部参数量
  • 在Keras中实现GRU
    • 实现代码

学习前言

我死了我死了我死了!

什么是GRU

GRU是LSTM的一个变种。

传承了LSTM的门结构,但是将LSTM的三个门转化成两个门,分别是更新门和重置门。

1、GRU单元的输入与输出

下图是每个GRU单元的结构。

在n时刻,每个GRU单元的输入有两个:

  • 当前时刻网络的输入值Xt
  • 上一时刻GRU的输出值ht-1

输出有一个:

  • 当前时刻GRU输出值ht

2、GRU的门结构

GRU含有两个门结构,分别是:

更新门zt和重置门rt

更新门用于控制前一时刻的状态信息被代入到当前状态的程度,更新门的值越大说明前一时刻的状态信息带入越少,这一时刻的状态信息带入越多。

重置门用于控制忽略前一时刻的状态信息的程度,重置门的值越小说明忽略得越多。

3、GRU的参数量计算

a、更新门


更新门在图中的标号为zt,需要结合ht-1和Xt来决定上一时刻的输出ht-1有多少得到保留,更新门的值越大说明前一时刻的状态信息保留越少,这一时刻的状态信息保留越多。

结合公式我们可以知道:
zt由ht-1和Xt来决定。

当更新门zt的值较大的时候,上一时刻的输出ht-1保留较少,而这一时刻的状态信息保留较多。

Wz的参数量=(xdim+hdim)∗hdimW_z的参数量 = (x_{dim} + h_{dim}) * h_{dim} Wz​的参数量=(xdim​+hdim​)∗hdim​
bz的参数量=hdimb_z的参数量 = h_{dim} bz​的参数量=hdim​
更新门的总参数量为:
总参数量=((xdim+hdim)∗hdim+hdim)总参数量 = ((x_{dim} + h_{dim}) * h_{dim} + h_{dim}) 总参数量=((xdim​+hdim​)∗hdim​+hdim​)

b、重置门


重置门在图中的标号为rt,需要结合ht-1和Xt来控制忽略前一时刻的状态信息的程度,重置门的值越小说明忽略得越多。

结合公式我们可以知道:

rt由ht-1和Xt来决定。

当重置门rt的值较小的时候,上一时刻的输出ht-1保留较少,说明忽略得越多。

Wt的参数量=(xdim+hdim)∗hdimW_t的参数量 = (x_{dim} + h_{dim}) * h_{dim} Wt​的参数量=(xdim​+hdim​)∗hdim​
bt的参数量=hdimb_t的参数量 = h_{dim} bt​的参数量=hdim​
W的参数量=(xdim+hdim)∗hdimW的参数量 = (x_{dim} + h_{dim}) * h_{dim} W的参数量=(xdim​+hdim​)∗hdim​
b的参数量=hdimb的参数量 = h_{dim} b的参数量=hdim​
重置门的总参数量为:
总参数量=2∗((xdim+hdim)∗hdim+hdim)总参数量 = 2*((x_{dim} + h_{dim}) * h_{dim} + h_{dim}) 总参数量=2∗((xdim​+hdim​)∗hdim​+hdim​)

c、全部参数量

所以所有的门总参数量为:
总参数量=3∗((xdim+hdim)∗hdim+hdim)总参数量 = 3*((x_{dim} + h_{dim}) * h_{dim} + h_{dim}) 总参数量=3∗((xdim​+hdim​)∗hdim​+hdim​)

在Keras中实现GRU

GRU一般需要输入两个参数。
一个是unit、一个是input_shape。

LSTM(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))

unit用于指定神经元的数量。
input_shape用于指定输入的shape,分别指定TIME_STEPS和INPUT_SIZE。

实现代码

import numpy as np
from keras.models import Sequential
from keras.layers import Input,Activation,Dense
from keras.models import Model
from keras.datasets import mnist
from keras.layers.recurrent import GRU
from keras.utils import np_utils
from keras.optimizers import AdamTIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3(X_train,Y_train),(X_test,Y_test) = mnist.load_data()X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)inputs = Input(shape=[TIME_STEPS,INPUT_SIZE])x = GRU(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))(inputs)
x = Dense(OUTPUT_SIZE)(x)
x = Activation("softmax")(x)model = Model(inputs,x)
adam = Adam(LR)
model.summary()
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])for i in range(50000):X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]index_start += BATCH_SIZEcost = model.train_on_batch(X_batch,Y_batch)if index_start >= X_train.shape[0]:index_start = 0if i%100 == 0:cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)print("accuracy:",accuracy)

实现效果:

10000/10000 [==============================] - 2s 231us/step
accuracy: 0.16749999986961484
10000/10000 [==============================] - 2s 206us/step
accuracy: 0.6134000015258789
10000/10000 [==============================] - 2s 214us/step
accuracy: 0.7058000019192696
10000/10000 [==============================] - 2s 209us/step
accuracy: 0.797899999320507

神经网络学习小记录37——Keras实现GRU与GRU参数量详解相关推荐

  1. 神经网络学习小记录39——MobileNetV3(small)模型的复现详解

    神经网络学习小记录39--MobileNetV3(small)模型的复现详解 学习前言 什么是MobileNetV3 代码下载 large与small的区别 MobileNetV3(small)的网络 ...

  2. 神经网络学习小记录45——Keras常用学习率下降方式汇总

    神经网络学习小记录45--Keras常用学习率下降方式汇总 2020年5月19日更新 前言 为什么要调控学习率 下降方式汇总 1.阶层性下降 2.指数型下降 3.余弦退火衰减 4.余弦退火衰减更新版 ...

  3. 神经网络学习小记录58——Keras GhostNet模型的复现详解

    神经网络学习小记录58--Keras GhostNet模型的复现详解 学习前言 什么是GhostNet模型 源码下载 GhostNet模型的实现思路 1.Ghost Module 2.Ghost Bo ...

  4. 神经网络学习小记录26——Keras 利用efficientnet系列模型搭建yolov3目标检测平台

    神经网络学习小记录26--Keras 利用efficientnet系列模型搭建efficientnet-yolov3目标检测平台 学习前言 什么是EfficientNet模型 源码下载 Efficie ...

  5. 神经网络学习小记录17——使用AlexNet分类模型训练自己的数据(猫狗数据集)

    神经网络学习小记录17--使用AlexNet分类模型训练自己的数据(猫狗数据集) 学习前言 什么是AlexNet模型 训练前准备 1.数据集处理 2.创建Keras的AlexNet模型 开始训练 1. ...

  6. 神经网络学习小记录19——微调VGG分类模型训练自己的数据(猫狗数据集)

    神经网络学习小记录19--微调VGG分类模型训练自己的数据(猫狗数据集) 注意事项 学习前言 什么是VGG16模型 VGG模型的复杂程度 训练前准备 1.数据集处理 2.创建Keras的VGG模型 3 ...

  7. 神经网络学习小记录-番外篇——常见问题汇总

    神经网络学习小记录-番外篇--常见问题汇总 前言 问题汇总 1.下载问题 a.代码下载 b. 权值下载 c. 数据集下载 2.环境配置问题 a.20系列所用的环境 b.30系列显卡环境配置 c.CPU ...

  8. 神经网络学习小记录40——春节到了,用LSTM写古诗不?

    神经网络学习小记录40--春节到了,用LSTM写古诗不? 学习前言 整体实现思路 github下载地址与B站连接 代码实现 1.数据处理 a.读取古诗并转化为id b.将读取到的所有古诗转化为6to1 ...

  9. 神经网络学习小记录68——Tensorflow2版 Vision Transformer(VIT)模型的复现详解

    神经网络学习小记录68--Tensorflow2版 Vision Transformer(VIT)模型的复现详解 学习前言 什么是Vision Transformer(VIT) 代码下载 Vision ...

最新文章

  1. android同时使用多个library时的问题
  2. 千万级饿了么交易系统架构 5 年演化史!
  3. Java---线程多(工作内存)和内存模型(主内存)分析
  4. 1.1 一个简单的脚本
  5. js进阶 14-8 表单序列化函数serializeArray()和serialize()的区别是什么
  6. kaggle房价预测特征意思_Kaggle之预测房价
  7. Java Spring全家桶详解——Spring简介
  8. ppi 各代iphone_iphone型号对比
  9. win7电脑误删鼠标键盘驱动_Win7系统鼠标键盘驱动检测不到的三种解决方法
  10. 如何把HTML背景图片变透明,photoshop怎样把图片背景变透明
  11. CSS4.2.3 参考手册.CHM
  12. 软件项目经理应具备的素质和条件_软件项目经理的素质能力要求
  13. WiFi相关知识介绍
  14. 零基础编程入门先学什么
  15. 纽约州立石溪分校计算机科学排名,纽约州立大学石溪分校计算机科学专业排名第40(2020年USNEWS美国排名)...
  16. ————博客永久废止————转到http://1su.net/nsB
  17. JAVA8用哪个版本的MYSQL_MySQL用哪个版本,5.7还是8.0?
  18. 数据处理之数据类型转换
  19. 《Linux操作系统 - RK3288开发笔记》第2章 G-3288-02开发环境搭建
  20. webpack中处理css文件

热门文章

  1. 微软命令行设定/取消定时关机
  2. 深度学习 + 众包重现历史街景,在线体验“时间旅行”
  3. java图形验证码_java图形验证码实现
  4. 节日不要随便发祝福短信--有感于11.1.11
  5. 定义一个用来和老师打招呼的方法。
  6. HTTP协议详解(超级详细)
  7. Spring开发Service层
  8. 移动端适配之视觉窗口view-port的详细设置
  9. linux内核的一些知识点(上)
  10. MySQL导出数据库、数据库表结构、存储过程及函数【用】