-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP][hydra-configs-torchvision] v0.7 models #65
base: main
Are you sure you want to change the base?
Changes from all commits
ae95a69
266183e
aa906b1
a16b1b9
cf3cf9e
4ef4b03
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
defaults: | ||
- configen_schema | ||
|
||
configen: | ||
# output directory | ||
output_dir: ${hydra:runtime.cwd} | ||
|
||
header: | | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
# | ||
# Generated by configen, do not edit. | ||
# See https://github.com/facebookresearch/hydra/tree/master/tools/configen | ||
# fmt: off | ||
# isort:skip_file | ||
# flake8: noqa | ||
|
||
module_path_pattern: "hydra_configs/{{module_path}}.py" | ||
|
||
# list of modules to generate configs for | ||
modules: | ||
- name: torchvision.datasets.vision | ||
classes: | ||
- VisionDataset | ||
- StandardTransform | ||
|
||
- name: torchvision.datasets.mnist | ||
# mnist datasets | ||
classes: | ||
- MNIST | ||
- FashionMNIST | ||
- KMNIST | ||
# TODO: The following need to be manually created for torchvision==0.7 | ||
# - EMNIST | ||
# - QMNIST | ||
|
||
- name: torchvision.models.alexnet | ||
classes: | ||
- AlexNet | ||
|
||
- name: torchvision.models.densenet | ||
classes: | ||
- DenseNet | ||
|
||
- name: torchvision.models.googlenet | ||
classes: | ||
- GoogLeNet | ||
|
||
- name: torchvision.models.mnasnet | ||
classes: | ||
- MNASNet | ||
|
||
- name: torchvision.models.squeezenet | ||
classes: | ||
- SqueezeNet | ||
|
||
- name: torchvision.models.resnet | ||
classes: | ||
- ResNet |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
# | ||
# Generated by configen, do not edit. | ||
# See https://github.com/facebookresearch/hydra/tree/master/tools/configen | ||
# fmt: off | ||
# isort:skip_file | ||
# flake8: noqa | ||
|
||
from dataclasses import dataclass, field | ||
from typing import Any | ||
|
||
|
||
@dataclass | ||
class AlexNetConf: | ||
_target_: str = "torchvision.models.alexnet.AlexNet" | ||
num_classes: Any = 1000 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
# | ||
# Generated by configen, do not edit. | ||
# See https://github.com/facebookresearch/hydra/tree/master/tools/configen | ||
# fmt: off | ||
# isort:skip_file | ||
# flake8: noqa | ||
|
||
from dataclasses import dataclass, field | ||
from typing import Any | ||
|
||
|
||
@dataclass | ||
class DenseNetConf: | ||
_target_: str = "torchvision.models.densenet.DenseNet" | ||
growth_rate: Any = 32 | ||
block_config: Any = (6, 12, 24, 16) | ||
num_init_features: Any = 64 | ||
bn_size: Any = 4 | ||
drop_rate: Any = 0 | ||
num_classes: Any = 1000 | ||
memory_efficient: Any = False |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
# | ||
# Generated by configen, do not edit. | ||
# See https://github.com/facebookresearch/hydra/tree/master/tools/configen | ||
# fmt: off | ||
# isort:skip_file | ||
# flake8: noqa | ||
|
||
from dataclasses import dataclass, field | ||
from typing import Any | ||
|
||
|
||
@dataclass | ||
class GoogLeNetConf: | ||
_target_: str = "torchvision.models.googlenet.GoogLeNet" | ||
num_classes: Any = 1000 | ||
aux_logits: Any = True | ||
transform_input: Any = False | ||
init_weights: Any = None | ||
blocks: Any = None |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
# | ||
# Generated by configen, do not edit. | ||
# See https://github.com/facebookresearch/hydra/tree/master/tools/configen | ||
# fmt: off | ||
# isort:skip_file | ||
# flake8: noqa | ||
|
||
from dataclasses import dataclass, field | ||
from omegaconf import MISSING | ||
from typing import Any | ||
|
||
|
||
@dataclass | ||
class MNASNetConf: | ||
_target_: str = "torchvision.models.mnasnet.MNASNet" | ||
alpha: Any = MISSING | ||
num_classes: Any = 1000 | ||
dropout: Any = 0.2 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
# | ||
# Generated by configen, do not edit. | ||
# See https://github.com/facebookresearch/hydra/tree/master/tools/configen | ||
# fmt: off | ||
# isort:skip_file | ||
# flake8: noqa | ||
|
||
from dataclasses import dataclass, field | ||
from omegaconf import MISSING | ||
from typing import Any | ||
|
||
|
||
@dataclass | ||
class ResNetConf: | ||
_target_: str = "torchvision.models.resnet.ResNet" | ||
block: Any = MISSING | ||
layers: Any = MISSING | ||
num_classes: Any = 1000 | ||
zero_init_residual: Any = False | ||
groups: Any = 1 | ||
width_per_group: Any = 64 | ||
replace_stride_with_dilation: Any = None | ||
norm_layer: Any = None |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
# | ||
# Generated by configen, do not edit. | ||
# See https://github.com/facebookresearch/hydra/tree/master/tools/configen | ||
# fmt: off | ||
# isort:skip_file | ||
# flake8: noqa | ||
|
||
from dataclasses import dataclass, field | ||
from typing import Any | ||
|
||
|
||
@dataclass | ||
class SqueezeNetConf: | ||
_target_: str = "torchvision.models.squeezenet.SqueezeNet" | ||
version: Any = "1_0" | ||
num_classes: Any = 1000 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
import pytest | ||
from hydra.utils import get_class, instantiate | ||
from omegaconf import OmegaConf | ||
|
||
import torchvision.models as models | ||
|
||
|
||
from torchvision.models.resnet import BasicBlock | ||
from torchvision.models.resnet import Bottleneck | ||
from typing import Any | ||
|
||
bb = BasicBlock(10, 10) | ||
mnasnet_dict = {"alpha": 1.0, "num_classes": 1000} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is this for? Looks like it's unused. You can probably remove it. |
||
|
||
|
||
@pytest.mark.parametrize( | ||
"modulepath, classname, cfg, passthrough_args, passthrough_kwargs, expected", | ||
[ | ||
pytest.param( | ||
"models.alexnet", | ||
"AlexNet", | ||
{}, | ||
[], | ||
{}, | ||
models.AlexNet(), | ||
id="AlexNetConf", | ||
), | ||
pytest.param( | ||
"models.resnet", | ||
"ResNet", | ||
{"layers": [2, 2, 2, 2]}, | ||
[], | ||
{"block": Bottleneck}, | ||
models.ResNet(block=Bottleneck, layers=[2, 2, 2, 2]), | ||
id="ResNetConf", | ||
), | ||
pytest.param( | ||
"models.densenet", | ||
"DenseNet", | ||
{}, | ||
[], | ||
{}, | ||
models.DenseNet(), | ||
id="DenseNetConf", | ||
), | ||
pytest.param( | ||
"models.squeezenet", | ||
"SqueezeNet", | ||
{}, | ||
[], | ||
{}, | ||
models.SqueezeNet(), | ||
id="SqueezeNetConf", | ||
), | ||
pytest.param( | ||
"models.mnasnet", | ||
"MNASNet", | ||
{"alpha": 1.0}, | ||
[], | ||
{}, | ||
models.MNASNet(alpha=1.0), | ||
id="MNASNetConf", | ||
), | ||
pytest.param( | ||
"models.googlenet", | ||
"GoogLeNet", | ||
{}, | ||
[], | ||
{}, | ||
models.GoogLeNet(), | ||
id="GoogleNetConf", | ||
), | ||
], | ||
) | ||
def test_instantiate_classes( | ||
modulepath: str, | ||
classname: str, | ||
cfg: Any, | ||
passthrough_args: Any, | ||
passthrough_kwargs: Any, | ||
expected: Any, | ||
) -> None: | ||
full_class = f"hydra_configs.torchvision.{modulepath}.{classname}Conf" | ||
schema = OmegaConf.structured(get_class(full_class)) | ||
cfg = OmegaConf.merge(schema, cfg) | ||
obj = instantiate(cfg, *passthrough_args, **passthrough_kwargs) | ||
|
||
assert isinstance(obj, type(expected)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ git+https://github.com/facebookresearch/hydra#subdirectory=tools/configen | |
git+https://github.com/facebookresearch/hydra | ||
torch==1.6.0 | ||
torchvision==0.7.0 | ||
scipy | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We shouldn't make the primary requirements.txt the union of the dependencies of all subprojects. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah we need to migrate this out of here. I think we should generally have per project dependencies for both using the library and running tests / doing dev. It's not clear that someone who wants to use torchvision configs will ever use a model that depends on scipy for example, so we shouldn't force them to install it. But if they plan to do dev/run tests, then it makes sense. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Otherwise I think this PR looks good. Maybe after the minor changes, we can merge and then I can follow up with a reorganization of how we manage dependencies across projects in a new PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The tests seems a bit too shallow to me. Is it really the best we can do to verify that the instantiated objects are good? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that the tests are minimal. @shivamdb, would it be possible to check that both:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. About the dependencies: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's the plan. We can have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As long as we have a handful of projects, we can also have a unified dev dependencies for all of them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressing the above comments in #59 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might suggest tidying up the imports here slightly. Not the end of the world, but I think it would be better to follow PEP8:
I'm probably also guilty of this in some of the other tests, but I want to continue improving the standards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest to integrate with isort.