RNN(LSTM&GRU)文本分类(PaddlePaddle2.0)

一、RNN简介

循环神经网络(Recurrent Neural Network, RNN)是一类以序列(sequence)数据为输入,在序列的演进方向进行递归(recursion)且所有节点(循环单元)按链式连接的递归神经网络(recursive neural network)。

它与DNN,CNN不同的是: 它不仅考虑前一时刻的输入,而且赋予了网络对前面的内容的一种’记忆’功能.

RNN之所以称为循环神经网路,即一个序列当前的输出与前面的输出也有关。具体的表现形式为网络会对前面的信息进行记忆并应用于当前输出的计算中,即隐藏层之间的节点不再无连接而是有连接的,并且隐藏层的输入不仅包括输入层的输出还包括上一时刻隐藏层的输出。

对循环神经网络的研究始于二十世纪80-90年代,并在二十一世纪初发展为深度学习(deep learning)算法之一,其中双向循环神经网络(Bidirectional RNN, Bi-RNN)和长短期记忆网络(Long Short-Term Memory networks,LSTM)是常见的循环神经网络。

循环神经网络具有记忆性、参数共享并且图灵完备(Turing completeness),因此在对序列的非线性特征进行学习时具有一定优势。循环神经网络在自然语言处理(Natural Language Processing, NLP),例如语音识别、语言建模、机器翻译等领域有应用,也被用于各类时间序列预报。引入了卷积神经网络(Convoutional Neural Network,CNN)构筑的循环神经网络可以处理包含序列输入的计算机视觉问题。

最简单的RNN网络

其展开可以表示为:

那么数学表示的公式为:
h∗t=Whxxt+Whhht−1+bhht=σ(h∗t)o∗t=Wohht+boot=θ(o∗t)h^{t}_{*} = W_{hx}x^{t} + W_{hh}h^{t-1} + b_{h} \\ h^{t} = \sigma(h^{t}_{*}) \\ o^{t}_{*} = W_{oh} h^{t} + b_{o}\\ o^{t} = \theta (o^{t}_{*}) ht=Whxxt+Whhht1+bhht=σ(ht)ot=Wohht+boot=θ(ot)

其中,xtx^{t}xt表示t时刻的输入,oto^{t}ot表示t时刻的输出,hth^{t}ht表示t时刻隐藏层的状态。

由于每一步的输出不仅仅依赖当前步的网络,并且还需要前若干步网络的状态,那么这种BP改版的算法叫做Backpropagation Through Time(BPTT) , 也就是将输出端的误差值反向传递,运用梯度下降法进行更新.

RNN的问题和改进

较为严重的是容易出现梯度消失(时间过长而造成记忆值较小)或者梯度爆炸的问题(BP算法和长时间依赖造成的)

因此, 就出现了一系列的改进的算法, 最基础的两种算法是LSTM 和 GRU.

这两种方法在面对梯度消失或者梯度爆炸的问题时,由于有特殊的方式存储”记忆”,那么以前梯度比较大的”记忆”不会像简单的RNN一样马上被抹除,因此可以一定程度上克服梯度消失问题;而针对梯度爆炸则设置阈值,超过阈值直接限制梯度。

LSTM算法(Long Short Term Memory, 长短期记忆网络 )

LSTM(Long short-term memory,长短期记忆)是一种特殊的RNN,主要是为了解决长序列训练过程中的梯度消失问题。

LSTM是有4个全连接层进行计算的,LSTM的内部结构如下图所示。

其中符号含义如下:

接下来看一下内部的具体内容:

LSTM的核心是细胞状态——最上层的横穿整个细胞的水平线,它通过门来控制信息的增加或者删除。
STM共有三个门,分别是遗忘门,输入门和输出门。

  • 遗忘门:遗忘门决定丢弃哪些信息,输入是上一个神经元细胞的计算结果ht-1以及当前的输入向量xt,二者联接并通过遗忘门后(sigmoid会决定哪些信息留下,哪些信息丢弃),会生成一个0-1向量Γft(维度与上一个神经元细胞的输出向量Ct-1相同),Γft与Ct-1进行点乘操作后,就会获取上一个神经元细胞经过计算后保留的信息。
  • 输入门:表示要保存的信息或者待更新的信息,如上图所示是ht-1与xt的连接向量,经过sigmoid层后得到的结果Γit,这就是输入门的输出结果了。
  • 输出门:输出门决定当前神经原细胞输出的隐向量ht,ht与Ct不同,ht要稍微复杂一点,它是Ct进过tanh计算后与输出门的计算结果进行点乘操作后的结果,用公式描述是:ht = tanh(ct) · Γot

