diff --git a/.github/analytics/get_repo_metrics.py b/.github/analytics/get_repo_metrics.py index 5778c8c14c..cdd4d279ea 100644 --- a/.github/analytics/get_repo_metrics.py +++ b/.github/analytics/get_repo_metrics.py @@ -15,7 +15,7 @@ import json import os from datetime import datetime -from typing import Callable, List +from collections.abc import Callable import matplotlib.dates as mdates import matplotlib.pyplot as plt @@ -279,7 +279,7 @@ def _rolling_window( last_month = _start_of_month(df.iloc[-1]['created_at']) last_month = _shift_n_months(last_month, 1) - rows: List[pd.Series] = [] + rows: list[pd.Series] = [] while end < last_month: row = f(df[(df['created_at'] >= start) & (df['created_at'] < end)]) row['period_start'] = start diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1a76e5a678..0e7ad6b95b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -30,7 +30,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: '3.10' - uses: pre-commit/action@v2.0.3 commit-count: name: Check commit count @@ -63,7 +63,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.9', '3.10', '3.11'] + python-version: ['3.10', '3.11'] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -82,18 +82,16 @@ jobs: runs-on: ubuntu-20.04-16core strategy: matrix: - python-version: ['3.9', '3.10', '3.11'] + python-version: ['3.10', '3.11'] test-type: [doctest, pytest, pytype, mypy] jax-version: [newest] exclude: - - test-type: pytype - python-version: '3.9' - test-type: pytype python-version: '3.10' - test-type: mypy python-version: '3.11' include: - - python-version: '3.9' + - python-version: '3.10' test-type: pytest jax-version: '0.4.27' # keep in sync with jax pin in pyproject.toml steps: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 205494138f..2860a1d920 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,3 +44,8 @@ repos: # Disable Ruff formatter for now # # Run the Ruff formatter. # - id: ruff-format +- repo: https://github.com/asottile/pyupgrade + rev: v3.16.0 + hooks: + - id: pyupgrade + args: [--py310-plus] diff --git a/docs/_ext/codediff.py b/docs/_ext/codediff.py index a459137428..3c0a8c0248 100644 --- a/docs/_ext/codediff.py +++ b/docs/_ext/codediff.py @@ -26,7 +26,6 @@ In order to highlight a line of code, append "#!" to it. """ -from typing import List, Optional, Tuple import sphinx from docutils import nodes @@ -40,10 +39,10 @@ class CodeDiffParser: def parse( self, - lines: List[str], + lines: list[str], title: str, - groups: Optional[List[str]] = None, - skip_test: Optional[str] = None, + groups: list[str] | None = None, + skip_test: str | None = None, code_sep: str = '---', sync: object = MISSING, ): @@ -104,7 +103,7 @@ def parse( sync = sync is not MISSING # skip legacy code snippets in upgrade guides if skip_test is not None: - skip_tests = set([index.strip() for index in skip_test.split(',')]) + skip_tests = {index.strip() for index in skip_test.split(',')} else: skip_tests = set() @@ -154,7 +153,7 @@ def _code_block(self, lines): # Indent code and add empty line so the code is picked up by the directive. return directive + [''] + list(map(lambda x: ' ' + x, code)) - def _tabs(self, *contents: Tuple[str, List[str]], sync): + def _tabs(self, *contents: tuple[str, list[str]], sync): output = ['.. tab-set::'] + [' '] for title, content in contents: diff --git a/docs/conf_sphinx_patch.py b/docs/conf_sphinx_patch.py index 2c99459f37..a423b79405 100644 --- a/docs/conf_sphinx_patch.py +++ b/docs/conf_sphinx_patch.py @@ -23,7 +23,7 @@ # # We should consider sending a PR to sphinx so we can get rid of this. # Original source: https://github.com/sphinx-doc/sphinx/blob/0aedcc9a916daa92d477226da67d33ce1831822e/sphinx/ext/autosummary/generate.py#L211-L351 -from typing import Any, Dict, List, Set, Tuple +from typing import Any import sphinx.ext.autodoc import sphinx.ext.autosummary.generate as ag @@ -38,7 +38,7 @@ def generate_autosummary_content( imported_members: bool, app: Any, recursive: bool, - context: Dict, + context: dict, modname: str = None, qualname: str = None, ) -> str: @@ -61,13 +61,13 @@ def skip_member(obj: Any, name: str, objtype: str) -> bool: ) return False - def get_class_members(obj: Any) -> Dict[str, Any]: + def get_class_members(obj: Any) -> dict[str, Any]: members = sphinx.ext.autodoc.get_class_members( obj, [qualname], ag.safe_getattr ) return {name: member.object for name, member in members.items()} - def get_module_members(obj: Any) -> Dict[str, Any]: + def get_module_members(obj: Any) -> dict[str, Any]: members = {} for name in ag.members_of(obj, app.config): try: @@ -76,7 +76,7 @@ def get_module_members(obj: Any) -> Dict[str, Any]: continue return members - def get_all_members(obj: Any) -> Dict[str, Any]: + def get_all_members(obj: Any) -> dict[str, Any]: if doc.objtype == 'module': return get_module_members(obj) elif doc.objtype == 'class': @@ -85,12 +85,12 @@ def get_all_members(obj: Any) -> Dict[str, Any]: def get_members( obj: Any, - types: Set[str], - include_public: List[str] = [], + types: set[str], + include_public: list[str] = [], imported: bool = True, - ) -> Tuple[List[str], List[str]]: - items: List[str] = [] - public: List[str] = [] + ) -> tuple[list[str], list[str]]: + items: list[str] = [] + public: list[str] = [] all_members = get_all_members(obj) for name, value in all_members.items(): @@ -112,7 +112,7 @@ def get_members( public.append(name) return public, items - def get_module_attrs(members: Any) -> Tuple[List[str], List[str]]: + def get_module_attrs(members: Any) -> tuple[list[str], list[str]]: """Find module attributes with docstrings.""" attrs, public = [], [] try: @@ -127,8 +127,8 @@ def get_module_attrs(members: Any) -> Tuple[List[str], List[str]]: pass # give up if ModuleAnalyzer fails to parse code return public, attrs - def get_modules(obj: Any) -> Tuple[List[str], List[str]]: - items: List[str] = [] + def get_modules(obj: Any) -> tuple[list[str], list[str]]: + items: list[str] = [] for _, modname, _ispkg in ag.pkgutil.iter_modules(obj.__path__): fullname = name + '.' + modname try: @@ -142,7 +142,7 @@ def get_modules(obj: Any) -> Tuple[List[str], List[str]]: public = [x for x in items if not x.split('.')[-1].startswith('_')] return public, items - ns: Dict[str, Any] = {} + ns: dict[str, Any] = {} ns.update(context) if doc.objtype == 'module': diff --git a/examples/cloud/launch_gce.py b/examples/cloud/launch_gce.py index 9a8f33a78e..eca9b1301f 100644 --- a/examples/cloud/launch_gce.py +++ b/examples/cloud/launch_gce.py @@ -20,7 +20,7 @@ import re import subprocess import time -from typing import Sequence +from collections.abc import Sequence from absl import app from absl import flags diff --git a/examples/imagenet/models.py b/examples/imagenet/models.py index 362f97d52d..f41ed6ad74 100644 --- a/examples/imagenet/models.py +++ b/examples/imagenet/models.py @@ -18,7 +18,8 @@ # pytype: disable=wrong-arg-count from functools import partial -from typing import Any, Callable, Sequence, Tuple +from typing import Any, Tuple +from collections.abc import Callable, Sequence from flax import linen as nn import jax.numpy as jnp @@ -33,7 +34,7 @@ class ResNetBlock(nn.Module): conv: ModuleDef norm: ModuleDef act: Callable - strides: Tuple[int, int] = (1, 1) + strides: tuple[int, int] = (1, 1) @nn.compact def __call__( @@ -63,7 +64,7 @@ class BottleneckResNetBlock(nn.Module): conv: ModuleDef norm: ModuleDef act: Callable - strides: Tuple[int, int] = (1, 1) + strides: tuple[int, int] = (1, 1) @nn.compact def __call__(self, x): diff --git a/examples/linen_design_test/attention_simple.py b/examples/linen_design_test/attention_simple.py index 5d51672234..05d53788b5 100644 --- a/examples/linen_design_test/attention_simple.py +++ b/examples/linen_design_test/attention_simple.py @@ -14,7 +14,8 @@ import functools from pprint import pprint -from typing import Any, Callable, Optional, Sequence +from typing import Any, Optional +from collections.abc import Callable, Sequence from flax.core.frozen_dict import unfreeze from flax.linen import initializers from flax.linen import Module, compact, vmap @@ -112,8 +113,8 @@ def __call__(self, query, key, value, bias=None, dtype=jnp.float32): class DotProductAttention(Module): - qkv_features: Optional[int] = None - out_features: Optional[int] = None + qkv_features: int | None = None + out_features: int | None = None attn_module: Callable = SoftmaxAttn @compact @@ -154,8 +155,8 @@ def concise_vmap(module, in_axes, out_axes, axis_size=None, **var_specs): class MultiHeadDotProductAttention(Module): - qkv_features: Optional[int] = None - out_features: Optional[int] = None + qkv_features: int | None = None + out_features: int | None = None attn_module: Callable = SoftmaxAttn batch_axes: Sequence[int] = (0,) num_heads: int = 1 diff --git a/examples/linen_design_test/autoencoder.py b/examples/linen_design_test/autoencoder.py index d5d3d925f2..63680eca09 100644 --- a/examples/linen_design_test/autoencoder.py +++ b/examples/linen_design_test/autoencoder.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, Tuple +from typing import Tuple +from collections.abc import Iterable import jax from jax import numpy as jnp, random @@ -37,7 +38,7 @@ def __call__(self, x): class AutoEncoder(Module): encoder_widths: Iterable decoder_widths: Iterable - input_shape: Tuple = None + input_shape: tuple = None def setup(self): # Submodules attached in `setup` get names via attribute assignment diff --git a/examples/linen_design_test/dense.py b/examples/linen_design_test/dense.py index b34e83174c..45c9582155 100644 --- a/examples/linen_design_test/dense.py +++ b/examples/linen_design_test/dense.py @@ -14,7 +14,7 @@ from jax import lax from flax.linen import initializers -from typing import Callable +from collections.abc import Callable from flax.linen import Module, compact diff --git a/examples/linen_design_test/mlp_explicit.py b/examples/linen_design_test/mlp_explicit.py index 533fa54457..adc0e75693 100644 --- a/examples/linen_design_test/mlp_explicit.py +++ b/examples/linen_design_test/mlp_explicit.py @@ -24,7 +24,7 @@ # Add `in_features` to the built-in Dense layer that normally works # via shape inference. class DenseExplicit(Dense): - in_features: Optional[int] = None + in_features: int | None = None def setup(self): # We feed a fake batch through the module, which initialized parameters. diff --git a/examples/linen_design_test/mlp_inline.py b/examples/linen_design_test/mlp_inline.py index 77dbf20a09..4595a36b19 100644 --- a/examples/linen_design_test/mlp_inline.py +++ b/examples/linen_design_test/mlp_inline.py @@ -15,7 +15,7 @@ import jax from jax import numpy as jnp from flax import linen as nn -from typing import Iterable +from collections.abc import Iterable from flax.linen import Module, compact from dense import Dense diff --git a/examples/lm1b/input_pipeline.py b/examples/lm1b/input_pipeline.py index 847ab0d5cb..94d531e306 100644 --- a/examples/lm1b/input_pipeline.py +++ b/examples/lm1b/input_pipeline.py @@ -25,7 +25,7 @@ import tokenizer AUTOTUNE = tf.data.experimental.AUTOTUNE -Features = Dict[str, tf.Tensor] +Features = dict[str, tf.Tensor] class NormalizeFeatureNamesOp: @@ -68,8 +68,8 @@ def get_raw_dataset( def pack_dataset( dataset: tf.data.Dataset, - key2length: Union[int, Dict[str, int]], - keys: Optional[List[str]] = None, + key2length: int | dict[str, int], + keys: list[str] | None = None, ) -> tf.data.Dataset: """Creates a 'packed' version of a dataset on-the-fly. @@ -150,7 +150,7 @@ def my_fn(x): def _pack_with_tf_ops( - dataset: tf.data.Dataset, keys: List[str], key2length: Dict[str, int] + dataset: tf.data.Dataset, keys: list[str], key2length: dict[str, int] ) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. @@ -276,7 +276,7 @@ def true_fn(): def preprocess_data( dataset, shuffle: bool, - num_epochs: Optional[int] = 1, + num_epochs: int | None = 1, pack_examples: bool = True, shuffle_buffer_size: int = 1024, max_length: int = 512, @@ -322,7 +322,7 @@ def get_datasets( config: ml_collections.ConfigDict, *, n_devices: int, - vocab_path: Optional[str] = None, + vocab_path: str | None = None, ): """Load and return dataset of batched examples for use during training.""" if vocab_path is None: diff --git a/examples/lm1b/models.py b/examples/lm1b/models.py index 72ec1ac1cf..956c08f37a 100644 --- a/examples/lm1b/models.py +++ b/examples/lm1b/models.py @@ -23,7 +23,8 @@ # pytype: disable=wrong-keyword-args # pytype: disable=attribute-error -from typing import Callable, Any, Optional +from typing import Any, Optional +from collections.abc import Callable from flax import linen as nn from flax import struct @@ -53,7 +54,7 @@ class TransformerConfig: decode: bool = False kernel_init: Callable = nn.initializers.xavier_uniform() bias_init: Callable = nn.initializers.normal(stddev=1e-6) - posemb_init: Optional[Callable] = None + posemb_init: Callable | None = None def shift_right(x, axis=1): @@ -176,7 +177,7 @@ class MlpBlock(nn.Module): """ config: TransformerConfig - out_dim: Optional[int] = None + out_dim: int | None = None @nn.compact def __call__(self, inputs): diff --git a/examples/lm1b/tokenizer.py b/examples/lm1b/tokenizer.py index 4d8d641b3c..6f0c77be97 100644 --- a/examples/lm1b/tokenizer.py +++ b/examples/lm1b/tokenizer.py @@ -17,7 +17,8 @@ import os import tempfile import time -from typing import Any, Dict, Iterable, Tuple +from typing import Any, Dict, Tuple +from collections.abc import Iterable from absl import logging import dataclasses @@ -26,14 +27,14 @@ import tensorflow as tf import tensorflow_text as tftxt -Features = Dict[str, tf.Tensor] +Features = dict[str, tf.Tensor] def _dump_chars_to_textfile( dataset: tf.data.Dataset, maxchars: int = int(1e7), data_keys=('inputs', 'targets'), -) -> Tuple[str, int]: +) -> tuple[str, int]: """Write part of a TFDS sentence dataset to lines in a text file. Args: @@ -138,7 +139,7 @@ def load_or_train_tokenizer( vocab_path: str, vocab_size: int, max_corpus_chars: int, - data_keys: Tuple[str, str] = ('inputs', 'targets'), + data_keys: tuple[str, str] = ('inputs', 'targets'), ): """Loads the tokenizer at `vocab_path` or trains a one from `dataset`.""" try: diff --git a/examples/nlp_seq/models.py b/examples/nlp_seq/models.py index a2158323ad..cc0ed5a8d9 100644 --- a/examples/nlp_seq/models.py +++ b/examples/nlp_seq/models.py @@ -14,7 +14,8 @@ """Transformer-based language models.""" -from typing import Callable, Any, Optional +from typing import Any, Optional +from collections.abc import Callable from flax import linen as nn from flax import struct @@ -39,7 +40,7 @@ class TransformerConfig: attention_dropout_rate: float = 0.3 kernel_init: Callable = nn.initializers.xavier_uniform() bias_init: Callable = nn.initializers.normal(stddev=1e-6) - posemb_init: Optional[Callable] = None + posemb_init: Callable | None = None def sinusoidal_init(max_len=2048): @@ -121,7 +122,7 @@ class MlpBlock(nn.Module): """ config: TransformerConfig - out_dim: Optional[int] = None + out_dim: int | None = None @nn.compact def __call__(self, inputs, deterministic=True): diff --git a/examples/ogbg_molpcba/input_pipeline.py b/examples/ogbg_molpcba/input_pipeline.py index be4f1e3643..538cddc700 100644 --- a/examples/ogbg_molpcba/input_pipeline.py +++ b/examples/ogbg_molpcba/input_pipeline.py @@ -31,7 +31,7 @@ class GraphsTupleSize(NamedTuple): n_graph: int -def get_raw_datasets() -> Dict[str, tf.data.Dataset]: +def get_raw_datasets() -> dict[str, tf.data.Dataset]: """Returns datasets as tf.data.Dataset, organized by split.""" ds_builder = tfds.builder('ogbg_molpcba') ds_builder.download_and_prepare() @@ -45,7 +45,7 @@ def get_datasets( add_virtual_node: bool = True, add_undirected_edges: bool = True, add_self_loops: bool = True, -) -> Dict[str, tf.data.Dataset]: +) -> dict[str, tf.data.Dataset]: """Returns datasets of batched GraphsTuples, organized by split.""" if batch_size <= 1: raise ValueError('Batch size must be > 1 to account for padding graphs.') @@ -109,7 +109,7 @@ def get_datasets( def convert_to_graphs_tuple( - graph: Dict[str, tf.Tensor], + graph: dict[str, tf.Tensor], add_virtual_node: bool, add_undirected_edges: bool, add_self_loops: bool, diff --git a/examples/ogbg_molpcba/models.py b/examples/ogbg_molpcba/models.py index da312c6efa..d6b462dd66 100644 --- a/examples/ogbg_molpcba/models.py +++ b/examples/ogbg_molpcba/models.py @@ -14,7 +14,7 @@ """Definition of the GNN model.""" -from typing import Callable, Sequence +from collections.abc import Callable, Sequence from flax import linen as nn import jax.numpy as jnp diff --git a/examples/ogbg_molpcba/train.py b/examples/ogbg_molpcba/train.py index e64b436f4c..ab1c20e873 100644 --- a/examples/ogbg_molpcba/train.py +++ b/examples/ogbg_molpcba/train.py @@ -15,7 +15,8 @@ """Library file for executing training and evaluation on ogbg-molpcba.""" import os -from typing import Any, Dict, Iterable, Tuple, Optional +from typing import Any, Dict, Tuple, Optional +from collections.abc import Iterable from absl import logging from clu import checkpoint @@ -111,7 +112,7 @@ def predictions_match_labels( return (preds == labels).astype(jnp.float32) -def add_prefix_to_keys(result: Dict[str, Any], prefix: str) -> Dict[str, Any]: +def add_prefix_to_keys(result: dict[str, Any], prefix: str) -> dict[str, Any]: """Adds a prefix to the keys of a dict, returning a new dict.""" return {f'{prefix}_{key}': val for key, val in result.items()} @@ -172,7 +173,7 @@ def replace_globals(graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: def get_predicted_logits( state: train_state.TrainState, graphs: jraph.GraphsTuple, - rngs: Optional[Dict[str, jnp.ndarray]], + rngs: dict[str, jnp.ndarray] | None, ) -> jnp.ndarray: """Get predicted logits from the network for input graphs.""" pred_graphs = state.apply_fn(state.params, graphs, rngs=rngs) @@ -202,8 +203,8 @@ def get_valid_mask( def train_step( state: train_state.TrainState, graphs: jraph.GraphsTuple, - rngs: Dict[str, jnp.ndarray], -) -> Tuple[train_state.TrainState, metrics.Collection]: + rngs: dict[str, jnp.ndarray], +) -> tuple[train_state.TrainState, metrics.Collection]: """Performs one update step over the current batch of graphs.""" def loss_fn(params, graphs): @@ -264,9 +265,9 @@ def evaluate_step( def evaluate_model( state: train_state.TrainState, - datasets: Dict[str, tf.data.Dataset], + datasets: dict[str, tf.data.Dataset], splits: Iterable[str], -) -> Dict[str, metrics.Collection]: +) -> dict[str, metrics.Collection]: """Evaluates the model on metrics over the specified splits.""" # Loop over each split independently. diff --git a/examples/ogbg_molpcba/train_test.py b/examples/ogbg_molpcba/train_test.py index c33da1b02b..e8200bf9d5 100644 --- a/examples/ogbg_molpcba/train_test.py +++ b/examples/ogbg_molpcba/train_test.py @@ -43,7 +43,7 @@ def average_with_mask(arr: jnp.ndarray, mask: jnp.ndarray): return jnp.sum(arr) / jnp.sum(mask) -def get_dummy_raw_datasets(dataset_length) -> Dict[str, tf.data.Dataset]: +def get_dummy_raw_datasets(dataset_length) -> dict[str, tf.data.Dataset]: """Returns dummy datasets, mocking tfds.DatasetBuilder.as_dataset().""" # The dummy graph. @@ -79,8 +79,8 @@ def get_dummy_graphs(): def get_dummy_datasets( - dataset_length: int, batch_size: Optional[int] = None -) -> Dict[str, tf.data.Dataset]: + dataset_length: int, batch_size: int | None = None +) -> dict[str, tf.data.Dataset]: """Returns dummy datasets, mocking input_pipeline.get_datasets().""" datasets = get_dummy_raw_datasets(dataset_length) diff --git a/examples/ppo/agent.py b/examples/ppo/agent.py index c10d105e66..85dd2838d8 100644 --- a/examples/ppo/agent.py +++ b/examples/ppo/agent.py @@ -17,7 +17,8 @@ import collections import functools import multiprocessing -from typing import Any, Callable +from typing import Any +from collections.abc import Callable import flax import jax diff --git a/examples/ppo/ppo_lib.py b/examples/ppo/ppo_lib.py index 4c127199e0..942dffd594 100644 --- a/examples/ppo/ppo_lib.py +++ b/examples/ppo/ppo_lib.py @@ -15,7 +15,8 @@ """Library file which executes the PPO training.""" import functools -from typing import Any, Callable +from typing import Any +from collections.abc import Callable from absl import logging import flax diff --git a/examples/ppo/seed_rl_atari_preprocessing.py b/examples/ppo/seed_rl_atari_preprocessing.py index 67d03e3ad7..e8519357fb 100644 --- a/examples/ppo/seed_rl_atari_preprocessing.py +++ b/examples/ppo/seed_rl_atari_preprocessing.py @@ -74,7 +74,7 @@ def __init__( """ if frame_skip <= 0: raise ValueError( - 'Frame skip should be strictly positive, got {}'.format(frame_skip) + f'Frame skip should be strictly positive, got {frame_skip}' ) if screen_size <= 0: raise ValueError( diff --git a/examples/ppo/test_episodes.py b/examples/ppo/test_episodes.py index 3715da5bcc..b59bf683ae 100644 --- a/examples/ppo/test_episodes.py +++ b/examples/ppo/test_episodes.py @@ -15,7 +15,8 @@ """Test policy by playing a full Atari game.""" import itertools -from typing import Any, Callable +from typing import Any +from collections.abc import Callable import flax import numpy as np diff --git a/examples/seq2seq/input_pipeline.py b/examples/seq2seq/input_pipeline.py index 3d0e72caf9..72794a3969 100644 --- a/examples/seq2seq/input_pipeline.py +++ b/examples/seq2seq/input_pipeline.py @@ -15,7 +15,8 @@ """Input pipeline for seq2seq addition example.""" import random -from typing import Any, Dict, Generator, Optional, Tuple +from typing import Any, Dict, Optional, Tuple +from collections.abc import Generator import jax.numpy as jnp import numpy as np @@ -63,11 +64,11 @@ def max_output_len(self) -> int: return self._max_len_query_digit + 3 @property - def encoder_input_shape(self) -> Tuple[int, int, int]: + def encoder_input_shape(self) -> tuple[int, int, int]: return (1, self.max_input_len, self.vocab_size) @property - def decoder_input_shape(self) -> Tuple[int, int, int]: + def decoder_input_shape(self) -> tuple[int, int, int]: return (1, self.max_output_len, self.vocab_size) def encode(self, inputs: str) -> np.ndarray: @@ -91,7 +92,7 @@ def one_hot(self, tokens: np.ndarray) -> np.ndarray: return vecs def encode_onehot( - self, batch_inputs: Array, max_len: Optional[int] = None + self, batch_inputs: Array, max_len: int | None = None ) -> np.ndarray: """One-hot encodes a string input.""" @@ -115,7 +116,7 @@ def decode_onehot(self, batch_inputs: Array) -> np.ndarray: def generate_examples( self, num_examples: int - ) -> Generator[Tuple[str, str], None, None]: + ) -> Generator[tuple[str, str], None, None]: """Yields `num_examples` examples.""" for _ in range(num_examples): max_digit = pow(10, self._max_len_query_digit) - 1 @@ -126,7 +127,7 @@ def generate_examples( outputs = '=' + str(key[0] + key[1]) yield (inputs, outputs) - def get_batch(self, batch_size: int) -> Dict[str, np.ndarray]: + def get_batch(self, batch_size: int) -> dict[str, np.ndarray]: """Returns a batch of example of size @batch_size.""" inputs, outputs = zip(*self.generate_examples(batch_size)) return { diff --git a/examples/seq2seq/models.py b/examples/seq2seq/models.py index 1082e2a906..00bdbcb561 100644 --- a/examples/seq2seq/models.py +++ b/examples/seq2seq/models.py @@ -25,7 +25,7 @@ Array = jax.Array PRNGKey = jax.Array -LSTMCarry = Tuple[Array, Array] +LSTMCarry = tuple[Array, Array] class DecoderLSTMCell(nn.RNNCellBase): @@ -42,8 +42,8 @@ class DecoderLSTMCell(nn.RNNCellBase): @nn.compact def __call__( - self, carry: Tuple[LSTMCarry, Array], x: Array - ) -> Tuple[Tuple[LSTMCarry, Array], Tuple[Array, Array]]: + self, carry: tuple[LSTMCarry, Array], x: Array + ) -> tuple[tuple[LSTMCarry, Array], tuple[Array, Array]]: """Applies the DecoderLSTM model.""" lstm_state, last_prediction = carry if not self.teacher_force: @@ -87,7 +87,7 @@ class Seq2seq(nn.Module): @nn.compact def __call__( self, encoder_inputs: Array, decoder_inputs: Array - ) -> Tuple[Array, Array]: + ) -> tuple[Array, Array]: """Applies the seq2seq model. Args: diff --git a/examples/seq2seq/train.py b/examples/seq2seq/train.py index 09116668cd..3fcfd0497e 100644 --- a/examples/seq2seq/train.py +++ b/examples/seq2seq/train.py @@ -82,7 +82,7 @@ def get_model(ctable: CTable, *, teacher_force: bool = False) -> models.Seq2seq: def get_initial_params( model: models.Seq2seq, rng: PRNGKey, ctable: CTable -) -> Dict[str, Any]: +) -> dict[str, Any]: """Returns the initial parameters of a seq2seq model.""" rng1, rng2 = jax.random.split(rng) variables = model.init( @@ -115,7 +115,7 @@ def cross_entropy_loss( def compute_metrics( logits: Array, labels: Array, eos_id: int -) -> Dict[str, jax.Array]: +) -> dict[str, jax.Array]: """Computes metrics and returns them.""" lengths = get_sequence_lengths(labels, eos_id) loss = cross_entropy_loss(logits, labels, lengths) @@ -136,7 +136,7 @@ def compute_metrics( @jax.jit def train_step( state: train_state.TrainState, batch: Array, lstm_rng: PRNGKey, eos_id: int -) -> Tuple[train_state.TrainState, Dict[str, jax.Array]]: +) -> tuple[train_state.TrainState, dict[str, jax.Array]]: """Trains one step.""" labels = batch['answer'][:, 1:] lstm_key = jax.random.fold_in(lstm_rng, state.step) @@ -171,7 +171,7 @@ def log_decode(question: str, inferred: str, golden: str): @functools.partial(jax.jit, static_argnums=3) def decode( - params: Dict[str, Any], inputs: Array, decode_rng: PRNGKey, ctable: CTable + params: dict[str, Any], inputs: Array, decode_rng: PRNGKey, ctable: CTable ) -> Array: """Decodes inputs.""" init_decoder_input = ctable.one_hot(ctable.encode('=')[0:1]) @@ -187,7 +187,7 @@ def decode( def decode_batch( state: train_state.TrainState, - batch: Dict[str, Array], + batch: dict[str, Array], decode_rng: PRNGKey, ctable: CTable, ): diff --git a/examples/sst2/build_vocabulary.py b/examples/sst2/build_vocabulary.py index 1bea244ef2..2b3059dba6 100755 --- a/examples/sst2/build_vocabulary.py +++ b/examples/sst2/build_vocabulary.py @@ -15,7 +15,7 @@ """A vocabulary builder that generates vocab.txt to be used for training.""" import time -from typing import Iterable, Sequence +from collections.abc import Iterable, Sequence from absl import logging import tensorflow as tf diff --git a/examples/sst2/input_pipeline.py b/examples/sst2/input_pipeline.py index a37a6772e8..dba5ba8dd7 100755 --- a/examples/sst2/input_pipeline.py +++ b/examples/sst2/input_pipeline.py @@ -26,7 +26,7 @@ AUTOTUNE = tf.data.experimental.AUTOTUNE -Example = Dict[str, tf.Tensor] +Example = dict[str, tf.Tensor] def get_bucket_boundaries(bucket_size: int, max_size: int) -> np.ndarray: @@ -64,7 +64,7 @@ def get_bucketed_batches( padded_shapes: Any, example_size_fn: Any, shuffle: bool = False, - shuffle_seed: Optional[int] = None, + shuffle_seed: int | None = None, drop_remainder: bool = False, ) -> tf.data.Dataset: """Returns padded batches of shuffled examples bucketed by length. @@ -230,9 +230,9 @@ def get_batches( batch_size: int, drop_remainder: bool = False, shuffle: bool = False, - shuffle_seed: Optional[int] = None, - fixed_pad_length: Optional[int] = None, - dataset: Optional[tf.data.Dataset] = None, + shuffle_seed: int | None = None, + fixed_pad_length: int | None = None, + dataset: tf.data.Dataset | None = None, ): """Returns an iterator with padded batches for the provided dataset.""" if dataset is None: @@ -256,8 +256,8 @@ def get_bucketed_batches( max_input_length: int, drop_remainder: bool = False, shuffle: bool = False, - shuffle_seed: Optional[int] = None, - dataset: Optional[tf.data.Dataset] = None, + shuffle_seed: int | None = None, + dataset: tf.data.Dataset | None = None, ): """Returns an iterator with bucketed batches for the provided dataset.""" if dataset is None: diff --git a/examples/sst2/models.py b/examples/sst2/models.py index 68158ed68c..b7156e62c5 100644 --- a/examples/sst2/models.py +++ b/examples/sst2/models.py @@ -15,7 +15,8 @@ """A text classification model.""" import functools -from typing import Any, Callable, Optional +from typing import Any, Optional +from collections.abc import Callable from flax import linen as nn import jax @@ -88,10 +89,10 @@ class WordDropout(nn.Module): dropout_rate: float unk_idx: int - deterministic: Optional[bool] = None + deterministic: bool | None = None @nn.compact - def __call__(self, inputs: Array, deterministic: Optional[bool] = None): + def __call__(self, inputs: Array, deterministic: bool | None = None): deterministic = nn.module.merge_param( 'deterministic', self.deterministic, deterministic ) @@ -121,8 +122,8 @@ class Embedder(nn.Module): frozen: bool = False dropout_rate: float = 0.0 word_dropout_rate: float = 0.0 - unk_idx: Optional[int] = None - deterministic: Optional[bool] = None + unk_idx: int | None = None + deterministic: bool | None = None dtype: jnp.dtype = jnp.float32 def setup(self): @@ -138,7 +139,7 @@ def setup(self): ) def __call__( - self, inputs: Array, deterministic: Optional[bool] = None + self, inputs: Array, deterministic: bool | None = None ) -> Array: """Embeds the input sequences and applies word dropout and dropout. @@ -232,14 +233,14 @@ class MLP(nn.Module): activation: Callable[..., Any] = nn.tanh dropout_rate: float = 0.0 output_bias: bool = False - deterministic: Optional[bool] = None + deterministic: bool | None = None def setup(self): self.intermediate_layer = nn.Dense(self.hidden_size) self.output_layer = nn.Dense(self.output_size, use_bias=self.output_bias) self.dropout_layer = nn.Dropout(rate=self.dropout_rate) - def __call__(self, inputs: Array, deterministic: Optional[bool] = None): + def __call__(self, inputs: Array, deterministic: bool | None = None): """Applies the MLP to the last dimension of the inputs. Args: @@ -319,7 +320,7 @@ class AttentionClassifier(nn.Module): hidden_size: int output_size: int dropout_rate: float = 0.0 - deterministic: Optional[bool] = None + deterministic: bool | None = None def setup(self): self.dropout_layer = nn.Dropout(rate=self.dropout_rate) @@ -337,7 +338,7 @@ def __call__( self, encoded_inputs: Array, lengths: Array, - deterministic: Optional[bool] = None, + deterministic: bool | None = None, ) -> Array: """Applies model to the encoded inputs. @@ -383,7 +384,7 @@ class TextClassifier(nn.Module): dropout_rate: float word_dropout_rate: float unk_idx: int = 1 - deterministic: Optional[bool] = None + deterministic: bool | None = None def setup(self): self.embedder = Embedder( @@ -401,7 +402,7 @@ def setup(self): ) def embed_token_ids( - self, token_ids: Array, deterministic: Optional[bool] = None + self, token_ids: Array, deterministic: bool | None = None ) -> Array: deterministic = nn.module.merge_param( 'deterministic', self.deterministic, deterministic @@ -412,7 +413,7 @@ def logits_from_embedded_inputs( self, embedded_inputs: Array, lengths: Array, - deterministic: Optional[bool] = None, + deterministic: bool | None = None, ) -> Array: deterministic = nn.module.merge_param( 'deterministic', self.deterministic, deterministic @@ -424,7 +425,7 @@ def __call__( self, token_ids: Array, lengths: Array, - deterministic: Optional[bool] = None, + deterministic: bool | None = None, ) -> Array: """Embeds the token IDs, encodes them, and classifies with attention.""" embedded_inputs = self.embed_token_ids( diff --git a/examples/sst2/train.py b/examples/sst2/train.py index 0c6fc76d7e..24543a615b 100644 --- a/examples/sst2/train.py +++ b/examples/sst2/train.py @@ -13,7 +13,8 @@ # limitations under the License. """Trains an SST2 text classifier.""" -from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union +from collections.abc import Callable, Iterable, Sequence from absl import logging from flax import struct @@ -31,7 +32,7 @@ Array = jnp.ndarray -Example = Dict[str, Array] +Example = dict[str, Array] TrainState = train_state.TrainState @@ -40,7 +41,7 @@ class Metrics(struct.PyTreeNode): loss: float accuracy: float - count: Optional[int] = None + count: int | None = None @jax.vmap @@ -102,9 +103,9 @@ def model_from_config(config: ml_collections.ConfigDict): def train_step( state: TrainState, - batch: Dict[str, Array], - rngs: Dict[str, Any], -) -> Tuple[TrainState, Metrics]: + batch: dict[str, Array], + rngs: dict[str, Any], +) -> tuple[TrainState, Metrics]: """Train for a single step.""" # Make sure to get a new RNG at every step. step = state.step @@ -138,7 +139,7 @@ def loss_fn(params): def eval_step( - state: TrainState, batch: Dict[str, Array], rngs: Dict[str, Any] + state: TrainState, batch: dict[str, Array], rngs: dict[str, Any] ) -> Metrics: """Evaluate for a single step. Model should be in deterministic mode.""" variables = {'params': state.params} @@ -165,7 +166,7 @@ def normalize_batch_metrics(batch_metrics: Sequence[Metrics]) -> Metrics: ) -def batch_to_numpy(batch: Dict[str, tf.Tensor]) -> Dict[str, Array]: +def batch_to_numpy(batch: dict[str, tf.Tensor]) -> dict[str, Array]: """Converts a batch with TF tensors to a batch of NumPy arrays.""" # _numpy() reuses memory, does not make a copy. # pylint: disable=protected-access @@ -175,9 +176,9 @@ def batch_to_numpy(batch: Dict[str, tf.Tensor]) -> Dict[str, Array]: def evaluate_model( eval_step_fn: Callable[..., Any], state: TrainState, - batches: Union[Iterable[Example], tf.data.Dataset], + batches: Iterable[Example] | tf.data.Dataset, epoch: int, - rngs: Optional[Dict[str, Any]] = None, + rngs: dict[str, Any] | None = None, ) -> Metrics: """Evaluate a model on a dataset.""" batch_metrics = [] @@ -201,12 +202,12 @@ def evaluate_model( def train_epoch( - train_step_fn: Callable[..., Tuple[TrainState, Metrics]], + train_step_fn: Callable[..., tuple[TrainState, Metrics]], state: TrainState, train_batches: tf.data.Dataset, epoch: int, - rngs: Optional[Dict[str, Any]] = None, -) -> Tuple[TrainState, Metrics]: + rngs: dict[str, Any] | None = None, +) -> tuple[TrainState, Metrics]: """Train for a single epoch.""" batch_metrics = [] for batch in train_batches: diff --git a/examples/sst2/vocabulary.py b/examples/sst2/vocabulary.py index 30f34ad13c..93a7a8a8b5 100755 --- a/examples/sst2/vocabulary.py +++ b/examples/sst2/vocabulary.py @@ -15,7 +15,8 @@ """A vocabulary that represents the tokens in a dataset and maps them to indices.""" import collections -from typing import Iterable, Optional, Sequence +from typing import Optional +from collections.abc import Iterable, Sequence from absl import logging @@ -25,8 +26,8 @@ class Vocabulary: def __init__( self, - vocab_path: Optional[str] = None, - tokenized_sequences: Optional[Iterable[Sequence[bytes]]] = None, + vocab_path: str | None = None, + tokenized_sequences: Iterable[Sequence[bytes]] | None = None, min_freq: int = 1, pad_token: bytes = b'', unk_token: bytes = b'', diff --git a/examples/wmt/input_pipeline.py b/examples/wmt/input_pipeline.py index b7178ab247..04d1da9c7c 100644 --- a/examples/wmt/input_pipeline.py +++ b/examples/wmt/input_pipeline.py @@ -26,7 +26,7 @@ AUTOTUNE = tf.data.AUTOTUNE -Features = Dict[str, tf.Tensor] +Features = dict[str, tf.Tensor] class NormalizeFeatureNamesOp: @@ -78,8 +78,8 @@ def get_raw_dataset( def pack_dataset( dataset: tf.data.Dataset, - key2length: Union[int, Dict[str, int]], - keys: Optional[List[str]] = None, + key2length: int | dict[str, int], + keys: list[str] | None = None, ) -> tf.data.Dataset: """Creates a 'packed' version of a dataset on-the-fly. @@ -160,7 +160,7 @@ def my_fn(x): def _pack_with_tf_ops( - dataset: tf.data.Dataset, keys: List[str], key2length: Dict[str, int] + dataset: tf.data.Dataset, keys: list[str], key2length: dict[str, int] ) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. @@ -286,7 +286,7 @@ def true_fn(): def preprocess_wmt_data( dataset, shuffle: bool, - num_epochs: Optional[int] = 1, + num_epochs: int | None = 1, pack_examples: bool = True, shuffle_buffer_size: int = 1024, max_length: int = 512, @@ -333,7 +333,7 @@ def get_wmt_datasets( *, n_devices: int, reverse_translation: bool = True, - vocab_path: Optional[str] = None, + vocab_path: str | None = None, ): """Load and return dataset of batched examples for use during training.""" if vocab_path is None: diff --git a/examples/wmt/models.py b/examples/wmt/models.py index cb80ed0099..5da0f70651 100644 --- a/examples/wmt/models.py +++ b/examples/wmt/models.py @@ -20,7 +20,8 @@ # pytype: disable=wrong-keyword-args # pytype: disable=attribute-error -from typing import Callable, Any, Optional +from typing import Any, Optional +from collections.abc import Callable from flax import linen as nn from flax import struct @@ -50,7 +51,7 @@ class TransformerConfig: decode: bool = False kernel_init: Callable = nn.initializers.xavier_uniform() bias_init: Callable = nn.initializers.normal(stddev=1e-6) - posemb_init: Optional[Callable] = None + posemb_init: Callable | None = None def shift_right(x, axis=1): @@ -163,7 +164,7 @@ class MlpBlock(nn.Module): """ config: TransformerConfig - out_dim: Optional[int] = None + out_dim: int | None = None @nn.compact def __call__(self, inputs): diff --git a/examples/wmt/tokenizer.py b/examples/wmt/tokenizer.py index eff381c124..4188a33877 100644 --- a/examples/wmt/tokenizer.py +++ b/examples/wmt/tokenizer.py @@ -18,7 +18,8 @@ import os import tempfile import time -from typing import Any, Dict, Iterable, Tuple +from typing import Any, Dict, Tuple +from collections.abc import Iterable from absl import logging import jax @@ -27,14 +28,14 @@ from sentencepiece import SentencePieceTrainer -Features = Dict[str, tf.Tensor] +Features = dict[str, tf.Tensor] def _dump_chars_to_textfile( dataset: tf.data.Dataset, maxchars: int = int(1e7), data_keys=('inputs', 'targets'), -) -> Tuple[str, int]: +) -> tuple[str, int]: """Write part of a TFDS sentence dataset to lines in a text file. Args: @@ -139,7 +140,7 @@ def load_or_train_tokenizer( vocab_path: str, vocab_size: int, max_corpus_chars: int, - data_keys: Tuple[str, str] = ('inputs', 'targets'), + data_keys: tuple[str, str] = ('inputs', 'targets'), ): """Loads the tokenizer at `vocab_path` or trains a one from `dataset`.""" try: diff --git a/flax/configurations.py b/flax/configurations.py index 98667521cb..4f61170f16 100644 --- a/flax/configurations.py +++ b/flax/configurations.py @@ -140,7 +140,7 @@ def static_bool_env(varname: str, default: bool) -> bool: return False else: raise ValueError( - 'invalid truth value {!r} for environment {!r}'.format(val, varname) + f'invalid truth value {val!r} for environment {varname!r}' ) diff --git a/flax/core/axes_scan.py b/flax/core/axes_scan.py index ec0f18d0d5..c495186d72 100644 --- a/flax/core/axes_scan.py +++ b/flax/core/axes_scan.py @@ -14,7 +14,8 @@ """Wrapper around jax.lax.scan with in_axes/out_axes API.""" import functools -from typing import Any, Callable, Optional +from typing import Any, Optional +from collections.abc import Callable import jax import jax.numpy as jnp @@ -37,7 +38,7 @@ def scan( fn: Callable[..., Any], in_axes: Any, out_axes: Any, - length: Optional[int] = None, + length: int | None = None, reverse: bool = False, unroll: int = 1, _split_transpose: bool = False diff --git a/flax/core/frozen_dict.py b/flax/core/frozen_dict.py index b78319a14f..13a195362f 100644 --- a/flax/core/frozen_dict.py +++ b/flax/core/frozen_dict.py @@ -16,7 +16,8 @@ import collections from types import MappingProxyType -from typing import Any, Dict, Hashable, Mapping, Tuple, TypeVar, Union +from typing import Any, TypeVar +from collections.abc import Hashable, Mapping import jax @@ -129,7 +130,7 @@ def items(self): for key in self._dict: yield (key, self[key]) - def pop(self, key: K) -> Tuple['FrozenDict[K, V]', V]: + def pop(self, key: K) -> tuple['FrozenDict[K, V]', V]: """Create a new FrozenDict where one entry is removed. Example:: @@ -149,7 +150,7 @@ def pop(self, key: K) -> Tuple['FrozenDict[K, V]', V]: new_self = type(self)(new_dict) return new_self, value - def unfreeze(self) -> Dict[K, V]: + def unfreeze(self) -> dict[K, V]: """Unfreeze this FrozenDict. Returns: @@ -157,7 +158,7 @@ def unfreeze(self) -> Dict[K, V]: """ return unfreeze(self) - def tree_flatten_with_keys(self) -> Tuple[Tuple[Any, ...], Hashable]: + def tree_flatten_with_keys(self) -> tuple[tuple[Any, ...], Hashable]: """Flattens this FrozenDict. Returns: @@ -201,7 +202,7 @@ def freeze(xs: Mapping[Any, Any]) -> FrozenDict[Any, Any]: return FrozenDict(xs) -def unfreeze(x: Union[FrozenDict, Dict[str, Any]]) -> Dict[Any, Any]: +def unfreeze(x: FrozenDict | dict[str, Any]) -> dict[Any, Any]: """Unfreeze a FrozenDict. Makes a mutable copy of a ``FrozenDict`` mutable by transforming @@ -228,9 +229,9 @@ def unfreeze(x: Union[FrozenDict, Dict[str, Any]]) -> Dict[Any, Any]: def copy( - x: Union[FrozenDict, Dict[str, Any]], - add_or_replace: Union[FrozenDict[str, Any], Dict[str, Any]] = FrozenDict({}), -) -> Union[FrozenDict, Dict[str, Any]]: + x: FrozenDict | dict[str, Any], + add_or_replace: FrozenDict[str, Any] | dict[str, Any] = FrozenDict({}), +) -> FrozenDict | dict[str, Any]: """Create a new dict with additional and/or replaced entries. This is a utility function that can act on either a FrozenDict or regular dict and mimics the behavior of ``FrozenDict.copy``. @@ -260,8 +261,8 @@ def copy( def pop( - x: Union[FrozenDict, Dict[str, Any]], key: str -) -> Tuple[Union[FrozenDict, Dict[str, Any]], Any]: + x: FrozenDict | dict[str, Any], key: str +) -> tuple[FrozenDict | dict[str, Any], Any]: """Create a new dict where one entry is removed. This is a utility function that can act on either a FrozenDict or regular dict and mimics the behavior of ``FrozenDict.pop``. diff --git a/flax/core/lift.py b/flax/core/lift.py index d254eb5cbf..e0de001f7a 100644 --- a/flax/core/lift.py +++ b/flax/core/lift.py @@ -18,17 +18,9 @@ import functools from typing import ( Any, - Callable, - Dict, - Iterable, - List, - Mapping, - Optional, - Sequence, - Tuple, TypeVar, - Union, ) +from collections.abc import Callable, Iterable, Mapping, Sequence import warnings from flax import traceback_util @@ -162,7 +154,7 @@ def _partial_pack( inner_rng_counters.append(rng_counters) rng_groups_xs_t = _transpose(rng_groups_xs) - inner_scopes: List[Scope] = [] + inner_scopes: list[Scope] = [] def scope_fn( variable_groups_xs_t, @@ -439,7 +431,7 @@ def vjp( vjp_variables: CollectionFilter = 'params', variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, -) -> Union[Tuple[Any, Callable[..., Any]], Tuple[Any, Callable[..., Any], Any]]: +) -> tuple[Any, Callable[..., Any]] | tuple[Any, Callable[..., Any], Any]: """A lifted version of ``jax.vjp``. See ``jax.vjp`` for the unlifted vector-Jacobian product (backward gradient). @@ -538,7 +530,7 @@ def value_and_grad( reduce_axes=(), variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, -) -> Union[Tuple[Any, Callable[..., Any]], Tuple[Any, Callable[..., Any], Any]]: +) -> tuple[Any, Callable[..., Any]] | tuple[Any, Callable[..., Any], Any]: """A limited lifted version of ``jax.value_and_grad``. See ``jax.value_and_grad`` for the unlifted reverse mode gradient. @@ -631,7 +623,7 @@ def jvp( variable_tangents, variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, -) -> Tuple[Any, Any]: +) -> tuple[Any, Any]: """A lifted version of ``jax.jvp``. See ``jax.jvp`` for the unlifted Jacobian-vector product (forward gradient). @@ -720,10 +712,10 @@ def vmap( split_rngs: Mapping[PRNGSequenceFilter, bool], in_axes=0, out_axes=0, - axis_size: Optional[int] = None, - axis_name: Optional[str] = None, - spmd_axis_name: Optional[str] = None, - metadata_params: Dict[Any, Any] = {}, + axis_size: int | None = None, + axis_name: str | None = None, + spmd_axis_name: str | None = None, + metadata_params: dict[Any, Any] = {}, ) -> Callable[..., Any]: """A lifted version of ``jax.vmap``. @@ -872,12 +864,12 @@ def scan( split_rngs: Mapping[PRNGSequenceFilter, bool] = {}, in_axes=0, out_axes=0, - length: Optional[int] = None, + length: int | None = None, reverse: bool = False, unroll: int = 1, _split_transpose: bool = False, - data_transform: Optional[Callable[..., Any]] = None, - metadata_params: Dict[Any, Any] = {}, + data_transform: Callable[..., Any] | None = None, + metadata_params: dict[Any, Any] = {}, ) -> Callable[..., Any]: """A lifted version of ``jax.lax.scan``. @@ -1416,8 +1408,8 @@ def checkpoint( rngs: PRNGSequenceFilter = True, concrete: bool = False, prevent_cse: bool = True, - static_argnums: Union[int, Tuple[int, ...]] = (), - policy: Optional[Callable[..., bool]] = None, + static_argnums: int | tuple[int, ...] = (), + policy: Callable[..., bool] | None = None, ) -> Callable[..., Any]: """Lifted version of ``jax.checkpoint``. @@ -1502,11 +1494,11 @@ def jit( fn: Callable[..., Any], variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, - static_argnums: Union[int, Iterable[int]] = (), - static_argnames: Union[str, Iterable[str]] = (), - donate_argnums: Union[int, Iterable[int]] = (), + static_argnums: int | Iterable[int] = (), + static_argnames: str | Iterable[str] = (), + donate_argnums: int | Iterable[int] = (), device=None, - backend: Union[str, None] = None, + backend: str | None = None, ) -> Callable[..., Any]: """Lifted version of ``jax.jit``. @@ -1562,8 +1554,8 @@ def jit( # Close over scope_fn & repack_fn to avoid recompilation # this is impure but we use the fingerprint arg to differentiate between cases # where scope_fn or repack_fn actually produce non-identical results. - scope_fn = None # type: Optional[Callable] - repack_fn = None # type: Optional[Callable] + scope_fn = None # type: Callable | None + repack_fn = None # type: Callable | None @functools.partial( jax.jit, @@ -1611,7 +1603,7 @@ def inner( def remat_scan( body_fn: Callable[..., Any], lengths: Sequence[int], - policy: Optional[Callable[..., bool]] = None, + policy: Callable[..., bool] | None = None, variable_broadcast: CollectionFilter = False, variable_carry: CollectionFilter = False, variable_axes: Mapping[CollectionFilter, InOutScanAxis] = {True: 0}, diff --git a/flax/core/meta.py b/flax/core/meta.py index aadeba8041..27686a40b5 100644 --- a/flax/core/meta.py +++ b/flax/core/meta.py @@ -23,7 +23,8 @@ import abc import functools -from typing import Any, Callable, Dict, Generic, Optional, TypeVar +from typing import Any, Generic, TypeVar +from collections.abc import Callable from flax import errors, struct from flax.typing import LogicalNames @@ -85,7 +86,7 @@ def replace_boxed(self, val: B) -> 'AxisMetadata[B]': @abc.abstractmethod def add_axis( - self: TAxisMetadata, index: int, params: Dict[Any, Any] + self: TAxisMetadata, index: int, params: dict[Any, Any] ) -> TAxisMetadata: """Adds a new axis to the axis metadata. @@ -107,7 +108,7 @@ def add_axis( @abc.abstractmethod def remove_axis( - self: TAxisMetadata, index: int, params: Dict[Any, Any] + self: TAxisMetadata, index: int, params: dict[Any, Any] ) -> TAxisMetadata: """Removes an axis from the axis metadata. @@ -145,12 +146,12 @@ def wrapper(x): return jax.tree_util.tree_map(wrapper, tree, is_leaf=is_axis_metadata) -def add_axis(tree: Any, index: int, params: Dict[Any, Any]) -> Any: +def add_axis(tree: Any, index: int, params: dict[Any, Any]) -> Any: """Add an axis to each AxisMetadata node in a PyTree.""" return map_axis_meta(lambda x: x.add_axis(index, params), tree) -def remove_axis(tree: Any, index: int, params: Dict[Any, Any]) -> Any: +def remove_axis(tree: Any, index: int, params: dict[Any, Any]) -> Any: """Remove an axis from each AxisMetadata node in a PyTree.""" return map_axis_meta(lambda x: x.remove_axis(index, params), tree) @@ -241,7 +242,7 @@ def body(mdl, c): value: Any names: LogicalNames = struct.field(pytree_node=False) - mesh: Optional[jax.sharding.Mesh] = struct.field( + mesh: jax.sharding.Mesh | None = struct.field( default=None, pytree_node=False ) @@ -259,12 +260,12 @@ def unbox(self, apply_constraint=True) -> A: def replace_boxed(self, val: B) -> 'Partitioned[B]': return self.replace(value=val) # type: ignore - def _get_partition_name(self, params: Dict[Any, Any]) -> str: + def _get_partition_name(self, params: dict[Any, Any]) -> str: if PARTITION_NAME not in params: raise errors.PartitioningUnspecifiedError(self) return params[PARTITION_NAME] - def add_axis(self, index: int, params: Dict[Any, Any]) -> 'Partitioned[A]': + def add_axis(self, index: int, params: dict[Any, Any]) -> 'Partitioned[A]': axis_name = self._get_partition_name(params) names = list(self.names) while len(names) < index: @@ -272,7 +273,7 @@ def add_axis(self, index: int, params: Dict[Any, Any]) -> 'Partitioned[A]': names.insert(index, axis_name) # type: ignore return self.replace(names=tuple(names)) - def remove_axis(self, index: int, params: Dict[Any, Any]) -> 'Partitioned[A]': + def remove_axis(self, index: int, params: dict[Any, Any]) -> 'Partitioned[A]': axis_name = self._get_partition_name(params) names = list(self.names) assert names.pop(index) == axis_name @@ -290,7 +291,7 @@ def get_sharding(self, mesh: jax.sharding.Mesh) -> jax.sharding.Sharding: def with_partitioning( fn: Callable[..., Any], names: LogicalNames, - mesh: Optional[jax.sharding.Mesh] = None, + mesh: jax.sharding.Mesh | None = None, ) -> Callable[..., Partitioned[Any]]: """Wraps a function's return value with Partitioned. diff --git a/flax/core/nn/attention.py b/flax/core/nn/attention.py index 82dd18d5aa..7e5e6bc8fc 100644 --- a/flax/core/nn/attention.py +++ b/flax/core/nn/attention.py @@ -16,7 +16,8 @@ import functools from collections.abc import Iterable # pylint: disable=g-importing-member -from typing import Any, Callable, Union +from typing import Any +from collections.abc import Callable import jax import jax.numpy as jnp @@ -272,7 +273,7 @@ def multi_head_dot_product_attention( value = scope.child(dense, 'value')(inputs_kv) if cache: - cache_entry: Union[Callable[[Any], CacheEntry], CacheEntry] + cache_entry: Callable[[Any], CacheEntry] | CacheEntry if not scope.has_variable('cache', 'entry'): ndim, tail_shape = (key.ndim, key.shape[-2:]) diff --git a/flax/core/scope.py b/flax/core/scope.py index 1d7f430ad7..9b71e38e20 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -22,21 +22,15 @@ import typing from typing import ( Any, - Callable, - Dict, Generic, - Iterable, Literal, - Mapping, Optional, - Sequence, - Set, - Tuple, TypeVar, Union, cast, overload, ) +from collections.abc import Callable, Iterable, Mapping, Sequence import jax import numpy as np @@ -95,7 +89,7 @@ class LazyRng(struct.PyTreeNode): """Wrapper around JAX PRNGKey that lazily maintains a tuple of static data to be folded into the rng.""" rng: PRNGKey - suffix: Tuple[PRNGFoldable, ...] = struct.field(pytree_node=False) + suffix: tuple[PRNGFoldable, ...] = struct.field(pytree_node=False) def as_jax_rng(self) -> PRNGKey: return _fold_in_static(self.rng, self.suffix) @@ -214,7 +208,7 @@ def in_filter(filter_like: Filter, col: str) -> bool: raise errors.InvalidFilterError(filter_like) -def filter_to_set(x: Filter) -> Set[str]: +def filter_to_set(x: Filter) -> set[str]: """Converts a Filter into a set of collections, fails on the infinite set. Args: @@ -227,7 +221,7 @@ def filter_to_set(x: Filter) -> Set[str]: if x is False: return set() if isinstance(x, str): - return set([x]) + return {x} if isinstance(x, typing.Collection): return set(x) raise errors.InvalidFilterError(x) @@ -419,18 +413,18 @@ class Scope: for a number of examples using ``Scopes``. """ - reservations: Dict[str, Set[Optional[str]]] + reservations: dict[str, set[str | None]] def __init__( self, variables: MutableVariableDict, - rngs: Optional[Union[RNGSequences, Dict[str, LazyRng]]] = None, - name: Optional[str] = None, + rngs: RNGSequences | dict[str, LazyRng] | None = None, + name: str | None = None, mutable: CollectionFilter = False, parent: Optional['Scope'] = None, path: Iterable[str] = (), debug_path: Iterable[str] = (), - flags: Optional[Mapping] = None, + flags: Mapping | None = None, ): """Initializes a Scope. @@ -511,7 +505,7 @@ def invalidate(self): """Invalidates the Scope.""" self._invalid = True - def mutable_variables(self) -> Union[VariableDict, Dict[str, Any]]: + def mutable_variables(self) -> VariableDict | dict[str, Any]: """Returns an immutable copy of the mutable variables belonging to this Scope.""" self._populate_collections() xs = { @@ -521,7 +515,7 @@ def mutable_variables(self) -> Union[VariableDict, Dict[str, Any]]: return freeze(xs) return xs - def variables(self) -> Union[VariableDict, Dict[str, Any]]: + def variables(self) -> VariableDict | dict[str, Any]: """Returns an immutable copy of the variables belonging to this Scope.""" self._populate_collections() if config.flax_return_frozendict: @@ -556,7 +550,7 @@ def rewound(self, rewind_rngs: bool = False) -> 'Scope': scope.rng_counters = self.rng_counters return scope - def name_reserved(self, name: str, col: Optional[str] = None) -> bool: + def name_reserved(self, name: str, col: str | None = None) -> bool: """Checks whether a name for a child Scope or Variable is taken. Args: @@ -574,7 +568,7 @@ def name_reserved(self, name: str, col: Optional[str] = None) -> bool: return True return False - def reserve(self, name: str, col: Optional[str] = None): + def reserve(self, name: str, col: str | None = None): """Reserves a name for a child Scope or Variable. Throws an error if the name exists already. @@ -608,7 +602,7 @@ def default_name(self, prefix: str) -> str: i += 1 def push( - self, name: Optional[str] = None, prefix: str = '', reuse=False + self, name: str | None = None, prefix: str = '', reuse=False ) -> 'Scope': """Creates a child Scope. @@ -650,8 +644,8 @@ def push( def child( self, fn: Callable[..., Any], - name: Optional[str] = None, - prefix: Optional[str] = None, + name: str | None = None, + prefix: str | None = None, named_call: bool = True, **partial_kwargs, ) -> Callable[..., Any]: @@ -818,7 +812,7 @@ def variable( self, col: str, name: str, - init_fn: Optional[Callable[..., T]] = None, + init_fn: Callable[..., T] | None = None, *init_args, ) -> Variable[T]: ... @@ -828,7 +822,7 @@ def variable( self, col: str, name: str, - init_fn: Optional[Callable[..., T]] = None, + init_fn: Callable[..., T] | None = None, *init_args, unbox: Literal[True], **init_kwargs, @@ -840,7 +834,7 @@ def variable( self, col: str, name: str, - init_fn: Optional[Callable[..., T]] = None, + init_fn: Callable[..., T] | None = None, *init_args, unbox: Literal[False], **init_kwargs, @@ -852,22 +846,22 @@ def variable( self, col: str, name: str, - init_fn: Optional[Callable[..., T]] = None, + init_fn: Callable[..., T] | None = None, *init_args, unbox: bool = True, **init_kwargs, - ) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]: + ) -> Variable[T] | Variable[meta.AxisMetadata[T]]: ... def variable( self, col: str, name: str, # pylint: disable=keyword-arg-before-vararg - init_fn: Optional[Callable[..., T]] = None, + init_fn: Callable[..., T] | None = None, *init_args, unbox: bool = True, **init_kwargs, - ) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]: + ) -> Variable[T] | Variable[meta.AxisMetadata[T]]: """Creates a variable if it doesn't exist yet in this scope and returns it. Args: @@ -934,7 +928,7 @@ def param( *init_args, unbox: bool, **init_kwargs, - ) -> Union[T, meta.AxisMetadata[T]]: + ) -> T | meta.AxisMetadata[T]: ... def param( @@ -944,7 +938,7 @@ def param( *init_args, unbox: bool = True, **init_kwargs, - ) -> Union[T, meta.AxisMetadata[T]]: + ) -> T | meta.AxisMetadata[T]: """Creates a parameter if it doesn't exist yet in this scope and returns it. If the parameter exists already, the existing value is simply returned. @@ -1019,9 +1013,9 @@ def _unfreeze_variables(variables, mutable): def bind( variables: VariableDict, - rngs: Optional[RNGSequences] = None, + rngs: RNGSequences | None = None, mutable: CollectionFilter = False, - flags: Optional[Mapping] = None, + flags: Mapping | None = None, ): """Binds variables and rngs to a new ``Scope``. @@ -1057,7 +1051,7 @@ def bind( def apply( fn: Callable[..., Any], mutable: CollectionFilter = False, - flags: Optional[Mapping] = None, + flags: Mapping | None = None, ) -> Callable[..., Any]: """Functionalize a `Scope` function. @@ -1074,9 +1068,9 @@ def apply( def wrapper( variables: VariableDict, *args, - rngs: Optional[Union[PRNGKey, RNGSequences]] = None, + rngs: PRNGKey | RNGSequences | None = None, **kwargs, - ) -> Union[Any, Tuple[Any, Union[VariableDict, Dict[str, Any]]]]: + ) -> Any | tuple[Any, VariableDict | dict[str, Any]]: if rngs is not None: if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs): raise ValueError( @@ -1110,7 +1104,7 @@ def wrapper( def init( fn: Callable[..., Any], mutable: CollectionFilter = True, - flags: Optional[Mapping] = None, + flags: Mapping | None = None, ) -> Callable[..., Any]: """Functionalize a `Scope` function for initialization. @@ -1124,7 +1118,7 @@ def init( """ @functools.wraps(fn) - def wrapper(rngs, *args, **kwargs) -> Tuple[Any, VariableDict]: + def wrapper(rngs, *args, **kwargs) -> tuple[Any, VariableDict]: if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs): raise ValueError( 'First argument passed to an init function should be a ' @@ -1144,7 +1138,7 @@ def wrapper(rngs, *args, **kwargs) -> Tuple[Any, VariableDict]: def lazy_init( fn: Callable[..., Any], mutable: CollectionFilter = True, - flags: Optional[Mapping] = None, + flags: Mapping | None = None, ) -> Callable[..., Any]: """Functionalizes a `Scope` function for lazy initialization. @@ -1227,7 +1221,7 @@ def _is_valid_rng(rng: Array): return True -def _is_valid_rngs(rngs: Union[PRNGKey, RNGSequences]): +def _is_valid_rngs(rngs: PRNGKey | RNGSequences): if not isinstance(rngs, (FrozenDict, dict)): return False for key, val in rngs.items(): diff --git a/flax/cursor.py b/flax/cursor.py index dc0f3c5274..334e234569 100644 --- a/flax/cursor.py +++ b/flax/cursor.py @@ -16,16 +16,12 @@ import enum from typing import ( Any, - Callable, - Dict, - Generator, Generic, - Mapping, - Optional, Protocol, TypeVar, runtime_checkable, ) +from collections.abc import Callable, Generator, Mapping from flax.core import FrozenDict from flax.errors import CursorFindError, TraverseTreeError @@ -123,10 +119,10 @@ def _traverse_tree(path, obj, *, update_fn=None, cond_fn=None): class Cursor(Generic[A]): _obj: A - _parent_key: Optional[ParentKey[A]] - _changes: Dict[Any, 'Cursor[A]'] + _parent_key: ParentKey[A] | None + _changes: dict[Any, 'Cursor[A]'] - def __init__(self, obj: A, parent_key: Optional[ParentKey[A]]): + def __init__(self, obj: A, parent_key: ParentKey[A] | None): # NOTE: we use `vars` here to avoid calling `__setattr__` # vars(self) = self.__dict__ vars(self)['_obj'] = obj diff --git a/flax/linen/attention.py b/flax/linen/attention.py index 6e91be996d..9859daefd9 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -18,7 +18,8 @@ import functools import inspect import warnings -from typing import Any, Callable, Optional, Union, overload +from typing import Any, overload +from collections.abc import Callable import jax import jax.numpy as jnp @@ -46,15 +47,15 @@ def dot_product_attention_weights( query: Array, key: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, + bias: Array | None = None, + mask: Array | None = None, broadcast_dropout: bool = True, - dropout_rng: Optional[PRNGKey] = None, + dropout_rng: PRNGKey | None = None, dropout_rate: float = 0.0, deterministic: bool = False, - dtype: Optional[Dtype] = None, + dtype: Dtype | None = None, precision: PrecisionLike = None, - module: Optional[Module] = None, + module: Module | None = None, force_fp32_for_softmax: bool = False, einsum_dot_general: Callable[..., Array] = jax.lax.dot_general, ): @@ -151,15 +152,15 @@ def dot_product_attention( query: Array, key: Array, value: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, + bias: Array | None = None, + mask: Array | None = None, broadcast_dropout: bool = True, - dropout_rng: Optional[PRNGKey] = None, + dropout_rng: PRNGKey | None = None, dropout_rate: float = 0.0, deterministic: bool = False, - dtype: Optional[Dtype] = None, + dtype: Dtype | None = None, precision: PrecisionLike = None, - module: Optional[Module] = None, + module: Module | None = None, force_fp32_for_softmax: bool = False, einsum_dot_general: Callable[..., Array] = jax.lax.dot_general, ): @@ -322,13 +323,13 @@ class MultiHeadDotProductAttention(Module): """ num_heads: int - dtype: Optional[Dtype] = None + dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 - qkv_features: Optional[int] = None - out_features: Optional[int] = None + qkv_features: int | None = None + out_features: int | None = None broadcast_dropout: bool = True dropout_rate: float = 0.0 - deterministic: Optional[bool] = None + deterministic: bool | None = None precision: PrecisionLike = None kernel_init: Initializer = default_kernel_init out_kernel_init: Initializer | None = None @@ -340,8 +341,8 @@ class MultiHeadDotProductAttention(Module): normalize_qk: bool = False force_fp32_for_softmax: bool = False # Deprecated, will be removed. - qkv_dot_general: Optional[DotGeneralT] = None - out_dot_general: Optional[DotGeneralT] = None + qkv_dot_general: DotGeneralT | None = None + out_dot_general: DotGeneralT | None = None qkv_dot_general_cls: Any = None out_dot_general_cls: Any = None @@ -349,12 +350,12 @@ class MultiHeadDotProductAttention(Module): def __call__( self, inputs_q: Array, - inputs_k: Optional[Array] = None, - inputs_v: Optional[Array] = None, + inputs_k: Array | None = None, + inputs_v: Array | None = None, *, - mask: Optional[Array] = None, - deterministic: Optional[bool] = None, - dropout_rng: Optional[PRNGKey] = None, + mask: Array | None = None, + deterministic: bool | None = None, + dropout_rng: PRNGKey | None = None, sow_weights: bool = False, ): ... @@ -364,10 +365,10 @@ def __call__( self, inputs_q: Array, *, - inputs_kv: Optional[Array] = None, - mask: Optional[Array] = None, - deterministic: Optional[bool] = None, - dropout_rng: Optional[PRNGKey] = None, + inputs_kv: Array | None = None, + mask: Array | None = None, + deterministic: bool | None = None, + dropout_rng: PRNGKey | None = None, sow_weights: bool = False, ): ... @@ -376,13 +377,13 @@ def __call__( def __call__( self, inputs_q: Array, - inputs_k: Optional[Array] = None, - inputs_v: Optional[Array] = None, + inputs_k: Array | None = None, + inputs_v: Array | None = None, *, - inputs_kv: Optional[Array] = None, - mask: Optional[Array] = None, - deterministic: Optional[bool] = None, - dropout_rng: Optional[PRNGKey] = None, + inputs_kv: Array | None = None, + mask: Array | None = None, + deterministic: bool | None = None, + dropout_rng: PRNGKey | None = None, sow_weights: bool = False, ): """Applies multi-head dot product attention on the input data. @@ -539,7 +540,7 @@ def __call__( # update key, value caches with our new 1d spatial slices cur_index = cache_index.value zero = jnp.array(0, dtype=lax.dtype(cur_index.dtype)) - indices: tuple[Union[int, jax.Array], ...] = (zero,) * len( + indices: tuple[int | jax.Array, ...] = (zero,) * len( batch_dims ) + ( cur_index, @@ -712,9 +713,9 @@ class SelfAttention(MultiHeadDotProductAttention): def __call__( # type: ignore self, inputs_q: Array, - mask: Optional[Array] = None, - deterministic: Optional[bool] = None, - dropout_rng: Optional[PRNGKey] = None, + mask: Array | None = None, + deterministic: bool | None = None, + dropout_rng: PRNGKey | None = None, sow_weights: bool = False, ): """Applies multi-head dot product self-attention on the input data. @@ -813,8 +814,8 @@ def make_causal_mask( def combine_masks( - *masks: Optional[Array], dtype: Dtype = jnp.float32 -) -> Optional[Array]: + *masks: Array | None, dtype: Dtype = jnp.float32 +) -> Array | None: """Combine attention masks. Args: diff --git a/flax/linen/combinators.py b/flax/linen/combinators.py index cf566dbaba..03aa18a6e1 100644 --- a/flax/linen/combinators.py +++ b/flax/linen/combinators.py @@ -14,7 +14,8 @@ """Combinators of modules, such as a Sequential.""" -from typing import Any, Callable, Dict, Sequence +from typing import Any +from collections.abc import Callable, Sequence from flax.linen.module import Module, compact @@ -106,7 +107,7 @@ def __call__(self, *args, **kwargs): for layer in self.layers[1:]: if isinstance(outputs, tuple): outputs = layer(*outputs) - elif isinstance(outputs, Dict): + elif isinstance(outputs, dict): outputs = layer(**outputs) else: outputs = layer(outputs) diff --git a/flax/linen/dtypes.py b/flax/linen/dtypes.py index df305fd3aa..d88f998180 100644 --- a/flax/linen/dtypes.py +++ b/flax/linen/dtypes.py @@ -27,13 +27,13 @@ # limitations under the License. """APIs for handling dtypes in Linen Modules.""" -from typing import Any, List, Optional +from typing import Any from flax.typing import Dtype from jax import numpy as jnp def canonicalize_dtype( - *args, dtype: Optional[Dtype] = None, inexact: bool = True + *args, dtype: Dtype | None = None, inexact: bool = True ) -> Dtype: """Canonicalize an optional dtype to the definitive dtype. @@ -64,7 +64,7 @@ def canonicalize_dtype( return dtype -def promote_dtype(*args, dtype=None, inexact=True) -> List[Any]: +def promote_dtype(*args, dtype=None, inexact=True) -> list[Any]: """ "Promotes input arguments to a specified or inferred dtype. All args are cast to the same dtype. See ``canonicalize_dtype`` for how diff --git a/flax/linen/experimental/layers_with_named_axes.py b/flax/linen/experimental/layers_with_named_axes.py index 8ff2b91526..795aa845fa 100644 --- a/flax/linen/experimental/layers_with_named_axes.py +++ b/flax/linen/experimental/layers_with_named_axes.py @@ -14,7 +14,8 @@ """Experimental layers with named axes for the partitioning API.""" import dataclasses -from typing import Any, Callable, Iterable, Optional, Tuple, Sequence +from typing import Any +from collections.abc import Callable, Iterable, Sequence import jax.numpy as jnp from jax import lax @@ -66,9 +67,9 @@ class Dense(nn.Module): precision: PrecisionLike = None kernel_init: Initializer = default_kernel_init bias_init: Initializer = initializers.zeros_init() - kernel_axes: Tuple[str, ...] = () + kernel_axes: tuple[str, ...] = () # Deprecated. Will be removed. - dot_general: Optional[DotGeneralT] = None + dot_general: DotGeneralT | None = None dot_general_cls: Any = None @nn.compact @@ -135,10 +136,10 @@ class Embed(nn.Module): num_embeddings: int features: int - cast_input_dtype: Optional[Dtype] = None + cast_input_dtype: Dtype | None = None dtype: Dtype = jnp.float32 param_dtype: Dtype = jnp.float32 - attend_dtype: Optional[Dtype] = None + attend_dtype: Dtype | None = None embedding_init: Initializer = default_embed_init one_hot: bool = False embedding: Array = dataclasses.field(init=False) @@ -196,7 +197,7 @@ def _canonicalize_axes(rank: int, axes: Axes) -> Sequence[int]: """Returns a tuple of deduplicated, sorted, and positive axes.""" if not isinstance(axes, Iterable): axes = (axes,) - return tuple(set([rank + axis if axis < 0 else axis for axis in axes])) + return tuple({rank + axis if axis < 0 else axis for axis in axes}) def _abs_sq(x): diff --git a/flax/linen/linear.py b/flax/linen/linear.py index 152fee30f1..babe809af0 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -16,13 +16,8 @@ from typing import ( Any, - Iterable, - List, - Optional, - Sequence, - Tuple, - Union, ) +from collections.abc import Iterable, Sequence import jax import jax.numpy as jnp @@ -54,12 +49,12 @@ default_kernel_init = initializers.lecun_normal() -def _normalize_axes(axes: Tuple[int, ...], ndim: int) -> Tuple[int, ...]: +def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]: # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. return tuple(sorted(ax if ax >= 0 else ndim + ax for ax in axes)) -def _canonicalize_tuple(x: Union[Sequence[int], int]) -> Tuple[int, ...]: +def _canonicalize_tuple(x: Sequence[int] | int) -> tuple[int, ...]: if isinstance(x, Iterable): return tuple(x) else: @@ -101,17 +96,17 @@ class DenseGeneral(Module): for details. """ - features: Union[int, Sequence[int]] - axis: Union[int, Sequence[int]] = -1 + features: int | Sequence[int] + axis: int | Sequence[int] = -1 batch_dims: Sequence[int] = () use_bias: bool = True - dtype: Optional[Dtype] = None + dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 kernel_init: Initializer = default_kernel_init bias_init: Initializer = initializers.zeros_init() precision: PrecisionLike = None # Deprecated. Will be removed. - dot_general: Optional[DotGeneralT] = None + dot_general: DotGeneralT | None = None dot_general_cls: Any = None @compact @@ -234,13 +229,13 @@ class Dense(Module): features: int use_bias: bool = True - dtype: Optional[Dtype] = None + dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Initializer = default_kernel_init bias_init: Initializer = initializers.zeros_init() # Deprecated. Will be removed. - dot_general: Optional[DotGeneralT] = None + dot_general: DotGeneralT | None = None dot_general_cls: Any = None @compact @@ -314,16 +309,16 @@ class Einsum(Module): """ shape: Shape - einsum_str: Optional[str] = None + einsum_str: str | None = None use_bias: bool = True - dtype: Optional[Dtype] = None + dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Initializer = default_kernel_init bias_init: Initializer = initializers.zeros_init() @compact - def __call__(self, inputs: Array, einsum_str: Optional[str] = None) -> Array: + def __call__(self, inputs: Array, einsum_str: str | None = None) -> Array: """Applies a linear transformation to the inputs along the last dimension. Args: @@ -477,21 +472,21 @@ class _Conv(Module): """ features: int - kernel_size: Union[int, Sequence[int]] - strides: Union[None, int, Sequence[int]] = 1 + kernel_size: int | Sequence[int] + strides: None | int | Sequence[int] = 1 padding: PaddingLike = 'SAME' - input_dilation: Union[None, int, Sequence[int]] = 1 - kernel_dilation: Union[None, int, Sequence[int]] = 1 + input_dilation: None | int | Sequence[int] = 1 + kernel_dilation: None | int | Sequence[int] = 1 feature_group_count: int = 1 use_bias: bool = True - mask: Optional[Array] = None - dtype: Optional[Dtype] = None + mask: Array | None = None + dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Initializer = default_kernel_init bias_init: Initializer = initializers.zeros_init() # Deprecated. Will be removed. - conv_general_dilated: Optional[ConvGeneralDilatedT] = None + conv_general_dilated: ConvGeneralDilatedT | None = None conv_general_dilated_cls: Any = None @property @@ -533,8 +528,8 @@ def __call__(self, inputs: Array) -> Array: kernel_size = tuple(self.kernel_size) def maybe_broadcast( - x: Optional[Union[int, Sequence[int]]], - ) -> Tuple[int, ...]: + x: int | Sequence[int] | None, + ) -> tuple[int, ...]: if x is None: # backward compatibility with using None as sentinel for # broadcast 1 @@ -563,7 +558,7 @@ def maybe_broadcast( kernel_size_dilated = [ (k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation) ] - zero_pad: List[Tuple[int, int]] = [(0, 0)] + zero_pad: list[tuple[int, int]] = [(0, 0)] pads = ( zero_pad + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] @@ -887,13 +882,13 @@ class ConvTranspose(Module): """ features: int - kernel_size: Union[int, Sequence[int]] - strides: Optional[Sequence[int]] = None + kernel_size: int | Sequence[int] + strides: Sequence[int] | None = None padding: PaddingLike = 'SAME' - kernel_dilation: Optional[Sequence[int]] = None + kernel_dilation: Sequence[int] | None = None use_bias: bool = True - mask: Optional[Array] = None - dtype: Optional[Dtype] = None + mask: Array | None = None + dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Initializer = default_kernel_init @@ -921,15 +916,15 @@ def __call__(self, inputs: Array) -> Array: Returns: The convolved data. """ - kernel_size: Tuple[int, ...] + kernel_size: tuple[int, ...] if isinstance(self.kernel_size, int): kernel_size = (self.kernel_size,) else: kernel_size = tuple(self.kernel_size) def maybe_broadcast( - x: Optional[Union[int, Sequence[int]]], - ) -> Tuple[int, ...]: + x: int | Sequence[int] | None, + ) -> tuple[int, ...]: if x is None: # backward compatibility with using None as sentinel for # broadcast 1 @@ -1098,7 +1093,7 @@ class Embed(Module): num_embeddings: int features: int - dtype: Optional[Dtype] = None + dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 embedding_init: Initializer = default_embed_init diff --git a/flax/linen/module.py b/flax/linen/module.py index f8234eb662..277d4c629d 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -26,20 +26,13 @@ from types import MappingProxyType from typing import ( Any, - Callable, - Dict, - Iterable, - Iterator, - List, Literal, - Mapping, Optional, - Tuple, - Type, TypeVar, Union, overload, ) +from collections.abc import Callable, Iterable, Iterator, Mapping import jax import jax.numpy as jnp @@ -154,20 +147,20 @@ def _module_repr(module: 'Module', num_spaces: int = 4): @dataclasses.dataclass class _CallInfo: index: int - path: Tuple[str, ...] + path: tuple[str, ...] module: 'Module' - rngs: Optional[Dict[str, Union[core.scope.PRNGKey, core.scope.LazyRng]]] + rngs: dict[str, core.scope.PRNGKey | core.scope.LazyRng] | None mutable: bool method: str - args: Tuple[Any, ...] - kwargs: Dict[str, Any] + args: tuple[Any, ...] + kwargs: dict[str, Any] outputs: Any @dataclasses.dataclass class _CallInfoContext(threading.local): index: int - calls: List[_CallInfo] + calls: list[_CallInfo] def get_call_index(self) -> int: index = self.index @@ -318,8 +311,8 @@ def __repr__(self) -> str: return f'{self.__class__.__name__}({self._storage})' -Args = Tuple[Any] -Kwargs = Dict[str, Any] +Args = tuple[Any] +Kwargs = dict[str, Any] NextGetter = Callable[..., Any] Interceptor = Callable[[NextGetter, Args, Kwargs, InterceptorContext], Any] _global_interceptor_stack = ThreadLocalStack() @@ -432,7 +425,7 @@ def _sorted_items(x): def _get_suffix_value_pairs( tree_or_leaf: Any, -) -> List[Tuple[str, Type['Module']]]: +) -> list[tuple[str, type['Module']]]: """Helper for naming pytrees of submodules.""" dict_or_leaf = serialization.to_state_dict(tree_or_leaf) if not isinstance(dict_or_leaf, dict) or not dict_or_leaf: @@ -630,7 +623,7 @@ def compact_name_scope_wrapper(self: nn.Module, *args, **kwargs): def _get_local_method_names( cls: Any, exclude: Iterable[str] = () -) -> Tuple[str, ...]: +) -> tuple[str, ...]: """Gets method names of a class, excluding class and static methods. Args: @@ -653,7 +646,7 @@ def _get_local_method_names( def _get_local_descriptor_names( cls: Any, exclude: Iterable[str] = () -) -> Tuple[str, ...]: +) -> tuple[str, ...]: """Gets descriptor names of a class. Args: @@ -799,8 +792,8 @@ class _ModuleInternalState: in_setup: bool = False setup_called: SetupState = SetupState.NEW is_initialized: bool = False - autoname_cursor: Dict[str, int] = dataclasses.field(default_factory=dict) - children: Dict[str, Union[str, 'Module']] = dataclasses.field( + autoname_cursor: dict[str, int] = dataclasses.field(default_factory=dict) + children: dict[str, Union[str, 'Module']] = dataclasses.field( default_factory=dict ) @@ -955,7 +948,7 @@ def __getattr__(self, name): # ----------------------------------------------------------------------------- -def module_field(*, kw_only: bool = False, default: Optional[Any] = ...) -> Any: +def module_field(*, kw_only: bool = False, default: Any | None = ...) -> Any: ... @@ -976,10 +969,10 @@ def module_field(*, kw_only: bool = False, default: Optional[Any] = ...) -> Any: @tpe.dataclass_transform(field_specifiers=(module_field,)) # type: ignore[literal-required] class ModuleBase: if typing.TYPE_CHECKING: - scope: Optional[Scope] + scope: Scope | None _state: _ModuleInternalState _parent_ref: Union['Module', weakref.ReferenceType['Module'], None] - __dataclass_fields__: Dict[str, dataclasses.Field] + __dataclass_fields__: dict[str, dataclasses.Field] class Module(ModuleBase): @@ -1019,7 +1012,7 @@ class Module(ModuleBase): """ if typing.TYPE_CHECKING: - name: Optional[str] = module_field(kw_only=True, default=None) + name: str | None = module_field(kw_only=True, default=None) parent: Union['Module', _Sentinel, None] = module_field( kw_only=True, default=None ) @@ -1049,7 +1042,7 @@ def __init_subclass__(cls, kw_only: bool = False, **kwargs: Any) -> None: cls._wrap_module_attributes() # Set empty class defaults. cls._state = _uninitialized_module_internal_state # type: ignore[attr-defined] - cls.scope: Optional[Scope] = None # type: ignore + cls.scope: Scope | None = None # type: ignore # Handles weak referencing of parent Modules to prevent reference cycles. cls._parent_ref = None # type: ignore[attr-defined] cls.parent = ParentDescriptor() # type: ignore[assignment] @@ -1325,7 +1318,7 @@ def __getattr__(self, name: str) -> Any: ) raise AttributeError(msg) - def __dir__(self) -> List[str]: + def __dir__(self) -> list[str]: """Call setup() before listing attributes.""" self._try_setup() return object.__dir__(self) # type: ignore @@ -1355,7 +1348,7 @@ def __post_init__(self) -> None: # When initializing an unnamed Module inside setup() # initialization is deferred until attachment by __setattr__ # i.e. self.mymodule = MyModule(...) - self.name: Optional[str] + self.name: str | None if ( self.parent._state.in_setup and self.name is None ): # pytype: disable=attribute-error @@ -1535,7 +1528,7 @@ def _name_taken( self, name: str, reuse_scopes: bool = False, - collection: Optional[str] = None, + collection: str | None = None, ) -> bool: assert self.scope is not None if reuse_scopes: @@ -1586,8 +1579,8 @@ def path(self): def clone( self: M, *, - parent: Optional[Union[Scope, 'Module', _Sentinel]] = None, - _deep_clone: Union[bool, weakref.WeakValueDictionary] = False, + parent: Union[Scope, 'Module', _Sentinel] | None = None, + _deep_clone: bool | weakref.WeakValueDictionary = False, _reset_names: bool = False, **updates, ) -> M: @@ -1662,8 +1655,8 @@ def clone_fn(m: Module) -> Module: def copy( self: M, *, - parent: Optional[Union[Scope, 'Module', _Sentinel]] = _unspecified_parent, - name: Optional[str] = None, + parent: Union[Scope, 'Module', _Sentinel] | None = _unspecified_parent, + name: str | None = None, **updates, ) -> M: """Creates a copy of this Module, with optionally updated arguments. @@ -1687,7 +1680,7 @@ def variable( self, col: str, name: str, - init_fn: Optional[Callable[..., T]] = None, + init_fn: Callable[..., T] | None = None, *init_args, ) -> Variable[T]: ... @@ -1697,7 +1690,7 @@ def variable( self, col: str, name: str, - init_fn: Optional[Callable[..., T]] = None, + init_fn: Callable[..., T] | None = None, *init_args, unbox: Literal[True], **init_kwargs, @@ -1709,7 +1702,7 @@ def variable( self, col: str, name: str, - init_fn: Optional[Callable[..., T]] = None, + init_fn: Callable[..., T] | None = None, *init_args, unbox: Literal[False], **init_kwargs, @@ -1721,22 +1714,22 @@ def variable( self, col: str, name: str, - init_fn: Optional[Callable[..., T]] = None, + init_fn: Callable[..., T] | None = None, *init_args, unbox: bool = True, **init_kwargs, - ) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]: + ) -> Variable[T] | Variable[meta.AxisMetadata[T]]: ... def variable( self, col: str, name: str, - init_fn: Optional[Callable[..., T]] = None, + init_fn: Callable[..., T] | None = None, *init_args, unbox: bool = True, **init_kwargs, - ) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]: + ) -> Variable[T] | Variable[meta.AxisMetadata[T]]: """Declares and returns a variable in this Module. See :mod:`flax.core.variables` for more information. See also :meth:`param` @@ -1828,7 +1821,7 @@ def param( *init_args, unbox: bool, **init_kwargs, - ) -> Union[T, meta.AxisMetadata[T]]: + ) -> T | meta.AxisMetadata[T]: ... def param( @@ -1838,7 +1831,7 @@ def param( *init_args, unbox: bool = True, **init_kwargs, - ) -> Union[T, meta.AxisMetadata[T]]: + ) -> T | meta.AxisMetadata[T]: """Declares and returns a parameter in this Module. Parameters are read-only variables in the collection named "params". See @@ -1991,7 +1984,7 @@ def bind( self: M, variables: VariableDict, *args, - rngs: Optional[RNGSequences] = None, + rngs: RNGSequences | None = None, mutable: CollectionFilter = False, ) -> M: """Creates an interactive Module instance by binding variables and RNGs. @@ -2048,7 +2041,7 @@ def bind( scope = core.bind(variables, rngs=rngs, mutable=mutable) return self.clone(parent=scope, _deep_clone=True) - def unbind(self: M) -> Tuple[M, VariableDict]: + def unbind(self: M) -> tuple[M, VariableDict]: """Returns an unbound copy of a Module and its variables. ``unbind`` helps create a stateless version of a bound Module. @@ -2101,12 +2094,12 @@ def apply( self, variables: VariableDict, *args, - rngs: Optional[Union[PRNGKey, RNGSequences]] = None, - method: Union[Callable[..., Any], str, None] = None, + rngs: PRNGKey | RNGSequences | None = None, + method: Callable[..., Any] | str | None = None, mutable: CollectionFilter = False, - capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False, + capture_intermediates: bool | Callable[['Module', str], bool] = False, **kwargs, - ) -> Union[Any, Tuple[Any, Union[FrozenVariableDict, Dict[str, Any]]]]: + ) -> Any | tuple[Any, FrozenVariableDict | dict[str, Any]]: """Applies a module method to variables and returns output and modified variables. Note that ``method`` should be set if one would like to call ``apply`` on a @@ -2259,13 +2252,13 @@ def apply( @traceback_util.api_boundary def init_with_output( self, - rngs: Union[PRNGKey, RNGSequences], + rngs: PRNGKey | RNGSequences, *args, - method: Union[Callable[..., Any], str, None] = None, + method: Callable[..., Any] | str | None = None, mutable: CollectionFilter = DenyList('intermediates'), - capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False, + capture_intermediates: bool | Callable[['Module', str], bool] = False, **kwargs, - ) -> Tuple[Any, Union[FrozenVariableDict, Dict[str, Any]]]: + ) -> tuple[Any, FrozenVariableDict | dict[str, Any]]: """Initializes a module method with variables and returns output and modified variables. Args: @@ -2323,13 +2316,13 @@ def init_with_output( @traceback_util.api_boundary def init( self, - rngs: Union[PRNGKey, RNGSequences], + rngs: PRNGKey | RNGSequences, *args, - method: Union[Callable[..., Any], str, None] = None, + method: Callable[..., Any] | str | None = None, mutable: CollectionFilter = DenyList('intermediates'), - capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False, + capture_intermediates: bool | Callable[['Module', str], bool] = False, **kwargs, - ) -> Union[FrozenVariableDict, Dict[str, Any]]: + ) -> FrozenVariableDict | dict[str, Any]: """Initializes a module method with variables and returns modified variables. ``init`` takes as first argument either a single ``PRNGKey``, or a @@ -2474,9 +2467,9 @@ def init( @traceback_util.api_boundary def lazy_init( self, - rngs: Union[PRNGKey, RNGSequences], + rngs: PRNGKey | RNGSequences, *args, - method: Optional[Callable[..., Any]] = None, + method: Callable[..., Any] | None = None, mutable: CollectionFilter = DenyList('intermediates'), **kwargs, ) -> FrozenVariableDict: @@ -2529,7 +2522,7 @@ def variables(self) -> VariableDict: raise ValueError("Can't access variables on unbound modules") return self.scope.variables() - def get_variable(self, col: str, name: str, default: Optional[T] = None) -> T: + def get_variable(self, col: str, name: str, default: T | None = None) -> T: """Retrieves the value of a Variable. Args: @@ -2732,12 +2725,12 @@ def perturb( def tabulate( self, - rngs: Union[PRNGKey, RNGSequences], + rngs: PRNGKey | RNGSequences, *args, - depth: Optional[int] = None, + depth: int | None = None, show_repeated: bool = False, mutable: CollectionFilter = DenyList('intermediates'), - console_kwargs: Optional[Mapping[str, Any]] = None, + console_kwargs: Mapping[str, Any] | None = None, table_kwargs: Mapping[str, Any] = MappingProxyType({}), column_kwargs: Mapping[str, Any] = MappingProxyType({}), compute_flops: bool = False, @@ -2862,7 +2855,7 @@ def tabulate( def module_paths( self, - rngs: Union[PRNGKey, RNGSequences], + rngs: PRNGKey | RNGSequences, *args, show_repeated: bool = False, mutable: CollectionFilter = DenyList('intermediates'), @@ -2923,10 +2916,10 @@ def module_paths( return {'/'.join(row.path): row.module_copy for row in table} -_ParentType = Union[Type[Module], Scope, Type[_Sentinel], None] +_ParentType = Union[type[Module], Scope, type[_Sentinel], None] -def merge_param(name: str, a: Optional[T], b: Optional[T]) -> T: +def merge_param(name: str, a: T | None, b: T | None) -> T: """Merges construction- and call-time argument. This is a utility for supporting a pattern where a Module hyperparameter @@ -2975,7 +2968,7 @@ def apply( fn: Callable[..., Any], module: Module, mutable: CollectionFilter = False, - capture_intermediates: Union[bool, Callable[[Module, str], bool]] = False, + capture_intermediates: bool | Callable[[Module, str], bool] = False, ) -> Callable[..., Any]: """Creates an apply function to call ``fn`` with a bound module. @@ -3045,8 +3038,8 @@ def init_with_output( fn: Callable[..., Any], module: Module, mutable: CollectionFilter = DenyList('intermediates'), - capture_intermediates: Union[bool, Callable[[Module, str], bool]] = False, -) -> Callable[..., Tuple[Any, Union[FrozenVariableDict, Dict[str, Any]]]]: + capture_intermediates: bool | Callable[[Module, str], bool] = False, +) -> Callable[..., tuple[Any, FrozenVariableDict | dict[str, Any]]]: """Creates an init function to call ``fn`` with a bound module that also returns the function outputs. Unlike ``Module.init_with_output`` this function returns a new function with @@ -3116,8 +3109,8 @@ def init( fn: Callable[..., Any], module: Module, mutable: CollectionFilter = DenyList('intermediates'), - capture_intermediates: Union[bool, Callable[[Module, str], bool]] = False, -) -> Callable[..., Union[FrozenVariableDict, Dict[str, Any]]]: + capture_intermediates: bool | Callable[[Module, str], bool] = False, +) -> Callable[..., FrozenVariableDict | dict[str, Any]]: """Creates an init function to call ``fn`` with a bound module. Unlike ``Module.init`` this function returns a new function with the signature diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index 0680737f9d..64ca0da44f 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -16,7 +16,8 @@ import dataclasses import functools -from typing import Any, Iterable, Optional, Tuple +from typing import Any +from collections.abc import Iterable import jax import jax.numpy as jnp @@ -41,11 +42,11 @@ map_variables = transforms.map_variables -def _canonicalize_axes(rank: int, axes: Axes) -> Tuple[int, ...]: +def _canonicalize_axes(rank: int, axes: Axes) -> tuple[int, ...]: """Returns a tuple of deduplicated, sorted, and positive axes.""" if not isinstance(axes, Iterable): axes = (axes,) - return tuple(set([rank + axis if axis < 0 else axis for axis in axes])) + return tuple({rank + axis if axis < 0 else axis for axis in axes}) def _abs_sq(x): @@ -59,12 +60,12 @@ def _abs_sq(x): def _compute_stats( x: Array, axes: Axes, - dtype: Optional[Dtype], - axis_name: Optional[str] = None, + dtype: Dtype | None, + axis_name: str | None = None, axis_index_groups: Any = None, use_mean: bool = True, use_fast_variance: bool = True, - mask: Optional[Array] = None, + mask: Array | None = None, force_float32_reductions=True, ): """Computes mean and variance statistics. @@ -152,7 +153,7 @@ def _normalize( var: Array, reduction_axes: Axes, feature_axes: Axes, - dtype: Optional[Dtype], + dtype: Dtype | None, param_dtype: Dtype, epsilon: float, use_bias: bool, @@ -291,17 +292,17 @@ class BatchNorm(Module): calculation for the variance. """ - use_running_average: Optional[bool] = None + use_running_average: bool | None = None axis: int = -1 momentum: float = 0.99 epsilon: float = 1e-5 - dtype: Optional[Dtype] = None + dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 use_bias: bool = True use_scale: bool = True bias_init: Initializer = initializers.zeros scale_init: Initializer = initializers.ones - axis_name: Optional[str] = None + axis_name: str | None = None axis_index_groups: Any = None use_fast_variance: bool = True force_float32_reductions: bool = True @@ -310,9 +311,9 @@ class BatchNorm(Module): def __call__( self, x, - use_running_average: Optional[bool] = None, + use_running_average: bool | None = None, *, - mask: Optional[jax.Array] = None, + mask: jax.Array | None = None, ): """Normalizes the input using batch statistics. @@ -451,7 +452,7 @@ class LayerNorm(Module): """ epsilon: float = 1e-6 - dtype: Optional[Dtype] = None + dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 use_bias: bool = True use_scale: bool = True @@ -459,13 +460,13 @@ class LayerNorm(Module): scale_init: Initializer = initializers.ones reduction_axes: Axes = -1 feature_axes: Axes = -1 - axis_name: Optional[str] = None + axis_name: str | None = None axis_index_groups: Any = None use_fast_variance: bool = True force_float32_reductions: bool = True @compact - def __call__(self, x, *, mask: Optional[jax.Array] = None): + def __call__(self, x, *, mask: jax.Array | None = None): """Applies layer normalization on the input. Args: @@ -552,19 +553,19 @@ class RMSNorm(Module): """ epsilon: float = 1e-6 - dtype: Optional[Dtype] = None + dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 use_scale: bool = True scale_init: Initializer = initializers.ones reduction_axes: Axes = -1 feature_axes: Axes = -1 - axis_name: Optional[str] = None + axis_name: str | None = None axis_index_groups: Any = None use_fast_variance: bool = True force_float32_reductions: bool = True @compact - def __call__(self, x, *, mask: Optional[jax.Array] = None): + def __call__(self, x, *, mask: jax.Array | None = None): """Applies RMS layer normalization on the input. Args: @@ -673,23 +674,23 @@ class GroupNorm(Module): calculation for the variance. """ - num_groups: Optional[int] = 32 - group_size: Optional[int] = None + num_groups: int | None = 32 + group_size: int | None = None epsilon: float = 1e-6 - dtype: Optional[Dtype] = None + dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 use_bias: bool = True use_scale: bool = True bias_init: Initializer = initializers.zeros scale_init: Initializer = initializers.ones - reduction_axes: Optional[Axes] = None - axis_name: Optional[str] = None + reduction_axes: Axes | None = None + axis_name: str | None = None axis_index_groups: Any = None use_fast_variance: bool = True force_float32_reductions: bool = True @compact - def __call__(self, x, *, mask: Optional[jax.Array] = None): + def __call__(self, x, *, mask: jax.Array | None = None): """Applies group normalization to the input (arxiv.org/abs/1803.08494). Args: @@ -849,20 +850,20 @@ class InstanceNorm(Module): """ epsilon: float = 1e-6 - dtype: Optional[Dtype] = None + dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 use_bias: bool = True use_scale: bool = True bias_init: Initializer = initializers.zeros scale_init: Initializer = initializers.ones feature_axes: Axes = -1 - axis_name: Optional[str] = None + axis_name: str | None = None axis_index_groups: Any = None use_fast_variance: bool = True force_float32_reductions: bool = True @compact - def __call__(self, x, *, mask: Optional[jax.Array] = None): + def __call__(self, x, *, mask: jax.Array | None = None): """Applies instance normalization on the input. Args: @@ -1023,7 +1024,7 @@ class SpectralNorm(Module): layer_instance: Module n_steps: int = 1 epsilon: float = 1e-12 - dtype: Optional[Dtype] = None + dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 error_on_non_matrix: bool = False collection_name: str = 'batch_stats' @@ -1097,7 +1098,7 @@ def _spectral_normalize(self, path, vs, update_stats): u_var_name = ( self.layer_instance.name + '/' - + '/'.join((dict_key.key for dict_key in path[1:])) + + '/'.join(dict_key.key for dict_key in path[1:]) + '/u' ) u_var = self.variable( @@ -1114,7 +1115,7 @@ def _spectral_normalize(self, path, vs, update_stats): sigma_var_name = ( self.layer_instance.name + '/' - + '/'.join((dict_key.key for dict_key in path[1:])) + + '/'.join(dict_key.key for dict_key in path[1:]) + '/sigma' ) sigma_var = self.variable( @@ -1262,12 +1263,12 @@ class WeightNorm(Module): layer_instance: Module epsilon: float = 1e-12 - dtype: Optional[Dtype] = None + dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 use_scale: bool = True scale_init: Initializer = initializers.ones - feature_axes: Optional[Axes] = -1 - variable_filter: Optional[Iterable] = dataclasses.field( + feature_axes: Axes | None = -1 + variable_filter: Iterable | None = dataclasses.field( default_factory=lambda: {'kernel'} ) @@ -1313,7 +1314,7 @@ def _l2_normalize(self, path, vs): str_path = ( self.layer_instance.name + '/' - + '/'.join((dict_key.key for dict_key in path[1:])) + + '/'.join(dict_key.key for dict_key in path[1:]) ) if self.variable_filter: for variable_name in self.variable_filter: diff --git a/flax/linen/partitioning.py b/flax/linen/partitioning.py index 2fb440c098..71045ba65a 100644 --- a/flax/linen/partitioning.py +++ b/flax/linen/partitioning.py @@ -30,7 +30,8 @@ import functools import re -from typing import (Any, Callable, Mapping, Optional, Tuple) +from typing import (Any, Optional, Tuple) +from collections.abc import Callable, Mapping import flax from flax import linen as nn @@ -125,7 +126,7 @@ def param_with_axes( name: str, init_fn, *init_args, - axes: Optional[Tuple[str, ...]] = None, + axes: tuple[str, ...] | None = None, module: Optional['nn.Module'] = None, **init_kwargs, ): @@ -187,7 +188,7 @@ def __init__( scope, collection: str, name: str, - axes: Optional[Tuple[str, ...]] = None, + axes: tuple[str, ...] | None = None, fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED, ): """Initializes a partitioned variable. @@ -227,7 +228,7 @@ def _core_variable_with_axes( name: str, init_fn: Callable[..., Any], *init_args, - axes: Optional[Tuple[str, ...]] = None, + axes: tuple[str, ...] | None = None, fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED, **init_kwargs, ): @@ -248,7 +249,7 @@ def variable_with_axes( name: str, init_fn, *init_args, - axes: Optional[Tuple[str, ...]] = None, + axes: tuple[str, ...] | None = None, module: Optional['nn.Module'] = None, fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED, **init_kwargs, @@ -424,12 +425,12 @@ def scan_with_axes( split_rngs: Mapping[PRNGSequenceFilter, bool] = {}, in_axes=0, out_axes=0, - length: Optional[int] = None, + length: int | None = None, reverse: bool = False, unroll: int = 1, axis_name: str = 'layers', - axes_collections: Tuple[str, ...] = ('params',), - data_transform: Optional[Callable[..., Any]] = None, + axes_collections: tuple[str, ...] = ('params',), + data_transform: Callable[..., Any] | None = None, methods=None, ) -> 'flax.linen.transforms.Target': """Wrapped version of nn.scan that handles logical axis metadata.""" @@ -478,10 +479,10 @@ def vmap_with_axes( split_rngs: Mapping[PRNGSequenceFilter, bool] = {}, in_axes=0, out_axes=0, - axis_size: Optional[int] = None, - axis_name: Optional[str] = None, + axis_size: int | None = None, + axis_name: str | None = None, partitioning_axis_names: Mapping[Any, str] = {}, - spmd_axis_name: Optional[str] = None, + spmd_axis_name: str | None = None, methods=None, ) -> 'flax.linen.transforms.Target': """Wrapped version of nn.vmap that handles logical axis metadata.""" diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index 7f2d4da26a..274546fc52 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -21,15 +21,9 @@ from functools import partial # pylint: disable=g-importing-member from typing import ( Any, - Callable, - Dict, - Mapping, - Optional, - Sequence, - Tuple, TypeVar, - Union, ) +from collections.abc import Callable, Mapping, Sequence import jax import numpy as np @@ -65,7 +59,7 @@ class RNNCellBase(Module): @nowrap def initialize_carry( - self, rng: PRNGKey, input_shape: Tuple[int, ...] + self, rng: PRNGKey, input_shape: tuple[int, ...] ) -> Carry: """Initialize the RNN cell carry. @@ -133,7 +127,7 @@ class LSTMCell(RNNCellBase): kernel_init: Initializer = default_kernel_init recurrent_kernel_init: Initializer = initializers.orthogonal() bias_init: Initializer = initializers.zeros_init() - dtype: Optional[Dtype] = None + dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 carry_init: Initializer = initializers.zeros_init() @@ -180,8 +174,8 @@ def __call__(self, carry, inputs): @nowrap def initialize_carry( - self, rng: PRNGKey, input_shape: Tuple[int, ...] - ) -> Tuple[Array, Array]: + self, rng: PRNGKey, input_shape: tuple[int, ...] + ) -> tuple[Array, Array]: """Initialize the RNN cell carry. Args: @@ -213,7 +207,7 @@ class DenseParams(Module): bias_init: Initializer = initializers.zeros_init() @compact - def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: + def __call__(self, inputs: Array) -> tuple[Array, Array | None]: k = self.param( 'kernel', self.kernel_init, @@ -280,14 +274,14 @@ class OptimizedLSTMCell(RNNCellBase): kernel_init: Initializer = default_kernel_init recurrent_kernel_init: Initializer = initializers.orthogonal() bias_init: Initializer = initializers.zeros_init() - dtype: Optional[Dtype] = None + dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 carry_init: Initializer = initializers.zeros_init() @compact def __call__( - self, carry: Tuple[Array, Array], inputs: Array - ) -> Tuple[Tuple[Array, Array], Array]: + self, carry: tuple[Array, Array], inputs: Array + ) -> tuple[tuple[Array, Array], Array]: r"""An optimized long short-term memory (LSTM) cell. Args: @@ -304,9 +298,9 @@ def __call__( def _concat_dense( inputs: Array, - params: Mapping[str, Tuple[Array, Optional[Array]]], + params: Mapping[str, tuple[Array, Array | None]], use_bias: bool = True, - ) -> Dict[str, Array]: + ) -> dict[str, Array]: # Concatenates the individual kernels and biases, given in params, into a # single kernel and single bias for efficiency before applying them using # dot_general. @@ -369,8 +363,8 @@ def _concat_dense( @nowrap def initialize_carry( - self, rng: PRNGKey, input_shape: Tuple[int, ...] - ) -> Tuple[Array, Array]: + self, rng: PRNGKey, input_shape: tuple[int, ...] + ) -> tuple[Array, Array]: """Initialize the RNN cell carry. Args: @@ -443,7 +437,7 @@ class SimpleCell(RNNCellBase): kernel_init: Initializer = default_kernel_init recurrent_kernel_init: Initializer = initializers.orthogonal() bias_init: Initializer = initializers.zeros_init() - dtype: Optional[Dtype] = None + dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 carry_init: Initializer = initializers.zeros_init() residual: bool = False @@ -487,7 +481,7 @@ def __call__(self, carry, inputs): return new_carry, new_carry @nowrap - def initialize_carry(self, rng: PRNGKey, input_shape: Tuple[int, ...]): + def initialize_carry(self, rng: PRNGKey, input_shape: tuple[int, ...]): """Initialize the RNN cell carry. Args: @@ -553,7 +547,7 @@ class GRUCell(RNNCellBase): kernel_init: Initializer = default_kernel_init recurrent_kernel_init: Initializer = initializers.orthogonal() bias_init: Initializer = initializers.zeros_init() - dtype: Optional[Dtype] = None + dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 carry_init: Initializer = initializers.zeros_init() @@ -601,7 +595,7 @@ def __call__(self, carry, inputs): return new_h, new_h @nowrap - def initialize_carry(self, rng: PRNGKey, input_shape: Tuple[int, ...]): + def initialize_carry(self, rng: PRNGKey, input_shape: tuple[int, ...]): """Initialize the RNN cell carry. Args: @@ -683,7 +677,7 @@ class MGUCell(RNNCellBase): recurrent_kernel_init: Initializer = initializers.orthogonal() forget_bias_init: Initializer = initializers.ones_init() activation_bias_init: Initializer = initializers.zeros_init() - dtype: Optional[Dtype] = None + dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 carry_init: Initializer = initializers.zeros_init() reset_gate: bool = True @@ -736,7 +730,7 @@ def __call__(self, carry, inputs): return new_h, new_h @nowrap - def initialize_carry(self, rng: PRNGKey, input_shape: Tuple[int, ...]): + def initialize_carry(self, rng: PRNGKey, input_shape: tuple[int, ...]): """Initialize the RNN cell carry. Args: @@ -809,10 +803,10 @@ class ConvLSTMCell(RNNCellBase): features: int kernel_size: Sequence[int] - strides: Optional[Sequence[int]] = None - padding: Union[str, Sequence[Tuple[int, int]]] = 'SAME' + strides: Sequence[int] | None = None + padding: str | Sequence[tuple[int, int]] = 'SAME' use_bias: bool = True - dtype: Optional[Dtype] = None + dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 carry_init: Initializer = initializers.zeros_init() @@ -861,7 +855,7 @@ def __call__(self, carry, inputs): return (new_c, new_h), new_h @nowrap - def initialize_carry(self, rng: PRNGKey, input_shape: Tuple[int, ...]): + def initialize_carry(self, rng: PRNGKey, input_shape: tuple[int, ...]): """Initialize the RNN cell carry. Args: @@ -1023,14 +1017,14 @@ def __call__( self, inputs: jax.Array, *, - initial_carry: Optional[Carry] = None, - init_key: Optional[PRNGKey] = None, - seq_lengths: Optional[Array] = None, - return_carry: Optional[bool] = None, - time_major: Optional[bool] = None, - reverse: Optional[bool] = None, - keep_order: Optional[bool] = None, - ) -> Union[Output, Tuple[Carry, Output]]: + initial_carry: Carry | None = None, + init_key: PRNGKey | None = None, + seq_lengths: Array | None = None, + return_carry: bool | None = None, + time_major: bool | None = None, + reverse: bool | None = None, + keep_order: bool | None = None, + ) -> Output | tuple[Carry, Output]: """ Applies the RNN to the inputs. @@ -1116,7 +1110,7 @@ def __call__( def scan_fn( cell: RNNCellBase, carry: Carry, x: Array - ) -> Union[Tuple[Carry, Array], Tuple[Carry, Tuple[Carry, Array]]]: + ) -> tuple[Carry, Array] | tuple[Carry, tuple[Carry, Array]]: carry, y = cell(carry, x) # When we have a segmentation mask we return the carry as an output # so that we can select the last carry for each sequence later. @@ -1185,7 +1179,7 @@ def _expand_dims_like(x, target): def flip_sequences( inputs: Array, - seq_lengths: Optional[Array], + seq_lengths: Array | None, num_batch_dims: int, time_major: bool, ) -> Array: @@ -1254,14 +1248,14 @@ def __call__( self, inputs: jax.Array, *, - initial_carry: Optional[Carry] = None, - init_key: Optional[PRNGKey] = None, - seq_lengths: Optional[Array] = None, - return_carry: Optional[bool] = None, - time_major: Optional[bool] = None, - reverse: Optional[bool] = None, - keep_order: Optional[bool] = None, - ) -> Union[Output, Tuple[Carry, Output]]: + initial_carry: Carry | None = None, + init_key: PRNGKey | None = None, + seq_lengths: Array | None = None, + return_carry: bool | None = None, + time_major: bool | None = None, + reverse: bool | None = None, + keep_order: bool | None = None, + ) -> Output | tuple[Carry, Output]: ... @@ -1289,14 +1283,14 @@ def __call__( self, inputs: jax.Array, *, - initial_carry: Optional[Carry] = None, - init_key: Optional[PRNGKey] = None, - seq_lengths: Optional[Array] = None, - return_carry: Optional[bool] = None, - time_major: Optional[bool] = None, - reverse: Optional[bool] = None, - keep_order: Optional[bool] = None, - ) -> Union[Output, Tuple[Carry, Output]]: + initial_carry: Carry | None = None, + init_key: PRNGKey | None = None, + seq_lengths: Array | None = None, + return_carry: bool | None = None, + time_major: bool | None = None, + reverse: bool | None = None, + keep_order: bool | None = None, + ) -> Output | tuple[Carry, Output]: if time_major is None: time_major = self.time_major if return_carry is None: @@ -1313,10 +1307,8 @@ def __call__( # for the backward pass and does not intend for them to share parameters. if self.forward_rnn is self.backward_rnn: logging.warning( - ( 'forward_rnn and backward_rnn is the same object, so ' 'they will share parameters.' - ) ) # Encode in the forward direction. diff --git a/flax/linen/spmd.py b/flax/linen/spmd.py index 56a4b96773..93afab7646 100644 --- a/flax/linen/spmd.py +++ b/flax/linen/spmd.py @@ -30,7 +30,8 @@ import enum import functools import threading -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from typing import Any +from collections.abc import Callable, Sequence import jax from jax import lax @@ -107,9 +108,9 @@ def _mesh_assignment_free(new_assignment, existing_assignments): def _logical_to_mesh_axes( - array_dim_names: Optional[Sequence[Optional[str]]], - rules: Optional[LogicalRules] = None, -) -> Optional[List[Union[_UnassignedAxis, None, str, Tuple[str, ...]]]]: + array_dim_names: Sequence[str | None] | None, + rules: LogicalRules | None = None, +) -> list[_UnassignedAxis | None | str | tuple[str, ...]] | None: """Same as logical_to_mesh_axes, but doesn't fill in _unassigned_axis.""" if array_dim_names is None: return None @@ -126,7 +127,7 @@ def _logical_to_mesh_axes( if not isinstance(rules, (tuple, list)): raise ValueError('Unknown axis rule specification type.') # We assign mesh axes using a priority based ruleset over logical axis names. - result: List[Union[_UnassignedAxis, None, str, Tuple[str, ...]]] + result: list[_UnassignedAxis | None | str | tuple[str, ...]] result = [ (_unassigned_axis if isinstance(name, str) else name) for name in array_dim_names @@ -143,9 +144,9 @@ def _logical_to_mesh_axes( def logical_to_mesh_axes( - array_dim_names: Optional[Sequence[Optional[str]]], - rules: Optional[LogicalRules] = None, -) -> Optional[jax.sharding.PartitionSpec]: + array_dim_names: Sequence[str | None] | None, + rules: LogicalRules | None = None, +) -> jax.sharding.PartitionSpec | None: """Compute layout for an array. The rules are in order of precedence, and consist of pairs: @@ -189,7 +190,7 @@ def logical_to_mesh_axes( return jax.sharding.PartitionSpec(*result) -def logical_to_mesh(tree: Any, rules: Optional[LogicalRules] = None) -> Any: +def logical_to_mesh(tree: Any, rules: LogicalRules | None = None) -> Any: """Applies logical_to_mesh_axes to pytrees of logical PartitionSpecs.""" return jax.tree_util.tree_map( lambda x: logical_to_mesh_axes(x, rules), @@ -201,7 +202,7 @@ def logical_to_mesh(tree: Any, rules: Optional[LogicalRules] = None) -> Any: def logical_to_mesh_sharding( tree: Any, mesh: jax.sharding.Mesh, - rules: Optional[LogicalRules] = None, + rules: LogicalRules | None = None, ) -> Any: """Convert pytrees of logical PartitionSpecs to shardings.""" return jax.tree_util.tree_map( @@ -227,8 +228,8 @@ class RulesFallback(enum.Enum): def _with_sharding_constraint( x: Array, - axis_resources: Optional[jax.sharding.PartitionSpec], - mesh: Optional[jax.sharding.Mesh] = None, + axis_resources: jax.sharding.PartitionSpec | None, + mesh: jax.sharding.Mesh | None = None, ): """Wrapper for lax.with_sharding_constraint, no-op on cpu or outside jit.""" if jax.devices()[0].platform == 'cpu' or ( @@ -246,8 +247,8 @@ def _with_sharding_constraint_one_fallback( axis_resources: LogicalPartitionSpec, x: Array, fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED, - rules: Optional[LogicalRules] = None, - mesh: Optional[jax.sharding.Mesh] = None, + rules: LogicalRules | None = None, + mesh: jax.sharding.Mesh | None = None, ): """Either imposes a sharding constraint or applies fallback.""" mesh_axes = _logical_to_mesh_axes(axis_resources, rules) @@ -284,8 +285,8 @@ def _is_logical_spec(x): def with_logical_constraint( x: ArrayPytree, logical_axis_resources: LogicalPartitionSpecPytree, - rules: Optional[LogicalRules] = None, - mesh: Optional[jax.sharding.Mesh] = None, + rules: LogicalRules | None = None, + mesh: jax.sharding.Mesh | None = None, fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED, ): """Version of jit's with_sharding_constraint that uses logical axis names.""" @@ -313,7 +314,7 @@ def with_logical_constraint( class LogicallyPartitioned(meta.Partitioned): - rules: Optional[LogicalRules] = struct.field(default=None, pytree_node=False) + rules: LogicalRules | None = struct.field(default=None, pytree_node=False) def unbox(self, apply_constraint=True) -> Any: """Returns the wrapped value with the partitioning constraint applied.""" @@ -331,8 +332,8 @@ def unbox(self, apply_constraint=True) -> Any: def with_logical_partitioning( fn: Callable[..., Any], names: LogicalNames, - mesh: Optional[jax.sharding.Mesh] = None, - rules: Optional[LogicalRules] = None, + mesh: jax.sharding.Mesh | None = None, + rules: LogicalRules | None = None, ) -> Callable[..., LogicallyPartitioned]: """Wraps a function's return value with LogicallyPartitioned. diff --git a/flax/linen/stochastic.py b/flax/linen/stochastic.py index c245ebb2a9..629b5dcf22 100644 --- a/flax/linen/stochastic.py +++ b/flax/linen/stochastic.py @@ -14,7 +14,7 @@ """Stochastic modules.""" -from typing import Optional, Sequence +from collections.abc import Sequence import jax.numpy as jnp from jax import lax, random @@ -62,15 +62,15 @@ class Dropout(Module): rate: float broadcast_dims: Sequence[int] = () - deterministic: Optional[bool] = None + deterministic: bool | None = None rng_collection: str = 'dropout' @compact def __call__( self, inputs, - deterministic: Optional[bool] = None, - rng: Optional[PRNGKey] = None, + deterministic: bool | None = None, + rng: PRNGKey | None = None, ): """Applies a random dropout mask to the input. diff --git a/flax/linen/summary.py b/flax/linen/summary.py index e5b176e458..badfa18178 100644 --- a/flax/linen/summary.py +++ b/flax/linen/summary.py @@ -21,17 +21,8 @@ from types import MappingProxyType from typing import ( Any, - Callable, - Dict, - Iterable, - List, - Mapping, - Optional, - Sequence, - Set, - Tuple, - Union, ) +from collections.abc import Callable, Iterable, Mapping, Sequence import jax import jax.numpy as jnp @@ -68,7 +59,7 @@ def render(self) -> str: @dataclasses.dataclass class _ArrayRepresentation(_ValueRepresentation): - shape: Tuple[int, ...] + shape: tuple[int, ...] dtype: Any @classmethod @@ -130,13 +121,13 @@ class Row: vjp_flops: FLOPs cost of calling the VJP of the module method. """ - path: Tuple[str, ...] + path: tuple[str, ...] module_copy: module_lib.Module method: str inputs: Any outputs: Any - module_variables: Dict[str, Dict[str, Any]] - counted_variables: Dict[str, Dict[str, Any]] + module_variables: dict[str, dict[str, Any]] + counted_variables: dict[str, dict[str, Any]] flops: int vjp_flops: int @@ -148,7 +139,7 @@ def __post_init__(self): def size_and_bytes( self, collections: Iterable[str] - ) -> Dict[str, Tuple[int, int]]: + ) -> dict[str, tuple[int, int]]: return { col: ( _size_and_bytes(self.counted_variables[col]) @@ -159,7 +150,7 @@ def size_and_bytes( } -class Table(List[Row]): +class Table(list[Row]): """A list of Row objects. Table inherits from `List[Row]` so it has all the methods of a list, however @@ -182,11 +173,11 @@ def __init__( def tabulate( module: module_lib.Module, - rngs: Union[PRNGKey, RNGSequences], - depth: Optional[int] = None, + rngs: PRNGKey | RNGSequences, + depth: int | None = None, show_repeated: bool = False, mutable: CollectionFilter = DenyList('intermediates'), - console_kwargs: Optional[Mapping[str, Any]] = None, + console_kwargs: Mapping[str, Any] | None = None, table_kwargs: Mapping[str, Any] = MappingProxyType({}), column_kwargs: Mapping[str, Any] = MappingProxyType({}), compute_flops: bool = False, @@ -433,7 +424,7 @@ def apply_vjp(variables, rngs, dynamic_leaves): def _get_module_table( module: module_lib.Module, - depth: Optional[int], + depth: int | None, show_repeated: bool, compute_flops: bool, compute_vjp_flops: bool, @@ -451,10 +442,10 @@ def _get_variables(): calls = module_lib._context.call_info_stack[-1].calls calls.sort(key=lambda c: c.index) - collections: Set[str] = set(variables.keys()) + collections: set[str] = set(variables.keys()) rows = [] - all_paths: Set[Tuple[str, ...]] = set(call.path for call in calls) - visited_paths: Set[Tuple[str, ...]] = set() + all_paths: set[tuple[str, ...]] = {call.path for call in calls} + visited_paths: set[tuple[str, ...]] = set() for c in calls: call_depth = len(c.path) @@ -497,10 +488,10 @@ def _get_variables(): def _get_module_variables( - path: Tuple[str, ...], + path: tuple[str, ...], variables: FrozenVariableDict, - all_paths: Set[Tuple[str, ...]], -) -> Tuple[MutableVariableDict, Any]: + all_paths: set[tuple[str, ...]], +) -> tuple[MutableVariableDict, Any]: """A function that takes a path and variables structure and returns a (module_variables, submodule_variables) tuple for that path. @@ -510,9 +501,9 @@ def _get_module_variables( """ module_variables = _get_path_variables(path, variables) submodule_variables: Any = {collection: {} for collection in module_variables} - all_keys = set( + all_keys = { key for collection in module_variables.values() for key in collection - ) + } for key in all_keys: submodule_path = path + (key,) @@ -527,7 +518,7 @@ def _get_module_variables( def _get_path_variables( - path: Tuple[str, ...], variables: FrozenVariableDict + path: tuple[str, ...], variables: FrozenVariableDict ) -> MutableVariableDict: """A function that takes a path and a variables structure and returns the variable structure at that path. @@ -566,10 +557,10 @@ def _process_inputs(args, kwargs) -> Any: def _render_table( table: Table, - console_extras: Optional[Mapping[str, Any]], + console_extras: Mapping[str, Any] | None, table_kwargs: Mapping[str, Any], column_kwargs: Mapping[str, Any], - non_params_cols: List[str], + non_params_cols: list[str], ) -> str: """A function that renders a Table to a string representation using rich.""" console_kwargs = {'force_terminal': True, 'force_jupyter': False} @@ -675,7 +666,7 @@ def _size_and_bytes_repr(size: int, num_bytes: int) -> str: return f'{size:,} [dim]({bytes_repr})[/dim]' -def _size_and_bytes(pytree: Any) -> Tuple[int, int]: +def _size_and_bytes(pytree: Any) -> tuple[int, int]: leaves = jax.tree_util.tree_leaves(pytree) size = sum(x.size for x in leaves if hasattr(x, 'size')) num_bytes = sum( diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index 4a8843646d..947b64e162 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -42,17 +42,10 @@ import inspect from typing import ( Any, - Callable, - Dict, - Iterable, - Mapping, - Optional, - Sequence, - Tuple, - Type, TypeVar, Union, ) +from collections.abc import Callable, Iterable, Mapping, Sequence from flax import core from flax import errors, struct, traceback_util @@ -109,8 +102,8 @@ class VariablePlaceholder: class InstancePlaceholder: """Marks module instances in a JAX-compatible way when lifting arguments.""" - cls: Type[Any] = struct.field(pytree_node=False) - attrs: Dict[Any, Any] = struct.field(pytree_node=False) + cls: type[Any] = struct.field(pytree_node=False) + attrs: dict[Any, Any] = struct.field(pytree_node=False) id: int = struct.field(pytree_node=False) @@ -739,7 +732,7 @@ def core_fn(scopes, module_hash, *args, **kwargs): # Utility to wrap a class or to use as decorator in def of class method. # ----------------------------------------------------------------------------- -TransformTarget = Union[Type[Module], Callable[..., Any]] +TransformTarget = Union[type[Module], Callable[..., Any]] Target = TypeVar('Target', bound=TransformTarget) @@ -772,7 +765,7 @@ def lift_transform( def lift_direct_transform( transform: Callable[..., Any], - targets: Tuple[Callable[..., Any], ...], + targets: tuple[Callable[..., Any], ...], mdl: Module, *args, multi_scope=True, @@ -803,9 +796,9 @@ def vmap( split_rngs: Mapping[PRNGSequenceFilter, bool] = FrozenDict(), in_axes=0, out_axes=0, - axis_size: Optional[int] = None, - axis_name: Optional[str] = None, - spmd_axis_name: Optional[str] = None, + axis_size: int | None = None, + axis_name: str | None = None, + spmd_axis_name: str | None = None, metadata_params: Mapping[Any, Any] = {}, methods=None, ) -> Target: @@ -894,11 +887,11 @@ def jit( target: Target, variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, - static_argnums: Union[int, Iterable[int]] = (), - static_argnames: Union[str, Iterable[str]] = (), - donate_argnums: Union[int, Iterable[int]] = (), + static_argnums: int | Iterable[int] = (), + static_argnames: str | Iterable[str] = (), + donate_argnums: int | Iterable[int] = (), device=None, - backend: Union[str, None] = None, + backend: str | None = None, methods=None, ) -> Target: """Lifted version of ``jax.jit``. @@ -980,8 +973,8 @@ def checkpoint( rngs: PRNGSequenceFilter = True, concrete: bool = False, prevent_cse: bool = True, - static_argnums: Union[int, Tuple[int, ...]] = (), - policy: Optional[Callable[..., bool]] = None, + static_argnums: int | tuple[int, ...] = (), + policy: Callable[..., bool] | None = None, methods=None, ) -> Target: """Lifted version of ``jax.checkpoint``. @@ -1066,8 +1059,8 @@ def checkpoint( def remat_scan( target: Target, - lengths: Optional[Sequence[int]] = (), - policy: Optional[Callable[..., bool]] = None, + lengths: Sequence[int] | None = (), + policy: Callable[..., bool] | None = None, variable_broadcast: CollectionFilter = False, variable_carry: CollectionFilter = False, variable_axes: Mapping[CollectionFilter, InOutScanAxis] = FrozenDict( @@ -1138,10 +1131,10 @@ def scan( split_rngs: Mapping[PRNGSequenceFilter, bool] = FrozenDict(), in_axes=0, out_axes=0, - length: Optional[int] = None, + length: int | None = None, reverse: bool = False, unroll: int = 1, - data_transform: Optional[Callable[..., Any]] = None, + data_transform: Callable[..., Any] | None = None, metadata_params: Mapping[Any, Any] = {}, methods=None, _split_transpose: bool = False, @@ -1659,7 +1652,7 @@ def jvp( variable_tangents, variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, -) -> Union[Tuple[Any, Callable[..., Any]], Tuple[Any, Callable[..., Any], Any]]: +) -> tuple[Any, Callable[..., Any]] | tuple[Any, Callable[..., Any], Any]: """A lifted version of ``jax.jvp``. See ``jax.jvp`` for the unlifted Jacobian-vector product (forward gradient). @@ -2110,7 +2103,7 @@ def wrapped_fn(self, *args, **kwargs): def add_metadata_axis( target: Target, variable_axes: Mapping[CollectionFilter, InOutAxis] = FrozenDict(), - metadata_params: Dict[Any, Any] = {}, + metadata_params: dict[Any, Any] = {}, ) -> Target: """A helper to manipulate boxed axis metadata. diff --git a/flax/nnx/examples/gemma/helpers.py b/flax/nnx/examples/gemma/helpers.py index e0570d7d68..37cc800b7f 100644 --- a/flax/nnx/examples/gemma/helpers.py +++ b/flax/nnx/examples/gemma/helpers.py @@ -30,7 +30,8 @@ from __future__ import annotations -from typing import Callable, Optional, Tuple, TypeVar, Union +from typing import TypeVar +from collections.abc import Callable import flax from flax import nnx from flax.typing import VariableDict # pylint: disable=g-importing-member,g-multiple-import @@ -38,7 +39,7 @@ M = TypeVar('M', bound='nnx.Module') -def _flatten_path(path: Tuple[Union[str, int], ...]) -> str: +def _flatten_path(path: tuple[str | int, ...]) -> str: def f(item) -> str: if isinstance(item, str): return f'{item}' @@ -53,9 +54,9 @@ def f(item) -> str: def module_from_linen_variables( module_factory: Callable[[], M], variables: VariableDict, - map_key_fn: Optional[ - Callable[[Tuple[str, ...]], Tuple[Union[str, int], ...]] - ] = None, + map_key_fn: None | ( + Callable[[tuple[str, ...]], tuple[str | int, ...]] + ) = None, ) -> M: """Returns an `nnx.Module` initialized with the `variables` of a linen module. @@ -70,7 +71,7 @@ def module_from_linen_variables( """ if map_key_fn is None: - def map_key_fn(path: Tuple[str, ...]) -> Tuple[Union[str, int], ...]: + def map_key_fn(path: tuple[str, ...]) -> tuple[str | int, ...]: return path[1:] if 'params' in variables else path mdl: M = nnx.eval_shape(module_factory) diff --git a/flax/nnx/examples/gemma/helpers_test.py b/flax/nnx/examples/gemma/helpers_test.py index cb7e718222..2bac0ab8c2 100644 --- a/flax/nnx/examples/gemma/helpers_test.py +++ b/flax/nnx/examples/gemma/helpers_test.py @@ -30,7 +30,6 @@ from __future__ import annotations -from typing import Tuple from absl.testing import absltest from absl.testing import parameterized @@ -127,7 +126,7 @@ def test_different_structure(self, inputs_shape, num_features, use_bias): for in_f, out_f, b in zip(in_features, out_features, use_bias) ]) - def _map_key_fn(key: Tuple[str, ...]) -> Tuple[str | int, ...]: + def _map_key_fn(key: tuple[str, ...]) -> tuple[str | int, ...]: new_key = [] for k in key[1:]: if k.startswith('layers_'): diff --git a/flax/nnx/examples/gemma/layers.py b/flax/nnx/examples/gemma/layers.py index da9ee9a865..241bbba52a 100644 --- a/flax/nnx/examples/gemma/layers.py +++ b/flax/nnx/examples/gemma/layers.py @@ -30,7 +30,8 @@ from __future__ import annotations -from typing import Any, Sequence, Union +from typing import Any, Union +from collections.abc import Sequence from flax import nnx import flax.linen as nn diff --git a/flax/nnx/examples/gemma/modules.py b/flax/nnx/examples/gemma/modules.py index cb4a9fc880..bf62faa465 100644 --- a/flax/nnx/examples/gemma/modules.py +++ b/flax/nnx/examples/gemma/modules.py @@ -31,7 +31,8 @@ from __future__ import annotations import enum -from typing import Any, Sequence, Union +from typing import Any, Union +from collections.abc import Sequence from flax import nnx import flax.linen as nn @@ -95,8 +96,8 @@ def __init__( attn_type: AttentionType, *, rngs: nnx.Rngs, - attn_logits_soft_cap: Union[float, None] = None, - sliding_window_size: Union[int, None] = None, + attn_logits_soft_cap: float | None = None, + sliding_window_size: int | None = None, ): if attn_type == AttentionType.LOCAL_SLIDING and sliding_window_size is None: raise ValueError( @@ -130,9 +131,9 @@ def __call__( self, x: Array, segment_pos: Array, - cache: Union[LayerCache, None], + cache: LayerCache | None, attn_mask: Array, - ) -> tuple[Union[LayerCache, None], Array]: + ) -> tuple[LayerCache | None, Array]: seq_len = x.shape[1] if self.use_qkv_einsum: @@ -291,8 +292,8 @@ def __init__( attn_type: AttentionType, *, rngs: nnx.Rngs, - attn_logits_soft_cap: Union[float, None] = None, - sliding_window_size: Union[int, None] = None, + attn_logits_soft_cap: float | None = None, + sliding_window_size: int | None = None, ): self.pre_attention_norm = layers.RMSNorm(embed_dim, rngs=rngs) self.attn = Attention( @@ -321,9 +322,9 @@ def __call__( self, x: jax.Array, segment_pos: jax.Array, - cache: Union[LayerCache, None], + cache: LayerCache | None, attn_mask: jax.Array, - ) -> tuple[Union[LayerCache, None], jax.Array]: + ) -> tuple[LayerCache | None, jax.Array]: inputs_normalized = self.pre_attention_norm(x) cache, attn_output = self.attn( inputs_normalized, diff --git a/flax/nnx/examples/gemma/params.py b/flax/nnx/examples/gemma/params.py index d774d91f16..4072dd3e5e 100644 --- a/flax/nnx/examples/gemma/params.py +++ b/flax/nnx/examples/gemma/params.py @@ -29,7 +29,8 @@ """Utils for loading Gemma params.""" import functools -from typing import Any, Mapping, Optional +from typing import Any +from collections.abc import Mapping import jax import jax.numpy as jnp @@ -47,7 +48,7 @@ def load_and_format_params(path: str) -> Params: return nested_params -def load_metadata(path: str) -> Optional[Any]: +def load_metadata(path: str) -> Any | None: """Loads metadata from a checkpoint path.""" checkpointer = orbax.checkpoint.PyTreeCheckpointer() metadata = checkpointer.metadata(path) diff --git a/flax/nnx/examples/gemma/sampler.py b/flax/nnx/examples/gemma/sampler.py index 4f6e4ebf05..5fd065303b 100644 --- a/flax/nnx/examples/gemma/sampler.py +++ b/flax/nnx/examples/gemma/sampler.py @@ -35,7 +35,6 @@ from collections.abc import Sequence import dataclasses -from typing import Union import chex from flax import nnx @@ -100,10 +99,10 @@ class _SamplingState: total_sampling_steps: int # Fixed-size buffer for accumulating the output logits. - logits_buffer: Union[jnp.ndarray, None] = None # [B, L, V] + logits_buffer: jnp.ndarray | None = None # [B, L, V] # List of tokens that are forbidden to be generated. - forbidden_token_ids: Union[Sequence[int], None] = None + forbidden_token_ids: Sequence[int] | None = None @dataclasses.dataclass @@ -211,7 +210,7 @@ def init_sample_state( all_input_ids: list[jax.Array], total_sampling_steps: int, include_logits: bool = False, - forbidden_token_ids: Union[Sequence[int], None] = None, + forbidden_token_ids: Sequence[int] | None = None, ) -> _SamplingState: """Initializes the sampling state given input prompts.""" batch_size = len(all_input_ids) @@ -310,7 +309,7 @@ def __call__( total_generation_steps: int, echo: bool = False, return_logits: bool = True, - forbidden_tokens: Union[Sequence[str], None] = None, + forbidden_tokens: Sequence[str] | None = None, ) -> SamplerOutput: """Samples a completion of the input string. diff --git a/flax/nnx/examples/gemma/sampler_test.py b/flax/nnx/examples/gemma/sampler_test.py index 28976c214b..e3993a633e 100644 --- a/flax/nnx/examples/gemma/sampler_test.py +++ b/flax/nnx/examples/gemma/sampler_test.py @@ -28,7 +28,7 @@ # ============================================================================ """Minimal test for sampler.""" -from typing import Iterable +from collections.abc import Iterable from absl.testing import absltest from flax import nnx diff --git a/flax/nnx/examples/gemma/transformer.py b/flax/nnx/examples/gemma/transformer.py index 4a99e24fad..bba583e46d 100644 --- a/flax/nnx/examples/gemma/transformer.py +++ b/flax/nnx/examples/gemma/transformer.py @@ -31,7 +31,7 @@ from __future__ import annotations import dataclasses -from typing import Iterable, Tuple, Union +from collections.abc import Iterable from flax import nnx import helpers @@ -55,15 +55,15 @@ class TransformerConfig: num_heads: int head_dim: int num_kv_heads: int - final_logit_softcap: Union[float, None] + final_logit_softcap: float | None use_post_attn_norm: bool use_post_ffw_norm: bool attention_types: Iterable[modules.AttentionType] - attn_logits_soft_cap: Union[float, None] = None - sliding_window_size: Union[int, None] = None + attn_logits_soft_cap: float | None = None + sliding_window_size: int | None = None @classmethod - def from_path(cls, path: str) -> 'TransformerConfig': + def from_path(cls, path: str) -> TransformerConfig: """Creates a TransformerConfig from loaded parameters.""" metadata = params_lib.load_metadata(path) params = params_lib.load_params(path) @@ -83,7 +83,7 @@ def from_path(cls, path: str) -> 'TransformerConfig': raise ValueError('Verify checkpoint path is a Gemma checkpoint') @classmethod - def from_params(cls, params: params_lib.Params) -> 'TransformerConfig': + def from_params(cls, params: params_lib.Params) -> TransformerConfig: """Creates a TransformerConfig from loaded parameters. Use for V1 models only. @@ -186,7 +186,7 @@ def gemma_9b(cls): ) -def _map_linen_var_names(key: Tuple[str, ...]) -> Tuple[Union[str, int], ...]: +def _map_linen_var_names(key: tuple[str, ...]) -> tuple[str | int, ...]: new_key = [] for k in key: if k.startswith('layer_'): @@ -204,7 +204,7 @@ class Transformer(nnx.Module): """Gemma transformer.""" @classmethod - def from_params(cls, params: params_lib.Params) -> 'Transformer': + def from_params(cls, params: params_lib.Params) -> Transformer: config = TransformerConfig.from_params(params) return helpers.module_from_linen_variables( module_factory=lambda: cls(config, rngs=nnx.Rngs(params=0)), @@ -243,9 +243,9 @@ def __call__( self, last_tokens: Array, # [B, L] positions: Array, # [B, L] - cache: Union[Cache, None], # (sequence length L') + cache: Cache | None, # (sequence length L') attention_mask: Array, # [B, L, L'] - ) -> tuple[Array, Union[Cache, None]]: + ) -> tuple[Array, Cache | None]: """Transformer forward pass. You can run this forward pass two ways: with or without an attention kv diff --git a/flax/nnx/examples/lm1b/input_pipeline.py b/flax/nnx/examples/lm1b/input_pipeline.py index e87db94b27..2c265b7b5e 100644 --- a/flax/nnx/examples/lm1b/input_pipeline.py +++ b/flax/nnx/examples/lm1b/input_pipeline.py @@ -15,7 +15,6 @@ """Input pipeline for a LM1B dataset.""" import os -from typing import Dict, List, Optional, Union import tensorflow as tf import tensorflow_datasets as tfds @@ -24,7 +23,7 @@ from configs import default AUTOTUNE = tf.data.experimental.AUTOTUNE -Features = Dict[str, tf.Tensor] +Features = dict[str, tf.Tensor] class NormalizeFeatureNamesOp: @@ -67,8 +66,8 @@ def get_raw_dataset( def pack_dataset( dataset: tf.data.Dataset, - key2length: Union[int, Dict[str, int]], - keys: Optional[List[str]] = None, + key2length: int | dict[str, int], + keys: list[str] | None = None, ) -> tf.data.Dataset: """Creates a 'packed' version of a dataset on-the-fly. @@ -149,7 +148,7 @@ def my_fn(x): def _pack_with_tf_ops( - dataset: tf.data.Dataset, keys: List[str], key2length: Dict[str, int] + dataset: tf.data.Dataset, keys: list[str], key2length: dict[str, int] ) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. @@ -275,7 +274,7 @@ def true_fn(): def preprocess_data( dataset, shuffle: bool, - num_epochs: Optional[int] = 1, + num_epochs: int | None = 1, pack_examples: bool = True, shuffle_buffer_size: int = 1024, max_length: int = 512, @@ -321,7 +320,7 @@ def get_datasets( config: default.Config, *, n_devices: int, - vocab_path: Optional[str] = None, + vocab_path: str | None = None, ): """Load and return dataset of batched examples for use during training.""" if vocab_path is None: diff --git a/flax/nnx/examples/lm1b/models.py b/flax/nnx/examples/lm1b/models.py index bb80e1eee8..1fcf98ab94 100644 --- a/flax/nnx/examples/lm1b/models.py +++ b/flax/nnx/examples/lm1b/models.py @@ -25,7 +25,7 @@ from __future__ import annotations import dataclasses -from typing import Any, Optional +from typing import Any import jax import jax.numpy as jnp @@ -57,7 +57,7 @@ class TransformerConfig: attention_dropout_rate: float = 0.1 kernel_init: nnx.Initializer = nnx.initializers.xavier_uniform() bias_init: nnx.Initializer = nnx.initializers.normal(stddev=1e-6) - posemb_init: Optional[nnx.Initializer] = None + posemb_init: nnx.Initializer | None = None axis_rules: default.MeshRules = dataclasses.field( default_factory=default.MeshRules ) diff --git a/flax/nnx/examples/lm1b/tokenizer.py b/flax/nnx/examples/lm1b/tokenizer.py index 811e445cff..8ca1e29cf2 100644 --- a/flax/nnx/examples/lm1b/tokenizer.py +++ b/flax/nnx/examples/lm1b/tokenizer.py @@ -18,7 +18,8 @@ import os import tempfile import time -from typing import Any, Dict, Iterable, Tuple +from typing import Any +from collections.abc import Iterable import jax import tensorflow as tf @@ -26,14 +27,14 @@ from absl import logging from sentencepiece import SentencePieceTrainer -Features = Dict[str, tf.Tensor] +Features = dict[str, tf.Tensor] def _dump_chars_to_textfile( dataset: tf.data.Dataset, maxchars: int = int(1e7), data_keys=('inputs', 'targets'), -) -> Tuple[str, int]: +) -> tuple[str, int]: """Write part of a TFDS sentence dataset to lines in a text file. Args: @@ -140,7 +141,7 @@ def load_or_train_tokenizer( vocab_path: str, vocab_size: int, max_corpus_chars: int, - data_keys: Tuple[str, str] = ('inputs', 'targets'), + data_keys: tuple[str, str] = ('inputs', 'targets'), ): """Loads the tokenizer at `vocab_path` or trains a one from `dataset`.""" try: diff --git a/flax/nnx/examples/lm1b/utils.py b/flax/nnx/examples/lm1b/utils.py index d2afc3c3bc..b18b6e4691 100644 --- a/flax/nnx/examples/lm1b/utils.py +++ b/flax/nnx/examples/lm1b/utils.py @@ -15,7 +15,8 @@ # Copied over from MaxText (https://github.com/google/maxtext/blob/main/MaxText/max_utils.py). import logging -from typing import Any, Callable +from typing import Any +from collections.abc import Callable from typing_extensions import Protocol, runtime_checkable import jax diff --git a/flax/nnx/nnx/compat/module.py b/flax/nnx/nnx/compat/module.py index 0af4d38f58..808d699daf 100644 --- a/flax/nnx/nnx/compat/module.py +++ b/flax/nnx/nnx/compat/module.py @@ -35,9 +35,9 @@ @dataclasses.dataclass class CompactContext: - module: 'Module' + module: Module type_counter: defaultdict[type, int] = dataclasses.field( - default_factory=lambda: defaultdict(lambda: 0) + default_factory=lambda: defaultdict(int) ) diff --git a/flax/nnx/nnx/graph.py b/flax/nnx/nnx/graph.py index 7c878b8066..50d20c94e6 100644 --- a/flax/nnx/nnx/graph.py +++ b/flax/nnx/nnx/graph.py @@ -157,7 +157,7 @@ class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]): ] -_node_impl_for_type: dict[type, 'NodeImpl[tp.Any, tp.Any, tp.Any]'] = {} +_node_impl_for_type: dict[type, NodeImpl[tp.Any, tp.Any, tp.Any]] = {} def register_graph_node_type( @@ -246,7 +246,7 @@ class NodeDef(tp.Generic[Node], reprlib.Representable): type: tp.Type[Node] index: int attributes: tuple[Key, ...] - subgraphs: _HashableMapping[Key, tp.Union['NodeDef[tp.Any]', Index]] + subgraphs: _HashableMapping[Key, tp.Union[NodeDef[tp.Any], Index]] static_fields: _HashableMapping[Key, tp.Any] leaves: _HashableMapping[Key, Index | None] metadata: tp.Any @@ -257,7 +257,7 @@ def create( type: tp.Type[Node], index: int, attributes: tuple[Key, ...], - subgraphs: tp.Iterable[tuple[Key, tp.Union['NodeDef[tp.Any]', Index]]], + subgraphs: tp.Iterable[tuple[Key, tp.Union[NodeDef[tp.Any], Index]]], static_fields: tp.Iterable[tuple[Key, tp.Any]], leaves: tp.Iterable[tuple[Key, Index | None]], metadata: tp.Any, @@ -343,7 +343,7 @@ def __eq__(self, other): def apply( self, state: GraphState, *states: GraphState - ) -> ApplyCaller[tuple['GraphDef[Node]', GraphState]]: + ) -> ApplyCaller[tuple[GraphDef[Node], GraphState]]: accessor = DelayedAccessor() def _apply( diff --git a/flax/nnx/nnx/nn/attention.py b/flax/nnx/nnx/nn/attention.py index 260011f85f..cb531729f1 100644 --- a/flax/nnx/nnx/nn/attention.py +++ b/flax/nnx/nnx/nn/attention.py @@ -17,7 +17,8 @@ from __future__ import annotations import functools -from typing import Any, Callable, Optional +from typing import Any +from collections.abc import Callable import jax import jax.numpy as jnp @@ -47,15 +48,15 @@ def dot_product_attention_weights( query: Array, key: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, + bias: Array | None = None, + mask: Array | None = None, broadcast_dropout: bool = True, - dropout_rng: Optional[Array] = None, + dropout_rng: Array | None = None, dropout_rate: float = 0.0, deterministic: bool = False, - dtype: Optional[Dtype] = None, + dtype: Dtype | None = None, precision: PrecisionLike = None, - module: Optional[Module] = None, + module: Module | None = None, ): """Computes dot-product attention weights given query and key. @@ -138,15 +139,15 @@ def dot_product_attention( query: Array, key: Array, value: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, + bias: Array | None = None, + mask: Array | None = None, broadcast_dropout: bool = True, - dropout_rng: Optional[Array] = None, + dropout_rng: Array | None = None, dropout_rate: float = 0.0, deterministic: bool = False, - dtype: Optional[Dtype] = None, + dtype: Dtype | None = None, precision: PrecisionLike = None, - module: Optional[Module] = None, + module: Module | None = None, ): """Computes dot-product attention given query, key, and value. @@ -662,7 +663,7 @@ def make_causal_mask( def combine_masks( - *masks: Optional[Array], dtype: Dtype = jnp.float32 + *masks: Array | None, dtype: Dtype = jnp.float32 ) -> Array | None: """Combine attention masks. diff --git a/flax/nnx/nnx/nn/dtypes.py b/flax/nnx/nnx/nn/dtypes.py index 223d401499..a1f60d20eb 100644 --- a/flax/nnx/nnx/nn/dtypes.py +++ b/flax/nnx/nnx/nn/dtypes.py @@ -13,7 +13,6 @@ # limitations under the License. import typing as tp -from typing import Optional from flax.typing import Dtype from jax import numpy as jnp @@ -21,7 +20,7 @@ def canonicalize_dtype( - *args, dtype: Optional[Dtype] = None, inexact: bool = True + *args, dtype: Dtype | None = None, inexact: bool = True ) -> Dtype: """Canonicalize an optional dtype to the definitive dtype. diff --git a/flax/nnx/nnx/nn/normalization.py b/flax/nnx/nnx/nn/normalization.py index c0b349c093..6fcfd3b2fe 100644 --- a/flax/nnx/nnx/nn/normalization.py +++ b/flax/nnx/nnx/nn/normalization.py @@ -34,7 +34,7 @@ def _canonicalize_axes(rank: int, axes: Axes) -> tp.Tuple[int, ...]: """Returns a tuple of deduplicated, sorted, and positive axes.""" if not isinstance(axes, tp.Iterable): axes = (axes,) - return tuple(set([rank + axis if axis < 0 else axis for axis in axes])) + return tuple({rank + axis if axis < 0 else axis for axis in axes}) def _abs_sq(x): diff --git a/flax/nnx/nnx/nn/stochastic.py b/flax/nnx/nnx/nn/stochastic.py index c3f743a4d9..93de8771f9 100644 --- a/flax/nnx/nnx/nn/stochastic.py +++ b/flax/nnx/nnx/nn/stochastic.py @@ -28,7 +28,7 @@ from __future__ import annotations import dataclasses -from typing import Sequence +from collections.abc import Sequence import jax import jax.numpy as jnp diff --git a/flax/nnx/nnx/proxy_caller.py b/flax/nnx/nnx/proxy_caller.py index 9c1886e736..a9eca15ef8 100644 --- a/flax/nnx/nnx/proxy_caller.py +++ b/flax/nnx/nnx/proxy_caller.py @@ -84,18 +84,18 @@ def __init__( def __call__(self, *args, **kwargs): return self._callable(self._accessor, *args, **kwargs) - def __getattr__(self, name) -> 'CallableProxy': + def __getattr__(self, name) -> CallableProxy: return CallableProxy(self._callable, getattr(self._accessor, name)) - def __getitem__(self, key) -> 'CallableProxy': + def __getitem__(self, key) -> CallableProxy: return CallableProxy(self._callable, self._accessor[key]) class ApplyCaller(tp.Protocol, tp.Generic[A]): - def __getattr__(self, __name) -> 'ApplyCaller[A]': + def __getattr__(self, __name) -> ApplyCaller[A]: ... - def __getitem__(self, __name) -> 'ApplyCaller[A]': + def __getitem__(self, __name) -> ApplyCaller[A]: ... def __call__(self, *args, **kwargs) -> tuple[tp.Any, A]: diff --git a/flax/nnx/nnx/transforms/looping.py b/flax/nnx/nnx/transforms/looping.py index 3aeafda49c..3c205bf1f1 100644 --- a/flax/nnx/nnx/transforms/looping.py +++ b/flax/nnx/nnx/transforms/looping.py @@ -375,11 +375,11 @@ def scan_apply_wrapper(*args, **kwargs): ) # infer length - lengths: set[int] = set( + lengths: set[int] = { x.shape[0] # type: ignore for x, axis in zip(flat_scan, flatdef.flat_axes) if axis is not None - ) + } if len(lengths) > 1: raise ValueError( @@ -499,7 +499,7 @@ def constructor( split_rngs: filterlib.Filter = ..., transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), scan_output: bool = True, - ) -> tp.Callable[..., 'Scan[MA]']: + ) -> tp.Callable[..., Scan[MA]]: def _create_scan(*args, **kwargs): return Scan( module_constructor=module_constructor, diff --git a/flax/nnx/nnx/transforms/transforms.py b/flax/nnx/nnx/transforms/transforms.py index 7cb04b4550..a9471b48b2 100644 --- a/flax/nnx/nnx/transforms/transforms.py +++ b/flax/nnx/nnx/transforms/transforms.py @@ -392,7 +392,7 @@ def constructor( # nnx specific donate_state: bool = False, constrain_state: bool | tp.Callable[[State], State] = False, - ) -> tp.Callable[..., 'Jit[MA]']: + ) -> tp.Callable[..., Jit[MA]]: def _create_jit(*args, **kwargs): return Jit( module_constructor=module_constructor, @@ -714,7 +714,7 @@ def constructor( return_value: bool = False, *, wrt: filterlib.Filter = variables.Param, - ) -> tp.Callable[..., 'Grad[MA]']: + ) -> tp.Callable[..., Grad[MA]]: def _create_grad(*args, **kwargs): return Grad( module_constructor=module_constructor, @@ -802,7 +802,7 @@ def constructor( prevent_cse: bool = True, static_argnums: int | tuple[int, ...] = (), policy: tp.Callable[..., bool] | None = None, - ) -> tp.Callable[..., 'Remat[MA]']: + ) -> tp.Callable[..., Remat[MA]]: def create_remat(*args, **kwargs): return Remat( module_constructor=module_constructor, diff --git a/flax/nnx/nnx/traversals.py b/flax/nnx/nnx/traversals.py index af56503326..eb9e5896bb 100644 --- a/flax/nnx/nnx/traversals.py +++ b/flax/nnx/nnx/traversals.py @@ -17,7 +17,7 @@ from __future__ import annotations from collections.abc import Callable, Mapping -from typing import Any, Union, overload +from typing import Any, overload from flax import struct @@ -40,7 +40,7 @@ def flatten_mapping(xs: Mapping[Any, Any], /, *, keep_empty_nodes: bool = False, - is_leaf: Union[None, IsLeafCallable] = None, + is_leaf: None | IsLeafCallable = None, sep: None = None ) -> dict[tuple[Any, ...], Any]: ... @@ -50,7 +50,7 @@ def flatten_mapping(xs: Mapping[Any, Any], /, *, keep_empty_nodes: bool = False, - is_leaf: Union[None, IsLeafCallable] = None, + is_leaf: None | IsLeafCallable = None, sep: str, ) -> dict[str, Any]: ... @@ -59,8 +59,8 @@ def flatten_mapping(xs: Mapping[Any, Any], /, *, keep_empty_nodes: bool = False, - is_leaf: Union[None, IsLeafCallable] = None, - sep: Union[None, str] = None + is_leaf: None | IsLeafCallable = None, + sep: None | str = None ) -> dict[Any, Any]: """Flatten a nested mapping. @@ -94,7 +94,7 @@ def flatten_mapping(xs: Mapping[Any, Any], xs, Mapping ), f'expected Mapping; got {type(xs).__qualname__}' - def _key(path: tuple[Any, ...]) -> Union[tuple[Any, ...], str]: + def _key(path: tuple[Any, ...]) -> tuple[Any, ...] | str: if sep is None: return path return sep.join(path) @@ -140,7 +140,7 @@ def unflatten_mapping(xs: Mapping[str, Any], def unflatten_mapping(xs: Any, /, *, - sep: Union[str, None] = None + sep: str | None = None ) -> dict[Any, Any]: """Unflatten a mapping. diff --git a/flax/nnx/nnx/variables.py b/flax/nnx/nnx/variables.py index 7153d6f571..1bed4ded51 100644 --- a/flax/nnx/nnx/variables.py +++ b/flax/nnx/nnx/variables.py @@ -50,7 +50,7 @@ AddAxisHook = tp.Callable[[V, AxisName, AxisIndex], None] RemoveAxisHook = tp.Callable[[V, AxisName, AxisIndex], None] -VariableTypeCache: dict[str, tp.Type['Variable[tp.Any]']] = {} +VariableTypeCache: dict[str, tp.Type[Variable[tp.Any]]] = {} class Empty: @@ -86,8 +86,8 @@ class VariableMetadata(tp.Generic[A]): set_value_hooks: tuple[SetValueHook[A], ...] = () get_value_hooks: tuple[GetValueHook[A], ...] = () create_value_hooks: tuple[CreateValueHook[A], ...] = () - add_axis_hooks: tuple[AddAxisHook['Variable[A]'], ...] = () - remove_axis_hooks: tuple[RemoveAxisHook['Variable[A]'], ...] = () + add_axis_hooks: tuple[AddAxisHook[Variable[A]], ...] = () + remove_axis_hooks: tuple[RemoveAxisHook[Variable[A]], ...] = () metadata: tp.Mapping[str, tp.Any] = dataclasses.field(default_factory=dict) @@ -160,8 +160,8 @@ class Variable(tp.Generic[A], reprlib.Representable): set_value_hooks: tuple[SetValueHook[A], ...] get_value_hooks: tuple[GetValueHook[A], ...] create_value_hooks: tuple[CreateValueHook[A], ...] - add_axis_hooks: tuple[AddAxisHook['Variable[A]'], ...] - remove_axis_hooks: tuple[RemoveAxisHook['Variable[A]'], ...] + add_axis_hooks: tuple[AddAxisHook[Variable[A]], ...] + remove_axis_hooks: tuple[RemoveAxisHook[Variable[A]], ...] _trace_state: tracers.TraceState def __init__( @@ -177,11 +177,11 @@ def __init__( CreateValueHook[A], tp.Sequence[CreateValueHook[A]] ] = (), add_axis_hooks: tp.Union[ - AddAxisHook['Variable[A]'], tp.Sequence[AddAxisHook['Variable[A]']] + AddAxisHook[Variable[A]], tp.Sequence[AddAxisHook[Variable[A]]] ] = (), remove_axis_hooks: tp.Union[ - RemoveAxisHook['Variable[A]'], - tp.Sequence[RemoveAxisHook['Variable[A]']], + RemoveAxisHook[Variable[A]], + tp.Sequence[RemoveAxisHook[Variable[A]]], ] = (), **metadata: tp.Any, ): @@ -304,10 +304,10 @@ def _setattr(self, name: str, value: tp.Any): object.__setattr__(self, name, value) @classmethod - def state(cls, value: A, **metadata) -> 'VariableState[A]': + def state(cls, value: A, **metadata) -> VariableState[A]: return cls(value, **metadata).to_state() - def copy_from(self, other: 'Variable[A]') -> None: + def copy_from(self, other: Variable[A]) -> None: if type(self) is not type(other): raise ValueError( f'Cannot copy from incompatible container, ' @@ -322,7 +322,7 @@ def copy_from(self, other: 'Variable[A]') -> None: vars_dict.clear() vars_dict.update(other_vars, _trace_state=trace_state) - def copy_from_state(self, variable_state: 'VariableState[A]'): + def copy_from_state(self, variable_state: VariableState[A]): trace_state = self._trace_state variable_vars = vars(self) variable_vars.clear() @@ -368,12 +368,12 @@ def __eq__(self, other: object) -> bool: return type(self) is type(other) and vars(other) == vars(self) @tp.overload - def replace(self, value: B, **kwargs) -> 'Variable[B]': ... + def replace(self, value: B, **kwargs) -> Variable[B]: ... @tp.overload - def replace(self, **kwargs) -> 'Variable[A]': ... + def replace(self, **kwargs) -> Variable[A]: ... - def replace(self, value: tp.Any = MISSING, **kwargs) -> 'Variable[tp.Any]': + def replace(self, value: tp.Any = MISSING, **kwargs) -> Variable[tp.Any]: if value is not MISSING: kwargs['raw_value'] = value @@ -407,14 +407,14 @@ def replace(self, value: tp.Any = MISSING, **kwargs) -> 'Variable[tp.Any]': vars(obj).update(attributes) return obj - def copy(self: 'Variable[A]') -> 'Variable[A]': + def copy(self: Variable[A]) -> Variable[A]: obj = object.__new__(type(self)) attributes = vars(self).copy() attributes['_trace_state'] = tracers.TraceState() vars(obj).update(attributes) return obj - def to_state(self: 'Variable[A]') -> 'VariableState[A]': + def to_state(self: Variable[A]) -> VariableState[A]: metadata = vars(self).copy() del metadata['raw_value'] del metadata['_trace_state'] @@ -794,7 +794,7 @@ def __penzai_repr__(self, path, subtree_renderer): subtree_renderer=subtree_renderer, ) - def replace(self, value: B) -> 'VariableState[B]': + def replace(self, value: B) -> VariableState[B]: return VariableState(self.type, value, **self.get_metadata()) def to_variable(self) -> Variable[A]: @@ -863,11 +863,11 @@ def with_metadata( CreateValueHook[A], tp.Sequence[CreateValueHook[A]] ] = (), add_axis_hooks: tp.Union[ - AddAxisHook['Variable[A]'], tp.Sequence[AddAxisHook['Variable[A]']] + AddAxisHook[Variable[A]], tp.Sequence[AddAxisHook[Variable[A]]] ] = (), remove_axis_hooks: tp.Union[ - RemoveAxisHook['Variable[A]'], - tp.Sequence[RemoveAxisHook['Variable[A]']], + RemoveAxisHook[Variable[A]], + tp.Sequence[RemoveAxisHook[Variable[A]]], ] = (), **metadata: tp.Any, ) -> F: diff --git a/flax/serialization.py b/flax/serialization.py index 485894f486..c4dcb9ba4e 100644 --- a/flax/serialization.py +++ b/flax/serialization.py @@ -20,13 +20,13 @@ import enum import threading from contextlib import contextmanager -from typing import Any, Dict, List +from typing import Any import jax import msgpack import numpy as np -_STATE_DICT_REGISTRY: Dict[Any, Any] = {} +_STATE_DICT_REGISTRY: dict[Any, Any] = {} class _ErrorContext(threading.local): @@ -64,7 +64,7 @@ def _is_namedtuple(x): return isinstance(x, tuple) and hasattr(x, '_fields') -def from_state_dict(target, state: Dict[str, Any], name: str = '.'): +def from_state_dict(target, state: dict[str, Any], name: str = '.'): """Restores the state of the given target using a state dict. This function takes the current target as an argument. This @@ -93,7 +93,7 @@ def from_state_dict(target, state: Dict[str, Any], name: str = '.'): return ty_from_state_dict(target, state) -def to_state_dict(target) -> Dict[str, Any]: +def to_state_dict(target) -> dict[str, Any]: """Returns a dictionary with the state of the given target.""" if _is_namedtuple(target): ty = _NamedTuple @@ -137,11 +137,11 @@ def register_serialization_state( _STATE_DICT_REGISTRY[ty] = (ty_to_state_dict, ty_from_state_dict) -def _list_state_dict(xs: List[Any]) -> Dict[str, Any]: +def _list_state_dict(xs: list[Any]) -> dict[str, Any]: return {str(i): to_state_dict(x) for i, x in enumerate(xs)} -def _restore_list(xs, state_dict: Dict[str, Any]) -> List[Any]: +def _restore_list(xs, state_dict: dict[str, Any]) -> list[Any]: if len(state_dict) != len(xs): raise ValueError( 'The size of the list and the state dict do not match,' @@ -155,8 +155,8 @@ def _restore_list(xs, state_dict: Dict[str, Any]) -> List[Any]: return ys -def _dict_state_dict(xs: Dict[str, Any]) -> Dict[str, Any]: - str_keys = set(str(k) for k in xs.keys()) +def _dict_state_dict(xs: dict[str, Any]) -> dict[str, Any]: + str_keys = {str(k) for k in xs.keys()} if len(str_keys) != len(xs): raise ValueError( 'Dict keys do not have a unique string representation: ' @@ -165,7 +165,7 @@ def _dict_state_dict(xs: Dict[str, Any]) -> Dict[str, Any]: return {str(key): to_state_dict(value) for key, value in xs.items()} -def _restore_dict(xs, states: Dict[str, Any]) -> Dict[str, Any]: +def _restore_dict(xs, states: dict[str, Any]) -> dict[str, Any]: diff = set(map(str, xs.keys())).difference(states.keys()) if diff: raise ValueError( @@ -180,11 +180,11 @@ def _restore_dict(xs, states: Dict[str, Any]) -> Dict[str, Any]: } -def _namedtuple_state_dict(nt) -> Dict[str, Any]: +def _namedtuple_state_dict(nt) -> dict[str, Any]: return {key: to_state_dict(getattr(nt, key)) for key in nt._fields} -def _restore_namedtuple(xs, state_dict: Dict[str, Any]): +def _restore_namedtuple(xs, state_dict: dict[str, Any]): """Rebuild namedtuple from serialized dict.""" if set(state_dict.keys()) == {'name', 'fields', 'values'}: # TODO(jheek): remove backward compatible named tuple restoration early 2022 @@ -341,7 +341,7 @@ def _np_convert_in_place(d): _dict_to_tuple = lambda dct: tuple(dct[str(i)] for i in range(len(dct))) -def _chunk(arr) -> Dict[str, Any]: +def _chunk(arr) -> dict[str, Any]: """Convert array to a canonical dictionary of chunked arrays.""" chunksize = max(1, int(MAX_CHUNK_SIZE / arr.dtype.itemsize)) data = {'__msgpack_chunked_array__': True, 'shape': _tuple_to_dict(arr.shape)} @@ -353,7 +353,7 @@ def _chunk(arr) -> Dict[str, Any]: return data -def _unchunk(data: Dict[str, Any]): +def _unchunk(data: dict[str, Any]): """Convert canonical dictionary of chunked arrays back into array.""" assert '__msgpack_chunked_array__' in data shape = _dict_to_tuple(data['shape']) diff --git a/flax/testing/benchmark.py b/flax/testing/benchmark.py index 88ee9a99be..ea55db31ba 100644 --- a/flax/testing/benchmark.py +++ b/flax/testing/benchmark.py @@ -29,7 +29,6 @@ import json import os import tempfile -from typing import Dict from absl import flags, logging from absl.testing import absltest @@ -186,7 +185,7 @@ def report_wall_time(self, wall_time: float): self._update_reported_name() self._reported_wall_time = wall_time - def report_metrics(self, metrics: Dict[str, float]): + def report_metrics(self, metrics: dict[str, float]): """Report metrics for the benchmark.""" self._update_reported_name() self._reported_metrics.update(metrics) @@ -195,7 +194,7 @@ def report_metric(self, name: str, value: float): """Report a single metric for the benchmark.""" self.report_metrics({name: value}) - def report_extras(self, extras: Dict[str, str]): + def report_extras(self, extras: dict[str, str]): """Report extras for the benchmark.""" self._update_reported_name() self._reported_extras.update(extras) @@ -232,7 +231,7 @@ def _get_test_name(self, prefix='test_'): # Prefix the name with the class name. class_name = type(calling_class).__name__ - name = '{}.{}'.format(class_name, name) + name = f'{class_name}.{name}' return name def _update_reported_name(self): diff --git a/flax/training/checkpoints.py b/flax/training/checkpoints.py index 3650f9163e..747ba63431 100644 --- a/flax/training/checkpoints.py +++ b/flax/training/checkpoints.py @@ -28,15 +28,8 @@ from concurrent.futures import thread from typing import ( Any, - Callable, - Dict, - Iterable, - List, - Optional, - Tuple, - Type, - Union, ) +from collections.abc import Callable, Iterable import jax import orbax.checkpoint as ocp @@ -100,12 +93,12 @@ def _is_multiprocess_array(value: Any) -> bool: def _checkpoint_path( - ckpt_dir: str, step: Union[int, float, str], prefix: str = 'checkpoint_' + ckpt_dir: str, step: int | float | str, prefix: str = 'checkpoint_' ) -> str: return os.path.join(ckpt_dir, f'{prefix}{step}') -def _checkpoint_path_step(path: str) -> Optional[float]: +def _checkpoint_path_step(path: str) -> float | None: """Returns the step number of a checkpoint path.""" for s in SIGNED_FLOAT_RE.split(path)[::-1]: if SIGNED_FLOAT_RE.match(s): @@ -168,8 +161,8 @@ def save_async(self, task: Callable[[], Any]): def _split_mp_arrays( - target: Dict[str, Any] -) -> Tuple[Dict[str, Any], List[Tuple[MultiprocessArrayType, str]]]: + target: dict[str, Any] +) -> tuple[dict[str, Any], list[tuple[MultiprocessArrayType, str]]]: """Split out the multiprocess arrays from the target pytree to save.""" # When target is a single leaf instead of a pytree dict. if not isinstance(target, (core.FrozenDict, dict)): @@ -189,7 +182,7 @@ def _split_mp_arrays( def _make_mpa_dirs( - mpa_targets: List[Tuple[MultiprocessArrayType, str]], tmp_path: str + mpa_targets: list[tuple[MultiprocessArrayType, str]], tmp_path: str ): # Temporary array path is not used in GCS. if tmp_path.startswith('gs://'): @@ -207,15 +200,15 @@ def _make_mpa_dirs( def _save_mpas( gda_manager, - mpa_targets: List[Tuple[MultiprocessArrayType, str]], + mpa_targets: list[tuple[MultiprocessArrayType, str]], tmp_path: str, final_path: str, base_path: str, keep: int, overwrite: bool, - keep_every_n_steps: Optional[int], + keep_every_n_steps: int | None, ckpt_start_time: float, - async_manager: Optional[AsyncManager] = None, + async_manager: AsyncManager | None = None, ): """Save the multiprocess arrays given the paths.""" mpa_list, mpa_subpaths = zip(*mpa_targets) @@ -255,10 +248,10 @@ def _save_mpas( def _restore_mpas( state_dict, - target: Optional[Any], + target: Any | None, ckpt_path: str, - step: Optional[Union[int, float]], - gda_manager: Optional[GlobalAsyncCheckpointManager], + step: int | float | None, + gda_manager: GlobalAsyncCheckpointManager | None, allow_partial: bool = False, ): """Restore the multiprocess arrays given the target structure and type.""" @@ -270,9 +263,9 @@ def _check_mpa_errors(): raise errors.MPARestoreTargetRequiredError(ckpt_path, step) def _safe_deserialize( - target_mpas: List[Tuple[Tuple[Any, ...], MultiprocessArrayType, str]], + target_mpas: list[tuple[tuple[Any, ...], MultiprocessArrayType, str]], gda_manager: Any, - ) -> List[MultiprocessArrayType]: + ) -> list[MultiprocessArrayType]: gda_manager.wait_until_finished() # Check if reading from GCS and the array dir is potentially corrupted. @@ -347,7 +340,7 @@ def _safe_deserialize( return state_dict -def natural_sort(file_list: Iterable[str], signed: bool = True) -> List[str]: +def natural_sort(file_list: Iterable[str], signed: bool = True) -> list[str]: """Natural sort for filenames with numerical substrings. Args: @@ -388,12 +381,12 @@ def _remove_invalid_ckpts( base_path: str, keep: int, overwrite: bool, - keep_every_n_steps: Optional[int], + keep_every_n_steps: int | None, has_mpa: bool, ) -> None: """Clean up the checkpoint space according to `overwrite`, `keep`, and `keep_every_n_steps` parameters.""" dir_path, prefix = os.path.split(base_path) - checkpoint_files: List[Any] = [ + checkpoint_files: list[Any] = [ pathlib.PurePath(c) for c in _allowempty_listdir(dir_path) ] checkpoint_files = [ @@ -450,11 +443,11 @@ def _save_commit( base_path: str, keep: int, overwrite: bool, - keep_every_n_steps: Optional[int], + keep_every_n_steps: int | None, ckpt_start_time: float, has_mpa: bool, write_commit_success: bool, - async_manager: Optional[AsyncManager] = None, + async_manager: AsyncManager | None = None, ) -> None: """Commit changes after saving checkpoints to disk. @@ -510,7 +503,7 @@ def _check_overwrite_error( ): """Throw error if a ckpt file of this step or higher already exists.""" dir_path, prefix = os.path.split(base_path) - checkpoint_files: List[Any] = [ + checkpoint_files: list[Any] = [ pathlib.PurePath(c) for c in _allowempty_listdir(dir_path) ] checkpoint_files = [ @@ -534,12 +527,12 @@ def _check_overwrite_error( def _save_main_ckpt_file( target: bytes, has_mpa: bool, - paths: Tuple[str, str], + paths: tuple[str, str], base_path: str, step: int, keep: int, overwrite: bool, - keep_every_n_steps: Optional[int], + keep_every_n_steps: int | None, ckpt_start_time: float, ): """Save the main checkpoint file via file system.""" @@ -565,10 +558,10 @@ def _save_main_ckpt_file( def _get_checkpoint_paths( - ckpt_dir: Union[str, os.PathLike], - step: Union[int, float], + ckpt_dir: str | os.PathLike, + step: int | float, prefix: str = 'checkpoint_', -) -> Tuple[str, str, str]: +) -> tuple[str, str, str]: """Generate the checkpoint paths used in this save operation.""" ckpt_dir = os.fspath(ckpt_dir) # Pathlib -> str logging.info('Saving checkpoint at step: %s', step) @@ -581,15 +574,15 @@ def _get_checkpoint_paths( def save_checkpoint( - ckpt_dir: Union[str, os.PathLike], + ckpt_dir: str | os.PathLike, target: PyTree, - step: Union[int, float], + step: int | float, prefix: str = 'checkpoint_', keep: int = 1, overwrite: bool = False, - keep_every_n_steps: Optional[int] = None, - async_manager: Optional[AsyncManager] = None, - orbax_checkpointer: Optional[ocp.Checkpointer] = None, + keep_every_n_steps: int | None = None, + async_manager: AsyncManager | None = None, + orbax_checkpointer: ocp.Checkpointer | None = None, ) -> str: """Save a checkpoint of the model. Suitable for single-host. @@ -750,16 +743,16 @@ def save_main_ckpt_task(): def save_checkpoint_multiprocess( - ckpt_dir: Union[str, os.PathLike], + ckpt_dir: str | os.PathLike, target: PyTree, - step: Union[int, float], + step: int | float, prefix: str = 'checkpoint_', keep: int = 1, overwrite: bool = False, - keep_every_n_steps: Optional[int] = None, - async_manager: Optional[AsyncManager] = None, - gda_manager: Optional[GlobalAsyncCheckpointManager] = None, - orbax_checkpointer: Optional[ocp.Checkpointer] = None, + keep_every_n_steps: int | None = None, + async_manager: AsyncManager | None = None, + gda_manager: GlobalAsyncCheckpointManager | None = None, + orbax_checkpointer: ocp.Checkpointer | None = None, ) -> str: """Save a checkpoint of the model in multi-process environment. @@ -927,8 +920,8 @@ def save_main_ckpt_task(): def _all_checkpoints( - ckpt_dir: Union[str, os.PathLike], prefix: str = 'checkpoint_' -) -> List[str]: + ckpt_dir: str | os.PathLike, prefix: str = 'checkpoint_' +) -> list[str]: """Retrieve all checkpoint paths in directory. Args: @@ -939,7 +932,7 @@ def _all_checkpoints( Sorted list of checkpoint paths or empty list if no checkpoints were found. """ ckpt_dir = os.fspath(ckpt_dir) # Pathlib -> str - checkpoint_files: List[Any] = [ + checkpoint_files: list[Any] = [ pathlib.PurePath(c) for c in _allowempty_listdir(ckpt_dir) ] checkpoint_files = [ @@ -958,8 +951,8 @@ def _all_checkpoints( def latest_checkpoint( - ckpt_dir: Union[str, os.PathLike], prefix: str = 'checkpoint_' -) -> Optional[str]: + ckpt_dir: str | os.PathLike, prefix: str = 'checkpoint_' +) -> str | None: """Retrieve the path of the latest checkpoint in a directory. Args: @@ -977,10 +970,10 @@ def latest_checkpoint( def available_steps( - ckpt_dir: Union[str, os.PathLike], + ckpt_dir: str | os.PathLike, prefix: str = 'checkpoint_', - step_type: Type = int, -) -> List[Union[int, float]]: + step_type: type = int, +) -> list[int | float]: """Return step numbers of available checkpoints in a directory. @@ -1004,15 +997,15 @@ def available_steps( def restore_checkpoint( - ckpt_dir: Union[str, os.PathLike], - target: Optional[Any], - step: Optional[Union[int, float]] = None, + ckpt_dir: str | os.PathLike, + target: Any | None, + step: int | float | None = None, prefix: str = 'checkpoint_', parallel: bool = True, - gda_manager: Optional[GlobalAsyncCheckpointManager] = None, + gda_manager: GlobalAsyncCheckpointManager | None = None, allow_partial_mpa_restoration: bool = False, - orbax_checkpointer: Optional[ocp.Checkpointer] = None, - orbax_transforms: Optional[Dict] = None, + orbax_checkpointer: ocp.Checkpointer | None = None, + orbax_transforms: dict | None = None, ) -> PyTree: """Restore last/best checkpoint from checkpoints in path. @@ -1241,7 +1234,7 @@ def __call__(self, x): if not isinstance(params, (dict, core.FrozenDict)): return params params_renamed = {} - counts: Dict[Any, Any] = {} + counts: dict[Any, Any] = {} names = natural_sort(params.keys()) for name in names: value = params[name] diff --git a/flax/training/dynamic_scale.py b/flax/training/dynamic_scale.py index 1effc6b24f..60012748bc 100644 --- a/flax/training/dynamic_scale.py +++ b/flax/training/dynamic_scale.py @@ -15,7 +15,8 @@ """Dynamic loss scaling for mixed precision gradients.""" import functools -from typing import Any, Callable, NamedTuple, Optional, Sequence, Union +from typing import Any, NamedTuple +from collections.abc import Callable, Sequence import jax import jax.numpy as jnp @@ -83,16 +84,16 @@ def loss_fn(p): growth_interval: int = struct.field(pytree_node=False, default=2000) fin_steps: int = 0 scale: float = 65536.0 - minimum_scale: Optional[float] = struct.field( + minimum_scale: float | None = struct.field( pytree_node=False, default=jnp.finfo(jnp.float32).tiny ) def value_and_grad( self, fun: Callable[..., Any], - argnums: Union[int, Sequence[int]] = 0, + argnums: int | Sequence[int] = 0, has_aux: bool = False, - axis_name: Optional[str] = None, + axis_name: str | None = None, ) -> Callable[..., DynamicScaleResult]: """Wrapper around `jax.value_and_grad`. diff --git a/flax/training/orbax_utils.py b/flax/training/orbax_utils.py index d873ecdcf8..d30c124771 100644 --- a/flax/training/orbax_utils.py +++ b/flax/training/orbax_utils.py @@ -17,7 +17,7 @@ import dataclasses import inspect import warnings -from typing import Any, Optional +from typing import Any import jax import numpy as np @@ -42,7 +42,7 @@ def save_args_from_target(target: Any) -> Any: def maybe_construct_transformations( - target: Any, transforms: Optional[Any] + target: Any, transforms: Any | None ) -> Any: if transforms is not None: return transforms @@ -54,7 +54,7 @@ def maybe_construct_transformations( return flat_transforms -def restore_args_from_target(target: Any, mesh: Optional[Mesh] = None) -> Any: +def restore_args_from_target(target: Any, mesh: Mesh | None = None) -> Any: """Creates Orbax `restore_args` given a target Pytree. Args: diff --git a/flax/training/train_state.py b/flax/training/train_state.py index 1188dedc0f..bbce765c1e 100644 --- a/flax/training/train_state.py +++ b/flax/training/train_state.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Union +from typing import Any +from collections.abc import Callable import optax @@ -71,7 +72,7 @@ class TrainState(struct.PyTreeNode): opt_state: The state for ``tx``. """ - step: Union[int, jax.Array] + step: int | jax.Array apply_fn: Callable = struct.field(pytree_node=False) params: core.FrozenDict[str, Any] = struct.field(pytree_node=True) tx: optax.GradientTransformation = struct.field(pytree_node=False) diff --git a/flax/traverse_util.py b/flax/traverse_util.py index 1d66144d98..007f425ac2 100644 --- a/flax/traverse_util.py +++ b/flax/traverse_util.py @@ -57,7 +57,8 @@ import copy import dataclasses import warnings -from typing import Any, Callable +from typing import Any +from collections.abc import Callable import jax diff --git a/flax/typing.py b/flax/typing.py index 8d0fc5855f..aa4cc00cd3 100644 --- a/flax/typing.py +++ b/flax/typing.py @@ -14,19 +14,14 @@ from typing import ( Any, - Callable, - Dict, Generic, - Hashable, - Mapping, Optional, Protocol, - Sequence, - Tuple, TypeVar, Union, runtime_checkable, ) +from collections.abc import Callable, Hashable, Mapping, Sequence import jax from flax.core import FrozenDict @@ -38,7 +33,7 @@ Array = Union[jax.Array, Any] PRNGKey = jax.Array -RNGSequences = Dict[str, PRNGKey] +RNGSequences = dict[str, PRNGKey] Dtype = Union[jax.typing.DTypeLike, Any] Shape = Sequence[int] K = TypeVar('K') @@ -50,7 +45,7 @@ def __lt__(self: K, value: K, /) -> bool: Path = str -PathParts = Tuple[Key, ...] +PathParts = tuple[Key, ...] Leaf = Any @@ -61,14 +56,14 @@ def __lt__(self: K, value: K, /) -> bool: None, str, jax.lax.Precision, - Tuple[str, str], - Tuple[jax.lax.Precision, jax.lax.Precision], + tuple[str, str], + tuple[jax.lax.Precision, jax.lax.Precision], ] DotGeneralT = Callable[..., Array] ConvGeneralDilatedT = Callable[..., Array] -PaddingLike = Union[str, int, Sequence[Union[int, Tuple[int, int]]]] -LaxPadding = Union[str, Sequence[Tuple[int, int]]] +PaddingLike = Union[str, int, Sequence[Union[int, tuple[int, int]]]] +LaxPadding = Union[str, Sequence[tuple[int, int]]] # Initializers @@ -79,14 +74,14 @@ def __lt__(self: K, value: K, /) -> bool: # Collections Collection = Mapping[str, Any] -MutableCollection = Dict[str, Any] +MutableCollection = dict[str, Any] # Dicts VariableDict = Mapping[str, Collection] FrozenVariableDict = FrozenDict[str, Collection] -MutableVariableDict = Dict[str, MutableCollection] +MutableVariableDict = dict[str, MutableCollection] PRNGFoldable = Union[int, str] @@ -118,14 +113,14 @@ class Out(Generic[T]): # SPMD -LogicalNames = Tuple[Union[str, None], ...] +LogicalNames = tuple[Union[str, None], ...] # Maps each logical axis to physical mesh, can be either None (replicated), # one physical axis or a tuple of physical axes. -LogicalRules = Sequence[Tuple[str, Union[str, Tuple[str, ...], None]]] +LogicalRules = Sequence[tuple[str, Union[str, tuple[str, ...], None]]] ArrayPytree = Any # pylint: disable=invalid-name LogicalPartitionSpec = Any # pylint: disable=invalid-name LogicalPartitionSpecPytree = Any # pylint: disable=invalid-name PartitionSpecPytree = Any # pylint: disable=invalid-name -Sharding = Tuple[Optional[str], ...] +Sharding = tuple[Optional[str], ...] diff --git a/pyproject.toml b/pyproject.toml index 8227432156..8507626bb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "flax" -requires-python = ">=3.9" +requires-python = ">=3.10" description = "Flax: A neural network library for JAX designed for flexibility" keywords = [] authors = [ diff --git a/tests/core/core_scope_test.py b/tests/core/core_scope_test.py index 2376a6b9ec..208dd3dd57 100644 --- a/tests/core/core_scope_test.py +++ b/tests/core/core_scope_test.py @@ -54,17 +54,17 @@ def union_check(a, b, ans): self.assertEqual(scope.union_filters(a, b), ans) self.assertEqual(scope.union_filters(b, a), ans) - union_check(['a', 'b'], ['b', 'c'], set(['a', 'b', 'c'])) + union_check(['a', 'b'], ['b', 'c'], {'a', 'b', 'c'}) union_check(True, False, True) union_check(False, False, set()) union_check(True, True, True) union_check( scope.DenyList(['a', 'b']), scope.DenyList(['b', 'c']), - scope.DenyList(set(['b'])), + scope.DenyList({'b'}), ) union_check( - scope.DenyList(['a', 'b']), ['b', 'c'], scope.DenyList(set(['a'])) + scope.DenyList(['a', 'b']), ['b', 'c'], scope.DenyList({'a'}) ) def test_intersect_filter(self): @@ -72,33 +72,33 @@ def intersect_check(a, b, ans): self.assertEqual(scope.intersect_filters(a, b), ans) self.assertEqual(scope.intersect_filters(b, a), ans) - intersect_check(['a', 'b'], ['b', 'c'], set(['b'])) + intersect_check(['a', 'b'], ['b', 'c'], {'b'}) intersect_check(True, False, False) intersect_check(False, False, set()) intersect_check(True, True, True) intersect_check( scope.DenyList(['a', 'b']), scope.DenyList(['b', 'c']), - scope.DenyList(set(['a', 'b', 'c'])), + scope.DenyList({'a', 'b', 'c'}), ) - intersect_check(scope.DenyList(['a', 'b']), ['b', 'c'], set(['c'])) + intersect_check(scope.DenyList(['a', 'b']), ['b', 'c'], {'c'}) def test_subtract_filter(self): def subtract_check(a, b, ans): self.assertEqual(scope.subtract_filters(a, b), ans) - subtract_check(['a', 'b'], ['b', 'c'], set(['a'])) + subtract_check(['a', 'b'], ['b', 'c'], {'a'}) subtract_check(True, False, scope.DenyList(False)) subtract_check(False, False, set()) subtract_check(True, True, False) subtract_check(True, 'a', scope.DenyList('a')) subtract_check( - scope.DenyList(['a', 'b']), scope.DenyList(['b', 'c']), set(['c']) + scope.DenyList(['a', 'b']), scope.DenyList(['b', 'c']), {'c'} ) subtract_check( scope.DenyList(['a', 'b']), ['b', 'c'], - scope.DenyList(set(['a', 'b', 'c'])), + scope.DenyList({'a', 'b', 'c'}), ) def test_group_collections(self): diff --git a/tests/core/design/core_attention_test.py b/tests/core/design/core_attention_test.py index eb9f590c37..7c49f518f9 100644 --- a/tests/core/design/core_attention_test.py +++ b/tests/core/design/core_attention_test.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import partial -from typing import Callable, Optional, Sequence +from collections.abc import Callable, Sequence import jax from absl.testing import absltest @@ -47,7 +47,7 @@ def _dot_product_attention( query: Array, key: Array, value: Array, - bias: Optional[Array] = None, + bias: Array | None = None, attn_fn: Callable = softmax_attn, dtype=jnp.float32, ): @@ -73,9 +73,9 @@ def dot_product_attention( scope: Scope, inputs_q: Array, inputs_kv: Array, - bias: Optional[Array] = None, - qkv_features: Optional[int] = None, - out_features: Optional[int] = None, + bias: Array | None = None, + qkv_features: int | None = None, + out_features: int | None = None, attn_fn: Callable = softmax_attn, dtype=jnp.float32, ): @@ -100,9 +100,9 @@ def multi_head_dot_product_attention( scope: Scope, inputs_q: Array, inputs_kv: Array, - bias: Optional[Array] = None, - qkv_features: Optional[int] = None, - out_features: Optional[int] = None, + bias: Array | None = None, + qkv_features: int | None = None, + out_features: int | None = None, attn_fn: Callable = softmax_attn, batch_axes: Sequence[int] = (0,), num_heads: int = 1, diff --git a/tests/core/design/core_auto_encoder_test.py b/tests/core/design/core_auto_encoder_test.py index 4de77e4844..5b7e1ca899 100644 --- a/tests/core/design/core_auto_encoder_test.py +++ b/tests/core/design/core_auto_encoder_test.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Callable +from collections.abc import Callable import jax from absl.testing import absltest diff --git a/tests/core/design/core_custom_vjp_test.py b/tests/core/design/core_custom_vjp_test.py index 7f841f06b1..ce040b772a 100644 --- a/tests/core/design/core_custom_vjp_test.py +++ b/tests/core/design/core_custom_vjp_test.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import partial -from typing import Callable, Sequence +from collections.abc import Callable, Sequence import jax import numpy as np diff --git a/tests/core/design/core_dense_test.py b/tests/core/design/core_dense_test.py index 1d608b1ee3..b4fce9aae9 100644 --- a/tests/core/design/core_dense_test.py +++ b/tests/core/design/core_dense_test.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Any, Optional +from typing import Any import jax from absl.testing import absltest @@ -45,7 +45,7 @@ def __call__(self, scope, x): @struct.dataclass class ExplicitDense: kernel: Array - bias: Optional[Array] + bias: Array | None # a fully explicit "scope free" version @staticmethod diff --git a/tests/core/design/core_flow_test.py b/tests/core/design/core_flow_test.py index 90a85c9b67..d8a55441b2 100644 --- a/tests/core/design/core_flow_test.py +++ b/tests/core/design/core_flow_test.py @@ -13,7 +13,8 @@ # limitations under the License. from dataclasses import dataclass -from typing import Any, Sequence +from typing import Any +from collections.abc import Sequence import jax from absl.testing import absltest diff --git a/tests/core/design/core_vmap_test.py b/tests/core/design/core_vmap_test.py index 14156119cc..6e8f049c1c 100644 --- a/tests/core/design/core_vmap_test.py +++ b/tests/core/design/core_vmap_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Sequence +from collections.abc import Callable, Sequence import jax from absl.testing import absltest diff --git a/tests/core/design/core_weight_std_test.py b/tests/core/design/core_weight_std_test.py index db1a6550f3..b247e843c0 100644 --- a/tests/core/design/core_weight_std_test.py +++ b/tests/core/design/core_weight_std_test.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import partial -from typing import Sequence +from collections.abc import Sequence import jax from absl.testing import absltest diff --git a/tests/linen/linen_combinators_test.py b/tests/linen/linen_combinators_test.py index 70d8eb0d1a..96437644e2 100644 --- a/tests/linen/linen_combinators_test.py +++ b/tests/linen/linen_combinators_test.py @@ -14,7 +14,8 @@ """Tests for flax.linen.combinators.""" -from typing import Any, Optional, Sequence +from typing import Any +from collections.abc import Sequence import jax import numpy as np @@ -30,8 +31,8 @@ class MLP(nn.Module): layer_sizes: Sequence[int] - activation: Optional[Any] = None - activation_final: Optional[Any] = None + activation: Any | None = None + activation_final: Any | None = None @nn.compact def __call__(self, inputs): diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index 21a9358697..bb53ec1046 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -26,16 +26,12 @@ from tempfile import TemporaryDirectory from typing import ( Any, - Callable, Generic, - Mapping, NamedTuple, - Optional, - Sequence, - Tuple, TypeVar, get_type_hints, ) +from collections.abc import Callable, Mapping, Sequence from unittest.mock import patch import jax @@ -296,7 +292,7 @@ def test_param_in_setup(self): rngkey = jax.random.key(0) class DummyModuleWithoutCompact(nn.Module): - xshape: Tuple[int, ...] + xshape: tuple[int, ...] def setup(self): self.bias = self.param('bias', initializers.ones, self.xshape) @@ -348,7 +344,7 @@ def test_setup_call_var_collision(self): rngkey = jax.random.key(0) class Dummy(nn.Module): - xshape: Tuple[int, ...] + xshape: tuple[int, ...] def setup(self): self.bias = self.param('bias', initializers.ones, self.xshape) @@ -368,7 +364,7 @@ def test_call_var_collision(self): rngkey = jax.random.key(0) class Dummy(nn.Module): - xshape: Tuple[int, ...] + xshape: tuple[int, ...] @compact def __call__(self, x): @@ -386,7 +382,7 @@ def test_setup_var_collision(self): rngkey = jax.random.key(0) class Dummy(nn.Module): - xshape: Tuple[int, ...] + xshape: tuple[int, ...] def setup(self): self.bias = self.param('bias', initializers.ones, self.xshape) @@ -405,7 +401,7 @@ def test_setattr_name_var_disagreement_allowed_in_lists(self): rngkey = jax.random.key(0) class Dummy(nn.Module): - xshape: Tuple[int, ...] + xshape: tuple[int, ...] def setup(self): self.biases = [ @@ -425,7 +421,7 @@ def test_setattr_name_var_disagreement_allowed_in_dicts(self): rngkey = jax.random.key(0) class Dummy(nn.Module): - xshape: Tuple[int, ...] + xshape: tuple[int, ...] def setup(self): self.biases = { @@ -450,7 +446,7 @@ def test_submodule_var_collision_with_scope(self): rngkey = jax.random.key(0) class Dummy(nn.Module): - xshape: Tuple[int, ...] + xshape: tuple[int, ...] def setup(self): self.bias = self.param('bias', initializers.ones, self.xshape) @@ -469,7 +465,7 @@ def test_submodule_var_collision_with_submodule(self): rngkey = jax.random.key(0) class Dummy(nn.Module): - xshape: Tuple[int, ...] + xshape: tuple[int, ...] def setup(self): self.bias = self.param('bias', initializers.ones, self.xshape) @@ -490,7 +486,7 @@ def test_submodule_var_collision_with_params(self): rngkey = jax.random.key(0) class Dummy(nn.Module): - xshape: Tuple[int, ...] + xshape: tuple[int, ...] def setup(self): self.bias = DummyModule() @@ -2243,7 +2239,7 @@ class Foo(nn.Module): def test_kw_only(self): def create_kw_layers(): class BaseLayer(nn.Module, kw_only=True): - base_multiplier: Optional[int] = -1 + base_multiplier: int | None = -1 class ChildLayer(BaseLayer): child_multiplier: int # Don't want to have to set a default argument! diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index 90430ebd01..281a728819 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -16,7 +16,8 @@ from functools import partial import operator -from typing import Any, Callable, Dict, Sequence +from typing import Any +from collections.abc import Callable, Sequence import unittest from absl.testing import absltest, parameterized @@ -2541,10 +2542,10 @@ def unbox(self): return self.value def replace_boxed(self, val): return self.replace(value=val) - def add_axis(self, index: int, params: Dict[Any, Any]): + def add_axis(self, index: int, params: dict[Any, Any]): value = jnp.mean(self.value, axis=index) return self.replace(value=value) - def remove_axis(self, index: int, params: Dict[Any, Any]): + def remove_axis(self, index: int, params: dict[Any, Any]): value_shape = list(self.value.shape) value_shape.insert(index, params['axis_size']) value = jnp.broadcast_to(self.value, value_shape) diff --git a/tests/linen/summary_test.py b/tests/linen/summary_test.py index 7daf4f7ceb..36515dd554 100644 --- a/tests/linen/summary_test.py +++ b/tests/linen/summary_test.py @@ -13,7 +13,6 @@ # limitations under the License. import enum -from typing import List import jax import jax.numpy as jnp @@ -46,7 +45,7 @@ def _get_obj_repr_value(x): class ConvBlock(nn.Module): features: int - kernel_size: List[int] + kernel_size: list[int] test_sow: bool def setup(self) -> None: diff --git a/tests/traverse_util_test.py b/tests/traverse_util_test.py index 93d3ebd90d..dc5ad3b47c 100644 --- a/tests/traverse_util_test.py +++ b/tests/traverse_util_test.py @@ -31,7 +31,7 @@ jax.config.parse_flags_with_absl() -class Foo(object): +class Foo: def __init__(self, foo, bar=None): self.foo = foo self.bar = bar @@ -268,7 +268,7 @@ def filter_fn(name, _): for model, expected_model in configs: self.assertEqual(values, [1, 3]) self.assertEqual( - set(names), set(['/x/kernel', '/x/bias', '/x/y/kernel', '/x/y/bias']) + set(names), {'/x/kernel', '/x/bias', '/x/y/kernel', '/x/y/bias'} ) new_model = traversal.update(lambda x: x + x, model) self.assertEqual(new_model, expected_model)