Skip to content

Commit

Permalink
pass seed to sampler (#697)
Browse files Browse the repository at this point in the history
* pass seed to sampler

* fix up tests

---------

Co-authored-by: Muhammed Shuaibi <[email protected]>
  • Loading branch information
misko and mshuaibii authored May 22, 2024
1 parent 2fe64a8 commit e68cdf6
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/fairchem/core/common/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(
num_replicas: int,
rank: int,
device: torch.device,
seed: int,
mode: str | bool = "atoms",
shuffle: bool = True,
drop_last: bool = False,
Expand Down Expand Up @@ -159,6 +160,7 @@ def __init__(
shuffle=shuffle,
drop_last=drop_last,
batch_size=batch_size,
seed=seed
)
self.batch_sampler = BatchSampler(
self.single_sampler,
Expand Down
1 change: 1 addition & 0 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def get_sampler(
mode=balancing_mode,
shuffle=shuffle,
force_balancing=force_balancing,
seed=self.config["cmd"]["seed"],
)

def get_dataloader(self, dataset, sampler) -> DataLoader:
Expand Down
12 changes: 12 additions & 0 deletions tests/core/common/test_data_parallel_batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def test_lowercase(invalid_dataset) -> None:
device=None,
mode="ATOMS",
throw_on_error=False,
seed=0
)
assert sampler.mode == "atoms"

Expand All @@ -101,6 +102,7 @@ def test_lowercase(invalid_dataset) -> None:
device=None,
mode="NEIGHBORS",
throw_on_error=False,
seed=0
)
assert sampler.mode == "neighbors"

Expand All @@ -117,6 +119,7 @@ def test_invalid_mode(invalid_dataset) -> None:
device=None,
mode="natoms",
throw_on_error=True,
seed=0
)

with pytest.raises(
Expand All @@ -130,6 +133,7 @@ def test_invalid_mode(invalid_dataset) -> None:
device=None,
mode="nneighbors",
throw_on_error=True,
seed=0
)


Expand All @@ -147,6 +151,7 @@ def test_invalid_dataset(invalid_dataset) -> None:
mode="atoms",
throw_on_error=True,
force_balancing=True,
seed=0
)
with pytest.raises(
RuntimeError,
Expand All @@ -161,6 +166,7 @@ def test_invalid_dataset(invalid_dataset) -> None:
mode="atoms",
throw_on_error=True,
force_balancing=False,
seed=0
)


Expand All @@ -178,6 +184,7 @@ def test_invalid_path_dataset(invalid_path_dataset) -> None:
mode="atoms",
throw_on_error=True,
force_balancing=True,
seed=0
)
with pytest.raises(
RuntimeError,
Expand All @@ -192,6 +199,7 @@ def test_invalid_path_dataset(invalid_path_dataset) -> None:
mode="atoms",
throw_on_error=True,
force_balancing=False,
seed=0
)


Expand All @@ -204,6 +212,7 @@ def test_valid_dataset(valid_path_dataset) -> None:
device=None,
mode="atoms",
throw_on_error=True,
seed=0
)
assert (sampler.sizes == np.array(SIZE_ATOMS)).all()

Expand All @@ -215,6 +224,7 @@ def test_valid_dataset(valid_path_dataset) -> None:
device=None,
mode="neighbors",
throw_on_error=True,
seed=0
)
assert (sampler.sizes == np.array(SIZE_NEIGHBORS)).all()

Expand All @@ -228,6 +238,7 @@ def test_disabled(valid_path_dataset) -> None:
device=None,
mode=False,
throw_on_error=True,
seed=0
)
assert sampler.balance_batches is False

Expand All @@ -241,6 +252,7 @@ def test_single_node(valid_path_dataset) -> None:
device=None,
mode="atoms",
throw_on_error=True,
seed=0
)
assert sampler.balance_batches is False

Expand Down

0 comments on commit e68cdf6

Please sign in to comment.