diff --git a/iree/task/task.c b/iree/task/task.c index 8a8110fbe9f7..a96c4400b1c4 100644 --- a/iree/task/task.c +++ b/iree/task/task.c @@ -522,7 +522,6 @@ void iree_task_dispatch_issue_sharded( iree_task_dispatch_shard_state_t* shared_state = &dispatch_task->shared.shard_state; - shared_state->dispatch_task = dispatch_task; // Fetch the workgroup count (directly or indirectly). if (dispatch_task->header.flags & IREE_TASK_FLAG_DISPATCH_INDIRECT) { @@ -754,6 +753,7 @@ void iree_task_dispatch_shard_initialize( iree_task_initialize(IREE_TASK_TYPE_DISPATCH_SHARD, dispatch_task->header.scope, &out_task->header); iree_task_set_completion_task(&out_task->header, &dispatch_task->header); + out_task->dispatch_task = dispatch_task; out_task->shared_state = shared_state; } @@ -778,12 +778,12 @@ iree_status_t iree_task_dispatch_shard_execute( iree_task_submission_t* pending_submission) { IREE_TRACE_ZONE_BEGIN(z0); - iree_task_dispatch_shard_state_t* shared_state = task->shared_state; - iree_task_dispatch_t* dispatch_task = shared_state->dispatch_task; + iree_task_dispatch_t* dispatch_task = task->dispatch_task; IREE_TRACE_ZONE_SET_COLOR( z0, iree_math_ptr_to_xrgb(dispatch_task->closure.user_context)); // Prepare context shared for all tiles in the shard. + iree_task_dispatch_shard_state_t* shared_state = task->shared_state; iree_task_tile_context_t tile_context; memcpy(&tile_context.workgroup_size, dispatch_task->workgroup_size, sizeof(tile_context.workgroup_size)); diff --git a/iree/task/task.h b/iree/task/task.h index 0010d3a52f24..ddb841b9c2db 100644 --- a/iree/task/task.h +++ b/iree/task/task.h @@ -445,9 +445,6 @@ typedef struct iree_task_dispatch_s iree_task_dispatch_t; // Shared state for all shards processing a dispatch. typedef iree_alignas(iree_max_align_t) struct { - // Direct reference to the parent dispatch that all shards are processing. - iree_task_dispatch_t* dispatch_task; - // The tail tile index; the next reservation will start from here. iree_atomic_int32_t tile_index; @@ -649,6 +646,9 @@ typedef iree_alignas(iree_max_align_t) struct { // Task header: implementation detail, do not use. iree_task_t header; + // The root dispatch task that this shard is a part of. + iree_task_dispatch_t* dispatch_task; + // Active dispatch progress shared across all shards. // Each shard will be read/modify/writing this and there's likely to be // contention.