You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.
I'm curious about the rematerialization (remat) mechanism in the project, could you help explain how this will work when the option turns on? For example, will the mechanism follow with the Megatron that all operators will be re-computed during backward to save memory?
The text was updated successfully, but these errors were encountered:
If turned on, Alpa does remat for each 'layer': at here, it wraps all JaxprEqns of the layer's forward by JAX's remat_call_p. When JAX is tracing the computation and meets remat_call_p, it will automatically generate a remat part at the backward computation. Since each equation belongs to a 'layer', the answer to
will the mechanism follow with the Megatron that all operators will be re-computed during backward
is yes.
'layer' can be manually assigned or automatically generated. If you want to manually assign layers, you just need to add mark_pipeline_boundary() between two layers, then set layer_option of PipeshardParallel to ManualLayerOption whose remat_layer=True. Otherwise, you set the layer_option to AutoLayerOption to let Alpa use its layer cluster algorithm to slice the computation into layers and do remat accordingly.
There are some issues related to remat+rng and they are WIP: #535 and #592
Hi, thanks for your great project!
I'm curious about the rematerialization (remat) mechanism in the project, could you help explain how this will work when the option turns on? For example, will the mechanism follow with the Megatron that all operators will be re-computed during backward to save memory?
The text was updated successfully, but these errors were encountered: