Skip to content

Commit

Permalink
lint: clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
Agustin-Picard committed Feb 2, 2024
1 parent 6289e71 commit 47419a0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
18 changes: 12 additions & 6 deletions deel/influenciae/rps/rps_lje.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Tuple

import tensorflow as tf
from tensorflow.keras.models import Sequential # pylint: disable=E0611

from ..common import InfluenceModel, InverseHessianVectorProductFactory, BaseInfluenceCalculator
from ..utils import map_to_device, split_model, assert_batched_dataset
Expand Down Expand Up @@ -99,7 +98,11 @@ def __init__(
self.perturbed_head = perturbed_head

# Create the new model with the perturbed weights to compute the hessian matrix
model = InfluenceModel(self.perturbed_head, 1, loss_function=influence_model.loss_function) # layer 0 is InputLayer
model = InfluenceModel(
self.perturbed_head,
1, # layer 0 is InputLayer
loss_function=influence_model.loss_function
)
self.ihvp_calculator = ihvp_calculator_factory.build(model, dataset_to_estimate_hessian)

def _reshape_assign(self, weights, influence_vector: tf.Tensor) -> None:
Expand Down Expand Up @@ -175,13 +178,16 @@ def _compute_alpha(self, z_batch: tf.Tensor, y_batch: tf.Tensor) -> tf.Tensor:
)
)
second_term = tf.map_fn(
lambda v: self.ihvp_calculator._compute_ihvp_single_batch(tf.expand_dims(v, axis=0), use_gradient=False),
lambda v: self.ihvp_calculator._compute_ihvp_single_batch( # pylint: disable=protected-access
tf.expand_dims(v, axis=0),
use_gradient=False
),
grads
) # pylint: disable=protected-access
)
second_term = tf.reduce_sum(tf.reshape(second_term, tf.shape(grads)), axis=1)

# Second, we compute the first term, which contains the weights
first_term = tf.concat([w for w in weights], axis=0)
first_term = tf.concat(list(weights), axis=0)
first_term = tf.multiply(
first_term,
tf.repeat(
Expand Down Expand Up @@ -302,7 +308,7 @@ def _estimate_influence_value_from_influence_vector(
A tensor with influence values for the (batch of) test samples.
"""
# Extract the different information inside the tuples
feature_maps_test, labels_test = preproc_test_sample
feature_maps_test, _ = preproc_test_sample
alpha, feature_maps_train = influence_vector

if len(alpha.shape) == 1 or (len(alpha.shape) == 2 and alpha.shape[1] == 1):
Expand Down
1 change: 0 additions & 1 deletion deel/influenciae/utils/tf_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import numpy as np
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

from ..types import Union, Tuple, Optional, Callable

Expand Down

0 comments on commit 47419a0

Please sign in to comment.