Skip to content

Commit

Permalink
replace references to deprecated KeyArray & PRNGKeyArray (huggingface…
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp authored Oct 9, 2023
1 parent 35952e6 commit a844065
Show file tree
Hide file tree
Showing 15 changed files with 28 additions and 26 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@
"importlib_metadata",
"invisible-watermark>=0.2.0",
"isort>=5.5.4",
"jax>=0.2.8,!=0.3.2",
"jaxlib>=0.1.65",
"jax>=0.4.1",
"jaxlib>=0.4.1",
"Jinja2",
"k-diffusion>=0.0.12",
"torchsde",
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
"importlib_metadata": "importlib_metadata",
"invisible-watermark": "invisible-watermark>=0.2.0",
"isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2",
"jaxlib": "jaxlib>=0.1.65",
"jax": "jax>=0.4.1",
"jaxlib": "jaxlib>=0.4.1",
"Jinja2": "Jinja2",
"k-diffusion": "k-diffusion>=0.0.12",
"torchsde": "torchsde",
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/controlnet_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
controlnet_conditioning_channel_order: str = "rgb"
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)

def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
def init_weights(self, rng: jax.Array) -> FrozenDict:
# init input tensors
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
```"""
return self._cast_floating_to(params, jnp.float16, mask)

def init_weights(self, rng: jax.random.KeyArray) -> Dict:
def init_weights(self, rng: jax.Array) -> Dict:
raise NotImplementedError(f"init_weights method has to be implemented for {self}")

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/unet_2d_condition_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
addition_embed_type_num_heads: int = 64
projection_class_embeddings_input_dim: Optional[int] = None

def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
def init_weights(self, rng: jax.Array) -> FrozenDict:
# init input tensors
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/vae_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ def setup(self):
dtype=self.dtype,
)

def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
def init_weights(self, rng: jax.Array) -> FrozenDict:
# init input tensors
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _generate(
prompt_ids: jnp.array,
image: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
prng_seed: jax.Array,
num_inference_steps: int,
guidance_scale: float,
latents: Optional[jnp.array] = None,
Expand Down Expand Up @@ -351,7 +351,7 @@ def __call__(
prompt_ids: jnp.array,
image: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
prng_seed: jax.Array,
num_inference_steps: int = 50,
guidance_scale: Union[float, jnp.array] = 7.5,
latents: jnp.array = None,
Expand All @@ -370,7 +370,7 @@ def __call__(
Array representing the ControlNet input condition to provide guidance to the `unet` for generation.
params (`Dict` or `FrozenDict`):
Dictionary containing the model parameters/weights.
prng_seed (`jax.random.KeyArray` or `jax.Array`):
prng_seed (`jax.Array` or `jax.Array`):
Array containing random number generator key.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _generate(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
prng_seed: jax.Array,
num_inference_steps: int,
height: int,
width: int,
Expand Down Expand Up @@ -312,7 +312,7 @@ def __call__(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
prng_seed: jax.Array,
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _generate(
prompt_ids: jnp.array,
image: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
prng_seed: jax.Array,
start_timestep: int,
num_inference_steps: int,
height: int,
Expand Down Expand Up @@ -340,7 +340,7 @@ def __call__(
prompt_ids: jnp.array,
image: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
prng_seed: jax.Array,
strength: float = 0.8,
num_inference_steps: int = 50,
height: Optional[int] = None,
Expand All @@ -361,7 +361,7 @@ def __call__(
Array representing an image batch to be used as the starting point.
params (`Dict` or `FrozenDict`):
Dictionary containing the model parameters/weights.
prng_seed (`jax.random.KeyArray` or `jax.Array`):
prng_seed (`jax.Array` or `jax.Array`):
Array containing random number generator key.
strength (`float`, *optional*, defaults to 0.8):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def _generate(
mask: jnp.array,
masked_image: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
prng_seed: jax.Array,
num_inference_steps: int,
height: int,
width: int,
Expand Down Expand Up @@ -398,7 +398,7 @@ def __call__(
mask: jnp.array,
masked_image: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
prng_seed: jax.Array,
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

def init_weights(self, rng: jax.random.KeyArray, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
def init_weights(self, rng: jax.Array, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensor
clip_input = jax.random.normal(rng, input_shape)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __call__(
self,
prompt_ids: jax.Array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
prng_seed: jax.Array,
num_inference_steps: int = 50,
guidance_scale: Union[float, jax.Array] = 7.5,
height: Optional[int] = None,
Expand Down Expand Up @@ -170,7 +170,7 @@ def _generate(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
prng_seed: jax.Array,
num_inference_steps: int,
height: int,
width: int,
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_ddpm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def step(
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
key: Optional[jax.random.KeyArray] = None,
key: Optional[jax.Array] = None,
return_dict: bool = True,
) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
"""
Expand All @@ -211,7 +211,7 @@ def step(
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
key (`jax.random.KeyArray`): a PRNG key.
key (`jax.Array`): a PRNG key.
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
Returns:
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/schedulers/scheduling_karras_ve_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Optional, Tuple, Union

import flax
import jax
import jax.numpy as jnp
from jax import random

Expand Down Expand Up @@ -139,7 +140,7 @@ def add_noise_to_input(
state: KarrasVeSchedulerState,
sample: jnp.ndarray,
sigma: float,
key: random.KeyArray,
key: jax.Array,
) -> Tuple[jnp.ndarray, float]:
"""
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/schedulers/scheduling_sde_ve_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Optional, Tuple, Union

import flax
import jax
import jax.numpy as jnp
from jax import random

Expand Down Expand Up @@ -169,7 +170,7 @@ def step_pred(
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
key: random.KeyArray,
key: jax.Array,
return_dict: bool = True,
) -> Union[FlaxSdeVeOutput, Tuple]:
"""
Expand Down Expand Up @@ -228,7 +229,7 @@ def step_correct(
state: ScoreSdeVeSchedulerState,
model_output: jnp.ndarray,
sample: jnp.ndarray,
key: random.KeyArray,
key: jax.Array,
return_dict: bool = True,
) -> Union[FlaxSdeVeOutput, Tuple]:
"""
Expand Down

0 comments on commit a844065

Please sign in to comment.