I am using GPT-Neo model from transformers
to generate text. Because the prompt I use starts with '{'
, so I would like to stop the sentence once the paring '}'
is generated.
I found that there is a StoppingCriteria
method in the source code but without further instructions on how to use it. Does anyone have found a way to early-stop the model generation? Thanks!
Here is what I've tried:
from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, torch_dtype=dtype).eval()
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords_ids:list):
self.keywords = keywords_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if input_ids in self.keywords:
return True
return False
stop_words = ['}', ' }', '\n']
stop_ids = [tokenizer.encode(w) for w in stop_words]
stop_ids.append(tokenizer.eos_token_id)
stop_criteria = KeywordsStoppingCriteria(stop_ids)
model.generate(
text_inputs='some text:{',
StoppingCriteria=stop_criteria
)