diff --git a/README.md b/README.md index dac8a39..e3b17d2 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ wandb login ### Usage -`spacy.WandbLogger.v4` is a logger that sends the results of each training step +`spacy.WandbLogger.v5` is a logger that sends the results of each training step to the dashboard of the [Weights & Biases](https://www.wandb.com/) tool. To use this logger, Weights & Biases should be installed, and you should be logged in. The logger will send the full config file to W&B, as well as various system @@ -58,6 +58,11 @@ information such as memory utilization, network traffic, disk IO, GPU statistics, etc. This will also include information such as your hostname and operating system, as well as the location of your Python executable. +`spacy.WandbLogger.v4` and below automatically call the [default console logger](https://spacy.io/api/top-level#ConsoleLogger). +However, starting with `spacy.WandbLogger.v5`, console logging must be activated +through the use of the [ChainLogger](#chainlogger). This allows the user to configure +the console logger's parameters according to their preferences. + **Note** that by default, the full (interpolated) [training config](https://spacy.io/usage/training#config) is sent over to the W&B dashboard. If you prefer to **exclude certain information** such as path @@ -70,23 +75,24 @@ on your local system. ```ini [training.logger] -@loggers = "spacy.WandbLogger.v4" +@loggers = "spacy.WandbLogger.v5" project_name = "monitor_spacy_training" remove_config_values = ["paths.train", "paths.dev", "corpora.train.path", "corpora.dev.path"] log_dataset_dir = "corpus" model_log_interval = 1000 ``` -| Name | Type | Description | -| ---------------------- | --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `project_name` | `str` | The name of the project in the Weights & Biases interface. The project will be created automatically if it doesn't exist yet. | -| `remove_config_values` | `List[str]` | A list of values to exclude from the config before it is uploaded to W&B (default: `[]`). | -| `model_log_interval` | `Optional[int]` | Steps to wait between logging model checkpoints to the W&B dasboard (default: `None`). Added in `spacy.WandbLogger.v2`. | -| `log_dataset_dir` | `Optional[str]` | Directory containing the dataset to be logged and versioned as a W&B artifact (default: `None`). Added in `spacy.WandbLogger.v2`. | -| `run_name` | `Optional[str]` | The name of the run. If you don't specify a run name, the name will be created by the `wandb` library (default: `None`). Added in `spacy.WandbLogger.v3`. | -| `entity` | `Optional[str]` | An entity is a username or team name where you're sending runs. If you don't specify an entity, the run will be sent to your default entity, which is usually your username (default: `None`). Added in `spacy.WandbLogger.v3`. | -| `log_best_dir` | `Optional[str]` | Directory containing the best trained model as saved by spaCy (by default in `training/model-best`), to be logged and versioned as a W&B artifact (default: `None`). Added in `spacy.WandbLogger.v4`. | -| `log_latest_dir` | `Optional[str]` | Directory containing the latest trained model as saved by spaCy (by default in `training/model-latest`), to be logged and versioned as a W&B artifact (default: `None`). Added in `spacy.WandbLogger.v4`. | +| Name | Type | Description | +| ---------------------- | --------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `project_name` | `str` | The name of the project in the Weights & Biases interface. The project will be created automatically if it doesn't exist yet. | +| `remove_config_values` | `List[str]` | A list of values to exclude from the config before it is uploaded to W&B (default: `[]`). | +| `model_log_interval` | `Optional[int]` | Steps to wait between logging model checkpoints to the W&B dasboard (default: `None`). Added in `spacy.WandbLogger.v2`. | +| `log_dataset_dir` | `Optional[str]` | Directory containing the dataset to be logged and versioned as a W&B artifact (default: `None`). Added in `spacy.WandbLogger.v2`. | +| `entity` | `Optional[str]` | An entity is a username or team name where you're sending runs. If you don't specify an entity, the run will be sent to your default entity, which is usually your username (default: `None`). Added in `spacy.WandbLogger.v3`. | +| `run_name` | `Optional[str]` | The name of the run. If you don't specify a run name, the name will be created by the `wandb` library (default: `None`). Added in `spacy.WandbLogger.v3`. | +| `log_best_dir` | `Optional[str]` | Directory containing the best trained model as saved by spaCy (by default in `training/model-best`), to be logged and versioned as a W&B artifact (default: `None`). Added in `spacy.WandbLogger.v4`. | +| `log_latest_dir` | `Optional[str]` | Directory containing the latest trained model as saved by spaCy (by default in `training/model-latest`), to be logged and versioned as a W&B artifact (default: `None`). Added in `spacy.WandbLogger.v4`. | +| `log_custom_stats` | `Optional[List[str]]` | A list of regular expressions that will be applied to the info dictionary passed to the logger (default: `None`). Statistics and metrics that match these regexps will be automatically logged. Added in `spacy.WandbLogger.v5`. | ## MLflowLogger diff --git a/setup.cfg b/setup.cfg index 050abf0..e62606f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,6 +15,7 @@ python_requires = >=3.6 [options.entry_points] spacy_loggers = + spacy.WandbLogger.v5 = spacy_loggers.wandb:wandb_logger_v5 spacy.WandbLogger.v4 = spacy_loggers.wandb:wandb_logger_v4 spacy.WandbLogger.v3 = spacy_loggers.wandb:wandb_logger_v3 spacy.WandbLogger.v2 = spacy_loggers.wandb:wandb_logger_v2 diff --git a/spacy_loggers/tests/test_registry.py b/spacy_loggers/tests/test_registry.py index e4e9290..218ca82 100644 --- a/spacy_loggers/tests/test_registry.py +++ b/spacy_loggers/tests/test_registry.py @@ -6,6 +6,7 @@ ("loggers", "spacy.WandbLogger.v2"), ("loggers", "spacy.WandbLogger.v3"), ("loggers", "spacy.WandbLogger.v4"), + ("loggers", "spacy.WandbLogger.v5"), ("loggers", "spacy.MLflowLogger.v1"), ("loggers", "spacy.ClearMLLogger.v1"), ("loggers", "spacy.ChainLogger.v1"), diff --git a/spacy_loggers/wandb.py b/spacy_loggers/wandb.py index c829db7..28bd112 100644 --- a/spacy_loggers/wandb.py +++ b/spacy_loggers/wandb.py @@ -2,119 +2,181 @@ A logger that logs training activity to Weights and Biases. """ -from typing import Dict, Any, Tuple, Callable, List, Optional, IO +from typing import Dict, Any, Tuple, Callable, List, IO, Optional +from types import ModuleType import sys -from spacy import util from spacy import Language -from spacy.training.loggers import console_logger +from spacy.util import SimpleFrozenList -# entry point: spacy.WandbLogger.v4 -def wandb_logger_v4( +from .util import dict_to_dot, dot_to_dict, matcher_for_regex_patterns +from .util import setup_default_console_logger, LoggerT + + +# entry point: spacy.WandbLogger.v5 +def wandb_logger_v5( project_name: str, - remove_config_values: List[str] = [], + remove_config_values: List[str] = SimpleFrozenList(), model_log_interval: Optional[int] = None, log_dataset_dir: Optional[str] = None, entity: Optional[str] = None, run_name: Optional[str] = None, log_best_dir: Optional[str] = None, log_latest_dir: Optional[str] = None, -): - try: - import wandb + log_custom_stats: Optional[List[str]] = None, +) -> LoggerT: + """Creates a logger that interoperates with the Weights & Biases framework. + + Args: + project_name (str): + The name of the project in the Weights & Biases interface. The project will be created automatically if it doesn't exist yet. + remove_config_values (List[str]): + A list of values to exclude from the config before it is uploaded to W&B. Defaults to []. + model_log_interval (Optional[int]): + Steps to wait between logging model checkpoints to the W&B dasboard. Defaults to None. + log_dataset_dir (Optional[str]): + Directory containing the dataset to be logged and versioned as a W&B artifact. Defaults to None. + entity (Optional[str]): + An entity is a username or team name where you're sending runs. If you don't specify an entity, the run will be sent to your default entity, which is usually your username. Defaults to None. + run_name (Optional[str]): + The name of the run. If you don't specify a run name, the name will be created by the `wandb` library. Defaults to None. + log_best_dir (Optional[str]): + Directory containing the best trained model as saved by spaCy, to be logged and versioned as a W&B artifact. Defaults to None. + log_latest_dir (Optional[str]): + Directory containing the latest trained model as saved by spaCy, to be logged and versioned as a W&B artifact. Defaults to None. + log_custom_stats (Optional[List[str]]): + A list of regular expressions that will be applied to the info dictionary passed to the logger. Statistics and metrics that match these regexps will be automatically logged. Defaults to None. + + Returns: + LoggerT: Logger instance. + """ + wandb = _import_wandb() - # test that these are available - from wandb import init, log, join # noqa: F401 - except ImportError: - raise ImportError( - "The 'wandb' library could not be found - did you install it? " - "Alternatively, specify the 'ConsoleLogger' in the " - "'training.logger' config section, instead of the 'WandbLogger'." + def setup_logger( + nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr + ) -> Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]: + match_stat = matcher_for_regex_patterns(log_custom_stats) + run = _setup_wandb( + wandb, + nlp, + project_name, + remove_config_values=remove_config_values, + entity=entity, ) + if run_name: + wandb.run.name = run_name + + if log_dataset_dir: + _log_dir_artifact( + wandb, path=log_dataset_dir, name="dataset", type="dataset" + ) + + def log_step(info: Optional[Dict[str, Any]]): + _log_scores(wandb, info) + _log_model_artifact(wandb, info, run, model_log_interval) + _log_custom_stats(wandb, info, match_stat) + + def finalize() -> None: + if log_best_dir: + _log_dir_artifact( + wandb, + path=log_best_dir, + name="model_best", + type="model", + ) + + if log_latest_dir: + _log_dir_artifact( + wandb, + path=log_latest_dir, + name="model_last", + type="model", + ) + wandb.join() + + return log_step, finalize - console = console_logger(progress_bar=False) + return setup_logger + + +# entry point: spacy.WandbLogger.v4 +def wandb_logger_v4( + project_name: str, + remove_config_values: List[str] = SimpleFrozenList(), + model_log_interval: Optional[int] = None, + log_dataset_dir: Optional[str] = None, + entity: Optional[str] = None, + run_name: Optional[str] = None, + log_best_dir: Optional[str] = None, + log_latest_dir: Optional[str] = None, +) -> LoggerT: + """Creates a logger that interoperates with the Weights & Biases framework. + + Args: + project_name (str): + The name of the project in the Weights & Biases interface. The project will be created automatically if it doesn't exist yet. + remove_config_values (List[str]): + A list of values to exclude from the config before it is uploaded to W&B. Defaults to []. + model_log_interval (Optional[int]): + Steps to wait between logging model checkpoints to the W&B dasboard. Defaults to None. + log_dataset_dir (Optional[str]): + Directory containing the dataset to be logged and versioned as a W&B artifact. Defaults to None. + entity (Optional[str]): + An entity is a username or team name where you're sending runs. If you don't specify an entity, the run will be sent to your default entity, which is usually your username. Defaults to None. + run_name (Optional[str]): + The name of the run. If you don't specify a run name, the name will be created by the `wandb` library. Defaults to None. + log_best_dir (Optional[str]): + Directory containing the best trained model as saved by spaCy, to be logged and versioned as a W&B artifact. Defaults to None. + log_latest_dir (Optional[str]): + Directory containing the latest trained model as saved by spaCy, to be logged and versioned as a W&B artifact. Defaults to None. + + Returns: + LoggerT: Logger instance. + """ + wandb = _import_wandb() def setup_logger( nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr ) -> Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]: - config = nlp.config.interpolate() - config_dot = util.dict_to_dot(config) - for field in remove_config_values: - del config_dot[field] - config = util.dot_to_dict(config_dot) - run = wandb.init( - project=project_name, config=config, entity=entity, reinit=True + console_log_step, console_finalize = setup_default_console_logger( + nlp, stdout, stderr + ) + run = _setup_wandb( + wandb, + nlp, + project_name, + remove_config_values=remove_config_values, + entity=entity, ) - if run_name: wandb.run.name = run_name - console_log_step, console_finalize = console(nlp, stdout, stderr) - - def log_dir_artifact( - path: str, - name: str, - type: str, - metadata: Optional[Dict[str, Any]] = {}, - aliases: Optional[List[str]] = [], - ): - dataset_artifact = wandb.Artifact( - name, type=type, metadata=metadata - ) - dataset_artifact.add_dir(path, name=name) - wandb.log_artifact(dataset_artifact, aliases=aliases) - if log_dataset_dir: - log_dir_artifact( - path=log_dataset_dir, name="dataset", type="dataset" + _log_dir_artifact( + wandb, path=log_dataset_dir, name="dataset", type="dataset" ) def log_step(info: Optional[Dict[str, Any]]): console_log_step(info) - if info is not None: - score = info["score"] - other_scores = info["other_scores"] - losses = info["losses"] - wandb.log({"score": score}) - if losses: - wandb.log({f"loss_{k}": v for k, v in losses.items()}) - if isinstance(other_scores, dict): - wandb.log(other_scores) - if model_log_interval and info.get("output_path"): - if ( - info["step"] % model_log_interval == 0 - and info["step"] != 0 - ): - log_dir_artifact( - path=info["output_path"], - name="pipeline_" + run.id, - type="checkpoint", - metadata=info, - aliases=[ - f"epoch {info['epoch']} step {info['step']}", - "latest", - "best" - if info["score"] == max(info["checkpoints"])[0] - else "", - ], - ) + _log_scores(wandb, info) + _log_model_artifact(wandb, info, run, model_log_interval) def finalize() -> None: - if log_best_dir: - log_dir_artifact( + _log_dir_artifact( + wandb, path=log_best_dir, name="model_best", type="model", ) if log_latest_dir: - log_dir_artifact( + _log_dir_artifact( + wandb, path=log_latest_dir, name="model_last", type="model", ) - console_finalize() wandb.join() @@ -126,90 +188,58 @@ def finalize() -> None: # entry point: spacy.WandbLogger.v3 def wandb_logger_v3( project_name: str, - remove_config_values: List[str] = [], + remove_config_values: List[str] = SimpleFrozenList(), model_log_interval: Optional[int] = None, log_dataset_dir: Optional[str] = None, entity: Optional[str] = None, run_name: Optional[str] = None, -): - try: - import wandb - - # test that these are available - from wandb import init, log, join # noqa: F401 - except ImportError: - raise ImportError( - "The 'wandb' library could not be found - did you install it? " - "Alternatively, specify the 'ConsoleLogger' in the " - "'training.logger' config section, instead of the 'WandbLogger'." - ) - - console = console_logger(progress_bar=False) +) -> LoggerT: + """Creates a logger that interoperates with the Weights & Biases framework. + + Args: + project_name (str): + The name of the project in the Weights & Biases interface. The project will be created automatically if it doesn't exist yet. + remove_config_values (List[str]): + A list of values to exclude from the config before it is uploaded to W&B. Defaults to []. + model_log_interval (Optional[int]): + Steps to wait between logging model checkpoints to the W&B dasboard. Defaults to None. + log_dataset_dir (Optional[str]): + Directory containing the dataset to be logged and versioned as a W&B artifact. Defaults to None. + entity (Optional[str]): + An entity is a username or team name where you're sending runs. If you don't specify an entity, the run will be sent to your default entity, which is usually your username. Defaults to None. + run_name (Optional[str]): + The name of the run. If you don't specify a run name, the name will be created by the `wandb` library. Defaults to None. + + Returns: + LoggerT: Logger instance. + """ + wandb = _import_wandb() def setup_logger( nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr ) -> Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]: - config = nlp.config.interpolate() - config_dot = util.dict_to_dot(config) - for field in remove_config_values: - del config_dot[field] - config = util.dot_to_dict(config_dot) - run = wandb.init( - project=project_name, config=config, entity=entity, reinit=True + console_log_step, console_finalize = setup_default_console_logger( + nlp, stdout, stderr + ) + run = _setup_wandb( + wandb, + nlp, + project_name, + remove_config_values=remove_config_values, + entity=entity, ) - if run_name: wandb.run.name = run_name - console_log_step, console_finalize = console(nlp, stdout, stderr) - - def log_dir_artifact( - path: str, - name: str, - type: str, - metadata: Optional[Dict[str, Any]] = {}, - aliases: Optional[List[str]] = [], - ): - dataset_artifact = wandb.Artifact( - name, type=type, metadata=metadata - ) - dataset_artifact.add_dir(path, name=name) - wandb.log_artifact(dataset_artifact, aliases=aliases) - if log_dataset_dir: - log_dir_artifact( - path=log_dataset_dir, name="dataset", type="dataset" + _log_dir_artifact( + wandb, path=log_dataset_dir, name="dataset", type="dataset" ) def log_step(info: Optional[Dict[str, Any]]): console_log_step(info) - if info is not None: - score = info["score"] - other_scores = info["other_scores"] - losses = info["losses"] - wandb.log({"score": score}) - if losses: - wandb.log({f"loss_{k}": v for k, v in losses.items()}) - if isinstance(other_scores, dict): - wandb.log(other_scores) - if model_log_interval and info.get("output_path"): - if ( - info["step"] % model_log_interval == 0 - and info["step"] != 0 - ): - log_dir_artifact( - path=info["output_path"], - name="pipeline_" + run.id, - type="checkpoint", - metadata=info, - aliases=[ - f"epoch {info['epoch']} step {info['step']}", - "latest", - "best" - if info["score"] == max(info["checkpoints"])[0] - else "", - ], - ) + _log_scores(wandb, info) + _log_model_artifact(wandb, info, run, model_log_interval) def finalize() -> None: console_finalize() @@ -223,82 +253,46 @@ def finalize() -> None: # entry point: spacy.WandbLogger.v2 def wandb_logger_v2( project_name: str, - remove_config_values: List[str] = [], + remove_config_values: List[str] = SimpleFrozenList(), model_log_interval: Optional[int] = None, log_dataset_dir: Optional[str] = None, -): - try: - import wandb - - # test that these are available - from wandb import init, log, join # noqa: F401 - except ImportError: - raise ImportError( - "The 'wandb' library could not be found - did you install it? " - "Alternatively, specify the 'ConsoleLogger' in the " - "'training.logger' config section, instead of the 'WandbLogger'." - ) - - console = console_logger(progress_bar=False) +) -> LoggerT: + """Creates a logger that interoperates with the Weights & Biases framework. + + Args: + project_name (str): + The name of the project in the Weights & Biases interface. The project will be created automatically if it doesn't exist yet. + remove_config_values (List[str]): + A list of values to exclude from the config before it is uploaded to W&B. Defaults to []. + model_log_interval (Optional[int]): + Steps to wait between logging model checkpoints to the W&B dasboard. Defaults to None. + log_dataset_dir (Optional[str]): + Directory containing the dataset to be logged and versioned as a W&B artifact. Defaults to None. + + Returns: + LoggerT: Logger instance. + """ + wandb = _import_wandb() def setup_logger( nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr ) -> Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]: - config = nlp.config.interpolate() - config_dot = util.dict_to_dot(config) - for field in remove_config_values: - del config_dot[field] - config = util.dot_to_dict(config_dot) - run = wandb.init(project=project_name, config=config, reinit=True) - console_log_step, console_finalize = console(nlp, stdout, stderr) - - def log_dir_artifact( - path: str, - name: str, - type: str, - metadata: Optional[Dict[str, Any]] = {}, - aliases: Optional[List[str]] = [], - ): - dataset_artifact = wandb.Artifact( - name, type=type, metadata=metadata - ) - dataset_artifact.add_dir(path, name=name) - wandb.log_artifact(dataset_artifact, aliases=aliases) + console_log_step, console_finalize = setup_default_console_logger( + nlp, stdout, stderr + ) + run = _setup_wandb( + wandb, nlp, project_name, remove_config_values=remove_config_values + ) if log_dataset_dir: - log_dir_artifact( - path=log_dataset_dir, name="dataset", type="dataset" + _log_dir_artifact( + wandb, path=log_dataset_dir, name="dataset", type="dataset" ) def log_step(info: Optional[Dict[str, Any]]): console_log_step(info) - if info is not None: - score = info["score"] - other_scores = info["other_scores"] - losses = info["losses"] - wandb.log({"score": score}) - if losses: - wandb.log({f"loss_{k}": v for k, v in losses.items()}) - if isinstance(other_scores, dict): - wandb.log(other_scores) - if model_log_interval and info.get("output_path"): - if ( - info["step"] % model_log_interval == 0 - and info["step"] != 0 - ): - log_dir_artifact( - path=info["output_path"], - name="pipeline_" + run.id, - type="checkpoint", - metadata=info, - aliases=[ - f"epoch {info['epoch']} step {info['step']}", - "latest", - "best" - if info["score"] == max(info["checkpoints"])[0] - else "", - ], - ) + _log_scores(wandb, info) + _log_model_artifact(wandb, info, run, model_log_interval) def finalize() -> None: console_finalize() @@ -310,12 +304,53 @@ def finalize() -> None: # entry point: spacy.WandbLogger.v1 -def wandb_logger_v1(project_name: str, remove_config_values: List[str] = []): +def wandb_logger_v1( + project_name: str, remove_config_values: List[str] = SimpleFrozenList() +) -> LoggerT: + """Creates a logger that interoperates with the Weights & Biases framework. + + Args: + project_name (str): + The name of the project in the Weights & Biases interface. The project will be created automatically if it doesn't exist yet. + remove_config_values (List[str]): + A list of values to exclude from the config before it is uploaded to W&B. Defaults to []. + + Returns: + LoggerT: Logger instance. + """ + wandb = _import_wandb() + + def setup_logger( + nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr + ) -> Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]: + console_log_step, console_finalize = setup_default_console_logger( + nlp, stdout, stderr + ) + _setup_wandb( + wandb, nlp, project_name, remove_config_values=remove_config_values + ) + + def log_step(info: Optional[Dict[str, Any]]): + console_log_step(info) + _log_scores(wandb, info) + + def finalize() -> None: + console_finalize() + wandb.join() + + return log_step, finalize + + return setup_logger + + +def _import_wandb() -> ModuleType: try: import wandb # test that these are available from wandb import init, log, join # noqa: F401 + + return wandb except ImportError: raise ImportError( "The 'wandb' library could not be found - did you install it? " @@ -323,35 +358,75 @@ def wandb_logger_v1(project_name: str, remove_config_values: List[str] = []): "'training.logger' config section, instead of the 'WandbLogger'." ) - console = console_logger(progress_bar=False) - def setup_logger( - nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr - ) -> Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]: - config = nlp.config.interpolate() - config_dot = util.dict_to_dot(config) - for field in remove_config_values: - del config_dot[field] - config = util.dot_to_dict(config_dot) - wandb.init(project=project_name, config=config, reinit=True) - console_log_step, console_finalize = console(nlp, stdout, stderr) +def _setup_wandb( + wandb: ModuleType, + nlp: "Language", + project: str, + entity: Optional[str] = None, + remove_config_values: List[str] = SimpleFrozenList(), +) -> Any: + config = nlp.config.interpolate() + config_dot = dict_to_dot(config) + for field in remove_config_values: + del config_dot[field] + config = dot_to_dict(config_dot) + run = wandb.init(project=project, config=config, entity=entity, reinit=True) + return run + + +def _log_scores(wandb: ModuleType, info: Optional[Dict[str, Any]]): + if info is not None: + score = info["score"] + other_scores = info["other_scores"] + losses = info["losses"] + wandb.log({"score": score}) + if losses: + wandb.log({f"loss_{k}": v for k, v in losses.items()}) + if isinstance(other_scores, dict): + wandb.log(other_scores) + + +def _log_model_artifact( + wandb: ModuleType, + info: Optional[Dict[str, Any]], + run: Any, + model_log_interval: Optional[int] = None, +): + if info is not None: + if model_log_interval and info.get("output_path"): + if info["step"] % model_log_interval == 0 and info["step"] != 0: + _log_dir_artifact( + wandb, + path=info["output_path"], + name="pipeline_" + run.id, + type="checkpoint", + metadata=info, + aliases=[ + f"epoch {info['epoch']} step {info['step']}", + "latest", + "best" if info["score"] == max(info["checkpoints"])[0] else "", + ], + ) - def log_step(info: Optional[Dict[str, Any]]): - console_log_step(info) - if info is not None: - score = info["score"] - other_scores = info["other_scores"] - losses = info["losses"] - wandb.log({"score": score}) - if losses: - wandb.log({f"loss_{k}": v for k, v in losses.items()}) - if isinstance(other_scores, dict): - wandb.log(other_scores) - def finalize() -> None: - console_finalize() - wandb.join() +def _log_dir_artifact( + wandb: ModuleType, + path: str, + name: str, + type: str, + metadata: Optional[Dict[str, Any]] = None, + aliases: Optional[List[str]] = None, +): + dataset_artifact = wandb.Artifact(name, type=type, metadata=metadata) + dataset_artifact.add_dir(path, name=name) + wandb.log_artifact(dataset_artifact, aliases=aliases) - return log_step, finalize - return setup_logger +def _log_custom_stats( + wandb: ModuleType, info: Optional[Dict[str, Any]], matcher: Callable[[str], bool] +): + if info is not None: + for k, v in info.items(): + if matcher(k): + wandb.log({k: v})