How do I train a encoder-decoder model for a translation task using hugging face transformers?
Asked Answered
S

1

11

I would like to train a encoder decoder model as configured below for a translation task. Could someone guide me as to how I can set-up a training pipeline for such a model? Any links or code snippets would be appreciated to understand.

from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel

# Initializing a BERT bert-base-uncased style configuration
config_encoder = BertConfig()
config_decoder = BertConfig()

config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)

# Initializing a Bert2Bert model from the bert-base-uncased style configurations
model = EncoderDecoderModel(config=config)
Sonjasonnet answered 18/6, 2020 at 9:31 Comment(1)
Did you find anything about the subject?Ob
E
6

The encoder-decoder models are used in the same as any other models in Transformers. It accepts batches of tokenized text as vocabulary indices (i.e., you need a tokenizer that is suitable for your sequence-to-sequence task). When you feed the model with the input (input_ids) and the desired output (decoder_input_ids and labels), you will get the loss value that you can optimize during training. Note that if the sentences in the batch have different lengths, you need to do masking too. This is a minimum example for the EncoderDecoderModel documentation:

from transformers import EncoderDecoderModel, BertTokenizer
import torch

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = EncoderDecoderModel.from_encoder_decoder_pretrained(
    'bert-base-uncased', 'bert-base-uncased')
input_ids = torch.tensor(
    tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)
outputs = model(
    input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids, 
    return_dict=True)
loss = outputs.loss

If you do not want to write the training loop yourself, you can use dataset processing (DataCollatorForSeq2Seq) and training (Seq2SeqTrainer) utilities from Transformers. You can follow the Seq2Seq example on GitHub.

Ernestinaernestine answered 24/2, 2021 at 8:35 Comment(1)
Can you please update the link to the example on Github? it isin't workingEncephalon

© 2022 - 2024 — McMap. All rights reserved.