Using Chain and Parser together in langchain
Asked Answered
L

2

11

The langchain docs include this example for configuring and invoking a PydanticOutputParser

# Define your desired data structure.
class Joke(BaseModel):
    setup: str = Field(description="question to set up a joke")
    punchline: str = Field(description="answer to resolve the joke")
    
    # You can add custom validation logic easily with Pydantic.
    @validator('setup')
    def question_ends_with_question_mark(cls, field):
        if field[-1] != '?':
            raise ValueError("Badly formed question!")
        return field

# And a query intented to prompt a language model to populate the data structure.
joke_query = "Tell me a joke."

# Set up a parser + inject instructions into the prompt template.
parser = PydanticOutputParser(pydantic_object=Joke)

prompt = PromptTemplate(
    template="Answer the user query.\n{format_instructions}\n{query}\n",
    input_variables=["query"],
    partial_variables={"format_instructions": parser.get_format_instructions()}
)

However, the code for actually making the API call is a bit weird:

model_name = 'text-davinci-003'
temperature = 0.0
my_llm = OpenAI(model_name=model_name, temperature=temperature)

_input = prompt.format_prompt(query=joke_query)
output = my_llm(_input.to_string())

parser.parse(output)

This returns exactly what we want: Joke(setup='Why did the chicken cross the road?', punchline='To get to the other side!')

However, it seems odd to not to use Chains for this.

I can get kind of close, as follows:

chain = LLMChain(llm=my_llm, prompt=prompt)
chain.run(query=joke_query)

But this returns raw, unparsed text: '\n{"setup": "Why did the chicken cross the road?", "punchline": "To get to the other side!"}'

Is there a preferred method to get the Chain class to make full use of a Parser, and return the parsed object? I could subclass and extend LLMChain, but I'd be surprised if this functionality doesn't already exist.

Linoleum answered 2/4, 2023 at 5:41 Comment(1)
Totally agree, this doesn't feel very LangChainic. Furthermore, for some reason the PromptTemplate class takes an optional output parser argument...Nathalie
L
5

You can use a TransformChain for this!

from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain, TransformChain
from langchain.chains import SequentialChain


llm = ChatOpenAI(temperature=0.5)
llm_chain = LLMChain(
    prompt=prompt,
    llm=llm,
    output_key="json_string",
)

def parse_output(inputs: dict) -> dict:
    text = inputs["json_string"]
    return {"result": parser.parse(text)}

transform_chain = TransformChain(
    input_variables=["json_string"],
    output_variables=["result"],
    transform=parse_output
)

chain = SequentialChain(
    input_variables=["joke_query"],
    output_variables=["result"],
    chains=[llm_chain, transform_chain],
)

chain.run(query="Tell me a joke.")
Lysippus answered 8/4, 2023 at 20:0 Comment(1)
Will this work with map_reduce load_qa_chain?Merell
W
0

You should be able to use the parser to parse the output of the chain. No need to subclass:

output = chain.run(query=joke_query)
bad_joke = parser.parse(output)

Not positive on the syntax because I use langchainjs, but that should get you close.

Winze answered 26/4, 2023 at 3:4 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.