Skip to content

Commit

Permalink
lycoris initialization testcases
Browse files Browse the repository at this point in the history
  • Loading branch information
yaswanth19 committed Oct 30, 2024
1 parent 0434e13 commit 1113989
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 2 deletions.
2 changes: 0 additions & 2 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@
),
("Vanilla MLP 7 LOKR", "MLP", LoKrConfig, {"target_modules": "lin0", "rank_dropout": 0.5}),
("Vanilla MLP 8 LOKR", "MLP", LoKrConfig, {"target_modules": "lin0", "decompose_both": True, "r": 1, "alpha": 1}),
("Vanilla MLP 9 LOKR", "MLP", LoKrConfig, {"target_modules": "lin0", "init_weights": "lycoris"}),
("Conv2d 1 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d"]}),
("Conv2d 2 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d", "lin0"]}),
("Conv2d 3 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d"], "use_effective_conv2d": True}),
Expand Down Expand Up @@ -264,7 +263,6 @@
"decompose_factor": 4,
},
),
("Conv2d 8 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d"], "init_weights": "lycoris"}),
########
# OFT #
########
Expand Down
89 changes: 89 additions & 0 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from peft import (
AdaLoraConfig,
IA3Config,
LoKrConfig,
LoraConfig,
PeftMixedModel,
PeftModel,
Expand Down Expand Up @@ -1058,6 +1059,94 @@ def test_lora_use_dora_with_megatron_core_raises(self):
LoraConfig(target_modules=["linear"], use_dora=True, megatron_config=megatron_config)


class TestLokrInitialization:
torch_device = infer_device()

def get_model(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
# Choose a large weight so that averages are close to expected values.
self.linear = nn.Linear(1000, 1000)
self.conv2d = nn.Conv2d(100, 100, 3)

def forward(self, x):
x_4d = x.flatten().reshape(1, 100, 10, 10)
return self.linear(x), self.conv2d(x_4d)

return MyModule().eval().to(self.torch_device)

@pytest.fixture
def data(self):
return torch.rand(10, 1000).to(self.torch_device)

def test_lokr_linear_init_default(self, data):
torch.manual_seed(0)

model = self.get_model()
output_before = model(data)[0]
config = LoKrConfig(target_modules=["linear"])
model = get_peft_model(model, config)
output_after = model(data)[0]

assert torch.allclose(output_before, output_after)

def test_lokr_linear_init_false(self, data):
torch.manual_seed(0)

model = self.get_model()
output_before = model(data)[0]
config = LoKrConfig(target_modules=["linear"], init_weights=False)
model = get_peft_model(model, config)
output_after = model(data)[0]

assert not torch.allclose(output_before, output_after)

def test_lokr_linear_init_lycoris(self, data):
torch.manual_seed(0)

model = self.get_model()
output_before = model(data)[0]
config = LoKrConfig(target_modules=["linear"], init_weights="lycoris")
model = get_peft_model(model, config)
output_after = model(data)[0]

assert torch.allclose(output_before, output_after)

def test_lokr_conv2d_init_default(self, data):
torch.manual_seed(0)

model = self.get_model()
output_before = model(data)[1]
config = LoKrConfig(target_modules=["conv2d"])
model = get_peft_model(model, config)
output_after = model(data)[1]

assert torch.allclose(output_before, output_after)

def test_lokr_conv2d_init_false(self, data):
torch.manual_seed(0)

model = self.get_model()
output_before = model(data)[1]
config = LoKrConfig(target_modules=["conv2d"], init_weights=False)
model = get_peft_model(model, config)
output_after = model(data)[1]

assert not torch.allclose(output_before, output_after)

def test_lokr_conv2d_init_lycoris(self, data):
torch.manual_seed(0)

model = self.get_model()
output_before = model(data)[1]
config = LoKrConfig(target_modules=["conv2d"], init_weights="lycoris")
model = get_peft_model(model, config)
output_after = model(data)[1]

assert torch.allclose(output_before, output_after)


class TestAdaLoraInitialization:
torch_device = infer_device()

Expand Down

0 comments on commit 1113989

Please sign in to comment.