Skip to content

Commit

Permalink
Vectorize TensorPrimitives.Sigmoid and TensorPrimitives.SoftMax (dotn…
Browse files Browse the repository at this point in the history
…et#93029)

* Vectorize TensorPrimitives.Sigmoid and TensorPrimitives.SoftMax

- Adds a SigmoidOperator that just wraps the ExpOperator
- Vectorizes both passes of SoftMax, on top of ExpOperator. Simplest way to do this was to augment the existing InvokeSpanScalarIntoSpan to take a transform operator.
- In doing so, found some naming inconsistencies I'd previously introduced, so I did some automatic renaming to make things more consistent.
- Added XML comments to all the internal/private surface area.
- Fleshes out some tests (and test values).

* Disable tests on mono

* Address PR feedback
  • Loading branch information
stephentoub authored and michaelgsharp committed Oct 20, 2023
1 parent f48d8b0 commit 6c63ae7
Show file tree
Hide file tree
Showing 4 changed files with 537 additions and 234 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -988,17 +988,7 @@ public static void Sigmoid(ReadOnlySpan<float> x, Span<float> destination)
ThrowHelper.ThrowArgument_SpansMustBeNonEmpty();
}

if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);

for (int i = 0; i < x.Length; i++)
{
destination[i] = 1f / (1f + MathF.Exp(-x[i]));
}
InvokeSpanIntoSpan<SigmoidOperator>(x, destination);
}

/// <summary>Computes the element-wise hyperbolic sine of each single-precision floating-point radian angle in the specified tensor.</summary>
Expand Down Expand Up @@ -1067,17 +1057,9 @@ public static void SoftMax(ReadOnlySpan<float> x, Span<float> destination)

ValidateInputOutputSpanNonOverlapping(x, destination);

float expSum = 0f;

for (int i = 0; i < x.Length; i++)
{
expSum += MathF.Exp(x[i]);
}
float expSum = Aggregate<ExpOperator, AddOperator>(x);

for (int i = 0; i < x.Length; i++)
{
destination[i] = MathF.Exp(x[i]) / expSum;
}
InvokeSpanScalarIntoSpan<ExpOperator, DivideOperator>(x, expSum, destination);
}

/// <summary>Computes the element-wise difference between single-precision floating-point numbers in the specified tensors.</summary>
Expand Down
Loading

0 comments on commit 6c63ae7

Please sign in to comment.