Skip to content

Commit

Permalink
fix tests (#6679)
Browse files Browse the repository at this point in the history
BUG
  • Loading branch information
Linchenn authored Jul 26, 2022
1 parent 2233eac commit 4aca2fa
Showing 1 changed file with 56 additions and 11 deletions.
67 changes: 56 additions & 11 deletions tfjs-core/src/ops/depthwise_conv2d_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -726,26 +726,71 @@ describeWithFlags('depthwiseConv2D', ALL_ENVS, () => {

const x = tf.tensor4d(
[
1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8,
9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16,
17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24,
25, 25, 26, 26, 27, 27, 28, 28, 29, 29, 30, 30, 31, 31, 32, 32,
33, 33, 34, 34, 35, 35, 36, 36, 37, 37, 38, 38, 39, 39, 40, 40,
41, 41, 42, 42, 43, 43, 44, 44, 45, 45, 46, 46, 47, 47, 48, 48,
49, 49, 50, 50, 51, 51, 52, 52, 53, 53, 54, 54, 55, 55, 56, 56,
57, 57, 58, 58, 59, 59, 60, 60, 61, 61, 62, 62, 63, 63, 64, 64
0.09941668063402176, 0.05248984694480896, 0.4567521810531616,
0.8002573847770691, 0.810535192489624, 0.7010623216629028,
0.5898630023002625, 0.05883334204554558, 0.2314797043800354,
0.45427876710891724, 0.10960108041763306, 0.9710874557495117,
0.18139968812465668, 0.8959258794784546, 0.35156702995300293,
0.6495933532714844, 0.5185067653656006, 0.3260101079940796,
0.7837356925010681, 0.9170011281967163, 0.465780109167099,
0.0857422724366188, 0.38354963064193726, 0.8134718537330627,
0.8768209218978882, 0.38151195645332336, 0.5045309066772461,
0.8152258396148682, 0.2782581150531769, 0.545160174369812,
0.1587309092283249, 0.5507456064224243, 0.2704062759876251,
0.7736618518829346, 0.9871141314506531, 0.29300180077552795,
0.3038032352924347, 0.36257433891296387, 0.967268168926239,
0.7251133918762207, 0.6244085431098938, 0.8398842215538025,
0.42696574330329895, 0.25569799542427063, 0.5784937143325806,
0.22755105793476105, 0.8869972229003906, 0.05128923058509827,
0.6748542785644531, 0.97468101978302, 0.5549167394638062,
0.5639380812644958, 0.821204662322998, 0.5207878947257996,
0.8831672668457031, 0.6721863746643066, 0.23375047743320465,
0.040671784430742264, 0.24522553384304047, 0.6293181777000427,
0.6886807680130005, 0.29527169466018677, 0.48199158906936646,
0.5751473307609558, 0.817806601524353, 0.38846832513809204,
0.5553714036941528, 0.1839468777179718, 0.5287416577339172,
0.4813096523284912, 0.477756530046463, 0.641162633895874,
0.03040425479412079, 0.20608118176460266, 0.7930338978767395,
0.727353572845459, 0.42868077754974365, 0.6136374473571777,
0.06312728673219681, 0.4346885681152344, 0.004786544945091009,
0.4951920807361603, 0.588252604007721, 0.724294126033783,
0.07830118387937546, 0.07353833317756653, 0.7818689346313477,
0.8137099742889404, 0.6505773067474365, 0.5716961026191711,
0.5416423678398132, 0.855529248714447, 0.8958709239959717,
0.3598312437534332, 0.31329575181007385, 0.5971285104751587,
0.034069616347551346, 0.6229354739189148, 0.24074052274227142,
0.3356363773345947, 0.1049640029668808, 0.2543765604496002,
0.1635538637638092, 0.8082090616226196, 0.9097364544868469,
0.6435819268226624, 0.6100808382034302, 0.29750677943229675,
0.0738643929362297, 0.8887753486633301, 0.7692861557006836,
0.6412256360054016, 0.16205888986587524, 0.9414404034614563,
0.5698712468147278, 0.6834514737129211, 0.41202589869499207,
0.9096908569335938, 0.8094117045402527, 0.42103442549705505,
0.8905773162841797, 0.069722980260849, 0.014392468146979809,
0.22018849849700928, 0.30076053738594055, 0.8472294211387634,
0.852762758731842, 0.5004454851150513
],
[1, 8, 8, inDepth]);

const w = tf.tensor4d(
[9, 1, 8, 2, 7, 3, 6, 4, 5, 5, 4, 6, 3, 7, 2, 8, 1, 9],
[
0.5785998106002808, 0.7439202666282654, 0.2178175300359726,
0.8782838582992554, 0.6579487919807434, 0.6556791067123413,
0.7341834306716919, 0.3332836329936981, 0.037182893604040146,
0.7394348382949829, 0.04031887650489807, 0.19104436039924622,
0.7014378309249878, 0.5309979319572449, 0.8485966920852661,
0.6609954237937927, 0.021728534251451492, 0.9289031624794006
],
[fSize, fSize, inDepth, 1],
);
const result = tf.depthwiseConv2d(x, w, stride, pad, 'NHWC', dilation);

expect(result.shape).toEqual([1, 2, 2, 2]);
expectArraysClose(
await result.data(), [810, 1710, 855, 1755, 1170, 2070, 1215, 2115]);
expectArraysClose(await result.data(), [
1.0257229804992676, 3.247040033340454, 1.9391249418258667,
2.9474055767059326, 2.0091731548309326, 3.600433826446533,
2.334312677383423, 2.548961877822876
]);
});

it('Tensor3D is allowed', async () => {
Expand Down

0 comments on commit 4aca2fa

Please sign in to comment.