内容

import torch
import torch.nn as nn
import torch.nn.functional as F
from src.model.general.attention.additive import AdditiveAttentiondevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")class KCNN(torch.nn.Module):"""Knowledge-aware CNN (KCNN) based on Kim CNN.Input a news sentence (e.g. its title), produce its embedding vector."""def __init__(self, config, pretrained_word_embedding,pretrained_entity_embedding, pretrained_context_embedding):#前面是单纯的定义: 获取单词嵌入、实体嵌入和上下文嵌入的预训练参数(不只是历史点击新闻还有候选新闻的)super(KCNN, self).__init__()self.config = configif pretrained_word_embedding is None:  #如果预训练单词嵌入是空,那么就需要用集成在nn.Embedding()的函数了self.word_embedding = nn.Embedding(config.num_words,config.word_embedding_dim,padding_idx=0)else:self.word_embedding = nn.Embedding.from_pretrained(pretrained_word_embedding, freeze=False, padding_idx=0)if pretrained_entity_embedding is None:self.entity_embedding = nn.Embedding(config.num_entities,config.entity_embedding_dim,padding_idx=0)else:self.entity_embedding = nn.Embedding.from_pretrained(pretrained_entity_embedding, freeze=False, padding_idx=0)if config.use_context:if pretrained_context_embedding is None:self.context_embedding = nn.Embedding(config.num_entities,config.entity_embedding_dim,padding_idx=0)else:self.context_embedding = nn.Embedding.from_pretrained(pretrained_context_embedding, freeze=False, padding_idx=0)self.transform_matrix = nn.Parameter(torch.empty(self.config.entity_embedding_dim,self.config.word_embedding_dim).uniform_(-0.1, 0.1))self.transform_bias = nn.Parameter(torch.empty(self.config.word_embedding_dim).uniform_(-0.1, 0.1))self.conv_filters = nn.ModuleDict({str(x): nn.Conv2d(3 if self.config.use_context else 2,self.config.num_filters,(x, self.config.word_embedding_dim))for x in self.config.window_sizes})self.additive_attention = AdditiveAttention(self.config.query_vector_dim, self.config.num_filters)def forward(self, news):"""Args:news:{"title": batch_size * num_words_title,"title_entities": batch_size * num_words_title}Returns:final_vector: batch_size, len(window_sizes) * num_filters"""# batch_size, num_words_title, word_embedding_dimword_vector = self.word_embedding(news["title"].to(device))#获得单词向量  需要放到设备上的# batch_size, num_words_title, entity_embedding_dimentity_vector = self.entity_embedding(    #获得实体向量news["title_entities"].to(device))if self.config.use_context:    #用上下文的话就得获得上下文的向量# batch_size, num_words_title, entity_embedding_dimcontext_vector = self.context_embedding(news["title_entities"].to(device))# batch_size, num_words_title, word_embedding_dimtransformed_entity_vector = torch.tanh(   #转换矩阵是将其中某些词替换掉!torch.add(torch.matmul(entity_vector, self.transform_matrix),self.transform_bias))if self.config.use_context:        # batch_size, num_words_title, word_embedding_dimtransformed_context_vector = torch.tanh(torch.add(torch.matmul(context_vector, self.transform_matrix),self.transform_bias))# batch_size, 3, num_words_title, word_embedding_dimmulti_channel_vector = torch.stack([word_vector, transformed_entity_vector,transformed_context_vector], dim=1)   #获得最终的concat向量else:# batch_size, 2, num_words_title, word_embedding_dimmulti_channel_vector = torch.stack([word_vector, transformed_entity_vector], dim=1)pooled_vectors = []  #for x in self.config.window_sizes:    # window_size = 3# batch_size, num_filters, num_words_title + 1 - xconvoluted = self.conv_filters[str(x)](   #后面就是卷积常规操作! 分别进行3种window_size的卷积multi_channel_vector).squeeze(dim=3)# batch_size, num_filters, num_words_title + 1 - xactivated = F.relu(convoluted)# batch_size, num_filters# Here we use a additive attention module# instead of pooling in the paperpooled = self.additive_attention(activated.transpose(1, 2))# pooled = activated.max(dim=-1)[0]# # or# # pooled = F.max_pool1d(activated, activated.size(2)).squeeze(dim=2)pooled_vectors.append(pooled)# batch_size, len(window_sizes) * num_filtersfinal_vector = torch.cat(pooled_vectors, dim=1)return final_vector

说明

最后的卷积有必要说一下

【DKN】(六)KCNN.py相关推荐

  1. 第二十一:基于Python2+Selenium3+Pytest4+Pytest-Html的UI自动化框架

    一.环境配置: 1.Python2.7.10, selenium3.141.0, pytest4.6.6, pytest-html1.22.0, Windows-7-6.1.7601-SP1 二.特点 ...

  2. pytorch上分之路——视频补全算法(onion peel network)

    文章目录 前言 一.config.py 二.datalist.py 三.common.py 四.model.py 五.model_common.py 六.train.py 总结 前言 该算法是从git ...

  3. Django 学习记录

    学习使用Django有一段时间了,期间也做过一些记录,希望能帮到大家! 2017-08-16 #Django python manage.py startapp Users #生产新的应用 2017- ...

  4. 学习python第一天

    一.计算机简介 组成:输入设备.输出设备.存储器.运算器.控制器 键盘 .鼠标:向电脑输入有效信息[输入设备] 机箱: 主板:连接其他所有设备的载体CPU:中央处理单元[Central Proessi ...

  5. python分配问题_组队、路径分配问题建模案例 ✕ Gurobi 应用 | python3 实现

    1 前言 本文以两道经典建模题为例, 进一步介绍 Gurobi 与 Python 的交互, 以及其在建模中的应用. 阅读本文前, 建议读者先配置好 Gurobi 环境, 并且对数学建模有一定的认识 ( ...

  6. 汉字的首拼音字母生成

    生成助记码(取汉字的第一个字母) SET NOCOUNT ON GO IF EXISTS(SELECT name    FROM   sysobjects    WHERE  name = N'hzp ...

  7. httprunner3、pytest、allure资料整理合集

    文章目录 httprunner介绍 一.前言 二.什么是Httprunner 三.Httprunner2.x和3.x区别 四.环境搭建 五.HttpRunner快速上手 5.1.测试用例结构 5.2 ...

  8. 【 办公类-03】 VS Python 大8班“运动场地”的周次安排。

    办公需求: 开学了,周计划教案的模板又要调整了.其中"分散活动"的内容需要按一定排序规律,每周填入不同的场地. 为了避免每个班级每天用的场地产生重复,组长制定了场地安排规则. 以大 ...

  9. 知识图谱论文阅读(八)【转】推荐系统遇上深度学习(二十六)--知识图谱与推荐系统结合之DKN模型原理及实现

    学习的博客: 推荐系统遇上深度学习(二十六)–知识图谱与推荐系统结合之DKN模型原理及实现 知识图谱特征学习的模型分类汇总 知识图谱嵌入(KGE):方法和应用的综述 论文: Knowledge Gra ...

最新文章

  1. SpringMVC(三):使用 POJO 对象绑定请求参数值
  2. 【剑指offer-Java版】33把数组排成最小的数
  3. 郑州网络推广浅谈网站首页在优化时都需要注意哪些细节呢?
  4. RHCE笔记1-安裝
  5. openCV中waitKey函数介绍
  6. 检测和删除多余无用的css
  7. Swift的笔记和参考
  8. java总结第四次//常用类
  9. JavaScript 的数据类型 相关知识点
  10. 技术项目 - MySQL多从系统的主库选择
  11. 文件磁盘相关函数[11]-获取指定文件的版本号 GetFileVersion
  12. java程序dna,蓝桥杯——DNA(Java题解)
  13. Nginx面试题及基础
  14. 棋牌游戏开发教程系列:游戏服务器框架搭建
  15. eds能谱图分析实例_热场发射扫描电子显微镜与能谱分析仪
  16. 最最最简单从官方获取最新行政区划代码、区划拼音
  17. Kubernetes Downward API的介绍及使用
  18. CITA v0.15 Release
  19. 【​观察】纺织印花行业转型与升级提速 爱普生蒙娜丽莎掀起技术迭代革命...
  20. VMware9虚拟机和Fedora12安装-实现Windows和Linux文件共享

热门文章

  1. 使用eclipse以及Juint进行测试
  2. 外网访问FTP服务,解决只能以POST模式访问Filezilla的问题
  3. 【JavaScript代码实现四】获取和设置 cookie
  4. 回溯算法--8皇后问题
  5. 数据结构和算法,也就那么回事儿
  6. Android中GC的触发时机和条件
  7. 中的数组怎么转成结构体_PLC知识,什么是数组和结构体?
  8. java 原子类_小学妹教你并发编程的三大特性:原子性、可见性、有序性
  9. ubuntu c++检测usb口事件变化_拆解报告:美式双USBA口充电插座
  10. python的最受欢迎的库_2018年最受欢迎的15个Python库