diff --git a/src/metatrain/utils/data/dataset.py b/src/metatrain/utils/data/dataset.py index 36a58b373..b0ec3de34 100644 --- a/src/metatrain/utils/data/dataset.py +++ b/src/metatrain/utils/data/dataset.py @@ -190,8 +190,8 @@ def get_atomic_types(datasets: Union[Dataset, List[Dataset]]) -> List[int]: types = set() for dataset in datasets: - for index in range(len(dataset)): - system = dataset[index]["system"] + for sample in dataset: + system = sample["system"] types.update(set(system.types.tolist())) return sorted(types)