一、前言

1、Skip-Thought-Vector论文

2、本文假设读者已了解Skip-Gram-Vector和RNN相关基础,以下文章可做参考:

(1)RNN古诗词生成

(2)Skip-Gram-Vector

(3)LSTM/GRU门控机制

二、实战

1、数据处理

(1)网络小说《神墓》,基于版权原因,请自行寻找数据源

(2)先对特殊符号进行处理,将整本小说按行分割成一个列表

    def _process_words(file_list):words = ''.join(file_list)vocab = sorted(set(words))mask = vocab[:110]+vocab[-57:]mark = ['!', ',', ':', ';', '?', '~', '…', '、', '。', '.', '?', ';', ':', '.', ',', '!']for m in mask:words = words.replace(m, '\\') if m in mark else words.replace(m, '')return words

(3)分割后的句子可能出现较多重复且意义不大的句子(如,啊,哈哈,等),对模型产生噪音。这里把高频句子剔除,用以下公式计算删除该句子的概率:

其中f(w)代表该句子出现的概率,t为一个阈值。

    def _process_sentence_list(sentence_list, t=1e-5, threshold=0.5):sentence_count = Counter(sentence_list)total_count = len(sentence_list)# 计算句子频率sentence_freqs = {w: c / total_count for w, c in sentence_count.items()}# 计算被删除的概率prob_drop = {w: 1 - np.sqrt(t / sentence_freqs[w]) for w in sentence_count}# 剔除出现频率太高的句子sentence_list = [w for w in sentence_list if prob_drop[w] < threshold]return sentence_list

上述代码基于概率进行了采样,减少了训练样本中的噪音。
(4)生成包含所有字的字典,添加特殊字符‘<PAD>’作为占位符,‘<UNK>’代替未在字典中出现的字,‘<GO>’代表句子的开始,'<EOS>'作为句子的结束。

    def _get_vocab(self):# 生成词字典special_words = ['<PAD>', '<UNK>', '<GO>', '<EOS>']words = ''.join(self.sentence_list)vocab = sorted(set(words))+special_wordsword_to_int = {w: i for i, w in enumerate(vocab)}int_to_word = {i: w for i, w in enumerate(vocab)}return vocab, word_to_int, int_to_word

基于上述代码可以将每一个句子转为数字向量。

(5)Skip-Thought-Vector借鉴了Skip-Gram-Vector的思想,这里选取窗口的大小都规定为1,所以其实是取句子的上一句及下一句

    def _get_target(sentences, index, window_size=1):# 获取句子相邻句子start = index - window_size if (index - window_size) > 0 else 0end = index + 2*window_sizetargets = set(sentences[start:index] + sentences[index+1:end])return list(targets)

(6)构造一个生成器,按照batch_size将文本列表分割为大小相等的训练batch。由于每个batch中的句子字数不一定相等,这里还需要将句子缺失部分进行padding,具体代码在我的github上可以看到

2、Skip-Thought-Vector

首先看Skip-Thought-Vector的示意图,

模型分为两个部分,encoder对句子进行encode,将final state传递到decoder,deocoder分别对当前句子的上一句及下一句进行decode。这是一个经典的encode-decode框架,原论文每个encoder、decoder使用了GRU-RNN,这里我们使用简单的LSTM来实现,两者的不同可参考前言提供链接。

3、模型输入定义

    def build_inputs():with tf.variable_scope('inputs'):# 句子encode = tf.placeholder(tf.int32, shape=[None, None], name='encode')encode_length = tf.placeholder(tf.int32, shape=[None, ], name='encode_length')# 句子的前一句decode_pre_x = tf.placeholder(tf.int32, shape=[None, None], name='decode_pre_x')decode_pre_y = tf.placeholder(tf.int32, shape=[None, None], name='decode_pre_y')decode_pre_length = tf.placeholder(tf.int32, shape=[None, ], name='decode_pre_length')# 句子的后一句decode_post_x = tf.placeholder(tf.int32, shape=[None, None], name='decode_post_x')decode_post_y = tf.placeholder(tf.int32, shape=[None, None], name='decode_post_y')decode_post_length = tf.placeholder(tf.int32, shape=[None, ], name='decode_post_length')return encode, decode_pre_x, decode_pre_y, decode_post_x, decode_post_y, encode_length, decode_pre_length, decode_post_length

由于我们每个batch中句子都进行了padding,为了防止padding对训练的影响,这里需要传递掩码给到RNN网络--每个句子各自的原始长度。

