Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Explicitly pass serialization directory and local rank to trainer in train command #5180

Merged
merged 4 commits into from
May 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions allennlp/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,13 +723,12 @@ def from_partial_objects(
for data_loader_ in data_loaders.values():
data_loader_.index_with(model_.vocab)

# We don't need to pass serialization_dir and local_rank here, because they will have been
# passed through the trainer by from_params already, because they were keyword arguments to
# construct this class in the first place.
trainer_ = trainer.construct(
serialization_dir=serialization_dir,
model=model_,
data_loader=data_loaders["train"],
validation_data_loader=data_loaders.get("validation"),
local_rank=local_rank,
)
assert trainer_ is not None

Expand Down
8 changes: 7 additions & 1 deletion allennlp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,13 @@ def __init__(
enable_default_callbacks: bool = True,
run_sanity_checks: bool = True,
) -> None:
super().__init__(serialization_dir, cuda_device, distributed, local_rank, world_size)
super().__init__(
serialization_dir=serialization_dir,
cuda_device=cuda_device,
distributed=distributed,
local_rank=local_rank,
world_size=world_size,
)

# I am not calling move_to_gpu here, because if the model is
# not already on the GPU then the optimizer is going to be wrong.
Expand Down
22 changes: 21 additions & 1 deletion tests/commands/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,20 @@ def on_batch( # type: ignore
_seen_training_devices.add(tensor.device)


@TrainerCallback.register("training_primary_check")
class TrainingPrimaryCheckCallback(TrainerCallback):
"""
Makes sure there is only one primary worker.
"""

def on_start(
self, trainer: "GradientDescentTrainer", is_primary: bool = True, **kwargs
) -> None:
super().on_start(trainer, is_primary=is_primary, **kwargs)
if is_primary:
assert torch.distributed.get_rank() == 0


class TestTrain(AllenNlpTestCase):
DEFAULT_PARAMS = Params(
{
Expand Down Expand Up @@ -209,7 +223,13 @@ def test_train_model_distributed(self):
"train_data_path": SEQUENCE_TAGGING_DATA_PATH,
"validation_data_path": SEQUENCE_TAGGING_DATA_PATH,
"data_loader": {"batch_size": 2},
"trainer": {"num_epochs": 2, "optimizer": "adam"},
"trainer": {
"num_epochs": 2,
"optimizer": "adam",
# Need to use the fully qualified name here so the distributed workers
# can import it.
"callbacks": ["tests.commands.train_test.TrainingPrimaryCheckCallback"],
},
"distributed": {"cuda_devices": devices},
}
)
Expand Down