From cf185e69f4608ab346aa98f56b2c803266589e91 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Fri, 3 Jan 2025 14:16:50 +0000 Subject: [PATCH] fix input names in sparsemax --- mambular/arch_utils/layer_utils/sparsemax.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mambular/arch_utils/layer_utils/sparsemax.py b/mambular/arch_utils/layer_utils/sparsemax.py index b874a6a..6c8e1ef 100644 --- a/mambular/arch_utils/layer_utils/sparsemax.py +++ b/mambular/arch_utils/layer_utils/sparsemax.py @@ -36,13 +36,13 @@ class SparsemaxFunction(Function): """ @staticmethod - def forward(ctx, x, dim=-1): + def forward(ctx, input, dim=-1): """ Forward pass of sparsemax: a normalizing, sparse transformation. Parameters ---------- - x : torch.Tensor + input : torch.Tensor The input tensor on which sparsemax will be applied. dim : int, optional Dimension along which to apply sparsemax. Default is -1. @@ -53,8 +53,8 @@ def forward(ctx, x, dim=-1): A tensor with the same shape as the input, with sparsemax applied. """ ctx.dim = dim - max_val, _ = x.max(dim=dim, keepdim=True) - x -= max_val # Numerical stability trick, as with softmax. + max_val, _ = input.max(dim=dim, keepdim=True) + input -= max_val # Numerical stability trick, as with softmax. tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim) output = torch.clamp(input - tau, min=0) ctx.save_for_backward(supp_size, output)