diff --git a/tests/e2e/convolution/CMakeLists.txt b/tests/e2e/convolution/CMakeLists.txt index 8ddad849b082..b0f636249650 100644 --- a/tests/e2e/convolution/CMakeLists.txt +++ b/tests/e2e/convolution/CMakeLists.txt @@ -323,3 +323,176 @@ iree_generated_e2e_runner_test( ) ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### + +# To distinguish between CDNA(gfx9) and RDNA3(gfx11) +if(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx9") + +unset(IREE_HIP_TEST_COMPILER_FLAGS) +list(APPEND IREE_HIP_TEST_COMPILER_FLAGS + "--iree-rocm-target-chip=${IREE_HIP_TEST_TARGET_CHIP}" +) + +iree_generated_e2e_runner_test( + NAME + e2e_conv2d_rocm_f16_f16_f32_large_cdna3 + TEST_TYPE + conv2d + GENERATOR + "generate_e2e_conv2d_tests.py" + GENERATOR_ARGS + "--input_type=f16" + "--input_layout=nhwc" + "--kernel_type=f16" + "--kernel_layout=hwcf" + "--acc_type=f32" + "--shapes=gpu_large" + "--compilation_info=LLVMGPUVectorDistributeMFMA" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-conv2d-test + TARGET_BACKENDS + "rocm" + DRIVERS + "hip" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-cdna3" +) + +iree_generated_e2e_runner_test( + NAME + e2e_conv2d_rocm_f16_f16_f32_large_gpu_vectorize_cdna3 + TEST_TYPE + conv2d + GENERATOR + "generate_e2e_conv2d_tests.py" + GENERATOR_ARGS + "--input_type=f16" + "--input_layout=nhwc" + "--kernel_type=f16" + "--kernel_layout=hwcf" + "--acc_type=f32" + "--shapes=gpu_large" + "--compilation_info=LLVMGPUVectorizeCDNA" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-conv2d-test + TARGET_BACKENDS + "rocm" + DRIVERS + "hip" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-cdna3" +) + +iree_generated_e2e_runner_test( + NAME + e2e_conv2d_rocm_nchw_f16_f16_f32_large_gpu_vectorize_cdna3 + TEST_TYPE + conv2d + GENERATOR + "generate_e2e_conv2d_tests.py" + GENERATOR_ARGS + "--input_type=f16" + "--input_layout=nchw" + "--kernel_type=f16" + "--kernel_layout=fchw" + "--acc_type=f32" + "--shapes=gpu_large" + "--compilation_info=LLVMGPUVectorizeCDNA" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-conv2d-test + TARGET_BACKENDS + "rocm" + DRIVERS + "hip" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-cdna3" +) + + +iree_generated_e2e_runner_test( + NAME + e2e_conv2d_rocm_i8_large_cdna3 + TEST_TYPE + conv2d + GENERATOR + "generate_e2e_conv2d_tests.py" + GENERATOR_ARGS + "--input_type=i8" + "--input_layout=nhwc" + "--kernel_type=i8" + "--kernel_layout=hwcf" + "--acc_type=i32" + "--shapes=gpu_large" + "--compilation_info=LLVMGPUVectorDistributeMFMA" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-conv2d-test + TARGET_BACKENDS + "rocm" + DRIVERS + "hip" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-cdna3" +) + +elseif(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx11") + +unset(IREE_HIP_TEST_COMPILER_FLAGS) +list(APPEND IREE_HIP_TEST_COMPILER_FLAGS + "--iree-rocm-target-chip=${IREE_HIP_TEST_TARGET_CHIP}" +) + +iree_generated_e2e_runner_test( + NAME + e2e_conv2d_rocm_f16_f16_f32_large_cdna3 + TEST_TYPE + conv2d + GENERATOR + "generate_e2e_conv2d_tests.py" + GENERATOR_ARGS + "--input_type=f16" + "--input_layout=nhwc" + "--kernel_type=f16" + "--kernel_layout=hwcf" + "--acc_type=f32" + "--shapes=gpu_large" + "--compilation_info=LLVMGPUVectorDistributeWMMA" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-conv2d-test + TARGET_BACKENDS + "rocm" + DRIVERS + "hip" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-rdna3" +) + +endif() diff --git a/tests/e2e/convolution/generate_e2e_conv2d_tests.py b/tests/e2e/convolution/generate_e2e_conv2d_tests.py index 0982e1801679..33edfb22e6d2 100644 --- a/tests/e2e/convolution/generate_e2e_conv2d_tests.py +++ b/tests/e2e/convolution/generate_e2e_conv2d_tests.py @@ -7,17 +7,21 @@ """Generator for e2e conv2d tests. """ +from typing import Optional import argparse import enum import dataclasses import typing import math +import itertools +import re # Data type of kernel entries. The string values must match MLIR data types. @enum.unique class KernelElemTypeId(enum.Enum): NONE = "" + I8 = "i8" F32 = "f32" F16 = "f16" @@ -26,6 +30,16 @@ class KernelElemTypeId(enum.Enum): @enum.unique class InputElemTypeId(enum.Enum): NONE = "" + I8 = "i8" + F32 = "f32" + F16 = "f16" + + +# Data type of input entries. The string values must match MLIR data types. +@enum.unique +class AccElemTypeId(enum.Enum): + NONE = "" + I32 = "i32" F32 = "f32" F16 = "f16" @@ -37,6 +51,55 @@ class ShapesId(enum.Enum): SMALL = "small" MEDIUM = "medium" LARGE = "large" + GPU_LARGE = "gpu_large" + + +# Describes a workgroup and tiling schedule to target a specific MMA intrinsic. +@dataclasses.dataclass +class MMASchedule: + intrinsic: str + m_count: int # Number of subgroups per workgroup along M + n_count: int # Number of subgroups per workgroup along N + m_tile_count: int + n_tile_count: int + k_tile_count: int + + def __str__(self): + return ( + "mma_schedule = #iree_gpu.mma_schedule<" + + f"intrinsic = #iree_gpu.mma_layout<{self.intrinsic}>, " + + f"subgroup_m_count = {self.m_count}, " + + f"subgroup_n_count = {self.n_count}>" + ) + + +# Enumerates of the collections of compilation info that we can generate tests +# for. The values are the accepted values for the --compilation_info= flag. +@enum.unique +class CompilationInfoId(enum.Enum): + NONE = "" + LLVMGPUVectorDistributeMFMA = "LLVMGPUVectorDistributeMFMA" + LLVMGPUVectorDistributeWMMA = "LLVMGPUVectorDistributeWMMA" + LLVMGPUVectorizeCDNA = "LLVMGPUVectorizeCDNA" + LLVMGPUVectorizeRDNA = "LLVMGPUVectorizeRDNA" + + +# Describes how to construct compilation info for the testcase. +@dataclasses.dataclass +class CompilationInfo: + # Lowering Config + tile_sizes: typing.List[typing.List[int]] + # Translation Info + dispatch_lowering_pass_pipeline: str + software_pipeline_depth: int + mma_schedule: typing.Optional[MMASchedule] + # Compilation info + workgroup_size: typing.List[int] + subgroup_size: Optional[int] = None + + # Prints the workgroup size + def workgroup_size_str(self): + return "workgroup_size = [" + ", ".join(map(str, self.workgroup_size)) + "]" # Enumerates ways to construct MLIR tensor types. @@ -125,6 +188,8 @@ def get_test_shapes(shapes_id: ShapesId): TestShape(n=2, c=4, h=128, w=128, kh=3, kw=3, f=8, accumulate=True), TestShape(n=2, c=3, h=128, w=128, kh=3, kw=3, f=12, accumulate=True), ] + if shapes_id == ShapesId.GPU_LARGE: + return [TestShape(n=1, c=64, h=130, w=130, kh=3, kw=3, f=64, accumulate=True)] raise ValueError(shapes_id) @@ -132,7 +197,7 @@ def get_test_shapes(shapes_id: ShapesId): # Returns the list of Dynamicity's to use for the collection of shapes # identified by shapes_id. def get_dynamicities(shapes_id: ShapesId): - if shapes_id == ShapesId.LARGE: + if shapes_id == ShapesId.LARGE or shapes_id == ShapesId.GPU_LARGE: return [ Dynamicity.STATIC, ] @@ -144,6 +209,156 @@ def get_dynamicities(shapes_id: ShapesId): raise ValueError(shapes_id) +@dataclasses.dataclass +class TileWorkgroupSizePair: + tile_size: typing.List[typing.List[int]] + workgroup_size: typing.List[int] + + +def get_rocm_test_compilation_infos( + compilation_info_id: CompilationInfoId, + lhs_rhs_type: InputElemTypeId, + acc_type: AccElemTypeId, +): + intrinsic = "" + if compilation_info_id == CompilationInfoId.LLVMGPUVectorDistributeMFMA: + intrinsic = "MFMA" + elif compilation_info_id == CompilationInfoId.LLVMGPUVectorDistributeWMMA: + intrinsic = "WMMA" + else: + raise ValueError("Unknown pipeline for rocm") + + schedules = [] + if intrinsic == "MFMA": + schedules = [ + MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 1), + MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 2), + MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 2, 1), + MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 2, 1, 1), + MMASchedule("MFMA_F16_16x16x16_F32", 2, 2, 1, 1, 1), + MMASchedule("MFMA_F16_16x16x16_F32", 2, 4, 2, 1, 2), + MMASchedule("MFMA_F16_32x32x8_F32", 1, 1, 1, 2, 2), + MMASchedule("MFMA_F16_32x32x8_F32", 2, 2, 1, 1, 1), + MMASchedule("MFMA_I8_16x16x32_I32", 1, 1, 1, 1, 1), + MMASchedule("MFMA_I8_16x16x32_I32", 2, 2, 1, 1, 2), + MMASchedule("MFMA_I8_16x16x32_I32", 4, 1, 4, 1, 1), + MMASchedule("MFMA_I8_16x16x32_I32", 4, 2, 4, 2, 1), + MMASchedule("MFMA_I8_32x32x16_I32", 1, 1, 1, 1, 1), + MMASchedule("MFMA_I8_32x32x16_I32", 2, 2, 1, 1, 2), + MMASchedule("MFMA_I8_32x32x16_I32", 4, 1, 1, 2, 2), + ] + elif intrinsic == "WMMA": + schedules = [ + MMASchedule("WMMA_F16_16x16x16_F32", 1, 1, 1, 1, 1), + MMASchedule("WMMA_F16_16x16x16_F32", 1, 1, 1, 1, 2), + MMASchedule("WMMA_F16_16x16x16_F32", 1, 1, 1, 2, 1), + MMASchedule("WMMA_F16_16x16x16_F32", 1, 1, 2, 1, 1), + MMASchedule("WMMA_F16_16x16x16_F32", 2, 2, 1, 1, 1), + MMASchedule("WMMA_F16_16x16x16_F32", 2, 4, 2, 1, 2), + MMASchedule("WMMA_F16_16x16x16_F32", 4, 2, 4, 2, 2), + ] + else: + raise NotImplementedError("unhandled intrinsic case") + + subgroup_size = 64 if intrinsic == "MFMA" else 32 + + infos = [] + for schedule in schedules: + # Skip schedules with an intrinsic which element type does not + # match the requested one. + + # Extract the input and acc type from intrinsic strings + extract_input_type = lambda s: re.search(r"(?:MFMA|WMMA)_([^_]+)_", s).group(1) + extract_output_type = lambda s: ( + match := re.search(r"_(F\d+|I\d+)$", s) + ) and match.group(1) + + if lhs_rhs_type.value.upper() != extract_input_type( + schedule.intrinsic + ) or acc_type.value.upper() != extract_output_type(schedule.intrinsic): + continue + + if schedule.intrinsic == "MFMA_F16_16x16x16_F32": + wg_tile_m = schedule.m_count * schedule.m_tile_count * 16 + wg_tile_n = schedule.n_count * schedule.n_tile_count * 16 + wg_tile_k = schedule.k_tile_count * 16 + elif schedule.intrinsic == "MFMA_F16_32x32x8_F32": + wg_tile_m = schedule.m_count * schedule.m_tile_count * 32 + wg_tile_n = schedule.n_count * schedule.n_tile_count * 32 + wg_tile_k = schedule.k_tile_count * 8 + elif schedule.intrinsic == "MFMA_I8_16x16x32_I32": + wg_tile_m = schedule.m_count * schedule.m_tile_count * 16 + wg_tile_n = schedule.n_count * schedule.n_tile_count * 16 + wg_tile_k = schedule.k_tile_count * 32 + elif schedule.intrinsic == "MFMA_I8_32x32x16_I32": + wg_tile_m = schedule.m_count * schedule.m_tile_count * 32 + wg_tile_n = schedule.n_count * schedule.n_tile_count * 32 + wg_tile_k = schedule.k_tile_count * 16 + elif schedule.intrinsic == "WMMA_F16_16x16x16_F32": + wg_tile_m = schedule.m_count * schedule.m_tile_count * 16 + wg_tile_n = schedule.n_count * schedule.n_tile_count * 16 + wg_tile_k = schedule.k_tile_count * 16 + else: + raise NotImplementedError("unhandled intrinsic case") + + workgroup_tile = [[1, 1, wg_tile_m, wg_tile_n, 1, 1, wg_tile_k]] + workgroup_size = [schedule.n_count * subgroup_size, schedule.m_count, 1] + + infos.append( + CompilationInfo( + tile_sizes=workgroup_tile, + dispatch_lowering_pass_pipeline="LLVMGPUVectorDistribute", + workgroup_size=workgroup_size, + software_pipeline_depth=0, + mma_schedule=schedule, + subgroup_size=subgroup_size, + ) + ) + return infos + + +# Returns the list of CompilationInfo's to use for the CompilationInfoId. +def get_test_compilation_infos( + compilation_info_id: CompilationInfoId, + lhs_rhs_type: InputElemTypeId, + acc_type: AccElemTypeId, +) -> typing.List[typing.Optional[CompilationInfo]]: + if compilation_info_id == CompilationInfoId.NONE: + return [None] + + if compilation_info_id in [ + CompilationInfoId.LLVMGPUVectorDistributeMFMA, + CompilationInfoId.LLVMGPUVectorDistributeWMMA, + ]: + return get_rocm_test_compilation_infos( + compilation_info_id, lhs_rhs_type, acc_type + ) + + subgroup_size = -1 + if compilation_info_id == CompilationInfoId.LLVMGPUVectorizeCDNA: + subgroup_size = 64 + elif compilation_info_id == CompilationInfoId.LLVMGPUVectorizeRDNA: + subgroup_size = 32 + + if subgroup_size == 64 or subgroup_size == 32: + workgroup_tile = [[1, 1, 8, 128, 1, 1, 4]] + workgroup_size = [64, 1, 1] + infos = [] + infos.append( + CompilationInfo( + tile_sizes=workgroup_tile, + dispatch_lowering_pass_pipeline="LLVMGPUVectorize", + workgroup_size=workgroup_size, + software_pipeline_depth=0, + mma_schedule=None, + subgroup_size=subgroup_size, + ) + ) + return infos + + raise ValueError("Unknown pipeline for rocm") + + # Intentionally fixed seed! We want full reproducibility here, both across runs # and across machines. # Intentionally not shared with pseudorandom_generator_seed to limit the ways @@ -267,7 +482,7 @@ def get_tensor_shape( if kernel_layout == KernelLayout.FCHW: kernel_tensor_shape = [f, c, kh, kw] elif kernel_layout == KernelLayout.HWCF: - kernel_tensor_shape = [f, c, kh, kw] + kernel_tensor_shape = [kh, kw, c, f] else: raise ValueError(kernel_layout) @@ -279,9 +494,10 @@ def get_tensor_shape( def generate_function_name( input_type: InputElemTypeId, kernel_type: KernelElemTypeId, - output_type: InputElemTypeId, + output_type: AccElemTypeId, shapes: TestInputTensorShapes, accumulate: bool, + compilation_info: typing.Optional[CompilationInfo] = None, ): input_t = input_type.value kernel_t = kernel_type.value @@ -294,10 +510,20 @@ def generate_function_name( kw = int_or_DYN(shapes.kw) f = int_or_DYN(shapes.f) + info = "" + if compilation_info: + tile_sizes = list(itertools.chain(*compilation_info.tile_sizes)) + tile_workgroup_key = ( + "_".join([str(a) for a in tile_sizes]) + + "_" + + "_".join([str(a) for a in compilation_info.workgroup_size]) + ) + info = f"_for_{compilation_info.dispatch_lowering_pass_pipeline}_{tile_workgroup_key}" + conv2d_kind = "conv2d_accumulate" if accumulate else "conv2d" return ( f"{conv2d_kind}_{n}_{c}_{h}_{w}_times_" - + f"{kh}_{kw}_{f}_dtype_{input_t}_{kernel_t}_{acc_t}" + + f"{kh}_{kw}_{f}_dtype_{input_t}_{kernel_t}_{acc_t}{info}" ) @@ -318,18 +544,15 @@ def generate_function( input_layout: InputLayout, kernel_type: KernelElemTypeId, kernel_layout: KernelLayout, - acc_type: InputElemTypeId, + acc_type: AccElemTypeId, conv2d_attr: ConvAttrs, shape: TestShape, dynamicity: Dynamicity, + compilation_info: typing.Optional[CompilationInfo] = None, ): shapes = generate_shapes(shape, dynamicity) func_name = generate_function_name( - input_type, - kernel_type, - acc_type, - shapes, - shape.accumulate, + input_type, kernel_type, acc_type, shapes, shape.accumulate, compilation_info ) input_shape, kernel_shape, output_shape = get_tensor_shape( @@ -338,7 +561,7 @@ def generate_function( input_tensor_type = f"tensor<{input_shape[0]}x{input_shape[1]}x{input_shape[2]}x{input_shape[3]}x{input_type.value}>" kernel_tensor_type = f"tensor<{kernel_shape[0]}x{kernel_shape[1]}x{kernel_shape[2]}x{kernel_shape[3]}x{kernel_type.value}>" - acc_tensor_type = f"tensor<{output_shape[0]}x{output_shape[1]}x{output_shape[2]}x{output_shape[3]}x{input_type.value}>" + acc_tensor_type = f"tensor<{output_shape[0]}x{output_shape[1]}x{output_shape[2]}x{output_shape[3]}x{acc_type.value}>" op_name = None if input_layout == InputLayout.NCHW: @@ -350,10 +573,38 @@ def generate_function( if kernel_layout == KernelLayout.HWCF: op_name = "linalg.conv_2d_nhwc_hwcf" + if op_name is None: + raise ValueError("Invalid combination of input_layout and kernel_layout") + conv_attr = f"{{dilations = dense<{list(conv2d_attr.DILATION)}> : tensor<2xi64>, strides = dense<{list(conv2d_attr.STRIDE)}> : tensor<2xi64>}}" # Compilation info is optional; prints empty string by default. func_definition = "" + # compilation_info_attr = "" + if compilation_info: + requested_pipeline = compilation_info.dispatch_lowering_pass_pipeline + compiler_pipeline = requested_pipeline + + mma_schedule = "" + if compilation_info.mma_schedule is not None: + mma_schedule = ", {}".format(compilation_info.mma_schedule) + subgroup_size_str = "" + if compilation_info.subgroup_size is not None: + subgroup_size_str = f"subgroup_size = {compilation_info.subgroup_size}" + + compilation_info_string = ( + f"#compilation{generate_function.compilation_index} = " + "#iree_codegen.compilation_info<\n" + f" lowering_config = #iree_codegen.lowering_config,\n" + f" translation_info = <{compiler_pipeline} {compilation_info.workgroup_size_str()}\n" + f" {subgroup_size_str},\n" + f" {{ pipeline_depth = {compilation_info.software_pipeline_depth}, " + f" store_stage = 1{mma_schedule} }}>>\n" + ) + + func_definition = func_definition + compilation_info_string + conv_attr = f"{{compilation_info = #compilation{generate_function.compilation_index}, dilations = dense<{list(conv2d_attr.DILATION)}> : tensor<2xi64>, strides = dense<{list(conv2d_attr.STRIDE)}> : tensor<2xi64>}}" + generate_function.compilation_index += 1 signature = f"({input_tensor_type}, {kernel_tensor_type}, {acc_tensor_type}) -> {acc_tensor_type}" import_declaration = f"func.func private @module.{func_name}(%input: !hal.buffer_view, %kernel: !hal.buffer_view, %acc: !hal.buffer_view) -> !hal.buffer_view" @@ -363,7 +614,6 @@ def generate_function( f" return %result: {acc_tensor_type}\n" f"}}\n" ) - return MLIRFunction( name=func_name, signature=signature, @@ -372,6 +622,10 @@ def generate_function( ) +# Counter for producing unique compilation info attrs +generate_function.compilation_index = 0 + + # Represents a call to a generated test function. @dataclasses.dataclass class TestCall: @@ -433,7 +687,7 @@ def generate_call( kernel_type: KernelElemTypeId, kernel_layout: KernelLayout, conv2d_attr: ConvAttrs, - acc_type: InputElemTypeId, + acc_type: AccElemTypeId, shape: TestShape, ): global call_id @@ -443,7 +697,22 @@ def generate_call( func_name = f"{func_name}_{call_id}" call_id = call_id + 1 + # layout of output tensor for checking correctness + layout = -1 description = f"Conv2d shape (NxCxHxWxFxKHxKW): {shape.n}x{shape.c}x{shape.h}x{shape.w}x{shape.f}x{shape.kh}x{shape.kw}" + if input_layout == InputLayout.NCHW: + if kernel_layout == KernelLayout.FCHW or kernel_layout == KernelLayout.HWCF: + layout = 0 # for output tensor NxFxOHxOW + else: + raise ValueError(kernel_layout) + elif input_layout == InputLayout.NHWC: + if kernel_layout == KernelLayout.HWCF: + layout = 1 # for output tensor NxOHxOWxF + else: + raise ValueError(kernel_layout) + else: + raise ValueError(InputLayout) + op = ( f"func.func @{func_name}() attributes {{\n" f' iree.reflection = {{description = "{description}"}}\n' @@ -485,11 +754,12 @@ def generate_call( f" %f = arith.constant {shape.f} : i64\n" f" %kh = arith.constant {shape.kh} : i64\n" f" %kw = arith.constant {shape.kw} : i64\n" + f" %layout = arith.constant {layout} : i64\n" f" %sh = arith.constant {conv2d_attr.STRIDE[0]} : i64\n" f" %sw = arith.constant {conv2d_attr.STRIDE[1]} : i64\n" f" %dh = arith.constant {conv2d_attr.DILATION[0]} : i64\n" f" %dw = arith.constant {conv2d_attr.DILATION[1]} : i64\n" - f" call @conv2d_test.check_conv2d_results(%device, %n, %c, %h, %w, %f, %kh, %kw, %sh, %sw, %dh, %dw, %input, %kernel, %acc, %result) : (!hal.device, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> ()\n" + f" call @conv2d_test.check_conv2d_results(%device, %n, %c, %h, %w, %f, %kh, %kw, %layout, %sh, %sw, %dh, %dw, %input, %kernel, %acc, %result) : (!hal.device, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> ()\n" ) op = op + " return\n" @@ -505,43 +775,48 @@ def generate( kernel_elem_type: KernelElemTypeId, kernel_layout: KernelLayout, conv2d_attr: ConvAttrs, - acc_type: InputElemTypeId, + acc_elem_type: AccElemTypeId, shapes_id: ShapesId, + compilation_info_id: CompilationInfoId, ): functions = {} calls = [] - for shape in get_test_shapes(shapes_id): - for dynamicity in get_dynamicities(shapes_id): - function = generate_function( - input_elem_type, - input_layout, - kernel_elem_type, - kernel_layout, - acc_type, - conv2d_attr, - shape, - dynamicity, - ) - # Different testcases may differ only by runtime parameters but - # share the same code. For example, dynamic-shapes testcases - # share the same code involing tensor even though the runtime - # value in the trace are different. That's why we append conditionally - # to calls, but unconditionally to function_definitions. - if function.name not in functions: - functions[function.name] = function - calls.append( - generate_call( - function, + for compilation_info in get_test_compilation_infos( + compilation_info_id, input_elem_type, acc_elem_type + ): + for shape in get_test_shapes(shapes_id): + for dynamicity in get_dynamicities(shapes_id): + function = generate_function( input_elem_type, input_layout, kernel_elem_type, kernel_layout, + acc_elem_type, conv2d_attr, - acc_type, shape, + dynamicity, + compilation_info, + ) + # Different testcases may differ only by runtime parameters but + # share the same code. For example, dynamic-shapes testcases + # share the same code involing tensor even though the runtime + # value in the trace are different. That's why we append conditionally + # to calls, but unconditionally to function_definitions. + if function.name not in functions: + functions[function.name] = function + calls.append( + generate_call( + function, + input_elem_type, + input_layout, + kernel_elem_type, + kernel_layout, + conv2d_attr, + acc_elem_type, + shape, + ) ) - ) return (functions, calls) @@ -563,7 +838,7 @@ def parse_arguments(): parser.add_argument( "--input_type", type=str, - choices=["f32", "f16"], + choices=["i8", "f32", "f16"], help="Numeric type of input tensors", required=True, ) @@ -578,7 +853,7 @@ def parse_arguments(): parser.add_argument( "--kernel_type", type=str, - choices=["f32", "f16"], + choices=["i8", "f32", "f16"], help="Numeric type of input tensors", required=True, ) @@ -593,7 +868,7 @@ def parse_arguments(): parser.add_argument( "--acc_type", type=str, - choices=["f32", "f16"], + choices=["i32", "f32", "f16"], help="Numeric type of input tensors", default="", required=False, @@ -619,6 +894,14 @@ def parse_arguments(): help="The stride factor for the convolution operation. Comma-separated. As in 1,1", required=False, ) + parser.add_argument( + "--compilation_info", + type=str, + choices=[i.value for i in CompilationInfoId], + help="Collection of compilation info setups to test", + default="", + required=False, + ) parser.add_argument( "--requirements", type=str, @@ -652,7 +935,7 @@ def write_calls_file(functions, calls, filename, requirements): # Declare the custom module that generates arguments. module_definition = module_definition + ( "func.func private @conv2d_test.generate_random_tensor(%device: !hal.device, %dim0: i64, %dim1: i64, %dim2: i64, %dim3: i64, %element_type: i32, %seed: i32) -> !hal.buffer_view\n" - "func.func private @conv2d_test.check_conv2d_results(%device: !hal.device, %n: i64, %c: i64, %h: i64, %w: i64, %f:i64, %kh:i64, %kw:i64, %sh:i64, %sw:i64, %dh:i64, %dw:i64, %input: !hal.buffer_view, %kernel: !hal.buffer_view, %acc: !hal.buffer_view, %actual_result: !hal.buffer_view)\n" + "func.func private @conv2d_test.check_conv2d_results(%device: !hal.device, %n: i64, %c: i64, %h: i64, %w: i64, %f:i64, %kh:i64, %kw:i64, %layout:i64, %sh:i64, %sw:i64, %dh:i64, %dw:i64, %input: !hal.buffer_view, %kernel: !hal.buffer_view, %acc: !hal.buffer_view, %actual_result: !hal.buffer_view)\n" "\n" ) @@ -676,9 +959,12 @@ def main(args): input_layout = InputLayout(args.input_layout) kernel_type = KernelElemTypeId(args.kernel_type) kernel_layout = KernelLayout(args.kernel_layout) - # TODO: The output type is same as the input type for now. - acc_type = input_type + + acc_type = AccElemTypeId(args.acc_type) + shapes_id = ShapesId(args.shapes) + compilation_info_id = CompilationInfoId(args.compilation_info) + conv2d_attr = ConvAttrs( tuple(map(int, args.stride.split(","))), tuple(map(int, args.dilation.split(","))), @@ -692,6 +978,7 @@ def main(args): conv2d_attr, acc_type, shapes_id, + compilation_info_id, ) write_code_file(functions, args.output_conv2d_mlir) diff --git a/tools/testing/e2e/iree-e2e-conv2d-test.cc b/tools/testing/e2e/iree-e2e-conv2d-test.cc index 31d02e953523..10e64521ef25 100644 --- a/tools/testing/e2e/iree-e2e-conv2d-test.cc +++ b/tools/testing/e2e/iree-e2e-conv2d-test.cc @@ -21,7 +21,7 @@ #include "tools/testing/e2e/test_utils.h" //===----------------------------------------------------------------------===// -// Reference conv2d (NCHW-FCHW) +// Reference conv2d (NCHW-FCHW) and (NHWC-HWCF) //===----------------------------------------------------------------------===// // Conversion from 4D indices in row major order to 1D index. @@ -36,57 +36,219 @@ static int convert_to_1d_index(iree_hal_dim_t channels, iree_hal_dim_t height, static void reference_conv2d_f16_f16_f16_f16( iree_hal_dim_t n_size, iree_hal_dim_t c_size, iree_hal_dim_t h_size, iree_hal_dim_t w_size, iree_hal_dim_t f_size, iree_hal_dim_t kh_size, - iree_hal_dim_t kw_size, iree_hal_dim_t sh_size, iree_hal_dim_t sw_size, - iree_hal_dim_t dh_size, iree_hal_dim_t dw_size, iree_hal_dim_t oh_size, - iree_hal_dim_t ow_size, const uint16_t* input_data, + iree_hal_dim_t kw_size, iree_hal_dim_t layout, iree_hal_dim_t sh_size, + iree_hal_dim_t sw_size, iree_hal_dim_t dh_size, iree_hal_dim_t dw_size, + iree_hal_dim_t oh_size, iree_hal_dim_t ow_size, const uint16_t* input_data, const uint16_t* kernel_data, const uint16_t* acc_data, uint16_t* result_data, iree_hal_dim_t n, iree_hal_dim_t oc, iree_hal_dim_t oh, iree_hal_dim_t ow) { - iree_hal_dim_t out_idx = - convert_to_1d_index(f_size, oh_size, ow_size, n, oc, oh, ow); + if (layout == 0) { + // The layout of output tensor is NxfxOHxOW + iree_hal_dim_t out_idx = + convert_to_1d_index(f_size, oh_size, ow_size, n, oc, oh, ow); + + float acc = acc_data ? iree_math_f16_to_f32(acc_data[out_idx]) : 0.f; + + for (iree_hal_dim_t ic = 0; ic < c_size; ++ic) { + for (iree_hal_dim_t kh = 0; kh < kh_size; ++kh) { + for (iree_hal_dim_t kw = 0; kw < kw_size; ++kw) { + iree_hal_dim_t inp_idx = convert_to_1d_index( + c_size, h_size, w_size, n, ic, (oh * sh_size + kh * dh_size), + (ow * sw_size + kw * dw_size)); + iree_hal_dim_t krnl_idx = + convert_to_1d_index(c_size, kh_size, kw_size, oc, ic, kh, kw); + + acc += iree_math_f16_to_f32(input_data[inp_idx]) * + iree_math_f16_to_f32(kernel_data[krnl_idx]); + } + } + result_data[out_idx] = iree_math_f32_to_f16(acc); + } + } else if (layout == 1) { + // The layout of output tensor is NxOHxOWxf + iree_hal_dim_t out_idx = + convert_to_1d_index(oh_size, ow_size, f_size, n, oh, ow, oc); - float acc = acc_data ? iree_math_f16_to_f32(acc_data[out_idx]) : 0.f; + float acc = acc_data ? iree_math_f16_to_f32(acc_data[out_idx]) : 0.f; - for (iree_hal_dim_t ic = 0; ic < c_size; ++ic) { for (iree_hal_dim_t kh = 0; kh < kh_size; ++kh) { for (iree_hal_dim_t kw = 0; kw < kw_size; ++kw) { - iree_hal_dim_t inp_idx = convert_to_1d_index( - c_size, h_size, w_size, n, ic, (oh * sh_size + kh * dh_size), - (ow * sw_size + kw * dw_size)); - iree_hal_dim_t krnl_idx = - convert_to_1d_index(c_size, kh_size, kw_size, oc, ic, kh, kw); - - acc += iree_math_f16_to_f32(input_data[inp_idx]) * - iree_math_f16_to_f32(kernel_data[krnl_idx]); + for (iree_hal_dim_t ic = 0; ic < c_size; ++ic) { + iree_hal_dim_t inp_idx = convert_to_1d_index( + h_size, w_size, c_size, n, (oh * sh_size + kh * dh_size), + (ow * sw_size + kw * dw_size), ic); + iree_hal_dim_t krnl_idx = + convert_to_1d_index(kw_size, c_size, f_size, kh, kw, ic, oc); + + acc += iree_math_f16_to_f32(input_data[inp_idx]) * + iree_math_f16_to_f32(kernel_data[krnl_idx]); + } } } result_data[out_idx] = iree_math_f32_to_f16(acc); } } -static void reference_conv2d_f32_f32_f32_f32( +// [f32 <= f16 * f16 + f32] +static void reference_conv2d_f16_f16_f32_f32( + iree_hal_dim_t n_size, iree_hal_dim_t c_size, iree_hal_dim_t h_size, + iree_hal_dim_t w_size, iree_hal_dim_t f_size, iree_hal_dim_t kh_size, + iree_hal_dim_t kw_size, iree_hal_dim_t layout, iree_hal_dim_t sh_size, + iree_hal_dim_t sw_size, iree_hal_dim_t dh_size, iree_hal_dim_t dw_size, + iree_hal_dim_t oh_size, iree_hal_dim_t ow_size, const uint16_t* input_data, + const uint16_t* kernel_data, const float* acc_data, float* result_data, + iree_hal_dim_t n, iree_hal_dim_t oc, iree_hal_dim_t oh, iree_hal_dim_t ow) { + if (layout == 0) { + // The layout of output tensor is NxfxOHxOW + iree_hal_dim_t out_idx = + convert_to_1d_index(f_size, oh_size, ow_size, n, oc, oh, ow); + + float acc = acc_data ? acc_data[out_idx] : 0.f; + + for (iree_hal_dim_t ic = 0; ic < c_size; ++ic) { + for (iree_hal_dim_t kh = 0; kh < kh_size; ++kh) { + for (iree_hal_dim_t kw = 0; kw < kw_size; ++kw) { + iree_hal_dim_t inp_idx = convert_to_1d_index( + c_size, h_size, w_size, n, ic, (oh * sh_size + kh * dh_size), + (ow * sw_size + kw * dw_size)); + iree_hal_dim_t krnl_idx = + convert_to_1d_index(c_size, kh_size, kw_size, oc, ic, kh, kw); + + acc += iree_math_f16_to_f32(input_data[inp_idx]) * + iree_math_f16_to_f32(kernel_data[krnl_idx]); + } + } + result_data[out_idx] = acc; + } + } else if (layout == 1) { + // The layout of output tensor is NxOHxOWxf + iree_hal_dim_t out_idx = + convert_to_1d_index(oh_size, ow_size, f_size, n, oh, ow, oc); + + float acc = acc_data ? acc_data[out_idx] : 0.f; + + for (iree_hal_dim_t kh = 0; kh < kh_size; ++kh) { + for (iree_hal_dim_t kw = 0; kw < kw_size; ++kw) { + for (iree_hal_dim_t ic = 0; ic < c_size; ++ic) { + iree_hal_dim_t inp_idx = convert_to_1d_index( + h_size, w_size, c_size, n, (oh * sh_size + kh * dh_size), + (ow * sw_size + kw * dw_size), ic); + iree_hal_dim_t krnl_idx = + convert_to_1d_index(kw_size, c_size, f_size, kh, kw, ic, oc); + + acc += iree_math_f16_to_f32(input_data[inp_idx]) * + iree_math_f16_to_f32(kernel_data[krnl_idx]); + } + } + } + result_data[out_idx] = acc; + } +} + +// [i32 <= i8 * i8 + i32] +static void reference_conv2d_i8_i8_i32_i32( iree_hal_dim_t n_size, iree_hal_dim_t c_size, iree_hal_dim_t h_size, iree_hal_dim_t w_size, iree_hal_dim_t f_size, iree_hal_dim_t kh_size, - iree_hal_dim_t kw_size, iree_hal_dim_t sh_size, iree_hal_dim_t sw_size, - iree_hal_dim_t dh_size, iree_hal_dim_t dw_size, iree_hal_dim_t oh_size, - iree_hal_dim_t ow_size, const float* input_data, const float* kernel_data, - const float* acc_data, float* result_data, iree_hal_dim_t n, - iree_hal_dim_t oc, iree_hal_dim_t oh, iree_hal_dim_t ow) { - iree_hal_dim_t out_idx = - convert_to_1d_index(f_size, oh_size, ow_size, n, oc, oh, ow); + iree_hal_dim_t kw_size, iree_hal_dim_t layout, iree_hal_dim_t sh_size, + iree_hal_dim_t sw_size, iree_hal_dim_t dh_size, iree_hal_dim_t dw_size, + iree_hal_dim_t oh_size, iree_hal_dim_t ow_size, const int8_t* input_data, + const int8_t* kernel_data, const int32_t* acc_data, int32_t* result_data, + iree_hal_dim_t n, iree_hal_dim_t oc, iree_hal_dim_t oh, iree_hal_dim_t ow) { + if (layout == 0) { + // The layout of output tensor is NxfxOHxOW + iree_hal_dim_t out_idx = + convert_to_1d_index(f_size, oh_size, ow_size, n, oc, oh, ow); + + int32_t acc = acc_data ? acc_data[out_idx] : 0; + + for (iree_hal_dim_t ic = 0; ic < c_size; ++ic) { + for (iree_hal_dim_t kh = 0; kh < kh_size; ++kh) { + for (iree_hal_dim_t kw = 0; kw < kw_size; ++kw) { + iree_hal_dim_t inp_idx = convert_to_1d_index( + c_size, h_size, w_size, n, ic, (oh * sh_size + kh * dh_size), + (ow * sw_size + kw * dw_size)); + iree_hal_dim_t krnl_idx = + convert_to_1d_index(c_size, kh_size, kw_size, oc, ic, kh, kw); + + int8_t lhs_value = input_data[inp_idx]; + int8_t rhs_value = kernel_data[krnl_idx]; + acc += (int32_t)lhs_value * (int32_t)rhs_value; + } + } + result_data[out_idx] = acc; + } + } else if (layout == 1) { + // The layout of output tensor is NxOHxOWxf + iree_hal_dim_t out_idx = + convert_to_1d_index(oh_size, ow_size, f_size, n, oh, ow, oc); - float acc = acc_data ? acc_data[out_idx] : 0; + int32_t acc = acc_data ? acc_data[out_idx] : 0; - for (iree_hal_dim_t ic = 0; ic < c_size; ++ic) { for (iree_hal_dim_t kh = 0; kh < kh_size; ++kh) { for (iree_hal_dim_t kw = 0; kw < kw_size; ++kw) { - iree_hal_dim_t inp_idx = convert_to_1d_index( - c_size, h_size, w_size, n, ic, (oh * sh_size + kh * dh_size), - (ow * sw_size + kw * dw_size)); - iree_hal_dim_t krnl_idx = - convert_to_1d_index(c_size, kh_size, kw_size, oc, ic, kh, kw); + for (iree_hal_dim_t ic = 0; ic < c_size; ++ic) { + iree_hal_dim_t inp_idx = convert_to_1d_index( + h_size, w_size, c_size, n, (oh * sh_size + kh * dh_size), + (ow * sw_size + kw * dw_size), ic); + iree_hal_dim_t krnl_idx = + convert_to_1d_index(kw_size, c_size, f_size, kh, kw, ic, oc); + + int8_t lhs_value = input_data[inp_idx]; + int8_t rhs_value = kernel_data[krnl_idx]; + acc += (int32_t)lhs_value * (int32_t)rhs_value; + } + } + } + result_data[out_idx] = acc; + } +} + +static void reference_conv2d_f32_f32_f32_f32( + iree_hal_dim_t n_size, iree_hal_dim_t c_size, iree_hal_dim_t h_size, + iree_hal_dim_t w_size, iree_hal_dim_t f_size, iree_hal_dim_t kh_size, + iree_hal_dim_t kw_size, iree_hal_dim_t layout, iree_hal_dim_t sh_size, + iree_hal_dim_t sw_size, iree_hal_dim_t dh_size, iree_hal_dim_t dw_size, + iree_hal_dim_t oh_size, iree_hal_dim_t ow_size, const float* input_data, + const float* kernel_data, const float* acc_data, float* result_data, + iree_hal_dim_t n, iree_hal_dim_t oc, iree_hal_dim_t oh, iree_hal_dim_t ow) { + if (layout == 0) { + // The layout of output tensor is NxfxOHxOW + iree_hal_dim_t out_idx = + convert_to_1d_index(f_size, oh_size, ow_size, n, oc, oh, ow); + + float acc = acc_data ? acc_data[out_idx] : 0; + + for (iree_hal_dim_t ic = 0; ic < c_size; ++ic) { + for (iree_hal_dim_t kh = 0; kh < kh_size; ++kh) { + for (iree_hal_dim_t kw = 0; kw < kw_size; ++kw) { + iree_hal_dim_t inp_idx = convert_to_1d_index( + c_size, h_size, w_size, n, ic, (oh * sh_size + kh * dh_size), + (ow * sw_size + kw * dw_size)); + iree_hal_dim_t krnl_idx = + convert_to_1d_index(c_size, kh_size, kw_size, oc, ic, kh, kw); + + acc += input_data[inp_idx] * kernel_data[krnl_idx]; + } + } + result_data[out_idx] = acc; + } + } else if (layout == 1) { + // The layout of output tensor is NxOHxOWxf + iree_hal_dim_t out_idx = + convert_to_1d_index(oh_size, ow_size, f_size, n, oh, ow, oc); - acc += input_data[inp_idx] * kernel_data[krnl_idx]; + float acc = acc_data ? acc_data[out_idx] : 0; + + for (iree_hal_dim_t kh = 0; kh < kh_size; ++kh) { + for (iree_hal_dim_t kw = 0; kw < kw_size; ++kw) { + for (iree_hal_dim_t ic = 0; ic < c_size; ++ic) { + iree_hal_dim_t inp_idx = convert_to_1d_index( + h_size, w_size, c_size, n, (oh * sh_size + kh * dh_size), + (ow * sw_size + kw * dw_size), ic); + iree_hal_dim_t krnl_idx = + convert_to_1d_index(kw_size, c_size, f_size, kh, kw, ic, oc); + acc += input_data[inp_idx] * kernel_data[krnl_idx]; + } } } result_data[out_idx] = acc; @@ -97,28 +259,45 @@ static void reference_conv2d_f32_f32_f32_f32( static iree_status_t reference_conv2d_element( iree_hal_dim_t n_size, iree_hal_dim_t c_size, iree_hal_dim_t h_size, iree_hal_dim_t w_size, iree_hal_dim_t f_size, iree_hal_dim_t kh_size, - iree_hal_dim_t kw_size, iree_hal_dim_t sh_size, iree_hal_dim_t sw_size, - iree_hal_dim_t dh_size, iree_hal_dim_t dw_size, iree_hal_dim_t oh_size, - iree_hal_dim_t ow_size, iree_hal_element_type_t input_type, - iree_hal_element_type_t kernel_type, iree_hal_element_type_t acc_type, - void* input_data, void* kernel_data, void* acc_data, void* result_data, - iree_hal_dim_t n, iree_hal_dim_t oc, iree_hal_dim_t oh, iree_hal_dim_t ow) { + iree_hal_dim_t kw_size, iree_hal_dim_t layout, iree_hal_dim_t sh_size, + iree_hal_dim_t sw_size, iree_hal_dim_t dh_size, iree_hal_dim_t dw_size, + iree_hal_dim_t oh_size, iree_hal_dim_t ow_size, + iree_hal_element_type_t input_type, iree_hal_element_type_t kernel_type, + iree_hal_element_type_t acc_type, void* input_data, void* kernel_data, + void* acc_data, void* result_data, iree_hal_dim_t n, iree_hal_dim_t oc, + iree_hal_dim_t oh, iree_hal_dim_t ow) { if (input_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 && kernel_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 && acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) { reference_conv2d_f32_f32_f32_f32( - n_size, c_size, h_size, w_size, f_size, kh_size, kw_size, sh_size, - sw_size, dh_size, dw_size, oh_size, ow_size, (const float*)input_data, - (const float*)kernel_data, (const float*)acc_data, (float*)result_data, - n, oc, oh, ow); + n_size, c_size, h_size, w_size, f_size, kh_size, kw_size, layout, + sh_size, sw_size, dh_size, dw_size, oh_size, ow_size, + (const float*)input_data, (const float*)kernel_data, + (const float*)acc_data, (float*)result_data, n, oc, oh, ow); } else if (input_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 && kernel_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 && acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16) { reference_conv2d_f16_f16_f16_f16( - n_size, c_size, h_size, w_size, f_size, kh_size, kw_size, sh_size, - sw_size, dh_size, dw_size, oh_size, ow_size, + n_size, c_size, h_size, w_size, f_size, kh_size, kw_size, layout, + sh_size, sw_size, dh_size, dw_size, oh_size, ow_size, (const uint16_t*)input_data, (const uint16_t*)kernel_data, (const uint16_t*)acc_data, (uint16_t*)result_data, n, oc, oh, ow); + } else if (input_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 && + kernel_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 && + acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) { + reference_conv2d_f16_f16_f32_f32( + n_size, c_size, h_size, w_size, f_size, kh_size, kw_size, layout, + sh_size, sw_size, dh_size, dw_size, oh_size, ow_size, + (const uint16_t*)input_data, (const uint16_t*)kernel_data, + (const float*)acc_data, (float*)result_data, n, oc, oh, ow); + } else if (input_type == IREE_HAL_ELEMENT_TYPE_INT_8 && + kernel_type == IREE_HAL_ELEMENT_TYPE_INT_8 && + acc_type == IREE_HAL_ELEMENT_TYPE_INT_32) { + reference_conv2d_i8_i8_i32_i32( + n_size, c_size, h_size, w_size, f_size, kh_size, kw_size, layout, + sh_size, sw_size, dh_size, dw_size, oh_size, ow_size, + (const int8_t*)input_data, (const int8_t*)kernel_data, + (const int32_t*)acc_data, (int32_t*)result_data, n, oc, oh, ow); } else { return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "unhandled combination of element types in conv2d"); @@ -141,8 +320,8 @@ static iree_hal_dim_t out_shape_calc(iree_hal_dim_t i_shape, static iree_status_t reference_conv2d( iree_hal_dim_t n_size, iree_hal_dim_t c_size, iree_hal_dim_t h_size, iree_hal_dim_t w_size, iree_hal_dim_t f_size, iree_hal_dim_t kh_size, - iree_hal_dim_t kw_size, iree_hal_dim_t sh_size, iree_hal_dim_t sw_size, - iree_hal_dim_t dh_size, iree_hal_dim_t dw_size, + iree_hal_dim_t kw_size, iree_hal_dim_t layout, iree_hal_dim_t sh_size, + iree_hal_dim_t sw_size, iree_hal_dim_t dh_size, iree_hal_dim_t dw_size, iree_hal_element_type_t input_type, iree_hal_element_type_t kernel_type, iree_hal_element_type_t acc_type, iree_byte_span_t input_contents, iree_byte_span_t kernel_contents, iree_byte_span_t acc_contents, @@ -156,20 +335,43 @@ static iree_status_t reference_conv2d( iree_hal_dim_t oh_size = out_shape_calc(h_size, kh_size, sh_size, dh_size); iree_hal_dim_t ow_size = out_shape_calc(w_size, kw_size, sw_size, dw_size); - for (iree_hal_dim_t n = 0; n < n_size; ++n) { - for (iree_hal_dim_t oc = 0; oc < f_size; ++oc) { + if (layout == 0) { + for (iree_hal_dim_t n = 0; n < n_size; ++n) { + for (iree_hal_dim_t oc = 0; oc < f_size; ++oc) { + for (iree_hal_dim_t oh = 0; oh < oh_size; ++oh) { + for (iree_hal_dim_t ow = 0; ow < ow_size; ++ow) { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + reference_conv2d_element( + n_size, c_size, h_size, w_size, f_size, kh_size, kw_size, + layout, sh_size, sw_size, dh_size, dw_size, oh_size, + ow_size, input_type, kernel_type, acc_type, + input_contents.data, kernel_contents.data, + acc_contents.data, result_contents.data, n, oc, oh, ow)); + } + } + } + } + } else if (layout == 1) { + for (iree_hal_dim_t n = 0; n < n_size; ++n) { for (iree_hal_dim_t oh = 0; oh < oh_size; ++oh) { for (iree_hal_dim_t ow = 0; ow < ow_size; ++ow) { - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, reference_conv2d_element( - n_size, c_size, h_size, w_size, f_size, kh_size, kw_size, - sh_size, sw_size, dh_size, dw_size, oh_size, ow_size, - input_type, kernel_type, acc_type, input_contents.data, - kernel_contents.data, acc_contents.data, - result_contents.data, n, oc, oh, ow)); + for (iree_hal_dim_t oc = 0; oc < f_size; ++oc) { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + reference_conv2d_element( + n_size, c_size, h_size, w_size, f_size, kh_size, kw_size, + layout, sh_size, sw_size, dh_size, dw_size, oh_size, + ow_size, input_type, kernel_type, acc_type, + input_contents.data, kernel_contents.data, + acc_contents.data, result_contents.data, n, oc, oh, ow)); + } } } } + } else { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "unhandled conv2d layout"); } IREE_TRACE_ZONE_END(z0); @@ -182,17 +384,18 @@ static iree_status_t reference_conv2d( typedef struct { iree_allocator_t host_allocator; - iree_hal_dim_t n; // batch dim - iree_hal_dim_t c; // input channels - iree_hal_dim_t h; // input height - iree_hal_dim_t w; // input width - iree_hal_dim_t f; // output channels - iree_hal_dim_t kh; // kernel height - iree_hal_dim_t kw; // kernel width - iree_hal_dim_t sh; // stride along height dim - iree_hal_dim_t sw; // stride along width dim - iree_hal_dim_t dh; // dilation along height dim - iree_hal_dim_t dw; // dilation along width dim + iree_hal_dim_t n; // batch dim + iree_hal_dim_t c; // input channels + iree_hal_dim_t h; // input height + iree_hal_dim_t w; // input width + iree_hal_dim_t f; // output channels + iree_hal_dim_t kh; // kernel height + iree_hal_dim_t kw; // kernel width + iree_hal_dim_t layout; // conv layout, 0 : nchwxfchw (default); 1: nhwcxhwcf + iree_hal_dim_t sh; // stride along height dim + iree_hal_dim_t sw; // stride along width dim + iree_hal_dim_t dh; // dilation along height dim + iree_hal_dim_t dw; // dilation along width dim iree_hal_element_type_t input_type; iree_hal_element_type_t kernel_type; iree_hal_element_type_t acc_type; @@ -209,11 +412,12 @@ static void conv2d_results_deinitialize(conv2d_results_t* results); static iree_status_t conv2d_results_initialize( iree_hal_device_t* device, iree_hal_dim_t n_size, iree_hal_dim_t c_size, iree_hal_dim_t h_size, iree_hal_dim_t w_size, iree_hal_dim_t f_size, - iree_hal_dim_t kh_size, iree_hal_dim_t kw_size, iree_hal_dim_t sh_size, - iree_hal_dim_t sw_size, iree_hal_dim_t dh_size, iree_hal_dim_t dw_size, - iree_hal_buffer_view_t* input, iree_hal_buffer_view_t* kernel, - iree_hal_buffer_view_t* acc, iree_hal_buffer_view_t* result, - iree_allocator_t host_allocator, conv2d_results_t* out_results) { + iree_hal_dim_t kh_size, iree_hal_dim_t kw_size, iree_hal_dim_t layout, + iree_hal_dim_t sh_size, iree_hal_dim_t sw_size, iree_hal_dim_t dh_size, + iree_hal_dim_t dw_size, iree_hal_buffer_view_t* input, + iree_hal_buffer_view_t* kernel, iree_hal_buffer_view_t* acc, + iree_hal_buffer_view_t* result, iree_allocator_t host_allocator, + conv2d_results_t* out_results) { IREE_TRACE_ZONE_BEGIN(z0); memset(out_results, 0, sizeof(*out_results)); @@ -226,6 +430,7 @@ static iree_status_t conv2d_results_initialize( out_results->f = f_size; out_results->kh = kh_size; out_results->kw = kw_size; + out_results->layout = layout; out_results->sh = sh_size; out_results->sw = sw_size; out_results->dh = dh_size; @@ -340,13 +545,13 @@ static iree_status_t check_conv2d_results_impl(FILE* file, IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, reference_conv2d(results->n, results->c, results->h, results->w, - results->f, results->kh, results->kw, results->sh, - results->sw, results->dh, results->dw, - results->input_type, results->acc_type, - results->kernel_type, results->input_contents, - results->kernel_contents, results->acc_contents, - results->expected_contents, check_every)); + z0, + reference_conv2d( + results->n, results->c, results->h, results->w, results->f, + results->kh, results->kw, results->layout, results->sh, results->sw, + results->dh, results->dw, results->input_type, results->kernel_type, + results->acc_type, results->input_contents, results->kernel_contents, + results->acc_contents, results->expected_contents, check_every)); int count = 0; @@ -466,7 +671,8 @@ class Conv2dTestModuleState final { int32_t max = 0; iree_test_utils_get_min_max_for_element_type( callback_state.element_type, &min, &max); - uint32_t range = (max - min + 1); + // divided by 4 to make numerical behavior more stable + uint32_t range = (max - min + 1) / 4; iree_host_size_t element_byte_count = iree_hal_element_dense_byte_count(callback_state.element_type); uint8_t* data_end = span.data + span.data_length; @@ -487,8 +693,9 @@ class Conv2dTestModuleState final { Status CheckConv2dResults( const vm::ref device, int64_t n, int64_t c, int64_t h, - int64_t w, int64_t f, int64_t kh, int64_t kw, int64_t sh, int64_t sw, - int64_t dh, int64_t dw, const vm::ref input, + int64_t w, int64_t f, int64_t kh, int64_t kw, int64_t layout, int64_t sh, + int64_t sw, int64_t dh, int64_t dw, + const vm::ref input, const vm::ref kernel, const vm::ref acc, const vm::ref actual_result) { @@ -496,9 +703,10 @@ class Conv2dTestModuleState final { IREE_RETURN_IF_ERROR(conv2d_results_initialize( device.get(), (iree_hal_dim_t)n, (iree_hal_dim_t)c, (iree_hal_dim_t)h, (iree_hal_dim_t)w, (iree_hal_dim_t)f, (iree_hal_dim_t)kh, - (iree_hal_dim_t)kw, (iree_hal_dim_t)sh, (iree_hal_dim_t)sw, - (iree_hal_dim_t)dh, (iree_hal_dim_t)dw, input.get(), kernel.get(), - acc.get(), actual_result.get(), host_allocator_, &results)); + (iree_hal_dim_t)kw, (iree_hal_dim_t)layout, (iree_hal_dim_t)sh, + (iree_hal_dim_t)sw, (iree_hal_dim_t)dh, (iree_hal_dim_t)dw, input.get(), + kernel.get(), acc.get(), actual_result.get(), host_allocator_, + &results)); iree_status_t status = check_conv2d_results(stderr, &results); conv2d_results_deinitialize(&results); return status;