Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metrics support for sweeping #148

Closed
maximsch2 opened this issue Mar 29, 2021 · 15 comments · Fixed by #544
Closed

Metrics support for sweeping #148

maximsch2 opened this issue Mar 29, 2021 · 15 comments · Fixed by #544
Labels
enhancement New feature or request
Milestone

Comments

@maximsch2
Copy link
Contributor

maximsch2 commented Mar 29, 2021

🚀 Feature

We would like to have tighter integration of metrics and sweeping. This requires a few features:

  1. Knowing if higher_is_better (e.g. are we trying to minimize or maximize the metric in a sweep)
  2. Knowing what value to optimize for. E.g. if a recall@precision metric returns both recall value and corresponding threshold, we want to optimize by maximizing recall and ignoring the threshold.

Alternatives

An alternative implementation will be for each metric to have is_better(left: TMetricResult, right: TMetricResult) where TMetricResult is whatever compute returns.

If we don't have it, people will have to have wrappers around the metrics to support this functionality in sweepers.

@maximsch2 maximsch2 added enhancement New feature or request help wanted Extra attention is needed labels Mar 29, 2021
@SkafteNicki
Copy link
Member

I think it is great addition.
@maximsch2 is there a specific framework that you had in mind where this would give better integration?

@maximsch2
Copy link
Contributor Author

Yeah, maybe something like this: https://ax.dev/tutorials/tune_cnn.html

But I do imagine that any sort of sweeping requires us to be able to a) select the target metric b) compare two runs to see if the metric

@SkafteNicki
Copy link
Member

@maximsch2 what do you think about something like this:

_REGISTER = {}

def register(metric, minimize, index=None):
    if minimize:
        compare_fn = torch.less
        init_val = torch.tensor(float("inf"))
    else:
        compare_fn = torch.greater
        init_val = -torch.tensor(float("inf"))
    _REGISTER[metric] = (minimize, compare_fn, init_val, index)

register(MeanSquaredError, True)


class MetricCompare:
    def __init__(self, metric):
        self.base_metric = metric
        minimize, compare_fn, init_val, index = _REGISTER[type(metric)]
        self._minimize = minimize
        self._compare_fn = compare_fn
        self._index = index
        self._init_val = init_val
        self._new_val = deepcopy(init_val)
        self._old_val = deepcopy(init_val)
        
    def update(self, *args, **kwargs):
        self.base_metric.update(*args, **kwargs)

    def compute(self):
        self._old_val = self._new_val
        val = self.base_metric.compute()
        self._new_val = val.detach()
        return val
    
    def reset(self):
        self.base_metric.reset()
        self._new_val = deepcopy(self._init_val)
        self._old_val = deepcopy(self._init_val)

    @property
    def has_improved(self):
        if self._index is None:
            return self._compare_fn(self._new_val, self._old_val)
        else:
            return self._compare_fn(self._new_val[index], self._old_val[index])

    @property
    def minimize(self):
        return self._minimize
    
    @property
    def maximize(self):
        return not self.minimize

metric = MetricCompare(MeanSquaredError())
metric.update(torch.randn(100,), torch.randn(100,))
val = metric.compute()
print(metric.has_improved)

this is basically a wrapper for metrics that adds additional properties that can tell if the metric should be minimized/maximized and after compute is called if it has improved.

@maximsch2
Copy link
Contributor Author

maximsch2 commented Mar 30, 2021

Usually sweeps will be run in a distributed fashion (e.g. schedule runs with different hyperparams separately, compute metric values, pick the one with the best metric), so has_improved might not be as useful there.

Thinking about it a bit more, just providing a way to convert a metric to optimization value might be enough (with a semantics that we are increasing or decreasing it).

Another example of package for hyperparam optimization that also takes objective: http://hyperopt.github.io/hyperopt/

@breznak
Copy link

breznak commented Apr 1, 2021

