Questions about training LLMs on large text datasets for text generation from scratch
Asked Answered
O

1

-1

I made a fully custom made GPT in Jax (with Keras 3), using Tensorflow for the data pipeline.

I've trained the model on the Shakespeare dataset and got good results (so no problem with the model). Now I want to train it on the Tiny-Stories dataset which is pretty big with GPT of 15M parameters.

Here is the code for loading the data:

def get_dataset_lists(ds_path:str):
    dataset = open(ds_path, "r", encoding="utf-8").read() # [...]
    dataset = dataset.split("<|endoftext|>")
    r.shuffle(dataset)
    dataset:list = spm.Encode( # llama's sentence piece encoder
            tf.strings.strip(dataset).numpy().tolist(), 
            add_bos=True,
            add_eos=False
        ) # [[SOS story], ..., [SOS story]]
    print("\tNumber of stories:", len(dataset))
    return dataset

def tf_dataload(
    dataset:list,
    batch_size:int,
    maxlen:int,
    shift:int,
):
    import functools; import operator
    dataset = functools.reduce(operator.iconcat, dataset, [])
    num_tokens = len(dataset); print("\tNumber of tokens in the dataset is", num_tokens)
    unique_tok = set(dataset); print("\tNumber of unique tokens in the dataset is", len(unique_tok))
    # [SOS story ... SOS story]
    dataset = tf.data.Dataset.from_tensor_slices(dataset)
    dataset = dataset.window(maxlen+1, shift=shift, drop_remainder=True)
    # [[...], [...], [...], ...] shape(m, maxlen+1)
    dataset = dataset.flat_map(lambda window: window.batch(maxlen+1))
    dataset = dataset.shuffle(10_000*batch_size, reshuffle_each_iteration=reshuffle_each_iteration)
    # [ [ [...], [...], [...], ...], ...] shape(m//B, B, maxlen+1)
    dataset = dataset.batch(batch_size, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.shuffle(batch_size*100)
    dataset = dataset.map(lambda window: (window[:, :-1], window[:, 1:]), num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)
    return dataset # (shape(m//B, B, maxlen) shape(m//B, B, maxlen))

def load_data(
    train_ds_path:str,
    val_ds_path:str,
    batch_size:int,
    maxlen:int,
    shift:int,
):  
    print("Training Dataset:")
    train_ds = tf_dataload(get_dataset_lists(train_ds_path), batch_size, maxlen, shift, reshuffle_each_iteration=True)
    print("Validation Dataset:")
    val_ds = tf_dataload(get_dataset_lists(val_ds_path), batch_size, maxlen, shift, reshuffle_each_iteration=True)
    print(f"\n{train_ds}\n{val_ds}")
    datasets = {"train": train_ds.repeat(), "val":val_ds}
    return datasets
  • I've certain questions regarding the value of the shift?
    First I set it equal to 1, but the training was very slow, even after 100000 steps it didn't converge even though it was decreasing slowly (I think there's no problem with the learning rate as I plotted Loss Vs Lr and selected the max learning rate possible and used cosine decay with warmup)

enter image description here

So I looked into Karpathy's llama-2 repo and the shift was equal to maxlen. So I set it equal to maxlen and trained it for 100000 steps but the model is learning very slowly, and didn't get a loss even close to what Karpathy got (I don't know what's the problem, as I've closely followed Karpathy's llama2 repo) What is shift generally equal to when pre-training an LLM on Language Modelling? Shouldn't it be 1, because the transformer model is not positionally invariant, and it would affect model performance if shift is not equal to 1? But then the number of samples will be very large...?

  • And for what number of steps to train a LLM given the number of tokens

You may find the below helpful...

@dataclass
class GPTArgs:
    """GPT Configuration"""
    d_model:int = 288
    num_layers:int = 6
    num_heads:int = 6
    max_context_length:int = 256
    vocab_size:int = VOCAB_SIZE # 32K
    output_units:int = None # equal to vocab_size if None in model init  
    assert d_model % 2 == 0
    assert d_model % num_heads == 0
    dropout_rate:float = 0.1

@dataclass
class TArgs:
    # lr scheduler
    init_lr:float = 1e-7
    max_lr:float = 6.5e-4
    min_lr:float = 0.1*max_lr # The factor is usually 0.1 or 0.0
    num_steps:int = 100_000
    warmup_steps:int = 1000 # 1000, to make training more stable instead of 2000
    decay_steps:int = num_steps

    # optimizer
    beta1:float = 0.9
    beta2:float = 0.95
    weight_decay:float = 1e-1
    clipvalue:float = 1e0
    num_grad_accumalation_steps:int = 4
    # num_tok_per_update = batch_size * maxlen * gradient_accumalation = 128 * 256 * 4 = 131_072

    # training
    checkpoint:str = 'weights/GPTstories/Epoch{epoch}.weights.h5'
    train_ds_path:str = "TinyStoriesDataset/TinyStories-train.txt"
    val_ds_path:str = "TinyStoriesDataset/TinyStories-valid.txt"
    steps_per_epoch = eval_freq = 2000
    eval_steps:int = 200
    batch_size:int = 128 
    patience:int = 10 # early stopping with restore best weights

