Skip to content

Commit

Permalink
Add mart.nn.Get() to extract a value from the kwargs dict. (#251)
Browse files Browse the repository at this point in the history
* Add batch_c15n for [0,1] image input and imagenet-normalized input.

* Turn off inference mode before creating perturbations.

* Switch to training mode before running LightningModule.training_step().

* Add utils for config instantiation.

* Add mart.utils.Get() to extract a value from kwargs dict.

* Comment

* Clean up.

* Move to mart.nn.Get().

* Add support to multi-level dicts.
  • Loading branch information
mzweilin authored May 16, 2024
1 parent 07409b4 commit 3689976
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion mart/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torch

__all__ = ["GroupNorm32", "SequentialDict", "ReturnKwargs", "CallWith", "Sum"]
__all__ = ["GroupNorm32", "SequentialDict", "ReturnKwargs", "CallWith", "Sum", "Get"]

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -300,3 +300,19 @@ def __init__(self):

def forward(self, *args):
return sum(args)


class Get:
"""Get a value from the kwargs dictionary by key.
The key can be a path to a nested dictionary, concatenated by dots. For example,
`Get(key="a.b")(a={"b": 1}) == 1`.
"""

def __init__(self, key):
self.key = key

def __call__(self, **kwargs):
# Add support to nested dicts.
kwargs = DotDict(kwargs)
return kwargs[self.key]

0 comments on commit 3689976

Please sign in to comment.