4、对输入句子进行embedding

    def build_word_embedding(self, encode, decode_pre_x, decode_post_x):# embeddingwith tf.variable_scope('embedding'):embedding = tf.get_variable(name='embedding', shape=[len(self.vocab), self.embedding_dim],initializer=tf.random_uniform_initializer(-0.1, 0.1))encode_emb = tf.nn.embedding_lookup(embedding, encode, name='encode_emb')decode_pre_emb = tf.nn.embedding_lookup(embedding, decode_pre_x, name='decode_pre_emb')decode_post_emb = tf.nn.embedding_lookup(embedding, decode_post_x, name='decode_post_emb')return encode_emb, decode_pre_emb, decode_post_emb

将句子中的每一个字都转化为vocab size长度的向量。

5、构建encoder

encoder对句子进行encode,得到最终的hidden state,这里采用了单层的LSTM网络,传递sequence_length作为掩码,去除padding的干扰,提高训练速度

    def build_encoder(self, encode_emb, length, train=True):batch_size = self.batch_size if train else 1with tf.variable_scope('encoder'):cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=self.num_units)initial_state = cell.zero_state(batch_size, tf.float32)_, final_state = tf.nn.dynamic_rnn(cell, encode_emb, initial_state=initial_state, sequence_length=length)return initial_state, final_state

6、构建decoder

需要分别建立两个decoder,代码是一样的,也采用了单层的LSTM网络,然后对输出进行一次全连接,得到logits,再进行softmax分类。需要注意这里w,b两个deocoder是共享的,得到预测输出

    def build_decoder(self, decode_emb, length, state, scope='decoder', reuse=False):with tf.variable_scope(scope):cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=self.num_units)outputs, final_state = tf.nn.dynamic_rnn(cell, decode_emb, initial_state=state, sequence_length=length)x = tf.reshape(outputs, [-1, self.num_units])w, b = self.soft_max_variable(self.num_units, len(self.vocab), reuse=reuse)logits = tf.matmul(x, w) + bprediction = tf.nn.softmax(logits, name='predictions')return logits, prediction, final_state

7、构建损失网络

这里用soft_max_entropy_with_logits进行交叉熵计算并进行softmax操作

    def build_loss(self, logits, targets, scope='loss'):with tf.variable_scope(scope):y_one_hot = tf.one_hot(targets, len(self.vocab))y_reshaped = tf.reshape(y_one_hot, [-1, len(self.vocab)])loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_reshaped))return loss

8、构建优化网络

加上梯度剪切防止梯度爆炸,进行最小化损失优化

    def build_optimizer(self, loss, scope='optimizer'):with tf.variable_scope(scope):grad_clip = 5# 使用clipping gradientstvars = tf.trainable_variables()grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars), grad_clip)train_op = tf.train.AdamOptimizer(self.learning_rate)optimizer = train_op.apply_gradients(zip(grads, tvars))return optimizer

9、训练

encoder-decoder模型中,encoder的initial state应该为上一个docoder的final state,这里用post deocoder的final state作为输入,进行训练,具体代码可以github上看到,这里就不贴了

10.生成结果

辰南与无尽星光闪耀
他的身体在刹那间变大
在刹那间他们感觉到了一股强大的窒息感
一声大喝
在这片区域后
他们似乎已经不是他的一个世界
这是一片奇异的世界
这里是葬天之所
他们知道这些人都知道
但是却没有人敢与我一战

这里训练了25个循环,耗时5个小时,效果并不是特别好,增加训练循环次数或者将decoder的网络层数适量增加或许会有更好的效果

三、其他

具体代码可以在我的github上找到:https://github.com/lpty/tensorflow_tutorial

