-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed up TF tests by reducing hidden layer counts #24595
Conversation
# Make sure fit works with tf.data.Dataset and results are consistent | ||
dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class) | ||
|
||
if sample_weight is not None: | ||
# Add in the sample weight | ||
weighted_dataset = dataset.map(lambda x: (x, None, tf.convert_to_tensor(0.5, dtype=tf.float32))) | ||
else: | ||
weighted_dataset = dataset | ||
# Pass in all samples as a batch to match other `fit` calls | ||
weighted_dataset = weighted_dataset.batch(len(dataset)) | ||
dataset = dataset.batch(len(dataset)) | ||
# Reinitialize to fix batchnorm again | ||
model.set_weights(model_weights) | ||
|
||
# To match the other calls, don't pass sample weights in the validation data | ||
history3 = model.fit( | ||
weighted_dataset, | ||
validation_data=dataset, | ||
steps_per_epoch=1, | ||
validation_steps=1, | ||
shuffle=False, | ||
) | ||
val_loss3 = history3.history["val_loss"][0] | ||
self.assertTrue(not isnan(val_loss3)) | ||
accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")} | ||
self.check_keras_fit_results(val_loss1, val_loss3) | ||
self.assertEqual(history1.history.keys(), history3.history.keys()) | ||
if metrics: | ||
self.assertTrue(len(accuracy1) == len(accuracy3) > 0, "Missing metrics!") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this section in test_keras_fit
a long time ago when I was paranoid about issues from fitting tf.data.Dataset
, but I believe it should not be possible for this section to fail if the other model tests pass, so I removed it!
93a4d9d
to
c411d1a
Compare
Would be nice if you can show the timing for one model (before v.s. after) 🙏 . Thanks. |
The documentation is not available anymore as the PR was closed or merged. |
@ydshieh testing locally BERT went from 510 seconds -> 220 seconds |
@@ -57,7 +57,7 @@ def __init__( | |||
use_labels=True, | |||
vocab_size=99, | |||
hidden_size=32, | |||
num_hidden_layers=5, | |||
num_hidden_layers=1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not use 1 but 2 if possible. 1
is kind exceptional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works for me except not use 1
as num of layers 🙏 .
Let's also have an approval from @sgugger as I am not the one who decide to use 5
.
I don't know if it was @sgugger either - a lot of this code is really old! I see |
I know he is probably not the one to decide use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
A lot of our slow TF tests are caused by TF compilation. TF compilation isn't really affected by layer width at all - the main thing is just the number of operations it has to build a graph for. By reducing the number of hidden layers, compilation gets much faster, (hopefully) without interfering with test coverage at all.