Skip to content

Commit

Permalink
fixed linter issue
Browse files Browse the repository at this point in the history
  • Loading branch information
thib-s committed May 25, 2021
1 parent 518f1e3 commit c09ed37
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 16 deletions.
7 changes: 4 additions & 3 deletions deel/lip/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,9 +588,10 @@ def vanilla_export(self):
@_deel_export
class FrobeniusDense(Dense, LipschitzLayer, Condensable):
"""
Same a SpectralDense, but in the case of a single output. In the multiclass setting,
the behaviour of this layer is similar to the stacking of 1 lipschitz layer (each output
is 1-lipschitz, but the no orthogonality is enforced between outputs ).
Same a SpectralDense, but in the case of a single output. In the multiclass
setting, the behaviour of this layer is similar to the stacking of 1 lipschitz
layer (each output is 1-lipschitz, but the no orthogonality is enforced between
outputs ).
"""

def __init__(
Expand Down
27 changes: 16 additions & 11 deletions deel/lip/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ def hinge_margin_fct(y_true, y_pred):
@_deel_export
def KR_multiclass_loss():
r"""
Loss to estimate average of W1 distance using Kantorovich-Rubinstein duality over outputs.
Note y_true should be one hot encoding (labels being 1s and 0s ).
In this multiclass setup thr KR term is computed for each class and then averaged.
Loss to estimate average of W1 distance using Kantorovich-Rubinstein duality over
outputs. Note y_true should be one hot encoding (labels being 1s and 0s ). In
this multiclass setup thr KR term is computed for each class and then averaged.
Returns:
Callable, the function to compute Wasserstein multiclass loss.
Expand All @@ -147,17 +147,20 @@ def KR_multiclass_loss():
@tf.function
def KR_multiclass_loss_fct(y_true, y_pred):
# use y_true to zero out y_pred where y_true != 1
# espYtrue is the avg value of y_pred when y_true==1 (one average per output neuron)
# espYtrue is the avg value of y_pred when y_true==1
# (one average per output neuron)
espYtrue = tf.reduce_sum(y_pred * y_true, axis=0) / tf.reduce_sum(
y_true, axis=0
)
# use(1- y_true) to zero out y_pred where y_true == 1
# espNotYtrue is the avg value of y_pred when y_true==0 (one average per output neuron)
# espNotYtrue is the avg value of y_pred when y_true==0
# (one average per output neuron)
espNotYtrue = tf.reduce_sum(y_pred * (1 - y_true), axis=0) / (
tf.cast(tf.shape(y_true)[0], dtype="float32")
- tf.reduce_sum(y_true, axis=0)
)
# compute the differences to have the KR term for each output neuron, and compute the average over the classes
# compute the differences to have the KR term for each output neuron,
# then compute the average over the classes
return tf.reduce_mean(-espNotYtrue + espYtrue)

return KR_multiclass_loss_fct
Expand All @@ -166,10 +169,11 @@ def KR_multiclass_loss_fct(y_true, y_pred):
@_deel_export
def Hinge_multiclass_loss(min_margin=1):
"""
Loss to estimate the Hinge loss in a multiclass setup. It compute the elementwise hinge term. Note that this
formulation differs from the one commonly found in tensorflow/pytorch (with marximise the difference between the two
largest logits). This formulation is consistent with the binary classification loss used in a multiclass fashion.
Note y_true should be one hot encoded. labels in (1,0)
Loss to estimate the Hinge loss in a multiclass setup. It compute the elementwise
hinge term. Note that this formulation differs from the one commonly found in
tensorflow/pytorch (with marximise the difference between the two largest
logits). This formulation is consistent with the binary classification loss used
in a multiclass fashion. Note y_true should be one hot encoded. labels in (1,0)
Returns:
Callable, the function to compute multiclass Hinge loss
Expand All @@ -193,7 +197,8 @@ def Hinge_multiclass_loss_fct(y_true, y_pred):
@_deel_export
def HKR_multiclass_loss(alpha=0.0, min_margin=1):
"""
The multiclass version of HKR. This is done by computing the HKR term over each class and averaging the results.
The multiclass version of HKR. This is done by computing the HKR term over each
class and averaging the results.
Args:
alpha: regularization factor
Expand Down
8 changes: 6 additions & 2 deletions doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,11 @@
html_static_path = ["_static"]

html_context = {
"css_files": ["_static/theme_overrides.css",], # override wide tables in RTD theme
"css_files": [
"_static/theme_overrides.css",
], # override wide tables in RTD theme
}

autodoc_member_order = ["bysource", ]
autodoc_member_order = [
"bysource",
]

0 comments on commit c09ed37

Please sign in to comment.