nlp-tutorial代码注释2-1,CNN用于句子分类简介
本部分基于paper:Convolutional Neural Networks for Sentence Classification
模型介绍
目的
模型的目的是对输入的句子进行分类。论文中的模型图片如下:
输入
首先将输入的句子的各个单词的词向量叠起来,符号⊕是将向量上下堆叠的意思(如上图最左边的部分):
卷积
然后使用一个hhh×\times×kkk的卷积核对其进行卷积,kkk为词向量的维数,这里实质上是一个一维的卷积,得到一个(n−h+1)(n-h+1)(n−h+1)维的向量。再变化卷积核的高度hhh来获得多组卷积结果,如上图的第二部分所示,记卷积结果为CiCiCi。假设共有mmm组卷积结果。
池化
对每个卷积结果CiCiCi进行max-over-time pooling处理,即取每一组卷积结果的最大值。每一组卷积结果最终处理得到一个单个数字。将池化处理得到的结果进行堆叠,获得一个111×\times×mmm的向量。
全连接层
最后通过全连接层。全连接层有softmax单元对输入进行分类。
代码实现
pytorch代码及详细注释如下:(源代码为github中nlp-tutorial项目,项目地址:nlp-tutorial)
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as Fdtype = torch.FloatTensor# Text-CNN Parameter
embedding_size = 2 # 词向量是二维的
sequence_length = 3 # 句子的长度
num_classes = 2 # 最终对句子分类,共有两类
filter_sizes = [2, 2, 2] # 最大filter的形状
num_filters = 3 # filter的个数# 3 words sentences (=sequence_length is 3)
sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
labels = [1, 1, 1, 0, 0, 0] # 1 is good, 0 is not good.
word_list = " ".join(sentences).split() # 先用" ".join(),以空格为分隔,将sentences中的句子连接起来,再用split()以空格为分割点,将每个词分出来
word_list = list(set(word_list)) # 先用set合并重复的单词,再用list创建单词列表
word_dict = {w: i for i, w in enumerate(word_list)} # 建立单词到序号的索引
vocab_size = len(word_dict) # 词典中单词的数量inputs = [] # 创建输入空列表
for sen in sentences: # 将每句话的单词的序号组合成一个数组,加入到输入列表inputs中inputs.append(np.asarray([word_dict[n] for n in sen.split()]))
targets = [] # 创建标签空列表
for out in labels: # 将每句话的标签加入标签空列表中targets.append(out) # To using Torch Softmax Loss functioninput_batch = Variable(torch.LongTensor(inputs)) #转换成variable形式
target_batch = Variable(torch.LongTensor(targets))class TextCNN(nn.Module):def __init__(self):super(TextCNN, self).__init__()self.num_filters_total = num_filters * len(filter_sizes) # 一个维数,方便后面的权重矩阵确定维数self.W = nn.Parameter(torch.empty(vocab_size, embedding_size).uniform_(-1, 1)).type(dtype) # 词向量self.Weight = nn.Parameter(torch.empty(self.num_filters_total, num_classes).uniform_(-1, 1)).type(dtype) # 全连接层权重self.Bias = nn.Parameter(0.1 * torch.ones([num_classes])).type(dtype) # 全连接层偏置值def forward(self, X):embedded_chars = self.W[X] # 形状是[batch_size, sequence_length, embedding_size]embedded_chars = embedded_chars.unsqueeze(1) # add channel(=1) [batch, channel(=1), sequence_length, embedding_size]pooled_outputs = []for filter_size in filter_sizes: # 用不同大小的卷积核进行卷积计算,这里不同卷积核的宽都是2,就height变化# 对输入的句子矩阵embedded_chars卷积# conv2d的参数 : [input_channel(=1), output_channel(=3), (filter_height, filter_width), bias_option]conv = nn.Conv2d(1, num_filters, (filter_size, embedding_size), bias=True)(embedded_chars) # 3个filter,故输出通道数为3h = F.relu(conv) # 激活函数relu# mp : ((filter_height, filter_width))mp = nn.MaxPool2d((sequence_length - filter_size + 1, 1)) # 最大池化# pooled : [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3)]pooled = mp(h).permute(0, 3, 2, 1) # 重新排列pooled_outputs.append(pooled) # 将使用某个大小的卷积核计算出的结果添加到outputs中# 使用torch.cat函数,在len(filter_sizes)的维度上将outputs中的张量进行堆叠h_pool = torch.cat(pooled_outputs, len(filter_sizes)) # [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3) * 3]# 重新排列h_pool_flat = torch.reshape(h_pool, [-1, self.num_filters_total]) # [batch_size(=6), output_height * output_width * (output_channel * 3)]# 经过全连接层线性计算model = torch.mm(h_pool_flat, self.Weight) + self.Bias # [batch_size, num_classes]return modelmodel = TextCNN()
criterion = nn.CrossEntropyLoss() # 损失函数使用交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.001) # 优化方法采用Adam# Training
for epoch in range(5000):optimizer.zero_grad() # 每次训练前清零梯度缓存output = model(input_batch) # 输入input_batch,从模型中获得输出# output : [batch_size, num_classes], target_batch : [batch_size] (LongTensor, not one-hot)loss = criterion(output, target_batch) # 计算lossif (epoch + 1) % 1000 == 0: # 每1000次打印一次训练情况print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))loss.backward() # 反向传播optimizer.step() # 优化、更新参数# Test
test_text = 'sorry hate you'
tests = [np.asarray([word_dict[n] for n in test_text.split()])]
test_batch = Variable(torch.LongTensor(tests))
# Predict
predict = model(test_batch).data.max(1, keepdim=True)[1]
if predict[0][0] == 0:print(test_text,"is Bad Mean...")
else:print(test_text,"is Good Mean!!")
nlp-tutorial代码注释2-1,CNN用于句子分类简介相关推荐
- CNN用于句子分类时的超参数分析
本文是"A Sensitivity Analysis of (and Practitioners' Guide to) Convolutional Neural Networks for S ...
- Convolutional Neural Networks for Sentence Classification用于句子分类的卷积神经网络
Convolutional Neural Networks for Sentence Classification 论文任务:用卷积神经网络(CNN)在预先训练好的词向量上进行句子级分类任务 论文借用 ...
- CNN对句子分类(tensorflow)
卷积神经网络是一种特殊的深层的神经网络模型,它的特殊性体现在两个方面,一方面它的神经元间的连接是非全连接的, 另一方面同一层中某些神经元之间的连接的权重是共享的(即相同的).它的非全连接和权值共享的网 ...
- 利用CNN进行句子分类的敏感性分析
原文标题 A Sensitivity Analysis of (and Practitioners' Guide to) Convolutional Neural Networks for Sente ...
- 【详细代码注释】基于CNN卷积神经网络实现随机森林算法
随机森林算法简介: 随机森林(Random Forest)是一种灵活性很高的机器学习算法. 它的底层是利用多棵树对样本进行训练并预测的一种分类器.在机器学习的许多领域都有广泛地应用. 例如构建医学疾病 ...
- CNN与句子分类之动态池化方法DCNN--模型介绍篇
本文是针对"A Convolutional Neural Network for Modelling Sentences"论文的阅读笔记和代码实现.这片论文的主要贡献在于其提出了一 ...
- nlp-tutorial代码注释笔记
系列语:本系列是nlp-tutorial代码注释系列,github上原项目地址为:nlp-tutorial,本系列每一篇文章的大纲是相关知识点介绍 + 详细代码注释. 传送门: nlp-tutoria ...
- 文本分类(下) | 卷积神经网络(CNN)在文本分类上的应用
正文共3758张图,4张图,预计阅读时间18分钟. 1.简介 原先写过两篇文章,分别介绍了传统机器学习方法在文本分类上的应用以及CNN原理,然后本篇文章结合两篇论文展开,主要讲述下CNN在文本分类上的 ...
- 【NLP】TensorFlow实现CNN用于中文文本分类
代码基于 dennybritz/cnn-text-classification-tf 及 clayandgithub/zh_cnn_text_classify 参考文章 了解用于NLP的卷积神经网络( ...
最新文章
- ExtJS之对话框及窗口篇
- Ubuntu18.04下安装RRStudio
- c语言如何发现错误在哪里,二个C语言例子,编译没通过.不知道错在哪里[求助]
- python简单超级马里奥游戏下载_python 实现超级玛丽游戏
- matlab平滑窗滤波,matlab实现平滑滤波
- AttributeError: type object ‘Image‘ has no attribute ‘open‘
- EZchip将推全球首款100核64位ARM A-53芯片
- java读取.properties文件及解决中文乱码问题
- linux centos ppp限速,Centos7限速和测速
- 深入理解ArrayList 和 LinkedList 区别
- 在Zephyr上使用malloc/new
- Python - 进程/线程相关整理
- 技术的理想该继续吗?
- 整数拆分 python_LeetCode 343. 整数拆分 | Python
- pythonturtle写人名_python turtle写名字
- nextTick介绍
- 安卓开发笔记(八)—— 王者荣耀英雄大全 数据库部分
- 第13届D2大会 - 参会感受和总结
- 网站用户行为日志采集和后台日志服务器搭建
- 【长截图】轻松简便、一步实现