Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix the bug of boolean_mask #1205

Merged
merged 1 commit into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,9 @@ public static Tensor _transpose_batch_time(Tensor x)
return x;

var x_rank = array_ops.rank(x);
var con1 = new object[]
var con1 = new Tensor[]
{
new []{1, 0 },
new Tensor(new int[]{0, 2}),
math_ops.range(2, x_rank)
};
var x_t = array_ops.transpose(x, array_ops.concat(con1, 0));
Expand Down
13 changes: 9 additions & 4 deletions src/TensorFlowNET.Core/Operations/array_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boo
throw new ValueError("mask cannot be scalar.");

var leading_size = gen_math_ops.prod(shape(tensor_tensor)[$"{axis}:{axis + ndims_mask}"], ops.convert_to_tensor(new[] { 0 }));
if (leading_size.rank == 0)
{
leading_size = expand_dims(leading_size, 0);
}

var shape1 = concat(new[]
{
shape(tensor_tensor)[$":{axis}"],
Expand All @@ -185,7 +190,7 @@ public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boo

private static Tensor _apply_mask_1d(Tensor reshaped_tensor, Tensor mask, int axis = 0)
{
var indices = squeeze(where(mask), axis: new[] { 1 });
var indices = squeeze(where_v2(mask), axis: new[] { 1 });
return gather(reshaped_tensor, indices, axis: ops.convert_to_tensor(axis));
}

Expand Down Expand Up @@ -940,12 +945,12 @@ public static Tensor broadcast_static_shape(Tensor shape_x, Tensor shape_y)
/// <returns></returns>
public static Tensor concat(Tensor[] values, Tensor axis, string name = "concat")
{
return tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis));
return gen_array_ops.concat_v2(values, axis, name: name);
}

public static Tensor concat(object[] values, int axis, string name = "concat")
public static Tensor concat(Tensor[] values, Axis axis, string name = "concat")
{
return tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis));
return gen_array_ops.concat_v2(values, axis, name: name);
}

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Operations/nn_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ private static Tensor _flatten_outer_dims(Tensor logits)
new[] { math_ops.subtract(rank, 1) },
new[] { constant_op.constant(1) });

var ops = array_ops.concat(new[] { new[] { -1 }, (object)last_dim_size }, 0);
var ops = array_ops.concat(new Tensor[] { new Tensor(new int[] {1}), last_dim_size }, 0);
var output = array_ops.reshape(logits, ops);

// Set output shape if known.
Expand Down
7 changes: 4 additions & 3 deletions test/TensorFlowNET.Graph.UnitTest/Basics/TensorTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Linq;
using static Tensorflow.Binding;
using Tensorflow;

namespace TensorFlowNET.UnitTest.Basics
{
Expand Down Expand Up @@ -60,14 +61,14 @@ public void batch_to_space_nd()
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>()));
}

[TestMethod, Ignore]
[TestMethod]
public void boolean_mask()
{
if (!tf.executing_eagerly())
tf.enable_eager_execution();
var tensor = new[] { 0, 1, 2, 3 };
var mask = np.array(new[] { true, false, true, false });
var masked = tf.boolean_mask(tensor, mask);
var sess = tf.Session();
var result = sess.run(masked);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>()));
}
}
Expand Down
Loading