利用GPT2模型训练中文闲聊模型

最近看了一下GPT2模型,看到很多博主都用来写诗歌,做问答等,小编突然萌生一个想法,利用GPT2来训练一个闲聊模型!!(小说生成器模型已经破产,写出来的东西狗屁不通,懒得再弄了,有兴趣的小伙伴可以继续尝试。。),闲聊模型在小编的Github上,欢迎star和fork,谢谢!!


文章目录

  • 利用GPT2模型训练中文闲聊模型
  • 前言
  • 一、数据结构
  • 二、模型搭建

前言

本来想用清源CPM预训练模型来进行的,可惜硬件条件不够,只能用中GPT模型来训练一个中文闲聊模型了。硬件条件好的小伙伴,推荐你们使用CPM来做,看别人做的效果还是不错的,清源CPM的模型有4.47G,小伙伴们量力而行,模型小编已经撸下来了:pytorch版CPM密码:k2zh; 百度飞桨版CPM密码:nb67;苏神版bert4keras,需要的小伙伴可自信下载。

一、数据结构

数据是在网上爬取的一些对话数据,数据格式如下所示:

你吃了吗?
吃过了,吃的糖醋排骨,你呢
我吃的是麻辣小龙虾

手机欠费了怎么办?
交话费啊
去哪里才能交话费呢
去相应的营业厅啊

数据格式就是一段对话,不同的对话间使用空行隔开。

二、模型搭建

(1) GPT2模型的搭建(小编自己写的,可能不是很合适!!有需要的可以直接调用Transformers库中的GPT2Model)
首先是Multi_Heads_Attention的代码,这个也没什么可说的,如下所示:

class Attention(nn.Module):def __init__(self, embedding_size, num_attention_heads, attention_dropout, residual_dropout):super(Attention, self).__init__()self.num_attention_heads = num_attention_headsself.size_per_head = embedding_size // num_attention_headsself.embedding_size = embedding_sizeself.query_key_value = nn.Linear(embedding_size, embedding_size * 3)self.attn_drop = nn.Dropout(attention_dropout)self.resid_drop = nn.Dropout(residual_dropout)self.dense = nn.Linear(embedding_size, embedding_size)def split_heads(self, x):"return shape [`batch`, `head`, `sequence`, `features`]"new_shape = x.size()[:-1] + (self.num_attention_heads, self.size_per_head)x = x.view(*new_shape)return x.permute(0, 2, 1, 3)def merge_heads(self, x):x = x.permute(0, 2, 1, 3).contiguous()new_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)return x.view(*new_shape)def forward(self, x, kv_cache=None):self.seq_len = x.size(1)# self_attentionx = self.query_key_value(x)q, k, v = x.split(self.embedding_size, dim=2)# 多头q = self.split_heads(q)k = self.split_heads(k)v = self.split_heads(v)cached_kv = torch.stack([k, v], dim=1)scores = torch.matmul(q, k.transpose(-2, -1))scores = scores / math.sqrt(self.size_per_head)attention_mask = torch.tril(torch.ones([self.seq_len, self.seq_len], dtype=torch.float32))# print("attention", attention_mask)attention_mask = attention_mask.reshape([1, 1, self.seq_len, self.seq_len])# print(1.0 - attention_mask)# print(scores * attention_mask)scores = scores * attention_mask - 10000.0 * (1.0 - attention_mask)# print(scores)scores = nn.Softmax(dim=-1)(scores)scores = self.attn_drop(scores)y = torch.matmul(scores, v)y = self.merge_heads(y)y = self.resid_drop(self.dense(y))return y, cached_kv

然后就是MLP部分的代码,也就是线性层:

# 构建线性转换层
class MLP(nn.Module):def __init__(self, embedding_size):super(MLP, self).__init__()self.dense_h_to_4h = nn.Linear(embedding_size, embedding_size * 4)self.dense_4h_to_h = nn.Linear(embedding_size * 4, embedding_size)self.act = nn.functional.geludef forward(self, x):h = self.act(self.dense_h_to_4h(x))h2 = self.dense_4h_to_h(h)return h2# 线性层测试
"""
layer = Linear(768, 768*3)
x = torch.rand(1,4,768) [batch_size, seq_len, dim]
y = layer(x)
print(y, y.shape) [1, 4, 2304]
"""

然后是写一个Block类把Multi_Heads_Attention和MLP进行整合:

class Block(nn.Module):def __init__(self, embedding_size, num_attention_heads, attention_dropout, residual_dropout):super(Block, self).__init__()self.input_layernorm = nn.LayerNorm(embedding_size, eps=1e-5)self.attention = Attention(embedding_size, num_attention_heads, attention_dropout, residual_dropout)self.post_attention_layernorm = nn.LayerNorm(embedding_size, eps=1e-5)self.mlp = MLP(embedding_size)def forward(self, x, kv_cache=None):# Attention + 前后的LayerNorm + 中间残差连接attn, cached_kv = self.attention(self.input_layernorm(x), kv_cache=kv_cache)x = x + attnz = self.post_attention_layernorm(x)# MLPz = self.mlp(z)# 残差连接x = x + zreturn x, cached_kv

