文章目录

  • HuggingFace Transformers
    • Tokenizer
    • Model
    • 下游任务

HuggingFace Transformers

使用BERT和其他各类Transformer模型,绕不开HuggingFace提供的Transformers生态。HuggingFace提供了各类BERT的API(transformers库)、训练好的模型(HuggingFace Hub)还有数据集(datasets)。最初,HuggingFace用PyTorch实现了BERT,并提供了预训练的模型,后来。越来越多的人直接使用HuggingFace提供好的模型进行微调,将自己的模型共享到HuggingFace社区。HuggingFace的社区越来越庞大,不仅覆盖了PyTorch版,还提供TensorFlow版,主流的预训练模型都会提交到HuggingFace社区,供其他人使用。

使用transformers库进行微调,主要包括:

  • Tokenizer:使用提供好的Tokenizer对原始文本处理,得到Token序列;
  • 构建模型:在提供好的模型结构上,增加下游任务所需预测接口,构建所需模型;
  • 微调:将Token序列送入构建的模型,进行训练。

Tokenizer

下面两行代码会创建 BertTokenizer,并将所需的词表加载进来。首次使用这个模型时,transformers 会帮我们将模型从HuggingFace Hub下载到本地。

from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

用得到的tokenizer进行分词:

encoded_input = tokenizer("我是一句话")
print(encoded_input)

输出

{'input_ids': [101, 2769, 3221, 671, 1368, 6413, 102],
'token_type_ids': [0, 0, 0, 0, 0, 0, 0],
'attention_mask': [1, 1, 1, 1, 1, 1, 1]}

得到的一个Python dict。其中,input_ids最容易理解,它表示的是句子中的每个Token在词表中的索引数字。词表(Vocabulary)是一个Token到索引数字的映射。可以使用decode()方法,将索引数字转换为Token。

 tokenizer.decode(encoded_input["input_ids"])

输出

'[CLS] 我 是 一 句 话 [SEP]'

可以看到,BertTokenizer在给原始文本处理时,自动给文本加上了[CLS]和[SEP]这两个符号,分别对应在词表中的索引数字为101和102。decode()之后,也将这两个符号反向解析出来了。

token_type_ids主要用于句子对,比如下面的例子,两个句子通过[SEP]分割,0表示Token对应的input_ids属于第一个句子,1表示Token对应的input_ids属于第二个句子。不是所有的模型和场景都用得上token_type_ids。

encoded_input = tokenizer("您贵姓?", "免贵姓李")
print(encoded_input)
{'input_ids': [101, 2644, 6586, 1998, 136, 102, 1048, 6586, 1998, 3330, 102],
'token_type_ids': [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
text = '[CLS]'
print('Original Text : ', text)
print('Tokenized Text: ', tokenizer.tokenize(text))
print('Token IDs     : ', tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)))
print('\n')text = '[SEP]'
print('Original Text : ', text)
print('Tokenized Text: ', tokenizer.tokenize(text))
print('Token IDs     : ', tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)))
print('\n')text = '[PAD]'
print('Original Text : ', text)
print('Tokenized Text: ', tokenizer.tokenize(text))
print('Token IDs     : ', tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)))
Original Text :  [CLS]
Tokenized Text:  ['[CLS]']
Token IDs     :  [101]Original Text :  [SEP]
Tokenized Text:  ['[SEP]']
Token IDs     :  [102]Original Text :  [PAD]
Tokenized Text:  ['[PAD]']
Token IDs     :  [0]

句子通常是变长的,多个句子组成一个Batch时,attention_mask就起了至关重要的作用。

batch_sentences = ["我是一句话", "我是另一句话", "我是最后一句话"]
batch = tokenizer(batch_sentences, padding=True, return_tensors="pt")
print(batch)
{'input_ids': tensor([[ 101, 2769, 3221,  671, 1368, 6413,  102,    0,    0],[ 101, 2769, 3221, 1369,  671, 1368, 6413,  102,    0],[ 101, 2769, 3221, 3297, 1400,  671, 1368, 6413,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1, 0],[1, 1, 1, 1, 1, 1, 1, 1, 1]])}

对于这种batch_size = 3的场景,不同句子的长度是不同的,padding=True表示短句子的结尾会被填充[PAD]符号,return_tensors="pt"表示返回PyTorch格式的Tensor。attention_mask告诉模型,哪些Token需要被模型关注而加入到模型训练中,哪些Token是被填充进去的无意义的符号,模型无需关注。

Model

下面两行代码会创建BertModel,并将所需的模型参数加载进来。

from transformers import BertModel
model = BertModel.from_pretrained("bert-base-chinese")

BertModel是一个PyTorch中用来包裹网络结构的torch.nn.Module,BertModel里有forward()方法,forward()方法中实现了将Token转化为词向量,再将词向量进行多层的Transformer Encoder的复杂变换。
forward()方法的入参有input_ids、attention_mask、token_type_ids等等,这些参数基本上是刚才Tokenizer部分的输出。

bert_output = model(input_ids=batch['input_ids'])

forward()方法返回模型预测的结果,返回结果是一个tuple(torch.FloatTensor),即多个Tensor组成的tuple。tuple默认返回两个重要的Tensor:

len(bert_output)
2
  • last_hidden_state:输出序列每个位置的语义向量,形状为:(batch_size, sequence_length, hidden_size)。
  • pooler_output:[CLS]符号对应的语义向量,经过了全连接层和tanh激活;该向量可用于下游分类任务。

下游任务

BERT可以进行很多下游任务,transformers库中实现了一些下游任务,我们也可以参考transformers中的实现,来做自己想做的任务。比如单文本分类,transformers库提供了BertForSequenceClassification类。

