You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently all of the metrics are more or less structured by the following scheme:
x: array
y: array
a: array
for x_instance, y_instance, a_instance in zip(x, y, a):
for perturbation_step in range(perturbation_steps):
x_perturbed = perturb_instance(x_instance, a_instance, perturbation_step)
y_perturbed = model(x_perturbed)
score = calculate_score_for_instance(y_instance, y_perturbed)
The choice of perturb_instance arguments are just for simplicity, the code is of course more complex than presented.
But this kind of implementation doesn't use the performance benefits from batched model-prediction and vectorized numpy functions.
Instead we could speed up computations by a magnitude if we would instead use the following approach:
x: array
y: array
a: array
batch_size: int
generator = BatchGenerator(x, y, a, batch_size)
for x_batch, y_batch, a_batch in next(generator):
for perturbation_step in range(perturbation_steps):
x_batch_perturbed = perturb_batch(x_batch, a_batch, perturbation_step)
y_batch_perturbed = model(x_batch_perturbed)
score = calculate_score_for_batch(y_batch, y_batch_perturbed)
Some of perturb_batch functions may need an inner for-loop again, but others could be computed on the whole batch for sure.
Depending on the dataset size and model complexity, this should lead to significant improvements in performance.
The text was updated successfully, but these errors were encountered:
Currently all of the metrics are more or less structured by the following scheme:
The choice of perturb_instance arguments are just for simplicity, the code is of course more complex than presented.
But this kind of implementation doesn't use the performance benefits from batched model-prediction and vectorized numpy functions.
Instead we could speed up computations by a magnitude if we would instead use the following approach:
Some of
perturb_batch
functions may need an inner for-loop again, but others could be computed on the whole batch for sure.Depending on the dataset size and model complexity, this should lead to significant improvements in performance.
The text was updated successfully, but these errors were encountered: