Colosal-AI复现流程

  • 1 环境搭建
    • 1.1 cuda环境
    • 1.2 python环境
    • 1.3 python package 环境
  • 2 下载代码
  • 3 模型训练
    • 3.1 SFT(supervised fine-tuning)
      • 3.1.1 命令
      • 3.1.2 日志
    • 3.2 训练奖励模型(Training reward model)
      • 3.2.1 命令
      • 3.2.2 日志
    • 3.3 RL(Training model using prompts with RL)
      • 3.3.1 命令
      • 3.3.2 日志
    • 3.4 Inferrence(推理)
      • 3.4.1 代码
      • 3.4.2 演示
  • 4 参考

1 环境搭建

1.1 cuda环境

root@LAPTOP-3SUHS40U:/home/work/ColossalAI# lsb_release -a
LSB Version:    core-11.1.0ubuntu2-noarch:security-11.1.0ubuntu2-noarch
Distributor ID: Ubuntu
Description:    Ubuntu 20.04.6 LTS
Release:        20.04
Codename:       focal
root@LAPTOP-3SUHS40U:/home/work/ColossalAI# nvidia-smi
Sun May 21 05:56:52 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.50                 Driver Version: 531.79       CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 2060         On | 00000000:01:00.0  On |                  N/A |
| N/A   47C    P8                7W /  N/A|   5760MiB /  6144MiB |     10%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------++---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A        23      G   /Xwayland                                 N/A      |
|    0   N/A  N/A    306013      C   /python3.10                               N/A      |
+---------------------------------------------------------------------------------------+

1.2 python环境

root@LAPTOP-3SUHS40U:/home/work/ColossalAI# python --version
Python 3.10.10

1.3 python package 环境

root@LAPTOP-3SUHS40U:/home/work/ColossalAI# pip list
Package                       Version          Editable project location
----------------------------- ---------------- ------------------------------------------------
absl-py                       1.4.0
aiofiles                      23.1.0
aiohttp                       3.8.4
aiosignal                     1.3.1
alabaster                     0.7.13
altair                        5.0.0
anyio                         3.6.2
appdirs                       1.4.4
arrow                         1.2.3
astunparse                    1.6.3
async-timeout                 4.0.2
attrs                         23.1.0
Babel                         2.12.1
bcrypt                        4.0.1
benepar                       0.2.0
blessed                       1.20.0
blis                          0.7.9
boltons                       23.0.0
brotlipy                      0.7.0
cachetools                    5.3.0
catalogue                     2.0.8
certifi                       2022.12.7
cffi                          1.15.1
cfgv                          3.3.1
charset-normalizer            2.0.4
click                         8.1.3
cmake                         3.26.3
coati                         1.0.0
colossalai                    0.2.8
conda                         23.3.1
conda-content-trust           0.1.3
conda-package-handling        2.0.2
conda_package_streaming       0.7.0
confection                    0.0.4
contexttimer                  0.3.3
coverage                      7.2.5
cryptography                  39.0.1
cycler                        0.11.0
cymem                         2.0.7
Cython                        0.29.34
dataclasses-json              0.5.7
datasets                      2.12.0
diffusers                     0.16.1
dill                          0.3.6
distlib                       0.3.6
distro                        1.8.0
dnspython                     2.3.0
docker                        6.0.1
docker-pycreds                0.4.0
docstring-parser              0.8.1
docutils                      0.18.1
einops                        0.6.1
et-xmlfile                    1.1.0
exceptiongroup                1.1.1
expecttest                    0.1.4
fabric                        3.0.1
fastapi                       0.95.1
fbgemm-gpu                    0.4.1
ffmpy                         0.3.0
filelock                      3.12.0
flash-attn                    0.1
fonttools                     4.39.3
frozenlist                    1.3.3
fsspec                        2023.4.0
gitdb                         4.0.10
GitPython                     3.1.31
google-auth                   2.17.3
google-auth-oauthlib          0.4.6
gpustat                       1.1
gradio                        3.30.0
gradio_client                 0.2.4
greenlet                      2.0.2
grpcio                        1.54.0
h11                           0.14.0
httpcore                      0.17.0
httpx                         0.24.0
huggingface-hub               0.14.1
hypothesis                    6.75.1
identify                      2.5.23
idna                          3.4
imagesize                     1.4.1
importlib-metadata            6.6.0
iniconfig                     2.0.0
invoke                        2.1.1
iopath                        0.1.10
Jinja2                        3.1.2
joblib                        1.2.0
jsonpatch                     1.32
jsonpointer                   2.1
jsonschema                    4.17.3
kiwisolver                    1.4.4
langchain                     0.0.161
langcodes                     3.3.0
libcst                        0.4.9
linkify-it-py                 2.0.2
loralib                       0.1.1
Markdown                      3.4.3
markdown-it-py                2.2.0
MarkupSafe                    2.1.2
marshmallow                   3.19.0
marshmallow-enum              1.5.1
matplotlib                    3.5.3
mdit-py-plugins               0.3.3
mdurl                         0.1.2
moreorless                    0.4.0
mpmath                        1.3.0
multidict                     6.0.4
multiprocess                  0.70.14
murmurhash                    1.0.9
mypy-extensions               1.0.0
myst-parser                   0.18.1
networkx                      3.1
ninja                         1.11.1
nltk                          3.8.1
nodeenv                       1.7.0
numexpr                       2.8.4
numpy                         1.24.3
nvidia-cublas-cu11            11.10.3.66
nvidia-cuda-nvrtc-cu11        11.7.99
nvidia-cuda-runtime-cu11      11.7.99
nvidia-cudnn-cu11             8.5.0.96
nvidia-ml-py                  11.525.112
oauthlib                      3.2.2
openai                        0.24.0
openapi-schema-pydantic       1.2.4
openpyxl                      3.1.2
orjson                        3.8.12
packaging                     23.0
pandas                        2.0.1
pandas-stubs                  2.0.1.230501
paramiko                      3.1.0
pathspec                      0.11.1
pathtools                     0.1.2
pathy                         0.10.1
Pillow                        9.5.0
pip                           23.0.1
platformdirs                  3.5.0
plotly                        5.14.1
pluggy                        1.0.0
portalocker                   2.7.0
pre-commit                    3.3.1
preshed                       3.0.8
protobuf                      3.19.6
psutil                        5.9.5
pyarrow                       12.0.0
pyasn1                        0.5.0
pyasn1-modules                0.3.0
pycosat                       0.6.4
pycparser                     2.21
pydantic                      1.10.7
pyDeprecate                   0.3.2
pydub                         0.25.1
Pygments                      2.15.1
PyNaCl                        1.5.0
pyOpenSSL                     23.0.0
pyparsing                     3.0.9
pyre-extensions               0.0.27
pyrsistent                    0.19.3
PySocks                       1.7.1
pytest                        7.3.1
pytest-cov                    4.0.0
python-dateutil               2.8.2
python-etcd                   0.4.5
python-multipart              0.0.6
pytorch-sphinx-theme          0.0.24           /home/work/pytorch/docs/src/pytorch-sphinx-theme
pytorch-triton                2.1.0+7d1a95b046
pytz                          2023.3
PyYAML                        6.0
regex                         2023.5.4
requests                      2.27.1
requests-oauthlib             1.3.1
responses                     0.18.0
rich                          13.3.5
rouge-score                   0.1.2
rsa                           4.9
ruamel.yaml                   0.17.21
ruamel.yaml.clib              0.2.6
safetensors                   0.3.1
scikit-build                  0.17.3
semantic-version              2.10.0
sentencepiece                 0.1.99
sentry-sdk                    1.22.1
setproctitle                  1.3.2
setuptools                    65.6.3
six                           1.16.0
smart-open                    6.3.0
smmap                         5.0.0
sniffio                       1.3.0
snowballstemmer               2.2.0
sortedcontainers              2.4.0
spacy                         3.5.2
spacy-legacy                  3.0.12
spacy-loggers                 1.0.4
Sphinx                        5.0.0
sphinx-copybutton             0.5.0
sphinx-panels                 0.4.1
sphinxcontrib-applehelp       1.0.4
sphinxcontrib-devhelp         1.0.2
sphinxcontrib-htmlhelp        2.0.1
sphinxcontrib-jsmath          1.0.1
sphinxcontrib-katex           0.8.6
sphinxcontrib-qthelp          1.0.3
sphinxcontrib-serializinghtml 1.1.5
SQLAlchemy                    2.0.12
srsly                         2.4.6
sse-starlette                 1.5.0
starlette                     0.26.1
stdlibs                       2022.10.9
sympy                         1.11.1
tabulate                      0.9.0
tenacity                      8.2.2
tensorboard                   2.10.0
tensorboard-data-server       0.6.1
tensorboard-plugin-wit        1.8.1
thinc                         8.1.10
timm                          0.6.13
titans                        0.0.7
tokenizers                    0.13.3
toml                          0.10.2
tomli                         2.0.1
toolz                         0.12.0
torch                         1.13.1           /root/miniconda3/lib/python3.10/site-packages
torch-struct                  0.5
torchaudio                    0.13.1
torchmetrics                  0.11.4
torchrec                      0.2.0
torchvision                   0.14.1
torchx-nightly                2023.5.3
tqdm                          4.65.0
trailrunner                   1.4.0
transformers                  4.28.0.dev0
typer                         0.7.0
types-dataclasses             0.6.6
types-pytz                    2023.3.0.0
typing_extensions             4.5.0
typing-inspect                0.8.0
tzdata                        2023.3
uc-micro-py                   1.0.2
urllib3                       1.26.15
usort                         1.0.6
uvicorn                       0.22.0
virtualenv                    20.23.0
wandb                         0.15.2
wasabi                        1.1.1
wcwidth                       0.2.6
websocket-client              1.5.1
websockets                    11.0.3
Werkzeug                      2.3.3
wheel                         0.38.4
xxhash                        3.2.0
yarl                          1.9.2
zipp                          3.15.0
zstandard                     0.19.0

