LangChain Adapter Tutorial
import os
os.environ["OPENROUTER_API_KEY"] = input("OPENROUTER_API_KEY: ")
from __future__ import annotations
import math
import random
import re
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from gepa import optimize
from gepa.adapters.langchain_adapter import (
LangChainAdapter,
last_message_text,
make_reflection_lm,
)
from langchain.chat_models import init_chat_model
SEED_SYSTEM_PROMPT = """Add adjacent pairs of numbers, then multiply the results.
If only one pair numbers, just add the numbers
Pairs Example: 1,2
- Add pairs: 1+2=3
- multiply: Assume 1, 3*1=3
- Answer: 3
Example: 3, 5, 2, 4
- Add pairs: 3+5=8, 2+4=6
- Multiply: 8*6=48
- Answer: 48
Now solve the problem below. Put your final answer in <answer> tags. Be very concise"""
Generate the data¶
A synthetic math task: add adjacent pairs, then multiply the results. Copied environment from Jonathan Whitaker (https://www.youtube.com/watch?v=yId2PE5Qmqo)
def generate_problem(rng: random.Random, difficulty: str) -> dict:
if difficulty == "easy":
n = 4
elif difficulty == "medium":
n = 6
else:
n = 12
nums = [rng.randint(1, 9) if rng.random() < 0.7 else rng.randint(10, 19) for _ in range(n)]
sums = [nums[i] + nums[i + 1] for i in range(0, len(nums), 2)]
answer = math.prod(sums)
return {
"input": f"Numbers: {', '.join(map(str, nums))}",
"answer": str(answer),
"additional_context": {"nums": nums, "sums": sums},
}
def generate_dataset(num_examples: int, difficulty: str, seed: int) -> list[dict]:
rng = random.Random(seed)
return [generate_problem(rng, difficulty) for _ in range(num_examples)]
total = 200
train_size = 100
val_size = 50
all_data = generate_dataset(num_examples=total, difficulty="hard", seed=42)
train_set = all_data[:train_size]
val_set = all_data[train_size : train_size + val_size]
test_set = all_data[train_size + val_size :]
print(f"Train: {len(train_set)}, Val: {len(val_set)}, Test: {len(test_set)}")
Train: 100, Val: 50, Test: 50
Load the chat models¶
openrouter models - but feel free to use anything you have access to via LangChain!
task_llm = init_chat_model(
"openrouter:openai/gpt-4.1-nano",
)
reflection_llm = init_chat_model("openrouter:openai/gpt-5-mini", reasoning={"effort": "medium"})
resp = task_llm.invoke("hi")
print(resp.content)
Hello! How can I assist you today?
resp = reflection_llm.invoke("hi")
print(resp.content)
Hi — how can I help you today?
Create rollout_fn¶
The rollout_fn takes a candidate item from the optimizer + example from the dataset, runs the langchain model/graph, and returns a state object. For a simple single turn chat model invoke, the state object is just the updated messages list
def rollout(candidate: dict[str, str], example: dict, llm: BaseChatModel) -> dict:
messages = [
SystemMessage(content=candidate["system_prompt"]),
HumanMessage(content=example["input"]),
]
result = llm.invoke(messages)
if not isinstance(result, AIMessage):
result = AIMessage(content=getattr(result, "content", str(result)))
return {"messages": messages + [result]}
Create eval_fn¶
The eval_fn takes as input the state object (dict) from the rollout and assigns a score/textual feedback for the optimizer
def evaluate_response(data: dict, state: dict) -> tuple[float, str]:
correct_answer = data["answer"]
nums = data["additional_context"]["nums"]
sums = data["additional_context"]["sums"]
response = last_message_text(state)
match = re.search(r"<answer>\s*(\S+)\s*</answer>", response)
if not match:
return 0.0, (f"Missing <answer> tags. Could not parse answer. The correct answer is {correct_answer}.")
parsed = match.group(1).strip()
if parsed == correct_answer.strip():
return 1.0, "Correct."
pairs_str = ", ".join(f"{nums[i]}+{nums[i + 1]}={sums[i // 2]}" for i in range(0, len(nums), 2))
return 0.0, (
f"Wrong answer: you said {parsed}, correct is {correct_answer}. "
f"Steps: add pairs [{pairs_str}], then multiply sums to get {correct_answer}."
)
Create the Adapter¶
The adapter follows the standard interface for GEPA adapters is what exposes the evaluate() and make_reflective_dataset() functions for the gepa optimizer
reflection_lm = make_reflection_lm(reflection_llm)
adapter = LangChainAdapter(
rollout_fn=lambda candidate, example: rollout(candidate, example, task_llm),
eval_fn=evaluate_response,
num_threads=2,
)
Understand your starting point/baseline¶
With the seed prompt, how do we perform on our test dataset?
print("\nBaseline evaluation on test set...")
baseline_batch = adapter.evaluate(
batch=test_set,
candidate={"system_prompt": SEED_SYSTEM_PROMPT},
capture_traces=True,
)
n = len(test_set)
baseline_correct = sum(1 for s in baseline_batch.scores if s == 1.0)
baseline_acc = baseline_correct / n * 100
print(f"\nBaseline: {baseline_correct}/{n} ({baseline_acc:.1f}%)")
/Users/bkuchars/Documents/gepa/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
Baseline evaluation on test set...
Baseline: 11/50 (22.0%)
adapter.evaluate() returns a gepa.core.adapter.EvaluationBatch, which lets you inspectthe scores, outputs, and trajectories of all items in the batch.
EXAMPLE_ID = 1
baseline_batch.scores[EXAMPLE_ID]
0.0
baseline_batch.outputs[EXAMPLE_ID]["state"]["messages"][-1].content
'Pairs: 4+18=22, 15+14=29, 9+4=13, 8+1=9, 2+19=21, 3+18=21\n\nMultiply: 22*29*13*9*21*21\n\nCalculations:\n22*29=638\n638*13=8294\n8294*9=74646\n74646*21=15656526\n15656526*21=328687506\n\n<answer>328687506</answer>'
print(baseline_batch.trajectories[0].keys())
print(baseline_batch.trajectories[EXAMPLE_ID]["feedback"])
dict_keys(['data', 'state', 'score', 'feedback']) Wrong answer: you said 328687506, correct is 32918886. Steps: add pairs [4+18=22, 15+14=29, 9+4=13, 8+1=9, 2+19=21, 3+18=21], then multiply sums to get 32918886.
Run the Optimization¶
result = optimize(
seed_candidate={"system_prompt": SEED_SYSTEM_PROMPT},
trainset=train_set,
valset=val_set,
adapter=adapter,
reflection_lm=reflection_lm,
max_metric_calls=500,
reflection_minibatch_size=3,
candidate_selection_strategy="pareto",
use_merge=True,
display_progress_bar=True,
seed=0,
)
print(f"\nBest val score: {result.val_aggregate_scores[result.best_idx]}")
print("\nStarting system prompt:")
print("=" * 80)
print(SEED_SYSTEM_PROMPT)
print("=" * 80)
print("\nOptimized system prompt:")
print("=" * 80)
print(result.best_candidate["system_prompt"])
print("=" * 80)
Best val score: 0.7 Starting system prompt: ================================================================================ Add adjacent pairs of numbers, then multiply the results. If only one pair numbers, just add the numbers Pairs Example: 1,2 - Add pairs: 1+2=3 - multiply: Assume 1, 3*1=3 - Answer: 3 Example: 3, 5, 2, 4 - Add pairs: 3+5=8, 2+4=6 - Multiply: 8*6=48 - Answer: 48 Now solve the problem below. Put your final answer in <answer> tags. Be very concise ================================================================================ Optimized system prompt: ================================================================================ You are given inputs in the form "Numbers: a, b, c, ..." (commas and spacing may vary). Your job is to compute a single integer result and output it exactly on one line enclosed in <answer>...</answer> tags and nothing else. Precise task and rules: - Parse the input in the given order and extract all integers (allow positive, negative, and zero; integers may have optional + or - sign). Ignore any non-numeric text. You may assume at least one integer is present. - Form adjacent, non-overlapping pairs from the list in order: (n1,n2), (n3,n4), (n5,n6), ... . - For each pair compute pair_sum = ni + n(i+1). - If there are multiple pairs, compute the final result as the product of all pair_sums (multiply every pair_sum together). - If there is exactly one pair (exactly two numbers), the final result is that pair's sum. - If the list has an odd number of integers, treat the final unpaired last integer as a standalone multiplicative factor (multiply the product of pair_sums by that last integer). - If the input has exactly one integer, return that integer (it is the final product). - Use exact integer arithmetic (arbitrary-precision if necessary). Do not perform any floating-point rounding. - Do NOT include thousands separators, commas, spaces, or any extra text in the numeric result. - Output must be a single line containing nothing but the final integer enclosed in <answer> and </answer> tags. Example: <answer>12345</answer> Common pitfalls to avoid (learned from examples): - Always pair sequentially and non-overlapping until you run out of numbers; do not leave two numbers unpaired at the end when they should form a final pair. - Do not treat multiple trailing numbers incorrectly—only one trailing number can exist (if the count is odd). - Multiply the pair sums together exactly; do not drop, combine, or reorder pair-sums incorrectly. - Support negative and zero values correctly (they affect sums and product signs). Examples (for your reference only; do not output these in responses): - Input: "Numbers: 3,5,2,4" -> pair_sums = [8,6] -> result = 8*6 = 48 -> output: <answer>48</answer> - Input: "Numbers: 4,5,6" -> pair_sums = [9], leftover 6 -> result = 9*6 = 54 -> output: <answer>54</answer> - Input: "Numbers: 7" -> result = 7 -> output: <answer>7</answer> ================================================================================
Evaluate New System Prompt on Test Set¶
print("\nOptimized evaluation on test set...")
optimized_batch = adapter.evaluate(
batch=test_set,
candidate=result.best_candidate,
capture_traces=False,
)
optimized_correct = sum(1 for s in optimized_batch.scores if s == 1.0)
optimized_acc = optimized_correct / n * 100
print(f"\nBaseline: {baseline_correct}/{n} ({baseline_acc:.1f}%)")
print(f"Optimized: {optimized_correct}/{n} ({optimized_acc:.1f}%)")
print(f"Delta: {optimized_acc - baseline_acc:+.1f}%")
Optimized evaluation on test set...
Baseline: 11/50 (22.0%) Optimized: 29/50 (58.0%) Delta: +36.0%