Is it possible to freeze only certain embedding weights in the embedding layer in pytorch?
Asked Answered
L

1

16

When using GloVe embedding in NLP tasks, some words from the dataset might not exist in GloVe. Therefore, we instantiate random weights for these unknown words.

Would it be possible to freeze weights gotten from GloVe, and train only the newly instantiated weights?

I am only aware that we can set: model.embedding.weight.requires_grad = False

But this makes the new words untrainable..

Or are there better ways to extract semantics of words..

Lengthen answered 28/2, 2019 at 11:23 Comment(0)
F
20

1. Divide embeddings into two separate objects

One approach would be to use two separate embeddings one for pretrained, another for the one to be trained.

The GloVe one should be frozen, while the one for which there is no pretrained representation would be taken from the trainable layer.

If you format your data that for pretrained token representations it is in smaller range than the tokens without GloVe representation it could be done. Let's say your pretrained indices are in the range [0, 300], while those without representation are [301, 500]. I would go with something along those lines:

import numpy as np
import torch


class YourNetwork(torch.nn.Module):
    def __init__(self, glove_embeddings: np.array, how_many_tokens_not_present: int):
        self.pretrained_embedding = torch.nn.Embedding.from_pretrained(glove_embeddings)
        self.trainable_embedding = torch.nn.Embedding(
            how_many_tokens_not_present, glove_embeddings.shape[1]
        )
        # Rest of your network setup

    def forward(self, batch):
        # Which tokens in batch do not have representation, should have indices BIGGER
        # than the pretrained ones, adjust your data creating function accordingly
        mask = batch > self.pretrained_embedding.num_embeddings

        # You may want to optimize it, you could probably get away without copy, though
        # I'm not currently sure how
        pretrained_batch = batch.copy()
        pretrained_batch[mask] = 0

        embedded_batch = self.pretrained_embedding(pretrained_batch)

        # Every token without representation has to be brought into appropriate range
        batch -= self.pretrained_embedding.num_embeddings
        # Zero out the ones which already have pretrained embedding
        batch[~mask] = 0
        non_pretrained_embedded_batch = self.trainable_embedding(batch)

        # And finally change appropriate tokens from placeholder embedding created by
        # pretrained into trainable embeddings.
        embedded_batch[mask] = non_pretrained_embedded_batch[mask]

        # Rest of your code
        ...

Let's say your pretrained indices are in the range [0, 300], while those without representation are [301, 500].

2. Zero gradients for specified tokens.

This one is a bit tricky, but I think it's pretty concise and easy to implement. So, if you obtain the indices of tokens which got no GloVe representation, you can explicitly zero their gradient after backprop, so those rows will not get updated.

import torch

embedding = torch.nn.Embedding(10, 3)
X = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])

values = embedding(X)
loss = values.mean()

# Use whatever loss you want
loss.backward()

# Let's say those indices in your embedding are pretrained (have GloVe representation)
indices = torch.LongTensor([2, 4, 5])

print("Before zeroing out gradient")
print(embedding.weight.grad)

print("After zeroing out gradient")
embedding.weight.grad[indices] = 0
print(embedding.weight.grad)

And the output of the second approach:

Before zeroing out gradient
tensor([[0.0000, 0.0000, 0.0000],
        [0.0417, 0.0417, 0.0417],
        [0.0833, 0.0833, 0.0833],
        [0.0417, 0.0417, 0.0417],
        [0.0833, 0.0833, 0.0833],
        [0.0417, 0.0417, 0.0417],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0417, 0.0417, 0.0417]])
After zeroing out gradient
tensor([[0.0000, 0.0000, 0.0000],
        [0.0417, 0.0417, 0.0417],
        [0.0000, 0.0000, 0.0000],
        [0.0417, 0.0417, 0.0417],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0417, 0.0417, 0.0417]])
Fanni answered 1/3, 2019 at 21:55 Comment(7)
In the second approach, since you just set to zero some of the obtained gradients, wouldn't the computation still be just as heavy as training all the embeddings from the beginning?Ampersand
@AndreaRossi Yes, hence first approach is much better. Just outlined another possibilityFanni
Couple of minor tweaks to get this working for me. I had to use self.pretrained_embedding.num_embeddings instead of self.pretrained_embedding.shape[0]. I also think the line embedded_batch = self.pretrained_embedding[pretrained_batch] needs to be updated very slightly to use round brackets - embedded_batch = self.pretrained_embedding(pretrained_batch). Thank you for the detailed answer btw, this is the only solution I found that actually explains how to implement a multiple embedding approach.Carrew
@Carrew Thanks, updated the answer. Do not hesitate to edit my answers if you see something is incorrect next time though. :)Fanni
Didn't realize that was possible actually, but definitely will do next time!!Carrew
One detail need clarification here. If you zero out the gradients, does it also stop applying weight decay?Lode
@menglin - no, weight decay is independent of the gradient and only relies on the weight magnitude (or so it should, I think it's different for Adam)Fanni

© 2022 - 2024 — McMap. All rights reserved.