diff --git a/ranzen/torch/data.py b/ranzen/torch/data.py index 5d19843a..0d753790 100644 --- a/ranzen/torch/data.py +++ b/ranzen/torch/data.py @@ -264,7 +264,7 @@ class TrainingMode(Enum): """step-based training""" -class BatchSamplerBase(Sampler[Sequence[int]]): +class BatchSamplerBase(Sampler[List[int]]): def __init__(self, epoch_length: int | None = None) -> None: self.epoch_length: Final[int | None] = epoch_length