class BertForSequenceClassification(BertPreTrainedModel):def __init__(self, config):super().__init__(config)self.num_labels = config.num_labelsself.config = configself.bert = BertModel(config)classifier_dropout = ...self.dropout = nn.Dropout(classifier_dropout)self.classifier = nn.Linear(config.hidden_size, config.num_labels)...def forward(...):...outputs = self.bert(...)pooled_output = outputs[1]pooled_output = self.dropout(pooled_output)logits = self.classifier(pooled_output)...

在这段代码中,BertForSequenceClassification在BertModel基础上,增加了nn.Dropout和nn.Linear层,在预测时,将BertModel的输出放入nn.Linear,完成一个分类任务。除了BertForSequenceClassification,还有BertForQuestionAnswering用于问答,BertForTokenClassification用于序列标注,比如命名实体识别。

transformers 中的各个API还有很多其他参数设置,比如得到每一层Transformer Encoder的输出等等,可以访问他们的文档查看使用方法。

如何使用HuggingFace训练Transformer相关推荐

  1. BERT(预训练Transformer模型)

    目录 一.前言 二.随机遮挡,进行预测 三.两句话是否原文相邻 四.两者结合起来 五.总结 六.参考链接 一.前言 Bert在18年提出,19年发表,Bert的目的是为了预训练Transformer模 ...

  2. Achuan读论文:用于远程监督关系抽取的微调预训练transformer语言模型

    Fine-tuning Pre-Trained Transformer Language Models to Distantly Supervised Relation Extraction 用于远程 ...

  3. 文本生成与自动摘要:基于生成式预训练Transformer的实现与优化

    作者:禅与计算机程序设计艺术 1.简介 文本生成是自然语言处理领域中非常重要的问题之一.在不断地探索学习新知识和技能的同时,越来越多的人也需要通过自己创造或整合的手段,将自己的想法.观点和信息转化成语 ...

  4. 【NLP】第 18 章从零开始训练 Transformer

  5. 赠书 | 新手指南——如何通过HuggingFace Transformer整合表格数据

    作者 | Ken Gu 翻译| 火火酱~,责编 | 晋兆雨 出品 | AI科技大本营 头图 | 付费下载于视觉中国 *文末有赠书福利 不可否认,Transformer-based模型彻底改变了处理非结 ...

  6. 基于Huggingface的预训练语言模型分类体系及实战

    一.预训练语言模型体系初印象 1.预训练模型体系 随着预训练模型被提出,自然语言处理领域有了突飞猛进的发展,通过在大规模文本中训练通用的语言表示,并用微调的方式进行下游任务的领域适应,绝大多数的自然语 ...

  7. 预训练图像处理Transformer

    机器之心发布 机器之心编辑部 作为自然语言处理领域的主流模型,Transformer 近期频频出现在计算机视觉领域的研究中.例如 OpenAI 的 iGPT.Facebook 提出的 DETR 等,这 ...

  8. 预训练图像处理Transformer:刷榜多项底层视觉任务

    来源|机器之心 作为自然语言处理领域的主流模型,Transformer 近期频频出现在计算机视觉领域的研究中.例如 OpenAI 的 iGPT.Facebook 提出的 DETR 等,这些跨界模型多应 ...

  9. ICML2020 | 伯克利提出大模型提升Transformer的训练和推理效率

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作分享,不代表本公众号立场,侵权联系删除 转载于:专知 AI博士笔记系列推荐 周志华<机器学习>手推笔 ...

最新文章

  1. 各方评论《面向儿童的人工智能北京共识》:这是中国人工智能发展轨迹中的一份关键文件...
  2. 刚获20亿投资的通用无人车Cruise估值达300亿美元,叫板Waymo!
  3. V3S中默认时区设置(笔记)
  4. mysql内部_使用mysql中的内部加入
  5. J2EE软件开发视频教程
  6. 漫步最优化三十——非精确线搜索
  7. 缓存类java_用Java写一个简单的缓存操作类
  8. JavaScript in Action
  9. axure后台示例_【Axure电商案例】如何设计和真的后台一样给客户看
  10. (大数据应用考察)全国水资源分析可视化
  11. 项目Beta冲刺(4/7)(追光的人)(2019.5.26)
  12. Comparator.comparing排序使用示例
  13. 别被别人抢走了你的时间
  14. 固定资产拆分比例怎么计算_资产拆分
  15. matlab函数anova,MATLAB进行单因素方差分析——ANOVA
  16. 计算机错误651是什么故障,宽带连接错误显示代码651是什么原因 宽带连接错误651的解决方法...
  17. 科技爱好者周刊(第 114 期):U 盘化生存和 Uber-job
  18. 吞下西甲英超中超成体育大胃王,PPTV还有哪些大招?
  19. 【原创】从头开始,使用安卓系统WebView做一个功能强大的Epub阅读器(二)
  20. setTimeout原理

热门文章

  1. 解决springboot maven多模块项目打包的时候某个被依赖的模块报错找不到main class
  2. 警惕! ”黑帽子“的社会工程学攻击
  3. 如何把DOC文档以网页的形式打开
  4. [codeforces 1327E] Count The Blocks 打表找规律+根据规律找公式+优化公式
  5. (爱思创题解)小书童——刷题大军
  6. DevSuite如何助力企业实施GJB5000A
  7. 阿里OSS自定义域名无法访问。404报错
  8. 漫天风口,一地泡沫,消费机器人四年跌宕史
  9. android极光推送声音,解决极光推送后台接收到推送消息,无提示音不震动
  10. 五线谱里震音记号是什么?