Skip to content

ExperimentTracker

gepa.logging.experiment_tracker.ExperimentTracker(use_wandb: bool = False, wandb_api_key: str | None = None, wandb_init_kwargs: dict[str, Any] | None = None, wandb_attach_existing: bool = False, wandb_step_metric: str | None = None, use_mlflow: bool = False, mlflow_tracking_uri: str | None = None, mlflow_experiment_name: str | None = None, mlflow_attach_existing: bool = False, key_prefix: str = '')

Unified experiment tracking that supports both wandb and mlflow.

Source code in gepa/logging/experiment_tracker.py
def __init__(
    self,
    use_wandb: bool = False,
    wandb_api_key: str | None = None,
    wandb_init_kwargs: dict[str, Any] | None = None,
    wandb_attach_existing: bool = False,
    wandb_step_metric: str | None = None,
    use_mlflow: bool = False,
    mlflow_tracking_uri: str | None = None,
    mlflow_experiment_name: str | None = None,
    mlflow_attach_existing: bool = False,
    key_prefix: str = "",
):
    self.use_wandb = use_wandb
    self.use_mlflow = use_mlflow

    self.wandb_api_key = wandb_api_key
    self.wandb_init_kwargs = wandb_init_kwargs or {}
    self.wandb_attach_existing = wandb_attach_existing
    self.wandb_step_metric = wandb_step_metric
    self.mlflow_tracking_uri = mlflow_tracking_uri
    self.mlflow_experiment_name = mlflow_experiment_name
    self.mlflow_attach_existing = mlflow_attach_existing
    self.key_prefix = key_prefix

    self._created_mlflow_run = False
    self._mlflow_run_id: str | None = None
    self._mlflow_client: Any = None
    self._wandb_step_metric_defined = False

    # Accumulate table rows so each wandb.log() sends the full growing
    # table, not just the latest row.  Without this, commit=False causes
    # the pending dict to overwrite earlier single-row tables with the
    # newest one, and only the last row per commit cycle survives.
    self._wandb_table_rows: dict[str, tuple[list[str], list[list]]] = {}

Attributes

use_wandb = use_wandb instance-attribute

use_mlflow = use_mlflow instance-attribute

wandb_api_key = wandb_api_key instance-attribute

wandb_init_kwargs = wandb_init_kwargs or {} instance-attribute

wandb_attach_existing = wandb_attach_existing instance-attribute

wandb_step_metric = wandb_step_metric instance-attribute

mlflow_tracking_uri = mlflow_tracking_uri instance-attribute

mlflow_experiment_name = mlflow_experiment_name instance-attribute

mlflow_attach_existing = mlflow_attach_existing instance-attribute

key_prefix = key_prefix instance-attribute

Functions

__enter__()

Context manager entry.

Source code in gepa/logging/experiment_tracker.py
def __enter__(self):
    """Context manager entry."""
    self.initialize()
    self.start_run()
    return self

__exit__(exc_type, exc_val, exc_tb)

Context manager exit - always end the run.

Source code in gepa/logging/experiment_tracker.py
def __exit__(self, exc_type, exc_val, exc_tb):
    """Context manager exit - always end the run."""
    self.end_run()
    return False  # Don't suppress exceptions

initialize()

Initialize the logging backends.

Source code in gepa/logging/experiment_tracker.py
def initialize(self):
    """Initialize the logging backends."""
    if self.use_wandb:
        self._initialize_wandb()
    if self.use_mlflow:
        self._initialize_mlflow()

start_run()

Start a new run.

When wandb_attach_existing=True the tracker skips wandb.init() and logs into whatever run is already active in the process. When mlflow_attach_existing=True the tracker skips mlflow.start_run() and logs into the already-active MLflow run. In both cases end_run() will not terminate the run.

Source code in gepa/logging/experiment_tracker.py
def start_run(self):
    """Start a new run.

    When ``wandb_attach_existing=True`` the tracker skips ``wandb.init()``
    and logs into whatever run is already active in the process.
    When ``mlflow_attach_existing=True`` the tracker skips
    ``mlflow.start_run()`` and logs into the already-active MLflow run.
    In both cases ``end_run()`` will not terminate the run.
    """
    if self.use_wandb:
        if self.wandb_attach_existing:
            # Attach to the active run — no init, no finish later.
            pass
        else:
            import wandb  # type: ignore

            wandb.init(**self.wandb_init_kwargs)
    if self.use_mlflow:
        import mlflow  # type: ignore
        from mlflow import MlflowClient  # type: ignore

        if self.mlflow_attach_existing:
            # Attach to the active run — no start, no end later.
            self._created_mlflow_run = False
        else:
            # Only start a new run if there's no active run
            if mlflow.active_run() is None:
                mlflow.start_run()
                self._created_mlflow_run = True
            else:
                self._created_mlflow_run = False

        # Capture run_id and create a client for thread-safe logging.
        # mlflow.active_run() is thread-local, so parallel threads
        # (e.g. parallel proposals) would auto-create new runs without this.
        active = mlflow.active_run()
        if active is not None:
            self._mlflow_run_id = active.info.run_id
            try:
                tracking_uri = mlflow.get_tracking_uri()
                self._mlflow_client = MlflowClient(tracking_uri=tracking_uri)
            except Exception:
                # MlflowClient creation can fail in test environments;
                # fall back to fluent API (thread-unsafe but functional).
                self._mlflow_client = None

