Skip to content

Commit

Permalink
Support :node on Series.from_*
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Jun 29, 2024
1 parent f9ff532 commit 85c3383
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 8 deletions.
18 changes: 10 additions & 8 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ defmodule Explorer.Series do
## Options
* `:backend` - The backend to allocate the series on.
* `:node` - The Erlang node to allocate the series on.
* `:dtype` - Create a series of a given `:dtype`. By default this is `nil`, which means
that Explorer will infer the type from the values in the list.
See the module docs for the list of valid dtypes and aliases.
Expand Down Expand Up @@ -469,14 +470,13 @@ defmodule Explorer.Series do
@doc type: :conversion
@spec from_list(list :: list(), opts :: Keyword.t()) :: Series.t()
def from_list(list, opts \\ []) do
opts = Keyword.validate!(opts, [:dtype, :backend])
opts = Keyword.validate!(opts, [:dtype, :backend, :node])
backend = backend_from_options!(opts)

normalised_dtype = if opts[:dtype], do: Shared.normalise_dtype!(opts[:dtype])

type = Shared.dtype_from_list!(list, normalised_dtype)

backend.from_list(list, type)
dtype = Shared.dtype_from_list!(list, normalised_dtype)
Shared.apply_init(backend, :from_list, [list, dtype], opts)
end

defp from_same_value(%{data: %backend{}}, value) do
Expand All @@ -495,6 +495,7 @@ defmodule Explorer.Series do
## Options
* `:backend` - The backend to allocate the series on.
* `:node` - The Erlang node to allocate the series on.
## Examples
Expand Down Expand Up @@ -565,7 +566,7 @@ defmodule Explorer.Series do
) ::
Series.t()
def from_binary(binary, dtype, opts \\ []) when K.and(is_binary(binary), is_list(opts)) do
opts = Keyword.validate!(opts, [:backend])
opts = Keyword.validate!(opts, [:backend, :node])
dtype = Shared.normalise_dtype!(dtype)

{_type, alignment} = dtype |> Shared.dtype_to_iotype!()
Expand All @@ -576,7 +577,7 @@ defmodule Explorer.Series do
end

backend = backend_from_options!(opts)
backend.from_binary(binary, dtype)
Shared.apply_init(backend, :from_binary, [binary, dtype], opts)
end

@doc """
Expand All @@ -589,6 +590,7 @@ defmodule Explorer.Series do
## Options
* `:backend` - The backend to allocate the series on.
* `:node` - The Erlang node to allocate the series on.
* `:dtype` - The dtype of the series that must match the underlying tensor type.
The series can have a different dtype if the tensor is compatible with it.
Expand Down Expand Up @@ -664,7 +666,7 @@ defmodule Explorer.Series do
@doc type: :conversion
@spec from_tensor(tensor :: Nx.Tensor.t(), opts :: Keyword.t()) :: Series.t()
def from_tensor(tensor, opts \\ []) when is_struct(tensor, Nx.Tensor) do
opts = Keyword.validate!(opts, [:dtype, :backend])
opts = Keyword.validate!(opts, [:dtype, :backend, :node])
type = Nx.type(tensor)
{dtype, opts} = Keyword.pop_lazy(opts, :dtype, fn -> Shared.iotype_to_dtype!(type) end)

Expand All @@ -677,7 +679,7 @@ defmodule Explorer.Series do
end

backend = backend_from_options!(opts)
tensor |> Nx.to_binary() |> backend.from_binary(dtype)
Shared.apply_init(backend, :from_binary, [Nx.to_binary(tensor), dtype], opts)
end

@doc """
Expand Down
11 changes: 11 additions & 0 deletions lib/explorer/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,17 @@ defmodule Explorer.Shared do

## Apply

@doc """
Initializes a series or a dataframe with node placement.
"""
def apply_init(impl, fun, args, opts) do
if node = opts[:node] do
Explorer.Remote.apply(node, impl, fun, [], fn _ -> args end)
else
apply(impl, fun, args)
end
end

@doc """
Applies a function to a series.
"""
Expand Down
10 changes: 10 additions & 0 deletions test/explorer/remote_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,14 @@ defmodule Explorer.RemoteTest do
assert_receive {:DOWN, ^ref, _, _, _}
end
end

describe "init placement" do
test "series" do
series = S.from_list([1, 2, 3], node: @node2)
assert series.remote

series = S.from_binary(<<1, 2, 3>>, {:s, 8}, node: @node2)
assert series.remote
end
end
end

0 comments on commit 85c3383

Please sign in to comment.