1. Divide embeddings into two separate objects
One approach would be to use two separate embeddings one for pretrained, another for the one to be trained.
The GloVe one should be frozen, while the one for which there is no pretrained representation would be taken from the trainable layer.
If you format your data that for pretrained token representations it is in smaller range than the tokens without GloVe representation it could be done. Let's say your pretrained indices are in the range [0, 300], while those without representation are [301, 500]. I would go with something along those lines:
import numpy as np
import torch
class YourNetwork(torch.nn.Module):
def __init__(self, glove_embeddings: np.array, how_many_tokens_not_present: int):
self.pretrained_embedding = torch.nn.Embedding.from_pretrained(glove_embeddings)
self.trainable_embedding = torch.nn.Embedding(
how_many_tokens_not_present, glove_embeddings.shape[1]
)
# Rest of your network setup
def forward(self, batch):
# Which tokens in batch do not have representation, should have indices BIGGER
# than the pretrained ones, adjust your data creating function accordingly
mask = batch > self.pretrained_embedding.num_embeddings
# You may want to optimize it, you could probably get away without copy, though
# I'm not currently sure how
pretrained_batch = batch.copy()
pretrained_batch[mask] = 0
embedded_batch = self.pretrained_embedding(pretrained_batch)
# Every token without representation has to be brought into appropriate range
batch -= self.pretrained_embedding.num_embeddings
# Zero out the ones which already have pretrained embedding
batch[~mask] = 0
non_pretrained_embedded_batch = self.trainable_embedding(batch)
# And finally change appropriate tokens from placeholder embedding created by
# pretrained into trainable embeddings.
embedded_batch[mask] = non_pretrained_embedded_batch[mask]
# Rest of your code
...
Let's say your pretrained indices are in the range [0, 300], while those without representation are [301, 500].
2. Zero gradients for specified tokens.
This one is a bit tricky, but I think it's pretty concise and easy to implement. So, if you obtain the indices of tokens which got no GloVe representation, you can explicitly zero their gradient after backprop, so those rows will not get updated.
import torch
embedding = torch.nn.Embedding(10, 3)
X = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
values = embedding(X)
loss = values.mean()
# Use whatever loss you want
loss.backward()
# Let's say those indices in your embedding are pretrained (have GloVe representation)
indices = torch.LongTensor([2, 4, 5])
print("Before zeroing out gradient")
print(embedding.weight.grad)
print("After zeroing out gradient")
embedding.weight.grad[indices] = 0
print(embedding.weight.grad)
And the output of the second approach:
Before zeroing out gradient
tensor([[0.0000, 0.0000, 0.0000],
[0.0417, 0.0417, 0.0417],
[0.0833, 0.0833, 0.0833],
[0.0417, 0.0417, 0.0417],
[0.0833, 0.0833, 0.0833],
[0.0417, 0.0417, 0.0417],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0417, 0.0417, 0.0417]])
After zeroing out gradient
tensor([[0.0000, 0.0000, 0.0000],
[0.0417, 0.0417, 0.0417],
[0.0000, 0.0000, 0.0000],
[0.0417, 0.0417, 0.0417],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0417, 0.0417, 0.0417]])