GRU(门控循环单元)

GRU是LSTM的变种,它也是一种RNN,因此是循环结构,相比LSTM而言,它的计算要简单一些,计算量也降低。

GRU 有两个有两个门,即一个重置门(reset gate)和一个更新门(update gate)。从直观上来说,重置门决定了如何将新的输入信息与前面的记忆相结合,更新门定义了前面记忆保存到当前时间步的量。如果我们将重置门设置为 1,更新门设置为 0,那么我们将再次获得标准 RNN 模型。使用门控机制学习长期依赖关系的基本思想和 LSTM 一致,但还是有一些关键区别:

  • GRU 有两个门(重置门与更新门),而 LSTM 有三个门(输入门、遗忘门和输出门)。

  • GRU 并不会控制并保留内部记忆(c_t),且没有 LSTM 中的输出门。

  • LSTM 中的输入与遗忘门对应于 GRU 的更新门,重置门直接作用于前面的隐藏状态。

  • 重置门:用来决定需要丢弃哪些上一个神经元细胞的信息,它的计算过程是将Ct-1与当前输入向量xt进行连接后,输入sigmoid层进行计算,结果为S1,再将S1与Ct-1进行点乘计算,则结果为保存的上个神经元细胞信息,用C’t-1表示。公式表示为:C’t-1 = Ct-1 · S1,S1 = sigmoid(concat(Ct-1,Xt))

  • 更新门:更新门类似于LSTM的遗忘门和输入门,它决定哪些信息会丢弃,以及哪些新信息会增加。

完整公式描述为:

二、数据简介

本次使用的分类数据是从新浪微博不实信息举报平台抓取的中文谣言数据,数据集中共包含1538条谣言和1849条非谣言。 更多数据集介绍请参考https://github.com/thunlp/Chinese_Rumor_Dataset

三、数据处理

加载数据集

import pandas as pd
all_data = pd.read_csv("data/data69671/all_data.tsv", sep="\t")
all_data.head()
label text
0 0 #广州#【广州游行打砸抢罪犯资料公布!居然是日本间谍!】 ​
1 0 【政协委员提议恢复大清王朝】康熙十世孙、广州政协委员金复新表示,他准备走遍中国收集100万人...
2 1 有木有人和我一样。睡觉时头总爱靠在枕头的一角。据说这样的孩纸,都没安全感。
3 1 据说,看到这张图的人,许个愿,在十秒内转发的,就能美梦成真!!我们也试试!!!
4 0 【老小子走了!李登辉今天凌晨心脏病复发身亡】台北消息:原国民党、台联党主席,有“台独教父”之...

生成词典


all_str = all_data["text"].values.tolist()
dict_set = set() # 保证每个字符只有唯一的对应数字
for content in all_str:for s in content:dict_set.add(s)
# 添加未知字符
dict_set.add("<unk>")
# 把元组转换成字典,一个字对应一个数字
dict_list = []
i = 0
for s in dict_set:dict_list.append([s, i])i += 1
dict_txt = dict(dict_list)
# 字典保存到本地
with open("dict.txt", 'w', encoding='utf-8') as f:f.write(str(dict_txt))
# 获取字典的长度
def get_dict_len(dict_path):with open(dict_path, 'r', encoding='utf-8') as f:line = eval(f.readlines()[0])return len(line.keys())
print(get_dict_len("dict.txt"))
4410

划分训练集、验证集以及测试集