基于Skip-Thought的Sentence2Vec神经网络实现相关推荐

  1. 基于skip_thoughts vectors 的sentence2vec神经网络实现

    1.论文摘要 我们描述了一种通用.分布式句子编码器的无监督学习方法.使用从书籍中提取的连续文本,我们训练了一个编码器-解码器模型,试图重建编码段落周围的句子.语义和语法属性一致的句子因此被映射到相似的 ...

  2. 基于PaddlePaddle构建ResNet18残差神经网络的食物图片分类问题

    基于PaddlePaddle构建ResNet18残差神经网络的食物图片分类问题 Introduction 本项目是在李宏毅机器学习课程的作业3进行的工作,任务是手动搭建一个CNN模型进行食物图片分类( ...

  3. 从LSTM到GRU基于门控的循环神经网络总结

    1.概述 为了改善基本RNN的长期依赖问题,一种方法是引入门控机制来控制信息的累积速度,包括有选择性地加入新的信息,并有选择性遗忘之前累积的信息.下面主要介绍两种基于门控的循环神经网络:长短时记忆网络 ...

  4. 【遗传优化BP网络】基于自适应遗传算法的BP神经网络的股票预测MATLAB仿真

    1.软件版本 MATLAB2021a 2.本算法理论知识 通过MATLAB对BP神经网络,基于遗传优化的BP神经网络,基于改进遗传优化的BP神经网络以及基于改进遗传优化的组合BP神经网络等多种算法的股 ...

  5. DL之DNN:基于自定义数据集利用深度神经网络(输入层(10个unit)→2个隐藏层(10个unit)→输出层1个unit)实现回归预测实现代码

    DL之DNN:基于自定义数据集利用深度神经网络(输入层(10个unit)→2个隐藏层(10个unit)→输出层1个unit)实现回归预测实现代码 目录 基于自定义数据集利用深度神经网络(输入层(10个 ...

  6. MAT之GRNN/PNN:基于GRNN、PNN两神经网络实现并比较鸢尾花(iris数据集)种类识别正确率、各个模型运行时间对比

    MAT之GRNN/PNN:基于GRNN.PNN两神经网络实现并比较鸢尾花(iris数据集)种类识别正确率.各个模型运行时间对比 目录 输出结果 实现代码 输出结果 实现代码 load iris_dat ...

  7. IDRLnet: 基于内嵌物理知识神经网络的开源求解框架

    " 点击蓝字 / 关注我们 " 编者按 为解决飞行器设计优化过程中物理场快速仿真问题和运行监测阶段物理场精确反演问题,国防科技创新研究院无人系统技术研究中心智能设计与鲁棒学习团队推 ...

  8. 基于S函数的BP神经网络PID控制器及simulink仿真

    基于S函数的BP神经网络PID控制器及simulink仿真 文章目录 文章来源和摘要 S函数的编写格式和运行步骤 simulink模型结构 S函数模型初始化部分代码理解 S函数模型更新部分 S函数模型 ...

  9. 卷积神经网络训练准确率突然下降_基于联邦学习和卷积神经网络的入侵检测方法...

    王蓉1,马春光2,武朋2 1. 哈尔滨工程大学计算机科学与技术学院,哈尔滨 150001:2. 山东科技大学计算机科学与工程学院,青岛 266590 doi :10.3969/j.issn.1671- ...

  10. GAPNet基于图注意力的点神经网络用于局域特征探索

    2篇GAPNet基于图注意力的点神经网络用于局域特征探索 转载文章 https://blog.csdn.net/u014636245/article/details/90478608 版权声明:本文为 ...

最新文章

  1. MongoDB(课时18 修改器)
  2. 全面介绍Windows内存管理机制及C++内存分配实例(一):进程空间
  3. wetool个人版_个人版wetool -公众号
  4. 思科智能交换机受多个严重漏洞影响
  5. discuz子导航下面的版块只有图标没有标题,什么原因?
  6. maven 本地仓库的配置
  7. 用ubantu14.04登录吉大校园网
  8. 甲骨文再传裁员,补偿N+6,昔日硅谷巨头缘何败走中国
  9. python-利用python写一个购物小程序
  10. wav文件隐写:Deepsound+TIFF图片PS处理( AntCTF x D^3CTF 2022 misc BadW3ter)
  11. 计算机网络提升培训心得体会,计算机网络培训心得体会.doc
  12. DAO设计模式之禅之数据库万能查询操作
  13. 2021年“上海区块链周”参会随感(二)2021-04-12
  14. dataframe排序 pd.rank()
  15. 软工网络15个人作业3——案例分析
  16. 【小程序】滚动到指定位置
  17. qlv转mp4格式工厂失败
  18. 利用 Amazon IoT Greengrass 在边缘 DIY 自动浇花系统
  19. Android 开发之漫漫长途 XIV——ListView
  20. 定调!深度解读央行DC/EP数字货币在28省市深化试点背后的逻辑

热门文章

  1. 用计算机用u盘怎么切换,更换电脑硬盘后如何用U盘重装系统?
  2. BurpWeb安全学院之不安全的反序列化
  3. Linux系统中的FTP服务配置与管理
  4. WebGoat (A5) Broken Access Control -- Missing Function Level Access Control (缺少功能级访问控制)
  5. 元月元日是哪一天_元日指的是哪一天?
  6. 实战绕过两层waf完成sql注入
  7. JetsonNano学习(二)环境配置
  8. Java中流的分类有哪些
  9. java企业微信发送语言_java微信企业号开发之发送消息(文本、图片、语音)
  10. 积分图实现快速均值滤波