ConfidenceAdapter AG News Tutorial
ConfidenceAdapter Tutorial: AG News Classification
This notebook shows an executable, AG News only, walkthrough for comparing DefaultAdapter and ConfidenceAdapter in GEPA.
Setup¶
The full benchmark script in this repository already covers three datasets. This notebook intentionally limits scope to AG News so it can run end-to-end in under a practical teaching budget.
API keys: set OPENAI_API_KEY and ANTHROPIC_API_KEY in your environment (or a .env file in this directory).
If you use an OpenAI-compatible proxy (e.g. a LiteLLM gateway) that routes all models through one endpoint, set USE_LITELLM_PROXY = True in the config cell below — only OPENAI_API_KEY and OPENAI_API_BASE are needed in that case.
The first code cell after configuration runs failfast checks on both the task and reflection models, so you will know immediately if something is misconfigured.
import json
import os
import random
import time
from typing import Dict, List
import litellm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datasets import load_dataset
from dotenv import load_dotenv
from llm_structured_confidence import extract_confidence
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import gepa
from gepa.adapters.confidence_adapter import ConfidenceAdapter
from gepa.adapters.default_adapter.default_adapter import DefaultAdapter
from gepa.adapters.default_adapter.default_adapter import EvaluationResult
load_dotenv()
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
# ── Model configuration ─────────────────────────────────────────────
TASK_MODEL = "openai/gpt-4.1-mini"
REFLECTION_MODEL = "anthropic/claude-sonnet-4-6"
# Set to True if you use an OpenAI-compatible proxy (e.g. LiteLLM gateway)
# that routes all models — including Anthropic — through OPENAI_API_BASE.
# When True, only OPENAI_API_KEY is required.
# When False (default), you need OPENAI_API_KEY + ANTHROPIC_API_KEY.
USE_LITELLM_PROXY = False
PROVIDER_KWARGS = {"custom_llm_provider": "openai"} if USE_LITELLM_PROXY else {}
# ── Experiment parameters ────────────────────────────────────────────
AG_NEWS_LABELS = ["World", "Sports", "Business", "Sci/Tech"]
TRAIN_PER_CLASS = 120
VAL_PER_CLASS = 40
TEST_PER_CLASS = 500
MINIBATCH_PER_CLASS = 20
BATCH_WORKERS = 40
BATCH_CHUNK = 80
MAX_ITERATIONS = 10
SEED_PROMPT = "Classify the following text into one of these four categories."
Part 1 — Failfast and structured output sanity check¶
Verify both models respond, test structured JSON output with logprobs, and confirm confidence extraction works — all before spending optimization budget.
def make_response_format(categories: List[str]) -> Dict[str, object]:
return {
"type" : "json_schema",
"json_schema": {
"name" : "news_category",
"strict": True,
"schema": {
"type" : "object",
"properties" : {"category": {"type": "string", "enum": categories}},
"required" : ["category"],
"additionalProperties": False,
},
},
}
RESPONSE_FORMAT = make_response_format(AG_NEWS_LABELS)
def parse_category(raw: str) -> str:
try:
return str(json.loads(raw).get("category", "")).strip()
except (TypeError, json.JSONDecodeError):
return ""
def parse_confidence(response) -> Dict[str, object]:
payload = extract_confidence(
response = response,
response_schema = RESPONSE_FORMAT["json_schema"]["schema"],
field_path = "category",
)
return {
"joint_probability" : payload.get("joint_probability") or payload.get("mean_nonzero_probability") or 0.0,
"top_alternative" : payload.get("top_alternative_resolved") or payload.get("top_alternative") or "",
"top_alternative_probability": payload.get("top_alternative_probability") or payload.get("top_alternative_prob") or 0.0,
}
# ── Failfast: verify both models respond before spending budget ──────
print("Testing task model...")
_test_resp = litellm.completion(
model = TASK_MODEL,
messages = [
{"role": "system", "content": SEED_PROMPT},
{"role": "user", "content": "Oil prices rose after new energy policies in OPEC countries."},
],
response_format = RESPONSE_FORMAT,
temperature = 0,
seed = SEED,
max_tokens = 32,
logprobs = True,
top_logprobs = 5,
**PROVIDER_KWARGS,
)
_test_conf = parse_confidence(_test_resp)
print(f" Task model OK: {parse_category(_test_resp.choices[0].message.content)}"
f" (prob={_test_conf['joint_probability']:.6f},"
f" alt={_test_conf['top_alternative']})")
print("Testing reflection model...")
_THINKING_PAYLOADS = [
("no_thinking", None),
("thinking_disabled", {"thinking": {"type": "disabled"}}),
("thinking_enabled", {"thinking": {"type": "enabled", "budget_tokens": 1024}}),
]
_reflection_ok = False
for _name, _extra in _THINKING_PAYLOADS:
_kwargs = {"model": REFLECTION_MODEL, "messages": [{"role": "user", "content": "Say OK."}],
"max_tokens": 16, **PROVIDER_KWARGS}
if _extra is not None:
_kwargs["extra_body"] = _extra
try:
_r = litellm.completion(**_kwargs)
print(f" Reflection model OK ({_name}): {_r.choices[0].message.content[:40]}")
_reflection_ok = True
WORKING_THINKING_PAYLOAD = _extra
break
except Exception as _e:
print(f" {_name} failed: {str(_e).splitlines()[0][:100]}")
if not _reflection_ok:
raise RuntimeError("Reflection model unreachable — fix API keys / proxy config before continuing.")
print("\nAll models responding. Ready to run.")
Part 2 — Load AG News and prepare splits¶
Train = 120/class, validation = 40/class, test = 500/class.
def load_ag_news_splits() -> Dict[str, List[Dict[str, object]]]:
ds = load_dataset("ag_news")
train_rows = list(ds['train'])
test_rows = list(ds['test'])
label_index = {
0: 'World',
1: 'Sports',
2: 'Business',
3: 'Sci/Tech',
}
def project(rows):
return [
{
'input' : row['text'],
'expected' : label_index[row['label']],
'answer' : label_index[row['label']],
'additional_context': {},
}
for row in rows
]
train_projected = project(train_rows)
test_projected = project(test_rows)
by_train = {label: [] for label in AG_NEWS_LABELS}
by_test = {label: [] for label in AG_NEWS_LABELS}
for row in train_projected:
by_train[row['expected']].append(row)
for row in test_projected:
if row['expected'] in by_test:
by_test[row['expected']].append(row)
random.seed(SEED)
for bucket in by_train.values():
random.shuffle(bucket)
for bucket in by_test.values():
random.shuffle(bucket)
train = []
val = []
for label, bucket in by_train.items():
train.extend(bucket[:TRAIN_PER_CLASS])
val.extend(bucket[TRAIN_PER_CLASS:TRAIN_PER_CLASS + VAL_PER_CLASS])
test = []
for label, bucket in by_test.items():
test.extend(bucket[:TEST_PER_CLASS])
for split in (train, val, test):
random.shuffle(split)
return {'train': train, 'val': val, 'test': test}
ag_news = load_ag_news_splits()
print('Train counts:', {k: sum(1 for r in ag_news['train'] if r['expected'] == k) for k in AG_NEWS_LABELS})
print('Val counts:', {k: sum(1 for r in ag_news['val'] if r['expected'] == k) for k in AG_NEWS_LABELS})
print('Test counts:', {k: sum(1 for r in ag_news['test'] if r['expected'] == k) for k in AG_NEWS_LABELS})
Part 3 — Batch evaluation helper¶
This helper is used for quick baseline/test evaluation and gives top-alternative context.
def call_batch(messages_batch: List[List[Dict[str, object]]]) -> List[object]:
all_responses = []
for start in range(0, len(messages_batch), BATCH_CHUNK):
chunk = messages_batch[start:start + BATCH_CHUNK]
for attempt in range(3):
try:
responses = litellm.batch_completion(
model = TASK_MODEL,
messages = chunk,
max_workers = BATCH_WORKERS,
response_format = RESPONSE_FORMAT,
logprobs = True,
top_logprobs = 5,
max_tokens = 64,
seed = SEED,
temperature = 0,
**PROVIDER_KWARGS,
)
all_responses.extend(responses)
break
except Exception as exc:
if attempt >= 2:
all_responses.extend([exc for _ in chunk])
else:
time.sleep(2 ** attempt)
return all_responses
def evaluate_on_dataset(prompt: str, data: List[Dict[str, object]]) -> pd.DataFrame:
messages_batch = []
for row in data:
messages_batch.append([
{"role": "system", "content": prompt},
{"role": "user", "content": row['input']},
])
responses = call_batch(messages_batch)
rows = []
for row, response in zip(data, responses):
expected = row['expected']
if isinstance(response, Exception):
predicted = ''
score = 0.0
top_alt = ''
top_alt_score = 0.0
correct = False
else:
raw = response.choices[0].message.content
predicted = parse_category(raw)
conf = parse_confidence(response)
score = float(conf['joint_probability'])
top_alt = str(conf['top_alternative'])
top_alt_score = float(conf['top_alternative_probability'])
correct = predicted == expected
rows.append({
'text' : row['input'],
'expected' : expected,
'predicted' : predicted,
'score' : score,
'top_alternative' : top_alt,
'top_alternative_score': top_alt_score,
'correct' : bool(correct),
})
return pd.DataFrame(rows)
def build_summary(df: pd.DataFrame) -> Dict[str, float]:
p, r, f1, _ = precision_recall_fscore_support(
df['expected'],
df['predicted'],
labels = AG_NEWS_LABELS,
average = 'weighted',
zero_division = 0.0,
)
return {
'accuracy' : float(accuracy_score(df['expected'], df['predicted'])),
'precision' : float(p),
'recall' : float(r),
'f1' : float(f1),
'mean_score': float(df['score'].mean()),
}
if not os.getenv('OPENAI_API_KEY'):
raise RuntimeError('Set OPENAI_API_KEY to run the tutorial end-to-end.')
baseline_df = evaluate_on_dataset(SEED_PROMPT, ag_news['test'])
print('Baseline weighted F1:', build_summary(baseline_df)['f1'])
Part 4 — Run GEPA optimization¶
Optimize with both adapters using the same train/validation protocol.
def make_reflection_lm():
def _reflect(prompt):
messages = [{'role': 'user', 'content': prompt}] if isinstance(prompt, str) else prompt
kwargs = {
'model': REFLECTION_MODEL,
'messages': messages,
'max_tokens': 2048,
**PROVIDER_KWARGS,
}
if WORKING_THINKING_PAYLOAD is not None:
kwargs['extra_body'] = WORKING_THINKING_PAYLOAD
for attempt in range(3):
try:
return litellm.completion(**kwargs).choices[0].message.content
except Exception:
if attempt >= 2:
raise
time.sleep(1.5 * (attempt + 1))
return _reflect
def make_default_evaluator():
def evaluator(data, response_text):
predicted = parse_category(response_text)
if predicted.strip().lower() == data['answer'].strip().lower():
return EvaluationResult(score=1.0, feedback=f"Correct: '{predicted}'.")
return EvaluationResult(score=0.0, feedback=f"Incorrect. Expected '{data['answer']}', got '{predicted}'.")
return evaluator
def run_default() -> Dict[str, object]:
adapter = DefaultAdapter(
model = TASK_MODEL,
evaluator=make_default_evaluator(),
litellm_batch_completion_kwargs={
'response_format': RESPONSE_FORMAT,
'seed' : SEED,
'temperature' : 0,
**PROVIDER_KWARGS,
},
)
minibatch = MINIBATCH_PER_CLASS * len(AG_NEWS_LABELS)
max_metric_calls = MAX_ITERATIONS * (minibatch + len(ag_news['val'])) + len(ag_news['val'])
result = gepa.optimize(
seed_candidate = {'system_prompt': SEED_PROMPT},
trainset = ag_news['train'],
valset = ag_news['val'],
adapter = adapter,
reflection_lm = make_reflection_lm(),
reflection_minibatch_size = minibatch,
max_metric_calls = max_metric_calls,
seed = SEED,
raise_on_exception = False,
)
return {'result': result, 'history': result.val_aggregate_scores}
def run_confidence() -> Dict[str, object]:
adapter = ConfidenceAdapter(
model = TASK_MODEL,
field_path = 'category',
response_format = RESPONSE_FORMAT,
response_schema = RESPONSE_FORMAT['json_schema']['schema'],
high_confidence_threshold = 0.99,
low_confidence_threshold = 0.90,
litellm_batch_completion_kwargs = {
'seed' : SEED,
'temperature' : 0,
**PROVIDER_KWARGS,
},
max_litellm_workers = BATCH_WORKERS,
)
minibatch = MINIBATCH_PER_CLASS * len(AG_NEWS_LABELS)
max_metric_calls = MAX_ITERATIONS * (minibatch + len(ag_news['val'])) + len(ag_news['val'])
result = gepa.optimize(
seed_candidate = {'system_prompt': SEED_PROMPT},
trainset = ag_news['train'],
valset = ag_news['val'],
adapter = adapter,
reflection_lm = make_reflection_lm(),
reflection_minibatch_size = minibatch,
max_metric_calls = max_metric_calls,
seed = SEED,
raise_on_exception = False,
)
return {'result': result, 'history': result.val_aggregate_scores}
default_bundle = run_default()
confidence_bundle = run_confidence()
default_prompt = default_bundle['result'].best_candidate['system_prompt']
confidence_prompt = confidence_bundle['result'].best_candidate['system_prompt']
print('Default best val score:', default_bundle['history'][-1] if default_bundle['history'] else None)
print('Confidence best val score:', confidence_bundle['history'][-1] if confidence_bundle['history'] else None)
Part 5 — Evaluate and compare on AG News test set¶
We finish with weighted metrics and confidence-focused charts.
default_test = evaluate_on_dataset(default_prompt, ag_news['test'])
confidence_test = evaluate_on_dataset(confidence_prompt, ag_news['test'])
summary = pd.DataFrame([
{'mode': 'Baseline', **build_summary(baseline_df)},
{'mode': 'DefaultAdapter', **build_summary(default_test)},
{'mode': 'ConfidenceAdapter', **build_summary(confidence_test)},
])
print(summary)
fig, ax = plt.subplots(figsize=(10, 4))
x = np.arange(len(summary))
w = 0.2
for i, m in enumerate(['accuracy', 'precision', 'recall', 'f1']):
bars = ax.bar(x + (i - 1.5) * w, summary[m], width=w, label=m)
for bar in bars:
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.008, f'{bar.get_height():.3f}', ha='center', va='bottom', fontsize=8)
ax.set_xticks(x)
ax.set_xticklabels(summary['mode'])
ax.set_ylim(0, 1.05)
ax.set_title('AG News test summary')
ax.set_ylabel('Score')
ax.legend(ncol=2)
ax.grid(axis='y', alpha=0.25)
plt.tight_layout()
plt.show()
def per_class_chart(metric: str) -> None:
frames = [('Baseline', baseline_df), ('DefaultAdapter', default_test), ('ConfidenceAdapter', confidence_test)]
values_by_name = {}
for name, df_src in frames:
p, r, f1, _ = precision_recall_fscore_support(
df_src['expected'], df_src['predicted'],
labels=AG_NEWS_LABELS, average=None, zero_division=0.0
)
if metric == 'precision':
values_by_name[name] = p
elif metric == 'recall':
values_by_name[name] = r
else:
values_by_name[name] = f1
x = np.arange(len(AG_NEWS_LABELS))
w = 0.22
fig, ax = plt.subplots(figsize=(10, 4))
for i, (name, vals) in enumerate(values_by_name.items()):
bars = ax.bar(x + (i - 1) * w, vals, width=w, label=name)
for bar in bars:
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005, f'{bar.get_height():.2f}', ha='center', fontsize=8)
ax.set_xticks(x)
ax.set_xticklabels(AG_NEWS_LABELS, rotation=25)
ax.set_ylim(0, 1.05)
ax.set_title(f'AG News per-class {metric}')
ax.set_ylabel(metric)
ax.legend()
ax.grid(axis='y', alpha=0.25)
plt.tight_layout()
plt.show()
for metric in ['precision', 'recall', 'f1']:
per_class_chart(metric)
# Convergence
fig, ax = plt.subplots(figsize=(8, 4))
if default_bundle['history']:
ax.plot(default_bundle['history'], marker='o', label='DefaultAdapter')
if confidence_bundle['history']:
ax.plot(confidence_bundle['history'], marker='o', label='ConfidenceAdapter')
ax.set_title('Validation score over candidates')
ax.set_xlabel('Candidate')
ax.set_ylabel('Validation score')
ax.text(0.5, -0.22, 'GEPA selects the best candidate from all evaluated prompts.', transform=ax.transAxes, ha='center', fontsize=8)
ax.grid(alpha=0.25)
ax.legend()
plt.tight_layout()
plt.show()
def confidence_ecdf(df: pd.DataFrame, title: str) -> None:
correct = df[df['correct']]['score'].sort_values().reset_index(drop=True)
wrong = df[~df['correct']]['score'].sort_values().reset_index(drop=True)
fig, ax = plt.subplots(figsize=(8, 4))
if len(correct):
y = np.arange(1, len(correct) + 1) / len(correct)
ax.plot(correct, y, label=f'{title} correct')
if len(wrong):
y = np.arange(1, len(wrong) + 1) / len(wrong)
ax.plot(wrong, y, '--', label=f'{title} wrong')
ax.set_title(f'{title} confidence CDF')
ax.set_xlim(0, 1.0)
ax.set_ylim(0, 1.0)
ax.set_xlabel('Joint probability')
ax.set_ylabel('Cumulative share')
ax.grid(alpha=0.25)
ax.legend()
plt.tight_layout()
plt.show()
confidence_ecdf(default_test, 'DefaultAdapter')
confidence_ecdf(confidence_test, 'ConfidenceAdapter')
sample_bad = confidence_test[~confidence_test['correct']].head(1)
if len(sample_bad):
row = sample_bad.iloc[0]
print('Sample wrong row:')
print('Expected:', row['expected'])
print('Pred:', row['predicted'])
print('Score:', row['score'])
print('Alternative:', row['top_alternative'], row['top_alternative_score'])
Part 6 — Next steps¶
Keep OPENAI_API_KEY and OPENAI_API_BASE configured, then run this notebook end-to-end for a fresh live reproduction.