Skip to content

Commit

Permalink
rps lje: move feature extractor outside of map as it was causing prob…
Browse files Browse the repository at this point in the history
…lems in benchmark
  • Loading branch information
Agustin-Picard committed Feb 16, 2024
1 parent 31ec527 commit 727e5e4
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions deel/influenciae/rps/rps_lje.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from .base_representer_point import BaseRepresenterPoint
from ..common import InfluenceModel, InverseHessianVectorProductFactory
from ..utils import map_to_device
from ..types import Union, Optional


Expand Down Expand Up @@ -65,12 +64,16 @@ def __init__(

# Get a dataset to compute the SGD step
if n_samples_for_hessian is None:
dataset_to_estimate_hessian = map_to_device(dataset, lambda x, y: (self.feature_extractor(x), y))
dataset_to_estimate_hessian = dataset
else:
dataset_to_estimate_hessian = map_to_device(
dataset.shuffle(shuffle_buffer_size).take(n_samples_for_hessian),
lambda x, y: (self.feature_extractor(x), y)
)
n_batches_for_hessian = max(n_samples_for_hessian // dataset._batch_size, 1)
dataset_to_estimate_hessian = dataset.shuffle(shuffle_buffer_size).take(n_batches_for_hessian)
f_array, y_array = None, None
for x, y in dataset_to_estimate_hessian:
f = self.feature_extractor(x)
f_array = f if f_array is None else tf.concat([f_array, f], axis=0)
y_array = y if y_array is None else tf.concat([y_array, y], axis=0)
dataset_to_estimate_hessian = tf.data.Dataset.from_tensor_slices((f_array, y_array)).batch(dataset._batch_size)

# Accumulate the gradients for the whole dataset and then update
trainable_vars = perturbed_head.trainable_variables
Expand Down

0 comments on commit 727e5e4

Please sign in to comment.