Skip to content

Commit

Permalink
webgpu: enable more test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
xhcao committed Jan 17, 2023
1 parent 8fa597b commit 8cc7499
Showing 1 changed file with 13 additions and 163 deletions.
176 changes: 13 additions & 163 deletions tfjs-backend-webgpu/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,55 +26,19 @@ import './backend_webgpu_test_registry';
import {parseTestEnvFromKarmaFlags, setTestEnvs, setupTestFilters, TEST_ENVS, TestFilter} from '@tensorflow/tfjs-core/dist/jasmine_util';

const TEST_FILTERS: TestFilter[] = [
// skip test cases include gradients webgpu
{
include: 'gradients webgpu',
excludes: ['webgpu '],
},

// skip specific test cases for supported kernels
{
startsWith: 'abs ',
excludes: [
'complex64', // Kernel 'ComplexAbs' not registered.
]
},
{
startsWith: 'atan2 ',
excludes: [
'gradient', // Not yet implemented.
]
},
{
startsWith: 'batchToSpaceND ',
excludes: [
'gradient', // Not yet implemented.
]
},
{
startsWith: 'conv2dTranspose ',
excludes: [
'gradient', // gradient function not found.
]
},
{
startsWith: 'conv3d ',
excludes: [
'gradient', // Not yet implemented.
]
},
{
startsWith: 'cumprod ',
excludes: [
'gradient', // gradient function not found.
]
},
{
startsWith: 'prod ',
excludes: [
'gradient', // gradient function not found.
]
},
{
startsWith: 'cumsum ',
excludes: [
Expand All @@ -93,137 +57,35 @@ const TEST_FILTERS: TestFilter[] = [
'int32', // TODO: fix precision problem.
]
},
{
startsWith: 'fused conv2d ',
excludes: [
'gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0', // conv2dDerInput not yet
// implemented
'backProp', // Conv2DBackpropFilter not yet
// implemented
]
},
{
startsWith: 'fused matmul ',
excludes: [
'gradient', // Not yet implemented.
]
},
{
startsWith: 'gather ',
excludes: [
'throws when index is out of bound',
'gradient' // gradient function not found.
]
},
{
startsWith: 'matmul',
excludes: [
'has zero in its shape', // Test times out.
'valueAndGradients', // backend.sum() not yet implemented.
]
},
{
startsWith: 'maxPool ',
excludes: [
'maxPoolBackprop', // Not yet implemented.
'maxPool3d', // Not yet implemented.
'maxPoolWithArgmax' // Not yet implemented.
]
},
{
startsWith: 'max ',
excludes: [
'AdamaxOptimizer', // gradient function not found.
'sparseSegmentMean', // 'SparseSegmentMean' not registered.
]
},
{
startsWith: 'mean ',
excludes: [
'meanSquaredError',
]
},
{
startsWith: 'mul ',
excludes: [
'broadcast', // Various: Actual != Expected, compile fails, etc.
]
},
{
startsWith: 'nonMaxSuppression ',
excludes: [
'NonMaxSuppressionPadded' // NonMaxSuppressionV4 not yet implemented.
]
},
{
startsWith: 'pool ',
excludes: [
'poolBackprop', // maxPoolBackprop not yet implemented.
]
},
{
startsWith: 'poolBackprop ',
excludes: [
'max', // maxPoolBackprop not yet implemented.
]
},
{
startsWith: 'prod ',
excludes: [
'gradients', // Not yet implemented
]
},
{
startsWith: 'range ',
excludes: [
'sparseSegmentMean', // 'SparseSegmentMean' not registered.
]
},
{
startsWith: 'relu ',
excludes: [
'valueAndGradients', // gradient function not found.
'propagates NaNs', // Arrays differ.
'derivative', // gradient function not found.
'gradient' // gradient function not found.
]
},
{
startsWith: 'softmax ',
excludes: [
'MEAN',
'Weighted - Reduction.SUM_BY_NONZERO_WEIGHTS',
]
},
{
startsWith: 'spaceToBatchND ',
startsWith: 'resizeBilinear ',
excludes: [
'tensor4d',
'accepts a tensor-like object',
]
},
{
startsWith: 'square ',
excludes: [
'dilation2d', // 'dilation2d' not yet implemented.
]
},
{
startsWith: 'squaredDifference ',
excludes: [
'dilation2d', // 'dilation2d' not yet implemented.
]
},
{
startsWith: 'tensor ',
excludes: [
'bool tensor' // Expected object not to have properties.
'gradients', // Not yet implemented
]
},
{
startsWith: 'transpose ',
startsWith: 'resizeNearestNeighbor ',
excludes: [
'fused', // Not yet implemented.
'gradients', // Not yet implemented
]
},

Expand All @@ -232,28 +94,16 @@ const TEST_FILTERS: TestFilter[] = [
include: ' webgpu ',
excludes: [
// Not implemented kernel list.
'avgPool3d ',
'avgPool3dBackprop ',
'conv2DBackpropFilter ',
'gradient with clones, input=2x2x1,d2=1,f=1,s=1,d=1,p=same', // Conv2DBackpropFilter
'conv1d gradients', // Conv2DBackpropFilter
'conv3dTranspose ',
'maxPool3d ',
'maxPool3dBackprop ',
'raggedGather ',
'raggedRange ',
'raggedTensorToTensor ',
'avgPool3d ', 'avgPool3dBackprop ',
'conv3dTranspose ', 'maxPool3d ',
'maxPool3dBackprop ', 'raggedGather ',
'raggedRange ', 'raggedTensorToTensor ',
'method otsu', // round
'sparseFillEmptyRows ',
'sparseReshape ',
'sparseSegmentMean ',
'sparseSegmentSum ',
'stringSplit ',
'stringToHashBucketFast ',
'tensorScatterUpdate ',
'unique ',
'unsortedSegmentSum ',
'valueAndGradients ',
'sparseFillEmptyRows ', 'sparseReshape ',
'sparseSegmentMean ', 'sparseSegmentSum ',
'stringSplit ', 'stringToHashBucketFast ',
'tensorScatterUpdate ', 'unique ',
'unsortedSegmentSum ', 'valueAndGradients ',
]
},
];
Expand Down

0 comments on commit 8cc7499

Please sign in to comment.