From e6836be7e3a689eb7cfc054f0ebdda0b28f3fecd Mon Sep 17 00:00:00 2001 From: romesc Date: Thu, 22 Apr 2021 14:57:55 -0700 Subject: [PATCH 1/7] optim config registration function --- examples/register_configs.py | 18 +++++++++++++++++ .../hydra_configs/torch/optim/__init__.py | 20 +++++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 examples/register_configs.py diff --git a/examples/register_configs.py b/examples/register_configs.py new file mode 100644 index 0000000..62655e3 --- /dev/null +++ b/examples/register_configs.py @@ -0,0 +1,18 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import hydra +from typing import Any +from omegaconf import OmegaConf + +# registers all 'base' optimizer configs with configstore instance +import hydra_configs.torch.optim + +hydra_configs.torch.optim.register_configs() + + +@hydra.main(config_name="torch/optim/adam") +def my_app(cfg: Any) -> None: + print(OmegaConf.to_yaml(cfg)) + + +if __name__ == "__main__": + my_app() diff --git a/hydra-configs-torch/hydra_configs/torch/optim/__init__.py b/hydra-configs-torch/hydra_configs/torch/optim/__init__.py index 76689e0..bd01f7e 100644 --- a/hydra-configs-torch/hydra_configs/torch/optim/__init__.py +++ b/hydra-configs-torch/hydra_configs/torch/optim/__init__.py @@ -1,3 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # flake8: noqa # Mirrors torch/optim __init__ to allow for symmetric import structure from .adadelta import AdadeltaConf @@ -14,6 +15,25 @@ from .lbfgs import LBFGSConf from . import lr_scheduler +from hydra.core.config_store import ConfigStore + + +def register_configs(): + cs = ConfigStore.instance() + cs.store(provider="torch", group="torch/optim", name="adadelta", node=AdadeltaConf) + cs.store(provider="torch", group="torch/optim", name="adagrad", node=AdagradConf) + cs.store(provider="torch", group="torch/optim", name="adam", node=AdamConf) + cs.store(provider="torch", group="torch/optim", name="adamw", node=AdamWConf) + cs.store( + provider="torch", group="torch/optim", name="sparseadam", node=SparseAdamConf + ) + cs.store(provider="torch", group="torch/optim", name="adamax", node=AdamaxConf) + cs.store(provider="torch", group="torch/optim", name="asgd", node=ASGDConf) + cs.store(provider="torch", group="torch/optim", name="sgd", node=SGDConf) + cs.store(provider="torch", group="torch/optim", name="lbfgs", node=LBFGSConf) + cs.store(provider="torch", group="torch/optim", name="rprop", node=RpropConf) + cs.store(provider="torch", group="torch/optim", name="rmsprop", node=RMSpropConf) + del adadelta del adagrad From f5dad4992e068e06f069978b12282840d4c3bd9c Mon Sep 17 00:00:00 2001 From: romesc Date: Thu, 22 Apr 2021 15:13:33 -0700 Subject: [PATCH 2/7] reformat example --- examples/register_configs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/register_configs.py b/examples/register_configs.py index 62655e3..2b5bb06 100644 --- a/examples/register_configs.py +++ b/examples/register_configs.py @@ -3,9 +3,9 @@ from typing import Any from omegaconf import OmegaConf -# registers all 'base' optimizer configs with configstore instance import hydra_configs.torch.optim +# registers all 'base' optimizer configs with configstore instance hydra_configs.torch.optim.register_configs() From b02b19361a2027180826e4367ad98de533a2fc9e Mon Sep 17 00:00:00 2001 From: romesc Date: Thu, 22 Apr 2021 16:54:03 -0700 Subject: [PATCH 3/7] add test, remove example --- examples/register_configs.py | 18 ------------ .../tests/test_register_optimizers.py | 29 +++++++++++++++++++ 2 files changed, 29 insertions(+), 18 deletions(-) delete mode 100644 examples/register_configs.py create mode 100644 hydra-configs-torch/tests/test_register_optimizers.py diff --git a/examples/register_configs.py b/examples/register_configs.py deleted file mode 100644 index 2b5bb06..0000000 --- a/examples/register_configs.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -import hydra -from typing import Any -from omegaconf import OmegaConf - -import hydra_configs.torch.optim - -# registers all 'base' optimizer configs with configstore instance -hydra_configs.torch.optim.register_configs() - - -@hydra.main(config_name="torch/optim/adam") -def my_app(cfg: Any) -> None: - print(OmegaConf.to_yaml(cfg)) - - -if __name__ == "__main__": - my_app() diff --git a/hydra-configs-torch/tests/test_register_optimizers.py b/hydra-configs-torch/tests/test_register_optimizers.py new file mode 100644 index 0000000..eef99dd --- /dev/null +++ b/hydra-configs-torch/tests/test_register_optimizers.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from hydra.core.config_store import ConfigStore +import hydra_configs.torch.optim + +cs = ConfigStore() + +# registers all 'base' optimizer configs with configstore instance +hydra_configs.torch.optim.register() + +expected = set( + [ + "adadelta", + "adagrad", + "adam", + "adamw", + "sparseadam", + "adamax", + "asgd", + "sgd", + "lbfgs", + "rprop", + "rmsprop", + ] +) + + +def test_instantiate_classes() -> None: + actual = set([conf.split(".yaml")[0] for conf in cs.list("torch/optim")]) + assert not actual ^ expected From c95ca8d7255348d3e718eeccb35fea8f76d29538 Mon Sep 17 00:00:00 2001 From: romesc Date: Thu, 22 Apr 2021 22:41:18 -0700 Subject: [PATCH 4/7] use ConfigStoreWithProvider --- .../hydra_configs/torch/optim/__init__.py | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/hydra-configs-torch/hydra_configs/torch/optim/__init__.py b/hydra-configs-torch/hydra_configs/torch/optim/__init__.py index bd01f7e..70c3b51 100644 --- a/hydra-configs-torch/hydra_configs/torch/optim/__init__.py +++ b/hydra-configs-torch/hydra_configs/torch/optim/__init__.py @@ -15,24 +15,22 @@ from .lbfgs import LBFGSConf from . import lr_scheduler -from hydra.core.config_store import ConfigStore +from hydra.core.config_store import ConfigStoreWithProvider -def register_configs(): - cs = ConfigStore.instance() - cs.store(provider="torch", group="torch/optim", name="adadelta", node=AdadeltaConf) - cs.store(provider="torch", group="torch/optim", name="adagrad", node=AdagradConf) - cs.store(provider="torch", group="torch/optim", name="adam", node=AdamConf) - cs.store(provider="torch", group="torch/optim", name="adamw", node=AdamWConf) - cs.store( - provider="torch", group="torch/optim", name="sparseadam", node=SparseAdamConf - ) - cs.store(provider="torch", group="torch/optim", name="adamax", node=AdamaxConf) - cs.store(provider="torch", group="torch/optim", name="asgd", node=ASGDConf) - cs.store(provider="torch", group="torch/optim", name="sgd", node=SGDConf) - cs.store(provider="torch", group="torch/optim", name="lbfgs", node=LBFGSConf) - cs.store(provider="torch", group="torch/optim", name="rprop", node=RpropConf) - cs.store(provider="torch", group="torch/optim", name="rmsprop", node=RMSpropConf) +def register(): + with ConfigStoreWithProvider("torch") as cs: + cs.store(group="torch/optim", name="adadelta", node=AdadeltaConf) + cs.store(group="torch/optim", name="adagrad", node=AdagradConf) + cs.store(group="torch/optim", name="adam", node=AdamConf) + cs.store(group="torch/optim", name="adamw", node=AdamWConf) + cs.store(group="torch/optim", name="sparseadam", node=SparseAdamConf) + cs.store(group="torch/optim", name="adamax", node=AdamaxConf) + cs.store(group="torch/optim", name="asgd", node=ASGDConf) + cs.store(group="torch/optim", name="sgd", node=SGDConf) + cs.store(group="torch/optim", name="lbfgs", node=LBFGSConf) + cs.store(group="torch/optim", name="rprop", node=RpropConf) + cs.store(group="torch/optim", name="rmsprop", node=RMSpropConf) del adadelta From 004011ee73c0306d93b16d218b07a61b81df88e3 Mon Sep 17 00:00:00 2001 From: romesc Date: Fri, 23 Apr 2021 14:43:11 -0700 Subject: [PATCH 5/7] add losses to registration --- .../torch/nn/modules/__init__.py | 85 +++++++++++++++++++ .../tests/test_register_losses.py | 37 ++++++++ 2 files changed, 122 insertions(+) create mode 100644 hydra-configs-torch/hydra_configs/torch/nn/modules/__init__.py create mode 100644 hydra-configs-torch/tests/test_register_losses.py diff --git a/hydra-configs-torch/hydra_configs/torch/nn/modules/__init__.py b/hydra-configs-torch/hydra_configs/torch/nn/modules/__init__.py new file mode 100644 index 0000000..e2be697 --- /dev/null +++ b/hydra-configs-torch/hydra_configs/torch/nn/modules/__init__.py @@ -0,0 +1,85 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# flake8: noqa +from .loss import BCELossConf +from .loss import BCEWithLogitsLossConf +from .loss import CosineEmbeddingLossConf +from .loss import CTCLossConf +from .loss import L1LossConf +from .loss import HingeEmbeddingLossConf +from .loss import KLDivLossConf +from .loss import MarginRankingLossConf +from .loss import MSELossConf +from .loss import MultiLabelMarginLossConf +from .loss import MultiLabelSoftMarginLossConf +from .loss import MultiMarginLossConf +from .loss import NLLLossConf +from .loss import NLLLoss2dConf +from .loss import PoissonNLLLossConf +from .loss import SmoothL1LossConf +from .loss import SoftMarginLossConf +from .loss import TripletMarginLossConf +from hydra.core.config_store import ConfigStoreWithProvider + + +def register(): + with ConfigStoreWithProvider("torch") as cs: + cs.store(group="torch/nn/modules/loss", name="bceloss", node=BCELossConf) + cs.store( + group="torch/nn/modules/loss", + name="bcewithlogitsloss", + node=BCEWithLogitsLossConf, + ) + cs.store( + group="torch/nn/modules/loss", + name="cosineembeddingloss", + node=CosineEmbeddingLossConf, + ) + cs.store(group="torch/nn/modules/loss", name="ctcloss", node=CTCLossConf) + cs.store(group="torch/nn/modules/loss", name="l1loss", node=L1LossConf) + cs.store( + group="torch/nn/modules/loss", + name="hingeembeddingloss", + node=HingeEmbeddingLossConf, + ) + cs.store(group="torch/nn/modules/loss", name="kldivloss", node=KLDivLossConf) + cs.store( + group="torch/nn/modules/loss", + name="marginrankingloss", + node=MarginRankingLossConf, + ) + cs.store(group="torch/nn/modules/loss", name="mseloss", node=MSELossConf) + cs.store( + group="torch/nn/modules/loss", + name="multilabelmarginloss", + node=MultiLabelMarginLossConf, + ) + cs.store( + group="torch/nn/modules/loss", + name="multilabelsoftmarginloss", + node=MultiLabelSoftMarginLossConf, + ) + cs.store( + group="torch/nn/modules/loss", + name="multimarginloss", + node=MultiMarginLossConf, + ) + cs.store(group="torch/nn/modules/loss", name="nllloss", node=NLLLossConf) + cs.store(group="torch/nn/modules/loss", name="nllloss2d", node=NLLLoss2dConf) + cs.store( + group="torch/nn/modules/loss", + name="poissonnllloss", + node=PoissonNLLLossConf, + ) + cs.store( + group="torch/nn/modules/loss", name="smoothl1loss", node=SmoothL1LossConf + ) + cs.store( + group="torch/nn/modules/loss", + name="softmarginloss", + node=SoftMarginLossConf, + ) + cs.store( + group="torch/nn/modules/loss", + name="tripletmarginloss", + node=TripletMarginLossConf, + ) diff --git a/hydra-configs-torch/tests/test_register_losses.py b/hydra-configs-torch/tests/test_register_losses.py new file mode 100644 index 0000000..0bfdc1e --- /dev/null +++ b/hydra-configs-torch/tests/test_register_losses.py @@ -0,0 +1,37 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from hydra.core.config_store import ConfigStore +import hydra_configs.torch.nn.modules + +cs = ConfigStore() + + +# registers all 'base' optimizer configs with configstore instance +hydra_configs.torch.nn.modules.register() + +expected = set( + [ + "bceloss", + "bcewithlogitsloss", + "cosineembeddingloss", + "ctcloss", + "l1loss", + "hingeembeddingloss", + "kldivloss", + "marginrankingloss", + "mseloss", + "multilabelmarginloss", + "multilabelsoftmarginloss", + "multimarginloss", + "nllloss", + "nllloss2d", + "poissonnllloss", + "smoothl1loss", + "softmarginloss", + "tripletmarginloss", + ] +) + + +def test_instantiate_classes() -> None: + actual = set([conf.split(".yaml")[0] for conf in cs.list("torch/nn/modules/loss")]) + assert not actual ^ expected From ceada173df1bcc33542ae91305ee9263c888630f Mon Sep 17 00:00:00 2001 From: romesc Date: Fri, 23 Apr 2021 14:58:40 -0700 Subject: [PATCH 6/7] add utils/data to registration --- .../torch/utils/data/__init__.py | 53 +++++++++++++++++++ .../tests/test_register_data.py | 33 ++++++++++++ 2 files changed, 86 insertions(+) create mode 100644 hydra-configs-torch/hydra_configs/torch/utils/data/__init__.py create mode 100644 hydra-configs-torch/tests/test_register_data.py diff --git a/hydra-configs-torch/hydra_configs/torch/utils/data/__init__.py b/hydra-configs-torch/hydra_configs/torch/utils/data/__init__.py new file mode 100644 index 0000000..9c96178 --- /dev/null +++ b/hydra-configs-torch/hydra_configs/torch/utils/data/__init__.py @@ -0,0 +1,53 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# flake8: noqa +from .dataloader import DataLoaderConf +from .dataset import DatasetConf +from .dataset import ChainDatasetConf +from .dataset import ConcatDatasetConf +from .dataset import IterableDatasetConf +from .dataset import TensorDatasetConf +from .dataset import SubsetConf +from .distributed import DistributedSamplerConf +from .sampler import SamplerConf +from .sampler import BatchSamplerConf +from .sampler import RandomSamplerConf +from .sampler import SequentialSamplerConf +from .sampler import SubsetRandomSamplerConf +from .sampler import WeightedRandomSamplerConf +from hydra.core.config_store import ConfigStoreWithProvider + + +def register(): + with ConfigStoreWithProvider("torch") as cs: + cs.store(group="torch/utils/data", name="dataloader", node=DataLoaderConf) + cs.store(group="torch/utils/data", name="dataset", node=DatasetConf) + cs.store(group="torch/utils/data", name="chaindataset", node=ChainDatasetConf) + cs.store(group="torch/utils/data", name="concatdataset", node=ConcatDatasetConf) + cs.store( + group="torch/utils/data", name="iterabledataset", node=IterableDatasetConf + ) + cs.store(group="torch/utils/data", name="tensordataset", node=TensorDatasetConf) + cs.store(group="torch/utils/data", name="subset", node=SubsetConf) + cs.store( + group="torch/utils/data", + name="distributedsampler", + node=DistributedSamplerConf, + ) + cs.store(group="torch/utils/data", name="sampler", node=SamplerConf) + cs.store(group="torch/utils/data", name="batchsampler", node=BatchSamplerConf) + cs.store(group="torch/utils/data", name="randomsampler", node=RandomSamplerConf) + cs.store( + group="torch/utils/data", + name="sequentialsampler", + node=SequentialSamplerConf, + ) + cs.store( + group="torch/utils/data", + name="subsetrandomsampler", + node=SubsetRandomSamplerConf, + ) + cs.store( + group="torch/utils/data", + name="weightedrandomsampler", + node=WeightedRandomSamplerConf, + ) diff --git a/hydra-configs-torch/tests/test_register_data.py b/hydra-configs-torch/tests/test_register_data.py new file mode 100644 index 0000000..3d1b208 --- /dev/null +++ b/hydra-configs-torch/tests/test_register_data.py @@ -0,0 +1,33 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from hydra.core.config_store import ConfigStore +import hydra_configs.torch.utils.data + +cs = ConfigStore() + + +# registers all 'base' optimizer configs with configstore instance +hydra_configs.torch.utils.data.register() + +expected = set( + [ + "dataloader", + "dataset", + "chaindataset", + "concatdataset", + "iterabledataset", + "tensordataset", + "subset", + "distributedsampler", + "sampler", + "batchsampler", + "randomsampler", + "sequentialsampler", + "subsetrandomsampler", + "weightedrandomsampler", + ] +) + + +def test_instantiate_classes() -> None: + actual = set([conf.split(".yaml")[0] for conf in cs.list("torch/utils/data")]) + assert not actual ^ expected From fe4114427c1f0376eef25a51f011bfa9e8d8d6f5 Mon Sep 17 00:00:00 2001 From: romesc Date: Fri, 23 Apr 2021 15:15:38 -0700 Subject: [PATCH 7/7] add hierarchical registration call --- hydra-configs-torch/hydra_configs/torch/__init__.py | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 hydra-configs-torch/hydra_configs/torch/__init__.py diff --git a/hydra-configs-torch/hydra_configs/torch/__init__.py b/hydra-configs-torch/hydra_configs/torch/__init__.py new file mode 100644 index 0000000..15eff3d --- /dev/null +++ b/hydra-configs-torch/hydra_configs/torch/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# flake8: noqa +from .optim import register as optim_register +from .nn.modules import register as modules_register +from .utils.data import register as data_register + + +def register(): + optim_register() + modules_register() + data_register()