机器学习训练营——机器学习爱好者的自由交流空间(入群联系qq:2279055353)

这个例子显示scikit-learn怎样进行OOC(out-of-core)分类。所谓核外方法(OOC approach), 指的是从未经内存的数据学习。在这里,我们利用一个支持partial_fit方法的在线分类器学习。为了确保特征空间在不同的时刻仍是相同的,我们利用HashingVectorizer, 它把每个例子投射到相同的特征空间。这在文本分类,当新特征出现在每一批次的例子里时,是特别有用的。本例使用的数据集Reuters-21578来自UCI机器学习数据库(UCI ML repository).

图结果反映了分类器的学习曲线:分类准确率随着最小批量的变化情况。这里的分类准确率是在前1000个样本组成的验证集上测量的。为了限制内存消耗,在把例子送进学习器之前,我们把它们排队成固定的数量。

数据集介绍

Reuters-21578文本类别数据集,收集自1987年的路透社(英国最大的通讯社)文档,这些文档被按类别编目。该数据集整理后包括21,578个实例,5个属性。这5个属性分别表示5份文件,列出了每个文档里所有合法目录的名字。

实例详解

首先,加载必需的库。

# Authors: Eustache Diemert <eustache@diemert.fr>
#          @FedericoV <https://github.com/FedericoV/>
# License: BSD 3 clausefrom __future__ import print_function
from glob import glob
import itertools
import os.path
import re
import tarfile
import time
import sysimport numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParamsfrom sklearn.externals.six.moves import html_parser
from sklearn.externals.six.moves.urllib.request import urlretrieve
from sklearn.datasets import get_data_home
from sklearn.feature_extraction.text import HashingVectorizer
from sklearn.linear_model import SGDClassifier
from sklearn.linear_model import PassiveAggressiveClassifier
from sklearn.linear_model import Perceptron
from sklearn.naive_bayes import MultinomialNBdef _not_in_sphinx():# Hack to detect whether we are running by the sphinx builderreturn '__file__' in globals()

Reuters数据集有关程序

类 ReutersParser

定义一个ReutersParser类,它是一个工具类,作用是从语法上分割SGML(标准通用标记语言,定义独立于平台和应用的文本文档格式)文件,一次产生一个文档。

class ReutersParser(html_parser.HTMLParser):"""Utility class to parse a SGML file and yield documents one at a time."""def __init__(self, encoding='latin-1'):html_parser.HTMLParser.__init__(self)self._reset()self.encoding = encodingdef handle_starttag(self, tag, attrs):method = 'start_' + taggetattr(self, method, lambda x: None)(attrs)def handle_endtag(self, tag):method = 'end_' + taggetattr(self, method, lambda: None)()def _reset(self):self.in_title = 0self.in_body = 0self.in_topics = 0self.in_topic_d = 0self.title = ""self.body = ""self.topics = []self.topic_d = ""def parse(self, fd):self.docs = []for chunk in fd:self.feed(chunk.decode(self.encoding))for doc in self.docs:yield docself.docs = []self.close()def handle_data(self, data):if self.in_body:self.body += dataelif self.in_title:self.title += dataelif self.in_topic_d:self.topic_d += datadef start_reuters(self, attributes):passdef end_reuters(self):self.body = re.sub(r'\s+', r' ', self.body)self.docs.append({'title': self.title,'body': self.body,'topics': self.topics})self._reset()def start_title(self, attributes):self.in_title = 1def end_title(self):self.in_title = 0def start_body(self, attributes):self.in_body = 1def end_body(self):self.in_body = 0def start_topics(self, attributes):self.in_topics = 1def end_topics(self):self.in_topics = 0def start_d(self, attributes):self.in_topic_d = 1def end_d(self):self.in_topic_d = 0self.topics.append(self.topic_d)self.topic_d = ""

函数 stream_reuters_documents

