Caching a computed value as a constant in TensorFlow
Asked Answered
S

2

6

Suppose I want to compute the least squares coefficients in TensorFlow using the closed form solution. Normally, I would do this like so,

beta_hat = tf.matmul(
           tf.matmul(tf.matrix_inverse(tf.matmul(tf.transpose(X), X)), tf.transpose(X)), y
)

Where X and y are TensorFlow placeholders corresponding to the covariates and target variable, respectively.

If I then wanted to perform prediction, I would do something like,

y_pred = tf.matmul(X, beta_hat)

If I were to execute,

sess.run(y_pred, feed_dict={X: data_X})

I would of course get an error that I did not provide a necessary value for the placeholder y. I would like to have the flexibility to treat beta_hat as constant after I have computed it (so that I would not need to define a new placeholder for the new covariate matrix for prediction). One way to accomplish this is,

# Make it constant.
beta_hat = sess.run(beta_hat, feed_dict={X: data_X, y: data_y})
y_pred = tf.matmul(X, beta_hat)

I was wondering if there were a more elegant way to treat the tensor as constant so that I neither need to execute the session and obtain a constant nor create a separate placeholder for incoming data to be used for prediction.

Here is some sample code that demonstrates the circumstance I'm describing.

import numpy as np
import tensorflow as tf


n, k = 100, 5
X = tf.placeholder(dtype=tf.float32, shape=[None, k])
y = tf.placeholder(dtype=tf.float32, shape=[None, 1])

beta = np.random.normal(size=(k, ))
data_X = np.random.normal(size=(n, k))

data_y = data_X.dot(beta)
data_y += np.random.normal(size=data_y.shape) / 3.0
data_y = np.atleast_2d(data_y).T

# Convert to 32-bit precision.
data_X, data_y = np.float32(data_X), np.float32(data_y)

# Compute the least squares solution.
beta_hat = tf.matmul(
    tf.matmul(tf.matrix_inverse(tf.matmul(tf.transpose(X), X)),
              tf.transpose(X)), y
)

# Launch the graph
sess = tf.Session()
sess.run(tf.initialize_all_variables())

print "True beta: {}".format(beta)
print "Est. beta: {}".format(
    sess.run(beta_hat, feed_dict={X: data_X, y: data_y}).ravel()
)

# # This would error.
# y_pred = tf.matmul(X, beta_hat)
# print "Predictions:"
# print sess.run(y_pred, feed_dict={X: data_X})

# Make it constant.
beta_hat = sess.run(beta_hat, feed_dict={X: data_X, y: data_y})

# This will no longer error.
y_pred = tf.matmul(X, beta_hat)
print "Predictions:"
print sess.run(y_pred, feed_dict={X: data_X})
Shick answered 22/11, 2015 at 18:54 Comment(0)
P
2

Perhaps counter-intuitively, the simplest way to re-use beta_hat as a constant in subsequent steps would be to assign it to a tf.Variable:

n, k = 100, 5
X = tf.placeholder(dtype=tf.float32, shape=[None, k])
y = tf.placeholder(dtype=tf.float32, shape=[None, 1])

beta = np.random.normal(size=(k, ))
data_X = np.random.normal(size=(n, k))

data_y = data_X.dot(beta)
data_y += np.random.normal(size=data_y.shape) / 3.0
data_y = np.atleast_2d(data_y).T

# Convert to 32-bit precision.
data_X, data_y = np.float32(data_X), np.float32(data_y)

# Compute the least squares solution.
beta_hat = tf.matmul(
    tf.matmul(tf.matrix_inverse(tf.matmul(tf.transpose(X), X)),
              tf.transpose(X)), y
)

beta_hat_cached = tf.Variable(beta_hat)

# Launch the graph
sess = tf.Session()

print "True beta: {}".format(beta)
# Run the initializer, which computes `beta_hat` once:
sess.run(beta_hat_cached.initializer, feed_dict={X: data_X, y: data_y})
# To access the value of `beta_hat`, "run" the variable to read its contents.
print "Est. beta: {}".format(beta_hat_cached.ravel())

