Skip to content

Commit

Permalink
fix random.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Aug 14, 2024
1 parent bf6f16c commit 789e9ac
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions brainstate/init/_random_inits.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ def __call__(self, shape):
denominator = (fan_in + fan_out) / 2
else:
raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode))
scale = self.scale.value if isinstance(self.scale, bu.Quantity) else self.scale
dim = self.scale.dim if isinstance(self.scale, bu.Quantity) else bu.DIMENSIONLESS
scale = self.scale.mantissa if isinstance(self.scale, bu.Quantity) else self.scale
unit = bu.get_unit(self.scale)
variance = (scale / denominator).astype(self.dtype)
if self.distribution == "truncated_normal":
stddev = (jnp.sqrt(variance) / .87962566103423978).astype(self.dtype)
Expand All @@ -302,7 +302,7 @@ def __call__(self, shape):
jnp.sqrt(3 * variance).astype(self.dtype))
else:
raise ValueError("invalid distribution for variance scaling initializer")
return res if dim == bu.DIMENSIONLESS else res * dim
return res if unit.is_unitless else bu.Quantity(res, unit=unit)

def __repr__(self):
name = self.__class__.__name__
Expand Down Expand Up @@ -445,8 +445,8 @@ def __call__(self, shape):
matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
norm_dst = random.normal(size=matrix_shape, dtype=self.dtype)

scale = self.scale.value if isinstance(self.scale, bu.Quantity) else self.scale
dim = self.scale.dim if isinstance(self.scale, bu.Quantity) else bu.DIMENSIONLESS
scale = self.scale.mantissa if isinstance(self.scale, bu.Quantity) else self.scale
unit = bu.get_unit(self.scale)
q_mat, r_mat = jnp.linalg.qr(norm_dst)
# Enforce Q is uniformly distributed
q_mat *= jnp.sign(jnp.diag(r_mat))
Expand All @@ -455,7 +455,7 @@ def __call__(self, shape):
q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
q_mat = jnp.moveaxis(q_mat, 0, self.axis)
r = jnp.asarray(scale, dtype=self.dtype) * q_mat
return r if dim == bu.DIMENSIONLESS else r * dim
return r if unit.is_unitless else bu.Quantity(r, unit=unit)

def __repr__(self):
return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, dtype={self.dtype})'
Expand All @@ -480,8 +480,8 @@ def __call__(self, shape):
raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.")
if shape[-1] < shape[-2]:
raise ValueError("`fan_in` must be less or equal than `fan_out`. ")
scale = self.scale.value if isinstance(self.scale, bu.Quantity) else self.scale
dim = self.scale.dim if isinstance(self.scale, bu.Quantity) else bu.DIMENSIONLESS
scale = self.scale.mantissa if isinstance(self.scale, bu.Quantity) else self.scale
unit = bu.get_unit(self.scale)
ortho_matrix = Orthogonal(scale=scale, axis=self.axis, dtype=self.dtype)(*shape[-2:])
W = jnp.zeros(shape, dtype=self.dtype)
if len(shape) == 3:
Expand All @@ -493,7 +493,7 @@ def __call__(self, shape):
else:
k1, k2, k3 = shape[:3]
W = W.at[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2].set(ortho_matrix)
return W if dim == bu.DIMENSIONLESS else W * dim
return W if unit.is_unitless else bu.Quantity(W, unit=unit)

def __repr__(self):
return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, dtype={self.dtype})'

0 comments on commit 789e9ac

Please sign in to comment.