pytorch中使用DataParallel保存模型参数时遇到的问题记录

之前使用Transformers库中的Bert模型在自己的文本分类任务上使用Transformers库里的Trainer方式进行了Fine-tune。今天尝试加载保存好的checkpoint到程序中来直接进行evaluate,结果参数匹配不上,故在此记录问题原因和解决过程。


更新

新的问题

我再某台只有0,1两张显卡的机器上对模型进行了若干次训练,模型直接保存成了一个文件。之后,在重新加载这个模型到另一台拥有0,1,2,3四张显卡的机器上继续进行训练时,发现模型还是只占用0,1两张卡。直接使用

model = nn.DataParallel(model, device_ids=[0,1,2,3])

会报错

RuntimeError: Input, output and indices must be on the current device

查阅相关问题后,找到一种更方便的方式,即使用如下代码:

if isinstance(model,torch.nn.DataParallel):model = model.module
# 之后加载就不会报错了
model = nn.DataParallel(model, device_ids=[0,1,2,3])

之后模型就可以正常使用加载了,也可以放到四张卡上。


直接使用AutoModelForSequenceClassification从checkpoint目录加载

Fine tune阶段的模型保存和使用

在Fine tune阶段,我手动将模型设置为了DataParallel模式,然后再使用Trainer进行Fine tune,而模型在保存时也是Trainer自动去保存的。

所设置的Training Args如下:

training_args = TrainingArguments(output_dir='./results_0415_32',         # output directorynum_train_epochs=3,             # total number of training epochsper_device_train_batch_size=128, # batch size per divice during trainingper_device_eval_batch_size=128,  # batch size for evaluationwarmup_steps=500,               # number of warmup steps for learning rate schedulerweight_decay=0.01,              # strength of weight decay 用于指定权值衰减率,相当于L2正则化中的λ参数logging_dir='./logs',           # directory for storing logslogging_steps=10000,
)

然后从一个初始的BertForSequenceClassification模型开始Fine tune,并手动将模型设置了DataParallel模式。

model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(afid2nor))device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = nn.DataParallel(model, device_ids=[0,1,2,3])
model.to(device)

此时打印模型即:

