1.20 Newsgroup数据集介绍

20newsgroups数据集是用于文本分类、文本挖据和信息检索研究的国际标准数据集之一。数据集收集了大约20,000左右的新闻组文档,均匀分为20个不同主题的新闻组集合。一些新闻组的主题特别相似(e.g. comp.sys.ibm.pc.hardware/ comp.sys.mac.hardware),还有一些却完全不相关 (e.g misc.forsale /soc.religion.christian)。

comp.graphics

comp.os.ms-windows.misc

comp.sys.ibm.pc.hardware

comp.sys.mac.hardware

comp.windows.x

rec.autos

rec.motorcycles

rec.sport.baseball

rec.sport.hockey

sci.crypt

sci.electronics

sci.med

sci.space

misc.forsale

talk.politics.misc

talk.politics.guns

talk.politics.mideast

talk.religion.misc

alt.atheism

soc.religion.christian

20newsgroups数据集有三个版本。第一个版本19997是原始的并没有修改过的版本。第二个版本bydate是按时间顺序分为训练(60%)和测试(40%)两部分数据集,不包含重复文档和新闻组名(新闻组,路径,隶属于,日期)。第三个版本18828不包含重复文档,只有来源和主题。

  • 20news-19997.tar.gz –原始20 Newsgroups数据集
  • 20news-bydate.tar.gz –按时间分类; 不包含重复文档和新闻组名(18846 个文档)
  • 20news-18828.tar.gz–  不包含重复文档,只有来源和主题 (18828 个文档)

在sklearn中,该模型有两种装载方式,第一种是sklearn.datasets.fetch_20newsgroups,返回一个可以被文本特征提取器(如sklearn.feature_extraction.text.CountVectorizer)自定义参数提取特征的原始文本序列;第二种是sklearn.datasets.fetch_20newsgroups_vectorized,返回一个已提取特征的文本序列,即不需要使用特征提取器。

2.加载训练好的向量(Glove,100维)

BASE_DIR = './data'
GLOVE_DIR = BASE_DIR + '/glove.6B/'
TEXT_DATA_DIR = BASE_DIR + '/20_newsgroup/'
MAX_SEQUENCE_LENGTH = 1000
MAX_NB_WORDS = 20000
EMBEDDING_DIM = 100
VALIDATION_SPLIT = 0.2
batch_size = 32print('Indexing word vectors.')
embeddings_index = {}
f = open(os.path.join(GLOVE_DIR,'glove.6B.100d.txt'),encoding='utf-8')
for line in f:values = line.split()word = values[0]coefs = np.asarray(values[1:],dtype='float32')embeddings_index[word] = coefs
f.close()print('Found %s word vectors.'%len(embeddings_index) )

3.加载数据集,这里也包括标签,我们是根据文档所在的文件夹用数字进行分类

print('Processing text dataset')texts = []
labels_index = {}
labels = []for name in sorted(os.listdir(TEXT_DATA_DIR)):path = os.path.join(TEXT_DATA_DIR,name)if os.path.isdir(path):label_id = len(labels_index)labels_index[name] = label_id    #每个文件夹给一个IDfor fname in sorted(os.listdir(path)):if fname.isdigit():fpath = os.path.join(path,fname)if sys.version_info<(3,):f = open(fpath)else:f = open(fpath,encoding='latin-1')texts.append(f.read())f.close()labels.append(label_id)
print('Found %s texts.'%len(texts))

4.将文本数据向量化

tokenizer = Tokenizer(nb_words = MAX_NB_WORDS)
tokenizer.fit_on_texts(texts)
sequences = tokenizer.texts_to_sequences(texts)word_index = tokenizer.word_index
print('Found %s unique tokens.'%len(word_index))

5.构造训练集和测试集,这里我们对数据进行了清洗

data = pad_sequences(sequences,maxlen = MAX_SEQUENCE_LENGTH)labels = to_categorical(np.asarray(labels))print('Shape of data tensor:',data.shape)
print('Shape of label tensor:',labels.shape)indices = np.arange(data.shape[0])
np.random.shuffle(indices)
data = data[indices]
labels = labels[indices]
nb_validation_samples = int(VALIDATION_SPLIT*data.shape[0])x_train = data[:-nb_validation_samples]
y_train = labels[:-nb_validation_samples]
x_val = data[-nb_validation_samples:]
y_val = labels[-nb_validation_samples:]print('Preparing embedding matrix.')
print(nb_validation_samples)

6.我们构建了LSTM网络模型,并对其评估

nb_words = min(MAX_NB_WORDS,len(word_index))
embedding_matrix = np.zeros((nb_words +1,EMBEDDING_DIM))
for word,i in word_index.items():if i>MAX_NB_WORDS:continueembedding_vector = embeddings_index.get(word)if embedding_vector is not None:embedding_matrix[i] = embedding_vector
print(embedding_matrix.shape)embedding_layer = Embedding(nb_words+1,EMBEDDING_DIM,weights= [embedding_matrix],input_length = MAX_SEQUENCE_LENGTH,trainable=False,#trainable,由于我们的W是word2vec训练出来的,算作预训练模型,所以就无需训练了。dropout = 0.2
)
batch_size = 32
print('Build model...')model = Sequential()
model.add(embedding_layer)
model.add(LSTM(100,dropout_W = 0.2,dropout_U=0.2))#输出维度 :100
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.add(Dense(len(labels_index),activation='softmax'))model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy']
)
print('Train...')
model.fit(x_train,y_train,batch_size=batch_size,nb_epoch=5,validation_data=(x_val,y_val))score,acc = model.evaluate(x_val,y_val,batch_size=batch_size)print('Test score:',score)
print('Test sccuracy:',acc)

