-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve ludwig feature dict #3904
base: master
Are you sure you want to change the base?
Changes from 4 commits
c27d50d
f8d7982
9dc7ef8
4ca184a
c5dcda9
3edaa08
cd36fed
4a9aa22
326c40d
1c685cf
23a2da7
acfa198
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,13 +65,13 @@ def __iter__(self) -> None: | |
return iter(self.obj.keys()) | ||
|
||
def keys(self) -> List[str]: | ||
return self.obj.keys() | ||
return self.obj.key_list() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @dennisrall Do you think it might be simpler to retain the coding pattern of returning |
||
|
||
def values(self) -> List[torch.nn.Module]: | ||
return self.obj.values() | ||
return self.obj.value_list() | ||
|
||
def items(self) -> List[Tuple[str, torch.nn.Module]]: | ||
return self.obj.items() | ||
return self.obj.item_list() | ||
|
||
def update(self, modules: Dict[str, torch.nn.Module]) -> None: | ||
self.obj.update(modules) | ||
|
@@ -148,7 +148,8 @@ def __init__( | |
) | ||
|
||
# Extract the decoder object for the forward pass | ||
self._output_feature_decoder = ModuleWrapper(self.output_features.items()[0][1]) | ||
decoder = next(iter(self.output_features.values())) | ||
self._output_feature_decoder = ModuleWrapper(decoder) | ||
|
||
self.attention_masks = None | ||
|
||
|
@@ -401,7 +402,7 @@ def _unpack_inputs( | |
else: | ||
targets = None | ||
|
||
assert list(inputs.keys()) == self.input_features.keys() | ||
assert list(inputs.keys()) == list(self.input_features.keys()) | ||
|
||
input_ids = self.get_input_ids(inputs) | ||
target_ids = self.get_target_ids(targets) if targets else None | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,114 @@ | |
from ludwig.features import feature_utils | ||
|
||
|
||
@pytest.fixture | ||
def to_module() -> torch.nn.Module: | ||
return torch.nn.Module() | ||
|
||
|
||
@pytest.fixture | ||
def type_module() -> torch.nn.Module: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @dennisrall This fixture and the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your feedback. I added a docstring for both fixtures and also for the |
||
return torch.nn.Module() | ||
|
||
|
||
@pytest.fixture | ||
def feature_dict(to_module: torch.nn.Module, type_module: torch.nn.Module) -> feature_utils.LudwigFeatureDict: | ||
fdict = feature_utils.LudwigFeatureDict() | ||
fdict.set("to", to_module) | ||
fdict["type"] = type_module | ||
return fdict | ||
|
||
|
||
def test_ludwig_feature_dict_get( | ||
feature_dict: feature_utils.LudwigFeatureDict, to_module: torch.nn.Module, type_module: torch.nn.Module | ||
): | ||
assert feature_dict["to"] == to_module | ||
assert feature_dict.get("type") == type_module | ||
assert feature_dict.get("other_key", default=None) is None | ||
|
||
|
||
def test_ludwig_feature_dict_keys(feature_dict: feature_utils.LudwigFeatureDict): | ||
assert list(feature_dict.keys()) == ["to", "type"] | ||
assert feature_dict.key_list() == ["to", "type"] | ||
|
||
|
||
def test_ludwig_feature_dict_values( | ||
feature_dict: feature_utils.LudwigFeatureDict, to_module: torch.nn.Module, type_module: torch.nn.Module | ||
): | ||
assert list(feature_dict.values()) == [to_module, type_module] | ||
assert feature_dict.value_list() == [to_module, type_module] | ||
|
||
|
||
def test_ludwig_feature_dict_items( | ||
feature_dict: feature_utils.LudwigFeatureDict, to_module: torch.nn.Module, type_module: torch.nn.Module | ||
): | ||
assert list(feature_dict.items()) == [("to", to_module), ("type", type_module)] | ||
assert feature_dict.item_list() == [("to", to_module), ("type", type_module)] | ||
|
||
|
||
def test_ludwig_feature_dict_iter(feature_dict: feature_utils.LudwigFeatureDict): | ||
assert list(iter(feature_dict)) == ["to", "type"] | ||
assert list(feature_dict) == ["to", "type"] | ||
|
||
|
||
def test_ludwig_feature_dict_len(feature_dict: feature_utils.LudwigFeatureDict): | ||
assert len(feature_dict) == 2 | ||
|
||
|
||
def test_ludwig_feature_dict_contains(feature_dict: feature_utils.LudwigFeatureDict): | ||
assert "to" in feature_dict and "type" in feature_dict | ||
|
||
|
||
def test_ludwig_feature_dict_eq(feature_dict: feature_utils.LudwigFeatureDict): | ||
other_dict = feature_utils.LudwigFeatureDict() | ||
assert not feature_dict == other_dict | ||
other_dict.update(feature_dict.item_list()) | ||
assert feature_dict == other_dict | ||
|
||
|
||
def test_ludwig_feature_dict_update( | ||
feature_dict: feature_utils.LudwigFeatureDict, to_module: torch.nn.Module, type_module: torch.nn.Module | ||
): | ||
feature_dict.update({"to": torch.nn.Module(), "new": torch.nn.Module()}) | ||
assert len(feature_dict) == 3 | ||
assert not feature_dict.get("to") == to_module | ||
assert feature_dict.get("type") == type_module | ||
|
||
|
||
def test_ludwig_feature_dict_del(feature_dict: feature_utils.LudwigFeatureDict): | ||
del feature_dict["to"] | ||
assert len(feature_dict) == 1 | ||
|
||
|
||
def test_ludwig_feature_dict_clear(feature_dict: feature_utils.LudwigFeatureDict): | ||
feature_dict.clear() | ||
assert len(feature_dict) == 0 | ||
|
||
|
||
def test_ludwig_feature_dict_pop(feature_dict: feature_utils.LudwigFeatureDict, type_module: torch.nn.Module): | ||
assert feature_dict.pop("type") == type_module | ||
assert len(feature_dict) == 1 | ||
assert feature_dict.pop("type", default=None) is None | ||
|
||
|
||
def test_ludwig_feature_dict_popitem(feature_dict: feature_utils.LudwigFeatureDict, to_module: torch.nn.Module): | ||
assert feature_dict.popitem() == ("to", to_module) | ||
assert len(feature_dict) == 1 | ||
|
||
|
||
def test_ludwig_feature_dict_setdefault(feature_dict: feature_utils.LudwigFeatureDict, to_module: torch.nn.Module): | ||
assert feature_dict.setdefault("to") == to_module | ||
assert feature_dict.get("other_key") is None | ||
|
||
|
||
@pytest.mark.parametrize("name", ["to", "type", "foo", "foo.bar"]) | ||
def test_name_to_module_dict_key(name: str): | ||
key = feature_utils.get_module_dict_key_from_name(name) | ||
assert key != name | ||
assert "." not in key | ||
assert feature_utils.get_name_from_module_dict_key(key) == name | ||
|
||
|
||
def test_ludwig_feature_dict(): | ||
feature_dict = feature_utils.LudwigFeatureDict() | ||
|
||
|
@@ -15,10 +123,10 @@ def test_ludwig_feature_dict(): | |
feature_dict.set("type", type_module) | ||
|
||
assert iter(feature_dict) is not None | ||
assert next(feature_dict) is not None | ||
# assert next(feature_dict) is not None | ||
assert len(feature_dict) == 2 | ||
assert feature_dict.keys() == ["to", "type"] | ||
assert feature_dict.items() == [("to", to_module), ("type", type_module)] | ||
assert feature_dict.key_list() == ["to", "type"] | ||
assert feature_dict.item_list() == [("to", to_module), ("type", type_module)] | ||
assert feature_dict.get("to"), to_module | ||
|
||
feature_dict.update({"to_empty": torch.nn.Module()}) | ||
|
@@ -34,8 +142,8 @@ def test_ludwig_feature_dict_with_periods(): | |
|
||
feature_dict.set("to.", to_module) | ||
|
||
assert feature_dict.keys() == ["to."] | ||
assert feature_dict.items() == [("to.", to_module)] | ||
assert feature_dict.key_list() == ["to."] | ||
assert feature_dict.item_list() == [("to.", to_module)] | ||
assert feature_dict.get("to.") == to_module | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dennisrall Thank you for incorporating my previous suggestion -- I think that part looks clean now. Thank you for the idea and the implementation!
For this one, I am not sure if the benefits due to adding the MutableMapping subclassing justify taking the risk brought about the multiple inheritance. Do the test cover all the eventualities that might happen with this change?
Thank you. /cc @justinxzhao @arnavgarg1 @Infernaught
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem, I was just playing around a bit.
I can't think of any problems about the multiple inheritance, but you know the code better than me😉
It is also possible to remove the
MutableMapping
inheritance and implement the other methods by hand. But I think this way it is a bit cleaner, if it doesn't cause any problems...