函数stream_reuters_documents迭代Reuters数据集文档。它有一个参数data_path, 表示数据集的本地位置路径,默认值是None. 如果该参数取默认值,那么将自动从UCI数据库下载并解压缩。文档用字典型对象表示,有三个键:‘body’, ‘title’ and ‘topics’.

def stream_reuters_documents(data_path=None):"""Iterate over documents of the Reuters dataset.The Reuters archive will automatically be downloaded and uncompressed ifthe `data_path` directory does not exist.Documents are represented as dictionaries with 'body' (str),'title' (str), 'topics' (list(str)) keys."""DOWNLOAD_URL = ('http://archive.ics.uci.edu/ml/machine-learning-databases/''reuters21578-mld/reuters21578.tar.gz')ARCHIVE_FILENAME = 'reuters21578.tar.gz'if data_path is None:data_path = os.path.join(get_data_home(), "reuters")if not os.path.exists(data_path):"""Download the dataset."""print("downloading dataset (once and for all) into %s" %data_path)os.mkdir(data_path)def progress(blocknum, bs, size):total_sz_mb = '%.2f MB' % (size / 1e6)current_sz_mb = '%.2f MB' % ((blocknum * bs) / 1e6)if _not_in_sphinx():sys.stdout.write('\rdownloaded %s / %s' % (current_sz_mb, total_sz_mb))archive_path = os.path.join(data_path, ARCHIVE_FILENAME)urlretrieve(DOWNLOAD_URL, filename=archive_path,reporthook=progress)if _not_in_sphinx():sys.stdout.write('\r')print("untarring Reuters dataset...")tarfile.open(archive_path, 'r:gz').extractall(data_path)print("done.")parser = ReutersParser()for filename in glob(os.path.join(data_path, "*.sgm")):for doc in parser.parse(open(filename, 'rb')):yield doc

主程序

主程序产生向量对象,限制特征数的上限到一个合理的值。

vectorizer 对象

vectorizer = HashingVectorizer(decode_error='ignore', n_features=2 ** 18,alternate_sign=False)# Iterator over parsed Reuters SGML files.
data_stream = stream_reuters_documents()# We learn a binary classification between the "acq" class and all the others.
# "acq" was chosen as it is more or less evenly distributed in the Reuters
# files. For other datasets, one should take care of creating a test set with
# a realistic portion of positive instances.
all_classes = np.array([0, 1])
positive_class = 'acq'# Here are some classifiers that support the `partial_fit` method
partial_fit_classifiers = {'SGD': SGDClassifier(max_iter=5),'Perceptron': Perceptron(tol=1e-3),'NB Multinomial': MultinomialNB(alpha=0.01),'Passive-Aggressive': PassiveAggressiveClassifier(tol=1e-3),
}

函数 get_minibatch

函数get_minibatch规定提取实例的最小批次数量,返回一个元组对象X_text, y.

def get_minibatch(doc_iter, size, pos_class=positive_class):"""Extract a minibatch of examples, return a tuple X_text, y.Note: size is before excluding invalid docs with no topics assigned."""data = [(u'{title}\n\n{body}'.format(**doc), pos_class in doc['topics'])for doc in itertools.islice(doc_iter, size)if doc['topics']]if not len(data):return np.asarray([], dtype=int), np.asarray([], dtype=int)X_text, y = zip(*data)return X_text, np.asarray(y, dtype=int)

函数 iter_minibatches

函数iter_minibatches作为一个最小批次生成器。

def iter_minibatches(doc_iter, minibatch_size):"""Generator of minibatches."""X_text, y = get_minibatch(doc_iter, minibatch_size)while len(X_text):yield X_text, yX_text, y = get_minibatch(doc_iter, minibatch_size)

检验统计量

产生检验统计量,检验1000个文档,估计准确率。