通过调用Block类,构建Transformer模型:

class Transformer(nn.Module):def __init__(self,layer_size,embedding_size,num_attention_heads,attention_dropout,residual_dropout):super(Transformer, self).__init__()self.layers = nn.ModuleList([Block(embedding_size,num_attention_heads,attention_dropout,residual_dropout)for _ in range(layer_size)])self.final_layernorm = nn.LayerNorm(embedding_size, eps=1e-5)def forward(self, x, kv_cache=None):# 多层 Blockcached_kvs = []for i, layer in enumerate(self.layers):x, cached_kv = layer(x, kv_cache=kv_cache[i] if kv_cache is not None else None)cached_kvs.append(cached_kv)# 最终的 LayerNormx = self.final_layernorm(x)return x, torch.stack(cached_kvs)

最后就是GPT2Model了,在这个类中除了调用了Transformer类外,还定义了attention_mask矩阵:

class GPT2Model(nn.Module):def __init__(self, config):super(GPT2Model, self).__init__()# 定义字符嵌入层self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size)# 定义位置嵌入层self.position_embeddings = nn.Embedding(config.block_size, config.embedding_size)# 定义嵌入随机丢弃层self.emb_drop = nn.Dropout(config.embedding_dropout)# 定义 Transformer Encoderself.transformer = Transformer(config.layer_size,config.embedding_size,config.num_attention_heads,config.attention_dropout,config.residual_dropout)def forward(self, x, kv_cache=None, use_cache=False):# 根据缓存确定历史输入长度if kv_cache is None:past_length = 0else:past_length = kv_cache[0][0].shape[-2]# 生成位置编码position_ids = torch.arange(past_length, x.shape[-1] + past_length, dtype=torch.int64)position_ids = position_ids.unsqueeze(0).expand_as(x)# 计算嵌入层输出x = self.word_embeddings(x)x = self.emb_drop(x + self.position_embeddings(position_ids))# 计算 Transformer Encoder 输出x, cached_kvs = self.transformer(x, kv_cache)# 计算解码输出# 解码使用的参数为字符嵌入层参数的转置# 相当于做一个逆运算或者可以理解为使用相同的参数进行编码和解码x = torch.matmul(x, self.word_embeddings.weight.transpose(-1, -2))# 如果使用缓存则返回输出和缓存if use_cache:return x, cached_kvs# 否则只返回输出return x

(2) 数据处理,利用了torch自带的TensorDataset,把形成的Dataset转为Tensor形式。

# -*- coding: utf-8 -*-
"""
@Time    : 2021/4/17 16:28
@Author  : SinGaln
"""
import os
import torch
import logging
from tqdm import tqdm
from torch.utils.data import TensorDatasetlogger = logging.getLogger(__name__)class DataProcess(object):def __init__(self, args):self.args = argsself.data_file = "./data/train.txt"@classmethoddef _read_data_file(cls, input_file):logger.info("tokenizing raw data,raw data path:{}".format(input_file))with open(input_file, 'rb') as f:data = f.read().decode("utf-8")if "\r\n" in data:train_data = data.split("\r\n\r\n")else:train_data = data.split("\n\n")logger.info("there are {} dialogue in raw dataset".format(len(train_data)))return train_datadef get_examples(self, tokenizer):context = []train_data = self._read_data_file(self.data_file)for dialogue_index, dialogue in enumerate(tqdm(train_data)):utterances = dialogue.split("\n")dialogue_ids = [tokenizer.cls_token_id]  # 每个dialogue以[CLS]开头for utterance in utterances:dialogue_ids.extend([tokenizer.convert_tokens_to_ids(word) for word in utterance])dialogue_ids.append(tokenizer.sep_token_id)  # 每个utterance之后添加[SEP],表示utterance结束# 对超过n_ctx的长度进行截断,否则GPT2模型会报错if len(dialogue_ids) > self.args.max_seq_len:dialogue_ids = dialogue_ids[:self.args.max_seq_len]print("dialogue", len(dialogue_ids))else:dialogue_ids = dialogue_ids + ([0] * (self.args.max_seq_len - len(dialogue_ids)))context.append(dialogue_ids)logger.info("finish processing for raw data!")return contextprocessors = {"chat": DataProcess
}
def load_and_cache_examples(args, tokenizer):processor = processors[args.task](args)# Load data features from cache or dataset filecached_features_file = os.path.join(args.data_dir,'cached_train_{}_{}'.format(args.task,args.max_seq_len))if os.path.exists(cached_features_file):logger.info("Loading features from cached file %s", cached_features_file)features = torch.load(cached_features_file)else:# Load data features from dataset filelogger.info("Creating features from dataset file at %s", args.data_dir)features = processor.get_examples(tokenizer)logger.info("Saving features into cached file %s", cached_features_file)torch.save(features, cached_features_file)# Convert to Tensors and build datasetall_input_ids = torch.tensor(features, dtype=torch.long)dataset = TensorDataset(all_input_ids)return dataset

