Understanding data augmentation in the object detection API
Asked Answered
S

1

8

I am using the object detection API to train with a different dataset and I would like to know if it is possible to have sample images of what is reaching the network during the training.

I ask this because I am trying to find a good combination of data augmentation options (here the options), but the result adding them has been worse. Seeing what reaches the network in training would be very helpful.

Another question is if it is possible to get the API to help with balancing the classes, in case that the dataset passed have them unbalanced.

Thank you!

Surveying answered 14/11, 2017 at 20:21 Comment(1)
Is your question about understanding the meaning of each augmentation separately or about keeping track of what your model learns to detect while training? If former, I would recommend to look through the preprocessor.py (github.com/tensorflow/models/blob/master/research/…), if latter, consider using eval.py alongside the TensorBoard "images" pane which shows you your current evaluation results.Fiche
B
0

Yes, it is possible. Shortly speaking, you need to get an instance of tf.data.Dataset. Then, you can iterate over it and get the network input data as NumPy arrays. Saving it to image files using PIL or OpenCV is trivial then.

Assuming you use TF2 the pseudo-code is like this:

ds = ... get dataset object somehow

sample_num = 0
for features, _ in ds:
    images = features[fields.InputDataFields.image]  # is a [batch_size, H, W, C] float32 tensor with preprocessed images
    batch_size = images.shape[0]
    for i in range(batch_size):
        image = np.array(images[i] * 255).astype(np.uint8)  # assuming input data is only scaled to [0..1]
        cv2.imwrite(output_path, image)

    sample_num += 1
    if sample_num >= MAX_SAMPLES:
        break

The trick here is to get the Dataset instance. Google object detection API is very sophisticated, but I guess you should start with calling train_input function here: https://github.com/tensorflow/models/blob/3c8b6f1e17e230b68519fd8d58c4dd9e9570d789/research/object_detection/inputs.py#L763

It requires pipeline config sub-parts describing training, train_input and the model.

You can find some code snippets on how to work with pipeline here: Dynamically Editing Pipeline Config for Tensorflow Object Detection

import argparse

import tensorflow as tf
from google.protobuf import text_format
from object_detection.protos import pipeline_pb2


def parse_arguments():                                                                                                                                                                                                                                                
    parser = argparse.ArgumentParser(description='')                                                                                                                                                                                                                  
    parser.add_argument('pipeline')                                                                                                                                                                                                                                   
    parser.add_argument('output')                                                                                                                                                                                                                                     
    return parser.parse_args()                                                                                                                                                                                                                                        


def main():                                                                                                                                                                                                                                                           
    args = parse_arguments()                                                                                                                                                                                                                                          
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()                                                                                                                                                                                                          

    with tf.gfile.GFile(args.pipeline, "r") as f:                                                                                                                                                                                                                     
        proto_str = f.read()                                                                                                                                                                                                                                          
        text_format.Merge(proto_str, pipeline_config)   
Brusquerie answered 24/2, 2021 at 12:0 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.