Faster way to do multiple embeddings in PyTorch?
Asked Answered
S

1

9

I'm working on a torch-based library for building autoencoders with tabular datasets.

One big feature is learning embeddings for categorical features.

In practice, however, training many embedding layers simultaneously is creating some slowdowns. I am using for-loops to do this and running the for-loop on each iteration is (I think) what's causing the slowdowns.

When building the model, I associate embedding layers with each categorical feature in the user's dataset:

        for ft in self.categorical_fts:
            feature = self.categorical_fts[ft]
            n_cats = len(feature['cats']) + 1
            embed_dim = compute_embedding_size(n_cats)
            embed_layer = torch.nn.Embedding(n_cats, embed_dim)
            feature['embedding'] = embed_layer

Then, with a call to .forward():

        embeddings = []
        for i, ft in enumerate(self.categorical_fts):
            feature = self.categorical_fts[ft]
            emb = feature['embedding'](codes[i])
            embeddings.append(emb)

        #num and bin are numeric and binary features
        x = torch.cat(num + bin + embeddings, dim=1)

Then x goes into dense layers.

This gets the job done but running this for loop during each forward pass really slows down training, especially when a dataset has tens or hundreds of categorical columns.

Does anybody know of a way of vectorizing something like this? Thanks!

UPDATE: For more clarity, I made this sketch of how I'm feeding categorical features into the network. You can see that each categorical column has its own embedding matrix, while numeric features are concatenated directly to their output before being passed into the feed-forward network.

diagram

Can we do this without iterating through each embedding matrix?

Selwin answered 24/6, 2019 at 2:15 Comment(2)
Would love to hear if there's a more efficient / cleaner way to do this tooBolme
Me too, but I haven't found anything so far :( in this case, tensorflow's compiled computation graph is more efficient but I'm disappointed because I like pytorch better.Selwin
T
0

just use simple indexing [, though i'm not sure whether it is fast enough

Here is a simplified version for all feature have same vocab_size and embedding dim, but it should apply to cases of heterogeneous category features

xdim = 240
embed_dim = 8
vocab_size = 64
embedding_table = torch.randn(size=(xdim, vocab_size, embed_dim))

batch_size = 32
x = torch.randint(vocab_size, size=(batch_size, xdim))

out = embedding_table[torch.arange(xdim), x]
out.shape  # (bz, xdim, embed_dim)

# unit test
i = np.random.randint(batch_size)
j = np.random.randint(xdim)

x_index = x[i][j]
w = embedding_table[j]

torch.allclose(w[x_index], out[i, j])
Timbering answered 6/5, 2022 at 9:54 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.