I'd like to see this implemented as well. We're using PL + Optuna (+ Hydra's plugin_sweeper_optuna) and running into the same problem. Esp. when a metric of a model is configurable.

I think the approach with property direction() -> 'min'/'max' is simple and would suffice.

While the solutions with wrappers work, I think it'd be good if PL somehow standardized this, so the other HP optimization libraries can integrate this.

@SkafteNicki
Copy link
Member

Okay, then settle on adding a property to each metric.

  1. What should it be named?
    direction->'min'/'max',
    minimize->True/False,
    higher_is_better->True/False
  2. It should not be implemented for all metrics. ConfusionMatrix comes to mind where it does not make sense to talk when one if better than another
  3. How do we deal with metric with multi output and metrics with multidim output.

@breznak
Copy link

breznak commented Apr 1, 2021

ConfusionMatrix comes to mind where it [min/max] does not make sense

add -> min/max/None?

  1. How do we deal with metric with multi output and metrics with multidim output.

Ie. Optuna let's you define a tuple

direction: 
 - minimize
 - maximize

I'd say we don't care for the first iteration and just leave these as None. And we cannot decide anyway on pareto-optimal front.

... and you probably meant multi-dim metric's output, not multidim optimization, right?
For the multidim output, we need a form of reduction.

Can we say that for the first draft, this feature works only form metrics that Loss(y_hat: Tensor, y: Tensor) -> float ?

@maximsch2
Copy link
Contributor Author

For multi-output metrics we need ability to extract the value that is actually being optimized over. E.g. some metrics can return value and corresponding threshold (e.g. recall@precision=90%) and we only want to optimize over the actual value.

@Borda
Copy link
Member

Borda commented Apr 26, 2021

@maximsch2 @breznak @SkafteNicki how is it going here? do we have a resolution on what to do?

@breznak
Copy link

breznak commented Apr 28, 2021

I think we got stuck on more advanced cases (eg. metrics that return more values, as above). While I see it's important to design it well so it works for all usecases in the future, I think we should find a MVP that we can easily implement, otherwise this will likely get stuck.

In practice, what we're running into is that this would ideally be coordinated "API" for pl.metrics and torchmetrics.

E.g. some metrics can return value and corresponding threshold (e.g. recall@precision=90%) and we only want to optimize over the actual value.

could you elaborate on this example, please, @maximsch2 ? From what I understand, the metric returns multiple values for several thresholds. But wouldn't the direction still be the same for all of them? (recall -> max ?)

@SkafteNicki
Copy link
Member

In practice, what we're running into is that this would ideally be coordinated "API" for pl.metrics and torchmetrics.

@breznak since pl.metrics will be deprecated in v1.3 of lightning and completely removed from v1.5, we only need to think about the torchmetrics API.

E.g. some metrics can return value and corresponding threshold (e.g. recall@precision=90%) and we only want to optimize over the actual value.

could you elaborate on this example, please, @maximsch2 ? From what I understand, the metric returns multiple values for several thresholds. But wouldn't the direction still be the same for all of them? (recall -> max ?)

I think what @maximsch2 is referring to, is that metrics such as PrecisionRecallCurve have 3 outputs:

precision, recall, thresholds = pr_curve(pred, target)

where I basically want to optimize the precision/recall but not the threshold values.

@breznak
Copy link

breznak commented Apr 28, 2021

we only need to think about the torchmetrics API.

good to know, thanks! then it should be easier.

precision, recall, thresholds = pr_curve(pred, target)
where I basically want to optimize the precision/recall but not the threshold values.

how about adding a "tell us what is the (1) optimization criterion for you" to the metric, then?
Like precision, recall, thresholds = pr_curve(pred, target, optimize='recall')
Then we have 1 number that represents the "important" results from such metric.

@maximsch2
Copy link
Contributor Author

I'm actually thinking that maybe let's defer the multi-output metrics to later as long as we can support those in CompositionalMetric. E.g. for single-output metrics, we'll provide higher_is_better, but for multi-output metrics we'll skip it and rely on people doing something like CompositionalMetric(RecallAndThresholdMetric()[0], None, higher_is_better=True) which will implement the needed functions and return the single value?

@breznak
Copy link

breznak commented Apr 29, 2021

I'm for starting small, but doing it rather soon.
Btw, it'd be nice to get people from Optuna/Ray/Ax/etc PL sweepers here, as those might have valuable feedback.

@stale
Copy link

stale bot commented Jun 28, 2021

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Jun 28, 2021
@Borda Borda added this to the v0.5 milestone Jul 2, 2021
@stale stale bot removed the wontfix label Jul 2, 2021
@Borda Borda modified the milestones: v0.5, v0.6 Aug 3, 2021
@Borda Borda removed the help wanted Extra attention is needed label Sep 20, 2021
@SkafteNicki SkafteNicki mentioned this issue Sep 24, 2021
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants