文章目录

  • 1、准备用于训练的数据集
  • 2、处理数据集
  • 3、克隆代码
  • 4、运行代码
  • 5、将ckpt模型转为bin模型使其可在pytorch中运用

Bert官方仓库:https://github.com/google-research/bert

1、准备用于训练的数据集

此处准备的是BBC news的数据集,下载链接:https://www.kaggle.com/datasets/gpreda/bbc-news
原数据集格式(.csv):

2、处理数据集

训练Bert时需要预处理数据,将数据处理成https://github.com/google-research/bert/blob/master/sample_text.txt中所示格式,如下所示:

数据预处理代码参考:

import pandas as pd# 读取BBC-news数据集
df = pd.read_csv("../../bbc_news.csv")
# print(df['title'])
l1 = []
l2 = []
cnt = 0
for line in df['title']:l1.append(line)for line in df['description']:l2.append(line)
# cnt=0
f = open("test1.txt", 'w+', encoding='utf8')
for i in range(len(l1)):s = l1[i] + " " + l2[i] + '\n'f.write(s)# cnt+=1# if cnt>10: break
f.close()
# print(l1)

处理完后的BBC news数据集格式如下所示:

3、克隆代码

使用git克隆仓库代码
http:

git clone https://github.com/google-research/bert.git

或ssh:

git clone git@github.com:google-research/bert.git

4、运行代码

先下载Bert模型:BERT-Base, Uncased
该文件中有以下文件:

运行代码:
在Teminal中运行:

python create_pretraining_data.py \--input_file=./sample_text.txt(数据集地址) \--output_file=/tmp/tf_examples.tfrecord(处理后数据集保存的位置) \--vocab_file=$BERT_BASE_DIR/vocab.txt(vocab.txt文件位置) \--do_lower_case=True \--max_seq_length=128 \--max_predictions_per_seq=20 \--masked_lm_prob=0.15 \--random_seed=12345 \--dupe_factor=5

训练模型:

python run_pretraining.py \--input_file=/tmp/tf_examples.tfrecord(处理后数据集保存的位置) \--output_dir=/tmp/pretraining_output(训练后模型保存位置) \--do_train=True \--do_eval=True \--bert_config_file=$BERT_BASE_DIR/bert_config.json(bert_config.json文件位置) \--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt(如果要从头开始的预训练,则去掉这行) \--train_batch_size=32 \--max_seq_length=128 \--max_predictions_per_seq=20 \--num_train_steps=20 \--num_warmup_steps=10 \--learning_rate=2e-5

训练完成后模型输出示例:

***** Eval results *****global_step = 20loss = 0.0979674masked_lm_accuracy = 0.985479masked_lm_loss = 0.0979328next_sentence_accuracy = 1.0next_sentence_loss = 3.45724e-05

要注意应该能够在至少具有 12GB RAM 的 GPU 上运行,不然会报错显存不足。
使用未标注数据训练BERT

5、将ckpt模型转为bin模型使其可在pytorch中运用

上一步训练好后准备好训练出来的model.ckpt-20.index文件和Bert模型中的bert_config.json文件

创建python文件convert_bert_original_tf_checkpoint_to_pytorch.py:

# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""import argparseimport torchfrom transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
from transformers.utils import logginglogging.set_verbosity_info()def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):# Initialise PyTorch modelconfig = BertConfig.from_json_file(bert_config_file)print("Building PyTorch model from configuration: {}".format(str(config)))model = BertForPreTraining(config)# Load weights from tf checkpointload_tf_weights_in_bert(model, config, tf_checkpoint_path)# Save pytorch-modelprint("Save PyTorch model to {}".format(pytorch_dump_path))torch.save(model.state_dict(), pytorch_dump_path)if __name__ == "__main__":parser = argparse.ArgumentParser()# Required parametersparser.add_argument("--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path.")parser.add_argument("--bert_config_file",default=None,type=str,required=True,help="The config json file corresponding to the pre-trained BERT model. \n""This specifies the model architecture.",)parser.add_argument("--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model.")args = parser.parse_args()convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

在Terminal中运行以下命令:

python convert_bert_original_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path Models/chinese_L-12_H-768_A-12/bert_model.ckpt.index(.ckpt.index文件位置) \
--bert_config_file Models/chinese_L-12_H-768_A-12/bert_config.json(bert_config.json文件位置)  \
--pytorch_dump_path  Models/chinese_L-12_H-768_A-12/pytorch_model.bin(输出的.bin模型文件位置)

以上命令最好在一行中运行:

python convert_bert_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path bert_model.ckpt.index --bert_config_file bert_config.json  --pytorch_dump_path  pytorch_model.bin

然后就可以得到bin文件了

【BERT for Tensorflow】本地ckpt文件的BERT使用

