How to use 'collate_fn' with dataloaders?
Asked Answered
R

2

39

I am trying to train a pretrained roberta model using 3 inputs, 3 input_masks and a label as tensors of my training dataset.

I do this using the following code:

from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
batch_size = 32
# Create the DataLoader for our training set.
train_data = TensorDataset(train_AT, train_BT, train_CT, train_maskAT, train_maskBT, train_maskCT, labels_trainT)
train_dataloader = DataLoader(train_data, batch_size=batch_size)

# Create the Dataloader for our validation set.
validation_data = TensorDataset(val_AT, val_BT, val_CT, val_maskAT, val_maskBT, val_maskCT, labels_valT)
val_dataloader = DataLoader(validation_data, batch_size=batch_size)

# Pytorch Training
training_args = TrainingArguments(
    output_dir='C:/Users/samvd/Documents/Master/AppliedMachineLearning/FinalProject/results',          # output directory
    num_train_epochs=1,              # total # of training epochs
    per_device_train_batch_size=32,  # batch size per device during training
    per_device_eval_batch_size=32,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='C:/Users/samvd/Documents/Master/AppliedMachineLearning/FinalProject/logs',            # directory for storing logs
)

trainer = Trainer(
    model=model,                          # the instantiated 🤗 Transformers model to be trained
    args=training_args,                   # training arguments, defined above
    train_dataset = train_data,           # training dataset
    eval_dataset = validation_data,       # evaluation dataset
)

trainer.train()

However this gives me the following error:

TypeError: vars() argument must have dict attribute

Now I have found out that it is probably because I don't use collate_fn when using DataLoader, but I can't really find a source that helps me define this correctly so the trainer understands the different tensors I put in.

Can anyone point me in the right direction?

Rizzio answered 13/12, 2020 at 18:23 Comment(2)
You have posted three times on the same problem, I am not sure it will help you get an answer. I would recommend editing your original question. This will help readers answer your question.Aton
Does this answer your question? Adding class objects to Pytorch Dataloader: batch must contain tensors. It shows how to use collate_fnFallonfallout
O
71

Basically, the collate_fn receives a list of tuples if your __getitem__ function from a Dataset subclass returns a tuple, or just a normal list if your Dataset subclass returns only one element. Its main objective is to create your batch without spending much time implementing it manually. Try to see it as a glue that you specify the way examples stick together in a batch. If you don’t use it, PyTorch only put batch_size examples together as you would using torch.stack (not exactly it, but it is simple like that).

Suppose for example, you want to create batches of a list of varying dimension tensors. The below code pads sequences with 0 until the maximum sequence size of the batch, that is why we need the collate_fn, because a standard batching algorithm (simply using torch.stack) won’t work in this case, and we need to manually pad different sequences with variable length to the same size before creating the batch.

def collate_fn(data):
    """
       data: is a list of tuples with (example, label, length)
             where 'example' is a tensor of arbitrary shape
             and label/length are scalars
    """
    _, labels, lengths = zip(*data)
    max_len = max(lengths)
    n_ftrs = data[0][0].size(1)
    features = torch.zeros((len(data), max_len, n_ftrs))
    labels = torch.tensor(labels)
    lengths = torch.tensor(lengths)

    for i in range(len(data)):
        j, k = data[i][0].size(0), data[i][0].size(1)
        features[i] = torch.cat([data[i][0], torch.zeros((max_len - j, k))])

    return features.float(), labels.long(), lengths.long()

The function above is fed to the collate_fn param in the DataLoader, as this example:

DataLoader(toy_dataset, collate_fn=collate_fn, batch_size=5)

With this collate_fn function, you always gonna have a tensor where all your examples have the same size. So, when you feed your forward() function with this data, you need to use the length to get the original data back, to not use those meaningless zeros in your computation.

Source: Pytorch Forum

Osmund answered 24/1, 2021 at 20:7 Comment(2)
This answer is great, but there are way more straightforward ways to do this. For this interested reader (for efficiency and clarity of the code) it is possible to use padding, two references bc the space in the comment is too small for more: the documentation pytorch.org/docs/stable/generated/… and a visual explaination #51031282Hieratic
Shouldn't k be replaced with n_ftrs? They are supposedly a constant throughout all the samples.Gonophore
B
4

You can use the pad_sequence (as mentioned in the comments above by Marine Galantin) to simplify the collate_fn

For example, if data contains a list of tuples where the first element is the input data and the second the label. Note that batch_first may need to be adapted depending on your own problem/model.

def collate_fn(data: List[Tuple[torch.Tensor, torch.Tensor]]):
    tensors, targets = zip(*data)
    features = pad_sequence(tensors, batch_first=True)
    targets = torch.stack(targets)
    return features, targets

and then the DataLoader:

DataLoader(toy_dataset, collate_fn=collate_fn, batch_size=5)
Bud answered 8/9, 2023 at 12:38 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.