2 下载代码

git clone https://github.com/hpcaitech/ColossalAI.git
cd cd ColossalAI
CUDA_EXT=1 pip install .

3 模型训练

3.1 SFT(supervised fine-tuning)

3.1.1 命令

torchrun --standalone --nproc_per_node=1 train_sft.py \--pretrain "/mnt/f/kangpengtao/study/ColossalAI/bigscience/bloom-560m/" \--model 'bloom' \--strategy naive \--log_interval 10 \--save_path /mnt/f/kangpengtao/study/ColossalAI/Coati-7B \--dataset /mnt/f/kangpengtao/study/ColossalAI/InstructionWild/data/instinwild.json \--batch_size 1 \--accumulation_steps 8 \--lr 2e-5 \--max_datasets_size 16384 \--max_epochs 1 \--lora_rank 16 \

3.1.2 日志

root@LAPTOP-3SUHS40U:/home/work/ColossalAI/applications/Chat/examples# ./train_sft_bloom_kpt.sh
[05/16/23 15:54:24] INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:121 __init__INFO     colossalai - colossalai - INFO: Loading data...
[05/16/23 15:54:25] INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:123 __init__INFO     colossalai - colossalai - INFO: Loaded 103695 examples.INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:126 __init__INFO     colossalai - colossalai - INFO: Limiting dataset to 16384 examples.INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:129 __init__INFO     colossalai - colossalai - INFO: Formatting inputs...INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:137 __init__INFO     colossalai - colossalai - INFO: Tokenizing inputs... This may take some time...
steps:   0%|                                                                                                                       | 0/2048 [00:00<?, ?it/s][05/16/23 15:54:40] WARNING  colossalai - colossalai - WARNING: /root/miniconda3/lib/python3.10/site-packages/coati/trainer/sft.py:86 fitWARNING  colossalai - colossalai - WARNING: batch_id:7, abnormal loss: 2.74609375
steps: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2048/2048 [36:00<00:00,  1.05s/it]

3.2 训练奖励模型(Training reward model)

3.2.1 命令

set_n_least_used_CUDA_VISIBLE_DEVICES() {local n=${1:-"9999"}echo "GPU Memory Usage:"local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \| tail -n +2 \| nl -v 0 \| tee /dev/tty \| sort -g -k 2 \| awk '{print $1}' \| head -n $n)export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')echo "Now CUDA_VISIBLE_DEVICES is set to:"echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}set_n_least_used_CUDA_VISIBLE_DEVICES 1torchrun --standalone --nproc_per_node=1 train_reward_model.py \--pretrain  '/mnt/f/kangpengtao/study/ColossalAI/Coati-7B' \--model 'bloom' \--strategy naive \--loss_fn 'log_sig'\--dataset 'Anthropic/hh-rlhf' \--save_path '/mnt/f/kangpengtao/study/ColossalAI/rm-static.pt' \--lora_rank 16 \--batch_size 1 \--max_len 512 \

3.2.2 日志

