文章目录

  • Fasttext
  • 一、文件目录
  • 二、语料集下载地址(本文选择AG)
  • 三、数据处理(AG_Dataset.py)
  • 四、模型(Fasttext.py)
  • 五、训练和测试
  • 实验结果

Fasttext

一、文件目录

二、语料集下载地址(本文选择AG)

AG News: https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz
DBPedia: https://s3.amazonaws.com/fast-ai-nlp/dbpedia_csv.tgz
Sogou news: https://s3.amazonaws.com/fast-ai-nlp/sogou_news_csv.tgz
Yelp Review Polarity: https://s3.amazonaws.com/fast-ai-nlp/yelp_review_polarity_csv.tgz
Yelp Review Full: https://s3.amazonaws.com/fast-ai-nlp/yelp_review_full_csv.tgz
Yahoo! Answers: https://s3.amazonaws.com/fast-ai-nlp/yahoo_answers_csv.tgz
Amazon Review Full: https://s3.amazonaws.com/fast-ai-nlp/amazon_review_full_csv.tgz
Amazon Review Polarity: https://s3.amazonaws.com/fast-ai-nlp/amazon_review_polarity_csv.tgz

三、数据处理(AG_Dataset.py)

1.数据集加载
2.读取标签和数据
3.创建word2id
   3.1统计词频
   3.2加入 pad:0,unk:1创建word2id
4.将数据转化成id

from torch.utils import data
import os
import csv
import nltk
import numpy as np
class AG_Data(data.DataLoader):def __init__(self,data_path,min_count,max_length,n_gram=1,word2id = None,uniwords_num=0):self.path = os.path.abspath(".")if "data" not in self.path:self.path += "/data"self.n_gram = n_gramself.load(data_path)# 数据集加载,读取标签和数据if word2id==None:self.get_word2id(self.data,min_count)# 得到word2idelse:self.word2id = word2idself.uniwords_num = uniwords_numself.data = self.convert_data2id(self.data,max_length)# 将文本中的词都转化成idself.data = np.array(self.data)self.y = np.array(self.y)# 数据集加载,读取标签和数据def load(self, data_path,lowercase=True):self.label = []self.data = []with open(self.path+data_path,"r") as f:datas = list(csv.reader(f,delimiter=',', quotechar='"'))for row in datas:self.label.append(int(row[0]) - 1)txt = " ".join(row[1:])if lowercase:txt = txt.lower()txt = nltk.word_tokenize(txt)  # 将句子转化成词new_txt = []for i in range(0, len(txt)):for j in range(self.n_gram):  # 添加n-gram词if j <= i:new_txt.append(" ".join(txt[i - j:i + 1]))self.data.append(new_txt)self.y = self.label# 得到word2iddef get_word2id(self, datas, min_count=3):word_freq = {}for data in datas:  # 首先统计词频,后续通过词频过滤低频词for word in data:if word_freq.get(word) != None:word_freq[word] += 1else:word_freq[word] = 1word2id = {"<pad>": 0, "<unk>": 1}for word in word_freq:  # 首先构建uni-gram词,因为不需要hashif word_freq[word] < min_count or " " in word:continueword2id[word] = len(word2id)self.uniwords_num = len(word2id)for word in word_freq:  # 构建2-gram以上的词,需要hashif word_freq[word] < min_count or " " not in word:continueword2id[word] = len(word2id)self.word2id = word2id# 将文本中的词都转化成iddef convert_data2id(self, datas, max_length):for i, data in enumerate(datas):for j, word in enumerate(data):if " " not in word:datas[i][j] = self.word2id.get(word, 1)else:datas[i][j] = self.word2id.get(word, 1) % 100000 + self.uniwords_num  # hash函数datas[i] = datas[i][0:max_length] + [0] * (max_length - len(datas[i]))return datasdef __getitem__(self, idx):X = self.data[idx]y = self.y[idx]return X, ydef __len__(self):return len(self.label)
if __name__=="__main__":ag_data = AG_Data("/AG/train.csv",3,100)print (ag_data.data.shape)print (ag_data.data[-20:])print (ag_data.y.shape)print (len(ag_data.word2id))

