Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Rematerialization Explaination Help #612

Closed
zhiqi-0 opened this issue Jul 18, 2022 · 2 comments
Closed

Rematerialization Explaination Help #612

zhiqi-0 opened this issue Jul 18, 2022 · 2 comments

Comments

@zhiqi-0
Copy link

zhiqi-0 commented Jul 18, 2022

Hi, thanks for your great project!

I'm curious about the rematerialization (remat) mechanism in the project, could you help explain how this will work when the option turns on? For example, will the mechanism follow with the Megatron that all operators will be re-computed during backward to save memory?

@ZYHowell
Copy link
Collaborator

If turned on, Alpa does remat for each 'layer': at here, it wraps all JaxprEqns of the layer's forward by JAX's remat_call_p. When JAX is tracing the computation and meets remat_call_p, it will automatically generate a remat part at the backward computation. Since each equation belongs to a 'layer', the answer to

will the mechanism follow with the Megatron that all operators will be re-computed during backward

is yes.

'layer' can be manually assigned or automatically generated. If you want to manually assign layers, you just need to add mark_pipeline_boundary() between two layers, then set layer_option of PipeshardParallel to ManualLayerOption whose remat_layer=True. Otherwise, you set the layer_option to AutoLayerOption to let Alpa use its layer cluster algorithm to slice the computation into layers and do remat accordingly.

There are some issues related to remat+rng and they are WIP: #535 and #592

@zhiqi-0
Copy link
Author

zhiqi-0 commented Jul 19, 2022

Thanks, this helped a lot!

@zhiqi-0 zhiqi-0 closed this as completed Jul 19, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants