Skip to content

Commit

Permalink
Add simple multiply2 XNNPACK example
Browse files Browse the repository at this point in the history
This commit replaces the multiplication of tensors in the system
plugin example with a call to XNNPACK's multiply2. The plugin runs e2e
correctly.

Note: the CMake changes can definitely be improved. I just needed to
get the ball rolling.
  • Loading branch information
ramiro050 committed Oct 17, 2023
1 parent 9181525 commit 981e7ad
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 11 deletions.
12 changes: 12 additions & 0 deletions samples/custom_dispatch/cpu/plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ if(NOT IREE_TARGET_BACKEND_LLVM_CPU OR
return()
endif()

set(XNNPACK_DIR "" CACHE STRING "XNNPACK directory")

add_library(xnnpack STATIC IMPORTED)
set_target_properties(xnnpack PROPERTIES IMPORTED_LOCATION ${XNNPACK_DIR}/build/local/libXNNPACK.a)
set_target_properties(xnnpack PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${XNNPACK_DIR}/include)

add_library(pthreadpool STATIC IMPORTED)
set_target_properties(pthreadpool PROPERTIES IMPORTED_LOCATION ${XNNPACK_DIR}/build/local/pthreadpool/libpthreadpool.a)
set_target_properties(pthreadpool PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${XNNPACK_DIR}/build/local/pthreadpool-source/include)

# system-library plugin mechanism using the system dynamic library loader.
if(IREE_HAL_EXECUTABLE_PLUGIN_SYSTEM_LIBRARY)

Expand All @@ -21,6 +31,8 @@ target_include_directories(iree_samples_custom_dispatch_cpu_system_plugin
${IREE_SOURCE_DIR}/runtime/src/
)

target_link_libraries(iree_samples_custom_dispatch_cpu_system_plugin xnnpack pthreadpool cpuinfo)

# NOTE: this is only required because we want this sample to run on all
# platforms without needing to change the library name (libfoo.so/foo.dll).
set_target_properties(iree_samples_custom_dispatch_cpu_system_plugin
Expand Down
84 changes: 73 additions & 11 deletions samples/custom_dispatch/cpu/plugin/system_plugin.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
// arbitrary threads concurrently. Be very careful and prefer standalone plugins
// instead except when debugging/profiling.

#include <assert.h>
#include <inttypes.h>
#include <stdio.h>
#include <xnnpack.h>

// The only header required from IREE:
#include "iree/hal/local/executable_plugin.h"
Expand Down Expand Up @@ -77,24 +79,84 @@ static int simple_mul_workgroup(void* params_ptr, void* context,
const uint64_t* restrict processor_data;
} params_t;
const params_t* params = (const params_t*)params_ptr;

enum xnn_status status;
const struct xnn_allocator* allocator = NULL;
status = xnn_initialize(allocator);
assert(status == xnn_status_success && "unable to initialize XNNPACK");

xnn_subgraph_t subgraph = NULL;
status =
xnn_create_subgraph(/*external_value_ids=*/3, /*flags=*/0, &subgraph);
assert(status == xnn_status_success && "unable to create subgraph");

const size_t dims[1] = {params->size};
uint32_t lhs_id = XNN_INVALID_VALUE_ID;
status = xnn_define_tensor_value(subgraph, /*datatype=*/xnn_datatype_fp32,
/*num_dims=*/1, /*dims=*/dims,
/*data=*/NULL,
/*external_id=*/0,
/*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT,
/*id_out=*/&lhs_id);
assert(status == xnn_status_success && "unable to define lhs input tensor");

uint32_t rhs_id = XNN_INVALID_VALUE_ID;
status = xnn_define_tensor_value(subgraph, /*datatype=*/xnn_datatype_fp32,
/*num_dims=*/1, /*dims=*/dims,
/*data=*/NULL,
/*external_id=*/1,
/*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT,
/*id_out=*/&rhs_id);
assert(status == xnn_status_success && "unable to define rhs input tensor");

uint32_t output_id = XNN_INVALID_VALUE_ID;
status = xnn_define_tensor_value(subgraph, xnn_datatype_fp32,
/*num_dims=*/1, dims, NULL,
/*external_id=*/2,
XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id);
assert(status == xnn_status_success && "unable to define output tensor");

status = xnn_define_multiply2(subgraph, -100.0f, 100.0f, lhs_id, rhs_id,
output_id, /*flags=*/0);
assert(status == xnn_status_success && "unable to define multiply2");

xnn_runtime_t runtime = NULL;
status = xnn_create_runtime(subgraph, &runtime);
assert(status == xnn_status_success && "unable to create runtime");
struct xnn_external_value lhs_external_value = {
lhs_id, (void*)&(params->binding0[params->binding0_offset])};
struct xnn_external_value rhs_external_value = {
rhs_id, (void*)&(params->binding1[params->binding1_offset])};
struct xnn_external_value output_external_value = {
output_id, (void*)&(params->binding2[params->binding2_offset])};
const struct xnn_external_value externals[3] = {
lhs_external_value, rhs_external_value, output_external_value};
status = xnn_setup_runtime(runtime, /*num_external_values=*/3, externals);
assert(status == xnn_status_success && "unable to setup runtime");

status = xnn_invoke_runtime(runtime);
assert(status == xnn_status_success && "unable to invoke runtime");

status = xnn_delete_runtime(runtime);
assert(status == xnn_status_success && "unable to delete runtime");
status = xnn_delete_subgraph(subgraph);
assert(status == xnn_status_success && "unable to delete subgraph");
status = xnn_deinitialize();
assert(status == xnn_status_success && "unable to deinitialize");

fprintf(plugin->file, "processor_id=%u\n", params->processor_id);
if (params->processor_data) {
fprintf(plugin->file, "processor_data[0]=%" PRIX64 "\n",
params->processor_data[0]);
}
// The operation `iree_codegen.ukernel.generic` always operates
// on a slice of the inputs to produce a slice of the output,
// so the loop here just needs to iterate from `0` to `size`,
// where `size` is the size of the slice to be executed by this call.
for (size_t i = 0; i < params->size; ++i) {
params->binding2[params->binding2_offset + i] =
params->binding0[params->binding0_offset + i] *
params->binding1[params->binding2_offset + i];
fprintf(plugin->file, "mul[%zu:%zu](%g * %g = %g)\n", params->tid, i,
params->binding0[params->binding0_offset + i],
params->binding1[params->binding1_offset + i],
params->binding2[params->binding2_offset + i]);
float curr_lhs = params->binding0[params->binding0_offset + i];
float curr_rhs = params->binding1[params->binding1_offset + i];
float curr_output = params->binding2[params->binding2_offset + i];
fprintf(plugin->file, "mul2[%zu:%zu](%g * %g = %g)\n", params->tid, i,
curr_lhs, curr_rhs, curr_output);
}

return 0;
}

Expand Down

0 comments on commit 981e7ad

Please sign in to comment.