四、模型(Fasttext.py)

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class Fasttext(nn.Module):def __init__(self,vocab_size,embedding_size,max_length,label_num):super(Fasttext,self).__init__()self.embedding =nn.Embedding(vocab_size,embedding_size)self.avg_pool = nn.AvgPool1d(kernel_size=max_length,stride=1)self.fc = nn.Linear(embedding_size, label_num)def forward(self, x):out = self.embedding(x) # batch_size*length*embedding_size bs*100*200out = out.transpose(1, 2).contiguous() # batch_size*embedding_size*length bs*200*100out = self.avg_pool(out).squeeze() # batch_size*embedding_size*1out = self.fc(out) # batch_size*label_numreturn out
if __name__=="__main__":fasttext = Fasttext(100,200,100,4)x = torch.Tensor(np.zeros([64,100])).long()out = fasttext(x)print (out.size())

五、训练和测试


import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
from model import Fasttext
from data import AG_Data
import numpy as np
from tqdm import tqdm
import config as argumentparser
config = argumentparser.ArgumentParser()
torch.manual_seed(config.seed)if config.cuda and torch.cuda.is_available():torch.cuda.set_device(config.gpu)
def get_test_result(data_iter,data_set):# 生成测试结果model.eval()true_sample_num = 0for data, label in data_iter:if config.cuda and torch.cuda.is_available():data = data.cuda()label = label.cuda()else:data = torch.autograd.Variable(data).long()out = model(data)true_sample_num += np.sum((torch.argmax(out, 1) == label).cpu().numpy())acc = true_sample_num / data_set.__len__()return acc
training_set = AG_Data("/AG/train.csv",min_count=config.min_count,max_length=config.max_length,n_gram=config.n_gram)
training_iter = torch.utils.data.DataLoader(dataset=training_set,batch_size=config.batch_size,shuffle=True,num_workers=0)
test_set = AG_Data(data_path="/AG/test.csv",min_count=config.min_count,max_length=config.max_length,n_gram=config.n_gram,word2id=training_set.word2id,uniwords_num=training_set.uniwords_num)
test_iter = torch.utils.data.DataLoader(dataset=test_set,batch_size=config.batch_size,shuffle=False,num_workers=0)
model = Fasttext(vocab_size=training_set.uniwords_num+100000,embedding_size=config.embed_size,max_length=config.max_length,label_num=config.label_num)
if config.cuda and torch.cuda.is_available():model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
loss = -1
for epoch in range(config.epoch):model.train()process_bar = tqdm(training_iter)for data, label in process_bar:if config.cuda and torch.cuda.is_available():data = data.cuda()label = label.cuda()else:data = torch.autograd.Variable(data).long()label = torch.autograd.Variable(label).squeeze()out = model(data)loss_now = criterion(out, autograd.Variable(label.long()))if loss == -1:loss = loss_now.data.item()else:loss = 0.95*loss+0.05*loss_now.data.item()process_bar.set_postfix(loss=loss_now.data.item())process_bar.update()optimizer.zero_grad()loss_now.backward()optimizer.step()test_acc = get_test_result(test_iter, test_set)print("The test acc is: %.5f" % test_acc)

实验结果

输出测试集准确率:

