From 47419a02851877cefa7e713616886a20df7b10d1 Mon Sep 17 00:00:00 2001 From: "agustin-martin.picard" Date: Fri, 2 Feb 2024 16:50:01 +0100 Subject: [PATCH] lint: clean-up --- deel/influenciae/rps/rps_lje.py | 18 ++++++++++++------ deel/influenciae/utils/tf_operations.py | 1 - 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/deel/influenciae/rps/rps_lje.py b/deel/influenciae/rps/rps_lje.py index 9844975..0db673b 100644 --- a/deel/influenciae/rps/rps_lje.py +++ b/deel/influenciae/rps/rps_lje.py @@ -10,7 +10,6 @@ from typing import Tuple import tensorflow as tf -from tensorflow.keras.models import Sequential # pylint: disable=E0611 from ..common import InfluenceModel, InverseHessianVectorProductFactory, BaseInfluenceCalculator from ..utils import map_to_device, split_model, assert_batched_dataset @@ -99,7 +98,11 @@ def __init__( self.perturbed_head = perturbed_head # Create the new model with the perturbed weights to compute the hessian matrix - model = InfluenceModel(self.perturbed_head, 1, loss_function=influence_model.loss_function) # layer 0 is InputLayer + model = InfluenceModel( + self.perturbed_head, + 1, # layer 0 is InputLayer + loss_function=influence_model.loss_function + ) self.ihvp_calculator = ihvp_calculator_factory.build(model, dataset_to_estimate_hessian) def _reshape_assign(self, weights, influence_vector: tf.Tensor) -> None: @@ -175,13 +178,16 @@ def _compute_alpha(self, z_batch: tf.Tensor, y_batch: tf.Tensor) -> tf.Tensor: ) ) second_term = tf.map_fn( - lambda v: self.ihvp_calculator._compute_ihvp_single_batch(tf.expand_dims(v, axis=0), use_gradient=False), + lambda v: self.ihvp_calculator._compute_ihvp_single_batch( # pylint: disable=protected-access + tf.expand_dims(v, axis=0), + use_gradient=False + ), grads - ) # pylint: disable=protected-access + ) second_term = tf.reduce_sum(tf.reshape(second_term, tf.shape(grads)), axis=1) # Second, we compute the first term, which contains the weights - first_term = tf.concat([w for w in weights], axis=0) + first_term = tf.concat(list(weights), axis=0) first_term = tf.multiply( first_term, tf.repeat( @@ -302,7 +308,7 @@ def _estimate_influence_value_from_influence_vector( A tensor with influence values for the (batch of) test samples. """ # Extract the different information inside the tuples - feature_maps_test, labels_test = preproc_test_sample + feature_maps_test, _ = preproc_test_sample alpha, feature_maps_train = influence_vector if len(alpha.shape) == 1 or (len(alpha.shape) == 2 and alpha.shape[1] == 1): diff --git a/deel/influenciae/utils/tf_operations.py b/deel/influenciae/utils/tf_operations.py index 0d5d896..5f98392 100644 --- a/deel/influenciae/utils/tf_operations.py +++ b/deel/influenciae/utils/tf_operations.py @@ -8,7 +8,6 @@ import numpy as np import tensorflow as tf -from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 from ..types import Union, Tuple, Optional, Callable