How to early-stop autoregressive model with a list of stop words?
Asked Answered
F

1

5

I am using GPT-Neo model from transformers to generate text. Because the prompt I use starts with '{', so I would like to stop the sentence once the paring '}' is generated. I found that there is a StoppingCriteria method in the source code but without further instructions on how to use it. Does anyone have found a way to early-stop the model generation? Thanks!

Here is what I've tried:

from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, torch_dtype=dtype).eval()

class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords_ids:list):
        self.keywords = keywords_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids in self.keywords:
            return True
        return False

stop_words = ['}', ' }', '\n']
stop_ids = [tokenizer.encode(w) for w in stop_words]
stop_ids.append(tokenizer.eos_token_id)
stop_criteria = KeywordsStoppingCriteria(stop_ids)

model.generate(
    text_inputs='some text:{', 
    StoppingCriteria=stop_criteria
)

Ferrell answered 1/10, 2021 at 9:30 Comment(2)
Can you post a minimal reproducible example of your current code?Halbert
If I had an example answer to this question, I don't have to post this question at the first place :p . But I'll post a snippet of what I've tried.Ferrell
J
7

I've been able to adapt your code to work. Additionally, make sure you're using a recent version of transformers, you may have to upgrade.

import torch
from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id).eval()

class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords_ids:list):
        self.keywords = keywords_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids[0][-1] in self.keywords:
            return True
        return False


stop_words = ['}', ' }', '\n']
stop_ids = [tokenizer.encode(w)[0] for w in stop_words]
stop_criteria = KeywordsStoppingCriteria(stop_ids)


inputs = tokenizer.encode('some text: {', add_special_tokens=False, return_tensors='pt')

output = model.generate(
    inputs,
    do_sample=True,
    stopping_criteria=StoppingCriteriaList([stop_criteria]),

)
print(tokenizer.decode(*output))
Johann answered 25/4, 2022 at 17:22 Comment(1)
This is a life saver _/\_Thanasi

© 2022 - 2024 — McMap. All rights reserved.