From 37e0e88caf1ad974755c062ebe1afb64a4fbe429 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 22 Nov 2024 10:47:02 -0800 Subject: [PATCH] improve docs --- docs/source/conf.py | 22 ++++++-- src/olmo_core/nn/attention.py | 26 ++++++--- src/olmo_core/nn/feed_forward.py | 24 ++++++-- src/olmo_core/nn/layer_norm.py | 46 ++++++++------- src/olmo_core/nn/lm_head.py | 32 +++++++++-- src/olmo_core/nn/rope.py | 68 ++++++++++++----------- src/olmo_core/nn/transformer/__init__.py | 6 +- src/olmo_core/nn/transformer/block.py | 10 ++-- src/olmo_core/nn/transformer/config.py | 5 +- src/olmo_core/train/__init__.py | 5 +- src/olmo_core/train/callbacks/__init__.py | 9 +++ src/test/nn/attention_test.py | 27 ++++++++- src/test/nn/rope_test.py | 12 ++-- 13 files changed, 194 insertions(+), 98 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 9118f6ae..91503b01 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -129,12 +129,22 @@ def filter(self, record: logging.LogRecord) -> bool: def autodoc_skip_member(app, what, name, obj, skip, options): - """ - Skip documenting these Pydantic-specific attributes. - """ - del app, what, obj, skip, options - exclude = name in {"model_config", "model_fields", "model_computed_fields"} - return True if exclude else None + import inspect + + del app, name, options + + module = inspect.getmodule(obj) + module_name = None if module is None else module.__name__ + if ( + what == "class" + and module_name is not None + and module_name.startswith("olmo_core.train.callbacks") + and module_name != "olmo_core.train.callbacks.callback" + ): + if inspect.isfunction(obj) or inspect.ismethod(obj): + return True + + return skip def setup(app): diff --git a/src/olmo_core/nn/attention.py b/src/olmo_core/nn/attention.py index c424deb2..634ac8d7 100644 --- a/src/olmo_core/nn/attention.py +++ b/src/olmo_core/nn/attention.py @@ -25,15 +25,20 @@ class AttentionType(StrEnum): """ An enumeration of the different attention implementations. - - - "default" ➡️ :class:`Attention` - - "fused" ➡️ :class:`FusedAttention` - - "normalized" ➡️ :class:`NormalizedAttention` """ default = "default" + """ + ➡️ :class:`Attention` + """ fused = "fused" + """ + ➡️ :class:`FusedAttention` + """ normalized = "normalized" + """ + ➡️ :class:`NormalizedAttention` + """ @dataclass @@ -41,13 +46,12 @@ class AttentionConfig(Config): """ A configuration class for easily building any of the different attention modules. - See :class:`Attention` for a description of the parameters. + See the individual :class:`Attention` subclasses for a description of the configuration options. """ name: AttentionType = AttentionType.default """ - - "default" ➡️ :class:`Attention` - - "fused" ➡️ :class:`FusedAttention` + The name of the implementation. """ n_heads: int = 16 n_kv_heads: Optional[int] = None @@ -60,6 +64,11 @@ class AttentionConfig(Config): dtype: DType = DType.float32 def num_params(self, d_model: int) -> int: + """ + The number of params that the attention implementation will have once built. + + :param d_model: The model dimensionality. + """ n_heads = self.n_heads n_kv_heads = self.n_kv_heads or n_heads head_dim = d_model // n_heads @@ -104,7 +113,8 @@ def build( """ Build the corresponding attention module. - See :class:`Attention` for a description of the parameters. + :param d_model: The model dimensionality. + :param init_device: The device initialize the parameters on, e.g. "cpu", "meta". """ kwargs = self.as_dict(exclude_none=True, recurse=False) kwargs.pop("name") diff --git a/src/olmo_core/nn/feed_forward.py b/src/olmo_core/nn/feed_forward.py index 38d29273..a6313858 100644 --- a/src/olmo_core/nn/feed_forward.py +++ b/src/olmo_core/nn/feed_forward.py @@ -11,7 +11,7 @@ from ..exceptions import OLMoConfigurationError from .functional import l2_normalize -__all__ = ["FeedForwardConfig", "FeedForwardType", "FeedForward", "NormalizedFeedForward"] +__all__ = ["FeedForwardType", "FeedForwardConfig", "FeedForward", "NormalizedFeedForward"] class FeedForwardType(StrEnum): @@ -21,12 +21,12 @@ class FeedForwardType(StrEnum): default = "default" """ - :class:`FeedForward`. + ➡️ :class:`FeedForward` """ normalized = "normalized" """ - :class:`NormalizedFeedForward`. + ➡️ :class:`NormalizedFeedForward` """ @@ -34,16 +34,22 @@ class FeedForwardType(StrEnum): class FeedForwardConfig(Config): """ A config for building :class:`FeedForward` modules. - - See :class:`FeedForward` for parameter descriptions. """ hidden_size: int name: FeedForwardType = FeedForwardType.default + """ + The name of the implementation. + """ bias: Optional[bool] = None dtype: DType = DType.float32 def num_params(self, d_model: int) -> int: + """ + The number of params that the module will have once built. + + :param d_model: The model dimensionality. + """ bias = self.bias if self.bias is not None else self.name != FeedForwardType.normalized params = 0 @@ -58,7 +64,13 @@ def num_params(self, d_model: int) -> int: return params - def build(self, d_model: int, init_device: str = "cpu") -> "FeedForward": + def build(self, d_model: int, *, init_device: str = "cpu") -> "FeedForward": + """ + Build the corresponding feed-forward module. + + :param d_model: The model dimensionality. + :param init_device: The device initialize the parameters on, e.g. "cpu", "meta". + """ kwargs = self.as_dict(exclude_none=True) kwargs.pop("name") kwargs.update(d_model=d_model, init_device=init_device, dtype=kwargs.pop("dtype").as_pt()) diff --git a/src/olmo_core/nn/layer_norm.py b/src/olmo_core/nn/layer_norm.py index e4a3c234..c778a4b6 100644 --- a/src/olmo_core/nn/layer_norm.py +++ b/src/olmo_core/nn/layer_norm.py @@ -15,16 +15,24 @@ class LayerNormType(StrEnum): """ An enumeration of the different layer norm implementations. - - - "default" ➡️ :class:`LayerNorm` - - "rms" ➡️ :class:`RMSNorm` - - "fused_rms" ➡️ :class:`FusedRMSNorm` """ default = "default" + """ + ➡️ :class:`LayerNorm` + """ rms = "rms" + """ + ➡️ :class:`RMSNorm` + """ fused_rms = "fused_rms" + """ + ➡️ :class:`FusedRMSNorm` + """ l2_norm = "l2_norm" + """ + ➡️ :class:`L2Norm` + """ @dataclass @@ -32,14 +40,12 @@ class LayerNormConfig(Config): """ A config for conveniently building any one of the different layer norm classes. - See :class:`LayerNorm` for a description of the parameters. + See the :class:`LayerNorm` subclasses to learn which fields are valid for each implementation. """ name: LayerNormType = LayerNormType.default """ - - "default" ➡️ :class:`LayerNorm` - - "rms" ➡️ :class:`RMSNorm` - - "fused_rms" ➡️ :class:`FusedRMSNorm` + The name of the implementation. """ eps: Optional[float] = None elementwise_affine: Optional[bool] = None @@ -48,6 +54,11 @@ class LayerNormConfig(Config): dtype: Optional[DType] = None def num_params(self, size: int) -> int: + """ + The number of parameters in the module once built. + + :param size: The size of the input along the dimension to be normalized. + """ elementwise_affine = ( self.elementwise_affine if self.elementwise_affine is not None @@ -65,7 +76,8 @@ def build(self, size: int, init_device: str = "cpu") -> "LayerNorm": """ Construct the corresponding LayerNorm class. - See :class:`LayerNorm` for a description of the parameters. + :param size: The size of the input along the dimension to be normalized. + :param init_device: The device initialize the parameters on, e.g. "cpu", "meta". """ kwargs = self.as_dict(exclude_none=True) kwargs.pop("name") @@ -93,11 +105,7 @@ class LayerNorm(nn.Module): """ Layer normalization. - .. seealso:: - - :class:`RMSNorm` - - :class:`FusedRMSNorm` - - :param size: The hidden size / dimensionality of the input. + :param size: The size of the input along the dimension to be normalized. :param eps: The epsilon used for numerical stability. :param elementwise_affine: Whether to include an element-wise affine transform. :param bias: Whether the element-wise affine should include an element-wise bias. @@ -178,16 +186,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class RMSNorm(LayerNorm): """ - RMS norm, a simplified layer norm implementation. - - .. seealso:: - - :class:`LayerNorm` - - :class:`FusedRMSNorm` + RMSNorm, a simplified layer norm implementation. """ def forward(self, x: torch.Tensor) -> torch.Tensor: """ - Apply RMS norm. + Apply RMSNorm. :param x: The input. """ @@ -270,6 +274,8 @@ class L2Norm(LayerNorm): """ A variant of layer norm that just normalizes the last dimension of the input by its L2 norm, as done in nGPT. + + :param size: The size of the input along the dimension to be normalized. """ def __init__( diff --git a/src/olmo_core/nn/lm_head.py b/src/olmo_core/nn/lm_head.py index b96bc999..e8543c42 100644 --- a/src/olmo_core/nn/lm_head.py +++ b/src/olmo_core/nn/lm_head.py @@ -11,37 +11,45 @@ from .functional import l2_normalize from .layer_norm import LayerNormConfig -__all__ = ["LMHeadConfig", "LMHeadType", "LMHead", "NormalizedLMHead"] +__all__ = ["LMHeadType", "LMHeadConfig", "LMHead", "NormalizedLMHead"] class LMHeadType(StrEnum): """ - An enumeration of LM head types. + An enumeration of the different LM head types. """ default = "default" """ - :class:`LMHead` + ➡️ :class:`LMHead` """ normalized = "normalized" """ - :class:`NormalizedLMHead` + ➡️ :class:`NormalizedLMHead` """ @dataclass class LMHeadConfig(Config): """ - A configuration class for building an :class:`LMHead`. + A configuration class for building any of the :class:`LMHead` implementations. + + See the :class:`LMHead` subclasses to learn which fields are valid for each implementation. """ name: LMHeadType = LMHeadType.default + """ + The name of the implementation. + """ layer_norm: Optional[LayerNormConfig] = None bias: Optional[bool] = None dtype: DType = DType.float32 def num_params(self, d_model: int, vocab_size: int) -> int: + """ + The number of parameters in the module once built. + """ bias = self.bias if self.bias is not None else self.name != LMHeadType.normalized params = 0 @@ -59,6 +67,12 @@ def num_params(self, d_model: int, vocab_size: int) -> int: return params def build(self, *, d_model: int, vocab_size: int, init_device: str = "cpu") -> "LMHead": + """ + Construct the corresponding LM head implementation. + + :param d_model: The model dimensionality. + :param init_device: The device initialize the parameters on, e.g. "cpu", "meta". + """ kwargs = self.as_dict(exclude_none=True, recurse=False) kwargs.pop("name") kwargs.update( @@ -83,7 +97,7 @@ def build(self, *, d_model: int, vocab_size: int, init_device: str = "cpu") -> " class LMHead(nn.Module): """ - The default LM head implementation. + The default language modeling head implementation. """ def __init__( @@ -103,6 +117,9 @@ def __init__( self.w_out = nn.Linear(d_model, vocab_size, bias=bias, dtype=dtype, device=init_device) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply the LM head to the hidden state ``x``, returning the logits. + """ h = self.norm(x) if self.norm is not None else x return self.w_out(h) @@ -136,6 +153,9 @@ def __init__( ) def reset_parameters(self): + """ + Reset the scaling parameter. + """ nn.init.ones_(self.sz) self.sz.mul_(self.sz_init_scaling) diff --git a/src/olmo_core/nn/rope.py b/src/olmo_core/nn/rope.py index 26a184a4..2dd9e03a 100644 --- a/src/olmo_core/nn/rope.py +++ b/src/olmo_core/nn/rope.py @@ -24,15 +24,20 @@ class RoPEType(StrEnum): """ An enumeration of the different RoPE implementations. - - - "default" ➡️ :class:`RotaryEmbedding` - - "fused" ➡️ :class:`FusedRotaryEmbedding` - - "complex" ➡️ :class:`ComplexRotaryEmbedding` """ default = "default" + """ + ➡️ :class:`RotaryEmbedding` + """ fused = "fused" + """ + ➡️ :class:`FusedRotaryEmbedding` + """ complex = "complex" + """ + ➡️ :class:`ComplexRotaryEmbedding` + """ @dataclass @@ -69,16 +74,15 @@ def scale_inv_freq( @dataclass class RoPEConfig(Config): """ - A config for conveniently building any one of the different RoPE classes. + A config for conveniently building any of the different RoPE classes. - See :class:`RotaryEmbedding` for a description of the parameters. + See the individual :class:`RotaryEmbedding` subclasses for a description of the + configuration options. """ name: RoPEType = RoPEType.default """ - - "default" ➡️ :class:`RotaryEmbedding` - - "fused" ➡️ :class:`FusedRotaryEmbedding` - - "complex" ➡️ :class:`ComplexRotaryEmbedding` + The name of the implementation. """ theta: int = 500_000 full_precision: bool = True @@ -86,17 +90,17 @@ class RoPEConfig(Config): def build( self, - head_shape: int, + head_size: int, cache: Optional[BufferCache] = None, ) -> "RotaryEmbeddingBase": """ Construct the corresponding RoPE class. - See :class:`RotaryEmbedding` for a description of the parameters. + :param head_size: The size of the attention heads. """ kwargs = self.as_dict(exclude_none=True, recurse=False) kwargs.pop("name") - kwargs.update(head_shape=head_shape, cache=cache) + kwargs.update(head_size=head_size, cache=cache) try: if self.name == "default": @@ -121,14 +125,14 @@ class RotaryEmbeddingBase(nn.Module): def __init__( self, *, - head_shape: int, + head_size: int, theta: int = 500_000, full_precision: bool = True, cache: Optional[BufferCache] = None, scaling: Optional[RoPEScalingConfig] = None, ): super().__init__() - self.dim = head_shape + self.dim = head_size self.theta = theta self.full_precision = full_precision self.scaling = scaling @@ -150,9 +154,10 @@ class RotaryEmbedding(RotaryEmbeddingBase): - :class:`ComplexRotaryEmbedding` - :class:`FusedRotaryEmbedding` - :param head_shape: The dimensionality of the attention heads. + :param head_size: The size of the attention heads. :param theta: The theta base value to use. :param full_precision: Always apply RoPE in full precision regardless of the input data type. + :param scaling: The scaling config. """ def warmup_cache(self, max_seq_len: int, device: torch.device): @@ -209,11 +214,11 @@ def forward( """ Apply RoPE to query (``q``) and key (``k``) matrices. - :param q: The query matrix of shape ``(batch_size, num_heads, seq_len, head_shape)`` - if ``head_first`` (the default) otherwise ``(batch_size, seq_len, num_heads, head_shape)``. - :param k: The key matrix of shape ``(batch_size, num_kv_heads, seq_len, head_shape)`` + :param q: The query matrix of shape ``(batch_size, num_heads, seq_len, head_size)`` + if ``head_first`` (the default) otherwise ``(batch_size, seq_len, num_heads, head_size)``. + :param k: The key matrix of shape ``(batch_size, num_kv_heads, seq_len, head_size)`` if ``head_first`` (the default) otherwise - ``(batch_size, seq_len, num_kv_heads, head_shape)``. + ``(batch_size, seq_len, num_kv_heads, head_size)``. :param head_first: If the head dim comes before the sequence dim. :returns: The query and key matrices after RoPE has been applied. @@ -231,7 +236,7 @@ def forward( q_, k_ = q, k with torch.autocast(q.device.type, enabled=False): - # shape: (T, head_shape), (T, head_shape) + # shape: (T, head_size), (T, head_size) pos_sin, pos_cos = self._get_rotary_embedding(k_len, q_.device) pos_sin, pos_cos = pos_sin.type_as(q_), pos_cos.type_as(q_) if head_first: @@ -262,15 +267,16 @@ class FusedRotaryEmbedding(RotaryEmbeddingBase): .. warning:: This requires `flash-attn `_ to be installed. - :param head_shape: The dimensionality of the attention heads. + :param head_size: The size of the attention heads. :param theta: The theta base value to use. :param full_precision: Always apply RoPE in full precision regardless of the input data type. + :param scaling: The scaling config. """ def __init__( self, *, - head_shape: int, + head_size: int, theta: int = 500_000, full_precision: bool = True, cache: Optional[BufferCache] = None, @@ -279,7 +285,7 @@ def __init__( from flash_attn.layers.rotary import apply_rotary_emb_qkv_ # type: ignore super().__init__( - head_shape=head_shape, + head_size=head_size, theta=theta, full_precision=full_precision, cache=cache, @@ -330,7 +336,7 @@ def forward(self, qkv: torch.Tensor) -> torch.Tensor: is not in full precision. :param qkv: The query, key, and value matrix of shape - ``(batch_size, seq_len, 3, n_heads, head_shape)``. + ``(batch_size, seq_len, 3, n_heads, head_size)``. """ if self.full_precision: qkv_ = qkv.float() @@ -349,7 +355,7 @@ class ComplexRotaryEmbedding(RotaryEmbeddingBase): """ An implementation of `RoPE `_ as a rotation in complex space. - :param head_shape: The dimensionality of the attention heads. + :param head_size: The dimensionality of the attention heads. :param theta: The theta base value to use. :param full_precision: Always apply RoPE in full precision regardless of the input data type. """ @@ -357,13 +363,13 @@ class ComplexRotaryEmbedding(RotaryEmbeddingBase): def __init__( self, *, - head_shape: int, + head_size: int, theta: int = 500_000, full_precision: bool = True, cache: Optional[BufferCache] = None, ): super().__init__( - head_shape=head_shape, + head_size=head_size, theta=theta, full_precision=full_precision, cache=cache, @@ -406,11 +412,11 @@ def forward( """ Apply RoPE to query (``q``) and key (``k``) matrices. - :param q: The query matrix of shape ``(batch_size, num_heads, seq_len, head_shape)`` - if ``head_first`` (the default) otherwise ``(batch_size, seq_len, num_heads, head_shape)``. - :param k: The key matrix of shape ``(batch_size, num_kv_heads, seq_len, head_shape)`` + :param q: The query matrix of shape ``(batch_size, num_heads, seq_len, head_size)`` + if ``head_first`` (the default) otherwise ``(batch_size, seq_len, num_heads, head_size)``. + :param k: The key matrix of shape ``(batch_size, num_kv_heads, seq_len, head_size)`` if ``head_first`` (the default) otherwise - ``(batch_size, seq_len, num_kv_heads, head_shape)``. + ``(batch_size, seq_len, num_kv_heads, head_size)``. :param head_first: If the head dim comes before the sequence dim. :returns: The query and key matrices after RoPE has been applied. diff --git a/src/olmo_core/nn/transformer/__init__.py b/src/olmo_core/nn/transformer/__init__.py index ffd0be28..ddbc2183 100644 --- a/src/olmo_core/nn/transformer/__init__.py +++ b/src/olmo_core/nn/transformer/__init__.py @@ -1,7 +1,3 @@ -""" -Transformer building blocks. -""" - from .block import ( MoEReorderedNormTransformerBlock, MoETransformerBlock, @@ -27,8 +23,8 @@ ) __all__ = [ - "TransformerConfig", "TransformerType", + "TransformerConfig", "Transformer", "NormalizedTransformer", "TransformerBlockType", diff --git a/src/olmo_core/nn/transformer/block.py b/src/olmo_core/nn/transformer/block.py index 81e3da3d..da74a7c4 100644 --- a/src/olmo_core/nn/transformer/block.py +++ b/src/olmo_core/nn/transformer/block.py @@ -25,27 +25,27 @@ class TransformerBlockType(StrEnum): default = "default" """ - :class:`TransformerBlock` + ➡️ :class:`TransformerBlock` """ reordered_norm = "reordered_norm" """ - :class:`ReorderedNormTransformerBlock` + ➡️ :class:`ReorderedNormTransformerBlock` """ normalized = "normalized" """ - :class:`NormalizedTransformerBlock` + ➡️ :class:`NormalizedTransformerBlock` """ moe = "moe" """ - :class:`MoETransformerBlock` + ➡️ :class:`MoETransformerBlock` """ moe_reordered_norm = "moe" """ - :class:`MoEReorderedNormTransformerBlock` + ➡️ :class:`MoEReorderedNormTransformerBlock` """ diff --git a/src/olmo_core/nn/transformer/config.py b/src/olmo_core/nn/transformer/config.py index 4e344859..6a48e508 100644 --- a/src/olmo_core/nn/transformer/config.py +++ b/src/olmo_core/nn/transformer/config.py @@ -89,12 +89,12 @@ class TransformerType(StrEnum): default = "default" """ - :class:`Transformer` + ➡️ :class:`Transformer` """ normalized = "normalized" """ - :class:`NormalizedTransformer` + ➡️ :class:`NormalizedTransformer` (nGPT) """ @@ -103,6 +103,7 @@ class TransformerConfig(Config): """ A config for easily building transformer models. + :param name: The name of the implementation. :param compile: Whether to compile the model with ``torch.compile``. :param dp_config: Data parallel configuration. :param ac_config: Activation checkpointing configuration. diff --git a/src/olmo_core/train/__init__.py b/src/olmo_core/train/__init__.py index 56c56933..d0f7bf83 100644 --- a/src/olmo_core/train/__init__.py +++ b/src/olmo_core/train/__init__.py @@ -8,8 +8,9 @@ - Supports any type of parallel strategy. - Async metric logging, with support for custom metrics, even those that need to be reduced across ranks. -- Flexible callback system for extending/modifying the training loop behavior. -- A powerful set of built-in callbacks. +- Flexible :class:`~olmo_core.train.callbacks.Callback` system for extending/modifying the training + loop behavior. +- A powerful set of built-in callbacks (:mod:`olmo_core.train.callbacks`). Overview -------- diff --git a/src/olmo_core/train/callbacks/__init__.py b/src/olmo_core/train/callbacks/__init__.py index c393422e..f37129f2 100644 --- a/src/olmo_core/train/callbacks/__init__.py +++ b/src/olmo_core/train/callbacks/__init__.py @@ -1,3 +1,7 @@ +""" +Trainer :class:`Callback` implementations. +""" + from .callback import Callback, CallbackConfig from .checkpointer import CheckpointerCallback, CheckpointRemovalStrategy from .comet import CometCallback, CometNotificationSetting @@ -44,3 +48,8 @@ "SpeedMonitorCallback", "WandBCallback", ] + +__doc__ += "\n" +for name in __all__[2:]: + if name.endswith("Callback"): + __doc__ += f"- :class:`{name}`\n" diff --git a/src/test/nn/attention_test.py b/src/test/nn/attention_test.py index 3d18d266..ab59e471 100644 --- a/src/test/nn/attention_test.py +++ b/src/test/nn/attention_test.py @@ -3,7 +3,12 @@ import pytest import torch -from olmo_core.nn.attention import Attention, FusedAttention +from olmo_core.nn.attention import ( + Attention, + AttentionConfig, + AttentionType, + FusedAttention, +) from olmo_core.nn.layer_norm import LayerNormConfig from olmo_core.nn.rope import RoPEConfig, RoPEType @@ -185,3 +190,23 @@ def test_attention_with_intra_document_masking(): torch.testing.assert_close(y1_fused, y2_fused) torch.testing.assert_close(y1, y1_fused) torch.testing.assert_close(y2, y2_fused) + + +@pytest.mark.parametrize( + "attn_config", + [ + AttentionConfig(name=AttentionType.default, n_heads=8, n_kv_heads=1, bias=True), + AttentionConfig(name=AttentionType.default, n_heads=8, n_kv_heads=1, bias=False), + AttentionConfig( + name=AttentionType.default, n_heads=8, bias=False, qk_norm=LayerNormConfig() + ), + ], +) +def test_attention_buidler_config(attn_config: AttentionConfig): + d_model = 64 + + attn = attn_config.build(d_model) + + # Make sure the estimated number of params matches the actual number of params. + n_params = sum(p.numel() for p in attn.parameters()) + assert attn_config.num_params(d_model) == n_params diff --git a/src/test/nn/rope_test.py b/src/test/nn/rope_test.py index 8592e90e..9d00efb0 100644 --- a/src/test/nn/rope_test.py +++ b/src/test/nn/rope_test.py @@ -13,7 +13,7 @@ @pytest.mark.parametrize("device", DEVICES) def test_rope_head_first_vs_seq_first(device): B, T, d_model, n_heads = 2, 12, 16, 4 - rope = RotaryEmbedding(head_shape=d_model // n_heads) + rope = RotaryEmbedding(head_size=d_model // n_heads) with torch.no_grad(): q = torch.rand(B, n_heads, T, d_model // n_heads, device=device) @@ -39,7 +39,7 @@ def test_rope_head_first_vs_seq_first(device): ) def test_rope_with_past_key_values(device, head_first): B, T, d_model, n_heads = 2, 12, 16, 4 - rope = RotaryEmbedding(head_shape=d_model // n_heads) + rope = RotaryEmbedding(head_size=d_model // n_heads) with torch.no_grad(): q = torch.rand(B, n_heads, T, d_model // n_heads, device=device) @@ -68,8 +68,8 @@ def test_rope_with_past_key_values(device, head_first): ) def test_fused_rope(dtype): B, T, d_model, n_heads = 2, 12, 32, 4 - fused_rope = FusedRotaryEmbedding(head_shape=d_model // n_heads) - rope = RotaryEmbedding(head_shape=d_model // n_heads) + fused_rope = FusedRotaryEmbedding(head_size=d_model // n_heads) + rope = RotaryEmbedding(head_size=d_model // n_heads) with torch.no_grad(), torch.autocast("cuda", dtype=dtype, enabled=dtype != torch.float32): qkv = torch.rand(B, T, 3, n_heads, d_model // n_heads, device="cuda", dtype=dtype) @@ -84,7 +84,7 @@ def test_fused_rope(dtype): @pytest.mark.parametrize("device", DEVICES) def test_complex_rope_head_first_vs_seq_first(device): B, T, d_model, n_heads = 2, 12, 16, 4 - rope = ComplexRotaryEmbedding(head_shape=d_model // n_heads) + rope = ComplexRotaryEmbedding(head_size=d_model // n_heads) with torch.no_grad(): q = torch.rand(B, n_heads, T, d_model // n_heads, device=device) @@ -110,7 +110,7 @@ def test_complex_rope_head_first_vs_seq_first(device): ) def test_complex_rope_with_past_key_values(device, head_first): B, T, d_model, n_heads = 2, 12, 16, 4 - rope = ComplexRotaryEmbedding(head_shape=d_model // n_heads) + rope = ComplexRotaryEmbedding(head_size=d_model // n_heads) with torch.no_grad(): q = torch.rand(B, n_heads, T, d_model // n_heads, device=device)