Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MultitaskWrapper #1762

Merged
merged 24 commits into from
May 15, 2023
Merged

Conversation

ValerianRey
Copy link
Contributor

@ValerianRey ValerianRey commented May 6, 2023

What does this PR do?

Fixes #1741

Before submitting
  • Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
PR review

This is the first draft of my work on #1741. Before going further in the development, I want to discuss several points:

  1. I made the choice to make this wrapper work entirely with dictionaries. In the __init__ method, task_metrics is a dict associating metrics to each task. In the update method, task_preds and task_targets are dictionaries associating a pred and a target to each task. Lastly, the output of the compute method is also a dictionary associating a computed metric to each task. Arguably, we could make it work entirely with Sequences (of metrics, of preds, of targets, of outputs). Another alternative would be to use a Sequence of metrics along with a Sequence of str in the __init__ method, for the metrics and the task names. The preds and targets could be Sequences of Tensors, and the output could be a Dictionary, mapping each task name to the computed output. Lastly, we could have both options supported, but I think this will make the code messy and the interface harder to understand.
    So I think the current version (using dicts only) is the best, but I want an external opinion on this before going further.

  2. I have a hard time understanding the full_state_update bool defined in Metric. What should I do with it?

  3. In MultioutputWrapper (which has many common points with MultitaskWrapper), I have seen that you have overriden the following methods as identity, with the docstring "Overwrite to do nothing.":

    • _wrap_update (code)
    • _wrap_compute (code)

    What was the intention, and should I do the same in MultitaskWrapper?

  4. I don't know whether I need to override forward or not (it is overriden in MultioutputWrapper (code))

  5. The plot method is greatly inspired from MetricCollection.plot (code). I decided to only support plotting each task's metrics separately, as they are heterogeneous in general. Arguably, each MetricCollection in the task_metrics would need to have a together bool parameter associated to it when calling plot, but that's very messy to implement and probably useless (the users could simply call plot separately and with the right value of together, for each metric in task_metric).

Please do not hesitate to give negative feedback!

Next steps
  • Add unit tests
  • Add docstring for plot
  • Make the necessary changes based on answers to 1., 2., 3., 4. and 5.
  • Make the necessary changes based on feedback

@ValerianRey ValerianRey marked this pull request as draft May 6, 2023 14:56
@ValerianRey
Copy link
Contributor Author

@SkafteNicki I'd love to have some of your feedback on this before going further!

@codecov
Copy link

codecov bot commented May 6, 2023

Codecov Report

Merging #1762 (7dcc995) into master (2d35650) will decrease coverage by 46%.
The diff coverage is 63%.

Additional details and impacted files
@@           Coverage Diff            @@
##           master   #1762     +/-   ##
========================================
- Coverage      87%     41%    -46%     
========================================
  Files         253     254      +1     
  Lines       14164   14224     +60     
========================================
- Hits        12387    5847   -6540     
- Misses       1777    8377   +6600     

@SkafteNicki SkafteNicki added enhancement New feature or request New metric labels May 6, 2023
@SkafteNicki SkafteNicki added this to the v1.0.0 milestone May 6, 2023
Copy link
Member

@SkafteNicki SkafteNicki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already looking good.
Of cause the testing is missing, but I think the interface is nice :]

src/torchmetrics/wrappers/multitask.py Outdated Show resolved Hide resolved
src/torchmetrics/wrappers/multitask.py Show resolved Hide resolved
src/torchmetrics/wrappers/multitask.py Outdated Show resolved Hide resolved
src/torchmetrics/wrappers/multitask.py Show resolved Hide resolved
ValerianRey and others added 4 commits May 6, 2023 21:06
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
- It enables automatic forwarding of some `nn.Module` method calls (`to` and `cuda` for instance) to the modules contained in `task_metrics`
@ValerianRey
Copy link
Contributor Author

Thanks for the feedback! I made most of the changes that you suggested!

@SkafteNicki
Copy link
Member

@ValerianRey cool. Could you try to write some unit tests for the wrapper?
I do not think that is should be that hard as the metric is not that complex. We basically just need to check that inputs gets correctly send to their respective task.

@ValerianRey
Copy link
Contributor Author

@ValerianRey cool. Could you try to write some unit tests for the wrapper? I do not think that is should be that hard as the metric is not that complex. We basically just need to check that inputs gets correctly send to their respective task.

Yes I will do that this week!
Could you please have a look at the questions that I asked in the original post of this PR and try to answer those for which you know the answer / have an opinion? It would help me a lot!

@SkafteNicki
Copy link
Member

Answers to your questions

