diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index a98d5e4d92..8206f4106b 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -943,7 +943,7 @@ defmodule Nx do for t <- [:u2, :u4, :u8, :u16, :u32, :u64, :s2, :s4, :s8, :s16, :s32, :s64] ++ - [:f8, :bf16, :f16, :f32, :f64] do + [:f8, :bf16, :f16, :f32, :f64, :c64, :c128] do @doc """ Short-hand function for creating tensor of type `#{t}`. diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index 638891eaf1..6386769e22 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -249,9 +249,18 @@ defmodule Nx.Defn.Expr do result = for expr <- [last | exprs] do - expr - |> Nx.as_type(type) - |> Nx.broadcast(shape, names: names) + typed_expr = + case expr do + %T{data: %Expr{op: :constant}} -> + expr + |> maybe_upcast_float_constant(type) + |> Nx.as_type(type) + + expr -> + Nx.as_type(expr, type) + end + + Nx.broadcast(typed_expr, shape, names: names) end {result, vectorized_axes} @@ -1401,6 +1410,10 @@ defmodule Nx.Defn.Expr do defp constant(%{shape: shape, type: type} = out, number) do number = cond do + Nx.Type.complex?(type) and + (is_number(number) or number in [:infinity, :neg_infinity, :nan]) -> + Complex.new(number, 0.0) + is_integer(number) and Nx.Type.float?(type) -> Complex.multiply(1.0, number) diff --git a/nx/test/nx/defn_test.exs b/nx/test/nx/defn_test.exs index d532ea4043..42bc0f8dcb 100644 --- a/nx/test/nx/defn_test.exs +++ b/nx/test/nx/defn_test.exs @@ -1176,6 +1176,48 @@ defmodule Nx.DefnTest do ) end + defn cond_upcast_float_literals(n) do + cond do + n == 1 -> 1.4 + n == 2 -> 2 + true -> n + end + end + + test "upcasts float literals based on the accumulated clause type" do + for input_type <- [f: 32, f: 64] do + assert %T{ + type: ^input_type, + data: %Expr{op: :cond, args: [[clause1, clause2], _last]} + } = + cond_upcast_float_literals(Nx.tensor(10.0, type: input_type)) + + assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [1.4]}}} = clause1 + assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [2.0]}}} = clause2 + end + + for input_type <- [c: 64, c: 128] do + assert %T{ + type: ^input_type, + data: %Expr{op: :cond, args: [[clause1, clause2], _last]} + } = + cond_upcast_float_literals(Nx.tensor(10.0, type: input_type)) + + assert {_, + %T{ + type: ^input_type, + data: %Expr{op: :constant, args: [%Complex{re: 1.4, im: +0.0}]} + }} = clause1 + + assert {_, + %T{ + type: ^input_type, + data: %Expr{op: :constant, args: [%Complex{re: 2.0, im: +0.0}]} + }} = + clause2 + end + end + defn cond_list(a) do if Nx.any(a), do: 1, else: -1 end