其他的文件可以到小编的Github上查看,注意代码在master分支上。

利用GPT2训练中文闲聊模型相关推荐

  1. PromptCLUE:大规模多任务Prompt预训练中文开源模型

    简介 PromptCLUE:大规模多任务Prompt预训练中文开源模型. 中文上的三大统一:统一模型框架,统一任务形式,统一应用方式.支持几十个不同类型的任务,具有较好的零样本学习能力和少样本学习能力 ...

  2. 基于GPT2的中文闲聊机器人/GPT2 for Chinese chitchat

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx 项目描述 本项目使用GPT2模型对中文闲聊语料进行训练,使用 HuggingFace的tran ...

  3. 利用PaddleOCR训练车牌识别模型

    目录 1--前言 2--生成车牌数据集 3--构建车牌数据集标签 4--自定义字典 5--训练模型 6--模型转换和推理 7--模型转换为onnx模型 8--参考 1--前言 ①系统:Ubuntu18 ...

  4. 如何调用 caffe 训练好的模型对输入图片进行测试

    如何调用 caffe 训练好的模型对输入图片进行测试 该部分包括两篇文章 win10 下 caffe 的第一个测试程序(附带详细讲解) 主要讲解如何利用 caffe 来训练模型. 如何调用 caffe ...

  5. GPT-2生成式多轮对话入门-----深入理解“用于中文闲聊的GPT2模型”项目

    UPDATE 2.28.2020 纠正之前文末的思维误区. 2.26.2020 增加了Jay Alammar The Illustrated GPT-2 博客的翻译 增加了关于Transformer你 ...

  6. DL之Attention-ED:基于TF NMT利用带有Attention的 ED模型训练、测试(中英文平行语料库)实现将英文翻译为中文的LSTM翻译模型过程全记录

    DL之Attention-ED:基于TF NMT利用带有Attention的 ED模型训练(中英文平行语料库)实现将英文翻译为中文的LSTM翻译模型过程全记录 目录 测试输出结果 模型监控 训练过程全 ...

  7. GPT模型介绍并且使用pytorch实现一个小型GPT中文闲聊系统

    文章目录 GPT模型介绍 无监督训练方式 模型结构 微调 下游任务输入形式 GPT-2 GPT-3 pytorch实现一个小型GPT中文闲聊系统 GPT模型介绍 GPT与BERT一样也是一种预训练模型 ...

  8. 《预训练周刊》第8期:首个千亿中文大模型「盘古」问世、谷歌等提出视频音频文本转换器VATT...

    No.08 智源社区 预训练组 预 训 练 研究 观点 资源 活动 关于周刊 超大规模预训练模型是当前人工智能领域研究的热点,为了帮助研究与工程人员了解这一领域的进展和资讯,智源社区整理了第8期< ...

  9. 把一个dataset的表放在另一个dataset里面_使用中文维基百科语料库训练一个word2vec模型并使用说明...

    ​本篇主要介绍如何通过中文维基百科语料库来训练一个word2vec模型. 相关资料下载: 中文维基百科下载地址:https://dumps.wikimedia.org/zhwiki/ WikiExtr ...

最新文章

  1. 死锁产生的原因和解锁的方法
  2. 设计模式(五)责任链模式
  3. Spark知识体系完整解读
  4. k8s pod MySQL环境变量_Kubernetes 配置Pod和容器(一)定义容器环境变量
  5. 显示画面 大华摄像头_大华乐橙智能锁荣获2020房地产首选供应商前十强
  6. 人工智能为什么要从本科生抓起?
  7. matlab 求解 Ax=B 时所用算法
  8. 【机器学习】输出层的设计
  9. Android 属性动画(三)
  10. 常见(MySQL)面试题(含答案)
  11. 批处理FOR 中的Delims和Tokens总结
  12. pdg文件格式 到 pdf文件格式 的转换
  13. 解决Windows照片查看器加载慢和颜色问题
  14. 《合作的进化》pdfmobiepub电子版
  15. ES stored fields作用
  16. python爬取下厨房每周最受欢迎菜谱
  17. 被顶级机构押注的6大新公链 公链之争谁更硬核?
  18. 微信小程序:开心锤锤超火动态表情包微信小程序源码下载自动采集
  19. Tecnomatiix PDPS数模数据格式转换方法
  20. 分享快速检测肖特基二极管的小窍门

热门文章

  1. 学java还是平面设计好_平面设计师精品课程
  2. raspberry pi(树莓派) + easycap d60 视频采集
  3. 第2章 带宽负担会降低人们的智商
  4. (带手机版数据同步)道路护栏交通设施类网站源码 城市基础设施设备网站织梦模板
  5. 从Linux源码看Socket(TCP)的bind
  6. python与bim_python的视图怎么调整?如何利用python进行BIM视图族类型的过滤
  7. zzw原创_非root用户启动apache的问题解决(非root用户启动apache的1024以下端口)
  8. Matlab下地形图绘图包m_map安装与使用
  9. oracle IO性能测试 -- calibrate_io
  10. 【Django | 开发】面试招聘信息网站(处理产品细节和权限美化页面样式)