Skip to content

Commit

Permalink
fix: add the implementation of the tile's grad
Browse files Browse the repository at this point in the history
  • Loading branch information
Wanglongzhi2001 committed Oct 20, 2023
1 parent d5f5c57 commit a73694a
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 1 deletion.
24 changes: 24 additions & 0 deletions src/TensorFlowNET.Core/Gradients/array_grad.cs
Original file line number Diff line number Diff line change
Expand Up @@ -381,5 +381,29 @@ public static Tensor[] _ReverseV2Grad(Operation op, Tensor[] grads)
var axis = op.inputs[1];
return new Tensor[] { array_ops.reverse(grad, axis), null };
}

[RegisterGradient("Tile")]
public static Tensor[] _TileGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var input_shape = array_ops.shape(op.inputs[0], out_type: op.inputs[1].dtype);
var split_shape = array_ops.reshape(array_ops.transpose(array_ops.stack(new Tensor[] { op.inputs[1], input_shape })), new Shape(-1));
var axes = math_ops.range(0, array_ops.size(split_shape), 2);

//# Sum reduces grad along the first dimension for IndexedSlices
//if isinstance(grad, indexed_slices_lib.IndexedSlices):
//input_shape_0 = math_ops.cast(input_shape[0], grad.indices.dtype)
//grad = math_ops.unsorted_segment_sum(
// grad.values, math_ops.mod(grad.indices, input_shape_0), input_shape_0)
//split_shape = array_ops.concat([[1], split_shape[1:]], axis = 0)

var input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes);
if (!tf.Context.executing_eagerly())
{
input_grad.set_shape(op.inputs[0].GetShape());
}
return new Tensor[] { input_grad, null };

}
}
}
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Operations/array_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ public static Tensor gather(ResourceVariable @params, Tensor indices, string nam
return @params.sparse_read(indices, name);
}

public static Tensor transpose<T1>(T1 a, Axis perm, string name = "transpose", bool conjugate = false)
public static Tensor transpose<T1>(T1 a, Axis perm = null, string name = "transpose", bool conjugate = false)
{
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
{
Expand Down
14 changes: 14 additions & 0 deletions test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -173,5 +173,19 @@ public void ConditionalMultiply()
var result = grad(x, 4);
Assert.AreEqual((float)result, 4.0f);
}

[TestMethod]
public void Tile()
{
var a = tf.constant(new int[] { 1 }, TF_DataType.TF_FLOAT);
var b = tf.constant(new int[] { 2 });
using (var tape = tf.GradientTape())
{
tape.watch(a);
var y = tf.tile(a, b);
var grad = tape.gradient(y, a);
Assert.AreEqual((float)grad.numpy(), 2.0f);
}
}
}
}

0 comments on commit a73694a

Please sign in to comment.