-
Notifications
You must be signed in to change notification settings - Fork 362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: remove calls to Pytorch Dataset len #8647
Conversation
✅ Deploy Preview for determined-ui canceled.
|
71b78f4
to
44e0e57
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #8647 +/- ##
==========================================
- Coverage 47.16% 47.14% -0.02%
==========================================
Files 1150 1150
Lines 141674 141671 -3
Branches 2415 2417 +2
==========================================
- Hits 66814 66786 -28
- Misses 74670 74695 +25
Partials 190 190
Flags with carried forward coverage won't be shown. Click here to find out more.
|
1452720
to
a3c2ae0
Compare
for callback in self.callbacks.values(): | ||
callback.on_validation_epoch_start() | ||
|
||
idx = -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are taking a check which was fully contained in two consecutive lines and spreading it over a wider area here. I don't want you to rewrite it because what you have is simple and easy, but can you just put a comment above idx = -1
to explain why -1 is significant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a nice idea. Done.
@@ -1059,12 +1054,17 @@ def _validate(self, searcher_op: Optional[core.SearcherOperation] = None) -> Dic | |||
# common than evaluate_batch() and we can't know how the user processed their | |||
# validation data. | |||
if self._evaluate_batch_defined(): | |||
# Reshape and sum. | |||
# TODO: remove the type directive once we upgrade to mypy >= 1.7.0 | |||
inputs_total, batches_total = [sum(n) for n in zip(*input_counts)] # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_counts doesn't seem to be defined in the evaluate_full_dataset codepath
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code is kind of unfortunately structured and pretty hard to figure out:
1. if evaluate_batch:
2. do stuff in a managed way
3. if chief:
4. do managed stuff relevant to the chief
5. else: # evaluate_full
6. do whatever stuff evaluate_full says to do
7.
8. if chief:
9. if evaluate_batch:
10. report batch-specific detail
11. report general stuff
Before patch:
num_inputs
defined / calculated line 2input_counts
defined line 2 (from gatherednum_inputs
)num_inputs
re-defined line 4 when evaluate_batch (frominput_counts
)num_inputs
independently defined line 6 when evaluate_full to be given a meaning similar to that on line 4 (fromlen(validation_loader)
)num_inputs
(per its second definition) used line 10 and nowhere else
Pre-patch, num_inputs
was being defined when both evaluate_batch and evaluate_full, but only used during evaluate_batch.
After patch:
num_inputs
defined / calculated line 2input_counts
defined line 2 (from gatherednum_inputs
)inputs_total
defined line 10 (frominput_counts
)inputs_total
used line 10
Effects of this patch:
num_inputs
not defined during evaluate_full (which is fine, because it hadn't been used there)- gathered calculation from
input_counts
moved to where it's used (line 10) - gathered calculation from
input_counts
given a new name sonum_inputs
isn't overloaded
For full "determined" functionality, this must return an instance of | ||
:py:class:`determined.pytorch.DataLoader`. It can also return an unwrapped | ||
:py:class:`torch.utils.data.DataLoader` if you need more control over the underlying | ||
DataLoader and are willing to sacrifice some Determined features (ex: automatic data | ||
sharding). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for catching that our typing and docstring had fallen out of correctness.
However, this is a user-facing docstring, and given that I think this addition is much too vague. What is full "determined" functionality
anyway?
I wouldn't recommend answering that question here; I would link to our existing docs on the subject.
Something like:
Users with a MapDataset will normally return a :class:`determined.pytorch.DataLoader`, but users with an IterableDataset or with other advanced needs may return a bare ``torch.utils.data.DataLoader`` if they follow the steps described in :ref:`pytorch-reproducible-dataset`.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree about the vagueness. It made me feel a little gross, too.
Ideally, the docstring should say why a user would choose which class to return. It's somewhat easy to say "if you can't return a det.pytorch.DataLoader
, return the other one. It's harder to come up with an explanation for "why managed" that's not vague.
I like your idea of referring to docs for that. That's something docs.determined.ai should do. The doc you suggested does have a nice note explaining how, I can't find anything on the site for why.
I think a voice chat is probably the best way to work something out from here. When you get to this and you've got the time to, could you please give me a call?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ended up with a little from column A, a little from column B.
""" | ||
Defines the data loader to use during validation. | ||
|
||
Must return an instance of :py:class:`determined.pytorch.DataLoader`. | ||
For full "determined" functionality, this must return an instance of | ||
:py:class:`determined.pytorch.DataLoader`. It can also return an unwrapped |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skip the :py
in :py:class:...
, as it is not necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. Thanks! (also changed in another existing place)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are missing the most important usage of len(data_loader)
which is here.
When I saw in #4303 (comment) that when you last looked at this line you preferred a different solution then, too. I understand that there can be inaccuracy both from this line and from variable-length datasets, too, and that counting examples is more foolproof. But I haven't worked out what the implications of that inaccuracy might be (or the cost of a solution) well enough to create a ticket for it. |
0c692dc
to
e326a17
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice
Datasets aren't guaranteed to have a `__len__` implemented. This removes two calls to `__len__` that weren't actually necessary.
This is functionally an abstract method, but actually marking it as abstract or making it raise an exception will functionally change the interface people have to implement in a breaking change. This is a problem for tomorrow.
cc0bc65
to
7cb26c5
Compare
Description
Pytorch Datasets (
torch.utils.data.Dataset
) aren't guaranteed to have a__len__
implemented (Datasets can be either "map-style" or "iterable-style". When map-style, they must implement a__len__
, and when iterable-style they may). The__len__
on a Pytorch DataLoader may pass the call through to its Dataset.A
det.pytorch.PyTorchTrial
is typically constructed from adet.pytorch.DataLoader
.det.pytorch.DataLoader
cannot, itself, front an iterable-style Pytorch Dataset. It is, however, possible to construct adet.pytorch.PyTorchTrial
with an unwrappedtorch.utils.data.Dataset
ifcontext.experimental.disable_dataset_reproducibility_checks()
is called in thePyTorchTrial
's__init__
.Before this patch, during a
PyTorchTrialContext.run
we calledlen
on the trial's validation dataloader. Per the above, it had been possible to construct a trial with a validation dataloader that did not have__len__
implemented, and in this caserun
would raise a runtimeTypeError
exception.Turns out, though, those existing calls to
__len__
that weren't actually necessary. This patch revises them with no functional change in behaviorlen(validation_loader)
to check for emptiness before iterating through it, instead check the number of times the validation_loader is iterated through, raising the same error if it was empty.len
where the result was entirely ignored.This PR also makes a couple "continuous improvement" changes, including moving around a couple pieces of code and renames variables so that its logic is a little more obvious.
Test Plan
I've tested this by hand by modifying
build_validation_training_loader
in the example https://github.com/determined-ai/determined/blob/main/examples/tutorials/mnist_pytorch/train.py to return a Dataset for atorch.utils.data.Dataset
subclass that has no implemented__len__
and then running https://github.com/determined-ai/determined/blob/main/harness/tests/experiment/pytorch/test_examples.py tests on it. Without the patch, validation fails because of the call to__len__
. With the patch, validation succeeds.We don't have any unit tests of the function I modified (
_PyTorchTrialController._run
or its caller_PyTorchTrialController.run
), and this doesn't quite seem like enough of a patch to create unit tests for the class for it. It also doesn't seem quite appropriate to create another end-to-end test just to ensure__len__
isn't called on a validation loader. Maybe more automated tests aren't needed?For the release party, if you'd like to test this yourself, create a
PyTorchTrial
object of a class that's implementedbuild_validation_data_loader
to return a plain, unwrappedtorch.utils.data.DataLoader
that itself has no__len__
. Then run a training loop for this trial object.Commentary (optional)
Checklist
docs/release-notes/
.See Release Note for details.
Ticket
[MLG-1022]