Skip to content

Commit

Permalink
Improve variance and standard deviation (#621)
Browse files Browse the repository at this point in the history
  • Loading branch information
tiagodavi authored Feb 7, 2022
1 parent cab4408 commit 8d74000
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 11 deletions.
90 changes: 83 additions & 7 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5417,7 +5417,7 @@ defmodule Nx do
@doc """
Returns the mean for the tensor.
If the `:axis` option is given, it aggregates over
If the `:axes` option is given, it aggregates over
that dimension, effectively removing it. `axes: [0]`
implies aggregating over the highest order dimension
and so forth. If the axis is negative, then counts
Expand Down Expand Up @@ -9278,21 +9278,63 @@ defmodule Nx do
f32
1.6666666269302368
>
iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [0])
#Nx.Tensor<
f32[2]
[1.0, 1.0]
>
iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [1])
#Nx.Tensor<
f32[2]
[0.25, 0.25]
>
iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [0], ddof: 1)
#Nx.Tensor<
f32[2]
[2.0, 2.0]
>
iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [1], ddof: 1)
#Nx.Tensor<
f32[2]
[0.5, 0.5]
>
### Keeping axes
iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [1], keep_axes: true)
#Nx.Tensor<
f32[2][1]
[
[0.25],
[0.25]
]
>
"""
@doc type: :aggregation
@spec variance(tensor :: Nx.Tensor.t(), opts :: Keyword.t()) :: Nx.Tensor.t()
def variance(tensor, opts \\ []) do
%T{shape: shape} = tensor = to_tensor(tensor)
%T{shape: shape, names: names} = tensor = to_tensor(tensor)
opts = keyword!(opts, [:axes, ddof: 0, keep_axes: false])
axes = opts[:axes]
{ddof, opts} = Keyword.pop!(opts, :ddof)

total =
if axes do
mean_den(shape, Nx.Shape.normalize_axes(shape, axes, names))
else
size(shape)
end

opts = keyword!(opts, ddof: 0)
total = size(shape)
ddof = Keyword.fetch!(opts, :ddof)
mean = mean(tensor)
mean = mean(tensor, Keyword.put(opts, :keep_axes, true))

tensor
|> subtract(mean)
|> power(2)
|> sum()
|> sum(opts)
|> divide(total - ddof)
end

Expand All @@ -9316,6 +9358,40 @@ defmodule Nx do
f32
1.29099440574646
>
iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), axes: [0])
#Nx.Tensor<
f32[2]
[1.0, 1.0]
>
iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), axes: [1])
#Nx.Tensor<
f32[2]
[0.5, 0.5]
>
iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), axes: [0], ddof: 1)
#Nx.Tensor<
f32[2]
[1.4142135381698608, 1.4142135381698608]
>
iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), axes: [1], ddof: 1)
#Nx.Tensor<
f32[2]
[0.7071067690849304, 0.7071067690849304]
>
### Keeping axes
iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), keep_axes: true)
#Nx.Tensor<
f32[1][1]
[
[1.1180340051651]
]
>
"""
@doc type: :aggregation
@spec standard_deviation(tensor :: Nx.Tensor.t(), opts :: Keyword.t()) :: Nx.Tensor.t()
Expand Down
35 changes: 31 additions & 4 deletions nx/test/nx_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1922,26 +1922,53 @@ defmodule NxTest do
end

describe "variance/1" do
test "should calculate the variance of a tensor" do
test "calculates variance of a tensor" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]])
assert Nx.variance(t) == Nx.tensor(2.9166667461395264)
end

test "should use the optional ddof" do
test "uses optional ddof" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]])
assert Nx.variance(t, ddof: 1) == Nx.tensor(3.5)
end

test "uses optional axes" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]], names: [:x, :y])

assert Nx.variance(t, axes: [:x]) ==
Nx.tensor([1.5555557012557983, 4.222222328186035], names: [:y])

t = Nx.tensor([[4, 5], [2, 3], [1, 0]], names: [:x, :y])
assert Nx.variance(t, axes: [:y]) == Nx.tensor([0.25, 0.25, 0.25], names: [:x])
end

test "uses optional keep axes" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]])
assert Nx.variance(t, keep_axes: true) == Nx.tensor([[2.9166667461395264]])
end
end

describe "standard_deviation/1" do
test "should calculate the standard deviation of a tensor" do
test "calculates the standard deviation of a tensor" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]])
assert Nx.standard_deviation(t) == Nx.tensor(1.707825127659933)
end

test "should use the optional ddof" do
test "uses optional ddof" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]])
assert Nx.standard_deviation(t, ddof: 1) == Nx.tensor(1.8708287477493286)
end

test "uses optional axes" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]])

assert Nx.standard_deviation(t, axes: [0]) ==
Nx.tensor([1.247219204902649, 2.054804801940918])
end

test "uses optional keep axes" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]])
assert Nx.standard_deviation(t, keep_axes: true) == Nx.tensor([[1.7078251838684082]])
end
end
end

0 comments on commit 8d74000

Please sign in to comment.