How to show the class distribution in Dataset object in Tensorflow
Asked Answered
K

3

6

I am working on a multi-class classification task using my own images.

filenames = [] # a list of filenames
labels = [] # a list of labels corresponding to the filenames
full_ds = tf.data.Dataset.from_tensor_slices((filenames, labels))

This full dataset will be shuffled and split into train, valid and test dataset

full_ds_size = len(filenames)
full_ds = full_ds.shuffle(buffer_size=full_ds_size*2, seed=128) # seed is used for reproducibility

train_ds_size = int(0.64 * full_ds_size)
valid_ds_size = int(0.16 * full_ds_size)

train_ds = full_ds.take(train_ds_size)
remaining = full_ds.skip(train_ds_size)  
valid_ds = remaining.take(valid_ds_size)
test_ds = remaining.skip(valid_ds_size)

Now I am struggling to understand how each class is distributed in train_ds, valid_ds and test_ds. An ugly solution is to iterate all the element in the dataset and count the occurrence of each class. Is there any better way to solve it?

My ugly solution:

def get_class_distribution(dataset):
    class_distribution = {}
    for element in dataset.as_numpy_iterator():
        label = element[1]

        if label in class_distribution.keys():
            class_distribution[label] += 1
        else:
            class_distribution[label] = 0

    # sort dict by key
    class_distribution = collections.OrderedDict(sorted(class_distribution.items())) 
    return class_distribution


train_ds_class_dist = get_class_distribution(train_ds)
valid_ds_class_dist = get_class_distribution(valid_ds)
test_ds_class_dist = get_class_distribution(test_ds)

print(train_ds_class_dist)
print(valid_ds_class_dist)
print(test_ds_class_dist)
Khalilahkhalin answered 26/3, 2020 at 21:56 Comment(0)
P
2

The answer below assumes:

  • there are five classes.
  • labels are integers from 0 to 4.

It can be modified to suit your needs.

Define a counter function:

def count_class(counts, batch, num_classes=5):
    labels = batch['label']
    for i in range(num_classes):
        cc = tf.cast(labels == i, tf.int32)
        counts[i] += tf.reduce_sum(cc)
    return counts

Use the reduce operation:

initial_state = dict((i, 0) for i in range(5))
counts = train_ds.reduce(initial_state=initial_state,
                         reduce_func=count_class)

print([(k, v.numpy()) for k, v in counts.items()])
Permit answered 26/3, 2020 at 23:22 Comment(0)
B
0

A solution inspired by user650654 's answer, only using TensorFlow primitives (with tf.unique_with_counts instead of for loop):

In theory, this should have better performance and scale better to large datasets, batches or class count.

num_classes = 5

@tf.function
def count_class(counts, batch):
    y, _, c = tf.unique_with_counts(batch[1])
    return tf.tensor_scatter_nd_add(counts, tf.expand_dims(y, axis=1), c)

counts = train_ds.reduce(
    initial_state=tf.zeros(num_classes, tf.int32),
    reduce_func=count_class)

print(counts.numpy())

Similar and simpler version with numpy that actually had better performances for my simple use-case:

count = np.zeros(num_classes, dtype=np.int32)
for _, labels in train_ds:
    y, _, c = tf.unique_with_counts(labels)
    count[y.numpy()] += c.numpy()
print(count)
Bleeder answered 26/5, 2022 at 0:56 Comment(2)
what is df_train?Pentose
@Pentose corrected to train_ds to match the questionBleeder
C
0

I couldn't get the other answers here to work for me, but I came up with another solution that worked for my use case. Posting it here in case it's of any use to others.

My use case was a tf.data.Dataset of x,y tensor pairs (x=training data, y=one-hot label).

I've unravelled and indented the rather long one-liner to hopefully aid readability:

#Get unique labels and associated counts
labels, counts = \
  np.unique(
    np.vstack(
      list(
        train_ds.map(lambda x, y : y).as_numpy_iterator()
      )
    ),
    axis = 0, 
    return_counts = True
  )

#Inspect labels and counts
for label, count in zip(labels, counts):
  print("Class label: " + str(label) + " count: " + str(count))

Output in my use case (3 classes):

Class label: [0 0 1] count: 24
Class label: [0 1 0] count: 99
Class label: [1 0 0] count: 134
Coleman answered 7/6 at 17:8 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.