From 727e5e4f0f406b499b399a5cbbeb9392d292513e Mon Sep 17 00:00:00 2001 From: "agustin-martin.picard" Date: Fri, 16 Feb 2024 12:16:15 +0100 Subject: [PATCH] rps lje: move feature extractor outside of map as it was causing problems in benchmark --- deel/influenciae/rps/rps_lje.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/deel/influenciae/rps/rps_lje.py b/deel/influenciae/rps/rps_lje.py index ec3f51e..e929cee 100644 --- a/deel/influenciae/rps/rps_lje.py +++ b/deel/influenciae/rps/rps_lje.py @@ -11,7 +11,6 @@ from .base_representer_point import BaseRepresenterPoint from ..common import InfluenceModel, InverseHessianVectorProductFactory -from ..utils import map_to_device from ..types import Union, Optional @@ -65,12 +64,16 @@ def __init__( # Get a dataset to compute the SGD step if n_samples_for_hessian is None: - dataset_to_estimate_hessian = map_to_device(dataset, lambda x, y: (self.feature_extractor(x), y)) + dataset_to_estimate_hessian = dataset else: - dataset_to_estimate_hessian = map_to_device( - dataset.shuffle(shuffle_buffer_size).take(n_samples_for_hessian), - lambda x, y: (self.feature_extractor(x), y) - ) + n_batches_for_hessian = max(n_samples_for_hessian // dataset._batch_size, 1) + dataset_to_estimate_hessian = dataset.shuffle(shuffle_buffer_size).take(n_batches_for_hessian) + f_array, y_array = None, None + for x, y in dataset_to_estimate_hessian: + f = self.feature_extractor(x) + f_array = f if f_array is None else tf.concat([f_array, f], axis=0) + y_array = y if y_array is None else tf.concat([y_array, y], axis=0) + dataset_to_estimate_hessian = tf.data.Dataset.from_tensor_slices((f_array, y_array)).batch(dataset._batch_size) # Accumulate the gradients for the whole dataset and then update trainable_vars = perturbed_head.trainable_variables