How can I load a partial pretrained pytorch model?
Asked Answered
C

1

11

I'm trying to get a pytorch model running on a sentence classification task. As I am working with medical notes I am using ClinicalBert (https://github.com/kexinhuang12345/clinicalBERT) and would like to use its pre-trained weights. Unfortunately the ClinicalBert model only classifies text into 1 binary label while I have 281 binary labels. I am therefore trying to implement this code https://github.com/kaushaltrivedi/bert-toxic-comments-multilabel/blob/master/toxic-bert-multilabel-classification.ipynb where the end classifier after bert is 281 long.

How can I load the pre-trained Bert weights from the ClinicalBert model without loading the classification weights?

Naively trying to load the weights from the pretrained ClinicalBert weights I get the following error:

size mismatch for classifier.weight: copying a param with shape torch.Size([2, 768]) from checkpoint, the shape in current model is torch.Size([281, 768]).
size mismatch for classifier.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([281]).

I currently tried to replace the from_pretrained function from the pytorch_pretrained_bert package and pop the classifier weights and biases like this:

def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
    ...
    if state_dict is None:
        weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
        state_dict = torch.load(weights_path, map_location='cpu')
    state_dict.pop('classifier.weight')
    state_dict.pop('classifier.bias')
    old_keys = []
    new_keys = []
    ...

And I get the following error message: INFO - modeling_diagnosis - Weights of BertForMultiLabelSequenceClassification not initialized from pretrained model: ['classifier.weight', 'classifier.bias']

In the end I would like to load the bert embeddings from the clinicalBert pretrained weights and have the top classifier weights initialized randomly.

Callaghan answered 14/4, 2020 at 15:42 Comment(0)
P
8

Removing the keys in the state dict before loading is a good start. Assuming you're using nn.Module.load_state_dict to load the pretrained weights then you'll also need to set the strict=False argument to avoid errors from unexpected or missing keys. This will ignore entries in the state_dict that aren't present in the model (unexpected keys) and, more importantly for you, will leave the missing entries with their default initialization (missing keys). For safety you can check the return value of the method to verify the weights in question are part of the missing keys and that there aren't any unexpected keys.

Pincer answered 14/4, 2020 at 15:54 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.