diff --git a/make_script.py b/make_script.py new file mode 100644 index 00000000..43353b81 --- /dev/null +++ b/make_script.py @@ -0,0 +1,34 @@ +import json + +with open("./tasks_config.json", "r") as f: + tasks_config = json.load(f) + +scriptstring = """ +from datetime import datetime +import shifthappens.benchmark +import shifthappens.utils +""" + +print(tasks_config["tasks"]) +for task in tasks_config["tasks"]: + scriptstring += tasks_config["import_lines"][task] + scriptstring += "\n" + +scriptstring += tasks_config["import_lines"][tasks_config["model"]] + +out_file_location = tasks_config["out_file_location"] +relative_data_folder = tasks_config["relative_data_folder"] + +scriptstring += f""" +tasks = shifthappens.benchmark.get_task_registrations() +model = {tasks_config['model']}() +results = shifthappens.benchmark.evaluate_model( + model, "{relative_data_folder}" +) +results_string = shifthappens.utils.serialize_model_results(results) +out_file_location = "{out_file_location}" +with open(out_file_location, 'w') as outfile: + outfile.write(results_string) +""" +with open("./run_tasks.py", "w") as run_script_file: + run_script_file.write(scriptstring) diff --git a/run.sh b/run.sh new file mode 100644 index 00000000..95e7c745 --- /dev/null +++ b/run.sh @@ -0,0 +1,2 @@ +python3 make_script.py +python3 run_tasks.py \ No newline at end of file diff --git a/shifthappens/task_data/task_metadata.py b/shifthappens/task_data/task_metadata.py index 9c690672..a6a2b974 100644 --- a/shifthappens/task_data/task_metadata.py +++ b/shifthappens/task_data/task_metadata.py @@ -1,6 +1,7 @@ """Class for storing a task's metadata.""" from dataclasses import dataclass +import json @dataclass(frozen=True, eq=True) @@ -22,5 +23,29 @@ class TaskMetadata: relative_data_folder: str standalone: bool = True + def serialize_task_metadata(self) -> str: + """ + Serialize TaskMetadata object into json string. + """ + metadata_dict = { + "name": self.name, + "relative_data_folder": self.relative_data_folder, + "standalone": self.standalone, + } + return json.dumps(metadata_dict) + + @staticmethod + def deserialize_task_metadata(metadata_str: str): + """ + Deserialize valid json string into TaskMetadata object. + """ + metadata_dict = json.loads(metadata_str) + metadata = TaskMetadata( + name=metadata_dict["name"], + relative_data_folder=metadata_dict["relative_data_folder"], + standalone=metadata_dict["standalone"], + ) + return metadata + _TASK_METADATA_FIELD = "__task_metadata__" diff --git a/shifthappens/tasks/task_result.py b/shifthappens/tasks/task_result.py index 4208826f..90e01d65 100644 --- a/shifthappens/tasks/task_result.py +++ b/shifthappens/tasks/task_result.py @@ -64,3 +64,44 @@ def __getattr__(self, item) -> float: return self[item] else: return super().__getattribute__(item) + + def serialize_summary_metrics(self) -> str: + """ + Serializes summary metrics of the objects into a string. + """ + return str({key.name: value for (key, value) in self.summary_metrics.items()}) + + def serialize_task_result(self) -> str: + """ + Serializes TaskResult object into a string. + """ + result_dict = { + "summary_metrics": self.serialize_summary_metrics(), + "metrics": str(self._metrics), + } + return str(result_dict) + + @staticmethod + def deserialize_summary_metrics( + summary_metrics_str: str, + ) -> Dict[Metric, Union[str, Tuple[str, ...]]]: + """ + Deserializes valid string into summary_metrics. + """ + summary_metrics = eval(summary_metrics_str) + result = {} + for key, value in summary_metrics.items(): + result[Metric.__members__.get(key)] = value + return result + + @staticmethod + def deserialize_task_result(task_result_str: str): + """ + Deserializes valid string into a TaskResult object. + """ + result_dict = eval(task_result_str) + metrics = eval(result_dict["metrics"]) + summary_metrics = TaskResult.deserialize_summary_metrics( + result_dict["summary_metrics"] + ) + return TaskResult(summary_metrics=summary_metrics, **metrics) diff --git a/shifthappens/utils.py b/shifthappens/utils.py index 57d696db..528c646f 100644 --- a/shifthappens/utils.py +++ b/shifthappens/utils.py @@ -1,12 +1,16 @@ """Utility functions that are needed for the entire package.""" import errno +import json import os import sys import time import urllib.error from itertools import product -from typing import Optional +from typing import Dict, Optional, Union + +from shifthappens.task_data import task_metadata +from shifthappens.tasks.task_result import TaskResult def dict_product(d): @@ -135,3 +139,33 @@ def download_and_extract_archive( archive = os.path.join(data_folder, filename) print(f"Extracting {archive} to {data_folder}") tv_utils.extract_archive(archive, data_folder, remove_finished) + + +def serialize_model_results( + results: Dict[task_metadata.TaskMetadata, Union[TaskResult, None]] +) -> str: + """ + Converts evaluation results of a model into json objects. + """ + return json.dumps( + { + key.serialize_task_metadata(): value.serialize_task_result() + for (key, value) in results.items() + if value is not None + } + ) + + +def deserialize_model_results( + results_str, +) -> Dict[task_metadata.TaskMetadata, TaskResult]: + """ + Converts json objects to a dictionary with (TaskMetadata, TaskResult) as (key, value) + """ + results_json_dict = json.loads(results_str) + results = {} + for key, value in results_json_dict.items(): + results[ + task_metadata.TaskMetadata.deserialize_task_metadata(key) + ] = TaskResult.deserialize_task_result(value) + return results diff --git a/tasks_config.json b/tasks_config.json new file mode 100644 index 00000000..200ebb72 --- /dev/null +++ b/tasks_config.json @@ -0,0 +1,40 @@ +{ + "out_file_location": "/mnt/qb/work/bethge/<>/<>", + "relative_data_folder": "/mnt/qb/work/bethge/<>/<>", + "tasks": [ + "imagenet_c", + "ccc", + "imagenet_3dcc", + "imagenet_cartoon", + "imagenet_d", + "imagenet_drawing", + "imagenet_m", + "imagenet_metashift", + "imagenet_patch", + "imagenet_r", + "objectnet", + "raccoons_ood", + "siscore", + "ssb", + "worst_case" + ], + "model": "ResNet18", + "import_lines": { + "imagenet_c": "from shifthappens.tasks.imagenet_c.imagenet_c import ImageNetCSeparateCorruptions", + "ccc": "from shifthappens.tasks.ccc.ccc import CCC", + "imagenet_3dcc": "from shifthappens.tasks.imagenet_3dcc.imagenet_3dcc import ImageNet3DCCSeparateCorruptions", + "imagenet_cartoon": "from shifthappens.tasks.imagenet_cartoon.imagenet_cartoon import ImageNetCartoon", + "imagenet_d": "from shifthappens.tasks.imagenet_d.imagenet_d import *", + "imagenet_drawing": "from shifthappens.tasks.imagenet_drawing.imagenet_drawing import ImageNetDrawing", + "imagenet_m": "from shifthappens.tasks.imagenet_m.imagenet_m import ImageNetM", + "imagenet_metashift": "from shifthappens.tasks.imagenet_metashift.imagenet_metashift import ImageNetMetaShift", + "imagenet_patch": "from shifthappens.tasks.imagenet_patch.imagenet_patch import ImageNetPatchCorruptions", + "imagenet_r": "from shifthappens.tasks.imagenet_r.imagenet_r import ImageNetR", + "objectnet": "from shifthappens.tasks.objectnet.objectnet import ObjectNet", + "raccoons_ood": "from shifthappens.tasks.raccoons_ood.raccoons_ood import RaccOOD", + "siscore": "from shifthappens.tasks.siscore.siscore import *", + "ssb": "from shifthappens.tasks.ssb.semantic_shift_benchmark import *", + "worst_case": "from shifthappens.tasks.worst_case.worst_case import WorstCase", + "ResNet18": "from shifthappens.models.torchvision import ResNet18" + } +} \ No newline at end of file