diff --git a/docs/nnx/nnx_basics.ipynb b/docs/nnx/nnx_basics.ipynb index ae12ce522e..344b6cc9c2 100644 --- a/docs/nnx/nnx_basics.ipynb +++ b/docs/nnx/nnx_basics.ipynb @@ -166,7 +166,7 @@ " self.count = Count(jnp.array(0))\n", "\n", " def __call__(self):\n", - " self.count.value += 1\n", + " self.count += 1\n", "\n", "counter = Counter()\n", "print(f'{counter.count.value = }')\n", @@ -480,8 +480,8 @@ " self.count = Count(0)\n", "\n", " def __call__(self, x: jax.Array):\n", - " self.count.value += 1\n", - " return x @ self.w.value + self.b.value\n", + " self.count += 1\n", + " return x @ self.w + self.b\n", " \n", "model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0))\n", "nnx.display(model)" diff --git a/docs/nnx/nnx_basics.md b/docs/nnx/nnx_basics.md index 1bc9c503ec..6b83888edd 100644 --- a/docs/nnx/nnx_basics.md +++ b/docs/nnx/nnx_basics.md @@ -95,7 +95,7 @@ class Counter(nnx.Module): self.count = Count(jnp.array(0)) def __call__(self): - self.count.value += 1 + self.count += 1 counter = Counter() print(f'{counter.count.value = }') @@ -282,8 +282,8 @@ class StatefulLinear(nnx.Module): self.count = Count(0) def __call__(self, x: jax.Array): - self.count.value += 1 - return x @ self.w.value + self.b.value + self.count += 1 + return x @ self.w + self.b model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0)) nnx.display(model) diff --git a/flax/nnx/nnx/variables.py b/flax/nnx/nnx/variables.py index 1bed4ded51..1f1c8e84ae 100644 --- a/flax/nnx/nnx/variables.py +++ b/flax/nnx/nnx/variables.py @@ -465,7 +465,15 @@ def on_remove_axis( ) -> V: raise NotImplementedError - # operator overloads + # overloads + @property + def shape(self) -> tuple[int, ...]: + return self.value.shape # type: ignore + + @property + def dtype(self) -> Any: + return self.value.dtype # type: ignore + def __jax_array__(self): return self.value @@ -556,44 +564,109 @@ def __rxor__(self, other) -> A: def __ror__(self, other) -> A: return self.value.__ror__(other) # type: ignore - def __iadd__(self, other) -> A: - return self.value.__iadd__(other) # type: ignore + def __iadd__(self: V, other) -> V: + value = self.value + if hasattr(value, '__iadd__'): + value.__iadd__(other) + else: + self.value = value.__add__(other) + return self - def __isub__(self, other) -> A: - return self.value.__isub__(other) # type: ignore + def __isub__(self: V, other) -> V: + value = self.value + if hasattr(value, '__isub__'): + value.__isub__(other) + else: + self.value = value.__sub__(other) + return self - def __imul__(self, other) -> A: - return self.value.__imul__(other) # type: ignore + def __imul__(self: V, other) -> V: + value = self.value + if hasattr(value, '__imul__'): + value.__imul__(other) + else: + self.value = value.__mul__(other) + return self - def __imatmul__(self, other) -> A: - return self.value.__imatmul__(other) # type: ignore + def __imatmul__(self: V, other) -> V: + value = self.value + if hasattr(value, '__imatmul__'): + value.__imatmul__(other) + else: + self.value = value.__matmul__(other) + return self - def __itruediv__(self, other) -> A: - return self.value.__itruediv__(other) # type: ignore + def __itruediv__(self: V, other) -> V: + value = self.value + if hasattr(value, '__itruediv__'): + value.__itruediv__(other) + else: + self.value = value.__truediv__(other) + return self - def __ifloordiv__(self, other) -> A: - return self.value.__ifloordiv__(other) # type: ignore + def __ifloordiv__(self: V, other) -> V: + value = self.value + if hasattr(value, '__ifloordiv__'): + value.__ifloordiv__(other) + else: + self.value = value.__floordiv__(other) + return self - def __imod__(self, other) -> A: - return self.value.__imod__(other) # type: ignore + def __imod__(self: V, other) -> V: + value = self.value + if hasattr(value, '__imod__'): + value.__imod__(other) + else: + self.value = value.__mod__(other) + return self - def __ipow__(self, other) -> A: - return self.value.__ipow__(other) # type: ignore + def __ipow__(self: V, other) -> V: + value = self.value + if hasattr(value, '__ipow__'): + value.__ipow__(other) + else: + self.value = value.__pow__(other) + return self - def __ilshift__(self, other) -> A: - return self.value.__ilshift__(other) # type: ignore + def __ilshift__(self: V, other) -> V: + value = self.value + if hasattr(value, '__ilshift__'): + value.__ilshift__(other) + else: + self.value = value.__lshift__(other) + return self - def __irshift__(self, other) -> A: - return self.value.__irshift__(other) # type: ignore + def __irshift__(self: V, other) -> V: + value = self.value + if hasattr(value, '__irshift__'): + value.__irshift__(other) + else: + self.value = value.__rshift__(other) + return self - def __iand__(self, other) -> A: - return self.value.__iand__(other) # type: ignore + def __iand__(self: V, other) -> V: + value = self.value + if hasattr(value, '__iand__'): + value.__iand__(other) + else: + self.value = value.__and__(other) + return self - def __ixor__(self, other) -> A: - return self.value.__ixor__(other) # type: ignore + def __ixor__(self: V, other) -> V: + value = self.value + if hasattr(value, '__ixor__'): + value.__ixor__(other) + else: + self.value = value.__xor__(other) + return self - def __ior__(self, other) -> A: - return self.value.__ior__(other) # type: ignore + def __ior__(self: V, other) -> V: + value = self.value + if hasattr(value, '__ior__'): + value.__ior__(other) + else: + self.value = value.__or__(other) + return self def __neg__(self) -> A: return self.value.__neg__() # type: ignore