Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up TF tests by reducing hidden layer counts #24595

Merged
merged 6 commits into from
Jun 30, 2023
Merged

Conversation

Rocketknight1
Copy link
Member

A lot of our slow TF tests are caused by TF compilation. TF compilation isn't really affected by layer width at all - the main thing is just the number of operations it has to build a graph for. By reducing the number of hidden layers, compilation gets much faster, (hopefully) without interfering with test coverage at all.

@Rocketknight1 Rocketknight1 requested a review from ydshieh June 30, 2023 13:21
Comment on lines -1530 to -1559
# Make sure fit works with tf.data.Dataset and results are consistent
dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)

if sample_weight is not None:
# Add in the sample weight
weighted_dataset = dataset.map(lambda x: (x, None, tf.convert_to_tensor(0.5, dtype=tf.float32)))
else:
weighted_dataset = dataset
# Pass in all samples as a batch to match other `fit` calls
weighted_dataset = weighted_dataset.batch(len(dataset))
dataset = dataset.batch(len(dataset))
# Reinitialize to fix batchnorm again
model.set_weights(model_weights)

# To match the other calls, don't pass sample weights in the validation data
history3 = model.fit(
weighted_dataset,
validation_data=dataset,
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
)
val_loss3 = history3.history["val_loss"][0]
self.assertTrue(not isnan(val_loss3))
accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")}
self.check_keras_fit_results(val_loss1, val_loss3)
self.assertEqual(history1.history.keys(), history3.history.keys())
if metrics:
self.assertTrue(len(accuracy1) == len(accuracy3) > 0, "Missing metrics!")

Copy link
Member Author

@Rocketknight1 Rocketknight1 Jun 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this section in test_keras_fit a long time ago when I was paranoid about issues from fitting tf.data.Dataset, but I believe it should not be possible for this section to fail if the other model tests pass, so I removed it!

@ydshieh
Copy link
Collaborator

ydshieh commented Jun 30, 2023

Would be nice if you can show the timing for one model (before v.s. after) 🙏 . Thanks.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 30, 2023

The documentation is not available anymore as the PR was closed or merged.

@Rocketknight1
Copy link
Member Author

@ydshieh testing locally BERT went from 510 seconds -> 220 seconds

@@ -57,7 +57,7 @@ def __init__(
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_hidden_layers=1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not use 1 but 2 if possible. 1 is kind exceptional.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me except not use 1 as num of layers 🙏 .

Let's also have an approval from @sgugger as I am not the one who decide to use 5.

@Rocketknight1 Rocketknight1 requested a review from sgugger June 30, 2023 14:23
@Rocketknight1
Copy link
Member Author

I don't know if it was @sgugger either - a lot of this code is really old! I see tf.tuple() in there, and even I had to look up the TF 1.x docs to remember what that was supposed to do, lol

@ydshieh
Copy link
Collaborator

ydshieh commented Jun 30, 2023

I know he is probably not the one to decide use 5, but he might know the history :-)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@Rocketknight1 Rocketknight1 merged commit 134caef into main Jun 30, 2023
@Rocketknight1 Rocketknight1 deleted the tf_test_speedup branch June 30, 2023 15:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants