What tensorflow's flat_map + window.batch() does to a dataset/array?
Asked Answered
C

2

6

I'm following one of the online courses about time series predictions using Tensorflow. The function used to convert Numpy array (TS) into a Tensorflow dataset used is LSTM-based model is already given (with my comment lines):

def windowed_dataset(series, window_size, batch_size, shuffle_buffer):
     # creating a tensor from an array
     dataset = tf.data.Dataset.from_tensor_slices(series)
     # cutting the tensor into fixed-size windows
     dataset = dataset.window(window_size + 1, shift=1, drop_remainder=True)  
     # joining windows into a batch?
     dataset = dataset.flat_map(lambda window: window.batch(window_size + 1))
     # separating row into features/label
     dataset = dataset.shuffle(shuffle_buffer).map(lambda window: (window[:-1], window[-1]))
     dataset = dataset.batch(batch_size).prefetch(1)
     return dataset

This code work fine but I want to understand it better to modify/adapt it for my needs.

If I remove dataset.flat_map(lambda window: window.batch(window_size + 1)) operation, I receive the TypeError: '_VariantDataset' object is not subscriptable pointing to the line: lambda window: (window[:-1], window[-1]))

I managed to rewrite part of this code (skipping shuffling) to Numpy-based one:

def windowed_dataset_np(series, window_size):
    values = sliding_window_view(series, window_size)
    X = values[:, :-1]
    X = tf.convert_to_tensor(np.expand_dims(X, axis=-1))
    y = values[:,-1]
    return X, y

Syntax of fitting of the model looks a bit differently but it works fine.

My two questions are:

  1. What does dataset.flat_map(lambda window: window.batch(window_size + 1)) achieves?
  2. Is the second code really equivalent to the three first operations in the original function?
Centigram answered 21/2, 2022 at 18:5 Comment(0)
B
9

I would break down the operations into smaller parts to really understand what is happening, since applying window to a dataset actually creates a dataset of windowed datasets containing tensor sequences:

import tensorflow as tf

window_size = 2
dataset = tf.data.Dataset.range(7)
dataset = dataset.window(window_size + 1, shift=1, drop_remainder=True)  

for i, window in enumerate(dataset):
  print('{}. windowed dataset'.format(i + 1))
  for w in window:
    print(w)
1. windowed dataset
tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(2, shape=(), dtype=int64)
2. windowed dataset
tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(3, shape=(), dtype=int64)
3. windowed dataset
tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(4, shape=(), dtype=int64)
4. windowed dataset
tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(4, shape=(), dtype=int64)
tf.Tensor(5, shape=(), dtype=int64)
5. windowed dataset
tf.Tensor(4, shape=(), dtype=int64)
tf.Tensor(5, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64)

Notice how the window is always shifted by one position due to the parameter shift=1. Now, the operation flat_map is used here to flatten the dataset of datasets into a dataset of elements; however, you still want to keep the windowed sequences you created so you divide the dataset according to the window parameters using dataset.batch:

dataset = dataset.flat_map(lambda window: window.batch(window_size + 1))
for w in dataset:
  print(w)
tf.Tensor([0 1 2], shape=(3,), dtype=int64)
tf.Tensor([1 2 3], shape=(3,), dtype=int64)
tf.Tensor([2 3 4], shape=(3,), dtype=int64)
tf.Tensor([3 4 5], shape=(3,), dtype=int64)
tf.Tensor([4 5 6], shape=(3,), dtype=int64)

You could also first flatten the dataset of datasets and then apply batch if you want to create the windowed sequences:

dataset = dataset.flat_map(lambda window: window).batch(window_size + 1)

Or only flatten the dataset of datasets:

dataset = dataset.flat_map(lambda window: window)
for w in dataset:
  print(w)
tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(4, shape=(), dtype=int64)
tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(4, shape=(), dtype=int64)
tf.Tensor(5, shape=(), dtype=int64)
tf.Tensor(4, shape=(), dtype=int64)
tf.Tensor(5, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64)

But that is probably not what you want. Regarding this line in your question: dataset = dataset.shuffle(shuffle_buffer).map(lambda window: (window[:-1], window[-1])), it is pretty trivial. It simply splits the data into sequences and labels, using the last element of each sequence as the label:

dataset = dataset.shuffle(2).map(lambda window: (window[:-1], window[-1]))
for w in dataset:
  print(w)
(<tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 2])>, <tf.Tensor: shape=(), dtype=int64, numpy=3>)
(<tf.Tensor: shape=(2,), dtype=int64, numpy=array([2, 3])>, <tf.Tensor: shape=(), dtype=int64, numpy=4>)
(<tf.Tensor: shape=(2,), dtype=int64, numpy=array([3, 4])>, <tf.Tensor: shape=(), dtype=int64, numpy=5>)
(<tf.Tensor: shape=(2,), dtype=int64, numpy=array([4, 5])>, <tf.Tensor: shape=(), dtype=int64, numpy=6>)
(<tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1])>, <tf.Tensor: shape=(), dtype=int64, numpy=2>)
Bevvy answered 21/2, 2022 at 19:26 Comment(3)
can you please check whether my answer is correct? I tried to put your answer into a more visual format.Bookcraft
Thanks seems fine to meBevvy
What is the difference between window and batch?Skein
B
3

As far as I understand the data structures, the data nesting can be visualized like below.

So first, we have a dataset with a lot of windows (which are also datasets), where every window consists of tensors. Each tensor holds a single value from the original time series.

As AloneTogether has shown, if lambda is "window: window" the flat_map function would just remove the hierachical structure, which I call window in my diagram. You end up with a dataset containing the tensors with only one element.

But the lambda function is using batch to combine the single element tensors in each window dataset to just one tensor with multiple elements.

This way, when you flatten (aka removing the hierarchical structure I call window), you are left with a dataset containing tensors with multiple elements.

As a conclusion: When we say "batch" in this context, we mean the values from a window of the original dataset, grouped into a tensor.

Edit: When I copied the orange box in the right part of the picture I forgot to change the numbers within it. Should be Tensor 1 (0,1,2), Tensor 2 (1,2,3) and Tensor X with (?,?,?).

Visualization of what happens

Bookcraft answered 3/11, 2023 at 15:5 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.