transformers库中使用DataParallel保存模型参数时遇到的问题记录
pytorch中使用DataParallel保存模型参数时遇到的问题记录
之前使用Transformers库中的Bert模型在自己的文本分类任务上使用Transformers库里的Trainer方式进行了Fine-tune。今天尝试加载保存好的checkpoint到程序中来直接进行evaluate,结果参数匹配不上,故在此记录问题原因和解决过程。
更新
新的问题
我再某台只有0,1两张显卡的机器上对模型进行了若干次训练,模型直接保存成了一个文件。之后,在重新加载这个模型到另一台拥有0,1,2,3四张显卡的机器上继续进行训练时,发现模型还是只占用0,1两张卡。直接使用
model = nn.DataParallel(model, device_ids=[0,1,2,3])
会报错
RuntimeError: Input, output and indices must be on the current device
查阅相关问题后,找到一种更方便的方式,即使用如下代码:
if isinstance(model,torch.nn.DataParallel):model = model.module
# 之后加载就不会报错了
model = nn.DataParallel(model, device_ids=[0,1,2,3])
之后模型就可以正常使用加载了,也可以放到四张卡上。
直接使用AutoModelForSequenceClassification从checkpoint目录加载
Fine tune阶段的模型保存和使用
在Fine tune阶段,我手动将模型设置为了DataParallel模式,然后再使用Trainer进行Fine tune,而模型在保存时也是Trainer自动去保存的。
所设置的Training Args如下:
training_args = TrainingArguments(output_dir='./results_0415_32', # output directorynum_train_epochs=3, # total number of training epochsper_device_train_batch_size=128, # batch size per divice during trainingper_device_eval_batch_size=128, # batch size for evaluationwarmup_steps=500, # number of warmup steps for learning rate schedulerweight_decay=0.01, # strength of weight decay 用于指定权值衰减率,相当于L2正则化中的λ参数logging_dir='./logs', # directory for storing logslogging_steps=10000,
)
然后从一个初始的BertForSequenceClassification模型开始Fine tune,并手动将模型设置了DataParallel模式。
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(afid2nor))device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = nn.DataParallel(model, device_ids=[0,1,2,3])
model.to(device)
此时打印模型即:
DataParallel((module): BertForSequenceClassification((bert): BertModel((embeddings): BertEmbeddings((word_embeddings): Embedding(30522, 768, padding_idx=0)(position_embeddings): Embedding(512, 768)(token_type_embeddings): Embedding(2, 768)(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)(dropout): Dropout(p=0.1, inplace=False))(encoder): BertEncoder((layer): ModuleList((0): BertLayer((attention): BertAttention((self): BertSelfAttention((query): Linear(in_features=768, out_features=768, bias=True)(key): Linear(in_features=768, out_features=768, bias=True)(value): Linear(in_features=768, out_features=768, bias=True)(dropout): Dropout(p=0.1, inplace=False))(output): BertSelfOutput((dense): Linear(in_features=768, out_features=768, bias=True)(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)(dropout): Dropout(p=0.1, inplace=False)))(intermediate): BertIntermediate((dense): Linear(in_features=768, out_features=3072, bias=True))(output): BertOutput((dense): Linear(in_features=3072, out_features=768, bias=True)(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)(dropout): Dropout(p=0.1, inplace=False)))# 1-10层内容省略,基本一致(11): BertLayer((attention): BertAttention((self): BertSelfAttention((query): Linear(in_features=768, out_features=768, bias=True)(key): Linear(in_features=768, out_features=768, bias=True)(value): Linear(in_features=768, out_features=768, bias=True)(dropout): Dropout(p=0.1, inplace=False))(output): BertSelfOutput((dense): Linear(in_features=768, out_features=768, bias=True)(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)(dropout): Dropout(p=0.1, inplace=False)))(intermediate): BertIntermediate((dense): Linear(in_features=768, out_features=3072, bias=True))(output): BertOutput((dense): Linear(in_features=3072, out_features=768, bias=True)(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)(dropout): Dropout(p=0.1, inplace=False)))))(pooler): BertPooler((dense): Linear(in_features=768, out_features=768, bias=True)(activation): Tanh()))(dropout): Dropout(p=0.1, inplace=False)(classifier): Linear(in_features=768, out_features=25762, bias=True))
)
可以看到,这样的模型外面"套着"两层结构为
(module): BertForSequenceClassification((bert): BertModel(
这个对后续从checkpoint加载产生了影响。
模型的加载
加载的代码为:
model = AutoModelForSequenceClassification.from_pretrained('./results_0415_32/checkpoint-98000/')
提示了如下的Warning:
Some weights of the model checkpoint at ./results_0415_32/checkpoint-98000/ were not used when initializing BertForSequenceClassification: ['module.bert.embeddings.position_ids', 'module.bert.embeddings.word_embeddings.weight', 'module.bert.embeddings.position_embeddings.weight', 'module.bert.embeddings.token_type_embeddings.weight', 'module.bert.embeddings.LayerNorm.weight', 'module.bert.embeddings.LayerNorm.bias', 'module.bert.encoder.layer.0.attention.self.query.weight', 'module.bert.encoder.layer.0.attention.self.query.bias', 'module.bert.encoder.layer.0.attention.self.key.weight', 'module.bert.encoder.layer.0.attention.self.key.bias', 'module.bert.encoder.layer.0.attention.self.value.weight', 'module.bert.encoder.layer.0.attention.self.value.bias', 'module.bert.encoder.layer.0.attention.output.dense.weight', 'module.bert.encoder.layer.0.attention.output.dense.bias', 'module.bert.encoder.layer.0.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.0.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.0.intermediate.dense.weight', 'module.bert.encoder.layer.0.intermediate.dense.bias', 'module.bert.encoder.layer.0.output.dense.weight', 'module.bert.encoder.layer.0.output.dense.bias', 'module.bert.encoder.layer.0.output.LayerNorm.weight', 'module.bert.encoder.layer.0.output.LayerNorm.bias', 'module.bert.encoder.layer.1.attention.self.query.weight', 'module.bert.encoder.layer.1.attention.self.query.bias', 'module.bert.encoder.layer.1.attention.self.key.weight', 'module.bert.encoder.layer.1.attention.self.key.bias', 'module.bert.encoder.layer.1.attention.self.value.weight', 'module.bert.encoder.layer.1.attention.self.value.bias', 'module.bert.encoder.layer.1.attention.output.dense.weight', 'module.bert.encoder.layer.1.attention.output.dense.bias', 'module.bert.encoder.layer.1.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.1.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.1.intermediate.dense.weight', 'module.bert.encoder.layer.1.intermediate.dense.bias', 'module.bert.encoder.layer.1.output.dense.weight', 'module.bert.encoder.layer.1.output.dense.bias', 'module.bert.encoder.layer.1.output.LayerNorm.weight', 'module.bert.encoder.layer.1.output.LayerNorm.bias', 'module.bert.encoder.layer.2.attention.self.query.weight', 'module.bert.encoder.layer.2.attention.self.query.bias', 'module.bert.encoder.layer.2.attention.self.key.weight', 'module.bert.encoder.layer.2.attention.self.key.bias', 'module.bert.encoder.layer.2.attention.self.value.weight', 'module.bert.encoder.layer.2.attention.self.value.bias', 'module.bert.encoder.layer.2.attention.output.dense.weight', 'module.bert.encoder.layer.2.attention.output.dense.bias', 'module.bert.encoder.layer.2.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.2.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.2.intermediate.dense.weight', 'module.bert.encoder.layer.2.intermediate.dense.bias', 'module.bert.encoder.layer.2.output.dense.weight', 'module.bert.encoder.layer.2.output.dense.bias', 'module.bert.encoder.layer.2.output.LayerNorm.weight', 'module.bert.encoder.layer.2.output.LayerNorm.bias', 'module.bert.encoder.layer.3.attention.self.query.weight', 'module.bert.encoder.layer.3.attention.self.query.bias', 'module.bert.encoder.layer.3.attention.self.key.weight', 'module.bert.encoder.layer.3.attention.self.key.bias', 'module.bert.encoder.layer.3.attention.self.value.weight', 'module.bert.encoder.layer.3.attention.self.value.bias', 'module.bert.encoder.layer.3.attention.output.dense.weight', 'module.bert.encoder.layer.3.attention.output.dense.bias', 'module.bert.encoder.layer.3.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.3.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.3.intermediate.dense.weight', 'module.bert.encoder.layer.3.intermediate.dense.bias', 'module.bert.encoder.layer.3.output.dense.weight', 'module.bert.encoder.layer.3.output.dense.bias', 'module.bert.encoder.layer.3.output.LayerNorm.weight', 'module.bert.encoder.layer.3.output.LayerNorm.bias', 'module.bert.encoder.layer.4.attention.self.query.weight', 'module.bert.encoder.layer.4.attention.self.query.bias', 'module.bert.encoder.layer.4.attention.self.key.weight', 'module.bert.encoder.layer.4.attention.self.key.bias', 'module.bert.encoder.layer.4.attention.self.value.weight', 'module.bert.encoder.layer.4.attention.self.value.bias', 'module.bert.encoder.layer.4.attention.output.dense.weight', 'module.bert.encoder.layer.4.attention.output.dense.bias', 'module.bert.encoder.layer.4.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.4.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.4.intermediate.dense.weight', 'module.bert.encoder.layer.4.intermediate.dense.bias', 'module.bert.encoder.layer.4.output.dense.weight', 'module.bert.encoder.layer.4.output.dense.bias', 'module.bert.encoder.layer.4.output.LayerNorm.weight', 'module.bert.encoder.layer.4.output.LayerNorm.bias', 'module.bert.encoder.layer.5.attention.self.query.weight', 'module.bert.encoder.layer.5.attention.self.query.bias', 'module.bert.encoder.layer.5.attention.self.key.weight', 'module.bert.encoder.layer.5.attention.self.key.bias', 'module.bert.encoder.layer.5.attention.self.value.weight', 'module.bert.encoder.layer.5.attention.self.value.bias', 'module.bert.encoder.layer.5.attention.output.dense.weight', 'module.bert.encoder.layer.5.attention.output.dense.bias', 'module.bert.encoder.layer.5.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.5.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.5.intermediate.dense.weight', 'module.bert.encoder.layer.5.intermediate.dense.bias', 'module.bert.encoder.layer.5.output.dense.weight', 'module.bert.encoder.layer.5.output.dense.bias', 'module.bert.encoder.layer.5.output.LayerNorm.weight', 'module.bert.encoder.layer.5.output.LayerNorm.bias', 'module.bert.encoder.layer.6.attention.self.query.weight', 'module.bert.encoder.layer.6.attention.self.query.bias', 'module.bert.encoder.layer.6.attention.self.key.weight', 'module.bert.encoder.layer.6.attention.self.key.bias', 'module.bert.encoder.layer.6.attention.self.value.weight', 'module.bert.encoder.layer.6.attention.self.value.bias', 'module.bert.encoder.layer.6.attention.output.dense.weight', 'module.bert.encoder.layer.6.attention.output.dense.bias', 'module.bert.encoder.layer.6.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.6.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.6.intermediate.dense.weight', 'module.bert.encoder.layer.6.intermediate.dense.bias', 'module.bert.encoder.layer.6.output.dense.weight', 'module.bert.encoder.layer.6.output.dense.bias', 'module.bert.encoder.layer.6.output.LayerNorm.weight', 'module.bert.encoder.layer.6.output.LayerNorm.bias', 'module.bert.encoder.layer.7.attention.self.query.weight', 'module.bert.encoder.layer.7.attention.self.query.bias', 'module.bert.encoder.layer.7.attention.self.key.weight', 'module.bert.encoder.layer.7.attention.self.key.bias', 'module.bert.encoder.layer.7.attention.self.value.weight', 'module.bert.encoder.layer.7.attention.self.value.bias', 'module.bert.encoder.layer.7.attention.output.dense.weight', 'module.bert.encoder.layer.7.attention.output.dense.bias', 'module.bert.encoder.layer.7.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.7.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.7.intermediate.dense.weight', 'module.bert.encoder.layer.7.intermediate.dense.bias', 'module.bert.encoder.layer.7.output.dense.weight', 'module.bert.encoder.layer.7.output.dense.bias', 'module.bert.encoder.layer.7.output.LayerNorm.weight', 'module.bert.encoder.layer.7.output.LayerNorm.bias', 'module.bert.encoder.layer.8.attention.self.query.weight', 'module.bert.encoder.layer.8.attention.self.query.bias', 'module.bert.encoder.layer.8.attention.self.key.weight', 'module.bert.encoder.layer.8.attention.self.key.bias', 'module.bert.encoder.layer.8.attention.self.value.weight', 'module.bert.encoder.layer.8.attention.self.value.bias', 'module.bert.encoder.layer.8.attention.output.dense.weight', 'module.bert.encoder.layer.8.attention.output.dense.bias', 'module.bert.encoder.layer.8.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.8.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.8.intermediate.dense.weight', 'module.bert.encoder.layer.8.intermediate.dense.bias', 'module.bert.encoder.layer.8.output.dense.weight', 'module.bert.encoder.layer.8.output.dense.bias', 'module.bert.encoder.layer.8.output.LayerNorm.weight', 'module.bert.encoder.layer.8.output.LayerNorm.bias', 'module.bert.encoder.layer.9.attention.self.query.weight', 'module.bert.encoder.layer.9.attention.self.query.bias', 'module.bert.encoder.layer.9.attention.self.key.weight', 'module.bert.encoder.layer.9.attention.self.key.bias', 'module.bert.encoder.layer.9.attention.self.value.weight', 'module.bert.encoder.layer.9.attention.self.value.bias', 'module.bert.encoder.layer.9.attention.output.dense.weight', 'module.bert.encoder.layer.9.attention.output.dense.bias', 'module.bert.encoder.layer.9.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.9.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.9.intermediate.dense.weight', 'module.bert.encoder.layer.9.intermediate.dense.bias', 'module.bert.encoder.layer.9.output.dense.weight', 'module.bert.encoder.layer.9.output.dense.bias', 'module.bert.encoder.layer.9.output.LayerNorm.weight', 'module.bert.encoder.layer.9.output.LayerNorm.bias', 'module.bert.encoder.layer.10.attention.self.query.weight', 'module.bert.encoder.layer.10.attention.self.query.bias', 'module.bert.encoder.layer.10.attention.self.key.weight', 'module.bert.encoder.layer.10.attention.self.key.bias', 'module.bert.encoder.layer.10.attention.self.value.weight', 'module.bert.encoder.layer.10.attention.self.value.bias', 'module.bert.encoder.layer.10.attention.output.dense.weight', 'module.bert.encoder.layer.10.attention.output.dense.bias', 'module.bert.encoder.layer.10.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.10.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.10.intermediate.dense.weight', 'module.bert.encoder.layer.10.intermediate.dense.bias', 'module.bert.encoder.layer.10.output.dense.weight', 'module.bert.encoder.layer.10.output.dense.bias', 'module.bert.encoder.layer.10.output.LayerNorm.weight', 'module.bert.encoder.layer.10.output.LayerNorm.bias', 'module.bert.encoder.layer.11.attention.self.query.weight', 'module.bert.encoder.layer.11.attention.self.query.bias', 'module.bert.encoder.layer.11.attention.self.key.weight', 'module.bert.encoder.layer.11.attention.self.key.bias', 'module.bert.encoder.layer.11.attention.self.value.weight', 'module.bert.encoder.layer.11.attention.self.value.bias', 'module.bert.encoder.layer.11.attention.output.dense.weight', 'module.bert.encoder.layer.11.attention.output.dense.bias', 'module.bert.encoder.layer.11.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.11.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.11.intermediate.dense.weight', 'module.bert.encoder.layer.11.intermediate.dense.bias', 'module.bert.encoder.layer.11.output.dense.weight', 'module.bert.encoder.layer.11.output.dense.bias', 'module.bert.encoder.layer.11.output.LayerNorm.weight', 'module.bert.encoder.layer.11.output.LayerNorm.bias', 'module.bert.pooler.dense.weight', 'module.bert.pooler.dense.bias', 'module.classifier.weight', 'module.classifier.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ./results_0415_32/checkpoint-98000/ and are newly initialized: ['embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.3.output.dense.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.output.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.7.output.dense.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.bias', 'pooler.dense.weight', 'pooler.dense.bias', 'classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
可以看到,提示有很多的参数无法匹配上,原因是在训练时使用了DataParallel方式,导致参数的前面都多了一个module.bert.的前缀。
解决方法
将模型参数重新加载与输出
目前一种实验可行的方法较为繁琐,主要有以下几个步骤
- 首先加载checkpoint目录下的pytorch_model.bin文件,这里面存储了在DataParallel方式下模型的各个参数。
- 生成一个与原始Fine-tune模型相同结构的模型,并将其也使用DataParallel方式放到设备上。
- 使用model.load_state_dict方法把参数加载进来
- 将模型放到一张显卡上(即退出了DataParallel模式)
- 重新保存模型参数到一个本地文件中。
- 新建一个与原始Fine-tune模型相同结构的模型,并将其放到一张GPU设备上。
- 读入在步骤5中保存的参数。
相关代码如下:
# Step 1
pytorch_model_state_dict = torch.load('./results_0415_32/checkpoint-98000/pytorch_model.bin', map_location=torch.device('cpu'))# Step 2
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=CLASS_NUM)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = nn.DataParallel(model, device_ids=[0,1,2,3])
model.to(device)# Step 3
model.load_state_dict(pytorch_model_state_dict)# Step 4
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model.to(device)# Step 5
torch.save(model.module.state_dict(), 'static_dict.pkl')# Step 6
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=CLASS_NUM)
model.to(device)# Step 7
model.load_state_dict(torch.load('static_dict.pkl'))
使用Trainer类时不要认为设置模型为DataParallel模式
这种方式加载和保存都非常繁琐,在与同学讨论后,了解到其实Trainer这个类会默认采用DataParallel方式去训练,因而并不需要人为的将模型设置DataParallel,这样在加载时就不会导致参数前嵌套的两层结构造成参数不匹配。
实验测试有效,完全匹配而不会抱任何错误和警告。
但由于默认的DataParallel模式总是会将模型放到这个机器的所有GPU上来并行加速。实际中,由于多用户使用,可能四卡机中的cuda 0和cuda 2被占用,那么默认启动Trainer时它就会请求cuda 0给予GPU memory,但由于cuda 0被占用,程序将意外终止为cuda out of memory。
多卡的选择
实际上,我们可以通过设置哪些GPU让程序可见来解决上述问题。具体做法为在命令行中程序运行时,设置可见的GPU即可。例如:
CUDA_VISIBLE_DEVICES=1,3 python my_script.py
就可以让程序只看见cuda 1和cuda 3并且程序内部所看到的cuda编号为cuda 0和cuda 1,即该程序会认为这个机器上只有两个GPU。
参考
- (原)PyTorch中使用指定的GPU, https://www.cnblogs.com/darkknightzh/p/6836568.html
transformers库中使用DataParallel保存模型参数时遇到的问题记录相关推荐
- 【Pytorch神经网络理论篇】 39 Transformers库中的BERTology系列模型
同学你好!本文章于2021年末编写,获得广泛的好评! 故在2022年末对本系列进行填充与更新,欢迎大家订阅最新的专栏,获取基于Pytorch1.10版本的理论代码(2023版)实现, Pytorch深 ...
- 【待更新】GPU 保存模型参数,GPU 加载模型参数
GPU 保存模型参数,GPU 加载模型参数 保存 # 模型 device = torch.device('cuda') net = KGCN(num_user, num_entity, num_rel ...
- Transformers 库中的 Tokenizer 使用
文章目录 概述 基本使用方法 进阶 基本使用不能满足的情况 解决思路 问题一解决:(有两种思路) 问题二解决: Tokenizer 中的 Encoder vocab_base 部分 vocab_add ...
- numpy库中ndarray切片操作的参数意义
ndarray切片操作的规则总结出来叫做"三帽号规则" 三帽号规则即:[开始索引:结尾索引:步长],并且切片区间是左闭右开的,即"开始索引:结尾索引"表示的区间 ...
- 【Python机器学习】Sklearn库中Kmeans类、超参数K值确定、特征归一化的讲解(图文解释)
一.局部最优解 采用随机产生初始簇中心 的方法,可能会出现运行 结果不一致的情况.这是 因为不同的初始簇中心使 得算法可能收敛到不同的 局部极小值. 不能收敛到全局最小值,是最优化计算中常常遇到的问题 ...
- lightgbm保存模型参数
20210205 params = {'task': 'train', # 执行的任务类型'boosting_type': 'gbrt', # 基学习器'objective': 'lambdarank ...
- 08_04基于手写数据集_mat保存模型参数
import os import numpy as np import tensorflow as tf from scipy import io from tensorflow.examples.t ...
- pytorch保存模型参数
1.代码 有保存路径 PATH='my_model.pth' torch.save(model.state_dict(),PATH) 新模型 new_model=model GPU上运行 new_mo ...
- tensorflow保存模型参数文件pb查看
查看pb文件的节点参数: with tf.Session() as sess: with open(model, 'rb') as model_file: graph_def = tf.GraphDe ...
最新文章
- jmail反馈是否发送成功_如何在钉钉上自动发送定制消息或通知给同事?(10行代码搞定)...
- 【总结+计划】2015一月份总结+2015二月份计划——全栈实践
- CentOS 安装python3.6
- [原创]Paros工具培训介绍
- 题目1025:最大报销额
- css教程之列表属性
- 面试 4 个月,最终入职大厂经验分享!
- Spring和Struts2整合
- python字符串三种常用的方法或函数_python中字符串常用的函数
- 【华为云•云享专家•原创分享计划上线】原创文章征集,寻找与众不同的你
- HTTP 1.1 协议规范
- 当IDENTITY_INSERT设置为OFF时不能向表插入显示值。(源:MSSQLServer,错误码:544)
- java batik_batik详解2
- 【Matlab系列】Matlab语言基础知识汇总
- 联想a30微型计算机,联想A30测评,硬件部分。是电脑哦。
- bug - Nacos - Ignore the empty nacos configuration and get it based on dataId
- 民国歌曲 - 毛毛雨
- 计算机二级考试python考试大纲
- leetcode-460:LFU 缓存
- 使用IE浏览器,禁止访问,显示 Internet Explorer增强安全配置正在阻止来自下列网站的从应用程序中的内容
热门文章
- 苹果的http流媒体技术
- 剑指Offer——面试小提示(持续更新中)
- Android 自带描边颜色渐变炫酷进度条,面试必知必会
- 会员消费占比高达96%,孩子王究竟是怎么做到的?
- 深入浅出WPF(10)——“脚踩N条船”的多路Binding
- 脚踩脸书,锤趴Google,“国产巨头”疫情期间狂吸金!
- 试题 算法训练 强力党逗志芃
- 2018年04月14日_s芃成_新浪博客
- 【最新】原神:数据异常,完全卸载游戏并从官方渠道下载安装,错误码:31-4302
- 【opencv】opencv之Mat属性