# test data statistics
test_stats = {'n_test': 0, 'n_test_pos': 0}# First we hold out a number of examples to estimate accuracy
n_test_documents = 1000
tick = time.time()
X_test_text, y_test = get_minibatch(data_stream, 1000)
parsing_time = time.time() - tick
tick = time.time()
X_test = vectorizer.transform(X_test_text)
vectorizing_time = time.time() - tick
test_stats['n_test'] += len(y_test)
test_stats['n_test_pos'] += sum(y_test)
print("Test set is %d documents (%d positive)" % (len(y_test), sum(y_test)))

函数 progress

函数progress报告检验的过程信息,返回一个字符串。

def progress(cls_name, stats):"""Report progress information, return a string."""duration = time.time() - stats['t0']s = "%20s classifier : \t" % cls_names += "%(n_train)6d train docs (%(n_train_pos)6d positive) " % statss += "%(n_test)6d test docs (%(n_test_pos)6d positive) " % test_statss += "accuracy: %(accuracy).3f " % statss += "in %.2fs (%5d docs/s)" % (duration, stats['n_train'] / duration)return s
cls_stats = {}for cls_name in partial_fit_classifiers:stats = {'n_train': 0, 'n_train_pos': 0,'accuracy': 0.0, 'accuracy_history': [(0, 0)], 't0': time.time(),'runtime_history': [(0, 0)], 'total_fit_time': 0.0}cls_stats[cls_name] = stats
get_minibatch(data_stream, n_test_documents)

丢弃检验集

我们将1000个文档的最小批次送进分类器学习,这意味着,在任何时候内存至多有1000个文档。文档批次规模越小,偏拟合方法的相对间接消耗就越大。

# We will feed the classifier with mini-batches of 1000 documents; this means
# we have at most 1000 docs in memory at any time.  The smaller the document
# batch, the bigger the relative overhead of the partial fit methods.
minibatch_size = 1000# Create the data_stream that parses Reuters SGML files and iterates on
# documents as a stream.
minibatch_iterators = iter_minibatches(data_stream, minibatch_size)
total_vect_time = 0.0

主循环

在主循环里,迭代分类一个最小批次的例子。

# Main loop : iterate on mini-batches of examples
for i, (X_train_text, y_train) in enumerate(minibatch_iterators):tick = time.time()X_train = vectorizer.transform(X_train_text)total_vect_time += time.time() - tickfor cls_name, cls in partial_fit_classifiers.items():tick = time.time()# update estimator with examples in the current mini-batchcls.partial_fit(X_train, y_train, classes=all_classes)# accumulate test accuracy statscls_stats[cls_name]['total_fit_time'] += time.time() - tickcls_stats[cls_name]['n_train'] += X_train.shape[0]cls_stats[cls_name]['n_train_pos'] += sum(y_train)tick = time.time()cls_stats[cls_name]['accuracy'] = cls.score(X_test, y_test)cls_stats[cls_name]['prediction_time'] = time.time() - tickacc_history = (cls_stats[cls_name]['accuracy'],cls_stats[cls_name]['n_train'])cls_stats[cls_name]['accuracy_history'].append(acc_history)run_history = (cls_stats[cls_name]['accuracy'],total_vect_time + cls_stats[cls_name]['total_fit_time'])cls_stats[cls_name]['runtime_history'].append(run_history)if i % 3 == 0:print(progress(cls_name, cls_stats[cls_name]))if i % 3 == 0:print('\n')

程序部分输出:

可视化结果