7.训练结果(部分结果)

3936/3999 [============================>.] - ETA: 0s
3968/3999 [============================>.] - ETA: 0s
3999/3999 [==============================] - 52s 13ms/step
Test score: 0.18472591743048325
Test sccuracy: 0.9499999885411226

Keras LSTM对20 Newsgroups数据集进行分类相关推荐

  1. 基于Jupyter Notebook---卷积神经网络的图像分类(keras对猫狗图像数据集进行分类)

    keras对猫狗图像数据集进行分类 一.安装Keras 二.keras对猫狗图像数据集进行分类 三.全部代码 四.结果显示 一.安装Keras 在windows下,首先添加中科大源 命令行中直接使用以 ...

  2. 20 Newsgroups数据集介绍

    源自如http://qwone.com/~jason/20Newsgroups/. 20newsgroups数据集是用于文本分类.文本挖据和信息检索研究的国际标准数据集之一.数据集收集了大约20,00 ...

  3. keras对猫、狗数据集进行分类(三)

    目录 keras对猫.狗数据集进行分类(一) keras对猫.狗数据集进行分类(二) keras对猫.狗数据集进行分类(三) 数据集:关注下方公众号,发送 猫狗数据集 进行获取

  4. 20 Newsgroups数据集

    原文: The 20 Newsgroups data set is a collection of approximately 20,000 newsgroup documents, partitio ...

  5. Newsgroups数据集介绍

    源自如http://qwone.com/~jason/20Newsgroups/. 20newsgroups数据集是用于文本分类.文本挖据和信息检索研究的国际标准数据集之一.数据集收集了大约20,00 ...

  6. ML之SVM:利用SVM算法(超参数组合进行多线程网格搜索+3fCrVa)对20类新闻文本数据集进行分类预测、评估

    ML之SVM:利用SVM算法(超参数组合进行多线程网格搜索+3fCrVa)对20类新闻文本数据集进行分类预测.评估 目录 输出结果 设计思路 核心代码 输出结果 Fitting 3 folds for ...

  7. ML之SVM:利用SVM算法(超参数组合进行单线程网格搜索+3fCrVa)对20类新闻文本数据集进行分类预测、评估

    ML之SVM:利用SVM算法(超参数组合进行单线程网格搜索+3fCrVa)对20类新闻文本数据集进行分类预测.评估 目录 输出结果 设计思路 核心代码 输出结果 Fitting 3 folds for ...

  8. ML之NB:利用朴素贝叶斯NB算法(TfidfVectorizer+不去除停用词)对20类新闻文本数据集进行分类预测、评估

    ML之NB:利用朴素贝叶斯NB算法(TfidfVectorizer+不去除停用词)对20类新闻文本数据集进行分类预测.评估 目录 输出结果 设计思路 核心代码 输出结果 设计思路 核心代码 class ...

  9. 使用SVM对分泌效应蛋白数据集进行分类预测

    1.SVM简介 支持向量机(Support Vector Machine, SVM)是一类按监督学习(supervised learning)方式对数据进行二元分类的广义线性分类器(generaliz ...

最新文章

  1. Stanford UFLDL教程 卷积特征提取
  2. varnish缓存服务器构建疑问
  3. 国科大prml--SVM
  4. 使用CMake生成sln项目和VS工程遇到的问题
  5. 技本功丨用短平快的方式告诉你:Flink-SQL的扩展实现
  6. 用python写网络爬虫 -从零开始 4 用正则表达式 编写链接爬虫
  7. Pytorch搭建自己的模型
  8. 通过注册表修改我的文档等系统文件夹默认位置
  9. Alexa交叉编译(avs-device-sdk)
  10. python3网络爬虫--爬取华为应用市场app数据(附源码)
  11. 【重点】心田花开:三年级教材知识点汇总
  12. 基于TBSS的DTI数据处理流程
  13. oracle 京东,【京东工资】oracle dba待遇-看准网
  14. SQL Server 2012 安装包
  15. Ubuntu下安装Matlab并破解
  16. F功能键必须按Fn才管用,如何设置不按Fn就直接使用F键功能
  17. WebApi测试工具:PostMan
  18. iOS TabBar中间凸起实践
  19. 阅读APP开发的发展现状
  20. 【Cocos Creator】 摄像机移动碰到的一些问题

热门文章

  1. java接口自动化-post请求获取不到cookie问题解决
  2. BIO NIO AIO 介绍与差别
  3. 飞行堡垒fx80g拆卸电源_华硕飞行堡垒第五代FX80拆机加装内存条教程(整盖翻转拆机)...
  4. php最大的优点,_____是PHP的最大优点
  5. 2017年的端午节祝福语
  6. 写一个判断素数的函数,在主函数输入一个整数,输出是否为素数的信息。
  7. ArcGis配色心得
  8. Eclipse中properties配置文件的中文乱码
  9. 琐碎的知识点(xly)
  10. cocos2d-x基本面试题