Skip to content

Commit

Permalink
[nnx] LoRAParam inherits from Param
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jun 27, 2024
1 parent fed7756 commit aad4494
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 4 deletions.
2 changes: 0 additions & 2 deletions flax/nnx/examples/gemma/helpers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
# ============================================================================
"""Tests for helpers."""

from __future__ import annotations

from typing import Tuple

from absl.testing import absltest
Expand Down
3 changes: 2 additions & 1 deletion flax/nnx/nnx/nn/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@
default_kernel_init = initializers.lecun_normal()


class LoRAParam(variables.Variable[A]):
class LoRAParam(variables.Param[A]):
pass



class LoRA(Module):
"""A standalone LoRA layer.
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/tests/nn/lora_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __call__(self, x):
def test_lora_param_type(self):
rngs = nnx.Rngs(0)
model = nnx.LoRA(3, 4, 2, lora_param_type=nnx.LoRAParam, rngs=rngs)
_, params, lora_params = nnx.split(model, nnx.Param, nnx.LoRAParam)
_, lora_params, params = nnx.split(model, nnx.LoRAParam, nnx.Param)
assert params == {}
assert ('lora_a' in lora_params) and ('lora_b' in lora_params)
np.testing.assert_allclose(lora_params.lora_a.value, model.lora_a.value)
Expand Down

0 comments on commit aad4494

Please sign in to comment.