Refactor LXU cache logic in TBE fwd training #1295
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
The LXU cache logic is in the critical path of the forward TBE kernel.
Even when the LXU cache is not used, the kernel still checks whether a
row should be fetched from the cache or HBM at runtime. The branching
logic should be harmless for the memory (subsystem) bound case.
However, it could add significant overhead if TBE is conditional
bound. (We have observed that FP16 weight type is generally compute
or conditional bound, while FP32 weight type is memory bound.)
This diff adds a static conditional in the forward TBE kernel to
enable/disable the LXU cache code path at compile time. At runtime,
the host selects the kernel with/without cache enabled based on
whether the LXU cache is present.
This diff also moves the conditional outside the D loop. It should
add a small benefit for the large D cases when cache is used.
Differential Revision: D39353035