-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dev] Config Registration Functions in __init__.py #72
Open
romesco
wants to merge
7
commits into
main
Choose a base branch
from
dev/config_registration
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 5 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
e6836be
optim config registration function
romesco f5dad49
reformat example
romesco b02b193
add test, remove example
romesco c95ca8d
use ConfigStoreWithProvider
romesco 004011e
add losses to registration
romesco ceada17
add utils/data to registration
romesco fe41144
add hierarchical registration call
romesco File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
85 changes: 85 additions & 0 deletions
85
hydra-configs-torch/hydra_configs/torch/nn/modules/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
# flake8: noqa | ||
from .loss import BCELossConf | ||
from .loss import BCEWithLogitsLossConf | ||
from .loss import CosineEmbeddingLossConf | ||
from .loss import CTCLossConf | ||
from .loss import L1LossConf | ||
from .loss import HingeEmbeddingLossConf | ||
from .loss import KLDivLossConf | ||
from .loss import MarginRankingLossConf | ||
from .loss import MSELossConf | ||
from .loss import MultiLabelMarginLossConf | ||
from .loss import MultiLabelSoftMarginLossConf | ||
from .loss import MultiMarginLossConf | ||
from .loss import NLLLossConf | ||
from .loss import NLLLoss2dConf | ||
from .loss import PoissonNLLLossConf | ||
from .loss import SmoothL1LossConf | ||
from .loss import SoftMarginLossConf | ||
from .loss import TripletMarginLossConf | ||
from hydra.core.config_store import ConfigStoreWithProvider | ||
|
||
|
||
def register(): | ||
with ConfigStoreWithProvider("torch") as cs: | ||
cs.store(group="torch/nn/modules/loss", name="bceloss", node=BCELossConf) | ||
cs.store( | ||
group="torch/nn/modules/loss", | ||
name="bcewithlogitsloss", | ||
node=BCEWithLogitsLossConf, | ||
) | ||
cs.store( | ||
group="torch/nn/modules/loss", | ||
name="cosineembeddingloss", | ||
node=CosineEmbeddingLossConf, | ||
) | ||
cs.store(group="torch/nn/modules/loss", name="ctcloss", node=CTCLossConf) | ||
cs.store(group="torch/nn/modules/loss", name="l1loss", node=L1LossConf) | ||
cs.store( | ||
group="torch/nn/modules/loss", | ||
name="hingeembeddingloss", | ||
node=HingeEmbeddingLossConf, | ||
) | ||
cs.store(group="torch/nn/modules/loss", name="kldivloss", node=KLDivLossConf) | ||
cs.store( | ||
group="torch/nn/modules/loss", | ||
name="marginrankingloss", | ||
node=MarginRankingLossConf, | ||
) | ||
cs.store(group="torch/nn/modules/loss", name="mseloss", node=MSELossConf) | ||
cs.store( | ||
group="torch/nn/modules/loss", | ||
name="multilabelmarginloss", | ||
node=MultiLabelMarginLossConf, | ||
) | ||
cs.store( | ||
group="torch/nn/modules/loss", | ||
name="multilabelsoftmarginloss", | ||
node=MultiLabelSoftMarginLossConf, | ||
) | ||
cs.store( | ||
group="torch/nn/modules/loss", | ||
name="multimarginloss", | ||
node=MultiMarginLossConf, | ||
) | ||
cs.store(group="torch/nn/modules/loss", name="nllloss", node=NLLLossConf) | ||
cs.store(group="torch/nn/modules/loss", name="nllloss2d", node=NLLLoss2dConf) | ||
cs.store( | ||
group="torch/nn/modules/loss", | ||
name="poissonnllloss", | ||
node=PoissonNLLLossConf, | ||
) | ||
cs.store( | ||
group="torch/nn/modules/loss", name="smoothl1loss", node=SmoothL1LossConf | ||
) | ||
cs.store( | ||
group="torch/nn/modules/loss", | ||
name="softmarginloss", | ||
node=SoftMarginLossConf, | ||
) | ||
cs.store( | ||
group="torch/nn/modules/loss", | ||
name="tripletmarginloss", | ||
node=TripletMarginLossConf, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
from hydra.core.config_store import ConfigStore | ||
import hydra_configs.torch.nn.modules | ||
|
||
cs = ConfigStore() | ||
|
||
|
||
# registers all 'base' optimizer configs with configstore instance | ||
hydra_configs.torch.nn.modules.register() | ||
|
||
expected = set( | ||
[ | ||
"bceloss", | ||
"bcewithlogitsloss", | ||
"cosineembeddingloss", | ||
"ctcloss", | ||
"l1loss", | ||
"hingeembeddingloss", | ||
"kldivloss", | ||
"marginrankingloss", | ||
"mseloss", | ||
"multilabelmarginloss", | ||
"multilabelsoftmarginloss", | ||
"multimarginloss", | ||
"nllloss", | ||
"nllloss2d", | ||
"poissonnllloss", | ||
"smoothl1loss", | ||
"softmarginloss", | ||
"tripletmarginloss", | ||
] | ||
) | ||
|
||
|
||
def test_instantiate_classes() -> None: | ||
actual = set([conf.split(".yaml")[0] for conf in cs.list("torch/nn/modules/loss")]) | ||
assert not actual ^ expected |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
from hydra.core.config_store import ConfigStore | ||
import hydra_configs.torch.optim | ||
|
||
cs = ConfigStore() | ||
|
||
# registers all 'base' optimizer configs with configstore instance | ||
hydra_configs.torch.optim.register() | ||
romesco marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
expected = set( | ||
[ | ||
"adadelta", | ||
"adagrad", | ||
"adam", | ||
"adamw", | ||
"sparseadam", | ||
"adamax", | ||
"asgd", | ||
"sgd", | ||
"lbfgs", | ||
"rprop", | ||
"rmsprop", | ||
] | ||
) | ||
|
||
|
||
def test_instantiate_classes() -> None: | ||
actual = set([conf.split(".yaml")[0] for conf in cs.list("torch/optim")]) | ||
assert not actual ^ expected |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feels like this would better be inside loss.py and
__init__.py
would just call it from its register function.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It makes sense there, but
loss.py
is a configen output. This is manual code.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we should create a convention:
At this point, the
__init__.py
file can import them both (and possibly also export from both, making them appear like one file).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a bad idea. I think we need a convention for handling manual configs anyways.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here is an even better idea.
We can add support to custom import to configen at the top and the bottom of the generated configs.
Allowing automatic import of user provided files if they exist.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you elaborate on this?
Do you mean chain the 'manual' file to the 'generated' one through importing that is automatically added by configen if a field is passed in the configen conf yaml?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not exactly.
I mean importing if a file with a specific name exists.
generated/foo.py
This way, you can add manual foo_footer.py and foo_header.py files that will get imported automatically when the user imports foo.py.