数据集:SST-2
论文地址:https://arxiv.org/abs/1810.04805
github(pytorch): https://github.com/huggingface/pytorch-transformers
github(tensorflow): https://github.com/google-research/bert

Step 1 模型下载

下载pretrained Tensorflow model https://github.com/google-research/bert#pre-trained-models

Step 2 模型转化

将 tensorflow model 转换为 pytorch

python3 convert_tf_checkpoint_to_pytorch.py--tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt \--bert_config_file $BERT_BASE_DIR/bert_config.json \--pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin

Step 3 代码详解

1. DataProcessor

# 读取文件的基本类
class DataProcessor(object):"""Base class for data converters for sequence classification data sets. """def get_train_examples(self, data_dir):"""Gets a collection of `InputExample`s for the train set."""raise NotImplementedError()def get_dev_examples(self, data_dir):"""Gets a collection of `InputExample`s for the dev set."""raise NotImplementedError()def get_labels(self):"""Gets the list of labels for the train set."""raise NotImplementedError()@classmethoddef _read_tsv(cls, input_file, quotechar=None):'''read a seqarated value file'''with open(input_file, 'r', encoding='utf-8-sig') as f:reader = csv.reader(f, delimiter='\t', quotechar=quotechar)lines = []for line in reader:lines.append(line)return lines
class SstProcess(DataProcessor):''' processer for SST-2 dataset '''def get_train_examples(self, data_dir):"""Gets a collection of `InputExample`s for the train set."""return self._create_examples(self._read_tsv(os.path.join(data_dir, 'train.tsv')), 'train')def get_dev_examples(self, data_dir):"""Gets a collection of `InputExample`s for the dev set."""return self._create_examples(self._read_tsv(os.path.join(data_dir, 'dev.tsv')), 'dev')def get_labels(self):"""Gets the list of labels for the train set.""""""SST-2"""return ['0','1']def _create_examples(self, lines, set_type):''' Create examples for the training and dev sets'''examples = []for i, line in enumerate(lines):if i == 0:continueguid = '%s-%s' % (set_type, i)text_a = line[0]label = line[1]examples.append(InputExample(guid=guid, text_a=text_a, label=label))return examples

2. convert_examples_to_features

def convert_examples_to_features(examples, label_list, max_seq_length,tokenizer, output_mode,cls_token_at_end=False, pad_on_left=False,cls_token='[CLS]', sep_token='[SEP]', pad_token=0,sequence_a_segment_id=0, sequence_b_segment_id=1,cls_token_segment_id=1, pad_token_segment_id=0,mask_padding_with_zero=True):"""Loads a data file into a list of `InputBatch`s.Args:examples: InputExample, 表示样本集label_list: 标签列表max_seq_length: 句子最大长度tokenizer: 分词器Returns:features: InputFeatures, 表示样本转化后信息"""label_map = {label:i for i, label in enumerate(label_list)}features = []for (ex_index, example) in enumerate(examples):if ex_index % 10000 == 0:logger.info("Writing example %d of %d" % (ex_index, len(examples)))tokens_a = tokenizer.tokenize(example.text_a)tokens_b = Noneif example.text_b:tokens_b = tokenizer.tokenize(example.text_b)# Modifies `tokens_a` and `tokens_b` in place so that the total# length is less than the specified length.# Account for [CLS], [SEP], [SEP] with "- 3"_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)else:# Account for [CLS] and [SEP] with "- 2"# 此处因为只有CLS 和SEP 即token_a & label 没有 token_b 所以-2if len(tokens_a) > max_seq_length - 2:tokens_a = tokens_a[:(max_seq_length - 2)]# [CLS] 可以视作是保存句子全局向量信息# [SEP] 用于区分句子,使得模型能够更好的把握句子信息tokens = tokens_a + [sep_token]segment_ids = [sequence_a_segment_id] * len(tokens)if tokens_b:tokens += tokens_b + [sep_token]segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)# CLS在句子的前面还是后面 bert 在前面 xlnet在后面if cls_token_at_end:tokens = tokens + [cls_token]segment_ids = segment_ids + [cls_token_segment_id]else:tokens = [cls_token] + tokenssegment_ids = [cls_token_segment_id] + segment_idsinput_ids = tokenizer.convert_tokens_to_ids(tokens)# The mask has 1 for real tokens and 0 for padding tokens. Only real# tokens are attended to.input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)# Zero-pad up to the sequence length.padding_length = max_seq_length - len(input_ids)# PAD在句子的左边还是右边 bert的都在右边 xlnet在左边if pad_on_left:input_ids = ([pad_token] * padding_length) + input_idsinput_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_masksegment_ids = ([pad_token_segment_id] * padding_length) + segment_idselse:input_ids = input_ids + ([pad_token] * padding_length)input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)assert len(input_ids) == max_seq_lengthassert len(input_mask) == max_seq_lengthassert len(segment_ids) == max_seq_lengthif output_mode == "classification":label_id = label_map[example.label]elif output_mode == "regression":label_id = float(example.label)else:raise KeyError(output_mode)features.append(InputFeatures(input_ids=input_ids,input_mask=input_mask,segment_ids=segment_ids,label_id=label_id))return featuresdef _truncate_seq_pair(tokens_a, tokens_b, max_length):# """截断句子a和句子b,使得二者之和不超过 max_length"""# 此处可以改进 25% 75% 效果更好"""Truncates a sequence pair in place to the maximum length."""while True:total_length = len(tokens_a) + len(tokens_b)if total_length <= max_length:breakif len(tokens_a) > len(tokens_b):tokens_a.pop()else:tokens_b.pop()