def plot_accuracy(x, y, x_legend):"""Plot accuracy as a function of x."""x = np.array(x)y = np.array(y)plt.title('Classification accuracy as a function of %s' % x_legend)plt.xlabel('%s' % x_legend)plt.ylabel('Accuracy')plt.grid(True)plt.plot(x, y)rcParams['legend.fontsize'] = 10
cls_names = list(sorted(cls_stats.keys()))# Plot accuracy evolution
plt.figure()
for _, stats in sorted(cls_stats.items()):# Plot accuracy evolution with #examplesaccuracy, n_examples = zip(*stats['accuracy_history'])plot_accuracy(n_examples, accuracy, "training examples (#)")ax = plt.gca()ax.set_ylim((0.8, 1))
plt.legend(cls_names, loc='best')plt.figure()
for _, stats in sorted(cls_stats.items()):# Plot accuracy evolution with runtimeaccuracy, runtime = zip(*stats['runtime_history'])plot_accuracy(runtime, accuracy, 'runtime (s)')ax = plt.gca()ax.set_ylim((0.8, 1))
plt.legend(cls_names, loc='best')# Plot fitting times
plt.figure()
fig = plt.gcf()
cls_runtime = []
for cls_name, stats in sorted(cls_stats.items()):cls_runtime.append(stats['total_fit_time'])cls_runtime.append(total_vect_time)
cls_names.append('Vectorization')
bar_colors = ['b', 'g', 'r', 'c', 'm', 'y']ax = plt.subplot(111)
rectangles = plt.bar(range(len(cls_names)), cls_runtime, width=0.5,color=bar_colors)ax.set_xticks(np.linspace(0.25, len(cls_names) - 0.75, len(cls_names)))
ax.set_xticklabels(cls_names, fontsize=10)
ymax = max(cls_runtime) * 1.2
ax.set_ylim((0, ymax))
ax.set_ylabel('runtime (s)')
ax.set_title('Training Times')def autolabel(rectangles):"""attach some text vi autolabel on rectangles."""for rect in rectangles:height = rect.get_height()ax.text(rect.get_x() + rect.get_width() / 2.,1.05 * height, '%.4f' % height,ha='center', va='bottom')autolabel(rectangles)
plt.show()# Plot prediction times
plt.figure()
cls_runtime = []
cls_names = list(sorted(cls_stats.keys()))
for cls_name, stats in sorted(cls_stats.items()):cls_runtime.append(stats['prediction_time'])
cls_runtime.append(parsing_time)
cls_names.append('Read/Parse\n+Feat.Extr.')
cls_runtime.append(vectorizing_time)
cls_names.append('Hashing\n+Vect.')ax = plt.subplot(111)
rectangles = plt.bar(range(len(cls_names)), cls_runtime, width=0.5,color=bar_colors)ax.set_xticks(np.linspace(0.25, len(cls_names) - 0.75, len(cls_names)))
ax.set_xticklabels(cls_names, fontsize=8)
plt.setp(plt.xticks()[1], rotation=30)
ymax = max(cls_runtime) * 1.2
ax.set_ylim((0, ymax))
ax.set_ylabel('runtime (s)')
ax.set_title('Prediction Times (%d instances)' % n_test_documents)
autolabel(rectangles)
plt.show()

精彩内容,请关注微信公众号:统计学习与大数据

