-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversion from FP32 to Mixed Precision Models #14584
Comments
Hey, this is the MXNet Label Bot. |
Hi @anirudh2290, It's really good to have AMP supported for symbolic models. One thing I want to mention is, we want to let AMP support all kinds of low precision float, not only |
@ZhennanQin Thanks for the suggestion! Yes the plan is tho keep the NNVM pass target_dtype agnostic and allow for easy extension for BF16. In future, expectation is amp_cast and amp_multicast ops should support BF16 apart from FP16 and there should be corresponding lists like these for (https://github.com/apache/incubator-mxnet/pull/14173/files#diff-b79bfa3e02355c43ca5b195ef67172a5R21) BF16 too. |
Hi, Also, would love to know if you quickly guide me how to cast gluon models into fp16. |
@AnaRhisT94 yes you can do that if you want to run your entire model in FP16 precision. You can do that by casting your inputs to FP16 and casting your params to FP16. You can look at the FP16 tutorial on how to do this: http://mxnet.incubator.apache.org/versions/master/faq/float16.html?highlight=mixed#using-the-gluon-api . This particular AMP feature would help for situations where you want to run specific layers in FP16 while others like softmax in FP32 and also want to be able to select which layers to run in FP16 versus FP32. |
API Addition
Users want to bring a FP32 model and convert it to a Mixed precision model to run inference on it. They want to use the model zoo to convert pretrained models in Python and other frontends. They can do this with gluon models today by casting the inputs and the blocks but the same cannot be done for symbolic models (json and params). Proposing to add an API to convert FP32 models to FP16.
Considering the recent AMP work in progress here: #14173, I think we should add a conversion API to FP16 model under AMP namespace:
With the
target_precision_ops
,original_precision_ops
andwidest_precision_ops
, users should be able to override the default in the amp lists.Backend Changes
Additionally, Add a NNVM pass for the backend. This would by default use the amp lists for FP16, FP32 and widest type casts to use FP16 or FP32 inputs.
This pass will perform graph traversal and adding amp_cast and amp_multicast layers for FP16 and FP32 ops.
Planning to start working on the POC unless someone is already working on this.
@ptrendx @DickJC123 @pengzhao-intel @ZhennanQin @eric-haibin-lin @Caenorst
EDIT: Proposal posted on dev list: https://cwiki.apache.org/confluence/display/MXNET/Conversion+from+FP32+to+Mixed+Precision+Models
The text was updated successfully, but these errors were encountered: