Skip to content

Commit

Permalink
Recursive instantiation support
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Sep 28, 2020
1 parent e7beb13 commit f386932
Show file tree
Hide file tree
Showing 31 changed files with 1,078 additions and 508 deletions.
10 changes: 10 additions & 0 deletions examples/instantiate/docs_example/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
trainer:
_target_: my_app.Trainer
optimizer:
_target_: my_app.Optimizer
algo: SGD
lr: 0.01
dataset:
_target_: my_app.Dataset
name: Imagenet
path: /datasets/imagenet
83 changes: 83 additions & 0 deletions examples/instantiate/docs_example/my_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from omegaconf import DictConfig

import hydra
from hydra.utils import instantiate


class Optimizer:
algo: str
lr: float

def __init__(self, algo: str, lr: float) -> None:
self.algo = algo
self.lr = lr

def __repr__(self) -> str:
return f"Optimizer(algo={self.algo},lr={self.lr})"


class Dataset:
name: str
path: str

def __init__(self, name: str, path: str) -> None:
self.name = name
self.path = path

def __repr__(self) -> str:
return f"Dataset(name={self.name}, path={self.path})"


class Trainer:
def __init__(self, optimizer: Optimizer, dataset: Dataset) -> None:
self.optimizer = optimizer
self.dataset = dataset

def __repr__(self) -> str:
return f"Trainer(\n optimizer={self.optimizer},\n dataset={self.dataset}\n)"


@hydra.main(config_name="config")
def my_app(cfg: DictConfig) -> None:
optimizer = instantiate(cfg.trainer.optimizer)
print(optimizer)
# Optimizer(algo=SGD,lr=0.01)

# override parameters on the call-site
optimizer = instantiate(cfg.trainer.optimizer, lr=0.2)
print(optimizer)
# Optimizer(algo=SGD,lr=0.2)

# recursive instantiation
trainer = instantiate(cfg.trainer)
print(trainer)
# Trainer(
# optimizer=Optimizer(algo=SGD,lr=0.01),
# dataset=Dataset(name=Imagenet, path=/datasets/imagenet)
# )

# override nested parameters from the call-site
trainer = instantiate(
cfg.trainer,
optimizer={"lr": 0.3},
dataset={"name": "cifar10", "path": "/datasets/cifar10"},
)
print(trainer)
# Trainer(
# optimizer=Optimizer(algo=SGD,lr=0.3),
# dataset=Dataset(name=cifar10, path=/datasets/cifar10)
# )

# non recursive instantiation
optimizer = instantiate(cfg.trainer, _recursive_=False)
print(optimizer)
# Trainer(
# optimizer={'_target_': 'my_app.Optimizer', 'algo': 'SGD', 'lr': 0.01},
# dataset={'_target_': 'my_app.Dataset', 'name': 'Imagenet', 'path': '/datasets/imagenet'}
# )


if __name__ == "__main__":
my_app()
File renamed without changes.
19 changes: 19 additions & 0 deletions examples/instantiate/object_recursive/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
car:
_target_: my_app.Car
driver:
_target_: my_app.Driver
name: James Bond
age: 7
wheels:
- _target_: my_app.Wheel
radius: 20
width: 1
- _target_: my_app.Wheel
radius: 20
width: 1
- _target_: my_app.Wheel
radius: 20
width: 1
- _target_: my_app.Wheel
radius: 20
width: 1
38 changes: 38 additions & 0 deletions examples/instantiate/object_recursive/my_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List

from omegaconf import DictConfig

import hydra
from hydra.utils import instantiate


class Driver:
def __init__(self, name: str, age: int) -> None:
self.name = name
self.age = age


class Wheel:
def __init__(self, radius: int, width: int) -> None:
self.radius = radius
self.width = width


class Car:
def __init__(self, driver: Driver, wheels: List[Wheel]):
self.driver = driver
self.wheels = wheels

def drive(self) -> None:
print(f"Driver : {self.driver.name}, {len(self.wheels)} wheels")


@hydra.main(config_name="config")
def my_app(cfg: DictConfig) -> None:
car: Car = instantiate(cfg.car)
car.drive()


if __name__ == "__main__":
my_app()
File renamed without changes.
10 changes: 10 additions & 0 deletions examples/instantiate/schema_recursive/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
tree:
value: 1
left:
value: 20
right:
value: 30
right:
value: 300
left:
value: 400
56 changes: 56 additions & 0 deletions examples/instantiate/schema_recursive/my_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from dataclasses import dataclass
from typing import Optional

from omegaconf import MISSING

import hydra
from hydra.core.config_store import ConfigStore
from hydra.utils import instantiate


class Tree:
def __init__(
self, value: int, left: Optional["Tree"] = None, right: Optional["Tree"] = None
) -> None:
self.value = value
self.left = left
self.right = right


@dataclass
class TreeConf:
_target_: str = "my_app.Tree"
value: int = MISSING
left: Optional["TreeConf"] = None
right: Optional["TreeConf"] = None


# we will populate the tree from the config file with the matching name
@dataclass
class Config:
tree: TreeConf = MISSING


cs = ConfigStore.instance()
cs.store(name="config", node=Config)


# pretty print utility
def pretty_print(tree: Tree, name: str = "root", depth: int = 0) -> None:
pad = " " * depth * 2
print(f"{pad}{name}({tree.value})")
if tree.left is not None:
pretty_print(tree.left, name="left", depth=depth + 1)
if tree.right is not None:
pretty_print(tree.right, name="right", depth=depth + 1)


@hydra.main(config_name="config")
def my_app(cfg: Config) -> None:
tree: Tree = instantiate(cfg.tree)
pretty_print(tree)


if __name__ == "__main__":
my_app()
Loading

0 comments on commit f386932

Please sign in to comment.