log_config(config: dict[str, Any]) -> None

Log run configuration/hyperparameters to the active backends.

Parameters:

Name Type Description Default
config dict[str, Any]

Flat dict of config key-value pairs. Non-serializable values are converted to strings.

required
Source code in gepa/logging/experiment_tracker.py
def log_config(self, config: dict[str, Any]) -> None:
    """Log run configuration/hyperparameters to the active backends.

    Args:
        config: Flat dict of config key-value pairs. Non-serializable values
                are converted to strings.
    """
    safe_config = {}
    for k, v in config.items():
        if isinstance(v, bool | int | float | str | type(None)):
            safe_config[k] = v
        else:
            safe_config[k] = str(v)

    if self.use_wandb:
        try:
            import wandb  # type: ignore

            prefixed = {self._p(k): v for k, v in safe_config.items()}
            wandb.config.update(prefixed, allow_val_change=True)
        except Exception as e:
            print(f"Warning: Failed to log config to wandb: {e}")

    if self.use_mlflow:
        try:
            str_params = {self._p(k): str(v) for k, v in safe_config.items()}
            if self._mlflow_client and self._mlflow_run_id:
                for k, v in str_params.items():
                    self._mlflow_client.log_param(self._mlflow_run_id, k, v)
            else:
                import mlflow  # type: ignore
                mlflow.log_params(str_params)
        except Exception as e:
            print(f"Warning: Failed to log config to mlflow: {e}")

log_metrics(metrics: dict[str, Any], step: int | None = None)

Log metrics to the active backends.

Source code in gepa/logging/experiment_tracker.py
def log_metrics(self, metrics: dict[str, Any], step: int | None = None):
    """Log metrics to the active backends."""
    if self.use_wandb:
        try:
            import wandb  # type: ignore

            # Lazily define the custom step metric on the first log call.
            self._define_wandb_step_metric()

            # Filter to numeric values only — non-numeric data (dicts, strings)
            # is logged via log_table() instead to avoid noisy flat charts
            numeric_metrics = {self._p(k): v for k, v in metrics.items() if isinstance(v, int | float)}
            if numeric_metrics:
                if self.wandb_step_metric and step is not None:
                    # Use custom x-axis: inject the step as a metric value
                    # and omit the step= arg to avoid conflicting with the
                    # host run's global step counter.
                    numeric_metrics[self.wandb_step_metric] = step
                    wandb.log(numeric_metrics)
                else:
                    wandb.log(numeric_metrics, step=step)
        except Exception as e:
            print(f"Warning: Failed to log to wandb: {e}")

    if self.use_mlflow:
        try:
            numeric_metrics = {self._p(k): float(v) for k, v in metrics.items() if isinstance(v, int | float)}
            if numeric_metrics:
                if self._mlflow_client and self._mlflow_run_id:
                    for k, v in numeric_metrics.items():
                        self._mlflow_client.log_metric(self._mlflow_run_id, k, v, step=step or 0)
                else:
                    import mlflow  # type: ignore
                    mlflow.log_metrics(numeric_metrics, step=step)
        except Exception as e:
            print(f"Warning: Failed to log to mlflow: {e}")

log_summary(summary: dict[str, Any]) -> None

Log run summary data (visible on the run overview page).

Parameters:

Name Type Description Default
summary dict[str, Any]

Key-value pairs for the run summary. Supports strings, numbers, and other serializable values.

required
Source code in gepa/logging/experiment_tracker.py
def log_summary(self, summary: dict[str, Any]) -> None:
    """Log run summary data (visible on the run overview page).

    Args:
        summary: Key-value pairs for the run summary. Supports strings,
                 numbers, and other serializable values.
    """
    if self.use_wandb:
        try:
            import wandb  # type: ignore

            for k, v in summary.items():
                wandb.run.summary[self._p(k)] = v  # type: ignore[union-attr]
        except Exception as e:
            print(f"Warning: Failed to log summary to wandb: {e}")

    if self.use_mlflow:
        try:
            numeric = {self._p(k): float(v) for k, v in summary.items() if isinstance(v, int | float)}
            text = {self._p(k): str(v) for k, v in summary.items() if isinstance(v, str)}
            if self._mlflow_client and self._mlflow_run_id:
                if numeric:
                    for k, v in numeric.items():
                        self._mlflow_client.log_metric(self._mlflow_run_id, k, v)
                if text:
                    for k, v in text.items():
                        self._mlflow_client.log_param(self._mlflow_run_id, f"summary/{k}", v)
            else:
                import mlflow  # type: ignore
                if numeric:
                    mlflow.log_metrics(numeric)
                if text:
                    mlflow.log_params({f"summary/{k}": v for k, v in text.items()})
        except Exception as e:
            print(f"Warning: Failed to log summary to mlflow: {e}")