Fasttext(AG数据集---新闻主题分类)相关推荐

  1. 【深度学习】单标签多分类问题之新闻主题分类

    # -*- coding: utf-8 -*- """单标签多分类问题之新闻主题分类.ipynbAutomatically generated by Colaborato ...

  2. 新闻主题分类任务NLP

    关于新闻主题分类任务: 以一段新闻报道中的文本描述内容为输入, 使用模型帮助我们判断它最有可能属于哪一种类型的新闻, 这是典型的文本分类问题, 我们这里假定每种类型是互斥的, 即文本描述有且只有一种类 ...

  3. 【NLP】文本分类TorchText实战-AG_NEWS 新闻主题分类任务(PyTorch版)

    AG_NEWS 新闻主题分类任务(PyTorch版) 前言 1. 使用 N 元组加载数据 2. 安装 Torch-GPU&TorchText 3. 访问原始数据集迭代器 4. 准备数据处理管道 ...

  4. 朴素贝叶斯进行新闻主题分类,有代码和数据,可以跑通

    folder_path = '/Users/apple/Documents/七月在线/NLP/第2课/Lecture_2/Naive-Bayes-Text-Classifier/Database/So ...

  5. 自然语言处理项目之新闻主题分类Python实现

    ''' #2018-06-10 June Sunday the 23 week, the 161 day SZ 数据来源:链接:https://pan.baidu.com/s/1_w7wOzNkUEa ...

  6. 新闻主题分类任务——torchtext 库进行文本分类

    目录 简介 导入相关的torch工具包 访问原始数据集迭代器 使用原始训练数据集构建词汇表 生成数据批处理和迭代器 定义模型 定义函数来训练模型和评估结果 实例化并运行模型 使用测试数据集评估模型 测 ...

  7. 天池比赛——新闻文本分类比赛(零基础入门NLP)

    1 赛题理解 1.1 比赛内容 对新闻文本的类别进行预测.比赛提供了包含14个新闻类别的文本数据,分为训练集和测试集A,B.训练集包含20万条新闻文本数据,测试集A,B分别包含5万条新闻文本数据.所有 ...

  8. fasttext文本分类python实现_一个使用fasttext训练的新闻文本分类器/模型

    fastext是什么? Facebook AI Research Lab 发布的一个用于快速进行文本分类和单词表示的库.优点是很快,可以进行分钟级训练,这意味着你可以在几分钟时间内就训练好一个分类模型 ...

  9. ML之NB:基于NB朴素贝叶斯算法训练20类新闻文本数据集进行多分类预测

    ML之NB:基于NB朴素贝叶斯算法训练20类新闻文本数据集进行多分类预测 目录 输出结果 设计思路 核心代码 输出结果 设计思路 核心代码 vec = CountVectorizer() X_trai ...

最新文章

  1. 末学者笔记--openstack共享组件:rabbitmq(3)
  2. JS 验证表单不能为空
  3. AT2370 Piling Up
  4. 大规模异常滥用检测:基于局部敏感哈希算法——来自Uber Engineering的实践
  5. Python 中re.split()方法
  6. python按正则方式搜索文件
  7. CentOS中设置ip地址等信息
  8. 科大星云诗社动态20210827
  9. thinkphp三级分销小程序源码_山东谷道微信小程序商城源码带后台 公众号平台三级分销系统...
  10. python人脸识别防小偷_Python人脸识别
  11. 什么是Mybatis配置解析?(源码+图文)
  12. ES6的类Class基础知识点
  13. Android:使用GsonFormat插件遇到的坑
  14. EDA实践——基于VHDL的循环八路彩灯设计
  15. “向日葵”远程控制软件,方舟Q2硬件付费/免费功能全面评测,拔草向
  16. java 邮件批量发送邮件_利用Java实现电子邮件的批量发送
  17. vc 星号密码查看方法
  18. Spring Cloud如何可用于微服务架构
  19. [报表篇] (11)设置印刷尺寸
  20. 使用python批量解压7z格式压缩包

热门文章

  1. 纷享销客 java开发实习生面经
  2. Typora+Gitee+PicGo搭建图床
  3. php调用百度地图定位,php用百度地图API进行IP定位和GPS定位
  4. krpano360全景教程 - 全景场景实现自动旋转及循环浏览全部场景
  5. YII2.0 接口开发步骤
  6. 基于MATLAB的LBM代码: Rough jet model
  7. bootstrap实现 — 个人简介
  8. 面向对象五个基本原则(SOLID)
  9. DDR3L和LPDDR3区别
  10. 嵌入式Linux(十一)DDR3