本部分基于paper:Convolutional Neural Networks for Sentence Classification

模型介绍

目的

模型的目的是对输入的句子进行分类。论文中的模型图片如下:

输入

首先将输入的句子的各个单词的词向量叠起来,符号⊕是将向量上下堆叠的意思(如上图最左边的部分):

卷积

然后使用一个hhh×\times×kkk的卷积核对其进行卷积,kkk为词向量的维数,这里实质上是一个一维的卷积,得到一个(n−h+1)(n-h+1)(nh+1)维的向量。再变化卷积核的高度hhh来获得多组卷积结果,如上图的第二部分所示,记卷积结果为CiCiCi。假设共有mmm组卷积结果。

池化

对每个卷积结果CiCiCi进行max-over-time pooling处理,即取每一组卷积结果的最大值。每一组卷积结果最终处理得到一个单个数字。将池化处理得到的结果进行堆叠,获得一个111×\times×mmm的向量。

全连接层

最后通过全连接层。全连接层有softmax单元对输入进行分类。

代码实现

pytorch代码及详细注释如下:(源代码为github中nlp-tutorial项目,项目地址:nlp-tutorial)

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as Fdtype = torch.FloatTensor# Text-CNN Parameter
embedding_size = 2    # 词向量是二维的
sequence_length = 3   # 句子的长度
num_classes = 2       # 最终对句子分类,共有两类
filter_sizes = [2, 2, 2]  # 最大filter的形状
num_filters = 3       # filter的个数# 3 words sentences (=sequence_length is 3)
sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
labels = [1, 1, 1, 0, 0, 0]  # 1 is good, 0 is not good.
word_list = " ".join(sentences).split() # 先用" ".join(),以空格为分隔,将sentences中的句子连接起来,再用split()以空格为分割点,将每个词分出来
word_list = list(set(word_list))        # 先用set合并重复的单词,再用list创建单词列表
word_dict = {w: i for i, w in enumerate(word_list)}     # 建立单词到序号的索引
vocab_size = len(word_dict)                             # 词典中单词的数量inputs = []               # 创建输入空列表
for sen in sentences:     # 将每句话的单词的序号组合成一个数组,加入到输入列表inputs中inputs.append(np.asarray([word_dict[n] for n in sen.split()]))
targets = []              # 创建标签空列表
for out in labels:        # 将每句话的标签加入标签空列表中targets.append(out)   # To using Torch Softmax Loss functioninput_batch = Variable(torch.LongTensor(inputs))         #转换成variable形式
target_batch = Variable(torch.LongTensor(targets))class TextCNN(nn.Module):def __init__(self):super(TextCNN, self).__init__()self.num_filters_total = num_filters * len(filter_sizes)    # 一个维数,方便后面的权重矩阵确定维数self.W = nn.Parameter(torch.empty(vocab_size, embedding_size).uniform_(-1, 1)).type(dtype)  # 词向量self.Weight = nn.Parameter(torch.empty(self.num_filters_total, num_classes).uniform_(-1, 1)).type(dtype) # 全连接层权重self.Bias = nn.Parameter(0.1 * torch.ones([num_classes])).type(dtype)  # 全连接层偏置值def forward(self, X):embedded_chars = self.W[X] # 形状是[batch_size, sequence_length, embedding_size]embedded_chars = embedded_chars.unsqueeze(1) # add channel(=1) [batch, channel(=1), sequence_length, embedding_size]pooled_outputs = []for filter_size in filter_sizes:     # 用不同大小的卷积核进行卷积计算,这里不同卷积核的宽都是2,就height变化# 对输入的句子矩阵embedded_chars卷积# conv2d的参数 : [input_channel(=1), output_channel(=3), (filter_height, filter_width), bias_option]conv = nn.Conv2d(1, num_filters, (filter_size, embedding_size), bias=True)(embedded_chars) # 3个filter,故输出通道数为3h = F.relu(conv)    # 激活函数relu# mp : ((filter_height, filter_width))mp = nn.MaxPool2d((sequence_length - filter_size + 1, 1))   # 最大池化# pooled : [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3)]pooled = mp(h).permute(0, 3, 2, 1)       # 重新排列pooled_outputs.append(pooled)            # 将使用某个大小的卷积核计算出的结果添加到outputs中# 使用torch.cat函数,在len(filter_sizes)的维度上将outputs中的张量进行堆叠h_pool = torch.cat(pooled_outputs, len(filter_sizes)) # [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3) * 3]# 重新排列h_pool_flat = torch.reshape(h_pool, [-1, self.num_filters_total]) # [batch_size(=6), output_height * output_width * (output_channel * 3)]# 经过全连接层线性计算model = torch.mm(h_pool_flat, self.Weight) + self.Bias # [batch_size, num_classes]return modelmodel = TextCNN()
criterion = nn.CrossEntropyLoss()                     # 损失函数使用交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 优化方法采用Adam# Training
for epoch in range(5000):optimizer.zero_grad()              # 每次训练前清零梯度缓存output = model(input_batch)        # 输入input_batch,从模型中获得输出# output : [batch_size, num_classes], target_batch : [batch_size] (LongTensor, not one-hot)loss = criterion(output, target_batch)  # 计算lossif (epoch + 1) % 1000 == 0:             # 每1000次打印一次训练情况print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))loss.backward()    # 反向传播optimizer.step()   # 优化、更新参数# Test
test_text = 'sorry hate you'
tests = [np.asarray([word_dict[n] for n in test_text.split()])]
test_batch = Variable(torch.LongTensor(tests))
# Predict
predict = model(test_batch).data.max(1, keepdim=True)[1]
if predict[0][0] == 0:print(test_text,"is Bad Mean...")
else:print(test_text,"is Good Mean!!")

