Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add datasets to config api (pipelines) #585

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions oml/configs/datasets/oml_image_dataset.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: oml_image_dataset
args:
dataframe_name: df.csv
cache_size: 100
transforms_train:
name: norm_resize_torch
args:
im_size: 224
transforms_val:
name: norm_resize_torch
args:
im_size: 224
29 changes: 29 additions & 0 deletions oml/configs/datasets/oml_reranking_dataset.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: "oml_reranking_dataset"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"oml_reranking_dataset" assumes a way dataset is going to be used, which is not a dataset's business
let's name it oml_image_dataset_with_embeddings

args:
# dataset_root: /path/to/dataset/ # use your dataset root here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do as in the test_pipelines (check entrypoint and config)

dataframe_name: df.csv
# embeddings_cache_dir: ${args.dataset_root} # CACHE EMBEDDINGS PRODUCED BY BASELINE FEATURE EXTRACTOR
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use None, and why the comment is in UPPER CASE?

cache_size: 0
batch_size_inference: 128
num_workers: 2
precision: 16
accelerator: cpu

feature_extractor:
name: vit
args:
normalise_features: True
use_multi_scale: False
weights: null # or /path/to/extractor.ckpt # <---- specify path to baseline feature extractor
arch: vits16

transforms_extraction:
name: norm_resize_hypvit_torch
args:
im_size: 224
crop_size: 224

transforms_train:
name: augs_hypvit_torch
args:
im_size: 224
40 changes: 26 additions & 14 deletions oml/lightning/pipelines/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,11 @@ def log_pipeline_info(self, cfg: TCfg) -> None:
)
self.run["code"].upload_files(source_files)