all_data_list = all_data.values.tolist()
train_length = len(all_data) // 10 * 7
dev_length = len(all_data) // 10 * 2train_data = []
dev_data = []
test_data = []
for i in range(train_length):text = ""for s in all_data_list[i][1]:text = text + str(dict_txt[s]) + ","text = text[:-1]train_data.append([text, all_data_list[i][0]])for i in range(train_length, train_length+dev_length):text = ""for s in all_data_list[i][1]:text = text + str(dict_txt[s]) + ","text = text[:-1]dev_data.append([text, all_data_list[i][0]])for i in range(train_length+dev_length, len(all_data)):text = ""for s in all_data_list[i][1]:text = text + str(dict_txt[s]) + ","text = text[:-1]test_data.append([text, all_data_list[i][0]])print(len(train_data))
print(len(dev_data))
print(len(test_data))
df_train = pd.DataFrame(columns=["text", "label"], data=train_data)
df_dev = pd.DataFrame(columns=["text", "label"], data=dev_data)
df_test = pd.DataFrame(columns=["text", "label"], data=test_data)
df_train.to_csv("train_data.csv", index=False)
df_dev.to_csv("dev_data.csv", index=False)
df_test.to_csv("test_data.csv", index=False)
2366
676
345

自定义数据集

import numpy as np
import paddle
from paddle.io import Dataset, DataLoader
import pandas as pd
class MyDataset(Dataset):"""步骤一:继承paddle.io.Dataset类"""def __init__(self, mode='train'):"""步骤二:实现构造函数,定义数据读取方式,划分训练和测试数据集"""super(MyDataset, self).__init__()self.label = Trueif mode == 'train':text = pd.read_csv("train_data.csv")["text"].values.tolist()label = pd.read_csv("train_data.csv")["label"].values.tolist()self.data = []for i in range(len(text)):self.data.append([])self.data[-1].append(np.array([int(i) for i in text[i].split(",")]))self.data[-1][0] = self.data[-1][0][:256].astype('int64')if len(self.data[-1][0])>=256 else np.concatenate([self.data[-1][0], np.array([dict_txt["<unk>"]]*(256-len(self.data[-1][0])))]).astype('int64')self.data[-1].append(np.array(int(label[i])).astype('int64'))elif mode == 'dev':text = pd.read_csv("dev_data.csv")["text"].values.tolist()label = pd.read_csv("dev_data.csv")["label"].values.tolist()self.data = []for i in range(len(text)):self.data.append([])self.data[-1].append(np.array([int(i) for i in text[i].split(",")]))self.data[-1][0] = self.data[-1][0][:256].astype('int64')if len(self.data[-1][0])>=256 else np.concatenate([self.data[-1][0], np.array([dict_txt["<unk>"]]*(256-len(self.data[-1][0])))]).astype('int64')self.data[-1].append(np.array(int(label[i])).astype('int64'))else:text = pd.read_csv("test_data.csv")["text"].values.tolist()label = pd.read_csv("test_data.csv")["label"].values.tolist()self.data = []for i in range(len(text)):self.data.append([])self.data[-1].append(np.array([int(i) for i in text[i].split(",")]))self.data[-1][0] = self.data[-1][0][:256].astype('int64')if len(self.data[-1][0])>=256 else np.concatenate([self.data[-1][0], np.array([dict_txt["<unk>"]]*(256-len(self.data[-1][0])))]).astype('int64')self.data[-1].append(np.array(int(label[i])).astype('int64'))self.label = Falsedef __getitem__(self, index):"""步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)"""text_ =  self.data[index][0]label_ = self.data[index][1]if self.label:return text_, label_else:return text_def __len__(self):"""步骤四:实现__len__方法,返回数据集总数目"""return len(self.data)
train_data = MyDataset(mode="train")
dev_data = MyDataset(mode="dev")
test_data = MyDataset(mode="test")
BATCH_SIZE = 128
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
dev_loader = DataLoader(dev_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

四、配置网络

LSTM

import paddle.nn as nn
inputs_dim = get_dict_len("dict.txt")
class myLSTM(nn.Layer):def __init__(self):super(myLSTM, self).__init__()# num_embeddings (int) - 嵌入字典的大小, input中的id必须满足 0 =< id < num_embeddings 。 。# embedding_dim (int) - 每个嵌入向量的维度。# padding_idx (int|long|None) - padding_idx的配置区间为 [-weight.shape[0], weight.shape[0],如果配置了padding_idx,那么在训练过程中遇到此id时会被用# sparse (bool) - 是否使用稀疏更新,在词嵌入权重较大的情况下,使用稀疏更新能够获得更快的训练速度及更小的内存/显存占用。# weight_attr (ParamAttr|None) - 指定嵌入向量的配置,包括初始化方法,具体用法请参见 ParamAttr ,一般无需设置,默认值为None。self.embedding = nn.Embedding(inputs_dim, 256)# input_size (int) - 输入的大小。# hidden_size (int) - 隐藏状态大小。# num_layers (int,可选) - 网络层数。默认为1。# direction (str,可选) - 网络迭代方向,可设置为forward或bidirect(或bidirectional)。默认为forward。# time_major (bool,可选) - 指定input的第一个维度是否是time steps。默认为False。# dropout (float,可选) - dropout概率,指的是出第一层外每层输入时的dropout概率。默认为0。# weight_ih_attr (ParamAttr,可选) - weight_ih的参数。默认为None。# weight_hh_attr (ParamAttr,可选) - weight_hh的参数。默认为None。# bias_ih_attr (ParamAttr,可选) - bias_ih的参数。默认为None。# bias_hh_attr (ParamAttr,可选) - bias_hh的参数。默认为None。self.lstm = nn.LSTM(256, 256, num_layers=2, direction='bidirectional',dropout=0.5)# in_features (int) – 线性变换层输入单元的数目。# out_features (int) – 线性变换层输出单元的数目。# weight_attr (ParamAttr, 可选) – 指定权重参数的属性。默认值为None,表示使用默认的权重参数属性,将权重参数初始化为0。具体用法请参见 ParamAttr 。# bias_attr (ParamAttr|bool, 可选) – 指定偏置参数的属性。 bias_attr 为bool类型且设置为False时,表示不会为该层添加偏置。 bias_attr 如果设置为True或者None,则表示使用默认的偏置参数属性,将偏置参数初始化为0。具体用法请参见 ParamAttr 。默认值为None。# name (str,可选) – 具体用法请参见 Name ,一般无需设置,默认值为None。self.linear = nn.Linear(in_features=256*2, out_features=2)self.dropout = nn.Dropout(0.5)def forward(self, inputs):emb = self.dropout(self.embedding(inputs))print(emb)output, (hidden, _) = self.lstm(emb)print("output:", output)#output形状大小为[batch_size,seq_len,num_directions * hidden_size]#hidden形状大小为[num_layers * num_directions, batch_size, hidden_size]#把前向的hidden与后向的hidden合并在一起hidden = paddle.concat((hidden[-2,:,:], hidden[-1,:,:]), axis = 1)print(hidden)hidden = self.dropout(hidden)#hidden形状大小为[batch_size, hidden_size * num_directions]return self.linear(hidden) 

封装模型

lstm_model = paddle.Model(myLSTM())

配置优化器等参数

lstm_model.prepare(paddle.optimizer.Adam(learning_rate=0.001, parameters=lstm_model.parameters()),paddle.nn.CrossEntropyLoss(),paddle.metric.Accuracy())

模型训练

lstm_model.fit(train_loader,dev_loader,epochs=1,batch_size=BATCH_SIZE,verbose=1,save_dir="work/lstm")
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/1/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingreturn (isinstance(seq, collections.Sequence) andTensor(shape=[128, 256, 256], dtype=float32, place=CPUPlace, stop_gradient=False,[[[-0.04651996,  0.        ,  0.        , ...,  0.03685956,  0.        ,  0.01133572],[ 0.06982405,  0.01399607,  0.        , ...,  0.        ,  0.04192550,  0.        ],[ 0.        ,  0.        , -0.07115781, ..., -0.04474840,  0.        , -0.00947660],...,[ 0.00036002,  0.02194860,  0.        , ...,  0.        ,  0.        , -0.06254490],[ 0.        ,  0.02194860,  0.        , ...,  0.        ,  0.        ,  0.        ],[ 0.        ,  0.02194860, -0.04798179, ..., -0.00808975,  0.        , -0.06254490]],[[-0.02243162,  0.        , -0.04848601, ...,  0.01318943,  0.05083232,  0.        ],[-0.03468777,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ],[-0.00250597,  0.        ,  0.00956149, ..., -0.05242553,  0.04173937,  0.04689348],...,[ 0.00036002,  0.        ,  0.        , ..., -0.00808975, -0.05875421, -0.06254490],[ 0.00036002,  0.        , -0.04798179, ..., -0.00808975,  0.        , -0.06254490],[ 0.        ,  0.        , -0.04798179, ...,  0.        ,  0.        ,  0.        ]],[[-0.05900055,  0.03810406, -0.02360879, ...,  0.00572890,  0.02618536,  0.        ],[ 0.01131869,  0.        ,  0.        , ...,  0.        ,  0.        , -0.02667695],[ 0.03942792, -0.05557661,  0.02318244, ...,  0.04537995,  0.01851625,  0.        ],...,[ 0.        ,  0.02194860, -0.04798179, ..., -0.00808975, -0.05875421, -0.06254490],[ 0.        ,  0.        , -0.04798179, ...,  0.        ,  0.        ,  0.        ],[ 0.        ,  0.02194860, -0.04798179, ...,  0.        ,  0.        ,  0.        ]],...,[[-0.05900055,  0.03810406, -0.02360879, ...,  0.00572890,  0.        ,  0.        ],[ 0.        , -0.03945251,  0.        , ..., -0.03621368,  0.        ,  0.        ],[ 0.04446360,  0.01865993,  0.        , ...,  0.        ,  0.07054503, -0.05907314],...,[ 0.        ,  0.        , -0.04798179, ..., -0.00808975,  0.        , -0.06254490],[ 0.00036002,  0.02194860, -0.04798179, ...,  0.        ,  0.        ,  0.        ],[ 0.        ,  0.02194860, -0.04798179, ..., -0.00808975, -0.05875421,  0.        ]],[[ 0.        ,  0.        ,  0.        , ...,  0.03685956, -0.06564632,  0.        ],[ 0.        ,  0.02998513,  0.        , ...,  0.01318943,  0.05083232, -0.02697797],[-0.02399679,  0.        ,  0.        , ...,  0.03550680,  0.        ,  0.        ],...,[ 0.        ,  0.        , -0.04798179, ..., -0.00808975, -0.05875421,  0.        ],[ 0.00036002,  0.        , -0.04798179, ...,  0.        ,  0.        ,  0.        ],[ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        , -0.06254490]],[[-0.04651996,  0.        ,  0.        , ...,  0.        , -0.06564632,  0.01133572],[ 0.05890563, -0.04621770,  0.02617000, ...,  0.        ,  0.02645203,  0.        ],[ 0.        ,  0.        ,  0.        , ..., -0.01541087, -0.01680669,  0.06438880],...,[ 0.00036002,  0.02194860,  0.        , ..., -0.00808975, -0.05875421,  0.        ],[ 0.00036002,  0.02194860, -0.04798179, ..., -0.00808975,  0.        , -0.06254490],[ 0.00036002,  0.02194860, -0.04798179, ..., -0.00808975,  0.        , -0.06254490]]])
output: Tensor(shape=[128, 256, 512], dtype=float32, place=CPUPlace, stop_gradient=False,[[[ 0.00938810, -0.00546899,  0.00342875, ...,  0.00907911,  0.03100064,  0.00325108],[ 0.00690732,  0.00087167,  0.00982505, ...,  0.00243614,  0.02229443,  0.00180063],[ 0.00080585,  0.00150115,  0.01130185, ...,  0.01030400,  0.02393062,  0.00775731],...,[-0.00172433,  0.01737818,  0.01828881, ..., -0.01210868,  0.01913041,  0.01308940],[ 0.00401915,  0.01313572,  0.02785316, ..., -0.01023939,  0.01315692,  0.01490174],[ 0.00594976,  0.00380577,  0.02129058, ..., -0.00132644,  0.00966622,  0.00500632]],[[ 0.01139693,  0.00300420,  0.01319209, ...,  0.00867867,  0.02928928,  0.00532537],[ 0.00867943,  0.00134485,  0.00759263, ...,  0.00232504,  0.03548140,  0.01460487],[ 0.00545911,  0.00598109,  0.00665692, ..., -0.00179417,  0.02968132,  0.02142053],...,[ 0.00684091, -0.00332550,  0.01627226, ..., -0.00768097,  0.01906899,  0.00495052],[ 0.00714776,  0.00359129,  0.00969665, ..., -0.00518371,  0.01270603,  0.02092562],[ 0.00996211, -0.00300248,  0.01471507, ..., -0.00518170,  0.00491907,  0.01977420]],[[ 0.00454305,  0.00557371,  0.00644129, ...,  0.00742812,  0.03228268,  0.00123904],[ 0.01620130,  0.01076532,  0.00993367, ...,  0.01371160,  0.02940202,  0.00991743],[ 0.00499136,  0.01006329,  0.00912400, ...,  0.00395425,  0.02999862,  0.00863243],...,[ 0.01431614,  0.00472833,  0.02303233, ..., -0.01040955,  0.02376084,  0.00746928],[ 0.01137663,  0.00380168,  0.02216066, ..., -0.01175049,  0.01654375,  0.01299890],[ 0.00578821,  0.01187050,  0.01773241, ..., -0.01594530,  0.01655355,  0.01706639]],...,[[ 0.00859276, -0.00315960,  0.01166933, ..., -0.00881801,  0.03494525,  0.01678915],[ 0.01290557,  0.00317856,  0.00955159, ..., -0.01111973,  0.03810613,  0.01380979],[ 0.01253064,  0.00664895,  0.02108969, ..., -0.01780477,  0.03851430,  0.01267604],...,[-0.00838452,  0.01019156,  0.01459483, ..., -0.01348441,  0.02485748,  0.01388637],[-0.00674383,  0.00988053,  0.01526485, ..., -0.01920962,  0.02392912,  0.01231303],[ 0.00178657,  0.00858035,  0.01583224, ..., -0.01086871,  0.00722162,  0.01038192]],[[ 0.00699786,  0.00642271,  0.00745836, ..., -0.00313805,  0.03383606,  0.00034484],[ 0.02290456,  0.01042972,  0.01237193, ...,  0.00948680,  0.03684252,  0.00118934],[ 0.01437230,  0.01381475,  0.02222175, ...,  0.01054095,  0.03601860,  0.00561095],...,[-0.00898388,  0.00819177,  0.02560077, ..., -0.01066977,  0.02100313,  0.03179439],[-0.00742552,  0.01296771,  0.02097484, ..., -0.00363565,  0.02471528,  0.01478868],[-0.00114184,  0.00497478,  0.01264016, ..., -0.01004613,  0.00440474,  0.01142078]],[[ 0.00540706,  0.00030641,  0.00700968, ...,  0.00050059,  0.03791260, -0.00777113],[ 0.00696417, -0.00121238,  0.00928910, ...,  0.00291850,  0.03325276, -0.00500016],[ 0.00232116,  0.00674826,  0.01029603, ...,  0.00446389,  0.03645320, -0.00496576],...,[-0.00207149,  0.01562872,  0.02173612, ..., -0.01239564,  0.01776450,  0.01830550],[ 0.00535780,  0.00764524,  0.02233512, ..., -0.01778406,  0.01849146,  0.01427940],[ 0.00507328,  0.00707623,  0.02679622, ..., -0.00783817,  0.01783962,  0.01461937]]])
Tensor(shape=[128, 512], dtype=float32, place=CPUPlace, stop_gradient=False,[[ 0.00594976,  0.00380577,  0.02129058, ...,  0.00907911,  0.03100064,  0.00325108],[ 0.00996211, -0.00300248,  0.01471507, ...,  0.00867867,  0.02928928,  0.00532537],[ 0.00578821,  0.01187050,  0.01773241, ...,  0.00742812,  0.03228268,  0.00123904],...,[ 0.00178657,  0.00858035,  0.01583224, ..., -0.00881801,  0.03494525,  0.01678915],[-0.00114184,  0.00497478,  0.01264016, ..., -0.00313805,  0.03383606,  0.00034484],[ 0.00507328,  0.00707623,  0.02679622, ...,  0.00050059,  0.03791260, -0.00777113]])

模型预测

result = lstm_model.predict(test_loader)
Predict begin...
step 3/3 [==============================] - 38ms/step
Predict samples: 345

GRU

class myGRU(nn.Layer):def __init__(self):super(myGRU, self).__init__()# num_embeddings (int) - 嵌入字典的大小, input中的id必须满足 0 =< id < num_embeddings 。 。# embedding_dim (int) - 每个嵌入向量的维度。# padding_idx (int|long|None) - padding_idx的配置区间为 [-weight.shape[0], weight.shape[0],如果配置了padding_idx,那么在训练过程中遇到此id时会被用# sparse (bool) - 是否使用稀疏更新,在词嵌入权重较大的情况下,使用稀疏更新能够获得更快的训练速度及更小的内存/显存占用。# weight_attr (ParamAttr|None) - 指定嵌入向量的配置,包括初始化方法,具体用法请参见 ParamAttr ,一般无需设置,默认值为None。self.embedding = nn.Embedding(inputs_dim, 256)# input_size (int) - 输入的大小。# hidden_size (int) - 隐藏状态大小。# num_layers (int,可选) - 网络层数。默认为1。# direction (str,可选) - 网络迭代方向,可设置为forward或bidirect(或bidirectional)。默认为forward。# time_major (bool,可选) - 指定input的第一个维度是否是time steps。默认为False。# dropout (float,可选) - dropout概率,指的是出第一层外每层输入时的dropout概率。默认为0。# weight_ih_attr (ParamAttr,可选) - weight_ih的参数。默认为None。# weight_hh_attr (ParamAttr,可选) - weight_hh的参数。默认为None。# bias_ih_attr (ParamAttr,可选) - bias_ih的参数。默认为None。# bias_hh_attr (ParamAttr,可选) - bias_hh的参数。默认为None。self.gru = nn.GRU(256, 256, num_layers=2, direction='bidirectional',dropout=0.5)# in_features (int) – 线性变换层输入单元的数目。# out_features (int) – 线性变换层输出单元的数目。# weight_attr (ParamAttr, 可选) – 指定权重参数的属性。默认值为None,表示使用默认的权重参数属性,将权重参数初始化为0。具体用法请参见 ParamAttr 。# bias_attr (ParamAttr|bool, 可选) – 指定偏置参数的属性。 bias_attr 为bool类型且设置为False时,表示不会为该层添加偏置。 bias_attr 如果设置为True或者None,则表示使用默认的偏置参数属性,将偏置参数初始化为0。具体用法请参见 ParamAttr 。默认值为None。# name (str,可选) – 具体用法请参见 Name ,一般无需设置,默认值为None。self.linear = nn.Linear(in_features=256*2, out_features=2)self.dropout = nn.Dropout(0.5)def forward(self, inputs):emb = self.dropout(self.embedding(inputs))output, hidden = self.gru(emb)#output形状大小为[batch_size,seq_len,num_directions * hidden_size]#hidden形状大小为[num_layers * num_directions, batch_size, hidden_size]#把前向的hidden与后向的hidden合并在一起hidden = paddle.concat((hidden[-2,:,:], hidden[-1,:,:]), axis = 1)hidden = self.dropout(hidden)#hidden形状大小为[batch_size, hidden_size * num_directions]return self.linear(hidden) 

封装模型

GRU_model = paddle.Model(myGRU())

配置优化器等参数

GRU_model.prepare(paddle.optimizer.Adam(learning_rate=0.001, parameters=GRU_model.parameters()),paddle.nn.CrossEntropyLoss(),paddle.metric.Accuracy())

模型训练

GRU_model.fit(train_loader,dev_loader,epochs=10,batch_size=BATCH_SIZE,verbose=1,save_dir="work/GRU")

模型预测

result = GRU_model.predict(test_loader)
Predict begin...
step 3/3 [==============================] - 35ms/step
Predict samples: 345

五、总结

本文主要注重在于PaddlePaddle2.0在nlp基础任务的全流程如何实现,因此并未对两个模型的最终结果进行对比。

GRU和LSTM的性能在很多任务上效果相差不大,不过GRU 参数更少因此更容易收敛,而在数据集很大的情况下,LSTM表达性能更好。

在简单任务上,LSTM和GRU其实都是不错的选择,从完成代码来说,两者差别也不大,都可以简单方便的实现。

运行代码请点击:https://aistudio.baidu.com/aistudio/projectdetail/1491175?shared=1
欢迎三连!

RNN(LSTMGRU)文本分类(PaddlePaddle2.0)相关推荐

  1. TensorFlow2.0教程-使用RNN实现文本分类

    TensorFlow2.0教程-使用RNN实现文本分类 原文地址:https://blog.csdn.net/qq_31456593/article/details/89923645 Tensorfl ...

  2. 【论文复现】使用RNN进行文本分类

    写在前面 这是文本分类任务的第二个系列----基于RNN的文本分类实现(Text RNN) 复现的论文是2016年复旦大学IJCAI 上的发表的关于循环神经网络在多任务文本分类上的应用:Recurre ...

  3. 科大讯飞NLP算法赛baseline:文本分类实践+0.79

    比赛题目:学术论文分类挑战赛 比赛链接:https://challenge.xfyun.cn/topic/info?type=academic-paper-classification&ch= ...

  4. [Python人工智能] 二十.基于Keras+RNN的文本分类vs基于传统机器学习的文本分类

    从本专栏开始,作者正式研究Python深度学习.神经网络及人工智能相关知识.前一篇文章分享了循环神经网络RNN的原理知识,并采用Keras实现手写数字识别的RNN分类案例及可视化呈现.这篇文章作者将带 ...

  5. Tensorflow2.0入门教程22:RNN网络实现文本分类

    RNN实现文本分类 import tensorflow as tf 下载数据集 imdb=tf.keras.datasets.imdb (train_x, train_y), (test_x, tes ...

  6. 【NLP傻瓜式教程】手把手带你RNN文本分类(附代码)

    文章来源于NewBeeNLP,作者kaiyuan 写在前面 这是NLP傻瓜式教程的第二篇----基于RNN的文本分类实现(Text RNN) 参考的的论文是来自2016年复旦大学IJCAI上的发表的关 ...

  7. 【NLP保姆级教程】手把手带你RNN文本分类(附代码)

    写在前面 这是NLP保姆级教程的第二篇----基于RNN的文本分类实现(Text RNN) 参考的的论文是来自2016年复旦大学IJCAI上的发表的关于循环神经网络在多任务文本分类上的应用:Recur ...

  8. 文本分类:Keras+RNN vs传统机器学习

    摘要:本文通过Keras实现了一个RNN文本分类学习的案例,并详细介绍了循环神经网络原理知识及与机器学习对比. 本文分享自华为云社区<基于Keras+RNN的文本分类vs基于传统机器学习的文本分 ...

  9. 自然语言处理入门实战——基于循环神经网络RNN、LSTM、GRU的文本分类(超级详细,学不会找我!!!)

    1  一.实验过程 1.1  实验目的 通过这个课程项目大,期望达到以下目的: 1.了解如何对 自然语言处理 的数据集进行预处理操作. 2.初识自然语言数据集处理操作的步骤流程. 3.进一步学习RNN ...

  10. 第六课.NLP文本分类任务

    第六课目录 NLP文本分类简介 IMDB数据集准备 设置随机种子 下载IMDB并划分数据集 构建词汇表 使用torchtext生成batch WordAveraging 模型定义 加载由glove.6 ...

最新文章

  1. 针对连续动作的DQN
  2. PHP-cli 日志彩色玩法 echo \033[1;33m Hello World. \033[0m \n;
  3. 【bfs】神殿(jzoj 2296)
  4. WSS 3.0 和 sharepoint 2007 中文SDK
  5. C#实现中国天气网XML接口测试
  6. c语言强化训练作业整理1
  7. 关于java AudioInputStream播放短音频没声音的问题
  8. sql server 加密_SQL Server机密–第一部分–加密基础知识和SQL Server加密功能
  9. 列出搜索过的数据(类似京东顶部搜索框)
  10. ORB-SLAM3 yaml文件介绍
  11. 计算机组成原理学习笔记——数据通路
  12. 汇编语言和本地代码及通过编译器输出汇编语言的源代码
  13. Python实现线性函数的拟合算法
  14. 宁海元 mysql_每公斤约360元 宁海香榧可以品尝了
  15. 【工程源码】CYUSB3014芯片使用EEPROM无法下载固件说明
  16. 《Python 3网络爬虫开发实战 》崔庆才著 第三章笔记
  17. 干货 | Elasticsearch 索引生命周期管理 ILM 实战指南
  18. Vuforia-PocketCat丨1. 设计目标及效果展示
  19. MATLAB 林地郁闭度计算
  20. 12、TWS API和IB中的订单管理

热门文章

  1. MATLABR2018自学一本通笔记
  2. “出色”IT项目经理的5大关键能力
  3. Windows任务栏实现动态自动隐藏并透明
  4. 02333软件工程要点及考点
  5. 【PPT技巧】为PPT寻找好看的英文字体(English nice-looking font free)并安装到Windows
  6. 状态机编程思想及实例
  7. 电机驱动软件学习笔记——数据打包解包CRC校验
  8. 软件中断SWI的实现
  9. 【SSM】SSM框架介绍
  10. 一次完整的 Http 请求过程