Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper handling of repeated fp16 conversion. #310

Closed

Conversation

LeoZDong
Copy link

@LeoZDong LeoZDong commented Dec 23, 2024

Description

When converting a model to fp16, PR #293 changes it such that we raise an error if the model already contains Cast nodes or fp16 types, while still allowing the user to force the conversion. This means that a user could repeatedly convert a model to fp16. However, the converter currently exhibits unexpected behaviors while handling repeated fp16 conversions, including creating orphaned and self-recurring nodes (see issue #276). This PR fixes the converter behavior when handling repeated fp16 conversions.

This PR is the combination of the following 3 fixes. I'm putting the 3 fixes in one PR to not lose context about the solution to issue #276. For readability in git blame, I'm happy to break it up into 3 separate PRs.

Fix 1: When the converter inserts Cast nodes, use UUID for naming.

Previously, we use a fixed string format to name the Cast nodes inserted by the converter, which results in name collision if we repeatedly run the converter. The converter currently does not handle name collisions at all, resulting in unexpected behavior (this is a separate issue that could be fixed).

This PR changes to using UUIDs when naming inserted Cast nodes to avoid name collision.

Fix 2: Add Cast output to value info block list.

The converter has a notion of "value info block list"; the converter changes every value_info in the input graph to fp16, except for those in the value info block list. Currently, we only add the value info from the output of "op block list" or "node block list" (i.e. ops and nodes that do not go through the fp16 conversion). However, when a graph contains Cast nodes, the output of Cast nodes cannot change types because otherwise it invalidates the very cast operation itself.

This PR adds the output of Cast nodes to value info block list.

NOTE: An alternative fix is to add Cast nodes to the DEFAULT_OP_BLOCK_LIST. I dislike it because it breaks the semantics that the default list contains "ops that are not supported for float16 in ONNX Runtime" according the manual here.

Fix 3: Only remove cast pairs with fp32 input types.

The converter currently has a heuristics to remove "unnecessary" cast node pairs (defined in remove_unnecessary_cast_node) with the following pattern:

  • From upstream --> cast_to_fp16 --> cast_to_fp32 --> downstream,
  • To upstream --> downstream.

However, the cast node pair is only truly "unnecessary" if the input type is fp16. This PR adds an additional guard to remove cast node pairs with the following pattern:

  • From upstream --fp32--> cast_to_fp16 --> cast_to_fp32 --> downstream,
  • To upstream --> downstream,

while remaining the following pattern intact:

  • upstream --fp16--> cast_to_fp16 --> cast_to_fp32 --> downstream.

Testing

Using the ONNX model submitted in #276, run a simple test script:

import onnx
from onnxconverter_common import float16

model = onnx.load("orphaned_cast.onnx")
model_fp16_converted = float16.convert_float_to_float16(model, check_fp16_ready=False)
model_fp16_converted_2x = float16.convert_float_to_float16(model_fp16_converted, check_fp16_ready=False)
onnx.save(model_fp16_converted, "orphaned_cast_converted.onnx")
onnx.save(model_fp16_converted_2x, "orphaned_cast_converted_2x.onnx")

Resulting graph before fix

repeated_conversion_before_fix

Resulting graph after fix

repeated_conversion_after_fix

Note on redundant cast nodes

Note that after the fix, repeated conversion of the same model to fp16 will result in redundant cast node chains (e.g.--fp32--> cast_to_fp16 --> cast_to_fp16 --> cast_to_fp16). This is a result of the remove_unnecessary_cast_node function using pattern matching heuristics instead of a true graph walking algorithm. While it is out of the scope of this PR, we could potentially implement an algorithm that inspects all "cast chains" in the graph and shorten them to a single "effective cast" operator.

Issues closed

Closes #276

@LeoZDong LeoZDong closed this Dec 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

convert to FP16 generate orphan and self-recurring nodes
1 participant