-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add datasets to config api (pipelines) #585
base: main
Are you sure you want to change the base?
Changes from all commits
8309c5a
eca39c6
118f36a
cebaf77
c77adff
a9f9bf4
3d94785
4f013eb
326a658
d14e406
c3f627c
71ec834
014bf17
5b3e9f1
b8d222d
1e0363e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
name: oml_image_dataset | ||
args: | ||
dataframe_name: df.csv | ||
cache_size: 100 | ||
transforms_train: | ||
name: norm_resize_torch | ||
args: | ||
im_size: 224 | ||
transforms_val: | ||
name: norm_resize_torch | ||
args: | ||
im_size: 224 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
name: "oml_reranking_dataset" | ||
args: | ||
# dataset_root: /path/to/dataset/ # use your dataset root here | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's do as in the |
||
dataframe_name: df.csv | ||
# embeddings_cache_dir: ${args.dataset_root} # CACHE EMBEDDINGS PRODUCED BY BASELINE FEATURE EXTRACTOR | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's use None, and why the comment is in UPPER CASE? |
||
cache_size: 0 | ||
batch_size_inference: 128 | ||
num_workers: 2 | ||
precision: 16 | ||
accelerator: cpu | ||
|
||
feature_extractor: | ||
name: vit | ||
args: | ||
normalise_features: True | ||
use_multi_scale: False | ||
weights: null # or /path/to/extractor.ckpt # <---- specify path to baseline feature extractor | ||
arch: vits16 | ||
|
||
transforms_extraction: | ||
name: norm_resize_hypvit_torch | ||
args: | ||
im_size: 224 | ||
crop_size: 224 | ||
|
||
transforms_train: | ||
name: augs_hypvit_torch | ||
args: | ||
im_size: 224 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -116,8 +116,11 @@ def log_pipeline_info(self, cfg: TCfg) -> None: | |
) | ||
self.run["code"].upload_files(source_files) | ||
|
||
# log dataframe | ||
self.run["dataset"].upload(str(Path(cfg["dataset_root"]) / cfg["dataframe_name"])) | ||
if cfg["datasets"]["args"].get("dataframe_name", None) is not None: | ||
# log dataframe | ||
self.run["dataset"].upload( | ||
str(Path(cfg["datasets"]["args"]["dataset_root"]) / cfg["datasets"]["args"]["dataframe_name"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you only checked you have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the same for the changes above in this file There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets also check that dataset_root is presented There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and let's not repeat ourselves and introduce a function:
and reuse it all places |
||
) | ||
|
||
def log_figure(self, fig: plt.Figure, title: str, idx: int) -> None: | ||
from neptune.types import File # this is the optional dependency | ||
|
@@ -149,9 +152,12 @@ def log_pipeline_info(self, cfg: TCfg) -> None: | |
self.experiment.log_artifact(code) | ||
|
||
# log dataset | ||
dataset = wandb.Artifact("dataset", type="dataset") | ||
dataset.add_file(str(Path(cfg["dataset_root"]) / cfg["dataframe_name"])) | ||
self.experiment.log_artifact(dataset) | ||
if cfg["datasets"]["args"].get("dataframe_name", None) is not None: | ||
dataset = wandb.Artifact("dataset", type="dataset") | ||
dataset.add_file( | ||
str(Path(cfg["datasets"]["args"]["dataset_root"]) / cfg["datasets"]["args"]["dataframe_name"]) | ||
) | ||
self.experiment.log_artifact(dataset) | ||
|
||
def log_figure(self, fig: plt.Figure, title: str, idx: int) -> None: | ||
fig_img = figure_to_nparray(fig) | ||
|
@@ -186,11 +192,14 @@ def log_pipeline_info(self, cfg: TCfg) -> None: | |
self.experiment.log_artifacts(run_id=self.run_id, local_dir=OML_PATH, artifact_path="code") | ||
|
||
# log dataframe | ||
self.experiment.log_artifact( | ||
run_id=self.run_id, | ||
local_path=str(Path(cfg["dataset_root"]) / cfg["dataframe_name"]), | ||
artifact_path="dataset", | ||
) | ||
if cfg["datasets"]["args"].get("dataframe_name", None) is not None: | ||
self.experiment.log_artifact( | ||
run_id=self.run_id, | ||
local_path=str( | ||
Path(cfg["datasets"]["args"]["dataset_root"]) / cfg["datasets"]["args"]["dataframe_name"] | ||
), | ||
artifact_path="dataset", | ||
) | ||
|
||
def log_figure(self, fig: plt.Figure, title: str, idx: int) -> None: | ||
self.experiment.log_figure(figure=fig, artifact_file=f"{title}.png", run_id=self.run_id) | ||
|
@@ -214,10 +223,13 @@ def log_pipeline_info(self, cfg: TCfg) -> None: | |
self.task.upload_artifact(name="code", artifact_object=OML_PATH) | ||
|
||
# log dataframe | ||
self.task.upload_artifact( | ||
name="dataset", | ||
artifact_object=str(Path(cfg["dataset_root"]) / cfg["dataframe_name"]), | ||
) | ||
if cfg["datasets"]["args"].get("dataframe_name", None) is not None: | ||
self.task.upload_artifact( | ||
name="dataset", | ||
artifact_object=str( | ||
Path(cfg["datasets"]["args"]["dataset_root"]) / cfg["datasets"]["args"]["dataframe_name"] | ||
), | ||
) | ||
|
||
def log_figure(self, fig: plt.Figure, title: str, idx: int) -> None: | ||
self.logger.report_matplotlib_figure( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -104,7 +104,32 @@ def parse_ckpt_callback_from_config(cfg: TCfg) -> ModelCheckpoint: | |
) | ||
|
||
|
||
def convert_to_new_format_if_needed( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. discussed offline: we don't support back compatibility for re-ranking pipeline |
||
cfg: TCfg, | ||
) -> Dict[str, Any]: | ||
old_keys = { | ||
"dataset_root", | ||
"dataframe_name", | ||
"cache_size", | ||
"transforms_train", | ||
"transforms_val", | ||
} | ||
keys_to_process = old_keys.intersection(set(cfg.keys())) | ||
if keys_to_process: | ||
print("Used old cfg version (before oml 3.1.0). Converting to new format") | ||
args = dict() | ||
for key in keys_to_process: | ||
args[key] = cfg.pop(key) | ||
|
||
cfg["datasets"] = { | ||
"name": "oml_image_dataset", | ||
"args": args, | ||
} | ||
return cfg | ||
|
||
|
||
__all__ = [ | ||
"convert_to_new_format_if_needed", | ||
"parse_engine_params_from_config", | ||
"check_is_config_for_ddp", | ||
"parse_scheduler_from_config", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,12 +30,12 @@ def extractor_prediction_pipeline(cfg: TCfg) -> None: | |
|
||
pprint(cfg) | ||
|
||
transforms = get_transforms_by_cfg(cfg["transforms_predict"]) | ||
filenames = [list(Path(cfg["data_dir"]).glob(f"**/*.{ext}")) for ext in IMAGE_EXTENSIONS] | ||
transforms = get_transforms_by_cfg(cfg["datasets"]["args"]["transforms_predict"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. something weird happened here, I think you should not touch this pipeline at the moment |
||
filenames = [list(Path(cfg["datasets"]["args"]["data_dir"]).glob(f"**/*.{ext}")) for ext in IMAGE_EXTENSIONS] | ||
filenames = list(itertools.chain(*filenames)) | ||
|
||
if len(filenames) == 0: | ||
raise RuntimeError(f"There are no images in the provided directory: {cfg['data_dir']}") | ||
raise RuntimeError(f"There are no images in the provided directory: {cfg['datasets']['args']['data_dir']}") | ||
|
||
f_imread = get_im_reader_for_transforms(transforms) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"oml_reranking_dataset" assumes a way dataset is going to be used, which is not a dataset's business
let's name it
oml_image_dataset_with_embeddings