Embeddings are typically large enough that the only viable approach is using them to initialize a tf.Variable
in your graph. This will allow you to take advantage of param servers in distributed, etc.
For this (and anything else), I would recommend you use the new "core" estimator, tf.estimator.Estimator
as this will make things much easier.
From the answer in the link you provided, and knowing that we want a variable not a constant, we can either take approach:
(2) Initialize the variable using a feed dict, or
(3) Load the variable from a checkpoint
I'll cover option (3) first since it's much easier, and better:
In your model_fn
, simply initialize a variable using the Tensor
returned by a tf.contrib.framework.load_variable
call. This requires:
- That you have a valid TF checkpoint with your embeddings
- You know the fully qualified name of the embeddings variable within the checkpoint.
The code is pretty simple:
def model_fn(mode, features, labels, hparams):
embeddings = tf.Variable(tf.contrib.framework.load_variable(
'gs://my-bucket/word2vec_checkpoints/',
'a/fully/qualified/scope/embeddings'
))
....
return tf.estimator.EstimatorSpec(...)
However this approach won't work for you if your embeddings weren't produced by another TF model, hence option (2).
For (2), we need to use tf.train.Scaffold
which is essentially a configuration object that holds all the options for starting a tf.Session
(which estimator intentionally hides for lots of reasons).
You may specify a Scaffold
in the tf.train.EstimatorSpec
you return in your model_fn
.
We create a placeholder in our model_fn, and make it the
initializer operation for our embedding variable, then pass an init_feed_dict
via the Scaffold
. e.g.
def model_fn(mode, features, labels, hparams):
embed_ph = tf.placeholder(
shape=[hparams.vocab_size, hparams.embedding_size],
dtype=tf.float32)
embeddings = tf.Variable(embed_ph)
# Define your model
return tf.estimator.EstimatorSpec(
..., # normal EstimatorSpec args
scaffold=tf.train.Scaffold(init_feed_dict={embed_ph: my_embedding_numpy_array})
)
What's happening here is the init_feed_dict
will populate the values of the embed_ph
placeholder at runtime, which will then allow the embeddings.initialization_op
(assignment of the placeholder), to run.
tf.estimator.EstimatorSpec(..., scaffold=tf.train.Scaffold(init_feed_dict={embed_ph: my_embedding_numpy_array})
– Holifield