3. bert family

BertConfig

config = BertConfig.from_pretrained(bert_config_path, num_label=2, finetuning_task='sst-2')

BertTokenizer

tokenizer = BertTokenizer.from_pretrained(bert_model_path, do_lower_case=True) #如果使用uncase的模型 选择True 否则选择False

BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained(bert_model_path, config=config)

AdamW & WarmupLinearSchedule

no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=args.adam_epsilon)scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=t_total)

此处注意 t_total 这个值

t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

gradient_accumulation_steps 为 做梯度累积的值 一般为3-8 相当于过这么多次进行一次清零 等于将batch_size扩大了n倍 可以节约显存
num_train_epochs 就是一共需要训练的epochs次数

Bert—SST-2相关推荐

  1. NLP专栏|图解 BERT 预训练模型!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:张贤,哈尔滨工程大学,Datawhale原创作者 本文约7000字 ...

  2. BERT中的黑暗秘密

    2020-01-30 17:00:34 作者:Anna Rogers 编译:ronghuaiyang 导读 在finetune BERT的时候发生了什么? 这篇博客文章总结了我们EMNLP 2019年 ...

  3. 图解当前最强语言模型BERT:NLP是如何攻克迁移学习的?

    选自jalammar.github.io 作者:Jay Alammar 机器之心编译 参与:Panda 前段时间,谷歌发布了基于双向 Transformer 的大规模预训练语言模型 BERT,该预训练 ...

  4. 迁移学习NLP:BERT、ELMo等直观图解

    2018年是自然语言处理的转折点,能捕捉潜在意义和关系的方式表达单词和句子的概念性理解正在迅速发展.此外,NLP社区已经出现了非常强大的组件,你可以在自己的模型和管道中自由下载和使用(它被称为NLP的 ...

  5. 【NLP】图解 BERT 预训练模型!

    作者:张贤,哈尔滨工程大学,Datawhale原创作者 本文约7000字,NLP专栏文章,建议收藏阅读 审稿人:Jepson,Datawhale成员,毕业于中国科学院,目前在腾讯从事推荐算法工作. 结 ...

  6. AI基础:一文看懂BERT

    0.导语 自google在2018年10月底公布BERT在11项nlp任务中的卓越表现后,BERT(Bidirectional Encoder Representation from Transfor ...

  7. 论文浅尝 | 使用孪生BERT网络生成句子的嵌入表示

    论文笔记整理:吴杨,浙江大学计算机学院,知识图谱.NLP方向. https://www.ctolib.com/https://arxiv.org/abs/1908.10084 动机 谷歌的 BERT ...

  8. 基于Transformers库的BERT模型:一个文本情感分类的实例解析

    简介 本文来讲述BERT应用的一个例子,采用预训练好的BERT模型来进行演示.BERT的库来源于Transformers,这是一个由PyTorch编写的库,其集成了多个NLP领域SOTA的模型,比如b ...

  9. Bert入门:使用Bert运行MRPC的demo成功案例

    一.tensorflow版本必须是2.0以下 我的版本 import sys import numpy as np import tensorflow as tf print('python版本是:' ...

  10. BERT: 理解上下文的语言模型

    BERT 全名为 Bidrectional Encoder Representations from Transformers, 是 Google 以无监督的方式利用大量无标注文本生成的语言代表模型, ...

最新文章

  1. Codeforces Round #649 (Div. 2)C. Ehab and Prefix MEXs[排列的构造]
  2. 内存溢出_容易造成单片机内存溢出的几个陷阱
  3. B/S,C/S简单介绍
  4. Android开发中adb命令的常用方法
  5. 今日推荐:如何设计一个支撑数亿用户的系统
  6. 一文带你看懂分布式软总线在家庭场景的应用
  7. react-redux中的持久化数据存储redux-persist
  8. js获取页面的各种高度与宽度
  9. python运维监控脚本_Python实现数通设备端口使用情况监控实例
  10. ANR 问题一般解决思路
  11. BZOJ4012[HNOI2015]开店——树链剖分+可持久化线段树/动态点分治+vector
  12. Keli Proteus 8 的使用教程
  13. logback 自定义PatternLayout
  14. 如何用算法预测世界杯?
  15. “人生就像滚雪球,重要的是发现很湿的雪和很长的坡。”+复利的力量
  16. Unity小游戏教程系列 | 创建小型太空射击游戏(4)
  17. 2019年云计算发展趋势,今年十大云计算趋势
  18. Java:基础 :集合和迭代器
  19. (转)Django新手需要注意的10个要点
  20. labelmx条码打印软件如何批量制作服装吊牌

热门文章

  1. 用 Java 生成和识别二维码就这么简单
  2. python电影爬虫背景介绍_python爬虫-爬虫电影八佰词云
  3. CSAPP期末复习(更新ing)
  4. 项目整体管理(6个过程:制定项目章程,制定项目管理计划,指导与管理项目工作,实施整体变更控制,结束项目或阶段)
  5. 西门子200smart模拟量滤波防抖PLC程序,能实现电流电压和热电阻模拟量信号的采集
  6. 基于android的电子词典设计_基于Android的电子词典的设计
  7. CynosDB 与传统数据库有什么不同?CynosDB 的兼容性怎么样?
  8. 无线临ftp服务器1.3,Cerberus FTP Server Enterprise(FTP服务器管理工具)V11.3.1.1 最新版
  9. yum 安装mysql 后 which is not functionally dependent on columns in GROUP BY clause; this is incompatibl
  10. 开学季准备什么蓝牙耳机好?五款性价比高的蓝牙耳机品牌推荐