diff --git a/src/fairchem/core/common/data_parallel.py b/src/fairchem/core/common/data_parallel.py index 185fae0b2e..99e2cd9e99 100644 --- a/src/fairchem/core/common/data_parallel.py +++ b/src/fairchem/core/common/data_parallel.py @@ -127,6 +127,7 @@ def __init__( num_replicas: int, rank: int, device: torch.device, + seed: int, mode: str | bool = "atoms", shuffle: bool = True, drop_last: bool = False, @@ -159,6 +160,7 @@ def __init__( shuffle=shuffle, drop_last=drop_last, batch_size=batch_size, + seed=seed ) self.batch_sampler = BatchSampler( self.single_sampler, diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 388d3b3a58..418951b0ed 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -259,6 +259,7 @@ def get_sampler( mode=balancing_mode, shuffle=shuffle, force_balancing=force_balancing, + seed=self.config["cmd"]["seed"], ) def get_dataloader(self, dataset, sampler) -> DataLoader: diff --git a/tests/core/common/test_data_parallel_batch_sampler.py b/tests/core/common/test_data_parallel_batch_sampler.py index a2b04d2ff8..6205042652 100644 --- a/tests/core/common/test_data_parallel_batch_sampler.py +++ b/tests/core/common/test_data_parallel_batch_sampler.py @@ -90,6 +90,7 @@ def test_lowercase(invalid_dataset) -> None: device=None, mode="ATOMS", throw_on_error=False, + seed=0 ) assert sampler.mode == "atoms" @@ -101,6 +102,7 @@ def test_lowercase(invalid_dataset) -> None: device=None, mode="NEIGHBORS", throw_on_error=False, + seed=0 ) assert sampler.mode == "neighbors" @@ -117,6 +119,7 @@ def test_invalid_mode(invalid_dataset) -> None: device=None, mode="natoms", throw_on_error=True, + seed=0 ) with pytest.raises( @@ -130,6 +133,7 @@ def test_invalid_mode(invalid_dataset) -> None: device=None, mode="nneighbors", throw_on_error=True, + seed=0 ) @@ -147,6 +151,7 @@ def test_invalid_dataset(invalid_dataset) -> None: mode="atoms", throw_on_error=True, force_balancing=True, + seed=0 ) with pytest.raises( RuntimeError, @@ -161,6 +166,7 @@ def test_invalid_dataset(invalid_dataset) -> None: mode="atoms", throw_on_error=True, force_balancing=False, + seed=0 ) @@ -178,6 +184,7 @@ def test_invalid_path_dataset(invalid_path_dataset) -> None: mode="atoms", throw_on_error=True, force_balancing=True, + seed=0 ) with pytest.raises( RuntimeError, @@ -192,6 +199,7 @@ def test_invalid_path_dataset(invalid_path_dataset) -> None: mode="atoms", throw_on_error=True, force_balancing=False, + seed=0 ) @@ -204,6 +212,7 @@ def test_valid_dataset(valid_path_dataset) -> None: device=None, mode="atoms", throw_on_error=True, + seed=0 ) assert (sampler.sizes == np.array(SIZE_ATOMS)).all() @@ -215,6 +224,7 @@ def test_valid_dataset(valid_path_dataset) -> None: device=None, mode="neighbors", throw_on_error=True, + seed=0 ) assert (sampler.sizes == np.array(SIZE_NEIGHBORS)).all() @@ -228,6 +238,7 @@ def test_disabled(valid_path_dataset) -> None: device=None, mode=False, throw_on_error=True, + seed=0 ) assert sampler.balance_batches is False @@ -241,6 +252,7 @@ def test_single_node(valid_path_dataset) -> None: device=None, mode="atoms", throw_on_error=True, + seed=0 ) assert sampler.balance_batches is False