# Use the cached version to compute predictions.
y_pred = tf.matmul(X, beta_hat_cached)
print "Predictions:"
print sess.run(y_pred, feed_dict={X: data_X})
Pestilent answered 28/11, 2016 at 16:9 Comment(1)
ravel == eval? (Or is this function documented?)Jacobine
P
0

mrry has indeed presented an elegant solution. You should consider marking his answer as correct if that is indeed what you want.

However, I think this is a good place to clear up what I perceive to be a source of confusion regarding placeholders... This isn't necessarily directed at the person who asked the question, but I believe it will be relevant for a lot of beginners who stumble on this question...


Placeholders should be thought of like function inputs. So first, let's review how that works in Python, and then I will show the equivalent form in Tensorflow...

If I want to have a function that computes the output given various inputs x and y, then I could do it like this...

def f(x,y):
    # For example... 
    return x * y 

Specifically, I can call this function with various values for x and y:

f(1,3) = 3
f(1,4) = 4
f(2,3) = 6
f(2,4) = 8

However, in my particular case, I may have a fixed value of y. So in my case, it doesn't make sense to pass y as an argument. Instead, I want to bake my value of y into the function, and just vary x. To do that, I can simply capture the outer value of y:

y = 3
def g(x):
    return x * y

Now whenever I call g, y will have the fixed value of 3:

g(1) = 3
g(2) = 6

Similarly, if I also know that x is fixed, I could capture the outer value of x:

x = 2
def h():
   return g(x)

Now, when I call h, I am implicitly calling h()=g(2)=f(2,3).

That's great, but the problem is that everytime I call h, it will redo the multiplication because it is equivalent to calling f(2,3). Thus to improve performance, I can evaluate the expression, and then have a function that just returns this precomputed value:

val = h()
def h23():
    return val

No matter how many times I call h23, the multiplication is only performed once (on the line val = h()).

Tensorflow has analogous concepts.

If you want to have a function where you can vary both inputs, then you should make placeholder objects for both instances, and pass the values to the function in a feed dictionary when running in a session:

dtype = tf.float64
shape = ()
x = tf.placeholder( dtype, shape )
y = tf.placeholder( dtype, shape )
fxy = f(x,y)
with tf.Session() as sess: 
    print( sess.run( fxy, {x:1,y:3} ) )
    print( sess.run( fxy, {x:1,y:4} ) )
    print( sess.run( fxy, {x:2,y:3} ) )
    print( sess.run( fxy, {x:2,y:4} ) )

However, if one of my values does not change, then I can directly initialize it as a constant and create a new function with this value "baked into it":

y = tf.constant( 3 )
gx = f( x, y )
with tf.Session() as sess: 
    print( sess.run( gx, {x:1} ) )
    print( sess.run( gx, {x:2} ) )

The key point is that now I do not need to pass a value for y in my feed dictionary. It is constant and captured in the expression gx.
Similarly, if x is also a constant, then I should declare it so:

x = tf.constant(2)
h = f(x,y)
with tf.Session() as sess: 
    print( sess.run( h ) )

As you can see, since all of my variables are constant, I don't need a feed dictionary at all. This is the Tensorflow equivalent of calling a function without arguments, like h().

However, just as before, when I call h, it may need to reevaluate the graph every time. So I have two options.

  1. I can compute the result in numpy, then wrap that value with a tensorflow constant.
  2. I can compute the output in tensorflow, run it in a session to get the numpy value, and then wrap it in a constant.

In the first option, I would do something like this

fxy = tf.constant( f(2,3) ) 

Now I have precomputed the value of the function outside of Tensorflow, and then wrapped that value as a constant so that I can use it in other tensorflow functions.

Conversely, you would only consider option 2 if your function uses some complicated Tensorflow intrinsics, or if your function takes really long to run and you think it will be faster to computer in Tensorflow:

with tf.Session() as sess: 
    fxy = tf.constant( sess.run( h ) )

To understand what is going on here, recall that

h = f( tf.constant(1), tf.constant(3) )

So I don't need to pass a feed dict. The snippet sess.run( h ) runs that multiplication inside of tensorflow and returns it as a Numpy array. Then finally, I wrap that value with a tf.constant so that I can use it in other Tensorflow functions.

Parang answered 5/2, 2018 at 13:2 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.