-
Notifications
You must be signed in to change notification settings - Fork 411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MultitaskWrapper
#1762
Add MultitaskWrapper
#1762
Conversation
@SkafteNicki I'd love to have some of your feedback on this before going further! |
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #1762 +/- ##
========================================
- Coverage 87% 41% -46%
========================================
Files 253 254 +1
Lines 14164 14224 +60
========================================
- Hits 12387 5847 -6540
- Misses 1777 8377 +6600 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already looking good.
Of cause the testing is missing, but I think the interface is nice :]
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
- It enables automatic forwarding of some `nn.Module` method calls (`to` and `cuda` for instance) to the modules contained in `task_metrics`
Thanks for the feedback! I made most of the changes that you suggested! |
@ValerianRey cool. Could you try to write some unit tests for the wrapper? |
Yes I will do that this week! |
Answers to your questions
Agree that this is a fine interface.
Instead of setting this, please add a custom def forward(self, task_preds: Dict[str, Tensor], task_targets: Dict[str, Tensor]) -> Dict[str, Any]:
return {metric(task_preds[task_name], tast_targets[task_name]) for task_name, metric in self.task_metrics.items()}: The default forward function is not needed here because all computations are handled by the task specific metrics.
See previous answer above
I agree that this is the way to go. The plot interface is meant for quick and dirty plotting of metrics, not to produce highly custom one. If users want that they can probably figure out to code it themselves :] |
A lot of the observations made here probably indicate that we should make a common interface for wrappers if possible since there seems to be some duplication going on. Lets tackle that in a future PR. |
for more information, see https://pre-commit.ci
I added the The last 2 steps for me are the docstring of |
Cool, the code still looks good to me. Going to wait for the unittest before reviewing the code again :] |
* Rename `ax` to `axes` * Add checks on argument `axes` * Add docstring * Improve code readability
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM :]
seem that these test are failing because of GH actions down-time, lets try to restart a bit later 🐿️ |
What does this PR do?
Fixes #1741
Before submitting
PR review
This is the first draft of my work on #1741. Before going further in the development, I want to discuss several points:
I made the choice to make this wrapper work entirely with dictionaries. In the
__init__
method,task_metrics
is a dict associating metrics to each task. In theupdate
method,task_preds
andtask_targets
are dictionaries associating a pred and a target to each task. Lastly, the output of thecompute
method is also a dictionary associating a computed metric to each task. Arguably, we could make it work entirely with Sequences (of metrics, of preds, of targets, of outputs). Another alternative would be to use a Sequence of metrics along with a Sequence of str in the__init__
method, for the metrics and the task names. The preds and targets could be Sequences of Tensors, and the output could be a Dictionary, mapping each task name to the computed output. Lastly, we could have both options supported, but I think this will make the code messy and the interface harder to understand.So I think the current version (using dicts only) is the best, but I want an external opinion on this before going further.
I have a hard time understanding the
full_state_update
bool defined inMetric
. What should I do with it?In
MultioutputWrapper
(which has many common points withMultitaskWrapper
), I have seen that you have overriden the following methods as identity, with the docstring "Overwrite to do nothing.":_wrap_update
(code)_wrap_compute
(code)What was the intention, and should I do the same in
MultitaskWrapper
?I don't know whether I need to override
forward
or not (it is overriden inMultioutputWrapper
(code))The
plot
method is greatly inspired fromMetricCollection.plot
(code). I decided to only support plotting each task's metrics separately, as they are heterogeneous in general. Arguably, eachMetricCollection
in thetask_metrics
would need to have atogether
bool parameter associated to it when callingplot
, but that's very messy to implement and probably useless (the users could simply callplot
separately and with the right value oftogether
, for each metric intask_metric
).Please do not hesitate to give negative feedback!
Next steps
plot