本文作者:合肥工业大学 管理学院 钱洋 email:1563178220@qq.com 内容可能有不到之处,欢迎交流。

未经本人允许禁止转载。

训练文档向量

在上一小节中,本人介绍了使用DeepLearning4J训练得到词向量(https://blog.csdn.net/qy20115549/article/details/82152462)。本篇主要介绍给定任意文本数据(分词后的数据),如何使用DeepLearning4J训练得到文档的向量。

在平时的使用中,我们可以将文档转化成向量形式,进而进行聚类分类等其他操作。常用的将文档转化成向量形式的方法有one-hot编码、TF-IDF编码、主题模型(LDA)以及本篇要介绍的Doc2Vec操作。如下为笔者使用的文本数据:


对应的操作程序如下:

package org.deeplearning4j.examples.nlp.paragraphvectors;import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.documentiterator.LabelsSource;
import org.deeplearning4j.text.sentenceiterator.LineSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;public class Doc2VecTest {private static Logger log = LoggerFactory.getLogger(Doc2VecTest.class);//文档向量输出路径private static String outputPath = "E:/doc2vec.txt";public static void main(String[] args) throws Exception {//输入文本文件的目录File inputTxt = new File("E:/raw_sentences.txt");log.info("开始加载数据...."+inputTxt.getName());//加载数据SentenceIterator iter = new LineSentenceIterator(inputTxt);//切词操作TokenizerFactory token = new DefaultTokenizerFactory();//去除特殊符号及大小写转换操作token.setTokenPreProcessor(new CommonPreprocessor());AbstractCache<VocabWord> cache=new AbstractCache<>();//添加文档标签,这个一般从文件读取,为了方面我这里使用了数字List<String> labelList = new ArrayList<String>();for (int i = 0; i < 97162; i++) {labelList.add("doc"+i);}//设置文档标签LabelsSource source = new LabelsSource(labelList);log.info("训练模型....");ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(5).epochs(1).layerSize(100).learningRate(0.025).labelsSource(source).windowSize(5).iterate(iter).trainWordVectors(false).vocabCache(cache).tokenizerFactory(token).sampling(0).build();vec.fit();log.info("相似的句子:");Collection<String> lst = vec.wordsNearest("doc0", 10);System.out.println(lst);log.info("输出文档向量....");WordVectorSerializer.writeWordVectors(vec, outputPath);//获取某词对应的向量log.info("向量获取:");double[] docVector = vec.getWordVector("doc0");System.out.println(Arrays.toString(docVector));}
}

程序在控制台输出的结果为:


另外,我们也成功的将每篇文档对应的向量输出到本地文件中,如下图所示为每篇文档对应的向量结果:


改写输出方式

另外,我们也可以写一个操作方法,是的输出结果按照每个人的需求来,比如我个人的需求是:

文档内容 Tab键分割 向量内容

那么,上述的程序可以重写为:

package com.qian;import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.documentiterator.LabelsSource;
import org.deeplearning4j.text.sentenceiterator.LineSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;public class Doc2VecTest {private static Logger log = LoggerFactory.getLogger(Doc2VecTest.class);//文档向量输出路径private static String outputPath = "data/doc2vec.txt";private static String inputPath = "data/raw_sentences.txt";public static void main(String[] args) throws Exception {//输入文本文件的目录File inputTxt = new File(inputPath);log.info("开始加载数据...." + inputTxt.getName());//加载数据SentenceIterator iter = new LineSentenceIterator(inputTxt);//切词操作TokenizerFactory token = new DefaultTokenizerFactory();//去除特殊符号及大小写转换操作token.setTokenPreProcessor(new CommonPreprocessor());AbstractCache<VocabWord> cache=new AbstractCache<>();//添加文档标签,这个一般从文件读取,为了方面我这里使用了数字List<String> labelList = new ArrayList<String>();for (int i = 1; i < 97163; i++) {labelList.add("doc"+i);}//设置文档标签LabelsSource source = new LabelsSource(labelList);log.info("训练模型....");ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(5).epochs(1).layerSize(50).learningRate(0.025).labelsSource(source).windowSize(5).iterate(iter).trainWordVectors(false).vocabCache(cache).tokenizerFactory(token).sampling(0).build();vec.fit();log.info("相似的句子:");Collection<String> lst = vec.wordsNearest("doc1", 10);System.out.println(lst);log.info("输出文档向量....");writeDocVectors(vec,outputPath);//获取某词对应的向量log.info("向量获取:");double[] docVector = vec.getWordVector("doc1");System.out.println(Arrays.toString(docVector));}public static void writeDocVectors(ParagraphVectors vectors, String outpath) throws IOException {//写操作BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(new File(outpath)),"gbk"));//读操作BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(new File(inputPath)), "gbk"));String line = null;int i = 1;Map<String, String> keyToDoc = new HashMap<>();while ((line = bufferedReader.readLine())!=null) {keyToDoc.put("doc" + i, line);i++;}VocabCache<VocabWord> vocabCache = vectors.getVocab();for (VocabWord word : vocabCache.vocabWords()) {StringBuilder builder = new StringBuilder();//获取每个文档对应的标签INDArray vector = vectors.getWordVectorMatrix(word.getLabel());//向量添加for (int j = 0; j < vector.length(); j++) {builder.append(vector.getDouble(j));if (j < vector.length() - 1) {builder.append(" ");}}//写入指定文件bufferedWriter.write(keyToDoc.get(word.getLabel()) + "\t" + builder.append("\n").toString());}bufferedWriter.close();bufferedReader.close();}
}

程序的输出结果如下图所示,前面是文档后面是文档对应的向量:


使用deeplearning4j训练Doc2Vec(文档向量)相关推荐

  1. doc2vec 文档向量

    目录 1 目的和思想 2 模型原理 2.1 PV-DM(段落向量的分布式存储模型) 2.2 PV-DBOW (段落向量的分布式单词包版本) 3 doc2vec 总结 4 应用任务 1 目的和思想 do ...

  2. 【Python自然语言处理】文本向量化的六种常见模型讲解(独热编码、词袋模型、词频-逆文档频率模型、N元模型、单词-向量模型、文档-向量模型)

    觉得有帮助请点赞关注收藏~~~ 一.文本向量化 文本向量化:将文本信息表示成能够表达文本语义的向量,是用数值向量来表示文本的语义. 词嵌入(Word Embedding):一种将文本中的词转换成数字向 ...

  3. 基于gensim的Deep learning with paragraph2vec 官方models.doc2vec文档解释

    ♦版权声明:转载时请注明出处URL,谢谢大家~ ♦文章声明:博主为在校生,基于学习兴趣作此文章,与大家分享.水平有限,恳请大家批评指正~ gensim介绍 [官网] gensim是一款强大的自然语言处 ...

  4. Keras深度学习实战(26)——文档向量详解

    Keras深度学习实战(26)--文档向量详解 0. 前言 1. 文档向量基本概念 2. 神经网络模型与数据集分析 2.1 模型分析 2.2 数据集介绍 3. 利用 Keras 构建神经网络模型生成文 ...

  5. 中文自然语言处理--文档向量Doc2Vec

    Doc2Vec 是 Mikolov 在 Word2Vec 基础上提出的另一个用于计算长文本向量的工具,Doc2Vec 将文档语料通过一个固定长度的向量表达. from gensim.models.do ...

  6. torch dataloader 数据并行_PyTorch Parallel Training(单机多卡并行、混合精度、同步BN训练指南文档)

    0 写在前面 这篇文章是我做实验室组会汇报的时候顺带整理的文档,在1-3部分参考了很多知乎文章,感谢这些大佬们的工作,所以先贴出Reference,本篇文章结合了这些内容,加上了我的一些理解,不足之处 ...

  7. WMD:基于词向量的文档相似度计算

    EMD算法简介 该部分引用自[1] Earth Mover's Distance (EMD),和欧氏距离一样,他们都是一种距离度量的定义,可以用来测量某分布之间的距离.EMD主要应用在图像处理和语音信 ...

  8. NLP︱句子级、词语级以及句子-词语之间相似性(相关名称:文档特征、词特征、词权重)

    每每以为攀得众山小,可.每每又切实来到起点,大牛们,缓缓脚步来俺笔记葩分享一下吧,please~ --------------------------- 关于相似性以及文档特征.词特征有太多种说法.弄 ...

  9. 多标签文本分类数据集_标签感知的文档表示用于多标签文本分类(EMNLP 2019)...

    原文: Label-Specific Document Representation for Multi-Label Text Classification(EMNLP 2019) 多标签文本分类 摘要: ...

最新文章

  1. linux上使用crontab任务调度
  2. 学python要多少钱-Python培训一般要多少钱?
  3. JMeter学习笔记--JMeter监听器
  4. 2.11 总结-深度学习第二课《改善深层神经网络》-Stanford吴恩达教授
  5. if else 简写_15+ JS简写骚操作,让你的代码“秀”起来??
  6. php 中如何重载父类的方法_PHP中子类重载父类的方法【parent::方法名】
  7. MyBatis中ThreadLocal
  8. 天津计算机本科学校有哪些专业吗,天津哪些大学有人工智能专业
  9. Fragment懒加载(三)
  10. 使用SimpleDateFormat出现时差
  11. wxpython 基本的控件 (文本)
  12. Java 8新特性探究(十一)Base64详解
  13. 新手如何在CSDN上写博客
  14. 认同和确定性矩阵(Ralph Stacey's Agreement and Certainty Matrix)-译
  15. 使用最新TexLive2020+VsCode来编写Latex论文(假设使用:CVPR2018)
  16. 基于javaee的养老保险管理系统
  17. 金蝶K/3产品各版本引入/引出Excel文件时出现意外错误的提示,或未正确安装Excel的提示。微软补丁解决方案!(转)
  18. Unity根据文字内容自动滚动显示最新文字
  19. Codeforces Global Round 7 E. Bombs(线段树)
  20. Python:OpenCV4识别一个蓝色的圆并估算到相机的距离

热门文章

  1. template 模板是怎样通过 Compile 编译的
  2. 面试官系统精讲Java源码及大厂真题 - 47 工作实战:Socket 结合线程池的使用
  3. Docker安装Kafka(docker-compose.yml)
  4. ZooKeeper配额指南
  5. 解决硬盘文件目录损坏且无法读取
  6. 个推成为首家支持统一推送接口标准的第三方推送服务商!
  7. Windows环境下多个tomcat启动,CATALINA_HOME配置(大坑)
  8. 函数式编程 -- 函子(Functor)
  9. 【Python】GUI编程(Tkinter)教程
  10. 【Python】Python库之数据可视化