-
Notifications
You must be signed in to change notification settings - Fork 74.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Setting class_weight in model.fit() with tf.data.Dataset causes error #47032
Comments
@tensortorch, |
@amahendrakar, It does help since I can make my dataset return a sample weight as a third value, which is what model.fit() does anyway under the hood when class_weight is provided. So one can work around this issue using sample weights instead. However, I believe providing class_weight in model.fit() should still work. The linked comment explains that it is not expected to work for 3+ dimensional targets, but this is not the case here. In fact, the error occurs within the check for target dimensionality, and is caused by the target rank being None for some reason:
Also, it does work for for the same inputs when the dataset is converted to an iterator, as my minimal example shows. The error only occurs for a tf.data.Dataset as input value. Here, the DataHandler attempts to add the third output to the dataset by calling map() to convert class_weight to sample_weight:
which fails due to the aforementioned error. |
The workaround for a somewhat related problem also works in this case: an additional call to map() to manually set the tensor shape makes it work. I previously tried manually converting the outputs of the py_function to tensors and also manually setting their shapes, but it did not work, so the key here is the second call to map() to set the shapes after batch(). |
@tensortorch, |
@amahendrakar |
@jvishnuvardhan, |
Same error, checked on python 3.8 and TF: v2.2.0, 2.3.0, 2.4.0, 2.4.1
|
This still fails in tf 2.5.0. On the go I've created another reproduction script: https://gist.github.com/kretes/ca911085b2eb0fa3985894245ce3fd0c I suggest changing the name of the issue to 'Setting class_weight in model.fit() with tf.data.Dataset using py_function causes error'. as this is the step that makes for unknown shape |
this still shows error , please help |
When I do the
For validation data, you can explicitly mention fit with validation data like below. |
@sachinprasadhs you are referring to Can you try running the gist e.g. in collab to see if it fails on your side, and modify accordingly for it to pass? |
@kretes , If you still face the issue, could you please open the issue in keras/team-keras repo. Thanks! |
Ok, I will move to keras repo. Just that I am not passing |
This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you. |
@sachinprasadhs as @kretes already mentioned, the fit method in the example is following the documentation and not using a separate y argument. You can also see from the example that the fit call works when the same dataset is transformed to an iterator. |
Development of keras moved to separate repository https://github.com/keras-team/keras/issues Please post this issue on keras-team/keras repo. |
This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you. |
Closing as stale. Please reopen if you'd like to work on this further. |
System information
Describe the current behavior
When a tf.data.Dataset is used in model.fit(), setting class_weight causes an error.
Describe the expected behavior
No error occurs.
Standalone code to reproduce the issue
Error message
The text was updated successfully, but these errors were encountered: