Java Tensorflow + Keras Equivalent of model.predict()
Asked Answered
A

3

7

In python you can simply pass a numpy array to predict() to get predictions from your model. What is the equivalent using Java with a SavedModelBundle?

Python

model = tf.keras.models.Sequential([
  # layers go here
])
model.compile(...)
model.fit(x_train, y_train)

predictions = model.predict(x_test_maxabs) # <= This line 

Java

SavedModelBundle model = SavedModelBundle.load(path, "serve");
model.predict() // ????? // What does it take as in input? Tensor?
Arianearianie answered 7/6, 2020 at 6:8 Comment(0)
P
11

TensorFlow Python automatically convert your NumPy array to a tf.Tensor. In TensorFlow Java, you manipulate tensors directly.

Now the SavedModelBundle does not have a predict method. You need to obtain the session and run it, using the SessionRunner and feeding it with input tensors.

For example, based on the next generation of TF Java (https://github.com/tensorflow/java), your code endup looking like this (note that I'm taking a lot of assumptions here about x_test_maxabs since your code sample does not explain clearly where it comes from):

try (SavedModelBundle model = SavedModelBundle.load(path, "serve")) {
    try (Tensor<TFloat32> input = TFloat32.tensorOf(...);
        Tensor<TFloat32> output = model.session()
            .runner()
            .feed("input_name", input)
            .fetch("output_name")
            .run()
            .expect(TFloat32.class)) {

        float prediction = output.data().getFloat();
        System.out.println("prediction = " + prediction);
    }        
}

If you are not sure what is the name of the input/output tensor in your graph, you can obtain programmatically by looking at the signature definition:

model.metaGraphDef().getSignatureDefMap().get("serving_default")
Propitious answered 10/6, 2020 at 3:5 Comment(1)
I had to upgrade to TF 2.x to get model.metaGraphDef().getSignatureDefMap().get("serving_default") but other than that it worksed!Arianearianie
R
4

You can try Deep Java Library (DJL).

DJL internally use Tensorflow java and provide high level API to make it easy fro inference:

Criteria<Image, Classifications> criteria =
    Criteria.builder()
        .setTypes(Image.class, Classifications.class)
        .optModelUrls("https://example.com/squeezenet.zip")
        .optTranslator(ImageClassificationTranslator
               .builder().addTransform(new ToTensor()).build())
        .build();

try (ZooModel<Image, Classification> model = ModelZoo.load(criteria);
        Predictor<Image, Classification> predictor = model.newPredictor()) {
    Image image = ImageFactory.getInstance().fromUrl("https://myimage.jpg");
    Classification result = predictor.predict(image);
}


Checkout the github repo: https://github.com/awslabs/djl

There is a blogpost: https://towardsdatascience.com/detecting-pneumonia-from-chest-x-ray-images-e02bcf705dd6

And the demo project can be found: https://github.com/aws-samples/djl-demo/blob/master/pneumonia-detection/README.md

Reflation answered 16/6, 2020 at 23:31 Comment(0)
L
0

In 0.3.1 API:

val model: SavedModelBundle = SavedModelBundle.load("path/to/model", "serve")

val inputTensor = TFloat32.tesnorOf(..)

val function: ConcreteFunction = model.function(Signature.DEFAULT_KEY)
val result: Tensor = function.call(inputTensor) // u can cast to type you expect, a type of returning tensor can be checked by signature: model.function("serving_default").signature().toString()

After you got a result Tensor of any subtype, you can iterate over its values. In my example, I had a TFloat32 with shape (1, 56), so I found max value by result.get(0, idx)

Lynnalynne answered 27/4, 2021 at 19:32 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.