I have an unbatched tensorflow
dataset that looks like this:
ds = ...
for record in ds.take(3):
print('data shape={}'.format(record['data'].shape))
-> data shape=(512, 512, 87)
-> data shape=(512, 512, 277)
-> data shape=(512, 512, 133)
I want to feed the data to my network in chunks of depth 5. In the example above, the tensor of shape (512, 512, 87) would be divided into 17 tensors of shape (512, 512, 5). The final 2 rows of the matrix (tensor[:,:, 85:87]
) should be discarded.
For example:
chunked_ds = ...
for record in chunked_ds.take(1):
print('chunked data shape={}'.format(record['data'].shape))
-> chunked data shape=(512, 512, 5)
How can I get from ds
to chunked_ds
? tf.data.Dataset.window()
looks like what I need but I cannot get this working.