Skip to content

Commit

Permalink
Allow fp16 re-write of pack-lh nodes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720489462
  • Loading branch information
alankelly authored and xnnpack-bot committed Jan 29, 2025
1 parent d7f398e commit e072477
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions src/subgraph.c
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,7 @@ bool xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph)
case xnn_node_type_global_average_pooling_2d:
case xnn_node_type_global_sum_pooling_2d:
case xnn_node_type_max_pooling_2d:
case xnn_node_type_pack_lh:
case xnn_node_type_softmax:
case xnn_node_type_space_to_depth_2d:
case xnn_node_type_static_constant_pad:
Expand Down Expand Up @@ -961,6 +962,18 @@ bool xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph)
subgraph->values[node->inputs[2]].datatype == xnn_datatype_fp32) {
subgraph->values[node->inputs[2]].fp16_compatible = true;
}
} else if (subgraph->values[node->inputs[0]].datatype ==
xnn_datatype_pfp32 &&
subgraph->values[node->inputs[1]].datatype ==
xnn_datatype_fp32 &&
subgraph->values[node->outputs[0]].datatype ==
xnn_datatype_fp32) {
subgraph->values[node->inputs[0]].fp16_compatible = true;
subgraph->values[node->outputs[0]].fp16_compatible = true;
if (node->num_inputs > 2 &&
subgraph->values[node->inputs[2]].datatype == xnn_datatype_fp32) {
subgraph->values[node->inputs[2]].fp16_compatible = true;
}
} else if (all_values_fp32(subgraph, node)) {
subgraph->values[node->inputs[0]].fp16_compatible = true;
subgraph->values[node->outputs[0]].fp16_compatible = true;
Expand Down Expand Up @@ -1009,7 +1022,7 @@ bool xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph)
value->fp16_id = XNN_INVALID_VALUE_ID;
value->fp32_id = XNN_INVALID_VALUE_ID;
if (value->fp16_compatible) {
assert(value->datatype == xnn_datatype_fp32);
assert(value->datatype == xnn_datatype_fp32 || value->datatype == xnn_datatype_pfp32);
if (xnn_value_is_static(value)) {
assert(value->producer == XNN_INVALID_NODE_ID);
const size_t fp16_size = xnn_tensor_get_size_by_id(subgraph, n) / 2 + XNN_EXTRA_BYTES;
Expand Down Expand Up @@ -1111,8 +1124,18 @@ bool xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph)
value->first_consumer = XNN_INVALID_NODE_ID;
xnn_log_debug("FP16 rewrite: created FP16 tensor #%" PRIu32 " for FP32 tensor #%" PRIu32, fp16_value->id, n);
} else {
xnn_log_debug("FP16 rewrite: converted FP32 tensor #%" PRIu32 " to FP16", n);
value->datatype = xnn_datatype_fp16;
switch (value->datatype) {
case xnn_datatype_fp32:
xnn_log_debug("FP16 rewrite: converted FP32 tensor #%" PRIu32 " to FP16", n);
value->datatype = xnn_datatype_fp16;
break;
case xnn_datatype_pfp32:
xnn_log_debug("FP16 rewrite: converted PFP32 tensor #%" PRIu32 " to PFP16", n);
value->datatype = xnn_datatype_pfp16;
break;
default:
XNN_UNREACHABLE;
}
}
}
}
Expand Down

0 comments on commit e072477

Please sign in to comment.