Update 1:

I thought that the model wasn't getting the training samples uniformly so I modified the data pipeline and also increased the number of steps to 200,000. But there were no significant improvements. The training was still very slow by the end and loss was decreasing by 0.01 every epoch (of 2000 steps)... Got a loss of 1.67 on validation set

def pretokenize_and_save_dataset(dataset_path:str, num_shards:int, shard_dir:str):
    dataset = open(dataset_path, "r", encoding="utf-8").read() # [...]
    dataset = dataset.split("<|endoftext|>")
    r.shuffle(dataset)
    dataset:list = spm.Encode(
            tf.strings.strip(dataset).numpy().tolist(), 
            add_bos=True,
            add_eos=False
        ) # [[SOS story], ..., [SOS story]]
    print("Dataset:")
    print("\tNumber of stories:", len(dataset))

    # flatten
    dataset = functools.reduce(operator.iconcat, dataset, [])
    num_tokens = len(dataset); print("\tNumber of tokens in the dataset:", num_tokens)
    print("\tNumber of unique tokens in the dataset:", len(set(dataset)))
    
    dataset = np.asarray(dataset, dtype=np.uint16) # [SOS story ... SOS story]
    print("\tAvg length of story:", num_tokens/((dataset==1).sum()))

    # shard and save dataset
    sharded_datasets_list = np.array_split(dataset, num_shards) # [[SOS story...], [...], [...], ...]
    filenames = [os.path.join(shard_dir, f"shard{i+1}.npy") for i in range(num_shards)]
    
    for filename, sharded_ds in zip(filenames, sharded_datasets_list):
        with open(filename, "wb") as f:
            np.save(f, sharded_ds)
    return filenames

def load_data_as_tfds(
    dataset:np.ndarray,
    maxlen:int,
    shift:int,
):
    # [SOS story ... SOS story]
    dataset = tf.data.Dataset.from_tensor_slices(dataset.tolist())
    dataset = dataset.window(maxlen+1, shift=shift, drop_remainder=True)
    # [[...], [...], [...], ...] shape(m, maxlen+1)
    dataset = dataset.flat_map(lambda window: window.batch(maxlen+1))
    dataset = dataset.shuffle(10_000*128)
    return dataset

def batch_tfds(
        dataset:tf.data.Dataset,
        batch_size:int,
):
    dataset = dataset.batch(batch_size, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.shuffle(batch_size*1000)
    dataset = dataset.map(lambda window: (window[:, :-1], window[:, 1:]), num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.repeat().prefetch(tf.data.AUTOTUNE)
    return dataset

def load_data(
    dataset_path:str,
    batch_size:int,
    maxlen:int,
    shift:int,
    num_shards:int,
    shard_dir:str
):  
    if os.path.exists(shard_dir) and os.listdir(shard_dir):
        filenames = glob.glob(os.path.join(shard_dir, "*.npy"))
    else:
        os.makedirs(shard_dir)
        filenames = pretokenize_and_save_dataset(dataset_path, num_shards=num_shards, shard_dir=shard_dir)
    r.shuffle(filenames)
    to_tfds = lambda dataset: load_data_as_tfds(dataset, maxlen=maxlen, shift=shift)
    num_train_shards = round(0.9651*num_shards)
    num_val_shards = num_shards-num_train_shards

    print("Training Dataset:")
    print(f"\tNumber of files taken for training: {num_train_shards}/{num_shards}")
    train_datasets_lists = [to_tfds(np.load(filename)) for filename in filenames[:num_train_shards]]
    train_ds = tf.data.Dataset.sample_from_datasets(train_datasets_lists, weights=[1/num_train_shards]*num_train_shards)
    # [ [ [...], [...], [...], ...], ...] shape(m//B, B, maxlen+1)
    train_ds = batch_tfds(train_ds, batch_size=batch_size)

    print("Validation Dataset:")
    print(f"\tNumber of files taken for validation: {num_val_shards}/{num_shards}")
    val_datasets_lists = [to_tfds(np.load(filename)) for filename in filenames[num_train_shards:]]
    val_ds = tf.data.Dataset.sample_from_datasets(val_datasets_lists, weights=[1/num_val_shards]*num_val_shards)
    # [ [ [...], [...], [...], ...], ...] shape(m//B, B, maxlen+1)
    val_ds = batch_tfds(val_ds, batch_size=batch_size)

    print(f"\n{train_ds}\n{val_ds}")
    datasets = {"train": train_ds, "val":val_ds}
    return datasets

Update 2

  • Gradient accumulation results in slow training, so changed it to 1 and trained it for 200,000 epochs, got a validation loss of 1.60.
Osanna answered 22/2 at 8:33 Comment(0)
O
1

Replaced keras's gradient accumalation argument in AdamW with a custom implementation like in karpathy's and now the loss is decreasing faster. (Will report further details after training fully)

Osanna answered 9/3 at 12:42 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.