Skip to content

Commit

Permalink
benchdnn: ref_prim: fix passed memory object with proper data type
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Jul 18, 2024
1 parent 9116681 commit 91c35d8
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions tests/benchdnn/dnnl_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1494,19 +1494,20 @@ int update_ref_mem_map_from_prim(dnnl_primitive_t prim_ref,
// have dedicated query mechanism for those. Process potential outcomes:
while (query_md_ndims(ref_md) == 0) {
bool is_scales_arg = (exec_arg & DNNL_ARG_ATTR_SCALES);
// Ref memory for scales is f32, the library expects it same data type.
// Skip replacement.
// Scales received data type support in the library. The reference
// primitive expects them in the same data type.
if (is_scales_arg) {
skip_replace = true;
prim_ref_mem = dnn_mem_t(
library_mem.md_, library_mem.dt(), tag::abx, ref_engine);
break;
}

bool is_zero_point_arg = (exec_arg & DNNL_ARG_ATTR_ZERO_POINTS);
// Ref memory for zps is f32, but the library expects it in s32. Update
// the memory and proceed to replacement.
// Zero-points received data type support in the library. The reference
// primitive expects them in the same data type.
if (is_zero_point_arg) {
prim_ref_mem = dnn_mem_t(
library_mem.md_, dnnl_s32, tag::abx, ref_engine);
library_mem.md_, library_mem.dt(), tag::abx, ref_engine);
break;
}

Expand Down

0 comments on commit 91c35d8

Please sign in to comment.