I have trained a tf.estimator.LinearClassifier. While training and evaluating the model takes a reasonable amount of time for my data size (~60 sec), predicting takes many order of magnitude longer (~1 hour).
The prediction code is as follow:
predictionResult = estimator.predict(input_fn=lambda: my_input_fn2(predictionValidationFile, False, 1))
predictionList = [prediction for prediction in predictionResult]
with:
def my_input_fn2(file_path, perform_shuffle=False, repeat_count=1):
def _parse_function(example_proto):
keys_to_features = {"xslm": tf.FixedLenFeature([10000], tf.float32),
"xrnn": tf.FixedLenFeature([10000], tf.float32),
"target": tf.FixedLenFeature([10000], tf.float32)}
parsed_features = tf.parse_single_example(example_proto, keys_to_features)
myfeatures = {'xrnn':parsed_features['xrnn'], 'xslm':parsed_features['xslm']}
return myfeatures, parsed_features['target']
dataset = (tf.data.TFRecordDataset(file_path)
.map(_parse_function))
dataset = dataset.repeat(repeat_count)
dataset = dataset.batch(1)
iterator = dataset.make_one_shot_iterator()
batch_feature, batch_labels = iterator.get_next()
xs= tf.reshape(batch_feature['xslm'],[-1,1])
xr= tf.reshape(batch_feature['xrnn'],[-1,1])
x = {'xrnn':xr, 'xslm':xs}
y = tf.reshape(batch_labels, [-1,1])
return x, y
The second line takes 0.8 sec to exececute when ran for 10 000 samples (corresponding to one batch). With 50 000 000 samples, prediction takes more than one hour.
My guess at this stage is that this slow performance is simply caused by the fact that the estimator predict() function is returning a python generator instead of returning the actual prediction results. For each batch, the generator eventually cause 10 000 calls to a function to get the 10 000 prediction results. This seems inefficient.
Are there any options to speed things up?
input_fn
? Also, a mini-batch of 10,000 seems pretty high regardless of how big your model might be. – Ximenesinput = my_input_fn2(filename) with tf.Session() as sess: sess.run(input)
then experiment with batch = 16 or something, as well as without the reshape – Ximenesestimator.predict
does not support distributed evaluation. So that isn't an option to speed it up either... – Yacov