Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dev] Config Registration Functions in __init__.py #72

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions hydra-configs-torch/hydra_configs/torch/nn/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# flake8: noqa
from .loss import BCELossConf
from .loss import BCEWithLogitsLossConf
from .loss import CosineEmbeddingLossConf
from .loss import CTCLossConf
from .loss import L1LossConf
from .loss import HingeEmbeddingLossConf
from .loss import KLDivLossConf
from .loss import MarginRankingLossConf
from .loss import MSELossConf
from .loss import MultiLabelMarginLossConf
from .loss import MultiLabelSoftMarginLossConf
from .loss import MultiMarginLossConf
from .loss import NLLLossConf
from .loss import NLLLoss2dConf
from .loss import PoissonNLLLossConf
from .loss import SmoothL1LossConf
from .loss import SoftMarginLossConf
from .loss import TripletMarginLossConf
Comment on lines +3 to +20
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels like this would better be inside loss.py and __init__.py would just call it from its register function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense there, but loss.py is a configen output. This is manual code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should create a convention:

loss.py # generated
loss_handcoded.py # manual additions.

At this point, the __init__.py file can import them both (and possibly also export from both, making them appear like one file).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a bad idea. I think we need a convention for handling manual configs anyways.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here is an even better idea.
We can add support to custom import to configen at the top and the bottom of the generated configs.
Allowing automatic import of user provided files if they exist.

Copy link
Contributor Author

@romesco romesco Apr 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate on this?

Do you mean chain the 'manual' file to the 'generated' one through importing that is automatically added by configen if a field is passed in the configen conf yaml?

Copy link
Contributor

@omry omry Apr 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not exactly.
I mean importing if a file with a specific name exists.

generated/foo.py

try:
 from . import foo_header
except ImportError:
  pass

# generated code here

try:
 from . import foo_footer
except ImportError:
  pass

This way, you can add manual foo_footer.py and foo_header.py files that will get imported automatically when the user imports foo.py.

from hydra.core.config_store import ConfigStoreWithProvider


def register():
with ConfigStoreWithProvider("torch") as cs:
cs.store(group="torch/nn/modules/loss", name="bceloss", node=BCELossConf)
cs.store(
group="torch/nn/modules/loss",
name="bcewithlogitsloss",
node=BCEWithLogitsLossConf,
)
cs.store(
group="torch/nn/modules/loss",
name="cosineembeddingloss",
node=CosineEmbeddingLossConf,
)
cs.store(group="torch/nn/modules/loss", name="ctcloss", node=CTCLossConf)
cs.store(group="torch/nn/modules/loss", name="l1loss", node=L1LossConf)
cs.store(
group="torch/nn/modules/loss",
name="hingeembeddingloss",
node=HingeEmbeddingLossConf,
)
cs.store(group="torch/nn/modules/loss", name="kldivloss", node=KLDivLossConf)
cs.store(
group="torch/nn/modules/loss",
name="marginrankingloss",
node=MarginRankingLossConf,
)
cs.store(group="torch/nn/modules/loss", name="mseloss", node=MSELossConf)
cs.store(
group="torch/nn/modules/loss",
name="multilabelmarginloss",
node=MultiLabelMarginLossConf,
)
cs.store(
group="torch/nn/modules/loss",
name="multilabelsoftmarginloss",
node=MultiLabelSoftMarginLossConf,
)
cs.store(
group="torch/nn/modules/loss",
name="multimarginloss",
node=MultiMarginLossConf,
)
cs.store(group="torch/nn/modules/loss", name="nllloss", node=NLLLossConf)
cs.store(group="torch/nn/modules/loss", name="nllloss2d", node=NLLLoss2dConf)
cs.store(
group="torch/nn/modules/loss",
name="poissonnllloss",
node=PoissonNLLLossConf,
)
cs.store(
group="torch/nn/modules/loss", name="smoothl1loss", node=SmoothL1LossConf
)
cs.store(
group="torch/nn/modules/loss",
name="softmarginloss",
node=SoftMarginLossConf,
)
cs.store(
group="torch/nn/modules/loss",
name="tripletmarginloss",
node=TripletMarginLossConf,
)
18 changes: 18 additions & 0 deletions hydra-configs-torch/hydra_configs/torch/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# flake8: noqa
# Mirrors torch/optim __init__ to allow for symmetric import structure
from .adadelta import AdadeltaConf
Expand All @@ -14,6 +15,23 @@
from .lbfgs import LBFGSConf
from . import lr_scheduler

from hydra.core.config_store import ConfigStoreWithProvider


def register():
with ConfigStoreWithProvider("torch") as cs:
cs.store(group="torch/optim", name="adadelta", node=AdadeltaConf)
cs.store(group="torch/optim", name="adagrad", node=AdagradConf)
cs.store(group="torch/optim", name="adam", node=AdamConf)
cs.store(group="torch/optim", name="adamw", node=AdamWConf)
cs.store(group="torch/optim", name="sparseadam", node=SparseAdamConf)
cs.store(group="torch/optim", name="adamax", node=AdamaxConf)
cs.store(group="torch/optim", name="asgd", node=ASGDConf)
cs.store(group="torch/optim", name="sgd", node=SGDConf)
cs.store(group="torch/optim", name="lbfgs", node=LBFGSConf)
cs.store(group="torch/optim", name="rprop", node=RpropConf)
cs.store(group="torch/optim", name="rmsprop", node=RMSpropConf)


del adadelta
del adagrad
Expand Down
37 changes: 37 additions & 0 deletions hydra-configs-torch/tests/test_register_losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from hydra.core.config_store import ConfigStore
import hydra_configs.torch.nn.modules

cs = ConfigStore()


# registers all 'base' optimizer configs with configstore instance
hydra_configs.torch.nn.modules.register()

expected = set(
[
"bceloss",
"bcewithlogitsloss",
"cosineembeddingloss",
"ctcloss",
"l1loss",
"hingeembeddingloss",
"kldivloss",
"marginrankingloss",
"mseloss",
"multilabelmarginloss",
"multilabelsoftmarginloss",
"multimarginloss",
"nllloss",
"nllloss2d",
"poissonnllloss",
"smoothl1loss",
"softmarginloss",
"tripletmarginloss",
]
)


def test_instantiate_classes() -> None:
actual = set([conf.split(".yaml")[0] for conf in cs.list("torch/nn/modules/loss")])
assert not actual ^ expected
29 changes: 29 additions & 0 deletions hydra-configs-torch/tests/test_register_optimizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from hydra.core.config_store import ConfigStore
import hydra_configs.torch.optim

cs = ConfigStore()

# registers all 'base' optimizer configs with configstore instance
hydra_configs.torch.optim.register()

expected = set(
[
"adadelta",
"adagrad",
"adam",
"adamw",
"sparseadam",
"adamax",
"asgd",
"sgd",
"lbfgs",
"rprop",
"rmsprop",
]
)


def test_instantiate_classes() -> None:
actual = set([conf.split(".yaml")[0] for conf in cs.list("torch/optim")])
assert not actual ^ expected