Skip to content

Commit

Permalink
Adding a unit test reflecting microsoft/CNTK#2698
Browse files Browse the repository at this point in the history
  • Loading branch information
cesarsouza committed Nov 28, 2017
1 parent a86aa27 commit 8134e8e
Showing 1 changed file with 55 additions and 0 deletions.
55 changes: 55 additions & 0 deletions Tests/Backend/CNTK/CNTKBackendTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,61 @@ public void cntk_sum_test()
}
}

[Test]
public void cntk_sum_test_direct_api()
{
double[][] r;

r = sum(null); // first, a sanity check to verify that values are being read correctly
double[,] d = new double[2, 3]; // result will be { 0, 1, 2, 3, 4, 5, 6 }
Buffer.BlockCopy(r[0], 0, d, 0, sizeof(double) * d.Length);
Assert.AreEqual(new[,] { { 1, 2, 3 }, { 4, 5, 6 } }, d); // ok

r = sum(0, 1); // sum over all axes
double a = r[0][0]; // result will be { 21 }
Assert.AreEqual(21, a); // ok

r = sum(0); // sum over first axis
double[] b = r[0]; // result will be { 3, 7, 11 }
Assert.AreEqual(new[] { 5.0, 7.0, 9.0 }, b); // fails

r = sum(1); // sum over second axis
double[] c = r[0]; // result will be { 9, 12 }
Assert.AreEqual(new[] { 6.0, 15.0 }, b); // fails
}

private static double[][] sum(params int[] axes)
{
var arr = new[]
{
/* total:
/* */ 1.0, 2.0, 3.0, /* 6.0 */
/* */ 4.0, 5.0, 6.0, /* 15.0 */
/* total: 5.0, 7.0, 9.0 21.0 */
};

var shape = NDShape.CreateNDShape(new[] { 2, 3 });
Value vx = Value.CreateBatch(shape, arr, DeviceDescriptor.CPUDevice, readOnly: true);
Variable x = Variable.InputVariable(shape, CNTK.DataType.Double, name: "input");

CNTK.Function f;
if (axes == null)
{
f = CNTKLib.Alias(x);
}
else
{
var axisVector = new AxisVector(axes.Select(ax => new Axis(ax)).ToArray());
f = CNTKLib.ReduceSum(x, axis: axisVector);
}

var inputs = new Dictionary<Variable, Value>() { { x, vx } };
var outputs = new Dictionary<Variable, Value>() { { f, null } };
f.Evaluate(inputs, outputs, DeviceDescriptor.CPUDevice);
var r = outputs[f].GetDenseData<double>((Variable)f);
return r.Select(ri => ri.ToArray()).ToArray();
}

[Test]
public void cntk_mean_test()
{
Expand Down

0 comments on commit 8134e8e

Please sign in to comment.