I have some custom data set with custom table entries and wanted to deal with it with a custom collate. But it didn't work when I pass a collate function I wrote (that DOES work on a individual dataloader e.g., see How does one create a pytorch data loader with a custom hugging face data set without having errors? or How does one create a pytoch data loader using an interleaved hugging face dataset?) . It just doesn't work with HF trianer.
Code
from pathlib import Path
# token = open(Path('~/data/hf_token.txt').expanduser()).read().strip()
token = None
batch_size = 8
# -- AF now
from datasets import load_dataset
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained("gpt2")
device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# -- Get batch from dataset
from datasets import load_dataset
# path, name = 'brando/debug1_af', 'debug1_af'
path, name = 'brando/debug0_af', 'debug0_af'
# train_dataset = load_dataset(path, name, streaming=True, split="train", token=token).with_format(type="torch")
# eval_dataset = load_dataset(path, name, streaming=True, split="test", token=token).with_format(type="torch")
# batch = dataset.take(1)
# column_names = next(iterbatch).keys()
# print(f'{column_names=}')
# -- Compute max steps (I think we should try to do this for real experiments such that the number of tokens is the same in all training runs for fair experiments, todo: ask Sudharsan or online, for now just make streaming=False)
train_dataset = load_dataset(path, name, streaming=False, split="train", token=token).with_format(type="torch") # hack to get dataset size
eval_dataset = load_dataset(path, name, streaming=False, split="test", token=token).with_format(type="torch") # hack to get dataset size
print(f'{len(train_dataset)=}')
print(f'{len(eval_dataset)=}')
per_device_train_batch_size = batch_size
num_epochs = 1
max_steps = (len(train_dataset) // per_device_train_batch_size) * num_epochs
print(f'{max_steps=}')
# -- Get trainer
def collate_tokenize(data):
text_batch = [f'informal statement {example["generated informal statement"]} formal statement {example["formal statement"]}' for example in data]
tokenized = tokenizer(text_batch, padding='longest', max_length=128, truncation=True, return_tensors='pt')
return tokenized
from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
output_dir=Path('./results').expanduser(), # output directory
max_steps=max_steps, # max_steps
per_device_train_batch_size=per_device_train_batch_size, # batch size per device during training
per_device_eval_batch_size=batch_size, # batch size for evaluation
warmup_steps=500, # number of warmup steps for learning rate scheduler
weight_decay=0.01, # strength of weight decay
logging_dir=Path('./logs').expanduser(), # directory for storing logs
logging_steps=10,
report_to='none',
)
trainer = Trainer(
model=model, # the instantiated π€ Transformers model to be trained
args=training_args, # training arguments, defined above
train_dataset=train_dataset, # training dataset
eval_dataset=eval_dataset, # evaluation dataset
data_collator = collate_tokenize,
)
trainer.train()
print('Done!\a')
error:
len(train_dataset)=14
len(eval_dataset)=13
max_steps=1
/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
warnings.warn(
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-2-4403554fc52d> in <cell line: 63>()
61 data_collator = collate_tokenize,
62 )
---> 63 trainer.train()
64 print('Done!\a')
11 frames
/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py in _check_valid_index_key(key, size)
524 if isinstance(key, int):
525 if (key < 0 and key + size < 0) or (key >= size):
--> 526 raise IndexError(f"Invalid key: {key} is out of bounds for size {size}")
527 return
528 elif isinstance(key, slice):
IndexError: Invalid key: 12 is out of bounds for size 0
why? How to fix?