# log dataframe
self.run["dataset"].upload(str(Path(cfg["dataset_root"]) / cfg["dataframe_name"]))
if cfg["datasets"]["args"].get("dataframe_name", None) is not None:
# log dataframe
self.run["dataset"].upload(
str(Path(cfg["datasets"]["args"]["dataset_root"]) / cfg["datasets"]["args"]["dataframe_name"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you only checked you have dataframe_name, what if you don't have dataset_root? I think text datasets don't have root, but they have dataframe

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the same for the changes above in this file

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets also check that dataset_root is presented

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and let's not repeat ourselves and introduce a function:

def dataframe_from_cfg_if_presented()

and reuse it all places

)

def log_figure(self, fig: plt.Figure, title: str, idx: int) -> None:
from neptune.types import File # this is the optional dependency
Expand Down Expand Up @@ -149,9 +152,12 @@ def log_pipeline_info(self, cfg: TCfg) -> None:
self.experiment.log_artifact(code)

# log dataset
dataset = wandb.Artifact("dataset", type="dataset")
dataset.add_file(str(Path(cfg["dataset_root"]) / cfg["dataframe_name"]))
self.experiment.log_artifact(dataset)
if cfg["datasets"]["args"].get("dataframe_name", None) is not None:
dataset = wandb.Artifact("dataset", type="dataset")
dataset.add_file(
str(Path(cfg["datasets"]["args"]["dataset_root"]) / cfg["datasets"]["args"]["dataframe_name"])
)
self.experiment.log_artifact(dataset)

def log_figure(self, fig: plt.Figure, title: str, idx: int) -> None:
fig_img = figure_to_nparray(fig)
Expand Down Expand Up @@ -186,11 +192,14 @@ def log_pipeline_info(self, cfg: TCfg) -> None:
self.experiment.log_artifacts(run_id=self.run_id, local_dir=OML_PATH, artifact_path="code")

# log dataframe
self.experiment.log_artifact(
run_id=self.run_id,
local_path=str(Path(cfg["dataset_root"]) / cfg["dataframe_name"]),
artifact_path="dataset",
)
if cfg["datasets"]["args"].get("dataframe_name", None) is not None:
self.experiment.log_artifact(
run_id=self.run_id,
local_path=str(
Path(cfg["datasets"]["args"]["dataset_root"]) / cfg["datasets"]["args"]["dataframe_name"]
),
artifact_path="dataset",
)

def log_figure(self, fig: plt.Figure, title: str, idx: int) -> None:
self.experiment.log_figure(figure=fig, artifact_file=f"{title}.png", run_id=self.run_id)
Expand All @@ -214,10 +223,13 @@ def log_pipeline_info(self, cfg: TCfg) -> None:
self.task.upload_artifact(name="code", artifact_object=OML_PATH)

# log dataframe
self.task.upload_artifact(
name="dataset",
artifact_object=str(Path(cfg["dataset_root"]) / cfg["dataframe_name"]),
)
if cfg["datasets"]["args"].get("dataframe_name", None) is not None:
self.task.upload_artifact(
name="dataset",
artifact_object=str(
Path(cfg["datasets"]["args"]["dataset_root"]) / cfg["datasets"]["args"]["dataframe_name"]
),
)

def log_figure(self, fig: plt.Figure, title: str, idx: int) -> None:
self.logger.report_matplotlib_figure(
Expand Down
25 changes: 25 additions & 0 deletions oml/lightning/pipelines/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,32 @@ def parse_ckpt_callback_from_config(cfg: TCfg) -> ModelCheckpoint:
)


def convert_to_new_format_if_needed(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline: we don't support back compatibility for re-ranking pipeline

cfg: TCfg,
) -> Dict[str, Any]:
old_keys = {
"dataset_root",
"dataframe_name",
"cache_size",
"transforms_train",
"transforms_val",
}
keys_to_process = old_keys.intersection(set(cfg.keys()))
if keys_to_process:
print("Used old cfg version (before oml 3.1.0). Converting to new format")
args = dict()
for key in keys_to_process:
args[key] = cfg.pop(key)

cfg["datasets"] = {
"name": "oml_image_dataset",
"args": args,
}
return cfg


__all__ = [
"convert_to_new_format_if_needed",
"parse_engine_params_from_config",
"check_is_config_for_ddp",
"parse_scheduler_from_config",
Expand Down
6 changes: 3 additions & 3 deletions oml/lightning/pipelines/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ def extractor_prediction_pipeline(cfg: TCfg) -> None:

pprint(cfg)

transforms = get_transforms_by_cfg(cfg["transforms_predict"])
filenames = [list(Path(cfg["data_dir"]).glob(f"**/*.{ext}")) for ext in IMAGE_EXTENSIONS]
transforms = get_transforms_by_cfg(cfg["datasets"]["args"]["transforms_predict"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something weird happened here, I think you should not touch this pipeline at the moment

filenames = [list(Path(cfg["datasets"]["args"]["data_dir"]).glob(f"**/*.{ext}")) for ext in IMAGE_EXTENSIONS]
filenames = list(itertools.chain(*filenames))

if len(filenames) == 0:
raise RuntimeError(f"There are no images in the provided directory: {cfg['data_dir']}")
raise RuntimeError(f"There are no images in the provided directory: {cfg['datasets']['args']['data_dir']}")

f_imread = get_im_reader_for_transforms(transforms)

Expand Down
18 changes: 8 additions & 10 deletions oml/lightning/pipelines/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,32 @@
from torch.utils.data import DataLoader

from oml.const import TCfg
from oml.datasets.images import get_retrieval_images_datasets
from oml.interfaces.datasets import ILabeledDataset, IQueryGalleryLabeledDataset
from oml.lightning.callbacks.metric import MetricValCallback
from oml.lightning.modules.extractor import ExtractorModule, ExtractorModuleDDP
from oml.lightning.pipelines.parser import (
check_is_config_for_ddp,
convert_to_new_format_if_needed,
parse_ckpt_callback_from_config,
parse_engine_params_from_config,
parse_logger_from_config,
parse_sampler_from_config,
parse_scheduler_from_config,
)
from oml.metrics.embeddings import EmbeddingMetrics
from oml.registry.datasets import get_image_datasets
from oml.registry.losses import get_criterion_by_cfg
from oml.registry.models import get_extractor_by_cfg
from oml.registry.optimizers import get_optimizer_by_cfg
from oml.registry.transforms import get_transforms_by_cfg
from oml.utils.misc import dictconfig_to_dict, set_global_seed


def get_retrieval_loaders(cfg: TCfg) -> Tuple[DataLoader, DataLoader, ILabeledDataset, IQueryGalleryLabeledDataset]:
train_dataset, valid_dataset = get_retrieval_images_datasets(
dataset_root=Path(cfg["dataset_root"]),
transforms_train=get_transforms_by_cfg(cfg["transforms_train"]),
transforms_val=get_transforms_by_cfg(cfg["transforms_val"]),
dataframe_name=cfg["dataframe_name"],
cache_size=cfg.get("cache_size", 0),
verbose=cfg.get("show_dataset_warnings", True),
)

train_dataset, valid_dataset = get_image_datasets(cfg["datasets"]["name"], **cfg["datasets"]["args"])

assert isinstance(train_dataset, ILabeledDataset)
assert isinstance(valid_dataset, IQueryGalleryLabeledDataset)

sampler = parse_sampler_from_config(cfg, train_dataset)

Expand Down Expand Up @@ -70,6 +67,7 @@ def extractor_training_pipeline(cfg: TCfg) -> None:
set_global_seed(cfg["seed"])

cfg = dictconfig_to_dict(cfg)
cfg = convert_to_new_format_if_needed(cfg)
pprint(cfg)

logger = parse_logger_from_config(cfg)
Expand Down
73 changes: 7 additions & 66 deletions oml/lightning/pipelines/train_postprocessor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import hashlib
from pathlib import Path
from pprint import pprint
from typing import Any, Dict, Tuple
from typing import Tuple

import pytorch_lightning as pl
import torch
Expand All @@ -10,9 +9,6 @@
from torch.utils.data import DataLoader

from oml.const import EMBEDDINGS_KEY, TCfg
from oml.datasets.base import ImageLabeledDataset, ImageQueryGalleryLabeledDataset
from oml.datasets.images import get_retrieval_images_datasets
from oml.inference import inference, inference_cached
from oml.interfaces.datasets import ILabeledDataset, IQueryGalleryLabeledDataset
from oml.interfaces.models import IPairwiseModel
from oml.lightning.callbacks.metric import MetricValCallback
Expand All @@ -22,6 +18,7 @@
)
from oml.lightning.pipelines.parser import (
check_is_config_for_ddp,
convert_to_new_format_if_needed,
parse_ckpt_callback_from_config,
parse_engine_params_from_config,
parse_logger_from_config,
Expand All @@ -30,77 +27,20 @@
)
from oml.metrics.embeddings import EmbeddingMetrics
from oml.miners.pairs import PairsMiner
from oml.registry.models import get_extractor_by_cfg
from oml.registry.datasets import get_image_datasets
from oml.registry.optimizers import get_optimizer_by_cfg
from oml.registry.postprocessors import get_postprocessor_by_cfg
from oml.registry.transforms import get_transforms_by_cfg
from oml.retrieval.postprocessors.pairwise import PairwiseReranker
from oml.utils.misc import dictconfig_to_dict, flatten_dict, set_global_seed


def get_hash_of_extraction_stage_cfg(cfg: TCfg) -> str:
def dict2str(dictionary: Dict[str, Any]) -> str:
flatten_items = flatten_dict(dictionary).items()
sorted(flatten_items, key=lambda x: x[0])
return str(flatten_items)

cfg_extraction_str = (
dict2str(cfg["extractor"])
+ dict2str(cfg["transforms_extraction"])
+ str(cfg["dataframe_name"])
+ str(cfg.get("precision", 32))
)

md5sum = hashlib.md5(cfg_extraction_str.encode("utf-8")).hexdigest()
return md5sum
from oml.utils.misc import dictconfig_to_dict, set_global_seed


def get_loaders_with_embeddings(
cfg: TCfg,
) -> Tuple[DataLoader, DataLoader, ILabeledDataset, IQueryGalleryLabeledDataset]:
device = tdevice("cuda:0") if parse_engine_params_from_config(cfg)["accelerator"] == "gpu" else tdevice("cpu")
extractor = get_extractor_by_cfg(cfg["extractor"]).to(device)

transforms_extraction = get_transforms_by_cfg(cfg["transforms_extraction"])

train_extraction, val_extraction = get_retrieval_images_datasets(
dataset_root=Path(cfg["dataset_root"]),
dataframe_name=cfg["dataframe_name"],
transforms_train=transforms_extraction,
transforms_val=transforms_extraction,
)

args = {
"model": extractor,
"num_workers": cfg["num_workers"],
"batch_size": cfg["batch_size_inference"],
"use_fp16": int(cfg.get("precision", 32)) == 16,
}

if cfg["embeddings_cache_dir"] is not None:
hash_ = get_hash_of_extraction_stage_cfg(cfg)[:5]
dir_ = Path(cfg["embeddings_cache_dir"])
emb_train = inference_cached(dataset=train_extraction, cache_path=str(dir_ / f"emb_train_{hash_}.pkl"), **args)
emb_val = inference_cached(dataset=val_extraction, cache_path=str(dir_ / f"emb_val_{hash_}.pkl"), **args)
else:
emb_train = inference(dataset=train_extraction, **args)
emb_val = inference(dataset=val_extraction, **args)

train_dataset = ImageLabeledDataset(
dataset_root=cfg["dataset_root"],
cache_size=cfg.get("cache_size", 0),
df=train_extraction.df,
transform=get_transforms_by_cfg(cfg["transforms_train"]),
extra_data={EMBEDDINGS_KEY: emb_train},
)

valid_dataset = ImageQueryGalleryLabeledDataset(
dataset_root=cfg["dataset_root"],
cache_size=cfg.get("cache_size", 0),
df=val_extraction.df,
transform=transforms_extraction,
extra_data={EMBEDDINGS_KEY: emb_val},
)
cfg["datasets"]["args"]["device"] = device
train_dataset, valid_dataset = get_image_datasets(cfg["datasets"]["name"], **cfg["datasets"]["args"])

sampler = parse_sampler_from_config(cfg, dataset=train_dataset)
assert sampler is not None, "We will be training on pairs, so, having sampler is obligatory."
Expand All @@ -126,6 +66,7 @@ def postprocessor_training_pipeline(cfg: DictConfig) -> None:
set_global_seed(cfg["seed"])

cfg = dictconfig_to_dict(cfg)
cfg = convert_to_new_format_if_needed(cfg)
pprint(cfg)

logger = parse_logger_from_config(cfg)
Expand Down
9 changes: 6 additions & 3 deletions oml/lightning/pipelines/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from oml.lightning.modules.extractor import ExtractorModule, ExtractorModuleDDP
from oml.lightning.pipelines.parser import (
check_is_config_for_ddp,
convert_to_new_format_if_needed,
parse_engine_params_from_config,
)
from oml.metrics.embeddings import EmbeddingMetrics
Expand All @@ -30,16 +31,18 @@ def extractor_validation_pipeline(cfg: TCfg) -> Tuple[pl.Trainer, Dict[str, Any]

"""
cfg = dictconfig_to_dict(cfg)
cfg = convert_to_new_format_if_needed(cfg)

trainer_engine_params = parse_engine_params_from_config(cfg)
is_ddp = check_is_config_for_ddp(trainer_engine_params)

pprint(cfg)

_, valid_dataset = get_retrieval_images_datasets(
dataset_root=Path(cfg["dataset_root"]),
dataset_root=Path(cfg["datasets"]["args"]["dataset_root"]),
transforms_train=None,
transforms_val=get_transforms_by_cfg(cfg["transforms_val"]),
dataframe_name=cfg["dataframe_name"],
transforms_val=get_transforms_by_cfg(cfg["datasets"]["args"]["transforms_val"]),
dataframe_name=cfg["datasets"]["args"]["dataframe_name"],
)
loader_val = DataLoader(dataset=valid_dataset, batch_size=cfg["bs_val"], num_workers=cfg["num_workers"])

Expand Down
Loading
Loading