Skip to content

EpochShuffledBatchSampler

gepa.strategies.batch_sampler.EpochShuffledBatchSampler(minibatch_size: int, rng: random.Random | None = None)

Bases: BatchSampler[DataId, DataInst]

Mirrors the original batching logic: - Shuffle ids each epoch - Pad to minibatch size with least frequent ids - Deterministic via state.rng1

Source code in gepa/strategies/batch_sampler.py
def __init__(self, minibatch_size: int, rng: random.Random | None = None):
    self.minibatch_size = minibatch_size
    self.shuffled_ids: list[DataId] = []
    self.epoch = -1
    self.id_freqs = Counter()
    self.last_trainset_size = 0
    if rng is None:
        self.rng = random.Random(0)
    else:
        self.rng = rng

Attributes

minibatch_size = minibatch_size instance-attribute

shuffled_ids: list[DataId] = [] instance-attribute

epoch = -1 instance-attribute

id_freqs = Counter() instance-attribute

last_trainset_size = 0 instance-attribute

rng = random.Random(0) instance-attribute

Functions

next_minibatch_ids(loader: DataLoader[DataId, DataInst], state: GEPAState) -> list[DataId]

Source code in gepa/strategies/batch_sampler.py
def next_minibatch_ids(self, loader: DataLoader[DataId, DataInst], state: GEPAState) -> list[DataId]:
    trainset_size = len(loader)
    if trainset_size == 0:
        raise ValueError("Cannot sample a minibatch from an empty loader.")

    base_idx = state.i * self.minibatch_size
    curr_epoch = 0 if self.epoch == -1 else base_idx // max(len(self.shuffled_ids), 1)

    needs_refresh = not self.shuffled_ids or trainset_size != self.last_trainset_size or curr_epoch > self.epoch
    if needs_refresh:
        self.epoch = curr_epoch
        self._update_shuffled(loader)

    assert len(self.shuffled_ids) >= self.minibatch_size
    assert len(self.shuffled_ids) % self.minibatch_size == 0

    base_idx = base_idx % len(self.shuffled_ids)
    end_idx = base_idx + self.minibatch_size
    assert end_idx <= len(self.shuffled_ids)
    return self.shuffled_ids[base_idx:end_idx]