Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support language dataset for DmTorchDataset #1592

Merged
4 changes: 2 additions & 2 deletions .github/workflows/codeql.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:

# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@429e1977040da7a23b6822b13c129cd1ba93dbb2 # v3.26.2
uses: github/codeql-action/init@f0f3afee809481da311ca3a6ff1ff51d81dbeb24 # v3.26.4
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
Expand All @@ -73,7 +73,7 @@ jobs:
python -m build

- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@429e1977040da7a23b6822b13c129cd1ba93dbb2 # v3.26.2
uses: github/codeql-action/analyze@f0f3afee809481da311ca3a6ff1ff51d81dbeb24 # v3.26.4
with:
category: "/language:${{matrix.language}}"
- name: Generate Security Report
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/scorecard.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,6 @@ jobs:

# Upload the results to GitHub's code scanning dashboard.
- name: "Upload to code-scanning"
uses: github/codeql-action/upload-sarif@429e1977040da7a23b6822b13c129cd1ba93dbb2 # v3.26.2
uses: github/codeql-action/upload-sarif@f0f3afee809481da311ca3a6ff1ff51d81dbeb24 # v3.26.4
with:
sarif_file: results.sarif
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### New features
- Add a new CLI command: datum format
(<https://github.com/openvinotoolkit/datumaro/pull/1570>)
- Support language dataset for DmTorchDataset
(<https://github.com/openvinotoolkit/datumaro/pull/1592>)

### Enhancements
- Change _Shape to Shape and add comments for subclasses of Shape
Expand Down
3 changes: 3 additions & 0 deletions requirements-core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,6 @@ json-stream

# TabularValidator
nltk

# torch converter for language
portalocker
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def parse_requirements(filename=CORE_REQUIREMENTS_FILE):
extras_require={
"tf": ["tensorflow"],
"tfds": ["tensorflow-datasets<4.9.3"],
"torch": ["torch", "torchvision"],
"torch": ["torch", "torchvision", "torchtext==0.16.0"],
"default": DEFAULT_REQUIREMENTS,
},
ext_modules=ext_modules,
Expand Down
2 changes: 1 addition & 1 deletion src/datumaro/components/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def merge(cls, envs: Sequence["Environment"]) -> "Environment":
merged = Environment()

def _register(registry: PluginRegistry):
merged.register_plugins(plugin for plugin in registry)
merged.register_plugins(list(registry._items.values()))

for env in envs:
_register(env.extractors)
Expand Down
15 changes: 8 additions & 7 deletions src/datumaro/components/hl_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,14 @@ def merge(
merger = get_merger(merge_policy, **kwargs)
merged = merger(*datasets)
env = Environment.merge(
dataset.env
for dataset in datasets
if hasattr(
dataset, "env"
) # TODO: Sometimes, there is dataset which is not exactly "Dataset",
# e.g., VocClassificationBase. this should be fixed and every object from
# Dataset.import_from should have "Dataset" type.
[
dataset.env
for dataset in datasets
if hasattr(dataset, "env")
# TODO: Sometimes, there is dataset which is not exactly "Dataset",
# e.g., VocClassificationBase. this should be fixed and every object from
# Dataset.import_from should have "Dataset" type.
]
)
if report_path:
merger.save_merge_report(report_path)
Expand Down
51 changes: 49 additions & 2 deletions src/datumaro/plugins/framework_converter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2023 Intel Corporation
# Copyright (C) 2023-2024 Intel Corporation
#
# SPDX-License-Identifier: MIT

Expand All @@ -17,6 +17,7 @@
"detection": AnnotationType.bbox,
"instance_segmentation": AnnotationType.polygon,
"semantic_segmentation": AnnotationType.mask,
"tabular": [AnnotationType.label, AnnotationType.caption],
}


Expand Down Expand Up @@ -88,7 +89,10 @@ def _gen_item(self, idx: int):
if ann.type == TASK_ANN_TYPE[self.task]
]
label = mask_tools.merge_masks((mask, label_id) for mask, label_id in masks)

elif self.task == "tabular":
label = [
ann.as_dict() for ann in item.annotations if ann.type in TASK_ANN_TYPE[self.task]
]
return image, label


Expand All @@ -103,15 +107,58 @@ def __init__(
task: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
target: Optional[str] = None,
tokenizer: Optional[tuple[Callable, Callable]] = None,
vocab: Optional[tuple[Callable, Callable]] = None,
):
super().__init__(dataset=dataset, subset=subset, task=task)

self.transform = transform
self.target_transform = target_transform

if self.task == "tabular":
if not isinstance(target, dict):
raise ValueError(
"Target should be a dictionary with 'input' and 'output' keys."
)
self.input_target = target.get("input")
self.output_target = target.get("output")
if not self.input_target:
raise ValueError(
"Please provide target column for tabular task which is used for input"
)

if not (tokenizer and vocab):
raise ValueError("Both tokenizer and vocab must be provided for tabular task")
self.tokenizer = tokenizer
self.vocab = vocab

def __getitem__(self, idx):
image, label = self._gen_item(idx)

if self.task == "tabular":
text = image()[self.input_target]

if self.output_target:
src_tokenizer, tgt_tokenizer = self.tokenizer
src_vocab, tgt_vocab = self.vocab
src_tokens = src_tokenizer(text)
src_token_ids = src_vocab(src_tokens)

label_text = label[0]["caption"].split(f"{self.output_target}:")[-1]
tgt_tokens = tgt_tokenizer(label_text)
tgt_token_ids = tgt_vocab(tgt_tokens)

return torch.tensor(src_token_ids, dtype=torch.long), torch.tensor(
tgt_token_ids, dtype=torch.long
)
else:
tokens = self.tokenizer(text)
token_ids = self.vocab(tokens)
return torch.tensor(token_ids, dtype=torch.long), torch.tensor(
label[0]["label"], dtype=torch.long
)

if len(image.shape) == 2:
image = np.expand_dims(image, axis=-1)

Expand Down
17 changes: 16 additions & 1 deletion tests/unit/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import pytest

import datumaro.components.lazy_plugin
from datumaro.components.environment import Environment, PluginRegistry
from datumaro.components.environment import DEFAULT_ENVIRONMENT, Environment, PluginRegistry
from datumaro.components.exporter import Exporter

real_find_spec = datumaro.components.lazy_plugin.find_spec

Expand Down Expand Up @@ -77,3 +78,17 @@ def test_extra_deps_req(self, fxt_tf_failure_env):
)

assert "tf_detection_api" not in loaded_plugin_names

def test_merge_default_env(self):
merged_env = Environment.merge([DEFAULT_ENVIRONMENT, DEFAULT_ENVIRONMENT])
assert merged_env is DEFAULT_ENVIRONMENT

def test_merge_custom_env(self):
class TestPlugin(Exporter):
pass

envs = [Environment(), Environment()]
envs[0].exporters.register("test_plugin", TestPlugin)

merged = Environment.merge(envs)
assert "test_plugin" in merged.exporters
Loading
Loading