How to implement `stopping_criteria` parameter in transformers library?
Asked Answered
R

1

7

I am using the python huggingface transformers library for a text-generation model. I need to know how to implement the stopping_criteria parameter in the generator() function I am using.

I found the stopping_criteria parameter in this documentation: https://huggingface.co/transformers/main_classes/pipelines.html#transformers.TextGenerationPipeline

The problem is, I just dont know how to implement it.

My Code:

from transformers import pipeline
generator = pipeline('text-generation', model='EleutherAI/gpt-neo-125M')
stl = StoppingCriteria(['###'])
res = generator(prompt, do_sample=True,stopping_criteria = stl)
Ramulose answered 6/7, 2021 at 21:41 Comment(2)
I was also curious about this. I'd like to be able to provide a particular stopping token (other than the EOS token). Did you work this out?Concerned
How do we resolve "StoppingCriteria() takes no arguments"? It says it's an abstract class; so do we need to define a new class and if so which methods do we override?Dormant
S
2

These two approaches worked for me. your_condition is True when you want to stop.

class CustomStoppingCriteria(StoppingCriteria):
    def __init__(self):
        pass
    
    def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
        return your_condition

stopping_criteria = StoppingCriteriaList([CustomStoppingCriteria()])

OR

def custom_stopping_criteria(input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
    return your_condition

stopping_criteria = StoppingCriteriaList([custom_stopping_criteria])
Snooperscope answered 16/3, 2023 at 4:35 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.