log_table(table_name: str, columns: list[str], data: list[list[Any]]) -> None

Log a table to the active backends.

Parameters:

Name Type Description Default
table_name str

Name/key for the table.

required
columns list[str]

Column headers.

required
data list[list[Any]]

Rows of data (each row is a list matching columns).

required
Source code in gepa/logging/experiment_tracker.py
def log_table(self, table_name: str, columns: list[str], data: list[list[Any]]) -> None:
    """Log a table to the active backends.

    Args:
        table_name: Name/key for the table.
        columns: Column headers.
        data: Rows of data (each row is a list matching columns).
    """
    if self.use_wandb:
        try:
            import wandb  # type: ignore

            # Accumulate rows: each call appends to the stored rows for
            # this table, then logs the full growing table.  This ensures
            # all rows survive even when multiple log_table calls share
            # the same commit cycle (commit=False overwrites the pending
            # dict, so a single-row table would replace the previous one).
            key = self._p(table_name)
            if key not in self._wandb_table_rows:
                self._wandb_table_rows[key] = (columns, list(data))
            else:
                self._wandb_table_rows[key][1].extend(data)
            all_columns, all_rows = self._wandb_table_rows[key]
            table = wandb.Table(columns=all_columns, data=all_rows)
            wandb.log({key: table}, commit=False)
        except Exception as e:
            print(f"Warning: Failed to log table to wandb: {e}")

    if self.use_mlflow:
        try:
            import mlflow  # type: ignore

            if self._mlflow_run_id:
                # When we have a run_id, only log from the owning thread to avoid creating new runs
                active = mlflow.active_run()
                if active is None or active.info.run_id != self._mlflow_run_id:
                    return
            table_dict = {col: [row[i] for row in data] for i, col in enumerate(columns)}
            mlflow.log_table(data=table_dict, artifact_file=f"{self._p(table_name)}.json")
        except Exception as e:
            print(f"Warning: Failed to log table to mlflow: {e}")

log_html(html_content: str, key: str = 'candidate_tree') -> None

Log an HTML string as a rich media artifact.

Parameters:

Name Type Description Default
html_content str

Self-contained HTML string.

required
key str

Artifact key / name used in the dashboard.

'candidate_tree'
Source code in gepa/logging/experiment_tracker.py
def log_html(self, html_content: str, key: str = "candidate_tree") -> None:
    """Log an HTML string as a rich media artifact.

    Args:
        html_content: Self-contained HTML string.
        key: Artifact key / name used in the dashboard.
    """
    if self.use_wandb:
        try:
            import wandb  # type: ignore

            html_obj = wandb.Html(html_content)
            pkey = self._p(key)
            wandb.log({pkey: html_obj}, commit=False)
            # Also write to run summary so the panel always shows the latest tree
            wandb.run.summary[pkey] = html_obj  # type: ignore[union-attr]
        except Exception as e:
            print(f"Warning: Failed to log HTML to wandb: {e}")

    if self.use_mlflow:
        try:
            import tempfile

            with tempfile.NamedTemporaryFile(mode="w", suffix=".html", delete=False) as f:
                f.write(html_content)
                tmp_path = f.name
            if self._mlflow_client and self._mlflow_run_id:
                self._mlflow_client.log_artifact(self._mlflow_run_id, tmp_path, artifact_path=self._p(key))
            else:
                import mlflow  # type: ignore
                mlflow.log_artifact(tmp_path, artifact_path=self._p(key))
        except Exception as e:
            print(f"Warning: Failed to log HTML to mlflow: {e}")

end_run()

End the current run.

When wandb_attach_existing=True or mlflow_attach_existing=True the respective run is left open — the caller owns its lifecycle.

Source code in gepa/logging/experiment_tracker.py
def end_run(self):
    """End the current run.

    When ``wandb_attach_existing=True`` or ``mlflow_attach_existing=True``
    the respective run is left open — the caller owns its lifecycle.
    """
    if self.use_wandb and not self.wandb_attach_existing:
        try:
            import wandb  # type: ignore

            if wandb.run is not None:
                wandb.finish()
        except Exception as e:
            print(f"Warning: Failed to end wandb run: {e}")

    if self.use_mlflow:
        try:
            import mlflow  # type: ignore

            if self._created_mlflow_run and mlflow.active_run() is not None:
                mlflow.end_run()
                self._created_mlflow_run = False
        except Exception as e:
            print(f"Warning: Failed to end mlflow run: {e}")

is_active() -> bool

Check if any backend has an active run.

Source code in gepa/logging/experiment_tracker.py
def is_active(self) -> bool:
    """Check if any backend has an active run."""
    if self.use_wandb:
        try:
            import wandb  # type: ignore

            if wandb.run is not None:
                return True
        except Exception:
            pass

    if self.use_mlflow:
        try:
            import mlflow  # type: ignore

            if mlflow.active_run() is not None:
                return True
        except Exception:
            pass

    return False