Data Normalization with tensorflow tf-transform
Asked Answered
Y

1

7

I'm doing a neural network prediction with my own datasets using Tensorflow. The first I did was a model that works with a small dataset in my computer. After this, I changed the code a little bit in order to use Google Cloud ML-Engine with bigger datasets to realize in ML-Engine the train and the predictions.

I am normalizing the features in the panda dataframe but this introduces skew and I get poor prediction results.

What I would really like is use the library tf-transform to normalize the data in the graph. To do this, I would like to create a function preprocessing_fn and use the 'tft.scale_to_0_1'. https://github.com/tensorflow/transform/blob/master/getting_started.md

The main problem that I found is when I'm trying to do the predict. I'm looking for internet but I don't find any example of exported model where the data is normalized in the training. In all the examples I found, the data is NOT normalized anywhere.

What I would like to know is If I normalize the data in the training and I send a new instance with new data to do the prediction, how is normalized this data?

¿Maybe in the Tensorflow Data Pipeline? The variables to do the normalization are saved in some place?

In summary: I'm looking for a way to normalize the inputs for my model and then that the new instances also become standardized.

Yacov answered 28/9, 2017 at 17:2 Comment(0)
P
11

First of all, you don't really need tf.transform for this. All you need to do is to write a function that you call from both the training/eval input_fn and from your serving input_fn.

For example, assuming that you've used Pandas on your whole dataset to figure out the min and max

def add_engineered(features):
  min_x = 22
  max_x = 43
  features['x'] = (features['x'] - min_x) / (max_x - min_x)
  return features

Then, in your input_fn, wrap the features you return with a call to add_engineered:

def input_fn():
  features = ...
  label = ...
  return add_engineered(features), label

and in your serving_input fn, make sure to similarly wrap the returned features (NOT the feature_placeholders) with a call to add_engineered:

def serving_input_fn():
    feature_placeholders = ...
    features = feature_placeholders.copy()
    return tf.estimator.export.ServingInputReceiver(
         add_engineered(features), feature_placeholders)

Now, your JSON input at prediction time would only need to contain the original, unscaled values.

Here's a complete working example of this approach.

https://github.com/GoogleCloudPlatform/training-data-analyst/blob/master/courses/machine_learning/feateng/taxifare/trainer/model.py#L130

tf.transform provides for a two-phase process: an analysis step to compute the min, max and a graph-modification step to insert the scaling for you into your TensorFlow graph. So, to use tf.transform, you first need to write a Dataflow pipeline does the analysis and then plug in calls to tf.scale_0_to_1 inside your TensorFlow code. Here's an example of doing this:

https://github.com/GoogleCloudPlatform/cloudml-samples/tree/master/criteo_tft

The add_engineered() approach is simpler and is what I would suggest. The tf.transform approach is needed if your data distributions will shift over time, and so you want to automate the entire pipeline (e.g. for continuous training).

Phonograph answered 29/9, 2017 at 0:40 Comment(4)
Sorry to answer so late, but the add_engineered() approach worked perfectly!Yacov
@lak I just came across this post as I had a similar question. If you have a moment, does this work with tf.estimator.export.ServingInputReceiver ?Serpentiform
@Phonograph I also have a similar problem, I can't make it work. How can I use a simple python function that process a string feature/column, in a serving_input_fn with ServingInputReceiver? InputFnOps seems deprecated, can you point us to some example?Deneendenegation
I've edited my answer above to use ServingInputReceiver instead. Note that add_engineered() has to use TensorFlow operations (since it is part of a TensorFlow graph; you can not use arbitrary python functions like calling out to a database etc.)Phonograph

© 2022 - 2024 — McMap. All rights reserved.