How does `enforce_stop_tokens` work in LangChain with Huggingface models?
Asked Answered
S

1

6

When we look at HuggingFaceHub model usage in langchain there's this part that the author doesn't know how to stop the generation, https://github.com/hwchase17/langchain/blob/master/langchain/llms/huggingface_pipeline.py#L182:

class HuggingFacePipeline(LLM):
        ...
    def _call(
        ...
        if stop is not None:
            # This is a bit hacky, but I can't figure out a better way to enforce
            # stop tokens when making calls to huggingface_hub.
            text = enforce_stop_tokens(text, stop)
        return text

What should I use to add the stop token to the end of the template?


If we look at https://github.com/hwchase17/langchain/blob/master/langchain/llms/utils.py, it's simply a regex split that split an input string up based on a list of stopwords, then take the first partition of the re.split

re.split("|".join(stop), text)[0]

Lets try to get a generation output from a Huggingface model, e.g.

from transformers import pipeline
from transformers import GPT2LMHeadModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
output = generator("Hey Pizza! ")
output

[out]:

[{'generated_text': 'Hey Pizza! 」\n\n「Hurry up, leave the place! 」\n\n「Oi! 」\n\nWhile eating pizza and then, Yuigahama came in contact with Ruriko in the middle of the'}]

If we apply the re.split:

import re
def enforce_stop_tokens(text, stop):
    """Cut off the text as soon as any stop words occur."""
    return re.split("|".join(stop), text)[0]

stop = ["up", "then"]
text = output[0]['generated_text']

re.split("|".join(stop), text)

[out]:

['Hey Pizza! 」\n\n「Hurry ',
 ', leave the place! 」\n\n「Oi! 」\n\nWhile eating pizza and ',
 ', Yuigahama came in contact with Ruriko in the middle of the']

But that isn't useful, I want to split at the point the generation ends. What tokens do I use to "enforce_stop_tokens"?

Spotlight answered 14/6, 2023 at 16:4 Comment(0)
H
1

You could do this by setting the eos_token_id as your stop term(s)-- in my testing it seemed to work with a list. See below: regex cuts off the stopword, eos_token_id cuts off just after the stopword ("once upon a time" vs. "once upon a")


from transformers import GPT2LMHeadModel, GPT2Tokenizer
import regex as re

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

# Define your custom stop terms
stop_terms = [ "right", "time"]

# Ensure the stop terms are in the tokenizer's vocabulary
for term in stop_terms:
    if term not in tokenizer.get_vocab():
        tokenizer.add_tokens([term])
        model.resize_token_embeddings(len(tokenizer))

def enforce_stop_tokens(text, stop):
    """Cut off the text as soon as any stop words occur."""
    return re.split("|".join(stop), text)[0]

# Get the token IDs for your custom stop terms
eos_token_ids_custom = [tokenizer.encode(term, add_prefix_space=True)[0] for term in stop_terms]

# Generate text
input_text = "Once upon "
input_ids = tokenizer.encode(input_text, return_tensors='pt')
output_ids = model.generate(input_ids, eos_token_id=eos_token_ids_custom, max_length=50)

# Decode the output IDs to text
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

print(generated_text) # Once upon a time

print("ENFORCE STOP TOKENS")

truncated_text = enforce_stop_tokens(generated_text, stop_terms)

print(truncated_text) # Once upon a 

Haug answered 8/8, 2023 at 23:6 Comment(2)
Wouldn't that always end the generation in 1 sentence?Spotlight
@Spotlight I don't think so-- in my [colab.research.google.com/drive/… colab) Input text: "I am", without stop token enforcement: "# no stop token enforcement: "I am not a fan of the idea of a "big-budget" movie. I think it's a waste of money. I think it's a waste of time..." with code above + stopwords ["money", "time"] it ends on the second sentence. hf documentationHaug

© 2022 - 2024 — McMap. All rights reserved.