From e072477c011c1a651ff12e858f66b746b0637218 Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Tue, 28 Jan 2025 02:07:45 -0800 Subject: [PATCH] Allow fp16 re-write of pack-lh nodes PiperOrigin-RevId: 720489462 --- src/subgraph.c | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/src/subgraph.c b/src/subgraph.c index be3ff741acc..b33e16e6915 100644 --- a/src/subgraph.c +++ b/src/subgraph.c @@ -903,6 +903,7 @@ bool xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph) case xnn_node_type_global_average_pooling_2d: case xnn_node_type_global_sum_pooling_2d: case xnn_node_type_max_pooling_2d: + case xnn_node_type_pack_lh: case xnn_node_type_softmax: case xnn_node_type_space_to_depth_2d: case xnn_node_type_static_constant_pad: @@ -961,6 +962,18 @@ bool xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph) subgraph->values[node->inputs[2]].datatype == xnn_datatype_fp32) { subgraph->values[node->inputs[2]].fp16_compatible = true; } + } else if (subgraph->values[node->inputs[0]].datatype == + xnn_datatype_pfp32 && + subgraph->values[node->inputs[1]].datatype == + xnn_datatype_fp32 && + subgraph->values[node->outputs[0]].datatype == + xnn_datatype_fp32) { + subgraph->values[node->inputs[0]].fp16_compatible = true; + subgraph->values[node->outputs[0]].fp16_compatible = true; + if (node->num_inputs > 2 && + subgraph->values[node->inputs[2]].datatype == xnn_datatype_fp32) { + subgraph->values[node->inputs[2]].fp16_compatible = true; + } } else if (all_values_fp32(subgraph, node)) { subgraph->values[node->inputs[0]].fp16_compatible = true; subgraph->values[node->outputs[0]].fp16_compatible = true; @@ -1009,7 +1022,7 @@ bool xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph) value->fp16_id = XNN_INVALID_VALUE_ID; value->fp32_id = XNN_INVALID_VALUE_ID; if (value->fp16_compatible) { - assert(value->datatype == xnn_datatype_fp32); + assert(value->datatype == xnn_datatype_fp32 || value->datatype == xnn_datatype_pfp32); if (xnn_value_is_static(value)) { assert(value->producer == XNN_INVALID_NODE_ID); const size_t fp16_size = xnn_tensor_get_size_by_id(subgraph, n) / 2 + XNN_EXTRA_BYTES; @@ -1111,8 +1124,18 @@ bool xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph) value->first_consumer = XNN_INVALID_NODE_ID; xnn_log_debug("FP16 rewrite: created FP16 tensor #%" PRIu32 " for FP32 tensor #%" PRIu32, fp16_value->id, n); } else { - xnn_log_debug("FP16 rewrite: converted FP32 tensor #%" PRIu32 " to FP16", n); - value->datatype = xnn_datatype_fp16; + switch (value->datatype) { + case xnn_datatype_fp32: + xnn_log_debug("FP16 rewrite: converted FP32 tensor #%" PRIu32 " to FP16", n); + value->datatype = xnn_datatype_fp16; + break; + case xnn_datatype_pfp32: + xnn_log_debug("FP16 rewrite: converted PFP32 tensor #%" PRIu32 " to PFP16", n); + value->datatype = xnn_datatype_pfp16; + break; + default: + XNN_UNREACHABLE; + } } } }