Bases: BatchSampler[DataId, DataInst]
Mirrors the original batching logic:
- Shuffle ids each epoch
- Pad to minibatch size with least frequent ids
- Deterministic via state.rng1
Source code in gepa/strategies/batch_sampler.py
| def __init__(self, minibatch_size: int, rng: random.Random | None = None):
self.minibatch_size = minibatch_size
self.shuffled_ids: list[DataId] = []
self.epoch = -1
self.id_freqs = Counter()
self.last_trainset_size = 0
if rng is None:
self.rng = random.Random(0)
else:
self.rng = rng
|
Attributes
minibatch_size = minibatch_size
instance-attribute
shuffled_ids: list[DataId] = []
instance-attribute
epoch = -1
instance-attribute
id_freqs = Counter()
instance-attribute
last_trainset_size = 0
instance-attribute
rng = random.Random(0)
instance-attribute
Functions
next_minibatch_ids(loader: DataLoader[DataId, DataInst], state: GEPAState) -> list[DataId]
Source code in gepa/strategies/batch_sampler.py
| def next_minibatch_ids(self, loader: DataLoader[DataId, DataInst], state: GEPAState) -> list[DataId]:
trainset_size = len(loader)
if trainset_size == 0:
raise ValueError("Cannot sample a minibatch from an empty loader.")
base_idx = state.i * self.minibatch_size
curr_epoch = 0 if self.epoch == -1 else base_idx // max(len(self.shuffled_ids), 1)
needs_refresh = not self.shuffled_ids or trainset_size != self.last_trainset_size or curr_epoch > self.epoch
if needs_refresh:
self.epoch = curr_epoch
self._update_shuffled(loader)
assert len(self.shuffled_ids) >= self.minibatch_size
assert len(self.shuffled_ids) % self.minibatch_size == 0
base_idx = base_idx % len(self.shuffled_ids)
end_idx = base_idx + self.minibatch_size
assert end_idx <= len(self.shuffled_ids)
return self.shuffled_ids[base_idx:end_idx]
|