神经网络学习小记录37——Keras实现GRU与GRU参数量详解
神经网络学习小记录37——Keras实现GRU与GRU参数量详解
- 学习前言
- 什么是GRU
- 1、GRU单元的输入与输出
- 2、GRU的门结构
- 3、GRU的参数量计算
- a、更新门
- b、重置门
- c、全部参数量
- 在Keras中实现GRU
- 实现代码
学习前言
我死了我死了我死了!
什么是GRU
GRU是LSTM的一个变种。
传承了LSTM的门结构,但是将LSTM的三个门转化成两个门,分别是更新门和重置门。
1、GRU单元的输入与输出
下图是每个GRU单元的结构。
在n时刻,每个GRU单元的输入有两个:
- 当前时刻网络的输入值Xt;
- 上一时刻GRU的输出值ht-1;
输出有一个:
- 当前时刻GRU输出值ht;
2、GRU的门结构
GRU含有两个门结构,分别是:
更新门zt和重置门rt:
更新门用于控制前一时刻的状态信息被代入到当前状态的程度,更新门的值越大说明前一时刻的状态信息带入越少,这一时刻的状态信息带入越多。
重置门用于控制忽略前一时刻的状态信息的程度,重置门的值越小说明忽略得越多。
3、GRU的参数量计算
a、更新门
更新门在图中的标号为zt,需要结合ht-1和Xt来决定上一时刻的输出ht-1有多少得到保留,更新门的值越大说明前一时刻的状态信息保留越少,这一时刻的状态信息保留越多。
结合公式我们可以知道:
zt由ht-1和Xt来决定。
当更新门zt的值较大的时候,上一时刻的输出ht-1保留较少,而这一时刻的状态信息保留较多。
Wz的参数量=(xdim+hdim)∗hdimW_z的参数量 = (x_{dim} + h_{dim}) * h_{dim} Wz的参数量=(xdim+hdim)∗hdim
bz的参数量=hdimb_z的参数量 = h_{dim} bz的参数量=hdim
更新门的总参数量为:
总参数量=((xdim+hdim)∗hdim+hdim)总参数量 = ((x_{dim} + h_{dim}) * h_{dim} + h_{dim}) 总参数量=((xdim+hdim)∗hdim+hdim)
b、重置门
重置门在图中的标号为rt,需要结合ht-1和Xt来控制忽略前一时刻的状态信息的程度,重置门的值越小说明忽略得越多。
结合公式我们可以知道:
rt由ht-1和Xt来决定。
当重置门rt的值较小的时候,上一时刻的输出ht-1保留较少,说明忽略得越多。
Wt的参数量=(xdim+hdim)∗hdimW_t的参数量 = (x_{dim} + h_{dim}) * h_{dim} Wt的参数量=(xdim+hdim)∗hdim
bt的参数量=hdimb_t的参数量 = h_{dim} bt的参数量=hdim
W的参数量=(xdim+hdim)∗hdimW的参数量 = (x_{dim} + h_{dim}) * h_{dim} W的参数量=(xdim+hdim)∗hdim
b的参数量=hdimb的参数量 = h_{dim} b的参数量=hdim
重置门的总参数量为:
总参数量=2∗((xdim+hdim)∗hdim+hdim)总参数量 = 2*((x_{dim} + h_{dim}) * h_{dim} + h_{dim}) 总参数量=2∗((xdim+hdim)∗hdim+hdim)
c、全部参数量
所以所有的门总参数量为:
总参数量=3∗((xdim+hdim)∗hdim+hdim)总参数量 = 3*((x_{dim} + h_{dim}) * h_{dim} + h_{dim}) 总参数量=3∗((xdim+hdim)∗hdim+hdim)
在Keras中实现GRU
GRU一般需要输入两个参数。
一个是unit、一个是input_shape。
LSTM(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))
unit用于指定神经元的数量。
input_shape用于指定输入的shape,分别指定TIME_STEPS和INPUT_SIZE。
实现代码
import numpy as np
from keras.models import Sequential
from keras.layers import Input,Activation,Dense
from keras.models import Model
from keras.datasets import mnist
from keras.layers.recurrent import GRU
from keras.utils import np_utils
from keras.optimizers import AdamTIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3(X_train,Y_train),(X_test,Y_test) = mnist.load_data()X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)inputs = Input(shape=[TIME_STEPS,INPUT_SIZE])x = GRU(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))(inputs)
x = Dense(OUTPUT_SIZE)(x)
x = Activation("softmax")(x)model = Model(inputs,x)
adam = Adam(LR)
model.summary()
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])for i in range(50000):X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]index_start += BATCH_SIZEcost = model.train_on_batch(X_batch,Y_batch)if index_start >= X_train.shape[0]:index_start = 0if i%100 == 0:cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)print("accuracy:",accuracy)
实现效果:
10000/10000 [==============================] - 2s 231us/step
accuracy: 0.16749999986961484
10000/10000 [==============================] - 2s 206us/step
accuracy: 0.6134000015258789
10000/10000 [==============================] - 2s 214us/step
accuracy: 0.7058000019192696
10000/10000 [==============================] - 2s 209us/step
accuracy: 0.797899999320507
神经网络学习小记录37——Keras实现GRU与GRU参数量详解相关推荐
- 神经网络学习小记录39——MobileNetV3(small)模型的复现详解
神经网络学习小记录39--MobileNetV3(small)模型的复现详解 学习前言 什么是MobileNetV3 代码下载 large与small的区别 MobileNetV3(small)的网络 ...
- 神经网络学习小记录45——Keras常用学习率下降方式汇总
神经网络学习小记录45--Keras常用学习率下降方式汇总 2020年5月19日更新 前言 为什么要调控学习率 下降方式汇总 1.阶层性下降 2.指数型下降 3.余弦退火衰减 4.余弦退火衰减更新版 ...
- 神经网络学习小记录58——Keras GhostNet模型的复现详解
神经网络学习小记录58--Keras GhostNet模型的复现详解 学习前言 什么是GhostNet模型 源码下载 GhostNet模型的实现思路 1.Ghost Module 2.Ghost Bo ...
- 神经网络学习小记录26——Keras 利用efficientnet系列模型搭建yolov3目标检测平台
神经网络学习小记录26--Keras 利用efficientnet系列模型搭建efficientnet-yolov3目标检测平台 学习前言 什么是EfficientNet模型 源码下载 Efficie ...
- 神经网络学习小记录17——使用AlexNet分类模型训练自己的数据(猫狗数据集)
神经网络学习小记录17--使用AlexNet分类模型训练自己的数据(猫狗数据集) 学习前言 什么是AlexNet模型 训练前准备 1.数据集处理 2.创建Keras的AlexNet模型 开始训练 1. ...
- 神经网络学习小记录19——微调VGG分类模型训练自己的数据(猫狗数据集)
神经网络学习小记录19--微调VGG分类模型训练自己的数据(猫狗数据集) 注意事项 学习前言 什么是VGG16模型 VGG模型的复杂程度 训练前准备 1.数据集处理 2.创建Keras的VGG模型 3 ...
- 神经网络学习小记录-番外篇——常见问题汇总
神经网络学习小记录-番外篇--常见问题汇总 前言 问题汇总 1.下载问题 a.代码下载 b. 权值下载 c. 数据集下载 2.环境配置问题 a.20系列所用的环境 b.30系列显卡环境配置 c.CPU ...
- 神经网络学习小记录40——春节到了,用LSTM写古诗不?
神经网络学习小记录40--春节到了,用LSTM写古诗不? 学习前言 整体实现思路 github下载地址与B站连接 代码实现 1.数据处理 a.读取古诗并转化为id b.将读取到的所有古诗转化为6to1 ...
- 神经网络学习小记录68——Tensorflow2版 Vision Transformer(VIT)模型的复现详解
神经网络学习小记录68--Tensorflow2版 Vision Transformer(VIT)模型的复现详解 学习前言 什么是Vision Transformer(VIT) 代码下载 Vision ...
最新文章
- android同时使用多个library时的问题
- 千万级饿了么交易系统架构 5 年演化史!
- Java---线程多(工作内存)和内存模型(主内存)分析
- 1.1 一个简单的脚本
- js进阶 14-8 表单序列化函数serializeArray()和serialize()的区别是什么
- kaggle房价预测特征意思_Kaggle之预测房价
- Java Spring全家桶详解——Spring简介
- ppi 各代iphone_iphone型号对比
- win7电脑误删鼠标键盘驱动_Win7系统鼠标键盘驱动检测不到的三种解决方法
- 如何把HTML背景图片变透明,photoshop怎样把图片背景变透明
- CSS4.2.3 参考手册.CHM
- 软件项目经理应具备的素质和条件_软件项目经理的素质能力要求
- WiFi相关知识介绍
- 零基础编程入门先学什么
- 纽约州立石溪分校计算机科学排名,纽约州立大学石溪分校计算机科学专业排名第40(2020年USNEWS美国排名)...
- ————博客永久废止————转到http://1su.net/nsB
- JAVA8用哪个版本的MYSQL_MySQL用哪个版本,5.7还是8.0?
- 数据处理之数据类型转换
- 《Linux操作系统 - RK3288开发笔记》第2章 G-3288-02开发环境搭建
- webpack中处理css文件