Skip to content

Commit

Permalink
Fixes attention_test.
Browse files Browse the repository at this point in the history
jax-ml/jax#20094 changes the behavior of RNG in vmap, so we can no longer rely on identical layer param initialization when using vmap vs. not. This affects RepeatedTransformerLayer and fused QKV layers.

The fix is to convert layer params from the reference layer to the test layer instead of relying on identical initialization.
  • Loading branch information
ruomingp committed May 10, 2024
1 parent 840194b commit 8aab11e
Showing 1 changed file with 117 additions and 26 deletions.
143 changes: 117 additions & 26 deletions axlearn/common/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Tests attention layers."""
import contextlib
import copy

# pylint: disable=too-many-lines,duplicate-code,no-self-use
import math
Expand Down Expand Up @@ -90,7 +91,15 @@
)
from axlearn.common.test_utils import TestCase, assert_allclose, dummy_segments_positions
from axlearn.common.torch_utils import parameters_from_torch_layer
from axlearn.common.utils import PartitionSpec, Tensor, as_tensor, flatten_items, shapes
from axlearn.common.utils import (
Nested,
PartitionSpec,
Tensor,
VDict,
as_tensor,
flatten_items,
shapes,
)


def _random_mask(prng_key, tgt_len, src_len):
Expand Down Expand Up @@ -1514,6 +1523,36 @@ def _scale_kwargs(
return kwargs


def _convert_to_qkv_linear(
base_state: Nested[Tensor], *, input_linear_layer_class: type
) -> Nested[Tensor]:
"""Converts the params of a MultiheadAttention layer
... to params of a MultiheadAttention layer with input_linear of the given type."""
test_state = copy.deepcopy(base_state)

if issubclass(
input_linear_layer_class, (attention.FusedQKVLinear, attention.FusedGroupedQKVLinear)
):

def combine_qkv(param_name: str) -> Tensor:
qkv_params = [
utils.get_recursively(base_state, f"i_proj/{proj}/{param_name}")
for proj in ("q_proj", "k_proj", "v_proj")
]
if issubclass(input_linear_layer_class, attention.FusedQKVLinear):
return jnp.stack(qkv_params)
else:
return jnp.concatenate(qkv_params, axis=-2)

qkv_proj = {"weight": combine_qkv("weight")}
if "bias" in base_state["i_proj"]["q_proj"]:
qkv_proj["bias"] = combine_qkv("bias")
test_state["i_proj"] = VDict({"qkv_proj": qkv_proj})

return test_state


class MultiheadAttentionTest(TestCase):
"""Tests MultiheadAttention, GroupedQueryAttention, and associated layers."""

Expand Down Expand Up @@ -1829,16 +1868,14 @@ def test_gqa_forward(
cfg.set(input_linear=input_linear)
set_bias_recursively(cfg, bias=bias)
test_layer = cfg.set(name="test").instantiate(parent=None)
test_state = test_layer.initialize_parameters_recursively(prng_key=init_key)

if input_linear and issubclass(input_linear.klass, attention.FusedGroupedQKVLinear):
test_state["i_proj"]["qkv_proj"]["weight"] = jnp.concatenate(
[
utils.get_recursively(base_state, f"i_proj/{proj}/weight")
for proj in ("q_proj", "k_proj", "v_proj")
],
axis=-2,
)
logging.info("base_state=%s", shapes(base_state))
# We convert 'base_state' to 'test_state' because JAX does not ensure that RNG behavior
# remains the same with vs. without vmap. So test_layer initialization may behave
# differently even with the same seed.
test_state = _convert_to_qkv_linear(
base_state, input_linear_layer_class=cfg.input_linear.klass
)
logging.info("transformed_test_state=%s", shapes(test_state))

# Dummy inputs.
batch_size, tgt_len = 2, 6
Expand Down Expand Up @@ -2434,17 +2471,17 @@ def test_per_dim_scale(self, per_dim_scale, scale_position):
)
expected_vals = {
str(None): {
MultiheadAttentionXL.ScalePosition.LOGIT.value: 48.06005,
MultiheadAttentionXL.ScalePosition.QUERY.value: 48.08012,
MultiheadAttentionXL.ScalePosition.LOGIT.value: 48.683887,
MultiheadAttentionXL.ScalePosition.QUERY.value: 48.598305,
},
str(PerDimScale.default_config()): {
MultiheadAttentionXL.ScalePosition.LOGIT.value: 47.321579,
MultiheadAttentionXL.ScalePosition.QUERY.value: 47.870319,
MultiheadAttentionXL.ScalePosition.LOGIT.value: 48.790010,
MultiheadAttentionXL.ScalePosition.QUERY.value: 48.858986,
},
}
assert_allclose(
jnp.abs(layer_outputs.data).sum(),
expected_vals[str(per_dim_scale)][scale_position.value],
jnp.abs(layer_outputs.data).sum(),
)

def test_multihead_attention_xl(self):
Expand Down Expand Up @@ -2799,6 +2836,52 @@ def forward(self, data, self_attention_logit_biases):
return loss, {"mean": x_mean}


def _recursive_stack(inputs: Nested[Tensor], axis=0):
def stack(*xs):
return jnp.stack(xs, axis=axis)

return {"layer": utils.vectorized_tree_map(stack, *inputs.values())}


def _convert_from_stacked_params(
layer_params: Nested[Tensor], *, target_stack_cfg: BaseStackedTransformerLayer.Config
) -> Nested[Tensor]:
"""Converts params of a StackedTransformerLayer to params for `target_stack_cfg`."""
# First stack to params of a RepeatedTransformerLayer.
layer_params = {"stack": {"repeat": VDict(_recursive_stack(layer_params["stack"]))}}
if target_stack_cfg.klass == RepeatedTransformerLayer:
return layer_params
elif target_stack_cfg.klass == PipelinedTransformerLayer:
pipeline_stage_cfg = target_stack_cfg.stage
num_layers_per_stage = target_stack_cfg.num_layers // target_stack_cfg.num_stages

def reshape(x):
"""Reshapes x from [num_layers, ...] to [num_stages, num_layers_per_stage, ...]."""
x_shape = list(x.shape)
return jnp.reshape(x, [target_stack_cfg.num_stages, num_layers_per_stage] + x_shape[1:])

pipeline_params = jax.tree_util.tree_map(reshape, layer_params["stack"].pop("repeat"))

if pipeline_stage_cfg.klass == RepeatedTransformerLayer:
layer_params["stack"]["pipeline"] = VDict({"layer": {"repeat": pipeline_params}})
elif pipeline_stage_cfg.klass == StackedTransformerLayer:
layer_params["stack"]["pipeline"] = VDict(
{
"layer": {
f"layer{i}": jax.tree_util.tree_map(
lambda x, i=i: x[:, i], pipeline_params["layer"]
)
for i in range(num_layers_per_stage)
}
}
)
else:
raise NotImplementedError(target_stack_cfg)
return layer_params
else:
raise NotImplementedError(target_stack_cfg)


class StackedTransformerTest(TestCase):
"""Tests StackedTransformerLayer."""

Expand Down Expand Up @@ -3154,6 +3237,7 @@ def test_stack_vs_pipeline_remat_everything_saveable(self):

# pylint: disable-next=too-many-statements,too-many-branches
def _compare_layers(self, *stack_configs, dtype=jnp.float32, remat_spec=None):
assert stack_configs[0] == StackedTransformerLayer, stack_configs[0]
with utils.numeric_checks(False):
batch_size, tgt_len = 10, 5
num_layers, model_dim, num_heads = 6, 8, 4
Expand All @@ -3168,6 +3252,7 @@ def _compare_layers(self, *stack_configs, dtype=jnp.float32, remat_spec=None):
all_outputs = []
all_gradients = []
all_updates = []
stacked_layer_params = None
for stack_cfg in stack_configs:
cfg = self._stack_config(
stack_cfg,
Expand Down Expand Up @@ -3200,6 +3285,20 @@ def _compare_layers(self, *stack_configs, dtype=jnp.float32, remat_spec=None):
for path, value in flatten_items(layer_params)
],
)
if cls == StackedTransformerLayer:
stacked_layer_params = copy.deepcopy(layer_params)
else:
layer_params = _convert_from_stacked_params(
stacked_layer_params, target_stack_cfg=cfg.stack
)
logging.info(
"Converted: %s.params=%s",
cls,
[
f"{path}={value.dtype}({value.shape})"
for path, value in flatten_items(layer_params)
],
)

def _loss(layer_params, data, mask, layer=layer):
layer_outputs, layer_output_collection = F(
Expand Down Expand Up @@ -3270,17 +3369,9 @@ def rms_norm(x):
dict(utils.flatten_items(update_norms)),
)

def recursive_stack(stacked, axis=0):
return {
"layer": utils.vectorized_tree_map(
lambda *xs: jnp.stack(xs, axis=axis),
*stacked.values(),
)
}

if cls == StackedTransformerLayer:
for x in (layer_params, grads, updates):
x["stack"] = recursive_stack(x["stack"])
x["stack"] = _recursive_stack(x["stack"])

if cls == RepeatedTransformerLayer:
for x in (layer_params, grads, updates):
Expand All @@ -3291,7 +3382,7 @@ def recursive_stack(stacked, axis=0):
logging.info("x=%s", shapes(x))
if cfg.stack.stage.klass == StackedTransformerLayer:
# First stack within each stage.
x["stack"]["pipeline"]["layer"] = recursive_stack(
x["stack"]["pipeline"]["layer"] = _recursive_stack(
x["stack"]["pipeline"]["layer"], axis=1
)
logging.info("x=%s", shapes(x))
Expand Down

0 comments on commit 8aab11e

Please sign in to comment.