Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix test util
Browse files Browse the repository at this point in the history
seanmor5 committed Nov 22, 2023
1 parent 502e8ca commit 5fcca84
Showing 2 changed files with 1 addition and 9 deletions.
4 changes: 0 additions & 4 deletions test/axon/integration_test.exs
Original file line number Diff line number Diff line change
@@ -389,21 +389,17 @@ defmodule Axon.IntegrationTest do
end)

input = Axon.input("input")
carry = {Axon.constant(Nx.broadcast(0.0, {2, 5})), Axon.constant(Nx.broadcast(0.0, {2, 5}))}
# mask = Axon.mask(input, 0)

dynamic_model =
input
|> Axon.embedding(2, 8)
# |> Axon.lstm(5, seed: 40)
|> Axon.lstm(5, recurrent_initializer: :zeros)
|> elem(0)
|> Axon.nx(fn seq -> Nx.squeeze(seq[[0..-1//1, -1, 0..-1//1]]) end)

static_model =
input
|> Axon.embedding(2, 8)
# |> Axon.lstm(5, seed: 40, unroll: :static)
|> Axon.lstm(5, unroll: :static, recurrent_initializer: :zeros)
|> elem(0)
|> Axon.nx(fn seq -> Nx.squeeze(seq[[0..-1//1, -1, 0..-1//1]]) end)
6 changes: 1 addition & 5 deletions test/support/axon_test_util.ex
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@ defmodule AxonTestUtil do
|> Enum.zip_with(Tuple.to_list(rhs), &assert_all_close(&1, &2, opts))
end

def assert_all_close(%Nx.Tensor{} = lhs, %Nx.Tensor{} = rhs, opts) do
def assert_all_close(lhs, rhs, opts) do
res = Nx.all_close(lhs, rhs, opts) |> Nx.backend_transfer(Nx.BinaryBackend)

unless Nx.to_number(res) == 1 do
@@ -43,10 +43,6 @@ defmodule AxonTestUtil do
end
end

def assert_all_close(lhs, rhs, opts) when is_map(lhs) and is_map(rhs) do
Axon.Shared.deep_merge(lhs, rhs, &assert_all_close(&1, &2, opts))
end

def assert_equal(lhs, rhs) when is_tuple(lhs) and is_tuple(rhs) do
lhs
|> Tuple.to_list()

0 comments on commit 5fcca84

Please sign in to comment.