目录

前言

先从分类说起 (run_classifeir.py文件)


前言

BERT(Bidirectional Encoder Representations from Transformers)近期提出之后,作为一个Word2Vec的替代者,其在NLP领域的11个方向大幅刷新了精度,可以说是近年来自残差网络最优突破性的一项技术了。论文的主要特点以下几点:

  1. 使用了Transformer 作为算法的主要框架,Trabsformer能更彻底的捕捉语句中的双向关系;
  2. 使用了Mask Language Model(MLM)  和 Next Sentence Prediction(NSP) 的多任务训练目标;
  3. 使用更强大的机器训练更大规模的数据,使BERT的结果达到了全新的高度,并且Google开源了BERT模型,用户可以直接使用BERT作为Word2Vec的转换矩阵并高效的将其应用到自己的任务中。

ERT的本质上是通过在海量的语料的基础上运行自监督学习方法为单词学习一个好的特征表示,所谓自监督学习是指在没有人工标注的数据上运行的监督学习。在以后特定的NLP任务中,我们可以直接使用BERT的特征表示作为该任务的词嵌入特征。所以BERT提供的是一个供其它任务迁移学习的模型,该模型可以根据任务微调或者固定之后作为特征提取器。

先扔出开源代码链接:pytorch、TensorFlow

附上关于BERT的资料汇总:BERT相关论文、文章和代码资源汇总

先从分类说起 (run_classifeir.py文件)

main函数完成了整个分类功能,它首先定义了 需要的参数,加载数据,要加载的bert预训练模型,输出路径等。。。

def main():parser = argparse.ArgumentParser()## Required parametersparser.add_argument("--data_dir",default=None,type=str,required=True,help="The input data dir. Should contain the .tsv files (or other data files) for the task.")parser.add_argument("--bert_model", default=None, type=str, required=True,help="Bert pre-trained model selected in the list: bert-base-uncased, ""bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, ""bert-base-multilingual-cased, bert-base-chinese.")parser.add_argument("--task_name",default=None,type=str,required=True,help="The name of the task to train.")parser.add_argument("--output_dir",default=None,type=str,required=True,help="The output directory where the model predictions and checkpoints will be written.")

接下来判断是否使用GPU环境

    if args.local_rank == -1 or args.no_cuda:device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")n_gpu = torch.cuda.device_count()else:torch.cuda.set_device(args.local_rank)device = torch.device("cuda", args.local_rank)n_gpu = 1# Initializes the distributed backend which will take care of sychronizing nodes/GPUstorch.distributed.init_process_group(backend='nccl')args.device = device

然后进行数据处理

    processor = processors[task_name]()output_mode = output_modes[task_name]label_list = processor.get_labels()num_labels = len(label_list)

processors方法是来自于run_classifier_dataset_utils.py这个模块的它通过task_name来索引到相应的数据集的处理函数中去,字典中 的键值就是 对应数据集的处理方法。

processors = {"cola": ColaProcessor,"mnli": MnliProcessor,"mnli-mm": MnliMismatchedProcessor,"mrpc": MrpcProcessor,"sst-2": Sst2Processor,"sts-b": StsbProcessor,"qqp": QqpProcessor,"qnli": QnliProcessor,"rte": RteProcessor,"wnli": WnliProcessor,
}

例如: ColaProcessor,它定义了一个ColaProcessor类(他是继承了它的父类DataProcessor),类中包含了 获取训练集的方法(get_train_examples), 获取测试集的方法(get_dev_examples),获取标签的方法(get_labels),和创建数据集的方法(_create_examples)。

