Here is how you can do that:
import tensorflow as tf
# Input data
a = tf.placeholder(tf.float32, [None, None])
num_top = tf.placeholder(tf.int32, [])
# Find top elements
a_top, a_top_idx = tf.nn.top_k(a, num_top, sorted=False)
# Apply softmax
a_top_sm = tf.nn.softmax(a_top)
# Reconstruct into original shape
a_shape = tf.shape(a)
a_row_idx = tf.tile(tf.range(a_shape[0])[:, tf.newaxis], (1, num_top))
scatter_idx = tf.stack([a_row_idx, a_top_idx], axis=-1)
result = tf.scatter_nd(scatter_idx, a_top_sm, a_shape)
# Test
with tf.Session() as sess:
result_val = sess.run(result, feed_dict={a: [[2, 5, 4, 7], [7, 5, 6, 8]], num_top: 2})
print(result_val)
Output:
[[0. 0.11920291 0. 0.880797 ]
[0.26894143 0. 0. 0.7310586 ]]
EDIT:
Actually, there is a function that more closely does what you intend, tf.sparse.softmax
. However, it requires a SparseTensor
as input, and I'm not sure it should be faster since it has to figure out which sparse values go together in the softmax. The good thing about this function is that you could have different number of elements to softmax in each row, but in your case that does not seem to be important. Anyway, here is an implementation with that, in case you find it useful.
import tensorflow as tf
a = tf.placeholder(tf.float32, [None, None])
num_top = tf.placeholder(tf.int32, [])
# Find top elements
a_top, a_top_idx = tf.nn.top_k(a, num_top, sorted=False)
# Flatten values
sparse_values = tf.reshape(a_top, [-1])
# Make sparse indices
shape = tf.cast(tf.shape(a), tf.int64)
a_row_idx = tf.tile(tf.range(shape[0])[:, tf.newaxis], (1, num_top))
sparse_idx = tf.stack([a_row_idx, tf.cast(a_top_idx, tf.int64)], axis=-1)
sparse_idx = tf.reshape(sparse_idx, [-1, 2])
# Make sparse tensor
a_top_sparse = tf.SparseTensor(sparse_idx, sparse_values, shape)
# Reorder sparse tensor
a_top_sparse = tf.sparse.reorder(a_top_sparse)
# Softmax
result_sparse = tf.sparse.softmax(a_top_sparse)
# Convert back to dense (or you can keep working with the sparse tensor)
result = tf.sparse.to_dense(result_sparse)
# Test
with tf.Session() as sess:
result_val = sess.run(result, feed_dict={a: [[2, 5, 4, 7], [7, 5, 6, 8]], num_top: 2})
print(result_val)
# Same as before