Skip to content

Commit

Permalink
Fix bug where custom environment plugins were lost on dataset merge (#…
Browse files Browse the repository at this point in the history
…1582)

I ran into an issue where plugins loaded into the `Environment` are not
retained after merging. Turns out a generator is being passed to
[Environment.merge](https://github.com/openvinotoolkit/datumaro/blob/30b1add52d6fe458dba5e32e1f63ffeea999ad7b/src/datumaro/components/environment.py#L271)
instead of the expected Sequence. This meant that custom plugins were
never registered into the merged environment.

Reproducible example:
```python
from datumaro.components.dataset import Dataset
from datumaro.components.environment import Environment
from datumaro.components.exporter import Exporter
from datumaro.components.hl_ops import HLOps
from datumaro.components.media import Image


class MyPlugin(Exporter):
    pass


environment = Environment()
environment.exporters.register('my_plugin', MyPlugin)

dataset1 = Dataset(media_type=Image, env=environment)
datasets = [dataset1, dataset1.clone()]
merged = HLOps.merge(*datasets)

assert 'my_plugin' in dataset1.env.exporters  # Passes
assert 'my_plugin' in merged.env.exporters  # Fails

```
---------

Co-authored-by: Sooah Lee <[email protected]>
  • Loading branch information
williamcorsel and sooahleex authored Aug 20, 2024
1 parent 5d669f7 commit f93ae27
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/datumaro/components/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def merge(cls, envs: Sequence["Environment"]) -> "Environment":
merged = Environment()

def _register(registry: PluginRegistry):
merged.register_plugins(plugin for plugin in registry)
merged.register_plugins(list(registry._items.values()))

for env in envs:
_register(env.extractors)
Expand Down
15 changes: 8 additions & 7 deletions src/datumaro/components/hl_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,14 @@ def merge(
merger = get_merger(merge_policy, **kwargs)
merged = merger(*datasets)
env = Environment.merge(
dataset.env
for dataset in datasets
if hasattr(
dataset, "env"
) # TODO: Sometimes, there is dataset which is not exactly "Dataset",
# e.g., VocClassificationBase. this should be fixed and every object from
# Dataset.import_from should have "Dataset" type.
[
dataset.env
for dataset in datasets
if hasattr(dataset, "env")
# TODO: Sometimes, there is dataset which is not exactly "Dataset",
# e.g., VocClassificationBase. this should be fixed and every object from
# Dataset.import_from should have "Dataset" type.
]
)
if report_path:
merger.save_merge_report(report_path)
Expand Down
17 changes: 16 additions & 1 deletion tests/unit/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import pytest

import datumaro.components.lazy_plugin
from datumaro.components.environment import Environment, PluginRegistry
from datumaro.components.environment import DEFAULT_ENVIRONMENT, Environment, PluginRegistry
from datumaro.components.exporter import Exporter

real_find_spec = datumaro.components.lazy_plugin.find_spec

Expand Down Expand Up @@ -77,3 +78,17 @@ def test_extra_deps_req(self, fxt_tf_failure_env):
)

assert "tf_detection_api" not in loaded_plugin_names

def test_merge_default_env(self):
merged_env = Environment.merge([DEFAULT_ENVIRONMENT, DEFAULT_ENVIRONMENT])
assert merged_env is DEFAULT_ENVIRONMENT

def test_merge_custom_env(self):
class TestPlugin(Exporter):
pass

envs = [Environment(), Environment()]
envs[0].exporters.register("test_plugin", TestPlugin)

merged = Environment.merge(envs)
assert "test_plugin" in merged.exporters

0 comments on commit f93ae27

Please sign in to comment.