公众号:深度学习视觉

前言

该工具追求着这样的一个目标,几行代码调用最先进的模型,加载训练好的模型参数,来完成自然语言项目,比如机器翻译、文本摘要、问答系统等。Transformers 同时支持 PyTorch 和TensorFlow2.0,用户可以将这些工具放在一起使用。

支持模型

transformers目前提供以下NLU / NLG体系结构:BERT、GPT、GPT-2、Transformer-XL、XLNet、XLM、RoBERTa、DistilBERT、CTRL、CamemBERT、ALBERT、T5、XLM-RoBERTa、MMBT、FlauBERT、其他社区的模型

安装PyTorch-Transformers

pip install pytorch-transformers

使用GPT-2预测下一个单词

GPT-2是一种于基于transformer的生成语言模型,其语言生成能力优秀到被讨论禁止开源。该模型是在40GB的文本下进行无监督训练。

# 导入必要的库

import torch

from pytorch_transformers import GPT2Tokenizer, GPT2LMHeadModel

# 加载预训练模型tokenizer (vocabulary)

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# 对文本输入进行编码

text = "What is the fastest car in the"

indexed_tokens = tokenizer.encode(text)

# 在PyTorch张量中转换indexed_tokens

tokens_tensor = torch.tensor([indexed_tokens])

# 加载预训练模型 (weights)

model = GPT2LMHeadModel.from_pretrained('gpt2')

#将模型设置为evaluation模式,关闭DropOut模块

model.eval()

# 如果你有GPU,把所有东西都放在cuda上

tokens_tensor = tokens_tensor.to('cuda')

model.to('cuda')

# 预测所有的tokens

with torch.no_grad():

outputs = model(tokens_tensor)

predictions = outputs[0]

# 得到预测的单词

predicted_index = torch.argmax(predictions[0, -1, :]).item()

predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])

# 打印预测单词

print(predicted_text)

预测长文本

!git clone https://github.com/huggingface/pytorch-transformers.git

# 启动模型

!python pytorch-transformers/examples/run_generation.py \

--model_type=gpt2 \

--length=100 \

--model_name_or_path=gpt2 \

输入(本来是英文)

在一个令人震惊的发现中,科学家发现了一群独角兽,它们生活在安第斯山脉一个偏远的,以前未被开发的山谷中。对于研究人员而言,更令人惊讶的是,独角兽会说完美的英语。

输出(本来是英文)

独角兽似乎和普通人一样了解彼此。该研究于5月6日发表在《科学转化医学》上。此外,研究人员发现,百分之五的独角兽彼此之间具有很好的识别性。研究团队认为,这可能会转化为未来,使人类能够与称为超级独角兽的人进行更清晰的交流。如果我们要朝着那个未来前进,我们至少必须做到

除了GPT-2以外,还有诸如XLNet,一个在包括问答、自然语言推理、情感分析和文档排序等18项任务上取得了最先进结果的模型。

!python pytorch-transformers/examples/run_generation.py \

--model_type=xlnet \

--length=50 \

--model_name_or_path=xlnet-base-cased \

还有能够学习长期依赖的Transformer-XL,比标准Transformer快1800倍。

!python pytorch-transformers/examples/run_generation.py \

--model_type=transfo-xl \

--length=100 \

--model_name_or_path=transfo-xl-wt103 \

Transformers API调用示例代码(收藏)

import torch

from transformers import *

# transformer有一个统一的API

# 有10个Transformer结构和30个预训练权重模型。

#模型|分词|预训练权重

MODELS = [(BertModel, BertTokenizer, 'bert-base-uncased'),

(OpenAIGPTModel, OpenAIGPTTokenizer, 'openai-gpt'),

(GPT2Model, GPT2Tokenizer, 'gpt2'),

(CTRLModel, CTRLTokenizer, 'ctrl'),

(TransfoXLModel, TransfoXLTokenizer, 'transfo-xl-wt103'),

(XLNetModel, XLNetTokenizer, 'xlnet-base-cased'),

(XLMModel, XLMTokenizer, 'xlm-mlm-enfr-1024'),

(DistilBertModel, DistilBertTokenizer, 'distilbert-base-cased'),

(RobertaModel, RobertaTokenizer, 'roberta-base'),

(XLMRobertaModel, XLMRobertaTokenizer, 'xlm-roberta-base'),

]

# 要使用TensorFlow 2.0版本的模型,只需在类名前面加上“TF”,例如。“TFRobertaModel”是TF2.0版本的PyTorch模型“RobertaModel”

# 让我们用每个模型将一些文本编码成隐藏状态序列:

for model_class, tokenizer_class, pretrained_weights in MODELS:

