前面一篇blog finetune一个GPT3,借助openai的api finetune了一个GPT3,使用下来确实太贵,生成了1w条数据,花掉了60多美刀。肉痛,所以穷人只能想想穷人的办法,脑子就浮现出好朋友EleutherAI的GPT-Neo来。github上有两个项目GPT-Neo和GPT-NeoX,下图来自高仿也赛高?GPT-Neo真好用

之前一个blog,我是基于openai davinci做的微调,从这个角度来说,根据github上对gpt-neox的描述,得采用GPT-NeoX是最好的选择

We aim to make this repo a centralized and accessible place to gather techniques for training large-scale autoregressive language models, and accelerate research into large-scale training. Additionally, we hope to train and open source a 175B parameter GPT-3 replication along the way. Please note, however, that this is a research codebase that is primarily designed for performance over ease of use. We endeavour to make it as easy to use as is feasible, but if there's anything in the readme that is unclear or you think you've found a bug, please open an issue.

不过这个实在太大,恐怕我的3090没能力撑起这个场子,GPT-Neo也是可以将就将就的。

剩余的也就是准备数据(同之前的json),再写几行简单的finetuning code,网上东拼西凑,再debug一下,获得如下代码,仅供参考。gpt-neo-1.3B 1000条数据finetune 10个周期效果也不够好(train_data_prepared.jsonl),最终我采用openai获得的9000多条数据训练一个周期后,效果就接近openai davinci finetune的结果了。

import sys
import json
import pandas as pd
import torch
from torch.utils.data import Dataset, random_split
from transformers import GPT2Tokenizer, TrainingArguments, Trainer, AutoModelForCausalLM#, GPTNeoXTokenizerFasttorch.manual_seed(42)
modelname = "EleutherAI/gpt-neo-1.3B"
tokenizer = GPT2Tokenizer.from_pretrained(f"{modelname}", bos_token='<|startoftext|>',eos_token='<|endoftext|>', pad_token='<|pad|>')
#special_tokens_dict = {'sep_token': '\n##\n', 'unk_token': '\n%%\n'}
#num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
#model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b")
#tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b")
model = AutoModelForCausalLM.from_pretrained(f"{modelname}").cuda()
for name, param in model.transformer.wte.named_parameters():param.requires_grad = False
for name, param in model.transformer.wpe.named_parameters():param.requires_grad = False
for name, param in model.transformer.h[:16].named_parameters():param.requires_grad = False
model.resize_token_embeddings(len(tokenizer))
prompts = []
completions = []'''
f = open("train_data_prepared.jsonl", "r")
lines = [line.strip() for line in f.readlines()]
f.close()
for line in lines:data = json.loads(line)if len(data['prompt'].split(' ')) + len(data['completion'].split(' ')) > 45:continueprompts.append(data['prompt'])completions.append(data['completion'])
'''
f = open("prompt_instrcut_completion.txt", "r")
lines = [line.strip() for line in f.readlines()]
f.close()
for line in lines:items = line.split('|')if len(items) != 3:continueprompt,completion = items[0]+"\n##\n",items[1]+"\n%%\n"+items[2]+"\nEND"if len(prompt.split(' ')) + len(completion.split(' ')) > 45:continueprompts.append(prompt)completions.append(completion)max_length = 50
print("prompts length", len(prompts), "max length", max_length)class InstructDataset(Dataset):def __init__(self, prompts, completions, tokenizer, max_length):self.input_ids = []self.attn_masks = []for prompt, completion in zip(prompts, completions):encodings_dict = tokenizer(f'<|startoftext|>prompt:{prompt}\ncompletion:{completion}<|endoftext|>', truncation=True,max_length=max_length, padding="max_length")self.input_ids.append(torch.tensor(encodings_dict['input_ids']))self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))def __len__(self):return len(self.input_ids)def __getitem__(self, idx):return self.input_ids[idx], self.attn_masks[idx]dataset = InstructDataset(prompts, completions, tokenizer, max_length=max_length)
train_size = int(0.9 * len(dataset))
train_dataset, val_dataset = random_split(dataset, [train_size, len(dataset) - train_size])
training_args = TrainingArguments(output_dir='./results', num_train_epochs=1, logging_steps=400, save_steps=500,per_device_train_batch_size=2, per_device_eval_batch_size=2,warmup_steps=100, weight_decay=0.01, logging_dir='./logs')
Trainer(model=model, args=training_args, train_dataset=train_dataset,eval_dataset=val_dataset, data_collator=lambda data: {'input_ids': torch.stack([f[0] for f in data]),'attention_mask': torch.stack([f[1] for f in data]),'labels': torch.stack([f[0] for f in data])}).train()
generated = tokenizer("<|startoftext|>a dog\n##\n", return_tensors="pt").input_ids.cuda()
sample_outputs = model.generate(generated, do_sample=True, top_k=50, max_length=50, top_p=0.95, temperature=0.7, frequency_penalty=0.1, num_return_sequences=20)
for i, sample_output in enumerate(sample_outputs):print("{}: {}".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))

参考:

Guide to fine-tuning Text Generation models: GPT-2, GPT-Neo and T5