使用无标注的数据训练Bert相关推荐

  1. 用物理学和域知识训练“无标注样本的”神经网络( Outstanding Paper Award 优秀论文奖)

    2017的优秀论文奖:Label-Free Supervision of Neural Networks with Physics and Domain Knowledge. 这篇论文可以运用到自动驾 ...

  2. 99.99%准确率!AI数据训练工具No.1来自中国

    萧箫 发自 凹非寺 量子位 报道 | 公众号 QbitAI 这年头,真是什么样的数据集都有了. IBM的5亿行代码(bug)数据集.清华&阿里的460万少样本NER数据集.还有假货数据集.&q ...

  3. 高效利用无标注数据:自监督学习简述

    一只小狐狸带你解锁 炼丹术&NLP 秘籍 作者:huyber 来源:https://zhuanlan.zhihu.com/p/108906502 BERT的大热让自监督学习成为了大家讨论的热点 ...

  4. 无标注数据是鸡肋还是宝藏?阿里工程师这样用它

    阿里妹导读:针对业务场景中标注数据不足.大量的无标注数据又难以有效利用的问题,我们提出了一种面向行为序列数据的深度学习风控算法 Auto Risk,提出通过代理任务从无标注数据中学习通用的特征表示.这 ...

  5. 无标注数据是鸡肋还是宝藏?阿里工程师这样用它​

    阿里妹导读:针对业务场景中标注数据不足.大量的无标注数据又难以有效利用的问题,我们提出了一种面向行为序列数据的深度学习风控算法 Auto Risk,提出通过代理任务从无标注数据中学习通用的特征表示.这 ...

  6. 新思路!商汤开源利用无标注数据大幅提高精度的人脸识别算法

    出处"来自微信公众号:我爱计算机视觉" 新思路!商汤开源利用无标注数据大幅提高精度的人脸识别算法 这篇论文解决的问题与现实中的人脸识别应用场景密切相关,其假设已经有了少量已经标注的 ...

  7. 13亿参数,无标注预训练实现SOTA:Facebook提出自监督CV新模型

    作者|陈萍 来源|机器之心 Facebook AI 用 10 亿张来自Instagram的随机.未标注图像预训练了一个参数量达 13 亿的自监督模型 SEER,该模型取得了自监督视觉模型的新 SOTA ...

  8. 实操教程|用不需要手工标注分割的训练数据来进行图像分割

    点上方计算机视觉联盟获取更多干货 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者丨Siddhartha Chandra 来源丨AI公园 AI博士笔记系列推荐 周志华<机器学习> ...

  9. Question | 标注下数据、训练个模型,商用的智能鉴黄有这么简单吗?

    "Question"为网易易盾的问答栏目,将会解答和呈现安全领域大家常见的问题和困惑.如果你有什么疑惑,也欢迎通过邮件(zhangyong02@corp.netease.com)提 ...

最新文章

  1. 一年春事,桃花红了谁……
  2. php 自动签到源码,我也来发个自动签到脚本,PHP版
  3. [小米OJ] 4. 最长连续数列
  4. 如何获取网站的ico图标呢?
  5. 华为不可参与 IEEE 审稿但可继续提供赞助;谷歌限制 Chrome 接口惹非议;Mozilla 号召用户换火狐 | 开发者周刊...
  6. highlightjs 详解
  7. python编程语言-为什么我不建议你将python作为入门编程语言
  8. Vue.js 使用 Swiper.js 在 iOS 11 时出现错误
  9. C# UrlEncoding
  10. excel 连接 mysql_Excel 数据库连接
  11. 爬虫实战1:爬取哔哩哔哩主播的头像以昵称命名保存到本地文件
  12. python 文件格式转换_如何把txt文件转换成py文件
  13. 归纳偏执_防御性编程:足够偏执
  14. 大华平台linux密码,大华DSS平台低权限账户越权直接修改system密码
  15. android 极光推送解绑,app集成极光推送笔记(angular js)
  16. IC、FPGA验证学习
  17. UV杀菌灯芯片-DLT8P65SA-杰力科创
  18. MySQL复制表结构,表数据。
  19. LED照明灯具与传感器技术
  20. 临平职高计算机专业高职考大学,临平职高的她考了全省第一!还有很多省前10都在这里!...

热门文章

  1. 发的很好就就飞突然他
  2. 为Linux安置红旗紫光输入法
  3. VeryCD的名言集锦
  4. java 使用Guava的RateLimiter做接口限流+redis的lua脚本做IP防刷
  5. 让CPU占用率曲线听你指挥
  6. 数据库开发工程师-招聘技能要求
  7. 高效过滤器是如何过滤的
  8. 小程序直播如何做好直播带货
  9. 周期矩形波的傅里叶级数展开(Matlab代码实现)
  10. Substrate 技术及生态6月大事记 | Polkadot Decoded 圆满落幕,黑客松获胜项目为生态注入新生力量