I made the choice to make this wrapper work entirely with dictionaries. In the init method, task_metrics is a dict associating metrics to each task. In the update method, task_preds and task_targets are dictionaries associating a pred and a target to each task. Lastly, the output of the compute method is also a dictionary associating a computed metric to each task. Arguably, we could make it work entirely with Sequences (of metrics, of preds, of targets, of outputs). Another alternative would be to use a Sequence of metrics along with a Sequence of str in the init method, for the metrics and the task names. The preds and targets could be Sequences of Tensors, and the output could be a Dictionary, mapping each task name to the computed output. Lastly, we could have both options supported, but I think this will make the code messy and the interface harder to understand. So I think the current version (using dicts only) is the best, but I want an external opinion on this before going further.

Agree that this is a fine interface.

I have a hard time understanding the full_state_update bool defined in Metric. What should I do with it?

Instead of setting this, please add a custom forward method:

def forward(self, task_preds: Dict[str, Tensor], task_targets: Dict[str, Tensor]) -> Dict[str, Any]:
    return {metric(task_preds[task_name], tast_targets[task_name]) for task_name, metric in self.task_metrics.items()}:

The default forward function is not needed here because all computations are handled by the task specific metrics.

In MultioutputWrapper (which has many common points with MultitaskWrapper), I have seen that you have overriden the following methods as identity, with the docstring "Overwrite to do nothing.":
_wrap_update (code)
_wrap_compute (code)
What was the intention, and should I do the same in MultitaskWrapper?

_wrap_update and _wrap_compute for normal metrics takes care of keeping the states up to date and making sure that everything gets correctly synced in distributed settings. As this wrapper (and MultioutputWrapper) does not have states of their own, they rely on other metrics, we can safely override these because all the logic is dealt with the child metrics.
Short answer: yes, please overwrite them to do nothing.

I don't know whether I need to override forward or not (it is overriden in MultioutputWrapper (code))

See previous answer above

The plot method is greatly inspired from MetricCollection.plot (code). I decided to only support plotting each task's metrics separately, as they are heterogeneous in general. Arguably, each MetricCollection in the task_metrics would need to have a together bool parameter associated to it when calling plot, but that's very messy to implement and probably useless (the users could simply call plot separately and with the right value of together, for each metric in task_metric).

I agree that this is the way to go. The plot interface is meant for quick and dirty plotting of metrics, not to produce highly custom one. If users want that they can probably figure out to code it themselves :]

@SkafteNicki
Copy link
Member

A lot of the observations made here probably indicate that we should make a common interface for wrappers if possible since there seems to be some duplication going on. Lets tackle that in a future PR.

@ValerianRey
Copy link
Contributor Author

Instead of setting this, please add a custom forward method:

I added the forward method, with a docstring and a comment explaining this implementation choice, with what I understood. I hope that this is correct! Also, I'm not very familiar with the @torch.jit.unused decorator, and I'm unsure whether I should add it (like in Metric.forward, MultioutputWrapper.forward and MetricCollection.forward) or not (like in ClasswiseWrapper.forward, Running.forward and MetricTracker.forward).

The last 2 steps for me are the docstring of plot and the unit tests (I already did some quick tests locally but nothing clean yet, so I didn't commit them yet).

@SkafteNicki
Copy link
Member

Instead of setting this, please add a custom forward method:

I added the forward method, with a docstring and a comment explaining this implementation choice, with what I understood. I hope that this is correct! Also, I'm not very familiar with the @torch.jit.unused decorator, and I'm unsure whether I should add it (like in Metric.forward, MultioutputWrapper.forward and MetricCollection.forward) or not (like in ClasswiseWrapper.forward, Running.forward and MetricTracker.forward).

The last 2 steps for me are the docstring of plot and the unit tests (I already did some quick tests locally but nothing clean yet, so I didn't commit them yet).

Cool, the code still looks good to me. Going to wait for the unittest before reviewing the code again :]

* Rename `ax` to `axes`
* Add checks on argument `axes`
* Add docstring
* Improve code readability
@ValerianRey ValerianRey marked this pull request as ready for review May 13, 2023 19:19
@ValerianRey ValerianRey changed the title [WIP] Add MultitaskWrapper Add MultitaskWrapper May 13, 2023
@mergify mergify bot removed the has conflicts label May 13, 2023
Copy link
Member

@SkafteNicki SkafteNicki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM :]

@Borda
Copy link
Member

Borda commented May 15, 2023

seem that these test are failing because of GH actions down-time, lets try to restart a bit later 🐿️

@Borda Borda enabled auto-merge (squash) May 15, 2023 21:08
@Borda Borda merged commit be1192f into Lightning-AI:master May 15, 2023
@mergify mergify bot added the ready label May 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add a wrapper adapted to multi-task learning
3 participants