TensorFlow Python automatically convert your NumPy array to a tf.Tensor
. In TensorFlow Java, you manipulate tensors directly.
Now the SavedModelBundle
does not have a predict
method. You need to obtain the session and run it, using the SessionRunner
and feeding it with input tensors.
For example, based on the next generation of TF Java (https://github.com/tensorflow/java), your code endup looking like this (note that I'm taking a lot of assumptions here about x_test_maxabs
since your code sample does not explain clearly where it comes from):
try (SavedModelBundle model = SavedModelBundle.load(path, "serve")) {
try (Tensor<TFloat32> input = TFloat32.tensorOf(...);
Tensor<TFloat32> output = model.session()
.runner()
.feed("input_name", input)
.fetch("output_name")
.run()
.expect(TFloat32.class)) {
float prediction = output.data().getFloat();
System.out.println("prediction = " + prediction);
}
}
If you are not sure what is the name of the input/output tensor in your graph, you can obtain programmatically by looking at the signature definition:
model.metaGraphDef().getSignatureDefMap().get("serving_default")
model.metaGraphDef().getSignatureDefMap().get("serving_default")
but other than that it worksed! – Arianearianie