DataParallel((module): BertForSequenceClassification((bert): BertModel((embeddings): BertEmbeddings((word_embeddings): Embedding(30522, 768, padding_idx=0)(position_embeddings): Embedding(512, 768)(token_type_embeddings): Embedding(2, 768)(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)(dropout): Dropout(p=0.1, inplace=False))(encoder): BertEncoder((layer): ModuleList((0): BertLayer((attention): BertAttention((self): BertSelfAttention((query): Linear(in_features=768, out_features=768, bias=True)(key): Linear(in_features=768, out_features=768, bias=True)(value): Linear(in_features=768, out_features=768, bias=True)(dropout): Dropout(p=0.1, inplace=False))(output): BertSelfOutput((dense): Linear(in_features=768, out_features=768, bias=True)(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)(dropout): Dropout(p=0.1, inplace=False)))(intermediate): BertIntermediate((dense): Linear(in_features=768, out_features=3072, bias=True))(output): BertOutput((dense): Linear(in_features=3072, out_features=768, bias=True)(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)(dropout): Dropout(p=0.1, inplace=False)))# 1-10层内容省略,基本一致(11): BertLayer((attention): BertAttention((self): BertSelfAttention((query): Linear(in_features=768, out_features=768, bias=True)(key): Linear(in_features=768, out_features=768, bias=True)(value): Linear(in_features=768, out_features=768, bias=True)(dropout): Dropout(p=0.1, inplace=False))(output): BertSelfOutput((dense): Linear(in_features=768, out_features=768, bias=True)(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)(dropout): Dropout(p=0.1, inplace=False)))(intermediate): BertIntermediate((dense): Linear(in_features=768, out_features=3072, bias=True))(output): BertOutput((dense): Linear(in_features=3072, out_features=768, bias=True)(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)(dropout): Dropout(p=0.1, inplace=False)))))(pooler): BertPooler((dense): Linear(in_features=768, out_features=768, bias=True)(activation): Tanh()))(dropout): Dropout(p=0.1, inplace=False)(classifier): Linear(in_features=768, out_features=25762, bias=True))
)

可以看到,这样的模型外面"套着"两层结构为

  (module): BertForSequenceClassification((bert): BertModel(

这个对后续从checkpoint加载产生了影响。

模型的加载

加载的代码为:

model = AutoModelForSequenceClassification.from_pretrained('./results_0415_32/checkpoint-98000/')

提示了如下的Warning:

Some weights of the model checkpoint at ./results_0415_32/checkpoint-98000/ were not used when initializing BertForSequenceClassification: ['module.bert.embeddings.position_ids', 'module.bert.embeddings.word_embeddings.weight', 'module.bert.embeddings.position_embeddings.weight', 'module.bert.embeddings.token_type_embeddings.weight', 'module.bert.embeddings.LayerNorm.weight', 'module.bert.embeddings.LayerNorm.bias', 'module.bert.encoder.layer.0.attention.self.query.weight', 'module.bert.encoder.layer.0.attention.self.query.bias', 'module.bert.encoder.layer.0.attention.self.key.weight', 'module.bert.encoder.layer.0.attention.self.key.bias', 'module.bert.encoder.layer.0.attention.self.value.weight', 'module.bert.encoder.layer.0.attention.self.value.bias', 'module.bert.encoder.layer.0.attention.output.dense.weight', 'module.bert.encoder.layer.0.attention.output.dense.bias', 'module.bert.encoder.layer.0.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.0.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.0.intermediate.dense.weight', 'module.bert.encoder.layer.0.intermediate.dense.bias', 'module.bert.encoder.layer.0.output.dense.weight', 'module.bert.encoder.layer.0.output.dense.bias', 'module.bert.encoder.layer.0.output.LayerNorm.weight', 'module.bert.encoder.layer.0.output.LayerNorm.bias', 'module.bert.encoder.layer.1.attention.self.query.weight', 'module.bert.encoder.layer.1.attention.self.query.bias', 'module.bert.encoder.layer.1.attention.self.key.weight', 'module.bert.encoder.layer.1.attention.self.key.bias', 'module.bert.encoder.layer.1.attention.self.value.weight', 'module.bert.encoder.layer.1.attention.self.value.bias', 'module.bert.encoder.layer.1.attention.output.dense.weight', 'module.bert.encoder.layer.1.attention.output.dense.bias', 'module.bert.encoder.layer.1.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.1.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.1.intermediate.dense.weight', 'module.bert.encoder.layer.1.intermediate.dense.bias', 'module.bert.encoder.layer.1.output.dense.weight', 'module.bert.encoder.layer.1.output.dense.bias', 'module.bert.encoder.layer.1.output.LayerNorm.weight', 'module.bert.encoder.layer.1.output.LayerNorm.bias', 'module.bert.encoder.layer.2.attention.self.query.weight', 'module.bert.encoder.layer.2.attention.self.query.bias', 'module.bert.encoder.layer.2.attention.self.key.weight', 'module.bert.encoder.layer.2.attention.self.key.bias', 'module.bert.encoder.layer.2.attention.self.value.weight', 'module.bert.encoder.layer.2.attention.self.value.bias', 'module.bert.encoder.layer.2.attention.output.dense.weight', 'module.bert.encoder.layer.2.attention.output.dense.bias', 'module.bert.encoder.layer.2.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.2.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.2.intermediate.dense.weight', 'module.bert.encoder.layer.2.intermediate.dense.bias', 'module.bert.encoder.layer.2.output.dense.weight', 'module.bert.encoder.layer.2.output.dense.bias', 'module.bert.encoder.layer.2.output.LayerNorm.weight', 'module.bert.encoder.layer.2.output.LayerNorm.bias', 'module.bert.encoder.layer.3.attention.self.query.weight', 'module.bert.encoder.layer.3.attention.self.query.bias', 'module.bert.encoder.layer.3.attention.self.key.weight', 'module.bert.encoder.layer.3.attention.self.key.bias', 'module.bert.encoder.layer.3.attention.self.value.weight', 'module.bert.encoder.layer.3.attention.self.value.bias', 'module.bert.encoder.layer.3.attention.output.dense.weight', 'module.bert.encoder.layer.3.attention.output.dense.bias', 'module.bert.encoder.layer.3.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.3.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.3.intermediate.dense.weight', 'module.bert.encoder.layer.3.intermediate.dense.bias', 'module.bert.encoder.layer.3.output.dense.weight', 'module.bert.encoder.layer.3.output.dense.bias', 'module.bert.encoder.layer.3.output.LayerNorm.weight', 'module.bert.encoder.layer.3.output.LayerNorm.bias', 'module.bert.encoder.layer.4.attention.self.query.weight', 'module.bert.encoder.layer.4.attention.self.query.bias', 'module.bert.encoder.layer.4.attention.self.key.weight', 'module.bert.encoder.layer.4.attention.self.key.bias', 'module.bert.encoder.layer.4.attention.self.value.weight', 'module.bert.encoder.layer.4.attention.self.value.bias', 'module.bert.encoder.layer.4.attention.output.dense.weight', 'module.bert.encoder.layer.4.attention.output.dense.bias', 'module.bert.encoder.layer.4.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.4.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.4.intermediate.dense.weight', 'module.bert.encoder.layer.4.intermediate.dense.bias', 'module.bert.encoder.layer.4.output.dense.weight', 'module.bert.encoder.layer.4.output.dense.bias', 'module.bert.encoder.layer.4.output.LayerNorm.weight', 'module.bert.encoder.layer.4.output.LayerNorm.bias', 'module.bert.encoder.layer.5.attention.self.query.weight', 'module.bert.encoder.layer.5.attention.self.query.bias', 'module.bert.encoder.layer.5.attention.self.key.weight', 'module.bert.encoder.layer.5.attention.self.key.bias', 'module.bert.encoder.layer.5.attention.self.value.weight', 'module.bert.encoder.layer.5.attention.self.value.bias', 'module.bert.encoder.layer.5.attention.output.dense.weight', 'module.bert.encoder.layer.5.attention.output.dense.bias', 'module.bert.encoder.layer.5.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.5.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.5.intermediate.dense.weight', 'module.bert.encoder.layer.5.intermediate.dense.bias', 'module.bert.encoder.layer.5.output.dense.weight', 'module.bert.encoder.layer.5.output.dense.bias', 'module.bert.encoder.layer.5.output.LayerNorm.weight', 'module.bert.encoder.layer.5.output.LayerNorm.bias', 'module.bert.encoder.layer.6.attention.self.query.weight', 'module.bert.encoder.layer.6.attention.self.query.bias', 'module.bert.encoder.layer.6.attention.self.key.weight', 'module.bert.encoder.layer.6.attention.self.key.bias', 'module.bert.encoder.layer.6.attention.self.value.weight', 'module.bert.encoder.layer.6.attention.self.value.bias', 'module.bert.encoder.layer.6.attention.output.dense.weight', 'module.bert.encoder.layer.6.attention.output.dense.bias', 'module.bert.encoder.layer.6.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.6.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.6.intermediate.dense.weight', 'module.bert.encoder.layer.6.intermediate.dense.bias', 'module.bert.encoder.layer.6.output.dense.weight', 'module.bert.encoder.layer.6.output.dense.bias', 'module.bert.encoder.layer.6.output.LayerNorm.weight', 'module.bert.encoder.layer.6.output.LayerNorm.bias', 'module.bert.encoder.layer.7.attention.self.query.weight', 'module.bert.encoder.layer.7.attention.self.query.bias', 'module.bert.encoder.layer.7.attention.self.key.weight', 'module.bert.encoder.layer.7.attention.self.key.bias', 'module.bert.encoder.layer.7.attention.self.value.weight', 'module.bert.encoder.layer.7.attention.self.value.bias', 'module.bert.encoder.layer.7.attention.output.dense.weight', 'module.bert.encoder.layer.7.attention.output.dense.bias', 'module.bert.encoder.layer.7.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.7.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.7.intermediate.dense.weight', 'module.bert.encoder.layer.7.intermediate.dense.bias', 'module.bert.encoder.layer.7.output.dense.weight', 'module.bert.encoder.layer.7.output.dense.bias', 'module.bert.encoder.layer.7.output.LayerNorm.weight', 'module.bert.encoder.layer.7.output.LayerNorm.bias', 'module.bert.encoder.layer.8.attention.self.query.weight', 'module.bert.encoder.layer.8.attention.self.query.bias', 'module.bert.encoder.layer.8.attention.self.key.weight', 'module.bert.encoder.layer.8.attention.self.key.bias', 'module.bert.encoder.layer.8.attention.self.value.weight', 'module.bert.encoder.layer.8.attention.self.value.bias', 'module.bert.encoder.layer.8.attention.output.dense.weight', 'module.bert.encoder.layer.8.attention.output.dense.bias', 'module.bert.encoder.layer.8.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.8.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.8.intermediate.dense.weight', 'module.bert.encoder.layer.8.intermediate.dense.bias', 'module.bert.encoder.layer.8.output.dense.weight', 'module.bert.encoder.layer.8.output.dense.bias', 'module.bert.encoder.layer.8.output.LayerNorm.weight', 'module.bert.encoder.layer.8.output.LayerNorm.bias', 'module.bert.encoder.layer.9.attention.self.query.weight', 'module.bert.encoder.layer.9.attention.self.query.bias', 'module.bert.encoder.layer.9.attention.self.key.weight', 'module.bert.encoder.layer.9.attention.self.key.bias', 'module.bert.encoder.layer.9.attention.self.value.weight', 'module.bert.encoder.layer.9.attention.self.value.bias', 'module.bert.encoder.layer.9.attention.output.dense.weight', 'module.bert.encoder.layer.9.attention.output.dense.bias', 'module.bert.encoder.layer.9.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.9.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.9.intermediate.dense.weight', 'module.bert.encoder.layer.9.intermediate.dense.bias', 'module.bert.encoder.layer.9.output.dense.weight', 'module.bert.encoder.layer.9.output.dense.bias', 'module.bert.encoder.layer.9.output.LayerNorm.weight', 'module.bert.encoder.layer.9.output.LayerNorm.bias', 'module.bert.encoder.layer.10.attention.self.query.weight', 'module.bert.encoder.layer.10.attention.self.query.bias', 'module.bert.encoder.layer.10.attention.self.key.weight', 'module.bert.encoder.layer.10.attention.self.key.bias', 'module.bert.encoder.layer.10.attention.self.value.weight', 'module.bert.encoder.layer.10.attention.self.value.bias', 'module.bert.encoder.layer.10.attention.output.dense.weight', 'module.bert.encoder.layer.10.attention.output.dense.bias', 'module.bert.encoder.layer.10.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.10.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.10.intermediate.dense.weight', 'module.bert.encoder.layer.10.intermediate.dense.bias', 'module.bert.encoder.layer.10.output.dense.weight', 'module.bert.encoder.layer.10.output.dense.bias', 'module.bert.encoder.layer.10.output.LayerNorm.weight', 'module.bert.encoder.layer.10.output.LayerNorm.bias', 'module.bert.encoder.layer.11.attention.self.query.weight', 'module.bert.encoder.layer.11.attention.self.query.bias', 'module.bert.encoder.layer.11.attention.self.key.weight', 'module.bert.encoder.layer.11.attention.self.key.bias', 'module.bert.encoder.layer.11.attention.self.value.weight', 'module.bert.encoder.layer.11.attention.self.value.bias', 'module.bert.encoder.layer.11.attention.output.dense.weight', 'module.bert.encoder.layer.11.attention.output.dense.bias', 'module.bert.encoder.layer.11.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.11.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.11.intermediate.dense.weight', 'module.bert.encoder.layer.11.intermediate.dense.bias', 'module.bert.encoder.layer.11.output.dense.weight', 'module.bert.encoder.layer.11.output.dense.bias', 'module.bert.encoder.layer.11.output.LayerNorm.weight', 'module.bert.encoder.layer.11.output.LayerNorm.bias', 'module.bert.pooler.dense.weight', 'module.bert.pooler.dense.bias', 'module.classifier.weight', 'module.classifier.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ./results_0415_32/checkpoint-98000/ and are newly initialized: ['embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.3.output.dense.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.output.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.7.output.dense.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.bias', 'pooler.dense.weight', 'pooler.dense.bias', 'classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

可以看到,提示有很多的参数无法匹配上,原因是在训练时使用了DataParallel方式,导致参数的前面都多了一个module.bert.的前缀。

解决方法

将模型参数重新加载与输出

目前一种实验可行的方法较为繁琐,主要有以下几个步骤

  1. 首先加载checkpoint目录下的pytorch_model.bin文件,这里面存储了在DataParallel方式下模型的各个参数。
  2. 生成一个与原始Fine-tune模型相同结构的模型,并将其也使用DataParallel方式放到设备上。
  3. 使用model.load_state_dict方法把参数加载进来
  4. 将模型放到一张显卡上(即退出了DataParallel模式)
  5. 重新保存模型参数到一个本地文件中。
  6. 新建一个与原始Fine-tune模型相同结构的模型,并将其放到一张GPU设备上。
  7. 读入在步骤5中保存的参数。

相关代码如下:

# Step 1
pytorch_model_state_dict = torch.load('./results_0415_32/checkpoint-98000/pytorch_model.bin', map_location=torch.device('cpu'))# Step 2
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=CLASS_NUM)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = nn.DataParallel(model, device_ids=[0,1,2,3])
model.to(device)# Step 3
model.load_state_dict(pytorch_model_state_dict)# Step 4
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model.to(device)# Step 5
torch.save(model.module.state_dict(), 'static_dict.pkl')# Step 6
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=CLASS_NUM)
model.to(device)# Step 7
model.load_state_dict(torch.load('static_dict.pkl'))

使用Trainer类时不要认为设置模型为DataParallel模式

这种方式加载和保存都非常繁琐,在与同学讨论后,了解到其实Trainer这个类会默认采用DataParallel方式去训练,因而并不需要人为的将模型设置DataParallel,这样在加载时就不会导致参数前嵌套的两层结构造成参数不匹配。

实验测试有效,完全匹配而不会抱任何错误和警告。

但由于默认的DataParallel模式总是会将模型放到这个机器的所有GPU上来并行加速。实际中,由于多用户使用,可能四卡机中的cuda 0和cuda 2被占用,那么默认启动Trainer时它就会请求cuda 0给予GPU memory,但由于cuda 0被占用,程序将意外终止为cuda out of memory。

多卡的选择

实际上,我们可以通过设置哪些GPU让程序可见来解决上述问题。具体做法为在命令行中程序运行时,设置可见的GPU即可。例如:

CUDA_VISIBLE_DEVICES=1,3 python my_script.py

就可以让程序只看见cuda 1和cuda 3并且程序内部所看到的cuda编号为cuda 0和cuda 1,即该程序会认为这个机器上只有两个GPU。

参考

  1. (原)PyTorch中使用指定的GPU, https://www.cnblogs.com/darkknightzh/p/6836568.html

transformers库中使用DataParallel保存模型参数时遇到的问题记录相关推荐

  1. 【Pytorch神经网络理论篇】 39 Transformers库中的BERTology系列模型

    同学你好!本文章于2021年末编写,获得广泛的好评! 故在2022年末对本系列进行填充与更新,欢迎大家订阅最新的专栏,获取基于Pytorch1.10版本的理论代码(2023版)实现, Pytorch深 ...

  2. 【待更新】GPU 保存模型参数,GPU 加载模型参数

    GPU 保存模型参数,GPU 加载模型参数 保存 # 模型 device = torch.device('cuda') net = KGCN(num_user, num_entity, num_rel ...

  3. Transformers 库中的 Tokenizer 使用

    文章目录 概述 基本使用方法 进阶 基本使用不能满足的情况 解决思路 问题一解决:(有两种思路) 问题二解决: Tokenizer 中的 Encoder vocab_base 部分 vocab_add ...

  4. numpy库中ndarray切片操作的参数意义

    ndarray切片操作的规则总结出来叫做"三帽号规则" 三帽号规则即:[开始索引:结尾索引:步长],并且切片区间是左闭右开的,即"开始索引:结尾索引"表示的区间 ...

  5. 【Python机器学习】Sklearn库中Kmeans类、超参数K值确定、特征归一化的讲解(图文解释)

    一.局部最优解 采用随机产生初始簇中心 的方法,可能会出现运行 结果不一致的情况.这是 因为不同的初始簇中心使 得算法可能收敛到不同的 局部极小值. 不能收敛到全局最小值,是最优化计算中常常遇到的问题 ...

  6. lightgbm保存模型参数

    20210205 params = {'task': 'train', # 执行的任务类型'boosting_type': 'gbrt', # 基学习器'objective': 'lambdarank ...

  7. 08_04基于手写数据集_mat保存模型参数

    import os import numpy as np import tensorflow as tf from scipy import io from tensorflow.examples.t ...

  8. pytorch保存模型参数

    1.代码 有保存路径 PATH='my_model.pth' torch.save(model.state_dict(),PATH) 新模型 new_model=model GPU上运行 new_mo ...

  9. tensorflow保存模型参数文件pb查看

    查看pb文件的节点参数: with tf.Session() as sess: with open(model, 'rb') as model_file: graph_def = tf.GraphDe ...

最新文章

  1. jmail反馈是否发送成功_如何在钉钉上自动发送定制消息或通知给同事?(10行代码搞定)...
  2. 【总结+计划】2015一月份总结+2015二月份计划——全栈实践
  3. CentOS 安装python3.6
  4. [原创]Paros工具培训介绍
  5. 题目1025:最大报销额
  6. css教程之列表属性
  7. 面试 4 个月,最终入职大厂经验分享!
  8. Spring和Struts2整合
  9. python字符串三种常用的方法或函数_python中字符串常用的函数
  10. 【华为云•云享专家•原创分享计划上线】原创文章征集,寻找与众不同的你
  11. HTTP 1.1 协议规范
  12. 当IDENTITY_INSERT设置为OFF时不能向表插入显示值。(源:MSSQLServer,错误码:544)
  13. java batik_batik详解2
  14. 【Matlab系列】Matlab语言基础知识汇总
  15. 联想a30微型计算机,联想A30测评,硬件部分。是电脑哦。
  16. bug - Nacos - Ignore the empty nacos configuration and get it based on dataId
  17. 民国歌曲 - 毛毛雨
  18. 计算机二级考试python考试大纲
  19. leetcode-460:LFU 缓存
  20. 使用IE浏览器,禁止访问,显示 Internet Explorer增强安全配置正在阻止来自下列网站的从应用程序中的内容

热门文章

  1. 苹果的http流媒体技术
  2. 剑指Offer——面试小提示(持续更新中)
  3. Android 自带描边颜色渐变炫酷进度条,面试必知必会
  4. 会员消费占比高达96%,孩子王究竟是怎么做到的?
  5. 深入浅出WPF(10)——“脚踩N条船”的多路Binding
  6. 脚踩脸书,锤趴Google,“国产巨头”疫情期间狂吸金!
  7. 试题 算法训练 强力党逗志芃
  8. 2018年04月14日_s芃成_新浪博客
  9. 【最新】原神:数据异常,完全卸载游戏并从官方渠道下载安装,错误码:31-4302
  10. 【opencv】opencv之Mat属性