diff --git a/ndsl/stencils/testing/translate.py b/ndsl/stencils/testing/translate.py index 1a48c1e..e3fc884 100644 --- a/ndsl/stencils/testing/translate.py +++ b/ndsl/stencils/testing/translate.py @@ -116,12 +116,6 @@ def make_storage_data( elif not full_shape and len(array.shape) < 3 and axis == len(array.shape) - 1: use_shape[1] = 1 start = (int(istart), int(jstart), int(kstart)) - if "float" in str(array.dtype): - dtype = Float - elif "int" in str(array.dtype): - dtype = Int - else: - dtype = array.dtype if names_4d: return utils.make_storage_dict( array, @@ -132,7 +126,7 @@ def make_storage_data( axis=axis, names=names_4d, backend=self.stencil_factory.backend, - dtype=dtype, + dtype=array.dtype, ) else: if len(array.shape) == 4: @@ -147,7 +141,7 @@ def make_storage_data( axis=axis, read_only=read_only, backend=self.stencil_factory.backend, - dtype=dtype, + dtype=array.dtype, ) def storage_vars(self): diff --git a/ndsl/testing/comparison.py b/ndsl/testing/comparison.py index 9c59231..04cdc55 100644 --- a/ndsl/testing/comparison.py +++ b/ndsl/testing/comparison.py @@ -260,15 +260,19 @@ def _compute_all_metrics( self.ulp_distance_metric = self.ulp_distance <= self.ulp_threshold.value # Combine all distances into sucess or failure - # Success = no NANs & ( abs or rel or ulp ) - naninf_success = not np.logical_and( + # Success = + # - no unexpected NANs (e.g. NaN in the ref MUST BE in computation) OR + # - absolute distance pass OR + # - relative distance pass OR + # - ulp distance pass + naninf_success = np.logical_and( np.isnan(self.computed), np.isnan(self.references) - ).all() + ) metric_success = np.logical_or( self.relative_distance_metric, self.absolute_distance_metric ) metric_success = np.logical_or(metric_success, self.ulp_distance_metric) - success = np.logical_and(naninf_success, metric_success) + success = np.logical_or(naninf_success, metric_success) return success elif self.references.dtype in (np.bool_, bool): success = np.logical_xor(self.computed, self.references)