Skip to content

Commit

Permalink
[nnx] fix Variable overloads and add shape/dtype properties
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jul 2, 2024
1 parent 15e0e8d commit 61343fa
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 33 deletions.
6 changes: 3 additions & 3 deletions docs/nnx/nnx_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@
" self.count = Count(jnp.array(0))\n",
"\n",
" def __call__(self):\n",
" self.count.value += 1\n",
" self.count += 1\n",
"\n",
"counter = Counter()\n",
"print(f'{counter.count.value = }')\n",
Expand Down Expand Up @@ -480,8 +480,8 @@
" self.count = Count(0)\n",
"\n",
" def __call__(self, x: jax.Array):\n",
" self.count.value += 1\n",
" return x @ self.w.value + self.b.value\n",
" self.count += 1\n",
" return x @ self.w + self.b\n",
" \n",
"model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0))\n",
"nnx.display(model)"
Expand Down
6 changes: 3 additions & 3 deletions docs/nnx/nnx_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class Counter(nnx.Module):
self.count = Count(jnp.array(0))
def __call__(self):
self.count.value += 1
self.count += 1
counter = Counter()
print(f'{counter.count.value = }')
Expand Down Expand Up @@ -282,8 +282,8 @@ class StatefulLinear(nnx.Module):
self.count = Count(0)
def __call__(self, x: jax.Array):
self.count.value += 1
return x @ self.w.value + self.b.value
self.count += 1
return x @ self.w + self.b
model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0))
nnx.display(model)
Expand Down
127 changes: 100 additions & 27 deletions flax/nnx/nnx/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,15 @@ def on_remove_axis(
) -> V:
raise NotImplementedError

# operator overloads
# overloads
@property
def shape(self) -> tuple[int, ...]:
return self.value.shape # type: ignore

@property
def dtype(self) -> Any:
return self.value.dtype # type: ignore

def __jax_array__(self):
return self.value

Expand Down Expand Up @@ -556,44 +564,109 @@ def __rxor__(self, other) -> A:
def __ror__(self, other) -> A:
return self.value.__ror__(other) # type: ignore

def __iadd__(self, other) -> A:
return self.value.__iadd__(other) # type: ignore
def __iadd__(self: V, other) -> V:
value = self.value
if hasattr(value, '__iadd__'):
value.__iadd__(other)
else:
self.value = value.__add__(other)
return self

def __isub__(self, other) -> A:
return self.value.__isub__(other) # type: ignore
def __isub__(self: V, other) -> V:
value = self.value
if hasattr(value, '__isub__'):
value.__isub__(other)
else:
self.value = value.__sub__(other)
return self

def __imul__(self, other) -> A:
return self.value.__imul__(other) # type: ignore
def __imul__(self: V, other) -> V:
value = self.value
if hasattr(value, '__imul__'):
value.__imul__(other)
else:
self.value = value.__mul__(other)
return self

def __imatmul__(self, other) -> A:
return self.value.__imatmul__(other) # type: ignore
def __imatmul__(self: V, other) -> V:
value = self.value
if hasattr(value, '__imatmul__'):
value.__imatmul__(other)
else:
self.value = value.__matmul__(other)
return self

def __itruediv__(self, other) -> A:
return self.value.__itruediv__(other) # type: ignore
def __itruediv__(self: V, other) -> V:
value = self.value
if hasattr(value, '__itruediv__'):
value.__itruediv__(other)
else:
self.value = value.__truediv__(other)
return self

def __ifloordiv__(self, other) -> A:
return self.value.__ifloordiv__(other) # type: ignore
def __ifloordiv__(self: V, other) -> V:
value = self.value
if hasattr(value, '__ifloordiv__'):
value.__ifloordiv__(other)
else:
self.value = value.__floordiv__(other)
return self

def __imod__(self, other) -> A:
return self.value.__imod__(other) # type: ignore
def __imod__(self: V, other) -> V:
value = self.value
if hasattr(value, '__imod__'):
value.__imod__(other)
else:
self.value = value.__mod__(other)
return self

def __ipow__(self, other) -> A:
return self.value.__ipow__(other) # type: ignore
def __ipow__(self: V, other) -> V:
value = self.value
if hasattr(value, '__ipow__'):
value.__ipow__(other)
else:
self.value = value.__pow__(other)
return self

def __ilshift__(self, other) -> A:
return self.value.__ilshift__(other) # type: ignore
def __ilshift__(self: V, other) -> V:
value = self.value
if hasattr(value, '__ilshift__'):
value.__ilshift__(other)
else:
self.value = value.__lshift__(other)
return self

def __irshift__(self, other) -> A:
return self.value.__irshift__(other) # type: ignore
def __irshift__(self: V, other) -> V:
value = self.value
if hasattr(value, '__irshift__'):
value.__irshift__(other)
else:
self.value = value.__rshift__(other)
return self

def __iand__(self, other) -> A:
return self.value.__iand__(other) # type: ignore
def __iand__(self: V, other) -> V:
value = self.value
if hasattr(value, '__iand__'):
value.__iand__(other)
else:
self.value = value.__and__(other)
return self

def __ixor__(self, other) -> A:
return self.value.__ixor__(other) # type: ignore
def __ixor__(self: V, other) -> V:
value = self.value
if hasattr(value, '__ixor__'):
value.__ixor__(other)
else:
self.value = value.__xor__(other)
return self

def __ior__(self, other) -> A:
return self.value.__ior__(other) # type: ignore
def __ior__(self: V, other) -> V:
value = self.value
if hasattr(value, '__ior__'):
value.__ior__(other)
else:
self.value = value.__or__(other)
return self

def __neg__(self) -> A:
return self.value.__neg__() # type: ignore
Expand Down

0 comments on commit 61343fa

Please sign in to comment.