Theano switch row-wise efficiently
Asked Answered
P

1

7

I have the following code

output = T.switch(cond, a, b)

where cond is a (N,1) bool Tensor, while a and b are (N, M) numeric Tensors with M being quite large. The condition operates on a row-wise manner.

I can easily make the switch work by running T.repeat() on cond, but this is quite slow. Is there a way I can efficiently make the bools in cond decide whether a or b should be returned?

Pamela answered 17/5, 2017 at 2:23 Comment(0)
E
3

Is there a way I can efficiently make the bools in cond decide whether a or b should be returned?

Yes, you could do

cond * a + (1-cond) * b

cond will be broadcast to (N, M) shape.

This should be close to the theoretical limit, which is the memory bandwidth: this operation needs to read about N*M elements and write N*M.

Instead, we read 2*N*M, but remove the conditional logic.

(I don't have Theano in front of me, so I am not sure if it's faster than T.switch, but it should be about as good as it gets. Also, I'd try casting cond to the same dtype as a and b)


If you want to update a in-place, you can do it using T.set_subtensor:

a = np.random.uniform(size=(N, M)).astype(np.float32)
b = np.random.uniform(size=(N, M)).astype(np.float32)

a = theano.shared(a)
b = theano.shared(b)

c = T.vector() # mostly 0, presumably (1-cond)

nz = T.nonzero(c)

s = T.set_subtensor(a[nz], b[nz])
fn = theano.function([c], [], updates=[(a, s)])

...

fn(1-cond)

It may or may not be faster than the first approach, depending on N, M and other factors.

Epizootic answered 21/5, 2017 at 20:9 Comment(5)
Thanks for the answer, I'll try it out! Interesting thoughts about the theoretical limit. I guess I could avoid the large reads and writes by exploiting that most often a would be the right value to return and it's fine for the method to modify a directly. Suppose only 5% of the time b should be returned for a given row, couldn't one obtain better performance by modifying a directly only on the rows needing modification?Pamela
@Pamela Are you optimizing for CPU or GPU? What are the typical N, N and dtype?Epizootic
@Pamela also, is this part of a NN or something that needs the gradient?Epizootic
Good questions! I'm optimizing for GPU and I do not need the gradient.Pamela
Thanks! I forgot to mention that N = 320, M = 10K and dtype = int32.Pamela

© 2022 - 2024 — McMap. All rights reserved.