Skip to content

Commit

Permalink
Fix TensorExtensions.StdDev
Browse files Browse the repository at this point in the history
  • Loading branch information
lilinus committed Dec 4, 2024
1 parent bc23f63 commit c212984
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3512,7 +3512,7 @@ public static ref readonly TensorSpan<T> StackAlongDimension<T>(scoped ReadOnlyS
/// <param name="x">The <see cref="TensorSpan{T}"/> to take the standard deviation of.</param>
/// <returns><typeparamref name="T"/> representing the standard deviation.</returns>
public static T StdDev<T>(in ReadOnlyTensorSpan<T> x)
where T : IFloatingPoint<T>, IPowerFunctions<T>, IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>
where T : IFloatingPoint<T>, IPowerFunctions<T>, IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>, IRootFunctions<T>
{
T mean = Average(x);
Span<T> span = MemoryMarshal.CreateSpan(ref x._reference, (int)x._shape._memoryLength);
Expand All @@ -3521,7 +3521,7 @@ public static T StdDev<T>(in ReadOnlyTensorSpan<T> x)
TensorPrimitives.Abs(output, output);
TensorPrimitives.Pow((ReadOnlySpan<T>)output, T.CreateChecked(2), output);
T sum = TensorPrimitives.Sum((ReadOnlySpan<T>)output);
return T.CreateChecked(sum / T.CreateChecked(x._shape._memoryLength));
return T.Sqrt(sum / T.CreateChecked(x._shape._memoryLength));
}
#endregion

Expand Down
2 changes: 1 addition & 1 deletion src/libraries/System.Numerics.Tensors/tests/TensorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,7 @@ public static float StdDev(float[] values)
{
sum += MathF.Pow(values[i] - mean, 2);
}
return sum / values.Length;
return MathF.Sqrt(sum / values.Length);
}

[Fact]
Expand Down

0 comments on commit c212984

Please sign in to comment.