Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] LoRA tests #9481

Merged
merged 5 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions tests/lora/test_lora_layers_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = CogVideoXPipeline
scheduler_cls = CogVideoXDPMScheduler
scheduler_kwargs = {"timestep_spacing": "trailing"}
scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler]

transformer_kwargs = {
"num_attention_heads": 4,
Expand Down Expand Up @@ -126,8 +127,7 @@ def get_dummy_inputs(self, with_generator=True):

@skip_mps
def test_lora_fuse_nan(self):
scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler]
for scheduler_cls in scheduler_classes:
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
Expand Down Expand Up @@ -156,10 +156,22 @@ def test_lora_fuse_nan(self):
self.assertTrue(np.isnan(out).all())

def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=5e-3)
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)

def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=5e-3)
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)

@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass

@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass

@unittest.skip("Not supported in CogVideoX.")
def test_modify_padding_mode(self):
pass

@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_partial_text_lora(self):
Expand Down
10 changes: 9 additions & 1 deletion tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_kwargs = {}
uses_flow_matching = True
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = {
"patch_size": 1,
"in_channels": 4,
Expand Down Expand Up @@ -154,6 +154,14 @@ def test_with_alpha_in_state_dict(self):
)
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))

@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass

@unittest.skip("Not supported in Flux.")
def test_modify_padding_mode(self):
pass


@slow
@require_torch_gpu
Expand Down
18 changes: 17 additions & 1 deletion tests/lora/test_lora_layers_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
uses_flow_matching = True
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = {
"sample_size": 32,
"patch_size": 1,
Expand Down Expand Up @@ -92,3 +92,19 @@ def test_sd3_lora(self):

lora_filename = "lora_peft_format.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)

@unittest.skip("Not supported in SD3.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass

@unittest.skip("Not supported in SD3.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass

@unittest.skip("Not supported in SD3.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass

@unittest.skip("Not supported in SD3.")
def test_modify_padding_mode(self):
pass
Loading
Loading