Root-Cause
This is a warning about using the API in the outdated manner (=unsupported soon). However, as of now, the code is fixing this on its own - hence only a warning not a breaking error.
See these lines in the source code.
Remedy
The transformers
library encourages the use of config files. In this case, we need to pass a GenerationConfig
object early, rather than to set attributes.
I will first share a clean, simple example:
from transformers import AutoTokenizer, BartForConditionalGeneration
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
ARTICLE_TO_SUMMARIZE = (
"PG&E stated it scheduled the blackouts in response to forecasts for high winds "
"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
)
inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt")
# change config and generate summary
from transformers.generation import GenerationConfig
model.config.max_new_tokens = 10
model.config.min_length = 1
gen_cfg = GenerationConfig.from_model_config(model.config)
gen_cfg.max_new_tokens = 10
gen_cfg.min_length = 1
summary_ids = model.generate(inputs["input_ids"], generation_config=gen_cfg)
tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
If you try to manipulate the config
attributes directly and pass no config, you get a warning. If you pass a GenerationConfig, you are all good. This example is reproducible as a Colab notebook here.
Now, to the original question. Note that, in general, changing architecture configs of pretrained models is not recommended for incompatibility reasons. This is sometimes possible with extra effort. However, certain config changes are possible upon initialization:
model = BartForConditionalGeneration.from_pretrained(
"facebook/bart-large-cnn",
attention_dropout=0.123
)
Here is the fully-working code, corrected for reproducibility and see also this notebook
from transformers import AutoTokenizer, BartForConditionalGeneration
from transformers.generation import GenerationConfig
from transformers import Trainer, TrainingArguments
from transformers.models.bart.modeling_bart import shift_tokens_right
from transformers import DataCollatorForSeq2Seq
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn", attention_dropout=0.123)
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
seq2seq_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
def get_features(batch):
input_encodings = tokenizer(batch["text"], max_length=1024, truncation=True)
with tokenizer.as_target_tokenizer():
target_encodings = tokenizer(batch["summary"], max_length=256, truncation=True)
return {"input_ids": input_encodings["input_ids"],
"attention_mask": input_encodings["attention_mask"],
"labels": target_encodings["input_ids"]}
dataset_ftrs = dataset.map(get_features, batched=True)
columns = ['input_ids', 'labels', 'input_ids','attention_mask',]
dataset_ftrs.set_format(type='torch', columns=columns)
training_args = TrainingArguments(
output_dir='./models/bart-summarizer',
num_train_epochs=1,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
)
model.config.output_attentions = True
model.config.output_hidden_states = True
training_args = TrainingArguments(
output_dir='./models/bart-summarizer',
num_train_epochs=1,
warmup_steps=500,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
weight_decay=0.01,
logging_steps=10,
push_to_hub=False,
evaluation_strategy='steps',
eval_steps=500,
save_steps=1e6,
gradient_accumulation_steps=16,
)
trainer = Trainer(
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=seq2seq_data_collator,
train_dataset=dataset_ftrs["train"],
eval_dataset=dataset_ftrs["test"],
)
assert model.config.attention_dropout==0.123
#trainer.train()
Seq2SeqTrainingArguments
andSeq2SeqTrainer
. Without that it's hard to pinpoint which arguments you used in those classes that's raising the deprecation warning. – Barmmodel.config.xxx
lines. – Volkman