【Python实例第11讲】文本的核外分类相关推荐

  1. 【Python基础】11、文本处理与IO深入理解

    1.有一个文件,单词之间使用空格.分号.逗号.或者句号分隔,请提取全部单词. 解决方案: 使用\w匹配并提取单词,但是存在误判 使用str.split分隔字符字符串,但是需要多次分隔 使用re.spl ...

  2. Python实例:11~20例

    例11:打印出所有的"水仙花数",所谓"水仙花数"是指一个三位数,其各位数字立方和等于该数本身.例如:153是一个"水仙花数",因为153= ...

  3. shell实例第11讲:取出系统IP地址,并判断属于哪个网段

    取出系统IP地址,并判断属于哪个网段 #!/bin/bash #作者:魏波 #时间:2017.02.04ip=`ifconfig -a | grep inet | grep -v 127.0.0.1 ...

  4. Python实例讲解 -- wxpython 基本的控件 (文本)

    使用基本的控件工作 wxPython 工具包提供了多种不同的窗口部件,包括了本章所提到的基本控件.我们涉及静态文本.可编辑的文本.按钮.微调.滑块.复选框.单选按钮.选择器.列表框.组合框和标尺.对于 ...

  5. 【NLP】Python实例:基于文本相似度对申报项目进行查重设计

    Python实例:申报项目查重系统设计与实现 作者:白宁超 2017年5月18日17:51:37 摘要:关于查重系统很多人并不陌生,无论本科还是硕博毕业都不可避免涉及论文查重问题,这也对学术不正之风起 ...

  6. 如何用Python和BERT做中文文本二元分类?| 程序员硬核评测

    点击上方↑↑↑蓝字关注我们~ 「2019 Python开发者日」全日程揭晓,请扫码咨询 ↑↑↑ 作者 | 王树义 来源 | 王树芝兰(ID:nkwangshuyi) 兴奋 去年, Google 的 B ...

  7. Python零基础速成班-第11讲-Python日志Logging,小游戏设计game of life

    Python零基础速成班-第11讲-Python日志Logging,小游戏设计game of life 学习目标 Python日志Logging 小游戏设计game of life 课后作业(2必做) ...

  8. Python实例10:文本词频统计

    Python实例10:文本词频统计 6.6.1 问题分析 在英文中文中,出现哪些词,出现多少次? 6.6.2 hamlet英文词频统计 CalHamletV1.py 6.6.3 三国演义人物出场统计 ...

  9. 【数据分析师-python基础】python基础语法精讲

    python基础语法精讲 1 从数字开始 1.1 理解整数.浮点数.复数几种类型对象 1.2 掌握运算及其相关的常用函数 2 变量.表达式和语句 2.1 变量作用及定义的方法 2.2 变量命名原则和习 ...

  10. Python实战之字符串和文本处理

    写在前面 博文为<Python Cookbook>读书后笔记整理 涉及内容包括: 使用多个界定符分割字符串 字符串开头或结尾匹配,用Shell通配符匹配字符串 字符串匹配和搜索和替换(忽略 ...

最新文章

  1. vue-router2路由参数注意问题
  2. 部署laravel项目
  3. Linux下的grep命令
  4. linux防火墙查看被动模式,Centos7搭建vsftpd及被动模式下的防火墙设置
  5. Nacos 发布0.3.0版本,迄今为止最好看的版本
  6. Java8新特性总结 - 4.方法引用
  7. SQL Server 2008 评估期已过解决方法
  8. Java连接程序数据源
  9. iis7.5 php 404.17,部署IISHTTP 404.17无法由静态文件处理程序来处理
  10. SCVMM 2012 部署测试之五向SCVMM中添加Hyper-V主机
  11. 研磨设计模式笔记之简单工厂模式
  12. canoco5主成分分析步骤_主成分分析(PCA)统计与MATLAB函数实现
  13. 虚拟服务器共享文件设置,虚拟机共享文件夹设置流程
  14. 几种开源的网络流量监控软件
  15. [Learn Android Studio 汉化教程]Reminders实验:第一部分(续)
  16. 线下盛会|欢迎关注 Pulsar Summit 2022 旧金山峰会
  17. php 公众号 欢迎,关注公众号的欢迎语怎么设置?公众号欢迎语怎么加链接?
  18. APE文件学习——文件头(1)
  19. VOS3000怎样给对接网关设置按主叫号码计费
  20. 语音识别英语_英语语音识别_英语 语音识别 - 云+社区 - 腾讯云

热门文章

  1. 每日英语:A Chinese Soccer Club Has Won Something!
  2. 数据库优化常用的途径(方法)
  3. Security+ 学习笔记41 安全网络技术
  4. KVM详解(一)——KVM基础知识
  5. C++程序设计(三:可视化)
  6. 运维之Linux秋招重点(根据面经和常见笔试题总结,持续更新)
  7. naivcat 破解安装教程(永久)
  8. 程序执行的过程分析--【sky原创】
  9. 转载:世上最全的百物妙用窍门-绝对不能错过,不断更新中
  10. eWebEditor浏览器兼容 ie8 ie7