class ColaProcessor(DataProcessor):"""Processor for the CoLA data set (GLUE version)."""def get_train_examples(self, data_dir):"""See base class."""return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")def get_dev_examples(self, data_dir):"""See base class."""return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")def get_labels(self):"""See base class."""return ["0", "1"]def _create_examples(self, lines, set_type):"""Creates examples for the training and dev sets."""examples = []for (i, line) in enumerate(lines):guid = "%s-%s" % (set_type, i)text_a = line[3]label = line[1]examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))return examples

接下来 看它的父类DataProcessor,它主要是定义了_read_tsv这个方法用于读取tsv文件(当然如果你的数据集格式是 CSV 或者其他格式的文件你也可以重写这个方法,或者添加读取其他文件的方法)。

class DataProcessor(object):"""Base class for data converters for sequence classification data sets."""def get_train_examples(self, data_dir):"""Gets a collection of `InputExample`s for the train set."""raise NotImplementedError()def get_dev_examples(self, data_dir):"""Gets a collection of `InputExample`s for the dev set."""raise NotImplementedError()def get_labels(self):"""Gets the list of labels for this data set."""raise NotImplementedError()@classmethoddef _read_tsv(cls, input_file, quotechar=None):"""Reads a tab separated value file."""with open(input_file, "r", encoding="utf-8") as f:reader = csv.reader(f, delimiter="\t", quotechar=quotechar)lines = []for line in reader:if sys.version_info[0] == 2:line = list(unicode(cell, 'utf-8') for cell in line)lines.append(line)return lines

数据加载完毕,要准备开始训练了,看最后一行代码convert examples to features

    if args.do_train:if args.local_rank in [-1, 0]:tb_writer = SummaryWriter()# Prepare data loadertrain_examples = processor.get_train_examples(args.data_dir)cached_train_features_file = os.path.join(args.data_dir, 'train_{0}_{1}_{2}'.format(list(filter(None, args.bert_model.split('/'))).pop(),str(args.max_seq_length),str(task_name)))try:with open(cached_train_features_file, "rb") as reader:train_features = pickle.load(reader)except:train_features = convert_examples_to_features(train_examples, label_list, args.max_seq_length, tokenizer, output_mode)

convert_examples_to_features这个方法主要是获取token并将token转换到字典中相对应的id

def convert_examples_to_features(examples, label_list, max_seq_length,tokenizer, output_mode):"""Loads a data file into a list of `InputBatch`s."""label_map = {label : i for i, label in enumerate(label_list)}features = []for (ex_index, example) in enumerate(examples):if ex_index % 10000 == 0:logger.info("Writing example %d of %d" % (ex_index, len(examples)))tokens_a = tokenizer.tokenize(example.text_a)tokens_b = Noneif example.text_b:tokens_b = tokenizer.tokenize(example.text_b)# Modifies `tokens_a` and `tokens_b` in place so that the total# length is less than the specified length.# Account for [CLS], [SEP], [SEP] with "- 3"_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)else:# Account for [CLS] and [SEP] with "- 2"if len(tokens_a) > max_seq_length - 2:tokens_a = tokens_a[:(max_seq_length - 2)]# The convention in BERT is:# (a) For sequence pairs:#  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]#  type_ids: 0   0  0    0    0     0       0 0    1  1  1  1   1 1# (b) For single sequences:#  tokens:   [CLS] the dog is hairy . [SEP]#  type_ids: 0   0   0   0  0     0 0## Where "type_ids" are used to indicate whether this is the first# sequence or the second sequence. The embedding vectors for `type=0` and# `type=1` were learned during pre-training and are added to the wordpiece# embedding vector (and position vector). This is not *strictly* necessary# since the [SEP] token unambiguously separates the sequences, but it makes# it easier for the model to learn the concept of sequences.## For classification tasks, the first vector (corresponding to [CLS]) is# used as as the "sentence vector". Note that this only makes sense because# the entire model is fine-tuned.tokens = ["[CLS]"] + tokens_a + ["[SEP]"]segment_ids = [0] * len(tokens)if tokens_b:tokens += tokens_b + ["[SEP]"]segment_ids += [1] * (len(tokens_b) + 1)input_ids = tokenizer.convert_tokens_to_ids(tokens)# The mask has 1 for real tokens and 0 for padding tokens. Only real# tokens are attended to.input_mask = [1] * len(input_ids)# Zero-pad up to the sequence length.padding = [0] * (max_seq_length - len(input_ids))input_ids += paddinginput_mask += paddingsegment_ids += paddingassert len(input_ids) == max_seq_lengthassert len(input_mask) == max_seq_lengthassert len(segment_ids) == max_seq_lengthif output_mode == "classification":label_id = label_map[example.label]elif output_mode == "regression":label_id = float(example.label)else:raise KeyError(output_mode)if ex_index < 5:logger.info("*** Example ***")logger.info("guid: %s" % (example.guid))logger.info("tokens: %s" % " ".join([str(x) for x in tokens]))logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))logger.info("label: %s (id = %d)" % (example.label, label_id))features.append(InputFeatures(input_ids=input_ids,input_mask=input_mask,segment_ids=segment_ids,label_id=label_id))return features

到目前为止数据加载完毕 ,接下来就要加载bert模型开始训练了(先更新到这,后续部分下次更新~~)

手撕Bert代码(torch版)相关推荐

  1. 前端面试那些你必须手撕的代码

    目录 1.call 2.apply 3.bind 4.promise.all 5.promise.race 6.promise.allsettled 7.new 8.数组扁平化 9.发布订阅模式 10 ...

  2. 数据与广告系列十一:从性别预测的CASE开始手撕机器学习代码

    作者|黄崇远(题图:ssyer.com,CCO协议)  公号,数据虫巢(ID: blogchong) " 说好的带你们手撕代码. 阅读本文预计需要...我哪知道要多久,反正有点长,看你的理解 ...

  3. asp手写签名代码2021版提供源码控件

    今天接了一个小事,一个朋友想实现货物在线签收,要收货人在线签名,并保存入库,让我帮忙写一个asp的手写签名功能,忙活一下午实现了,估计会有很多人有这种需求,放网上吧 function lineCanv ...

  4. 深度学习之手撕深度神经网络DNN代码(基于numpy)

    声明 1)本文仅供学术交流,非商用.所以每一部分具体的参考资料并没有详细对应.如果某部分不小心侵犯了大家的利益,还望海涵,并联系博主删除. 2)博主才疏学浅,文中如有不当之处,请各位指出,共同进步,谢 ...

  5. 数字IC手撕代码-泰凌微笔试真题

    前言: 本专栏旨在记录高频笔面试手撕代码题,以备数字前端秋招,本专栏所有文章提供原理分析.代码及波形,所有代码均经过本人验证. 目录如下: 1.数字IC手撕代码-分频器(任意偶数分频) 2.数字IC手 ...

  6. 【手撕算法】AC显著性检测算法

    [手撕算法]AC显著性检测算法 算法原理 论文名称: Salient Region Detection and Segmentation AC算法同样是计算每个像素的显著值,但却不是基于全局对比度,而 ...

  7. 手撕Alexnet卷积神经网络-pytorch-详细注释版(可以直接替换自己数据集)-直接放置自己的数据集就能直接跑。跑的代码有问题的可以在评论区指出,看到了会回复。训练代码和预测代码均有。

    Alexnet网络详解代码:手撕Alexnet卷积神经网络-pytorch-详细注释版(可以直接替换自己数据集)-直接放置自己的数据集就能直接跑.跑的代码有问题的可以在评论区指出,看到了会回复.训练代 ...

  8. 手撕VGG卷积神经网络-pytorch-详细注释版(可以直接替换自己数据集)-直接放置自己的数据集就能直接跑。跑的代码有问题的可以在评论区指出,看到了会回复。训练代码和预测代码均有。

    Alexnet网络详解代码:手撕Alexnet卷积神经网络-pytorch-详细注释版(可以直接替换自己数据集)-直接放置自己的数据集就能直接跑.跑的代码有问题的可以在评论区指出,看到了会回复.训练代 ...

  9. 手撕Resnet卷积神经网络-pytorch-详细注释版(可以直接替换自己数据集)-直接放置自己的数据集就能直接跑。跑的代码有问题的可以在评论区指出,看到了会回复。训练代码和预测代码均有。

    Alexnet网络详解代码:手撕Alexnet卷积神经网络-pytorch-详细注释版(可以直接替换自己数据集)-直接放置自己的数据集就能直接跑.跑的代码有问题的可以在评论区指出,看到了会回复.训练代 ...

最新文章

  1. 校园二手平台的开发和利用
  2. 亚洲游戏行业遭遇史上最大DDoS攻击,微软:我给扛下来了
  3. php 如何把u5fb,php如何将json中的unicode编码转为汉字?
  4. HDU2050 折线分割平面
  5. python中的单下划线和双下划线_python中的单下划线和双下划线
  6. IET Cyber-Systems Robotics线上研讨会:聚焦人工智能与机器人前沿
  7. GBDT和RF的区别
  8. finalshell连接失败解决方法_Windows 无法连接到SENS的解决方法
  9. 腾讯视频怎么删除收藏的内容
  10. springboot日志的实现方式(两种log4j2.properties和log4j2.yml)
  11. windows下手动安装pyinstaller(python2.7)
  12. 大数据分析方法管不管用
  13. ELK性能优化实战总结:java私塾初级模拟银源代码
  14. 用阿里云盘,找不到资源怎么办???
  15. 电脑打不开计算机考试模拟软件怎么回事,计算机等级考试模拟软件提示COMDLG32.OCX错误怎么办...
  16. 小知识:Windows XP优化全攻略(网吧型)
  17. 第二篇 界面开发 (Android学习笔记)
  18. 支付宝小程序开发+java服务
  19. 关于Git提交报warning解决方法(个人笔记)
  20. 2021年广东省安全员B证第三批(项目负责人)新版试题及广东省安全员B证第三批(项目负责人)作业模拟考试

热门文章

  1. 健身的基本知识(3)
  2. 搭档之家| 生活中的“破窗理论”
  3. 一个游戏制作的全过程
  4. 从java到女装暴走漫画_[Java教程]暴走漫画
  5. java3d点线面_3D游戏与计算机图形学中的数学方法-点线面
  6. Docker for Windows
  7. 多校训练1 A Alice and Bob 博弈
  8. 2022年吃瓜事件拆解,打造爆款,让你拥有顶级营销思维!
  9. 地铁大数据挖掘之数据预处理——从原始一卡通数据提取城市地铁客流(一)
  10. opacity - cocos js