Skip to content

Commit

Permalink
🐞 Fix HPO (#462)
Browse files Browse the repository at this point in the history
* Fix issue #452

* Use system node in pre-commit
  • Loading branch information
ashwinvaidya17 authored Aug 1, 2022
1 parent 33355be commit 1a4d963
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
3 changes: 3 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
default_language_version:
node: system

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
Expand Down
4 changes: 4 additions & 0 deletions tools/benchmarking/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Benchmarking Tools."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
12 changes: 8 additions & 4 deletions tools/hpo/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,17 @@
from omegaconf import DictConfig, ListConfig, OmegaConf
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import WandbLogger
from utils import flatten_hpo_params

import wandb
from anomalib.config import get_configurable_parameters, update_input_size_config
from anomalib.data import get_datamodule
from anomalib.models import get_model
from anomalib.utils.sweep import flatten_sweep_params, set_in_nested_config

from .utils import flatten_hpo_params
from anomalib.utils.sweep import (
flatten_sweep_params,
get_sweep_callbacks,
set_in_nested_config,
)


class WandbSweep:
Expand Down Expand Up @@ -70,11 +73,12 @@ def sweep(self):

model = get_model(config)
datamodule = get_datamodule(config)
callbacks = get_sweep_callbacks(config)

# Disable saving checkpoints as all checkpoints from the sweep will get uploaded
config.trainer.checkpoint_callback = False

trainer = pl.Trainer(**config.trainer, logger=wandb_logger)
trainer = pl.Trainer(**config.trainer, logger=wandb_logger, callbacks=callbacks)
trainer.fit(model, datamodule=datamodule)


Expand Down

0 comments on commit 1a4d963

Please sign in to comment.