root@LAPTOP-3SUHS40U:/home/work/ColossalAI/applications/Chat/examples# ./train_rm_bloom_kpt.sh
GPU Memory Usage:0  228 MiB
Now CUDA_VISIBLE_DEVICES is set to:
CUDA_VISIBLE_DEVICES=0
Some weights of the model checkpoint at /mnt/f/kangpengtao/study/ColossalAI/Coati-7B were not used when initializing BloomModel: ['transformer.h.23.mlp.dense_4h_to_h.lora_B', 'transformer.h.11.mlp.dense_4h_to_h.lora_B', 'transformer.h.18.self_attention.dense.lora_B', 'transformer.h.6.mlp.dense_4h_to_h.lora_B', 'transformer.h.11.mlp.dense_4h_to_h.lora_A', 'transformer.h.11.self_attention.query_key_value.lora_B', 'transformer.h.21.self_attention.query_key_value.lora_B', 'transformer.h.22.mlp.dense_h_to_4h.lora_A', 'transformer.h.20.mlp.dense_h_to_4h.lora_A', 'lm_head.lora_A', 'transformer.h.22.self_attention.dense.lora_A', 'transformer.h.6.self_attention.query_key_value.lora_A', 'transformer.h.14.mlp.dense_h_to_4h.lora_A', 'transformer.h.5.mlp.dense_h_to_4h.lora_B', 'transformer.h.9.mlp.dense_h_to_4h.lora_A', 'transformer.h.20.self_attention.query_key_value.lora_A', 'transformer.h.6.mlp.dense_h_to_4h.lora_B', 'transformer.h.20.mlp.dense_4h_to_h.lora_A', 'transformer.h.19.self_attention.dense.lora_A', 'transformer.h.17.self_attention.query_key_value.lora_B', 'transformer.h.0.mlp.dense_4h_to_h.lora_B', 'transformer.h.13.self_attention.query_key_value.lora_A', 'transformer.h.23.mlp.dense_h_to_4h.lora_B', 'transformer.h.23.self_attention.query_key_value.lora_B', 'transformer.h.9.self_attention.dense.lora_B', 'transformer.h.23.mlp.dense_4h_to_h.lora_A', 'transformer.h.9.self_attention.query_key_value.lora_B', 'transformer.h.0.mlp.dense_h_to_4h.lora_A', 'transformer.h.19.mlp.dense_4h_to_h.lora_A', 'transformer.h.4.self_attention.dense.lora_B', 'transformer.h.16.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.self_attention.dense.lora_B', 'transformer.h.4.mlp.dense_4h_to_h.lora_B', 'transformer.h.16.self_attention.query_key_value.lora_B', 'transformer.h.5.mlp.dense_4h_to_h.lora_B', 'transformer.h.10.mlp.dense_4h_to_h.lora_A', 'transformer.h.18.self_attention.query_key_value.lora_A', 'transformer.h.12.self_attention.query_key_value.lora_A', 'transformer.h.22.self_attention.dense.lora_B', 'transformer.h.1.mlp.dense_4h_to_h.lora_A', 'transformer.h.6.self_attention.dense.lora_A', 'transformer.h.13.mlp.dense_h_to_4h.lora_B', 'transformer.h.2.self_attention.dense.lora_A', 'transformer.h.1.self_attention.query_key_value.lora_A', 'transformer.h.12.mlp.dense_h_to_4h.lora_A', 'transformer.h.4.self_attention.query_key_value.lora_B', 'transformer.h.13.self_attention.query_key_value.lora_B', 'lm_head.lora_B', 'transformer.h.17.mlp.dense_4h_to_h.lora_B', 'transformer.h.12.mlp.dense_h_to_4h.lora_B', 'transformer.h.23.self_attention.query_key_value.lora_A', 'transformer.h.21.mlp.dense_4h_to_h.lora_B', 'transformer.h.20.mlp.dense_h_to_4h.lora_B', 'transformer.h.15.self_attention.dense.lora_B', 'transformer.h.11.self_attention.dense.lora_B', 'transformer.h.2.self_attention.query_key_value.lora_B', 'transformer.h.16.mlp.dense_h_to_4h.lora_A', 'transformer.h.5.self_attention.query_key_value.lora_A', 'transformer.h.16.self_attention.query_key_value.lora_A', 'transformer.h.23.self_attention.dense.lora_B', 'transformer.h.13.self_attention.dense.lora_B', 'transformer.h.6.self_attention.dense.lora_B', 'transformer.h.13.self_attention.dense.lora_A', 'transformer.h.1.self_attention.query_key_value.lora_B', 'transformer.h.22.self_attention.query_key_value.lora_A', 'transformer.h.15.mlp.dense_4h_to_h.lora_B', 'transformer.h.12.self_attention.dense.lora_B', 'transformer.h.2.self_attention.dense.lora_B', 'transformer.h.15.mlp.dense_4h_to_h.lora_A', 'transformer.h.1.mlp.dense_h_to_4h.lora_B', 'transformer.h.11.self_attention.dense.lora_A', 'transformer.h.19.self_attention.query_key_value.lora_A', 'transformer.h.4.self_attention.dense.lora_A', 'transformer.h.14.self_attention.query_key_value.lora_B', 'transformer.h.9.self_attention.dense.lora_A', 'transformer.h.22.mlp.dense_h_to_4h.lora_B', 'transformer.h.15.self_attention.dense.lora_A', 'transformer.h.0.mlp.dense_4h_to_h.lora_A', 'transformer.h.3.self_attention.query_key_value.lora_B', 'transformer.h.17.mlp.dense_4h_to_h.lora_A', 'transformer.h.22.self_attention.query_key_value.lora_B', 'transformer.h.7.self_attention.dense.lora_B', 'transformer.h.5.mlp.dense_4h_to_h.lora_A', 'transformer.h.10.self_attention.query_key_value.lora_A', 'transformer.h.22.mlp.dense_4h_to_h.lora_B', 'transformer.h.7.self_attention.query_key_value.lora_A', 'transformer.h.2.mlp.dense_h_to_4h.lora_A', 'transformer.h.20.self_attention.dense.lora_A', 'transformer.h.15.mlp.dense_h_to_4h.lora_B', 'transformer.h.11.mlp.dense_h_to_4h.lora_A', 'transformer.h.0.self_attention.dense.lora_A', 'transformer.h.3.mlp.dense_h_to_4h.lora_A', 'transformer.h.19.mlp.dense_h_to_4h.lora_B', 'transformer.h.5.mlp.dense_h_to_4h.lora_A', 'transformer.h.3.self_attention.dense.lora_A', 'transformer.h.10.mlp.dense_4h_to_h.lora_B', 'transformer.h.9.self_attention.query_key_value.lora_A', 'transformer.h.8.self_attention.dense.lora_B', 'transformer.h.12.mlp.dense_4h_to_h.lora_A', 'transformer.h.19.mlp.dense_4h_to_h.lora_B', 'transformer.h.6.self_attention.query_key_value.lora_B', 'transformer.h.9.mlp.dense_4h_to_h.lora_B', 'transformer.h.13.mlp.dense_4h_to_h.lora_B', 'transformer.h.12.self_attention.query_key_value.lora_B', 'transformer.h.16.mlp.dense_4h_to_h.lora_A', 'transformer.h.8.mlp.dense_h_to_4h.lora_A', 'transformer.h.5.self_attention.dense.lora_B', 'transformer.h.17.self_attention.query_key_value.lora_A', 'transformer.h.9.mlp.dense_4h_to_h.lora_A', 'transformer.h.10.self_attention.dense.lora_A', 'transformer.h.1.mlp.dense_h_to_4h.lora_A', 'transformer.h.21.self_attention.query_key_value.lora_A', 'transformer.h.10.mlp.dense_h_to_4h.lora_A', 'transformer.h.15.mlp.dense_h_to_4h.lora_A', 'transformer.h.8.mlp.dense_h_to_4h.lora_B', 'transformer.h.21.mlp.dense_h_to_4h.lora_A', 'transformer.h.7.self_attention.dense.lora_A', 'transformer.h.16.self_attention.dense.lora_B', 'transformer.h.17.self_attention.dense.lora_A', 'transformer.h.20.mlp.dense_4h_to_h.lora_B', 'transformer.h.15.self_attention.query_key_value.lora_B', 'transformer.h.22.mlp.dense_4h_to_h.lora_A', 'transformer.h.18.self_attention.query_key_value.lora_B', 'transformer.h.13.mlp.dense_h_to_4h.lora_A', 'transformer.h.4.mlp.dense_4h_to_h.lora_A', 'transformer.h.1.mlp.dense_4h_to_h.lora_B', 'transformer.h.15.self_attention.query_key_value.lora_A', 'transformer.h.11.self_attention.query_key_value.lora_A', 'transformer.h.3.self_attention.dense.lora_B', 'transformer.h.4.self_attention.query_key_value.lora_A', 'transformer.h.0.self_attention.dense.lora_B', 'transformer.h.13.mlp.dense_4h_to_h.lora_A', 'transformer.h.3.mlp.dense_4h_to_h.lora_B', 'transformer.h.5.self_attention.dense.lora_A', 'transformer.h.10.self_attention.dense.lora_B', 'transformer.h.23.mlp.dense_h_to_4h.lora_A', 'transformer.h.3.self_attention.query_key_value.lora_A', 'transformer.h.21.self_attention.dense.lora_B', 'transformer.h.17.mlp.dense_h_to_4h.lora_B', 'transformer.h.2.mlp.dense_4h_to_h.lora_B', 'transformer.h.8.self_attention.query_key_value.lora_A', 'transformer.h.11.mlp.dense_h_to_4h.lora_B', 'transformer.h.3.mlp.dense_4h_to_h.lora_A', 'transformer.h.18.mlp.dense_4h_to_h.lora_B', 'transformer.h.16.mlp.dense_4h_to_h.lora_B', 'transformer.h.21.mlp.dense_h_to_4h.lora_B', 'transformer.h.2.mlp.dense_4h_to_h.lora_A', 'transformer.h.7.mlp.dense_h_to_4h.lora_B', 'transformer.h.18.mlp.dense_4h_to_h.lora_A', 'transformer.h.8.mlp.dense_4h_to_h.lora_A', 'transformer.h.4.mlp.dense_h_to_4h.lora_B', 'transformer.h.7.self_attention.query_key_value.lora_B', 'transformer.h.9.mlp.dense_h_to_4h.lora_B', 'transformer.h.21.self_attention.dense.lora_A', 'transformer.h.20.self_attention.dense.lora_B', 'transformer.h.8.self_attention.query_key_value.lora_B', 'transformer.h.8.self_attention.dense.lora_A', 'transformer.h.2.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.mlp.dense_4h_to_h.lora_A', 'transformer.h.8.mlp.dense_4h_to_h.lora_B', 'transformer.h.6.mlp.dense_4h_to_h.lora_A', 'transformer.h.7.mlp.dense_4h_to_h.lora_B', 'transformer.h.18.mlp.dense_h_to_4h.lora_B', 'transformer.h.17.self_attention.dense.lora_B', 'transformer.h.0.self_attention.query_key_value.lora_A', 'transformer.h.7.mlp.dense_4h_to_h.lora_A', 'transformer.h.12.mlp.dense_4h_to_h.lora_B', 'transformer.h.10.self_attention.query_key_value.lora_B', 'transformer.h.19.self_attention.query_key_value.lora_B', 'transformer.h.21.mlp.dense_4h_to_h.lora_A', 'transformer.h.14.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.self_attention.dense.lora_A', 'transformer.h.4.mlp.dense_h_to_4h.lora_A', 'transformer.h.18.self_attention.dense.lora_A', 'transformer.h.7.mlp.dense_h_to_4h.lora_A', 'transformer.h.19.self_attention.dense.lora_B', 'transformer.h.19.mlp.dense_h_to_4h.lora_A', 'transformer.h.17.mlp.dense_h_to_4h.lora_A', 'transformer.h.18.mlp.dense_h_to_4h.lora_A', 'transformer.h.0.self_attention.query_key_value.lora_B', 'transformer.h.14.mlp.dense_4h_to_h.lora_B', 'transformer.h.16.self_attention.dense.lora_A', 'transformer.h.10.mlp.dense_h_to_4h.lora_B', 'lm_head.weight', 'transformer.h.1.self_attention.dense.lora_B', 'transformer.h.3.mlp.dense_h_to_4h.lora_B', 'transformer.h.23.self_attention.dense.lora_A', 'transformer.h.5.self_attention.query_key_value.lora_B', 'transformer.h.0.mlp.dense_h_to_4h.lora_B', 'transformer.h.2.self_attention.query_key_value.lora_A', 'transformer.h.14.self_attention.query_key_value.lora_A', 'transformer.h.1.self_attention.dense.lora_A', 'transformer.h.12.self_attention.dense.lora_A', 'transformer.h.6.mlp.dense_h_to_4h.lora_A', 'transformer.h.20.self_attention.query_key_value.lora_B']
- This IS expected if you are initializing BloomModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BloomModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Found cached dataset json (/root/.cache/huggingface/datasets/json/hh-rlhf-226e0526113c616f/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 383.92it/s]
Parameter 'indices'=<generator object train.<locals>.<genexpr> at 0x7fdf71767ca0> of the transform datasets.arrow_dataset.Dataset.select couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 160800/160800 [05:40<00:00, 472.16it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1710/1710 [00:04<00:00, 422.55it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8552/8552 [00:19<00:00, 433.97it/s]
Train step of epoch 0: 100%|███████████████████████████████████████████████████████████████████| 160800/160800 [100:45:04<00:00,  2.26s/it, dist=nan, acc=0]
Train epoch: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [100:45:04<00:00, 362704.33s/it]
root@LAPTOP-3SUHS40U:/home/work/ColossalAI/applications/Chat/examples#

3.3 RL(Training model using prompts with RL)

3.3.1 命令

set_n_least_used_CUDA_VISIBLE_DEVICES() {local n=${1:-"9999"}echo "GPU Memory Usage:"local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \| tail -n +2 \| nl -v 0 \| tee /dev/tty \| sort -g -k 2 \| awk '{print $1}' \| head -n $n)export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')echo "Now CUDA_VISIBLE_DEVICES is set to:"echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}set_n_least_used_CUDA_VISIBLE_DEVICES 1# torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2# torchrun --standalone --nproc_per_node=2 train_prompts.py --prompt_dataset /path/to/data.json --strategy colossalai_zero2torchrun --standalone --nproc_per_node=1 train_prompts.py \--prompt_dataset '/mnt/f/kangpengtao/study/ColossalAI/InstructionWild/data/prompts.json' \--pretrain_dataset '/mnt/f/kangpengtao/study/ColossalAI/InstructionWild/data/instinwild.json' \--strategy naive \--model bloom \--pretrain '/mnt/f/kangpengtao/study/ColossalAI/bigscience/bloom-560m' \--rm_path '/mnt/f/kangpengtao/study/ColossalAI/rm-static.pt' \--rm_pretrain '/mnt/f/kangpengtao/study/ColossalAI/Coati-7B' \--save_path '/mnt/f/kangpengtao/study/ColossalAI/prompts-static.pt' \--max_epochs 1 \--num_episodes 1 \--train_batch_size 1 \--ptx_batch_size 1 \--experience_batch_size 1 \--lora_rank 16 \

3.3.2 日志

root@LAPTOP-3SUHS40U:/home/work/ColossalAI/applications/Chat/examples# ./train_prompts_bloom_kpt.sh
GPU Memory Usage:0  580 MiB
Now CUDA_VISIBLE_DEVICES is set to:
CUDA_VISIBLE_DEVICES=0
Some weights of the model checkpoint at /mnt/f/kangpengtao/study/ColossalAI/Coati-7B were not used when initializing BloomModel: ['transformer.h.5.self_attention.query_key_value.lora_A', 'transformer.h.7.self_attention.query_key_value.lora_A', 'transformer.h.15.self_attention.query_key_value.lora_B', 'transformer.h.6.self_attention.query_key_value.lora_A', 'transformer.h.16.mlp.dense_h_to_4h.lora_A', 'transformer.h.2.mlp.dense_h_to_4h.lora_A', 'transformer.h.22.self_attention.query_key_value.lora_B', 'transformer.h.14.mlp.dense_4h_to_h.lora_A', 'transformer.h.12.self_attention.dense.lora_A', 'transformer.h.15.mlp.dense_h_to_4h.lora_A', 'transformer.h.3.self_attention.dense.lora_A', 'transformer.h.10.self_attention.query_key_value.lora_B', 'transformer.h.1.self_attention.dense.lora_B', 'transformer.h.2.self_attention.dense.lora_A', 'transformer.h.21.mlp.dense_h_to_4h.lora_A', 'transformer.h.12.self_attention.query_key_value.lora_A', 'transformer.h.9.self_attention.query_key_value.lora_B', 'transformer.h.13.mlp.dense_4h_to_h.lora_B', 'transformer.h.3.self_attention.query_key_value.lora_A', 'transformer.h.11.mlp.dense_4h_to_h.lora_B', 'transformer.h.7.self_attention.dense.lora_A', 'transformer.h.3.self_attention.dense.lora_B', 'transformer.h.21.self_attention.dense.lora_A', 'transformer.h.7.mlp.dense_h_to_4h.lora_A', 'transformer.h.12.mlp.dense_4h_to_h.lora_B', 'transformer.h.13.self_attention.query_key_value.lora_A', 'transformer.h.19.mlp.dense_4h_to_h.lora_A', 'transformer.h.0.self_attention.dense.lora_A', 'transformer.h.13.mlp.dense_h_to_4h.lora_A', 'transformer.h.16.mlp.dense_4h_to_h.lora_A', 'transformer.h.14.mlp.dense_h_to_4h.lora_B', 'transformer.h.17.self_attention.dense.lora_A', 'transformer.h.3.mlp.dense_h_to_4h.lora_A', 'transformer.h.15.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.mlp.dense_h_to_4h.lora_A', 'transformer.h.4.mlp.dense_4h_to_h.lora_B', 'transformer.h.2.self_attention.query_key_value.lora_A', 'transformer.h.18.mlp.dense_4h_to_h.lora_A', 'transformer.h.18.self_attention.dense.lora_A', 'transformer.h.21.mlp.dense_4h_to_h.lora_B', 'transformer.h.1.mlp.dense_h_to_4h.lora_A', 'transformer.h.18.self_attention.dense.lora_B', 'transformer.h.17.self_attention.query_key_value.lora_A', 'transformer.h.11.mlp.dense_4h_to_h.lora_A', 'transformer.h.0.self_attention.dense.lora_B', 'transformer.h.11.self_attention.query_key_value.lora_A', 'transformer.h.5.self_attention.query_key_value.lora_B', 'transformer.h.20.mlp.dense_4h_to_h.lora_A', 'transformer.h.16.mlp.dense_4h_to_h.lora_B', 'transformer.h.9.mlp.dense_h_to_4h.lora_B', 'transformer.h.17.mlp.dense_h_to_4h.lora_A', 'transformer.h.21.self_attention.dense.lora_B', 'transformer.h.18.self_attention.query_key_value.lora_B', 'transformer.h.2.self_attention.dense.lora_B', 'transformer.h.5.mlp.dense_h_to_4h.lora_B', 'transformer.h.10.mlp.dense_4h_to_h.lora_B', 'transformer.h.13.self_attention.dense.lora_A', 'transformer.h.3.mlp.dense_4h_to_h.lora_B', 'transformer.h.19.self_attention.query_key_value.lora_A', 'transformer.h.19.self_attention.dense.lora_A', 'transformer.h.13.mlp.dense_h_to_4h.lora_B', 'transformer.h.15.mlp.dense_4h_to_h.lora_B', 'transformer.h.9.self_attention.dense.lora_A', 'lm_head.lora_A', 'transformer.h.5.mlp.dense_4h_to_h.lora_A', 'transformer.h.16.self_attention.query_key_value.lora_B', 'transformer.h.7.self_attention.dense.lora_B', 'transformer.h.18.mlp.dense_h_to_4h.lora_A', 'transformer.h.11.self_attention.query_key_value.lora_B', 'transformer.h.1.mlp.dense_4h_to_h.lora_A', 'transformer.h.10.mlp.dense_4h_to_h.lora_A', 'transformer.h.21.mlp.dense_h_to_4h.lora_B', 'transformer.h.17.self_attention.query_key_value.lora_B', 'transformer.h.19.mlp.dense_h_to_4h.lora_A', 'transformer.h.6.self_attention.query_key_value.lora_B', 'transformer.h.12.self_attention.query_key_value.lora_B', 'transformer.h.9.self_attention.dense.lora_B', 'transformer.h.22.mlp.dense_4h_to_h.lora_A', 'transformer.h.19.mlp.dense_4h_to_h.lora_B', 'transformer.h.9.mlp.dense_4h_to_h.lora_A', 'transformer.h.0.mlp.dense_4h_to_h.lora_B', 'transformer.h.6.mlp.dense_h_to_4h.lora_A', 'transformer.h.10.self_attention.query_key_value.lora_A', 'transformer.h.16.self_attention.dense.lora_A', 'transformer.h.4.mlp.dense_h_to_4h.lora_A', 'transformer.h.0.mlp.dense_4h_to_h.lora_A', 'transformer.h.8.mlp.dense_4h_to_h.lora_A', 'transformer.h.1.mlp.dense_h_to_4h.lora_B', 'transformer.h.6.self_attention.dense.lora_B', 'transformer.h.15.self_attention.query_key_value.lora_A', 'transformer.h.0.self_attention.query_key_value.lora_A', 'transformer.h.0.mlp.dense_h_to_4h.lora_B', 'transformer.h.8.mlp.dense_4h_to_h.lora_B', 'transformer.h.10.mlp.dense_h_to_4h.lora_A', 'transformer.h.6.self_attention.dense.lora_A', 'transformer.h.11.self_attention.dense.lora_B', 'transformer.h.17.mlp.dense_4h_to_h.lora_A', 'transformer.h.23.mlp.dense_h_to_4h.lora_A', 'transformer.h.12.mlp.dense_h_to_4h.lora_B', 'transformer.h.8.mlp.dense_h_to_4h.lora_B', 'transformer.h.16.self_attention.dense.lora_B', 'transformer.h.15.self_attention.dense.lora_B', 'transformer.h.10.mlp.dense_h_to_4h.lora_B', 'transformer.h.20.mlp.dense_h_to_4h.lora_A', 'transformer.h.1.self_attention.dense.lora_A', 'transformer.h.2.mlp.dense_4h_to_h.lora_A', 'transformer.h.12.mlp.dense_4h_to_h.lora_A', 'transformer.h.3.mlp.dense_4h_to_h.lora_A', 'transformer.h.2.self_attention.query_key_value.lora_B', 'transformer.h.12.self_attention.dense.lora_B', 'transformer.h.18.mlp.dense_h_to_4h.lora_B', 'transformer.h.23.mlp.dense_4h_to_h.lora_B', 'transformer.h.17.self_attention.dense.lora_B', 'transformer.h.16.self_attention.query_key_value.lora_A', 'transformer.h.23.self_attention.query_key_value.lora_B', 'transformer.h.7.mlp.dense_4h_to_h.lora_A', 'transformer.h.23.self_attention.dense.lora_B', 'transformer.h.14.mlp.dense_4h_to_h.lora_B', 'transformer.h.7.mlp.dense_h_to_4h.lora_B', 'transformer.h.8.self_attention.dense.lora_B', 'transformer.h.9.mlp.dense_h_to_4h.lora_A', 'transformer.h.14.self_attention.dense.lora_B', 'transformer.h.1.mlp.dense_4h_to_h.lora_B', 'transformer.h.1.self_attention.query_key_value.lora_B', 'transformer.h.8.self_attention.query_key_value.lora_B', 'transformer.h.8.mlp.dense_h_to_4h.lora_A', 'transformer.h.2.mlp.dense_4h_to_h.lora_B', 'transformer.h.21.self_attention.query_key_value.lora_B', 'transformer.h.20.mlp.dense_4h_to_h.lora_B', 'transformer.h.0.self_attention.query_key_value.lora_B', 'transformer.h.5.self_attention.dense.lora_A', 'transformer.h.2.mlp.dense_h_to_4h.lora_B', 'transformer.h.10.self_attention.dense.lora_A', 'transformer.h.4.self_attention.query_key_value.lora_A', 'transformer.h.14.self_attention.query_key_value.lora_B', 'transformer.h.8.self_attention.query_key_value.lora_A', 'transformer.h.18.self_attention.query_key_value.lora_A', 'transformer.h.4.mlp.dense_4h_to_h.lora_A', 'transformer.h.4.self_attention.query_key_value.lora_B', 'transformer.h.4.self_attention.dense.lora_B', 'transformer.h.21.self_attention.query_key_value.lora_A', 'transformer.h.7.self_attention.query_key_value.lora_B', 'transformer.h.20.self_attention.dense.lora_A', 'transformer.h.18.mlp.dense_4h_to_h.lora_B', 'transformer.h.22.mlp.dense_h_to_4h.lora_B', 'transformer.h.22.mlp.dense_h_to_4h.lora_A', 'transformer.h.9.self_attention.query_key_value.lora_A', 'transformer.h.11.mlp.dense_h_to_4h.lora_B', 'transformer.h.5.self_attention.dense.lora_B', 'transformer.h.3.self_attention.query_key_value.lora_B', 'transformer.h.11.self_attention.dense.lora_A', 'transformer.h.23.self_attention.query_key_value.lora_A', 'transformer.h.23.mlp.dense_h_to_4h.lora_B', 'transformer.h.20.self_attention.dense.lora_B', 'transformer.h.3.mlp.dense_h_to_4h.lora_B', 'transformer.h.23.mlp.dense_4h_to_h.lora_A', 'transformer.h.13.mlp.dense_4h_to_h.lora_A', 'transformer.h.6.mlp.dense_4h_to_h.lora_A', 'transformer.h.6.mlp.dense_4h_to_h.lora_B', 'transformer.h.1.self_attention.query_key_value.lora_A', 'transformer.h.5.mlp.dense_h_to_4h.lora_A', 'transformer.h.13.self_attention.query_key_value.lora_B', 'lm_head.lora_B', 'transformer.h.0.mlp.dense_h_to_4h.lora_A', 'transformer.h.7.mlp.dense_4h_to_h.lora_B', 'transformer.h.22.self_attention.dense.lora_B', 'transformer.h.19.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.self_attention.dense.lora_A', 'transformer.h.16.mlp.dense_h_to_4h.lora_B', 'transformer.h.6.mlp.dense_h_to_4h.lora_B', 'transformer.h.22.self_attention.query_key_value.lora_A', 'lm_head.weight', 'transformer.h.22.self_attention.dense.lora_A', 'transformer.h.15.mlp.dense_4h_to_h.lora_A', 'transformer.h.15.self_attention.dense.lora_A', 'transformer.h.10.self_attention.dense.lora_B', 'transformer.h.12.mlp.dense_h_to_4h.lora_A', 'transformer.h.20.self_attention.query_key_value.lora_B', 'transformer.h.23.self_attention.dense.lora_A', 'transformer.h.22.mlp.dense_4h_to_h.lora_B', 'transformer.h.11.mlp.dense_h_to_4h.lora_A', 'transformer.h.13.self_attention.dense.lora_B', 'transformer.h.20.self_attention.query_key_value.lora_A', 'transformer.h.5.mlp.dense_4h_to_h.lora_B', 'transformer.h.4.mlp.dense_h_to_4h.lora_B', 'transformer.h.19.self_attention.query_key_value.lora_B', 'transformer.h.20.mlp.dense_h_to_4h.lora_B', 'transformer.h.4.self_attention.dense.lora_A', 'transformer.h.19.self_attention.dense.lora_B', 'transformer.h.9.mlp.dense_4h_to_h.lora_B', 'transformer.h.17.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.self_attention.query_key_value.lora_A', 'transformer.h.21.mlp.dense_4h_to_h.lora_A', 'transformer.h.8.self_attention.dense.lora_A', 'transformer.h.17.mlp.dense_4h_to_h.lora_B']
- This IS expected if you are initializing BloomModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BloomModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at /mnt/f/kangpengtao/study/ColossalAI/Coati-7B were not used when initializing BloomModel: ['transformer.h.5.self_attention.query_key_value.lora_A', 'transformer.h.7.self_attention.query_key_value.lora_A', 'transformer.h.15.self_attention.query_key_value.lora_B', 'transformer.h.6.self_attention.query_key_value.lora_A', 'transformer.h.16.mlp.dense_h_to_4h.lora_A', 'transformer.h.2.mlp.dense_h_to_4h.lora_A', 'transformer.h.22.self_attention.query_key_value.lora_B', 'transformer.h.14.mlp.dense_4h_to_h.lora_A', 'transformer.h.12.self_attention.dense.lora_A', 'transformer.h.15.mlp.dense_h_to_4h.lora_A', 'transformer.h.3.self_attention.dense.lora_A', 'transformer.h.10.self_attention.query_key_value.lora_B', 'transformer.h.1.self_attention.dense.lora_B', 'transformer.h.2.self_attention.dense.lora_A', 'transformer.h.21.mlp.dense_h_to_4h.lora_A', 'transformer.h.12.self_attention.query_key_value.lora_A', 'transformer.h.9.self_attention.query_key_value.lora_B', 'transformer.h.13.mlp.dense_4h_to_h.lora_B', 'transformer.h.3.self_attention.query_key_value.lora_A', 'transformer.h.11.mlp.dense_4h_to_h.lora_B', 'transformer.h.7.self_attention.dense.lora_A', 'transformer.h.3.self_attention.dense.lora_B', 'transformer.h.21.self_attention.dense.lora_A', 'transformer.h.7.mlp.dense_h_to_4h.lora_A', 'transformer.h.12.mlp.dense_4h_to_h.lora_B', 'transformer.h.13.self_attention.query_key_value.lora_A', 'transformer.h.19.mlp.dense_4h_to_h.lora_A', 'transformer.h.0.self_attention.dense.lora_A', 'transformer.h.13.mlp.dense_h_to_4h.lora_A', 'transformer.h.16.mlp.dense_4h_to_h.lora_A', 'transformer.h.14.mlp.dense_h_to_4h.lora_B', 'transformer.h.17.self_attention.dense.lora_A', 'transformer.h.3.mlp.dense_h_to_4h.lora_A', 'transformer.h.15.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.mlp.dense_h_to_4h.lora_A', 'transformer.h.4.mlp.dense_4h_to_h.lora_B', 'transformer.h.2.self_attention.query_key_value.lora_A', 'transformer.h.18.mlp.dense_4h_to_h.lora_A', 'transformer.h.18.self_attention.dense.lora_A', 'transformer.h.21.mlp.dense_4h_to_h.lora_B', 'transformer.h.1.mlp.dense_h_to_4h.lora_A', 'transformer.h.18.self_attention.dense.lora_B', 'transformer.h.17.self_attention.query_key_value.lora_A', 'transformer.h.11.mlp.dense_4h_to_h.lora_A', 'transformer.h.0.self_attention.dense.lora_B', 'transformer.h.11.self_attention.query_key_value.lora_A', 'transformer.h.5.self_attention.query_key_value.lora_B', 'transformer.h.20.mlp.dense_4h_to_h.lora_A', 'transformer.h.16.mlp.dense_4h_to_h.lora_B', 'transformer.h.9.mlp.dense_h_to_4h.lora_B', 'transformer.h.17.mlp.dense_h_to_4h.lora_A', 'transformer.h.21.self_attention.dense.lora_B', 'transformer.h.18.self_attention.query_key_value.lora_B', 'transformer.h.2.self_attention.dense.lora_B', 'transformer.h.5.mlp.dense_h_to_4h.lora_B', 'transformer.h.10.mlp.dense_4h_to_h.lora_B', 'transformer.h.13.self_attention.dense.lora_A', 'transformer.h.3.mlp.dense_4h_to_h.lora_B', 'transformer.h.19.self_attention.query_key_value.lora_A', 'transformer.h.19.self_attention.dense.lora_A', 'transformer.h.13.mlp.dense_h_to_4h.lora_B', 'transformer.h.15.mlp.dense_4h_to_h.lora_B', 'transformer.h.9.self_attention.dense.lora_A', 'lm_head.lora_A', 'transformer.h.5.mlp.dense_4h_to_h.lora_A', 'transformer.h.16.self_attention.query_key_value.lora_B', 'transformer.h.7.self_attention.dense.lora_B', 'transformer.h.18.mlp.dense_h_to_4h.lora_A', 'transformer.h.11.self_attention.query_key_value.lora_B', 'transformer.h.1.mlp.dense_4h_to_h.lora_A', 'transformer.h.10.mlp.dense_4h_to_h.lora_A', 'transformer.h.21.mlp.dense_h_to_4h.lora_B', 'transformer.h.17.self_attention.query_key_value.lora_B', 'transformer.h.19.mlp.dense_h_to_4h.lora_A', 'transformer.h.6.self_attention.query_key_value.lora_B', 'transformer.h.12.self_attention.query_key_value.lora_B', 'transformer.h.9.self_attention.dense.lora_B', 'transformer.h.22.mlp.dense_4h_to_h.lora_A', 'transformer.h.19.mlp.dense_4h_to_h.lora_B', 'transformer.h.9.mlp.dense_4h_to_h.lora_A', 'transformer.h.0.mlp.dense_4h_to_h.lora_B', 'transformer.h.6.mlp.dense_h_to_4h.lora_A', 'transformer.h.10.self_attention.query_key_value.lora_A', 'transformer.h.16.self_attention.dense.lora_A', 'transformer.h.4.mlp.dense_h_to_4h.lora_A', 'transformer.h.0.mlp.dense_4h_to_h.lora_A', 'transformer.h.8.mlp.dense_4h_to_h.lora_A', 'transformer.h.1.mlp.dense_h_to_4h.lora_B', 'transformer.h.6.self_attention.dense.lora_B', 'transformer.h.15.self_attention.query_key_value.lora_A', 'transformer.h.0.self_attention.query_key_value.lora_A', 'transformer.h.0.mlp.dense_h_to_4h.lora_B', 'transformer.h.8.mlp.dense_4h_to_h.lora_B', 'transformer.h.10.mlp.dense_h_to_4h.lora_A', 'transformer.h.6.self_attention.dense.lora_A', 'transformer.h.11.self_attention.dense.lora_B', 'transformer.h.17.mlp.dense_4h_to_h.lora_A', 'transformer.h.23.mlp.dense_h_to_4h.lora_A', 'transformer.h.12.mlp.dense_h_to_4h.lora_B', 'transformer.h.8.mlp.dense_h_to_4h.lora_B', 'transformer.h.16.self_attention.dense.lora_B', 'transformer.h.15.self_attention.dense.lora_B', 'transformer.h.10.mlp.dense_h_to_4h.lora_B', 'transformer.h.20.mlp.dense_h_to_4h.lora_A', 'transformer.h.1.self_attention.dense.lora_A', 'transformer.h.2.mlp.dense_4h_to_h.lora_A', 'transformer.h.12.mlp.dense_4h_to_h.lora_A', 'transformer.h.3.mlp.dense_4h_to_h.lora_A', 'transformer.h.2.self_attention.query_key_value.lora_B', 'transformer.h.12.self_attention.dense.lora_B', 'transformer.h.18.mlp.dense_h_to_4h.lora_B', 'transformer.h.23.mlp.dense_4h_to_h.lora_B', 'transformer.h.17.self_attention.dense.lora_B', 'transformer.h.16.self_attention.query_key_value.lora_A', 'transformer.h.23.self_attention.query_key_value.lora_B', 'transformer.h.7.mlp.dense_4h_to_h.lora_A', 'transformer.h.23.self_attention.dense.lora_B', 'transformer.h.14.mlp.dense_4h_to_h.lora_B', 'transformer.h.7.mlp.dense_h_to_4h.lora_B', 'transformer.h.8.self_attention.dense.lora_B', 'transformer.h.9.mlp.dense_h_to_4h.lora_A', 'transformer.h.14.self_attention.dense.lora_B', 'transformer.h.1.mlp.dense_4h_to_h.lora_B', 'transformer.h.1.self_attention.query_key_value.lora_B', 'transformer.h.8.self_attention.query_key_value.lora_B', 'transformer.h.8.mlp.dense_h_to_4h.lora_A', 'transformer.h.2.mlp.dense_4h_to_h.lora_B', 'transformer.h.21.self_attention.query_key_value.lora_B', 'transformer.h.20.mlp.dense_4h_to_h.lora_B', 'transformer.h.0.self_attention.query_key_value.lora_B', 'transformer.h.5.self_attention.dense.lora_A', 'transformer.h.2.mlp.dense_h_to_4h.lora_B', 'transformer.h.10.self_attention.dense.lora_A', 'transformer.h.4.self_attention.query_key_value.lora_A', 'transformer.h.14.self_attention.query_key_value.lora_B', 'transformer.h.8.self_attention.query_key_value.lora_A', 'transformer.h.18.self_attention.query_key_value.lora_A', 'transformer.h.4.mlp.dense_4h_to_h.lora_A', 'transformer.h.4.self_attention.query_key_value.lora_B', 'transformer.h.4.self_attention.dense.lora_B', 'transformer.h.21.self_attention.query_key_value.lora_A', 'transformer.h.7.self_attention.query_key_value.lora_B', 'transformer.h.20.self_attention.dense.lora_A', 'transformer.h.18.mlp.dense_4h_to_h.lora_B', 'transformer.h.22.mlp.dense_h_to_4h.lora_B', 'transformer.h.22.mlp.dense_h_to_4h.lora_A', 'transformer.h.9.self_attention.query_key_value.lora_A', 'transformer.h.11.mlp.dense_h_to_4h.lora_B', 'transformer.h.5.self_attention.dense.lora_B', 'transformer.h.3.self_attention.query_key_value.lora_B', 'transformer.h.11.self_attention.dense.lora_A', 'transformer.h.23.self_attention.query_key_value.lora_A', 'transformer.h.23.mlp.dense_h_to_4h.lora_B', 'transformer.h.20.self_attention.dense.lora_B', 'transformer.h.3.mlp.dense_h_to_4h.lora_B', 'transformer.h.23.mlp.dense_4h_to_h.lora_A', 'transformer.h.13.mlp.dense_4h_to_h.lora_A', 'transformer.h.6.mlp.dense_4h_to_h.lora_A', 'transformer.h.6.mlp.dense_4h_to_h.lora_B', 'transformer.h.1.self_attention.query_key_value.lora_A', 'transformer.h.5.mlp.dense_h_to_4h.lora_A', 'transformer.h.13.self_attention.query_key_value.lora_B', 'lm_head.lora_B', 'transformer.h.0.mlp.dense_h_to_4h.lora_A', 'transformer.h.7.mlp.dense_4h_to_h.lora_B', 'transformer.h.22.self_attention.dense.lora_B', 'transformer.h.19.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.self_attention.dense.lora_A', 'transformer.h.16.mlp.dense_h_to_4h.lora_B', 'transformer.h.6.mlp.dense_h_to_4h.lora_B', 'transformer.h.22.self_attention.query_key_value.lora_A', 'lm_head.weight', 'transformer.h.22.self_attention.dense.lora_A', 'transformer.h.15.mlp.dense_4h_to_h.lora_A', 'transformer.h.15.self_attention.dense.lora_A', 'transformer.h.10.self_attention.dense.lora_B', 'transformer.h.12.mlp.dense_h_to_4h.lora_A', 'transformer.h.20.self_attention.query_key_value.lora_B', 'transformer.h.23.self_attention.dense.lora_A', 'transformer.h.22.mlp.dense_4h_to_h.lora_B', 'transformer.h.11.mlp.dense_h_to_4h.lora_A', 'transformer.h.13.self_attention.dense.lora_B', 'transformer.h.20.self_attention.query_key_value.lora_A', 'transformer.h.5.mlp.dense_4h_to_h.lora_B', 'transformer.h.4.mlp.dense_h_to_4h.lora_B', 'transformer.h.19.self_attention.query_key_value.lora_B', 'transformer.h.20.mlp.dense_h_to_4h.lora_B', 'transformer.h.4.self_attention.dense.lora_A', 'transformer.h.19.self_attention.dense.lora_B', 'transformer.h.9.mlp.dense_4h_to_h.lora_B', 'transformer.h.17.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.self_attention.query_key_value.lora_A', 'transformer.h.21.mlp.dense_4h_to_h.lora_A', 'transformer.h.8.self_attention.dense.lora_A', 'transformer.h.17.mlp.dense_4h_to_h.lora_B']
- This IS expected if you are initializing BloomModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BloomModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[05/21/23 00:32:36] INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/prompt_dataset.py:30__init__INFO     colossalai - colossalai - INFO: Loading data...INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/prompt_dataset.py:32__init__INFO     colossalai - colossalai - INFO: Loaded 858 examples.INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/prompt_dataset.py:35__init__INFO     colossalai - colossalai - INFO: Limiting dataset to 16384 examples.INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:121__init__INFO     colossalai - colossalai - INFO: Loading data...
[05/21/23 00:32:37] INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:123__init__INFO     colossalai - colossalai - INFO: Loaded 103695 examples.INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:126__init__INFO     colossalai - colossalai - INFO: Limiting dataset to 16384 examples.INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:129__init__INFO     colossalai - colossalai - INFO: Formatting inputs...INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:137__init__INFO     colossalai - colossalai - INFO: Tokenizing inputs... This may take some time...
Train epoch [1/1]: 100%|████████████████████████████████████████████████████████████████████████████| 10/10 [00:04<00:00,  2.45it/s, reward=nan]
Episode [1/1]: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:20<00:00,  2.03s/it]

3.4 Inferrence(推理)

3.4.1 代码

import argparseimport torch
from coati.models.bloom import BLOOMActor
from coati.models.gpt import GPTActor
from coati.models.opt import OPTActor
from coati.models.roberta import RoBERTaActor
from transformers import AutoTokenizer, RobertaTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizerimport gradio as grMAX_TURNS = 20
MAX_BOXES = MAX_TURNS * 2# 这里换成自己模型的路径
model_path_dict = {'SFT': '/mnt/f/kangpengtao/study/ColossalAI/Coati-7B/pytorch_model.bin','RM': '/mnt/f/kangpengtao/study/ColossalAI/rm-static.pt','RL': '/mnt/f/kangpengtao/study/ColossalAI/prompts-static.pt',
}pretrain_dict = {'bloom': '/mnt/f/kangpengtao/study/ColossalAI/bigscience/bloom-560m'
}def predict(model, dict, input, max_length, history):pretrain = pretrain_dict[model]updates = []# configure modelif model == 'gpt2':actor = GPTActor(pretrained=pretrain).to(torch.cuda.current_device())elif model == 'bloom':actor = BLOOMActor(pretrained=pretrain).to(torch.cuda.current_device())elif model == 'opt':actor = OPTActor(pretrained=pretrain).to(torch.cuda.current_device())elif model == 'roberta':actor = RoBERTaActor(pretrained=pretrain).to(torch.cuda.current_device())else:raise ValueError(f'Unsupported model "{model}"')state_dict = torch.load(model_path_dict[dict])actor.model.load_state_dict(state_dict, strict=False)# configure tokenizerif model == 'gpt2':tokenizer = GPT2Tokenizer.from_pretrained('gpt2')tokenizer.pad_token = tokenizer.eos_tokenelif model == 'bloom':tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')tokenizer.pad_token = tokenizer.eos_tokenelif model == 'opt':tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')elif model == 'roberta':tokenizer = RobertaTokenizer.from_pretrained("roberta-base")else:raise ValueError(f'Unsupported model "{model}"')actor.eval()question = f'Question: {input} ? Answer:'input_ids = tokenizer.encode(question, return_tensors='pt').to(torch.cuda.current_device())outputs = actor.generate(input_ids,max_length=max_length,do_sample=True,top_k=50,top_p=0.95,num_return_sequences=1)output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)for i in history:if not i.get('visible'):continueprint(i)value = i.get('value')updates.append(gr.update(visible=True, value=value))updates.append(gr.update(visible=True, value="提问:" + input))updates.append(gr.update(visible=True, value=f"{dict}:" + output[0].replace(question, '').replace(question.replace(' ', ''), '')))if len(updates) < MAX_BOXES:updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))history.extend(updates)return [history] + updatesif __name__ == '__main__':with gr.Blocks() as demo:state = gr.State([])text_boxes = []with gr.Row():with gr.Column(scale=1):model = gr.Radio(['gpt2', 'bloom', 'opt', 'roberta'], label="model", interactive=False, value='bloom')dict = gr.Radio(["SFT", "RM", "RL"], label="dict", interactive=True, value='RL')max_length = gr.Slider(0, 512, value=100, step=1.0, label="max_length", interactive=True)button = gr.Button("Generate")with gr.Column(scale=4):for i in range(MAX_BOXES):if i % 2 == 0:text_boxes += [gr.Markdown(visible=False, label="提问:")]else:text_boxes += [gr.Markdown(visible=False, label="回复:")]input = gr.Textbox(show_label=True, placeholder="input", lines=5, label='input').style(container=False)button.click(predict, [model, dict, input, max_length, state], [state] + text_boxes)demo.queue().launch(share=False, inbrowser=True, server_name='0.0.0.0')

3.4.2 演示

提问:了解一个无家可归或社会边缘人的故事,并考虑自己的角色和关怀方式。RL:Question: 了解一个无家可归或社会边缘人的故事,并考虑自己的角色和关怀方式。? Answer: 无论你是单身、独居或已婚,不管你和谁在一起,或者无论在哪里工作,都必须向世界发出一个邀请:去帮助他人!? Excluded Subjects: 婚姻和性关系、情感、女性、单身和女性伴侣、同性恋和异性伴侣。 作者简介 作者:(英国)戴维·巴兰森(David Baker) 译提问:了解一个无家可归或社会边缘人的故事,并考虑自己的角色和关怀方式。RL:Question: 了解一个无家可归或社会边缘人的故事,并考虑自己的角色和关怀方式。? Answer: 无论你是单身、独居或已婚,不管你和谁在一起,或者无论在哪里工作,都必须向世界发出一个邀请:去帮助他人!? Excluded Subjects: 婚姻和性关系、情感、女性、单身和女性伴侣、同性恋和异性伴侣。 作者简介 作者:(英国)戴维·巴兰森(David Baker) 译提问:在机器学习中,什么是监督学习和无监督学习?你如何选择适当的算法?RL:Question: 在机器学习中,什么是监督学习和无监督学习?你如何选择适当的算法?? Answer: 机器学习的深度学习方法无监督学习(Deep Learning of No-Layered):它比传统深度学习(深度学习)更具灵活性,例如,您可以通过训练网络学习某种类型(非学习任务)的模式(比如,网络训练了某条路径),然后在网络中生成新的特征,以便获得更高的模型效果(对于非学习任务,模型可能非常糟糕提问:给定一组JSON对象,写一个程序来计算其中所有数字值的平均值。RL:Question: 给定一组JSON对象,写一个程序来计算其中所有数字值的平均值。? Answer: 返回每一个数字的平均值。? Answer: 如果你有一个数组中的每一列都是一个数字,那么可以使用这种方法计算每一个数字的平均值。? Answer: 默认情况下,JSON会返回一个返回一个数组。? Question: 你需要一个JSON对象来存储一组数据(一个字符串和一个集合)。例如,如果你给一个字符串定义了8个属性,则应该使用这个对象来存储这些属性(这些属性被抽象到一个类下面)。? Answer: 如果该字符串的每个值都一个数字,那么JSON会返回一个数组,因为数字比较准确。但是如果你给一个字符串定义了12个属性,那么JSON会返回一个返回一个数组。这个问题就在这里。? Question: 你需要一个JSON对象来存储一组数据(一个字符串和一个集合)。例如,如果你给一个字符串定义了8个属性,那么应该使用这个对象来存储这些属性(这些属性被抽象到一个类下面)。? Answer: 如果该字符串的每个值都一个数字,那么JSON会返回一个数组,因为数字比较准确。但是如果你给一个字符串定义了12个属性,那么JSON会返回一个返回一个数组。这个问题就在这里。? Question: 你需要一个JSON对象来存储一组数据(一个字符串和一个集合)。例如,如果你给一个字符串定义了8个属性,那么应该使用这个对象来存储这些属性(这些属性被抽象到一个类下面)。? Answer: 如果该字符串的每个值都一个数字,那么JSON会返回一个数组,因为数字比较准确。但是如果你给一个字符串定义了12个属性,那么JSON会返回一个返回一个数组。这个问题就在这里。? Question: 你需要一个JSON对象来存储一组数据(一个字符串和一个集合)。例如,如果你给一个字符串定义了8个属性,那么应该使用这个对象来存储这些属性(这些属性被抽象到一个类下面)。? Answer: 如果该字符串的每个值都一个数字,那么JSON会返回一个数组,因为数字比较准确。但是如果你给一个字符串定义了12个属性,那么JSON会返回一个返回一个数组。这个问题就在这里。? Question: 你需要一个JSON对象来存储一组数据(一个字符串和一个集合)。例如,如果你给一个字符串定义了8个属性,那么应该使用这个对象来存储这些属性(这些属性被抽象到一个类下面)。? Answer

4 参考

https://blog.csdn.net/chen_hao_181/article/details/130172096

注意:md转公众号的连接:https://md.guozh.net/

Colosal-AI复现流程相关推荐

  1. 干货! AI 推断解决方案栈 Vitis AI 全流程独家解析

    2019年,擅长FPGA赛道的赛灵思发布了Vitis统一软件平台,旨在为不熟悉硬件编程与FPGA开发的软件开发人员提供便利.其中的Vitis AI开发环境,一经发布就备受AI开发者关注. 在赛灵思的定 ...

  2. AI全流程开发难题破解之钥

    摘要:通过对ModelArts.盘古大模型.ModelBox产品技术的解读,帮助开发者更好的了解AI开发生产线. 本文分享自华为云社区<[大厂内参]第16期:华为云AI开发生产线,破解AI全流程 ...

  3. Pointnet++复现流程及问题解决

    复现流程 1.通过github下载pointnet++的包; 2.以分类为例,在pointnet2-master目录下新建data文件夹,将modelnet40_ply_hdf5_2048数据集解压到 ...

  4. 从生产到理解分发,第三届“马栏山杯”算法大赛带你了解视频平台AI全流程...

    点击蓝字 关注我们 AI TIME欢迎每一位AI爱好者的加入! 从生产到理解分发 第三届"马栏山杯"算法大赛 带你了解视频平台AI全流程       -立足现有业务,朝元宇宙出发 ...

  5. 盘点工业界AI项目流程以及边缘设备现状

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨白夜 来源丨江大白 编辑丨极市平台 导读 本文作者分享了自己踏入深度学习第一个项目的经验以及自己 ...

  6. 白夜:一文看懂AI项目流程及边缘设备开发

    在<AI未来星球>内部群中,白夜从AI项目开发及边缘设备开发的角度,深入浅出的分享了,AI项目开发中的各种感受,以及工作中使用过的一些边缘设备的开发经验. PS:不同嘉宾分享中涉及的相关代 ...

  7. 【解决方案】助力电子商务平台建设TSINGSEE青犀视频实现AI全流程监管

    一.背景 在全球经济放缓.电子商务冲击.经营成本上涨这样一个大时代背景下,整个零售行业目前正在面临着重大的变革,数字化转型成为新零售发展的必然趋势,现阶段整个零售企业信息化建设投入是不足的,一方面是意 ...

  8. 机器学习驱动的游戏AI 应用流程指南

    以卡丁车游戏演示 AI 开发过程,来一场速度与激情?自己玩卡丁车的时候总是碰撞?那不如跟着课程作者一起来打造一个可以自己掌控的AI 玩家?哇-想想就很酷,那我们开始吧! #你将获得# 1.实践将机器学 ...

  9. GraphDTA论文阅读小白笔记(附代码注释和复现流程)

    目录 摘要 背景 数据和方法 GraphDTA概述 药物表征 蛋白表征 分子图的深度学习 GCN GAT GIN GAT-GCN 基准 模型解释 结果讨论 图模型的表现超过了其它模型 图模型发现已知药 ...

最新文章

  1. 做出的C++选择以及背后的原因
  2. 分库分表:如何做到永不迁移数据和避免热点?
  3. python编程中文版百度百科_1.2 搭建python+pycharm编程开发环境
  4. j函数 判断以 什么开头
  5. A different twist on pre-compiling JSPs--reference
  6. 面向对象(匿名内部类与有名字内部类的比较)
  7. 管道popen和pclose的实例使用
  8. oracle erp crm系统,企业集成ERP和CRM系统的模式体验
  9. 点击按钮重新加载ajax,Jquery AJAX点击链接,然后重新加载页面
  10. Java内存模型 - 同步八种操作
  11. 全国计算机一级ms office考试题型,全国计算机等级考试一级MS Office题型剖析
  12. Java 中的 SPI 机制是到底是什么?高级 Java 必须掌握!
  13. 关于URL路径的基本使用
  14. 第三方定量定性质谱检测技术实验
  15. 解决iPhone模拟器无法启动的方法
  16. 麦客CEO李卉:实践证明肯钻营的“小而美”亦动人|企服三会系列报道
  17. houdni 联机渲染解算 hqueue 和deadline问题笔记
  18. 微分恒等式(助于找到均值、方差和其他矩)
  19. switch初始化说服务器维护中,switch国行怎么初始化-switch国行初始化教程
  20. GEE遥感云大数据在林业中的应用

热门文章

  1. 电视卡原理及采购指导
  2. Nudge 助推 下载 及 感想
  3. ZBrush?中Nudge推动笔刷
  4. Qt for Android 动态全屏显示
  5. 开发常用英文[update]
  6. 3D视觉应用案例:引导板件定位抓取
  7. qcc304x笔记之嵌入EQ模块(十二)
  8. GitLab 运行GitLab-Runner CI/CD发布
  9. 微信小程序使用第三方库(第三方js)问题
  10. Go语言开发工程师一定要熟读的5个开源项目