EleutherAI GPT-Neo: 穷人的希望相关推荐

  1. 莆田版GPT-3开源:同等复现预训练模型GPT Neo

    GPT-3开源了?Eleuther AI推出的名为GPT-Neo的开源项目:公开发布的GPT-3同等复现预训练模型(1.3B & 2.7B),可在Colab上完成微调. --当然此 GPT-3 ...

  2. linux命令管理GPT分区,Linux磁盘管理GPT分区教程

    Linux内核代码量大.逻辑关系复杂,因此对内核中的错误进行追溯和调试一直以来都是一件既耗费时间又耗费精力的事情.接下来是小编为大家收集的 Linux磁盘管理GPT分区教程,希望能帮到大家. Linu ...

  3. 穷人靠力,富人借力!看完你将明白一切!

    有个穷人,因为吃不饱穿不暖,而在佛祖面前痛哭流涕,诉说生活的艰苦,天天干活累的半死却挣不来几个钱. 哭了半晌他突然开始埋怨道:"这个社会太不公平了,为什么富人天天悠闲自在,而穷人就应该天天吃 ...

  4. 富人和穷人的八大差异

    ●自我认知 穷人:很少想到如何去赚钱和如何才能赚到钱,认为自己一辈子就该这样,不相信会有什么改变. 富人:骨子里就深信自己生下来不是要做穷人,而是要做富人,他有强烈的赚钱意识 ,这也是他血液里的东西, ...

  5. 决定你是富人还是穷人的12条定律

    1. 自我认知 穷人:很少想到如何去赚钱和如何才能赚到钱,认为自己一辈子就该这样,不相信会有什么改变. 富人:骨子里就深信自己生下来不是要做穷人,而是要做富人,他有强烈的赚钱意识,这也是他血液里的东西 ...

  6. 杂谈--穷人和富人的区别(觉得很有道理,需要反思自己的日常行为)

    1.自我认知 穷人:很少想到如何去赚钱和如何才能赚到钱,认为自己一辈子就该这样,不相信会有什么改变. 富人:骨子里就深信自己生下来不是要做穷人,而是要做富人,他有强烈的赚钱意识,这也是他血液里的东西, ...

  7. 穷人跟懒人 富人跟勤快人

    1.自我认知 穷人:很少想到如何去赚钱和如何才能赚到钱,认为自己一辈子就该这样,不相信会有什么改变. 富人:骨子里就深信自己生下来不是要做穷人,而是要做富人,他有强烈的赚钱意识,这也是他血液里的东西, ...

  8. 穷人 与 富人 思维 比较

    穷人与富人思维比较 1.自我认知 穷人:很少想到如何去赚钱和如何才能赚到钱,认为自己一辈子就该这样,不相信会有什么改变. 富人:骨子里就深信自己生下来不是要做穷人,而是要做富人,他有强烈的赚钱意识,这 ...

  9. 4t gpt索引 linux如何挂载,LINUX教学:Ubuntu 16.04通过GPT挂载硬盘

    <LINUX教学:Ubuntu 16.04通过GPT挂载硬盘>要点: 本文介绍了LINUX教学:Ubuntu 16.04通过GPT挂载硬盘,希望对您有用.如果有疑问,可以联系我们. 记录下 ...

最新文章

  1. 浏览器允许跨域设置(不用于生产环境,开发用)
  2. 多线程—Thread类及线程三种创建方式及对比
  3. c语言十六进制转换加H,c语言十六进制和十进制间的转换.docx
  4. Redis(七):Hash哈希数据类型详解
  5. Magento Helper简介
  6. Cracking the Coding Interview(Stacks and Queues)
  7. jQuery 的一个自动向上翻页的效果
  8. Spring Boot + Log4j2 日志框架配置 (Maven)
  9. table设置width无效
  10. [2019上海网络赛J题]Stone game
  11. idea导入一个工程后只显示pom文件_P1搭建第一个springboot应用
  12. HSPICE求导语句
  13. 计算机网络实验二 VLAN间路由
  14. 苹果台式机怎么设置我的电脑计算机,怎么让台式电脑用苹果手机的wifi上网
  15. Excel学习笔记(1)——数据类型,自动填充,数据有效性,美化
  16. Anaconda3\Scripts\activate.bat 不是内部或外部命令,也不是可运行的程序的问题处理方法
  17. 爬虫学习(14):selenium自动化测试(三):鼠标和键盘操作
  18. 学计算机有哪些推荐书籍?
  19. lisp语言画地物符号_LISP语言在CAD工程制图中的应用_谢威
  20. Leetcode-LCP 25. 古董键盘

热门文章

  1. 一味研究:岩石股份摘帽即收“两连板”,海银系要认真酿酒了吗?
  2. foxmail登入密码忘记怎么办?
  3. switch 注册哪个服务器,switch注册教程
  4. 五金模具设计常见的二十一块模板作用,一起学起来
  5. JMeter源码学习- 5.0版本源码本地构建
  6. Android studio 图片按钮
  7. win10搭FTP与单片机通信:配置+编程实现的完整流程
  8. mac菜单栏设置教程
  9. 视频虚化边框剪辑技巧分享
  10. jquery.countdown.js一个时间倒计时的插件