Broadcast and concatenate ragged tensors
Asked Answered
L

2

8

I have a ragged tensor of dimensions [BATCH_SIZE, TIME_STEPS, EMBEDDING_DIM]. I want to augment the last axis with data from another tensor of shape [BATCH_SIZE, AUG_DIM]. Each time step of a given example gets augmented with the same value.

If the tensor wasn't ragged with varying TIME_STEPS for each example, I could simply reshape the second tensor with tf.repeat and then use tf.concat:

import tensorflow as tf


# create data
# shape: [BATCH_SIZE, TIME_STEPS, EMBEDDING_DIM]
emb = tf.constant([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [0, 0, 0]]])
# shape: [BATCH_SIZE, 1, AUG_DIM]
aug = tf.constant([[[8]], [[9]]])

# concat
aug = tf.repeat(aug, emb.shape[1], axis=1)
emb_aug = tf.concat([emb, aug], axis=-1)

This doesn't approach work when emb is ragged since emb.shape[1] is unknown and varies across examples:

# rag and remove padding
emb = tf.RaggedTensor.from_tensor(emb, padding=(0, 0, 0))

# reshape for augmentation - this doesn't work
aug = tf.repeat(aug, emb.shape[1], axis=1)

ValueError: Attempt to convert a value (None) with an unsupported type (<class 'NoneType'>) to a Tensor.

The goal is to create a ragged tensor emb_aug which looks like this:

<tf.RaggedTensor [[[1, 2, 3, 8], [4, 5, 6, 8]], [[1, 2, 3 ,9]]]>

Any ideas?

Leavenworth answered 12/3, 2021 at 18:55 Comment(0)
C
3

The easiest way to do this is to just make your ragged tensor a regular tensor by using tf.RaggedTensor.to_tensor() and then do the rest of your solution. I'll assume that you need the tensor to remain ragged. The key is to find the row_lengths of each batch in your ragged tensor, and then use this information to make your augmentation tensor ragged.

Example:

import tensorflow as tf


# data
emb = tf.constant([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [0, 0, 0]]])
aug = tf.constant([[[8]], [[9]]])

# make embeddings ragged for testing
emb_r = tf.RaggedTensor.from_tensor(emb, padding=(0, 0, 0))

print(emb_r.shape)
# (2, None, 3)

Here we'll use a combination of row_lengths and sequence_mask to create a new ragged tensor.

# find the row lengths of the embeddings
rl = emb_r.row_lengths()

print(rl)
# tf.Tensor([2 1], shape=(2,), dtype=int64)

# find the biggest row length
max_rl = tf.math.reduce_max(rl)

print(max_rl)
# tf.Tensor(2, shape=(), dtype=int64)

# repeat the augmented data `max_rl` number of times
aug_t = tf.repeat(aug, repeats=max_rl, axis=1)

print(aug_t)
# tf.Tensor(
# [[[8]
#   [8]]
# 
#  [[9]
#   [9]]], shape=(2, 2, 1), dtype=int32)

# create a mask
msk = tf.sequence_mask(rl)

print(msk)
# tf.Tensor(
# [[ True  True]
#  [ True False]], shape=(2, 2), dtype=bool)

From here we can use tf.ragged.boolean_mask to make the augmented data ragged

# make the augmented data a ragged tensor
aug_r = tf.ragged.boolean_mask(aug_t, msk)
print(aug_r)
# <tf.RaggedTensor [[[8], [8]], [[9]]]>

# concatenate!
output = tf.concat([emb_r, aug_r], 2)
print(output)
# <tf.RaggedTensor [[[1, 2, 3, 8], [4, 5, 6, 8]], [[1, 2, 3, 9]]]>

You can find the list of tensorflow methods that support ragged tensors here

Ciceronian answered 13/3, 2021 at 15:17 Comment(0)
K
0

Ragged Tensors can be constructed from row lengths directly. The values input is a flat (with respect to the future ragged dimension not all other dimensions) tensor that can be constructed using tf.repeat, again using the row_lengths to find the appropriate number of repeats per sample!

ragged_lengths = emb.row_lengths()
aug = tf.RaggedTensor.from_row_lengths(
         values=tf.repeat(aug, ragged_lengths, axis=0),
         row_lengths=ragged_lengths)
emb_aug = tf.concat([emb, aug], axis=-1)
Krummhorn answered 18/3, 2022 at 14:31 Comment(1)
Remember that Stack Overflow isn't just intended to solve the immediate problem, but also to help future readers find solutions to similar problems, which requires understanding the underlying code. This is especially important for members of our community who are beginners, and not familiar with the syntax. Given that, can you edit your answer to include an explanation of what you're doing and why you believe it is the best approach?Brummett

© 2022 - 2024 — McMap. All rights reserved.