Skip to content

Commit

Permalink
assert all close
Browse files Browse the repository at this point in the history
  • Loading branch information
novikov-alexander committed Nov 10, 2023
1 parent fc8f493 commit 165e916
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 30 deletions.
22 changes: 1 addition & 21 deletions test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -625,25 +625,6 @@ public void testPartialDerivatives()
}
}

// TODO: remove when np.testing.assert_allclose(a, b) is implemented
private class CollectionComparer : System.Collections.IComparer
{
private readonly double _epsilon = 1e-07;

public int Compare(object x, object y)
{
var a = (double)x;
var b = (double)y;

double delta = Math.Abs(a - b);
if (delta < _epsilon)
{
return 0;
}
return a.CompareTo(b);
}
}

private struct Case
{
public Tensor[] grad1;
Expand Down Expand Up @@ -748,8 +729,7 @@ Tensor[] gradients(Tensor[] ys, Tensor[] xs, Tensor[] stop_gradients = null)
var npgrad2 = result[1];
foreach (var (a, b) in npgrad1.Zip(npgrad2))
{
// TODO: np.testing.assert_allclose(a, b);
CollectionAssert.AreEqual(a.ToArray(), b.ToArray(), new CollectionComparer());
self.assertAllClose(a, b);
}
}
}
Expand Down
18 changes: 9 additions & 9 deletions test/Tensorflow.UnitTest/PythonTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@ public void assertProtoEquals(object toProto, object o)

#region tensor evaluation and test session

private Session _cached_session = null;
private Graph _cached_graph = null;
private object _cached_config = null;
private Session? _cached_session = null;
private Graph? _cached_graph = null;
private object? _cached_config = null;
private bool _cached_force_gpu = false;

private void _ClearCachedSession()
Expand Down Expand Up @@ -237,7 +237,7 @@ protected object _eval_tensor(object tensor)
/// </summary>
public T evaluate<T>(Tensor tensor)
{
object result = null;
object? result = null;
// if context.executing_eagerly():
// return self._eval_helper(tensors)
// else:
Expand Down Expand Up @@ -274,7 +274,7 @@ public T evaluate<T>(Tensor tensor)

///Returns a TensorFlow Session for use in executing tests.
public Session cached_session(
Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false)
Graph? graph = null, object? config = null, bool use_gpu = false, bool force_gpu = false)
{
// This method behaves differently than self.session(): for performance reasons
// `cached_session` will by default reuse the same session within the same
Expand Down Expand Up @@ -325,7 +325,7 @@ public Session cached_session(
}

//Returns a TensorFlow Session for use in executing tests.
public Session session(Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false)
public Session session(Graph? graph = null, object? config = null, bool use_gpu = false, bool force_gpu = false)
{
//Note that this will set this session and the graph as global defaults.

Expand Down Expand Up @@ -359,7 +359,7 @@ public Session session(Graph graph = null, object config = null, bool use_gpu =
// A Session object that should be used as a context manager to surround
// the graph building and execution code in a test case.

Session s = null;
Session? s = null;
//if (context.executing_eagerly())
// yield None
//else
Expand Down Expand Up @@ -448,8 +448,8 @@ private Session _create_session(Graph graph, object cfg, bool forceGpu)
}

private Session _get_cached_session(
Graph graph = null,
object config = null,
Graph? graph = null,
object? config = null,
bool force_gpu = false,
bool crash_if_inconsistent_args = true)
{
Expand Down

0 comments on commit 165e916

Please sign in to comment.