Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
move MKLDNNSum test back to bottom
Browse files Browse the repository at this point in the history
  • Loading branch information
azai91 committed Jun 20, 2018
1 parent bfc729e commit 22ff293
Showing 1 changed file with 40 additions and 40 deletions.
80 changes: 40 additions & 40 deletions tests/cpp/operator/mkldnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -708,46 +708,6 @@ TEST(MKLDNN_NDArray, CopyFrom) {
}
}

TEST(MKLDNN_BASE, MKLDNNSum) {
std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays();
std::vector<NDArrayAttrs> in_arrs2 = GetTestInputArrays(true);
TestArrayShapes tas = GetTestArrayShapes();
std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;

for (int i = 0; i < in_arrs.size(); i++) {
auto in_arr = in_arrs[i];
auto in_arr2 = in_arrs2[i];
std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds);
if (!SupportMKLDNN(in_arr.arr) || !in_arr.arr.IsMKLDNNData() || in_arr.arr.IsView())
continue;

for (auto out_arr : out_arrs) {
auto in_mem1 = in_arr.arr.GetMKLDNNData();
auto in_mem2 = in_arr2.arr.GetMKLDNNData();
auto out_mem = out_arr.arr.GetMKLDNNData(in_mem1->get_primitive_desc());

// TODO(alexzai) : remove this noop when by reordering in MKLDNNSum
if (out_mem == nullptr)
continue;
PrintVerifyMsg(in_arr, in_arr);
op::MKLDNNSum(*in_mem1, *in_mem2, *out_mem);
MKLDNNStream::Get()->Submit();
VerifySumResult({&in_arr.arr, &in_arr2.arr}, {&out_arr.arr});
}

// in place
auto input_mem = in_arr.arr.GetMKLDNNData();
auto input_mem2 = in_arr2.arr.GetMKLDNNData();
NDArrayAttrs orig_arr(in_arr.arr.Copy(in_arr.arr.ctx()), "In Place Copy");
PrintVerifyMsg(orig_arr, in_arr);
InitMKLDNNArray(&orig_arr.arr, input_mem->get_primitive_desc());
orig_arr.arr.CopyFrom(*input_mem);
op::MKLDNNSum(*input_mem, *input_mem2, *input_mem);
MKLDNNStream::Get()->Submit();
VerifySumResult({&orig_arr.arr, &in_arr2.arr}, {&in_arr.arr});
}
}

void TestOp(const OpAttrs &attrs, VerifyFunc verify_fn) {
std::vector<NDArray*> inputs(attrs.num_inputs);
std::vector<NDArray*> outputs(attrs.num_outputs);
Expand Down Expand Up @@ -834,4 +794,44 @@ TEST(IMPERATIVE, SumBackwardsOp) {
TestOp(attrs, VerifySumBackwardsResult);
}

TEST(MKLDNN_BASE, MKLDNNSum) {
std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays();
std::vector<NDArrayAttrs> in_arrs2 = GetTestInputArrays(true);
TestArrayShapes tas = GetTestArrayShapes();
std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;

for (int i = 0; i < in_arrs.size(); i++) {
auto in_arr = in_arrs[i];
auto in_arr2 = in_arrs2[i];
std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds);
if (!SupportMKLDNN(in_arr.arr) || !in_arr.arr.IsMKLDNNData() || in_arr.arr.IsView())
continue;

for (auto out_arr : out_arrs) {
auto in_mem1 = in_arr.arr.GetMKLDNNData();
auto in_mem2 = in_arr2.arr.GetMKLDNNData();
auto out_mem = out_arr.arr.GetMKLDNNData(in_mem1->get_primitive_desc());

// TODO(alexzai) : remove this noop when by reordering in MKLDNNSum
if (out_mem == nullptr)
continue;
PrintVerifyMsg(in_arr, in_arr);
op::MKLDNNSum(*in_mem1, *in_mem2, *out_mem);
MKLDNNStream::Get()->Submit();
VerifySumResult({&in_arr.arr, &in_arr2.arr}, {&out_arr.arr});
}

// in place
auto input_mem = in_arr.arr.GetMKLDNNData();
auto input_mem2 = in_arr2.arr.GetMKLDNNData();
NDArrayAttrs orig_arr(in_arr.arr.Copy(in_arr.arr.ctx()), "In Place Copy");
PrintVerifyMsg(orig_arr, in_arr);
InitMKLDNNArray(&orig_arr.arr, input_mem->get_primitive_desc());
orig_arr.arr.CopyFrom(*input_mem);
op::MKLDNNSum(*input_mem, *input_mem2, *input_mem);
MKLDNNStream::Get()->Submit();
VerifySumResult({&orig_arr.arr, &in_arr2.arr}, {&in_arr.arr});
}
}

#endif

0 comments on commit 22ff293

Please sign in to comment.