nlp-tutorial代码注释2-1,CNN用于句子分类简介相关推荐

  1. CNN用于句子分类时的超参数分析

    本文是"A Sensitivity Analysis of (and Practitioners' Guide to) Convolutional Neural Networks for S ...

  2. Convolutional Neural Networks for Sentence Classification用于句子分类的卷积神经网络

    Convolutional Neural Networks for Sentence Classification 论文任务:用卷积神经网络(CNN)在预先训练好的词向量上进行句子级分类任务 论文借用 ...

  3. CNN对句子分类(tensorflow)

    卷积神经网络是一种特殊的深层的神经网络模型,它的特殊性体现在两个方面,一方面它的神经元间的连接是非全连接的, 另一方面同一层中某些神经元之间的连接的权重是共享的(即相同的).它的非全连接和权值共享的网 ...

  4. 利用CNN进行句子分类的敏感性分析

    原文标题 A Sensitivity Analysis of (and Practitioners' Guide to) Convolutional Neural Networks for Sente ...

  5. 【详细代码注释】基于CNN卷积神经网络实现随机森林算法

    随机森林算法简介: 随机森林(Random Forest)是一种灵活性很高的机器学习算法. 它的底层是利用多棵树对样本进行训练并预测的一种分类器.在机器学习的许多领域都有广泛地应用. 例如构建医学疾病 ...

  6. CNN与句子分类之动态池化方法DCNN--模型介绍篇

    本文是针对"A Convolutional Neural Network for Modelling Sentences"论文的阅读笔记和代码实现.这片论文的主要贡献在于其提出了一 ...

  7. nlp-tutorial代码注释笔记

    系列语:本系列是nlp-tutorial代码注释系列,github上原项目地址为:nlp-tutorial,本系列每一篇文章的大纲是相关知识点介绍 + 详细代码注释. 传送门: nlp-tutoria ...

  8. 文本分类(下) | 卷积神经网络(CNN)在文本分类上的应用

    正文共3758张图,4张图,预计阅读时间18分钟. 1.简介 原先写过两篇文章,分别介绍了传统机器学习方法在文本分类上的应用以及CNN原理,然后本篇文章结合两篇论文展开,主要讲述下CNN在文本分类上的 ...

  9. 【NLP】TensorFlow实现CNN用于中文文本分类

    代码基于 dennybritz/cnn-text-classification-tf 及 clayandgithub/zh_cnn_text_classify 参考文章 了解用于NLP的卷积神经网络( ...

最新文章

  1. ExtJS之对话框及窗口篇
  2. Ubuntu18.04下安装RRStudio
  3. c语言如何发现错误在哪里,二个C语言例子,编译没通过.不知道错在哪里[求助]
  4. python简单超级马里奥游戏下载_python 实现超级玛丽游戏
  5. matlab平滑窗滤波,matlab实现平滑滤波
  6. AttributeError: type object ‘Image‘ has no attribute ‘open‘
  7. EZchip将推全球首款100核64位ARM A-53芯片
  8. java读取.properties文件及解决中文乱码问题
  9. linux centos ppp限速,Centos7限速和测速
  10. 深入理解ArrayList 和 LinkedList 区别
  11. 在Zephyr上使用malloc/new
  12. Python - 进程/线程相关整理
  13. 技术的理想该继续吗?
  14. 整数拆分 python_LeetCode 343. 整数拆分 | Python
  15. pythonturtle写人名_python turtle写名字
  16. nextTick介绍
  17. 安卓开发笔记(八)—— 王者荣耀英雄大全 数据库部分
  18. 第13届D2大会 - 参会感受和总结
  19. 网站用户行为日志采集和后台日志服务器搭建
  20. 【长截图】轻松简便、一步实现

热门文章

  1. UVA 818 Cutting Chains 切断圆环链 (暴力dfs)
  2. Linux 搭建SVN server
  3. 如果不使用 SQL Mail,如何在 SQL Server 中发送电子邮件
  4. maven安装与配置等相关知识
  5. Windows Server 2008 R2 install Visual Studio 2015 failed
  6. 【UOJ#450】【集训队作业2018】复读机(生成函数,单位根反演)
  7. 小峰视频十四:面向对象和类的概念
  8. 洛谷P1273 有线电视网
  9. java多线程之wait和notify协作,生产者和消费者
  10. Opengl_9_复合变换