From 2e81cd880616a7e931261dbdffd8ca6018a4c416 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Fri, 13 Sep 2024 19:40:33 +0800 Subject: [PATCH] Implemented cpu offload (#197) * Implemented cpu offload for other frameworks * polish example code * Added test for low mem settings. * polish * format * fix arg * update pipeline and test --------- Co-authored-by: ExtremeViscent --- examples/latte/sample.py | 10 ++++++++ examples/open_sora/sample.py | 10 ++++++++ examples/open_sora_plan/sample.py | 12 ++++++++- tests/examples/test_sample.py | 23 +++++++++++------ tests/pipelines/cogvideox/test_cogvideox.py | 10 ++++++++ tests/pipelines/latte/test_latte.py | 10 ++++++++ tests/pipelines/open_sora/test_open_sora.py | 10 ++++++++ .../open_sora_plan/test_open_sora_plan.py | 10 ++++++++ videosys/core/engine.py | 6 ++--- .../autoencoders/autoencoder_kl_open_sora.py | 3 +++ .../pipelines/cogvideox/pipeline_cogvideox.py | 2 +- videosys/pipelines/latte/pipeline_latte.py | 14 +++++++++-- .../pipelines/open_sora/pipeline_open_sora.py | 25 ++++++++++++++----- .../open_sora_plan/pipeline_open_sora_plan.py | 18 +++++++++---- 14 files changed, 136 insertions(+), 27 deletions(-) diff --git a/examples/latte/sample.py b/examples/latte/sample.py index 839b08c6..ec05913e 100644 --- a/examples/latte/sample.py +++ b/examples/latte/sample.py @@ -16,6 +16,15 @@ def run_base(): engine.save_video(video, f"./outputs/{prompt}.mp4") +def run_low_mem(): + config = LatteConfig("maxin-cn/Latte-1", cpu_offload=True) + engine = VideoSysEngine(config) + + prompt = "Sunset over the sea." + video = engine.generate(prompt).video[0] + engine.save_video(video, f"./outputs/{prompt}.mp4") + + def run_pab(): config = LatteConfig("maxin-cn/Latte-1", enable_pab=True) engine = VideoSysEngine(config) @@ -27,4 +36,5 @@ def run_pab(): if __name__ == "__main__": run_base() + # run_low_mem() # run_pab() diff --git a/examples/open_sora/sample.py b/examples/open_sora/sample.py index 7af838cd..d7306e66 100644 --- a/examples/open_sora/sample.py +++ b/examples/open_sora/sample.py @@ -20,6 +20,15 @@ def run_base(): engine.save_video(video, f"./outputs/{prompt}.mp4") +def run_low_mem(): + config = OpenSoraConfig(cpu_offload=True, tiling_size=1) + engine = VideoSysEngine(config) + + prompt = "Sunset over the sea." + video = engine.generate(prompt).video[0] + engine.save_video(video, f"./outputs/{prompt}.mp4") + + def run_pab(): config = OpenSoraConfig(enable_pab=True) engine = VideoSysEngine(config) @@ -31,4 +40,5 @@ def run_pab(): if __name__ == "__main__": run_base() + # run_low_mem() # run_pab() diff --git a/examples/open_sora_plan/sample.py b/examples/open_sora_plan/sample.py index ac802c11..4ff50c67 100644 --- a/examples/open_sora_plan/sample.py +++ b/examples/open_sora_plan/sample.py @@ -16,8 +16,17 @@ def run_base(): engine.save_video(video, f"./outputs/{prompt}.mp4") +def run_low_mem(): + config = OpenSoraPlanConfig(cpu_offload=True, enable_tiling=True) + engine = VideoSysEngine(config) + + prompt = "Sunset over the sea." + video = engine.generate(prompt).video[0] + engine.save_video(video, f"./outputs/{prompt}.mp4") + + def run_pab(): - config = OpenSoraPlanConfig(num_gpus=1, enable_pab=True) + config = OpenSoraPlanConfig(enable_pab=True) engine = VideoSysEngine(config) prompt = "Sunset over the sea." @@ -27,4 +36,5 @@ def run_pab(): if __name__ == "__main__": run_base() + # run_low_mem() # run_pab() diff --git a/tests/examples/test_sample.py b/tests/examples/test_sample.py index eb8c07b7..8e4c749f 100644 --- a/tests/examples/test_sample.py +++ b/tests/examples/test_sample.py @@ -12,12 +12,19 @@ import examples.open_sora.sample as open_sora import examples.open_sora_plan.sample as open_sora_plan +files = [cogvideox, latte, open_sora, open_sora_plan] +members = [] -@pytest.mark.parametrize("file", [cogvideox, latte, open_sora, open_sora_plan]) -def test_examples(file): - funcs = inspect.getmembers(file, inspect.isfunction) - for name, func in funcs: - try: - func() - except Exception as e: - raise Exception(f"Failed to run {name} in {file.__file__}") from e +for file in files: + for m in inspect.getmembers(file, inspect.isfunction): + members.append(m) +print(members) + + +@pytest.mark.parametrize("members", members) +def test_examples(members): + name, func = members + try: + func() + except Exception as e: + raise Exception(f"Failed to run {name} in {file.__file__}") from e diff --git a/tests/pipelines/cogvideox/test_cogvideox.py b/tests/pipelines/cogvideox/test_cogvideox.py index 97c01796..9257aa41 100644 --- a/tests/pipelines/cogvideox/test_cogvideox.py +++ b/tests/pipelines/cogvideox/test_cogvideox.py @@ -21,3 +21,13 @@ def test_pab(num_gpus): prompt = "Sunset over the sea." video = engine.generate(prompt).video[0] engine.save_video(video, f"./test_outputs/{prompt}_cogvideo_pab_{num_gpus}.mp4") + + +@pytest.mark.parametrize("num_gpus", [1]) +def test_low_mem(num_gpus): + config = CogVideoXConfig(num_gpus=num_gpus, cpu_offload=True, vae_tiling=True) + engine = VideoSysEngine(config) + + prompt = "Sunset over the sea." + video = engine.generate(prompt).video[0] + engine.save_video(video, f"./test_outputs/{prompt}_cogvideo_low_mem_{num_gpus}.mp4") diff --git a/tests/pipelines/latte/test_latte.py b/tests/pipelines/latte/test_latte.py index ab135b3c..e5838605 100644 --- a/tests/pipelines/latte/test_latte.py +++ b/tests/pipelines/latte/test_latte.py @@ -21,3 +21,13 @@ def test_pab(num_gpus): prompt = "Sunset over the sea." video = engine.generate(prompt).video[0] engine.save_video(video, f"./test_outputs/{prompt}_latte_pab_{num_gpus}.mp4") + + +@pytest.mark.parametrize("num_gpus", [1]) +def test_low_mem(num_gpus): + config = LatteConfig(num_gpus=num_gpus, cpu_offload=True) + engine = VideoSysEngine(config) + + prompt = "Sunset over the sea." + video = engine.generate(prompt).video[0] + engine.save_video(video, f"./test_outputs/{prompt}_latte_low_mem_{num_gpus}.mp4") diff --git a/tests/pipelines/open_sora/test_open_sora.py b/tests/pipelines/open_sora/test_open_sora.py index a811742a..f42de1ee 100644 --- a/tests/pipelines/open_sora/test_open_sora.py +++ b/tests/pipelines/open_sora/test_open_sora.py @@ -21,3 +21,13 @@ def test_pab(num_gpus): prompt = "Sunset over the sea." video = engine.generate(prompt).video[0] engine.save_video(video, f"./test_outputs/{prompt}_open_sora_pab_{num_gpus}.mp4") + + +@pytest.mark.parametrize("num_gpus", [1]) +def test_low_mem(num_gpus): + config = OpenSoraConfig(num_gpus=num_gpus, cpu_offload=True, tiling_size=1) + engine = VideoSysEngine(config) + + prompt = "Sunset over the sea." + video = engine.generate(prompt).video[0] + engine.save_video(video, f"./test_outputs/{prompt}_open_sora_low_mem_{num_gpus}.mp4") diff --git a/tests/pipelines/open_sora_plan/test_open_sora_plan.py b/tests/pipelines/open_sora_plan/test_open_sora_plan.py index e8a2d933..6a4278c4 100644 --- a/tests/pipelines/open_sora_plan/test_open_sora_plan.py +++ b/tests/pipelines/open_sora_plan/test_open_sora_plan.py @@ -21,3 +21,13 @@ def test_pab(num_gpus): prompt = "Sunset over the sea." video = engine.generate(prompt).video[0] engine.save_video(video, f"./test_outputs/{prompt}_open_sora_plan_pab_{num_gpus}.mp4") + + +@pytest.mark.parametrize("num_gpus", [1]) +def test_low_mem(num_gpus): + config = OpenSoraPlanConfig(num_gpus=num_gpus, cpu_offload=True, enable_tiling=True) + engine = VideoSysEngine(config) + + prompt = "Sunset over the sea." + video = engine.generate(prompt).video[0] + engine.save_video(video, f"./test_outputs/{prompt}_open_sora_plan_low_mem_{num_gpus}.mp4") diff --git a/videosys/core/engine.py b/videosys/core/engine.py index 82c62c47..6d1868f1 100644 --- a/videosys/core/engine.py +++ b/videosys/core/engine.py @@ -3,6 +3,7 @@ from typing import Any, Optional import torch +import torch.distributed as dist import videosys @@ -22,9 +23,6 @@ def __init__(self, config): def _init_worker(self, pipeline_cls): world_size = self.config.num_gpus - if "CUDA_VISIBLE_DEVICES" not in os.environ: - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(world_size)) - # Disable torch async compiling which won't work with daemonic processes os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" @@ -124,7 +122,7 @@ def save_video(self, video, output_path): def shutdown(self): if (worker_monitor := getattr(self, "worker_monitor", None)) is not None: worker_monitor.close() - torch.distributed.destroy_process_group() + dist.destroy_process_group() def __del__(self): self.shutdown() diff --git a/videosys/models/autoencoders/autoencoder_kl_open_sora.py b/videosys/models/autoencoders/autoencoder_kl_open_sora.py index 920e03fb..d073fe4a 100644 --- a/videosys/models/autoencoders/autoencoder_kl_open_sora.py +++ b/videosys/models/autoencoders/autoencoder_kl_open_sora.py @@ -670,6 +670,9 @@ def encode(self, x): return (z - self.shift) / self.scale def decode(self, z, num_frames=None): + device = z.device + self.scale = self.scale.to(device) + self.shift = self.shift.to(device) if not self.cal_loss: z = z * self.scale.to(z.dtype) + self.shift.to(z.dtype) diff --git a/videosys/pipelines/cogvideox/pipeline_cogvideox.py b/videosys/pipelines/cogvideox/pipeline_cogvideox.py index d33f689b..163832bf 100644 --- a/videosys/pipelines/cogvideox/pipeline_cogvideox.py +++ b/videosys/pipelines/cogvideox/pipeline_cogvideox.py @@ -114,7 +114,7 @@ def __init__( class CogVideoXPipeline(VideoSysPipeline): - _optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"] + _optional_components = ["text_encoder", "tokenizer"] model_cpu_offload_seq = "text_encoder->transformer->vae" _callback_tensor_inputs = [ "latents", diff --git a/videosys/pipelines/latte/pipeline_latte.py b/videosys/pipelines/latte/pipeline_latte.py index f0689fc4..7c1d590c 100644 --- a/videosys/pipelines/latte/pipeline_latte.py +++ b/videosys/pipelines/latte/pipeline_latte.py @@ -138,6 +138,8 @@ def __init__( beta_end: float = 0.02, beta_schedule: str = "linear", variance_type: str = "learned_range", + # ======= memory ======= + cpu_offload: bool = False, # ======= pab ======== enable_pab: bool = False, pab_config: PABConfig = LattePABConfig(), @@ -148,6 +150,8 @@ def __init__( self.num_gpus = num_gpus # ======= vae ======== self.enable_vae_temporal_decoder = enable_vae_temporal_decoder + # ======= memory ======== + self.cpu_offload = cpu_offload # ======= scheduler ======== self.beta_start = beta_start self.beta_end = beta_end @@ -235,12 +239,18 @@ def __init__( set_pab_manager(config.pab_config) # set eval and device - self.set_eval_and_device(device, text_encoder, vae, transformer) + self.set_eval_and_device(device, vae, transformer) self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) + # cpu offload + if config.cpu_offload: + self.enable_model_cpu_offload() + else: + self.set_eval_and_device(device, text_encoder) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -744,7 +754,7 @@ def generate( else: batch_size = prompt_embeds.shape[0] - device = self.text_encoder.device or self._execution_device + device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` diff --git a/videosys/pipelines/open_sora/pipeline_open_sora.py b/videosys/pipelines/open_sora/pipeline_open_sora.py index 9c06000e..5fd34734 100644 --- a/videosys/pipelines/open_sora/pipeline_open_sora.py +++ b/videosys/pipelines/open_sora/pipeline_open_sora.py @@ -132,6 +132,8 @@ def __init__( # ======== scheduler ======== num_sampling_steps: int = 30, cfg_scale: float = 7.0, + # ======= memory ======= + cpu_offload: bool = False, # ======== vae ======== tiling_size: int = 4, # ======== speedup ======== @@ -151,6 +153,8 @@ def __init__( self.cfg_scale = cfg_scale # ======== vae ======== self.tiling_size = tiling_size + # ======= memory ======== + self.cpu_offload = cpu_offload # ======== speedup ======== self.enable_flash_attn = enable_flash_attn # ======== pab ======== @@ -184,7 +188,10 @@ class OpenSoraPipeline(VideoSysPipeline): r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" ) # noqa - _optional_components = ["tokenizer", "text_encoder"] + _optional_components = [ + "text_encoder", + "tokenizer", + ] model_cpu_offload_seq = "text_encoder->transformer->vae" def __init__( @@ -228,12 +235,18 @@ def __init__( set_pab_manager(config.pab_config) # set eval and device - self.set_eval_and_device(device, text_encoder, vae, transformer) + self.set_eval_and_device(device, vae, transformer) self.register_modules( text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler, tokenizer=tokenizer ) + # cpu offload + if config.cpu_offload: + self.enable_model_cpu_offload() + else: + self.set_eval_and_device(self._device, text_encoder) + def get_text_embeddings(self, texts): text_tokens_and_mask = self.tokenizer( texts, @@ -244,9 +257,9 @@ def get_text_embeddings(self, texts): add_special_tokens=True, return_tensors="pt", ) - - input_ids = text_tokens_and_mask["input_ids"].to(self.device) - attention_mask = text_tokens_and_mask["attention_mask"].to(self.device) + device = self._execution_device + input_ids = text_tokens_and_mask["input_ids"].to(device) + attention_mask = text_tokens_and_mask["attention_mask"].to(device) with torch.no_grad(): text_encoder_embs = self.text_encoder( input_ids=input_ids, @@ -260,7 +273,7 @@ def encode_prompt(self, text): return dict(y=caption_embs, mask=emb_masks) def null_embed(self, n): - null_y = self.transformer.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None] + null_y = self.transformer.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None].to(self._execution_device) return null_y @staticmethod diff --git a/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py b/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py index b267714f..b973361b 100644 --- a/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py +++ b/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py @@ -169,7 +169,8 @@ def __init__( num_frames: int = 65, # ======= distributed ======== num_gpus: int = 1, - # ======= vae ======= + # ======= memory ======= + cpu_offload: bool = False, enable_tiling: bool = True, tile_overlap_factor: float = 0.25, # ======= pab ======== @@ -185,7 +186,8 @@ def __init__( self.version = f"{num_frames}x512x512" # ======= distributed ======== self.num_gpus = num_gpus - # ======= vae ======== + # ======= memory ======== + self.cpu_offload = cpu_offload self.enable_tiling = enable_tiling self.tile_overlap_factor = tile_overlap_factor # ======= pab ======== @@ -256,7 +258,7 @@ def __init__( transformer.force_images = False # set eval and device - self.set_eval_and_device(device, text_encoder, vae, transformer) + self.set_eval_and_device(device, vae, transformer) # pab if config.enable_pab: @@ -266,6 +268,12 @@ def __init__( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) + # cpu offload + if config.cpu_offload: + self.enable_model_cpu_offload() + else: + self.set_eval_and_device(device, text_encoder) + # self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py @@ -320,7 +328,7 @@ def encode_prompt( embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None if device is None: - device = self.text_encoder.device or self._execution_device + device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -774,7 +782,7 @@ def generate( else: batch_size = prompt_embeds.shape[0] - device = self.text_encoder.device or self._execution_device + device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`