# 加载pretrained模型/分词器

tokenizer = tokenizer_class.from_pretrained(pretrained_weights)

model = model_class.from_pretrained(pretrained_weights)

# 编码文本

input_ids = torch.tensor([tokenizer.encode("Here is some text to encode", add_special_tokens=True)]) # 添加特殊标记

with torch.no_grad():

last_hidden_states = model(input_ids)[0] # 模型输出是元组

# 每个架构都提供了几个类,用于对下游任务进行调优,例如。

BERT_MODEL_CLASSES = [BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction,

BertForSequenceClassification, BertForTokenClassification, BertForQuestionAnswering]

# 体系结构的所有类都可以从该体系结构的预训练权重开始

#注意,为微调添加的额外权重只在需要接受下游任务的训练时初始化

pretrained_weights = 'bert-base-uncased'

tokenizer = BertTokenizer.from_pretrained(pretrained_weights)

for model_class in BERT_MODEL_CLASSES:

# 载入模型/分词器

model = model_class.from_pretrained(pretrained_weights)

# 模型可以在每一层返回隐藏状态和带有注意力机制的权值

model = model_class.from_pretrained(pretrained_weights,

output_hidden_states=True,

output_attentions=True)

input_ids = torch.tensor([tokenizer.encode("Let's see all hidden-states and attentions on this text")])

all_hidden_states, all_attentions = model(input_ids)[-2:]

#模型与Torchscript兼容

model = model_class.from_pretrained(pretrained_weights, torchscript=True)

traced_model = torch.jit.trace(model, (input_ids,))

# 模型和分词的简单序列化

model.save_pretrained('./directory/to/save/') # 保存

model = model_class.from_pretrained('./directory/to/save/') # 重载

tokenizer.save_pretrained('./directory/to/save/') # 保存

tokenizer = BertTokenizer.from_pretrained('./directory/to/save/') # 重载

import tensorflow as tf

import tensorflow_datasets

from transformers import *

# 从预训练模型/词汇表中加载数据集、分词器、模型

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

model = TFBertForSequenceClassification.from_pretrained('bert-base-cased')

data = tensorflow_datasets.load('glue/mrpc')

# 准备数据集作为tf.data.Dataset的实例

train_dataset = glue_convert_examples_to_features(data['train'], tokenizer, max_length=128, task='mrpc')

valid_dataset = glue_convert_examples_to_features(data['validation'], tokenizer, max_length=128, task='mrpc')

train_dataset = train_dataset.shuffle(100).batch(32).repeat(2)

valid_dataset = valid_dataset.batch(64)

# 准备训练:编写tf.keras模型与优化,损失和学习率调度

optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')

model.compile(optimizer=optimizer, loss=loss, metrics=[metric])

# 用tf.keras.Model.fit进行测试和评估

history = model.fit(train_dataset, epochs=2, steps_per_epoch=115,

validation_data=valid_dataset, validation_steps=7)

# 在PyTorch中加载TensorFlow模型进行检查

model.save_pretrained('./save/')

pytorch_model = BertForSequenceClassification.from_pretrained('./save/', from_tf=True)

#让我们看看我们的模型是否学会了这个任务

sentence_0 = "This research was consistent with his findings."

sentence_1 = "His findings were compatible with this research."

sentence_2 = "His findings were not compatible with this research."

inputs_1 = tokenizer.encode_plus(sentence_0, sentence_1, add_special_tokens=True, return_tensors='pt')

inputs_2 = tokenizer.encode_plus(sentence_0, sentence_2, add_special_tokens=True, return_tensors='pt')

pred_1 = pytorch_model(inputs_1['input_ids'], token_type_ids=inputs_1['token_type_ids'])[0].argmax().item()

pred_2 = pytorch_model(inputs_2['input_ids'], token_type_ids=inputs_2['token_type_ids'])[0].argmax().item()

print("sentence_1 is", "a paraphrase" if pred_1 else "not a paraphrase", "of sentence_0")

print("sentence_2 is", "a paraphrase" if pred_2 else "not a paraphrase", "of sentence_0")

优秀的python库_一个优秀Python库,轻松吟诗作对写文章!相关推荐

  1. python算法工程师需要学什么_一个优秀的算法工程师必须具备哪些素质?

    导言 怎样成为一名优秀的算法工程师?这是很多从事人工智能学术研究和产品研发的同学都关心的一个问题.面对市场对人才的大量需求与供给的严重不足,以及高薪水的诱惑,越来越多的人开始学习这个方向的技术,或者打 ...

  2. python 均方误差_一个很随意的Python智能优化库,一个文件就是一个库-- PySwarm

    之前无聊做了个简单的Python智能算法库的小总结:Python智能优化算法库小汇总 .当时没注意到有一个库PySwarms是基于另外一个小库 PySwarm开发的. 这个库非常有意思,整个库只依赖N ...

  3. python 题库自动答题,自动匹配题库_如何用python写一个从题库自动匹配的答题脚本_淘题吧...

    A. web数据库题目:根据用户输入的用户名和密码于数据库中的记录是否匹配制作一个用户登录模块 http://blog.csdn.net/love_leve/article/details/43226 ...

  4. python自动控制库_一个可以自动化控制鼠标键盘的库:PyAUtoGUI

    PyAutoGUI 不知道你们有没有用过,它是一款用Python自动化控制键盘.鼠标的库.但凡是你不想手动重复操作的工作都可以用这个库来解决. 如果,我想半夜时候定时给发个微信,或者每天自动刷页面等操 ...

  5. python语言是一个优秀的面向对象语言_python是面向对象的语言吗

    Python从设计之初就已经是一门面向对象的语言,正因为如此,在Python中创建一个类和对象是很容易的. 面向对象技术简介(推荐学习:Python视频教程) 类(Class): 用来描述具有相同的属 ...

  6. 谷歌设计规范_一个优秀的UI设计师整理的设计规范汇总,收藏起来!

    写在前面很多新人在开始做移动端UI设计的时候,往往对界面的一些尺寸规范不是十分清楚,很多时候都是凭借自己的感觉和经验去绘制界面,心里并没有一个清晰的概念,导致做出来的页面总是不那么尽如人意.本文整理汇 ...

  7. python 人工智能库_人工智能与Python库的关系

    目前人工智能技术发展速度很快,也很吸引眼球.但是对于各种多如牛毛的方法,目前并米有一个可靠的精准的基准来衡量各项硬件在不同算法训练和推理的性能. 现在,不用愁了.国外的一个哥们, Andrey Ign ...

  8. python数据处理_时间序列数据处理python 库

    [注]本人原创,最初发表于CSDN,后发布于知乎.为避免误会为抄袭,特此说明 由于我热衷于机器学习在时间序列中的应用,特别是在医学检测和分类中,在尝试的过程中,一直在寻找优质的Python库(而不是从 ...

  9. python中的urllib库_七、urllib库(一)

    python2中,有urllib和urllib2两个库,在python3中统一为urllib库 它是python内置的HTTP请求库,包含了4个模块: request:最基本的HTTP请求模块,用来模 ...

最新文章

  1. win10管理员已阻止你运行此应用”解决方法
  2. arm交叉编译器gnueabi、none-eabi、arm-eabi、gnueabihf的区别
  3. sql server 2008学习2 文件和文件组
  4. pxe安装系统 ip获取错误_【图说】消防系统安装典型错误举例
  5. pcb过孔漏铜_为什么PCB板在生产中会铜线脱落?
  6. 阿里云服务器如何创建快照备份数据
  7. 物联网卡加持智能电网,发展更具优势
  8. oracle 全局搜索字符串,oracle操作字符串:拼接、替换、截取、查找 _ 学编程-免费技术教程分享平台...
  9. python3什么意思_python3是什么意思啊
  10. 摩拜前端周刊第15期
  11. 【09-06】数据结构学习笔记-图篇00
  12. Python使用requests发送post请求
  13. 中国电网计算机面试题目,国家电网面试经验
  14. 一键刷入twrp_小米红米如何正确TWRP卡刷MIUI12波兰版或者欧版等系统详细教程
  15. 常用的浏览器及其内核
  16. 基于OpenCV实现简单人脸面具、眼镜、胡须、鼻子特效(详细步骤 + 源码)
  17. 部署超级账本fabric区块可视化浏览器
  18. 如何 增删改查 XML文件中的元素
  19. OpenCV4.x图像处理实例-工地安全帽反光衣穿戴检测
  20. E4A易安卓计次循环和变量循环及数组

热门文章

  1. 敏之澳电商:拼多多打造爆款的具体步骤
  2. 什么是adsl动态拨号服务器?
  3. 无敌破坏王-高清在线观看
  4. 小米 红米5A 解BL锁教程 申请BootLoader解锁教程
  5. fbx模型导入unity,绑了骨骼加蒙皮法线就反
  6. 看看老牛是如何给陈彤写的信的
  7. [论文学习]Private traits and attributes are predictable from digital records of human behavior
  8. 输出问候语(PTA厦大慕课)
  9. 用原生js+html写一个像素鸟游戏
  10. 万人千题第一阶段报告【待继续总结】