Causal Mask HLO for jnp.tril(jnp.ones()) can be simplified #19905
-
Quite often we see the following JAX code to create and apply causal mask in the self-attention layer. import jax
import jax.numpy as jnp
from jax import Array, random
key = random.PRNGKey(42)
qk = random.uniform(key, shape=(4,4))
def apply_causal_mask(qk: Array):
seq_len = qk.shape[-1]
mask = jnp.tril(jnp.ones((seq_len, seq_len))).astype('bool')
return jnp.where(mask, qk, -jnp.inf) If we jit + lower + compile + as_text this function print(jax.jit(apply_causal_mask).lower(qk).compile().as_text()) then we will get the following hlo: HloModule jit_apply_causal_mask, entry_computation_layout={(f32[4,4]{1,0})->f32[4,4]{1,0}}, allow_spmd_sharding_propagation_to_output={true}
%fused_computation (param_0.1: f32[4,4]) -> f32[4,4] {
%iota.5 = s32[4,4]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]" source_file="<stdin>" source_line=5}
%iota.4 = s32[4,4]{1,0} iota(), iota_dimension=1, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]" source_file="<stdin>" source_line=5}
%compare.2 = pred[4,4]{1,0} compare(s32[4,4]{1,0} %iota.5, s32[4,4]{1,0} %iota.4), direction=GE, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/ge" source_file="<stdin>" source_line=5}
%constant.5 = f32[] constant(1)
%broadcast.8 = f32[4,4]{1,0} broadcast(f32[] %constant.5), dimensions={}
%constant.3 = f32[] constant(0)
%broadcast.7 = f32[4,4]{1,0} broadcast(f32[] %constant.3), dimensions={}
%select.3 = f32[4,4]{1,0} select(pred[4,4]{1,0} %compare.2, f32[4,4]{1,0} %broadcast.8, f32[4,4]{1,0} %broadcast.7), metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/select_n" source_file="<stdin>" source_line=5}
%compare.1 = pred[4,4]{1,0} compare(f32[4,4]{1,0} %select.3, f32[4,4]{1,0} %broadcast.7), direction=NE, metadata={op_name="jit(apply_causal_mask)/jit(main)/convert_element_type[new_dtype=bool weak_type=False]" source_file="<stdin>" source_line=5}
%param_0.1 = f32[4,4]{1,0} parameter(0)
%constant.1 = f32[] constant(-inf)
%broadcast.6 = f32[4,4]{1,0} broadcast(f32[] %constant.1), dimensions={}, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(_where)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]" source_file="<stdin>" source_line=8}
ROOT %select.2 = f32[4,4]{1,0} select(pred[4,4]{1,0} %compare.1, f32[4,4]{1,0} %param_0.1, f32[4,4]{1,0} %broadcast.6), metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(_where)/select_n" source_file="<stdin>" source_line=8}
}
ENTRY %main.26 (Arg_0.1: f32[4,4]) -> f32[4,4] {
%Arg_0.1 = f32[4,4]{1,0} parameter(0), sharding={replicated}
ROOT %fusion = f32[4,4]{1,0} fusion(f32[4,4]{1,0} %Arg_0.1), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(_where)/select_n" source_file="<stdin>" source_line=8}
} We noticed that In particular, Mask related code can be simplified to just iota + iota + compare %fused_computation (param_0.1: f32[4,4]) -> f32[4,4] {
%iota.5 = s32[4,4]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]" source_file="<stdin>" source_line=5}
%iota.4 = s32[4,4]{1,0} iota(), iota_dimension=1, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]" source_file="<stdin>" source_line=5}
%compare.2 = pred[4,4]{1,0} compare(s32[4,4]{1,0} %iota.5, s32[4,4]{1,0} %iota.4), direction=GE, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/ge" source_file="<stdin>" source_line=5}
%param_0.1 = f32[4,4]{1,0} parameter(0)
%constant.1 = f32[] constant(-inf)
%broadcast.6 = f32[4,4]{1,0} broadcast(f32[] %constant.1), dimensions={}, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(_where)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]" source_file="<stdin>" source_line=8}
ROOT %select.2 = f32[4,4]{1,0} select(pred[4,4]{1,0} %compare.2, f32[4,4]{1,0} %param_0.1, f32[4,4]{1,0} %broadcast.6), metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(_where)/select_n" source_file="<stdin>" source_line=8}
} As a result the following several ops can be removed %constant.5 = f32[] constant(1)
%broadcast.8 = f32[4,4]{1,0} broadcast(f32[] %constant.5), dimensions={}
%constant.3 = f32[] constant(0)
%broadcast.7 = f32[4,4]{1,0} broadcast(f32[] %constant.3), dimensions={}
%select.3 = f32[4,4]{1,0} select(pred[4,4]{1,0} %compare.2, f32[4,4]{1,0} %broadcast.8, f32[4,4]{1,0} %broadcast.7), metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/select_n" source_file="<stdin>" source_line=5}
%compare.1 = pred[4,4]{1,0} compare(f32[4,4]{1,0} %select.3, f32[4,4]{1,0} %broadcast.7), direction=NE, metadata={op_name="jit(apply_causal_mask)/jit(main)/convert_element_type[new_dtype=bool weak_type=False]" source_file="<stdin>" source_line=5} What do you think about recognizing such a pattern and applying the described simplification to it? Seems like a very common use-case in LLM models. Our team will be happy to work on it. Link to similar XLA discussion: openxla/xla#9709 |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Rather than Edit: |
Beta Was this translation helpful? Give feedback.
-
I think we identified common JAX code issue encountered when generating a Causal mask. The user either neglected to specify I opened XLA PR-9867 to simplify potential Causal mask suboptimal HLO |
Beta Was this translation helpful? Give feedback.
Rather than
jnp.tril(jnp.ones((seq_len, seq_len))).astype('bool')
, you might try writingjnp.tri(seq_len, dtype=bool)
which is much more direct.Edit:
tri
also generates the desired HLO (see openxla/xla#9709 (comment))