When we look at HuggingFaceHub model usage in langchain
there's this part that the author doesn't know how to stop the generation, https://github.com/hwchase17/langchain/blob/master/langchain/llms/huggingface_pipeline.py#L182:
class HuggingFacePipeline(LLM):
...
def _call(
...
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop)
return text
What should I use to add the stop token to the end of the template?
If we look at https://github.com/hwchase17/langchain/blob/master/langchain/llms/utils.py, it's simply a regex split that split an input string up based on a list of stopwords, then take the first partition of the re.split
re.split("|".join(stop), text)[0]
Lets try to get a generation output from a Huggingface model, e.g.
from transformers import pipeline
from transformers import GPT2LMHeadModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
output = generator("Hey Pizza! ")
output
[out]:
[{'generated_text': 'Hey Pizza! 」\n\n「Hurry up, leave the place! 」\n\n「Oi! 」\n\nWhile eating pizza and then, Yuigahama came in contact with Ruriko in the middle of the'}]
If we apply the re.split
:
import re
def enforce_stop_tokens(text, stop):
"""Cut off the text as soon as any stop words occur."""
return re.split("|".join(stop), text)[0]
stop = ["up", "then"]
text = output[0]['generated_text']
re.split("|".join(stop), text)
[out]:
['Hey Pizza! 」\n\n「Hurry ',
', leave the place! 」\n\n「Oi! 」\n\nWhile eating pizza and ',
', Yuigahama came in contact with Ruriko in the middle of the']
But that isn't useful, I want to split at the point the generation ends. What tokens do I use to "enforce_stop_tokens"?