let _where = { lambda ; a:bool[239696] b:f32[239696] c:f32[]. let d:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c e:f32[239696] = broadcast_in_dim[broadcast_dimensions=() shape=(239696,)] d f:f32[239696] = select_n a e b in (f,) } in let _where1 = { lambda ; g:bool[422220] h:f32[422220] i:f32[]. let j:f32[] = convert_element_type[new_dtype=float32 weak_type=False] i k:f32[422220] = broadcast_in_dim[broadcast_dimensions=() shape=(422220,)] j l:f32[422220] = select_n g k h in (l,) } in let norm = { lambda ; m:f32[239696,2000]. let n:f32[239696,2000] = mul m m o:f32[239696] = reduce_sum[axes=(1,)] n p:f32[239696] = sqrt o in (p,) } in let norm1 = { lambda ; q:f32[422220,2000]. let r:f32[422220,2000] = mul q q s:f32[422220] = reduce_sum[axes=(1,)] r t:f32[422220] = sqrt s in (t,) } in let norm2 = { lambda ; u:f32[4,2000]. let v:f32[4,2000] = mul u u w:f32[4] = reduce_sum[axes=(1,)] v x:f32[4] = sqrt w in (x,) } in let _where2 = { lambda ; y:bool[422220] z:f32[422220] ba:i32[]. let bb:f32[] = convert_element_type[new_dtype=float32 weak_type=False] ba bc:f32[422220] = broadcast_in_dim[broadcast_dimensions=() shape=(422220,)] bb bd:f32[422220] = select_n y bc z in (bd,) } in let norm3 = { lambda ; be:f32[8,2000]. let bf:f32[8,2000] = mul be be bg:f32[8] = reduce_sum[axes=(1,)] bf bh:f32[8] = sqrt bg in (bh,) } in let floor_divide = { lambda ; bi:i32[] bj:i32[]. let bk:i32[] = div bi bj bl:i32[] = sign bi bm:i32[] = sign bj bn:bool[] = ne bl bm bo:i32[] = rem bi bj bp:bool[] = ne bo 0 bq:bool[] = convert_element_type[new_dtype=bool weak_type=False] bn br:bool[] = convert_element_type[new_dtype=bool weak_type=False] bp bs:bool[] = and bq br bt:i32[] = sub bk 1 bu:i32[] = pjit[ name=_where jaxpr={ lambda ; bv:bool[] bw:i32[] bx:i32[]. let by:i32[] = select_n bv bx bw in (by,) } ] bs bt bk in (bu,) } in let _where3 = { lambda ; bv:bool[] bw:i32[] bx:i32[]. let by:i32[] = select_n bv bx bw in (by,) } in let _where4 = { lambda ; bz:bool[239696] ca:f32[239696] cb:i32[]. let cc:f32[] = convert_element_type[new_dtype=float32 weak_type=False] cb cd:f32[239696] = broadcast_in_dim[broadcast_dimensions=() shape=(239696,)] cc ce:f32[239696] = select_n bz cd ca in (ce,) } in { lambda cf:f32[30,42] cg:i32[1]; ch:f32[239696,2001] ci:f32[422220,2001] cj:f32[] ck:f32[]. let cl:key[] = random_seed[impl=fry] 0 cm:u32[2] = random_unwrap cl cn:key[] = random_wrap[impl=fry] cm co:key[2] = random_split[shape=(2,)] cn cp:u32[2,2] = random_unwrap co cq:u32[1,2] = slice[ limit_indices=(1, 2) start_indices=(0, 0) strides=(1, 1) ] cp _:u32[2] = squeeze[dimensions=(0,)] cq cr:u32[1,2] = slice[ limit_indices=(2, 2) start_indices=(1, 0) strides=(1, 1) ] cp _:u32[2] = squeeze[dimensions=(0,)] cr cs:f32[239696] = broadcast_in_dim[broadcast_dimensions=() shape=(239696,)] 1.0 _:f32[239696] = div cs 239696.0 ct:f32[239696] = broadcast_in_dim[broadcast_dimensions=() shape=(239696,)] 0.0 cu:f32[422220] = broadcast_in_dim[broadcast_dimensions=() shape=(422220,)] 1.0 _:f32[422220] = div cu 422220.0 cv:f32[422220] = broadcast_in_dim[broadcast_dimensions=() shape=(422220,)] 0.0 cw:f32[239696] = broadcast_in_dim[broadcast_dimensions=() shape=(239696,)] 1.0 cx:f32[239696] = div cw 239696.0 cy:bool[239696] = gt cx 0.0 cz:f32[239696] = pjit[name=_where jaxpr=_where] cy ct -inf da:f32[422220] = broadcast_in_dim[broadcast_dimensions=() shape=(422220,)] 1.0 db:f32[422220] = div da 422220.0 dc:bool[422220] = gt db 0.0 dd:f32[422220] = pjit[name=_where jaxpr=_where1] dc cv -inf de:f32[239696] df:f32[422220] dg:f32[200] dh:f32[] di:bool[] dj:i32[] = custom_vjp_call_jaxpr[ bwd=. at 0x7f8d1012c940> fun_jaxpr={ lambda ; dk:f32[30,42] dl:i32[1] dm:f32[239696,2001] dn:f32[422220,2001] do:f32[] dp:f32[] dq:f32[] dr:f32[239696] ds:f32[422220]. let dt:f32[200,1] = broadcast_in_dim[ broadcast_dimensions=() shape=(200, 1) ] 1.0 du:f32[200,1] = neg dt dv:i32[10] = iota[dimension=0 dtype=int32 shape=(10,)] dw:bool[10] = eq dv 9 _:i32[] dx:f32[200,1] dy:f32[239696] dz:f32[422220] = while[ body_jaxpr={ lambda ; ea:f32[239696,2001] eb:f32[422220,2001] ec:f32[30,42] ed:f32[] ee:f32[] ef:i32[1] eg:bool[10] eh:i32[] ei:f32[200,1] ej:f32[239696] ek:f32[422220]. let el:i32[] em:f32[200,1] en:f32[239696] eo:f32[422220] = scan[ jaxpr={ lambda ; ep:f32[239696,2001] eq:f32[422220,2001] er:f32[30,42] es:f32[] et:f32[] eu:i32[1] ev:i32[] ew:f32[200,1] ex:f32[239696] ey:f32[422220] ez:bool[]. let fa:f32[422220] = broadcast_in_dim[ broadcast_dimensions=() shape=(422220,) ] 1.0 fb:f32[422220] = div fa 422220.0 fc:f32[422220] = log fb fd:f32[239696] = broadcast_in_dim[ broadcast_dimensions=() shape=(239696,) ] 1.0 fe:f32[] = reduce_sum[axes=(0,)] fd ff:f32[239696] = div fd fe fg:f32[239696,1] = reshape[ dimensions=None new_sizes=(239696, 1) ] ff fh:f32[239696,2000] = slice[ limit_indices=(239696, 2000) start_indices=(0, 0) strides=None ] ep fi:i32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] 2000 fj:f32[239696,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(239696, 1) unique_indices=True ] ep fi fk:f32[239696] = squeeze[dimensions=(1,)] fj fl:i32[239696] = convert_element_type[ new_dtype=int32 weak_type=False ] fk fm:f32[422220,2000] = slice[ limit_indices=(422220, 2000) start_indices=(0, 0) strides=None ] eq fn:i32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] 2000 fo:f32[422220,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(422220, 1) unique_indices=True ] eq fn fp:f32[422220] = squeeze[dimensions=(1,)] fo fq:i32[422220] = convert_element_type[ new_dtype=int32 weak_type=False ] fp fr:f32[239696] = pjit[name=norm jaxpr=norm] fh fs:f32[422220] = pjit[name=norm jaxpr=norm1] fm ft:f32[239696,422220] = dot_general[ dimension_numbers=(([1], [1]), ([], [])) preferred_element_type=float32 ] fh fm fu:f32[1,239696] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 239696) ] fr fv:f32[422220,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(422220, 1) ] fs fw:f32[422220,239696] = mul fu fv fx:f32[422220,239696] = add fw 9.99999993922529e-09 fy:f32[422220,239696] = transpose[permutation=(1, 0)] ft fz:f32[422220,239696] = div fy fx ga:f32[422220,239696] = sub 1.0 fz gb:f32[422220,239696] = mul 1.0 ga gc:bool[239696] = lt fl 0 gd:i32[239696] = add fl 30 ge:i32[239696] = select_n gc fl gd gf:bool[422220] = lt fq 0 gg:i32[422220] = add fq 42 gh:i32[422220] = select_n gf fq gg gi:i32[239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(239696, 1) ] ge gj:i32[422220,239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(422220, 239696, 1) ] gh gk:i32[422220,239696,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(422220, 239696, 1) ] gi gl:i32[422220,239696,2] = concatenate[dimension=2] gk gj gm:f32[422220,239696,1,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(2, 3), collapsed_slice_dims=(), start_index_map=(0, 1)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=False ] er gl gn:f32[422220,239696] = squeeze[dimensions=(2, 3)] gm go:f32[422220,239696] = mul 0.0010000000474974513 gn gp:f32[422220,239696] = add gb go gq:f32[422220,239696] = add 0.0 gp gr:f32[422220,239696] = mul gq 1.0 gs:f32[422220,1] = dot_general[ dimension_numbers=(([1], [0]), ([], [])) preferred_element_type=float32 ] gr fg gt:f32[422220] = squeeze[dimensions=(1,)] gs gu:f32[422220] = broadcast_in_dim[ broadcast_dimensions=() shape=(422220,) ] 1.0 gv:f32[] = reduce_sum[axes=(0,)] gu gw:f32[422220] = div gu gv gx:f32[422220] = mul gt gw gy:f32[] = reduce_sum[axes=(0,)] gx gz:f32[] = stop_gradient gy ha:f32[] = min es 1.0 _:f32[] = convert_element_type[ new_dtype=float32 weak_type=True ] ev hb:f32[] = pow ha ev hc:f32[] = mul et hb hd:f32[] = max hc 1.0 he:f32[] = mul gz 0.05000000074505806 hf:f32[] = convert_element_type[ new_dtype=float32 weak_type=False ] hd hg:f32[] = mul hf he hh:i32[52777] = iota[ dimension=0 dtype=int32 shape=(52777,) ] _:f32[239696] _:f32[422220] _:f32[] hi:f32[52777,8] hj:f32[52777,8] = scan[ jaxpr={ lambda ; hk:f32[422220,2001] hl:f32[239696,2001] hm:f32[30,42] hn:f32[239696] ho:f32[422220] hp:f32[] hq:i32[]. let hr:i32[] = mul hq 8 hs:bool[] = lt hr 0 ht:i32[] = add hr 422220 hu:i32[] = select_n hs hr ht hv:f32[8,2001] = dynamic_slice[ slice_sizes=(8, 2001) ] hk hu 0 hw:i32[] = mul hq 8 hx:bool[] = lt hw 0 hy:i32[] = add hw 422220 hz:i32[] = select_n hx hw hy ia:f32[8] = dynamic_slice[slice_sizes=(8,)] ho hz ib:f32[239696,2000] = slice[ limit_indices=(239696, 2000) start_indices=(0, 0) strides=None ] hl ic:i32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] 2000 id:f32[239696,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(239696, 1) unique_indices=True ] hl ic ie:f32[239696] = squeeze[dimensions=(1,)] id if:i32[239696] = convert_element_type[ new_dtype=int32 weak_type=False ] ie ig:f32[8,2000] = slice[ limit_indices=(8, 2000) start_indices=(0, 0) strides=None ] hv ih:i32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] 2000 ii:f32[8,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(8, 1) unique_indices=True ] hv ih ij:f32[8] = squeeze[dimensions=(1,)] ii ik:i32[8] = convert_element_type[ new_dtype=int32 weak_type=False ] ij il:f32[239696] = pjit[name=norm jaxpr=norm] ib im:f32[8] = pjit[name=norm jaxpr=norm3] ig in:f32[239696,8] = dot_general[ dimension_numbers=(([1], [1]), ([], [])) preferred_element_type=float32 ] ib ig io:f32[1,239696] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 239696) ] il ip:f32[8,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(8, 1) ] im iq:f32[8,239696] = mul io ip ir:f32[8,239696] = add iq 9.99999993922529e-09 is:f32[8,239696] = transpose[permutation=(1, 0)] in it:f32[8,239696] = div is ir iu:f32[8,239696] = sub 1.0 it iv:f32[8,239696] = mul 1.0 iu iw:bool[239696] = lt if 0 ix:i32[239696] = add if 30 iy:i32[239696] = select_n iw if ix iz:bool[8] = lt ik 0 ja:i32[8] = add ik 42 jb:i32[8] = select_n iz ik ja jc:i32[239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(239696, 1) ] iy jd:i32[8,239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(8, 239696, 1) ] jb je:i32[8,239696,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(8, 239696, 1) ] jc jf:i32[8,239696,2] = concatenate[dimension=2] je jd jg:f32[8,239696,1,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(2, 3), collapsed_slice_dims=(), start_index_map=(0, 1)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=False ] hm jf jh:f32[8,239696] = squeeze[dimensions=(2, 3)] jg ji:f32[8,239696] = mul 0.0010000000474974513 jh jj:f32[8,239696] = add iv ji jk:f32[8,239696] = add 0.0 jj jl:f32[8,239696] = mul jk 1.0 jm:f32[1,239696] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 239696) ] hn jn:f32[8,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(8, 1) ] ia jo:f32[8,239696] = add jm jn jp:f32[8,239696] = sub jo jl jq:f32[8,239696] = div jp hp jr:f32[8] js:f32[8] = custom_jvp_call[ call_jaxpr={ lambda ; jt:f32[8,239696]. let ju:f32[8] = reduce_max[axes=(1,)] jt jv:bool[8] = is_finite ju jw:f32[8] = broadcast_in_dim[ broadcast_dimensions=() shape=(8,) ] 0.0 jx:f32[8] = select_n jv jw ju jy:f32[8] = stop_gradient jx jz:f32[8,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(8, 1) ] jy ka:f32[8,239696] = sub jt jz kb:f32[8,239696] = exp ka kc:f32[8] = reduce_sum[axes=(1,)] kb kd:f32[8] = sign kc ke:f32[8] = abs kc kf:f32[8] = log ke kg:f32[8] = add kf jy in (kg, kd) } jvp_jaxpr_thunk=.memoized at 0x7f8d10572dd0> num_consts=0 symbolic_zeros=False ] jq in (hn, ho, hp, jr, js) } length=52777 linear=(False, False, False, False, False, False, False) num_carry=3 num_consts=3 reverse=False unroll=1 ] eq ep er ex ey hg hh kh:f32[422216] = reshape[ dimensions=None new_sizes=(422216,) ] hi ki:f32[422216] = reshape[ dimensions=None new_sizes=(422216,) ] hj kj:f32[4,2001] = slice[ limit_indices=(422220, 2001) start_indices=(422216, 0) strides=None ] eq kk:f32[4] = slice[ limit_indices=(422220,) start_indices=(422216,) strides=None ] ey kl:f32[239696,2000] = slice[ limit_indices=(239696, 2000) start_indices=(0, 0) strides=None ] ep km:i32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] 2000 kn:f32[239696,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(239696, 1) unique_indices=True ] ep km ko:f32[239696] = squeeze[dimensions=(1,)] kn kp:i32[239696] = convert_element_type[ new_dtype=int32 weak_type=False ] ko kq:f32[4,2000] = slice[ limit_indices=(4, 2000) start_indices=(0, 0) strides=None ] kj kr:i32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] 2000 ks:f32[4,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(4, 1) unique_indices=True ] kj kr kt:f32[4] = squeeze[dimensions=(1,)] ks ku:i32[4] = convert_element_type[ new_dtype=int32 weak_type=False ] kt kv:f32[239696] = pjit[name=norm jaxpr=norm] kl kw:f32[4] = pjit[name=norm jaxpr=norm2] kq kx:f32[239696,4] = dot_general[ dimension_numbers=(([1], [1]), ([], [])) preferred_element_type=float32 ] kl kq ky:f32[1,239696] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 239696) ] kv kz:f32[4,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(4, 1) ] kw la:f32[4,239696] = mul ky kz lb:f32[4,239696] = add la 9.99999993922529e-09 lc:f32[4,239696] = transpose[permutation=(1, 0)] kx ld:f32[4,239696] = div lc lb le:f32[4,239696] = sub 1.0 ld lf:f32[4,239696] = mul 1.0 le lg:bool[239696] = lt kp 0 lh:i32[239696] = add kp 30 li:i32[239696] = select_n lg kp lh lj:bool[4] = lt ku 0 lk:i32[4] = add ku 42 ll:i32[4] = select_n lj ku lk lm:i32[239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(239696, 1) ] li ln:i32[4,239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(4, 239696, 1) ] ll lo:i32[4,239696,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(4, 239696, 1) ] lm lp:i32[4,239696,2] = concatenate[dimension=2] lo ln lq:f32[4,239696,1,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(2, 3), collapsed_slice_dims=(), start_index_map=(0, 1)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=False ] er lp lr:f32[4,239696] = squeeze[dimensions=(2, 3)] lq ls:f32[4,239696] = mul 0.0010000000474974513 lr lt:f32[4,239696] = add lf ls lu:f32[4,239696] = add 0.0 lt lv:f32[4,239696] = mul lu 1.0 lw:f32[1,239696] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 239696) ] ex lx:f32[4,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(4, 1) ] kk ly:f32[4,239696] = add lw lx lz:f32[4,239696] = sub ly lv ma:f32[4,239696] = div lz hg mb:f32[4] mc:f32[4] = custom_jvp_call[ call_jaxpr={ lambda ; md:f32[4,239696]. let me:f32[4] = reduce_max[axes=(1,)] md mf:bool[4] = is_finite me mg:f32[4] = broadcast_in_dim[ broadcast_dimensions=() shape=(4,) ] 0.0 mh:f32[4] = select_n mf mg me mi:f32[4] = stop_gradient mh mj:f32[4,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(4, 1) ] mi mk:f32[4,239696] = sub md mj ml:f32[4,239696] = exp mk mm:f32[4] = reduce_sum[axes=(1,)] ml mn:f32[4] = sign mm mo:f32[4] = abs mm mp:f32[4] = log mo mq:f32[4] = add mp mi in (mq, mn) } jvp_jaxpr_thunk=.memoized at 0x7f8d105720e0> num_consts=0 symbolic_zeros=False ] ma mr:f32[422220] = concatenate[dimension=0] kh mb _:f32[422220] = concatenate[dimension=0] ki mc ms:f32[422220] = mul hg mr mt:bool[422220] = is_finite ey mu:f32[422220] = pjit[name=_where jaxpr=_where2] mt ey 0 mv:f32[422220] = sub ms mu mw:f32[422220] = mul hg fc mx:bool[422220] = is_finite mv my:f32[422220] = pjit[name=_where jaxpr=_where2] mx mv 0 mz:f32[422220] = sub mw my na:f32[422220] = mul 1.0 mz nb:bool[422220] = is_finite ey nc:f32[422220] = pjit[name=_where jaxpr=_where1] nb ey 0.0 nd:f32[422220] = mul 0.0 nc ne:f32[422220] = mul 1.0 na nf:f32[422220] = add nd ne ng:f32[239696] = broadcast_in_dim[ broadcast_dimensions=() shape=(239696,) ] 1.0 nh:f32[239696] = div ng 239696.0 ni:f32[239696] = log nh nj:f32[239696] = broadcast_in_dim[ broadcast_dimensions=() shape=(239696,) ] 1.0 nk:f32[] = reduce_sum[axes=(0,)] nj nl:f32[239696] = div nj nk nm:f32[239696,1] = reshape[ dimensions=None new_sizes=(239696, 1) ] nl nn:f32[239696,2000] = slice[ limit_indices=(239696, 2000) start_indices=(0, 0) strides=None ] ep no:i32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] 2000 np:f32[239696,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(239696, 1) unique_indices=True ] ep no nq:f32[239696] = squeeze[dimensions=(1,)] np nr:i32[239696] = convert_element_type[ new_dtype=int32 weak_type=False ] nq ns:f32[422220,2000] = slice[ limit_indices=(422220, 2000) start_indices=(0, 0) strides=None ] eq nt:i32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] 2000 nu:f32[422220,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(422220, 1) unique_indices=True ] eq nt nv:f32[422220] = squeeze[dimensions=(1,)] nu nw:i32[422220] = convert_element_type[ new_dtype=int32 weak_type=False ] nv nx:f32[239696] = pjit[name=norm jaxpr=norm] nn ny:f32[422220] = pjit[name=norm jaxpr=norm1] ns nz:f32[239696,422220] = dot_general[ dimension_numbers=(([1], [1]), ([], [])) preferred_element_type=float32 ] nn ns oa:f32[1,239696] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 239696) ] nx ob:f32[422220,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(422220, 1) ] ny oc:f32[422220,239696] = mul oa ob od:f32[422220,239696] = add oc 9.99999993922529e-09 oe:f32[422220,239696] = transpose[permutation=(1, 0)] nz of:f32[422220,239696] = div oe od og:f32[422220,239696] = sub 1.0 of oh:f32[422220,239696] = mul 1.0 og oi:bool[239696] = lt nr 0 oj:i32[239696] = add nr 30 ok:i32[239696] = select_n oi nr oj ol:bool[422220] = lt nw 0 om:i32[422220] = add nw 42 on:i32[422220] = select_n ol nw om oo:i32[239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(239696, 1) ] ok op:i32[422220,239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(422220, 239696, 1) ] on oq:i32[422220,239696,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(422220, 239696, 1) ] oo or:i32[422220,239696,2] = concatenate[dimension=2] oq op os:f32[422220,239696,1,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(2, 3), collapsed_slice_dims=(), start_index_map=(0, 1)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=False ] er or ot:f32[422220,239696] = squeeze[dimensions=(2, 3)] os ou:f32[422220,239696] = mul 0.0010000000474974513 ot ov:f32[422220,239696] = add oh ou ow:f32[422220,239696] = add 0.0 ov ox:f32[422220,239696] = mul ow 1.0 oy:f32[422220,1] = dot_general[ dimension_numbers=(([1], [0]), ([], [])) preferred_element_type=float32 ] ox nm oz:f32[422220] = squeeze[dimensions=(1,)] oy pa:f32[422220] = broadcast_in_dim[ broadcast_dimensions=() shape=(422220,) ] 1.0 pb:f32[] = reduce_sum[axes=(0,)] pa pc:f32[422220] = div pa pb pd:f32[422220] = mul oz pc pe:f32[] = reduce_sum[axes=(0,)] pd pf:f32[] = stop_gradient pe pg:f32[] = min es 1.0 _:f32[] = convert_element_type[ new_dtype=float32 weak_type=True ] ev ph:f32[] = pow pg ev pi:f32[] = mul et ph pj:f32[] = max pi 1.0 pk:f32[] = mul pf 0.05000000074505806 pl:f32[] = convert_element_type[ new_dtype=float32 weak_type=False ] pj pm:f32[] = mul pl pk pn:i32[29962] = iota[ dimension=0 dtype=int32 shape=(29962,) ] _:f32[239696] _:f32[422220] _:f32[] po:f32[29962,8] pp:f32[29962,8] = scan[ jaxpr={ lambda ; pq:f32[239696,2001] pr:f32[422220,2001] ps:f32[30,42] pt:f32[239696] pu:f32[422220] pv:f32[] pw:i32[]. let px:i32[] = mul pw 8 py:bool[] = lt px 0 pz:i32[] = add px 239696 qa:i32[] = select_n py px pz qb:f32[8,2001] = dynamic_slice[ slice_sizes=(8, 2001) ] pq qa 0 qc:i32[] = mul pw 8 qd:bool[] = lt qc 0 qe:i32[] = add qc 239696 qf:i32[] = select_n qd qc qe qg:f32[8] = dynamic_slice[slice_sizes=(8,)] pt qf qh:f32[422220,2000] = slice[ limit_indices=(422220, 2000) start_indices=(0, 0) strides=None ] pr qi:i32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] 2000 qj:f32[422220,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(422220, 1) unique_indices=True ] pr qi qk:f32[422220] = squeeze[dimensions=(1,)] qj ql:i32[422220] = convert_element_type[ new_dtype=int32 weak_type=False ] qk qm:f32[8,2000] = slice[ limit_indices=(8, 2000) start_indices=(0, 0) strides=None ] qb qn:i32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] 2000 qo:f32[8,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(8, 1) unique_indices=True ] qb qn qp:f32[8] = squeeze[dimensions=(1,)] qo qq:i32[8] = convert_element_type[ new_dtype=int32 weak_type=False ] qp qr:f32[422220] = pjit[name=norm jaxpr=norm1] qh qs:f32[8] = pjit[name=norm jaxpr=norm3] qm qt:f32[422220,8] = dot_general[ dimension_numbers=(([1], [1]), ([], [])) preferred_element_type=float32 ] qh qm qu:f32[1,422220] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 422220) ] qr qv:f32[8,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(8, 1) ] qs qw:f32[8,422220] = mul qu qv qx:f32[8,422220] = add qw 9.99999993922529e-09 qy:f32[8,422220] = transpose[permutation=(1, 0)] qt qz:f32[8,422220] = div qy qx ra:f32[8,422220] = sub 1.0 qz rb:f32[8,422220] = mul 1.0 ra rc:bool[422220] = lt ql 0 rd:i32[422220] = add ql 30 re:i32[422220] = select_n rc ql rd rf:bool[8] = lt qq 0 rg:i32[8] = add qq 42 rh:i32[8] = select_n rf qq rg ri:i32[422220,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(422220, 1) ] re rj:i32[8,422220,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(8, 422220, 1) ] rh rk:i32[8,422220,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(8, 422220, 1) ] ri rl:i32[8,422220,2] = concatenate[dimension=2] rk rj rm:f32[8,422220,1,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(2, 3), collapsed_slice_dims=(), start_index_map=(0, 1)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=False ] ps rl rn:f32[8,422220] = squeeze[dimensions=(2, 3)] rm ro:f32[8,422220] = mul 0.0010000000474974513 rn rp:f32[8,422220] = add rb ro rq:f32[8,422220] = add 0.0 rp rr:f32[8,422220] = mul rq 1.0 rs:f32[1,422220] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 422220) ] pu rt:f32[8,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(8, 1) ] qg ru:f32[8,422220] = add rs rt rv:f32[8,422220] = sub ru rr rw:f32[8,422220] = div rv pv rx:f32[8] ry:f32[8] = custom_jvp_call[ call_jaxpr={ lambda ; rz:f32[8,422220]. let sa:f32[8] = reduce_max[axes=(1,)] rz sb:bool[8] = is_finite sa sc:f32[8] = broadcast_in_dim[ broadcast_dimensions=() shape=(8,) ] 0.0 sd:f32[8] = select_n sb sc sa se:f32[8] = stop_gradient sd sf:f32[8,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(8, 1) ] se sg:f32[8,422220] = sub rz sf sh:f32[8,422220] = exp sg si:f32[8] = reduce_sum[axes=(1,)] sh sj:f32[8] = sign si sk:f32[8] = abs si sl:f32[8] = log sk sm:f32[8] = add sl se in (sm, sj) } jvp_jaxpr_thunk=.memoized at 0x7f8d10572560> num_consts=0 symbolic_zeros=False ] rw in (pt, pu, pv, rx, ry) } length=29962 linear=(False, False, False, False, False, False, False) num_carry=3 num_consts=3 reverse=False unroll=1 ] ep eq er ex nf pm pn sn:f32[239696] = reshape[ dimensions=None new_sizes=(239696,) ] po so:f32[239696] = reshape[ dimensions=None new_sizes=(239696,) ] pp sp:f32[0,2001] = slice[ limit_indices=(239696, 2001) start_indices=(239696, 0) strides=None ] ep sq:f32[0] = slice[ limit_indices=(239696,) start_indices=(239696,) strides=None ] ex sr:f32[422220,2000] = slice[ limit_indices=(422220, 2000) start_indices=(0, 0) strides=None ] eq ss:i32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] 2000 st:f32[422220,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(422220, 1) unique_indices=True ] eq ss su:f32[422220] = squeeze[dimensions=(1,)] st sv:i32[422220] = convert_element_type[ new_dtype=int32 weak_type=False ] su sw:f32[0,2000] = slice[ limit_indices=(0, 2000) start_indices=(0, 0) strides=None ] sp sx:i32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] 2000 sy:f32[0,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(0, 1) unique_indices=True ] sp sx sz:f32[0] = squeeze[dimensions=(1,)] sy ta:i32[0] = convert_element_type[ new_dtype=int32 weak_type=False ] sz tb:f32[422220] = pjit[name=norm jaxpr=norm1] sr tc:f32[0] = pjit[ name=norm jaxpr={ lambda ; td:f32[0,2000]. let te:f32[0,2000] = mul td td tf:f32[0] = reduce_sum[axes=(1,)] te tg:f32[0] = sqrt tf in (tg,) } ] sw th:f32[422220,0] = dot_general[ dimension_numbers=(([1], [1]), ([], [])) preferred_element_type=float32 ] sr sw ti:f32[1,422220] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 422220) ] tb tj:f32[0,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(0, 1) ] tc tk:f32[0,422220] = mul ti tj tl:f32[0,422220] = add tk 9.99999993922529e-09 tm:f32[0,422220] = transpose[permutation=(1, 0)] th tn:f32[0,422220] = div tm tl to:f32[0,422220] = sub 1.0 tn tp:f32[0,422220] = mul 1.0 to tq:bool[422220] = lt sv 0 tr:i32[422220] = add sv 30 ts:i32[422220] = select_n tq sv tr tt:bool[0] = lt ta 0 tu:i32[0] = add ta 42 tv:i32[0] = select_n tt ta tu tw:i32[422220,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(422220, 1) ] ts tx:i32[0,422220,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(0, 422220, 1) ] tv ty:i32[0,422220,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(0, 422220, 1) ] tw tz:i32[0,422220,2] = concatenate[dimension=2] ty tx ua:f32[0,422220,1,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(2, 3), collapsed_slice_dims=(), start_index_map=(0, 1)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=False ] er tz ub:f32[0,422220] = squeeze[dimensions=(2, 3)] ua uc:f32[0,422220] = mul 0.0010000000474974513 ub ud:f32[0,422220] = add tp uc ue:f32[0,422220] = add 0.0 ud uf:f32[0,422220] = mul ue 1.0 ug:f32[1,422220] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 422220) ] nf uh:f32[0,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(0, 1) ] sq ui:f32[0,422220] = add ug uh uj:f32[0,422220] = sub ui uf uk:f32[0,422220] = div uj pm ul:f32[0] um:f32[0] = custom_jvp_call[ call_jaxpr={ lambda ; un:f32[0,422220]. let uo:f32[0] = reduce_max[axes=(1,)] un up:bool[0] = is_finite uo uq:f32[0] = broadcast_in_dim[ broadcast_dimensions=() shape=(0,) ] 0.0 ur:f32[0] = select_n up uq uo us:f32[0] = stop_gradient ur ut:f32[0,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(0, 1) ] us uu:f32[0,422220] = sub un ut uv:f32[0,422220] = exp uu uw:f32[0] = reduce_sum[axes=(1,)] uv ux:f32[0] = sign uw uy:f32[0] = abs uw uz:f32[0] = log uy va:f32[0] = add uz us in (va, ux) } jvp_jaxpr_thunk=.memoized at 0x7f8d10572d40> num_consts=0 symbolic_zeros=False ] uk vb:f32[239696] = concatenate[dimension=0] sn ul _:f32[239696] = concatenate[dimension=0] so um vc:f32[239696] = mul pm vb vd:bool[239696] = is_finite ex ve:f32[239696] = pjit[name=_where jaxpr=_where4] vd ex 0 vf:f32[239696] = sub vc ve vg:f32[239696] = mul pm ni vh:bool[239696] = is_finite vf vi:f32[239696] = pjit[name=_where jaxpr=_where4] vh vf 0 vj:f32[239696] = sub vg vi vk:f32[239696] = mul 1.0 vj vl:bool[239696] = is_finite ex vm:f32[239696] = pjit[name=_where jaxpr=_where] vl ex 0.0 vn:f32[239696] = mul 0.0 vm vo:f32[239696] = mul 1.0 vk vp:f32[239696] = add vn vo vq:bool[] = eq ev 1999 vr:bool[] = ge ev 0 vs:bool[] = convert_element_type[ new_dtype=bool weak_type=False ] vr vt:bool[] = and ez vs vu:bool[] = convert_element_type[ new_dtype=bool weak_type=False ] vq vv:bool[] = or vu vt vw:i32[] = convert_element_type[ new_dtype=int32 weak_type=False ] vv vx:f32[] = cond[ branches=( { lambda ; vy_:f32[30,42] vz_:i32[1] wa:f32[200,1] wb:f32[239696] wc:f32[422220] wd:f32[239696,2001] we:f32[422220,2001] wf:f32[] wg:f32[]. let in (inf,) } { lambda ; wh:f32[30,42] wi:i32[1] wj:f32[200,1] wk:f32[239696] wl:f32[422220] wm:f32[239696,2001] wn:f32[422220,2001] wo:f32[] wp:f32[]. let wq:f32[422220] = broadcast_in_dim[ broadcast_dimensions=() shape=(422220,) ] 1.0 wr:f32[422220] = div wq 422220.0 ws:f32[239696] = broadcast_in_dim[ broadcast_dimensions=() shape=(239696,) ] 1.0 wt:f32[] = reduce_sum[axes=(0,)] ws wu:f32[239696] = div ws wt wv:f32[239696,1] = reshape[ dimensions=None new_sizes=(239696, 1) ] wu ww:f32[239696,2000] = slice[ limit_indices=(239696, 2000) start_indices=(0, 0) strides=None ] wm wx:i32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] 2000 wy:f32[239696,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(239696, 1) unique_indices=True ] wm wx wz:f32[239696] = squeeze[dimensions=(1,)] wy xa:i32[239696] = convert_element_type[ new_dtype=int32 weak_type=False ] wz xb:f32[422220,2000] = slice[ limit_indices=(422220, 2000) start_indices=(0, 0) strides=None ] wn xc:i32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] 2000 xd:f32[422220,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(422220, 1) unique_indices=True ] wn xc xe:f32[422220] = squeeze[dimensions=(1,)] xd xf:i32[422220] = convert_element_type[ new_dtype=int32 weak_type=False ] xe xg:f32[239696] = pjit[name=norm jaxpr=norm] ww xh:f32[422220] = pjit[name=norm jaxpr=norm1] xb xi:f32[239696,422220] = dot_general[ dimension_numbers=(([1], [1]), ([], [])) preferred_element_type=float32 ] ww xb xj:f32[1,239696] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 239696) ] xg xk:f32[422220,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(422220, 1) ] xh xl:f32[422220,239696] = mul xj xk xm:f32[422220,239696] = add xl 9.99999993922529e-09 xn:f32[422220,239696] = transpose[ permutation=(1, 0) ] xi xo:f32[422220,239696] = div xn xm xp:f32[422220,239696] = sub 1.0 xo xq:f32[422220,239696] = mul 1.0 xp xr:bool[239696] = lt xa 0 xs:i32[239696] = add xa 30 xt:i32[239696] = select_n xr xa xs xu:bool[422220] = lt xf 0 xv:i32[422220] = add xf 42 xw:i32[422220] = select_n xu xf xv xx:i32[239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(239696, 1) ] xt xy:i32[422220,239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(422220, 239696, 1) ] xw xz:i32[422220,239696,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(422220, 239696, 1) ] xx ya:i32[422220,239696,2] = concatenate[dimension=2] xz xy yb:f32[422220,239696,1,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(2, 3), collapsed_slice_dims=(), start_index_map=(0, 1)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=False ] wh ya yc:f32[422220,239696] = squeeze[dimensions=(2, 3)] yb yd:f32[422220,239696] = mul 0.0010000000474974513 yc ye:f32[422220,239696] = add xq yd yf:f32[422220,239696] = add 0.0 ye yg:f32[422220,239696] = mul yf 1.0 yh:f32[422220,1] = dot_general[ dimension_numbers=(([1], [0]), ([], [])) preferred_element_type=float32 ] yg wv yi:f32[422220] = squeeze[dimensions=(1,)] yh yj:f32[422220] = broadcast_in_dim[ broadcast_dimensions=() shape=(422220,) ] 1.0 yk:f32[] = reduce_sum[axes=(0,)] yj yl:f32[422220] = div yj yk ym:f32[422220] = mul yi yl yn:f32[] = reduce_sum[axes=(0,)] ym yo:f32[] = stop_gradient yn yp:f32[] = mul yo 0.05000000074505806 yq:i32[52777] = iota[ dimension=0 dtype=int32 shape=(52777,) ] _:f32[239696] _:f32[422220] _:f32[] yr:f32[52777,8] ys:f32[52777,8] = scan[ jaxpr={ lambda ; yt:f32[422220,2001] yu:f32[239696,2001] yv:f32[30,42] yw:f32[239696] yx:f32[422220] yy:f32[] yz:i32[]. let za:i32[] = mul yz 8 zb:bool[] = lt za 0 zc:i32[] = add za 422220 zd:i32[] = select_n zb za zc ze:f32[8,2001] = dynamic_slice[ slice_sizes=(8, 2001) ] yt zd 0 zf:i32[] = mul yz 8 zg:bool[] = lt zf 0 zh:i32[] = add zf 422220 zi:i32[] = select_n zg zf zh zj:f32[8] = dynamic_slice[slice_sizes=(8,)] yx zi zk:f32[239696,2000] = slice[ limit_indices=(239696, 2000) start_indices=(0, 0) strides=None ] yu zl:i32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] 2000 zm:f32[239696,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(239696, 1) unique_indices=True ] yu zl zn:f32[239696] = squeeze[dimensions=(1,)] zm zo:i32[239696] = convert_element_type[ new_dtype=int32 weak_type=False ] zn zp:f32[8,2000] = slice[ limit_indices=(8, 2000) start_indices=(0, 0) strides=None ] ze zq:i32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] 2000 zr:f32[8,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(8, 1) unique_indices=True ] ze zq zs:f32[8] = squeeze[dimensions=(1,)] zr zt:i32[8] = convert_element_type[ new_dtype=int32 weak_type=False ] zs zu:f32[239696] = pjit[name=norm jaxpr=norm] zk zv:f32[8] = pjit[name=norm jaxpr=norm3] zp zw:f32[239696,8] = dot_general[ dimension_numbers=(([1], [1]), ([], [])) preferred_element_type=float32 ] zk zp zx:f32[1,239696] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 239696) ] zu zy:f32[8,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(8, 1) ] zv zz:f32[8,239696] = mul zx zy baa:f32[8,239696] = add zz 9.99999993922529e-09 bab:f32[8,239696] = transpose[ permutation=(1, 0) ] zw bac:f32[8,239696] = div bab baa bad:f32[8,239696] = sub 1.0 bac bae:f32[8,239696] = mul 1.0 bad baf:bool[239696] = lt zo 0 bag:i32[239696] = add zo 30 bah:i32[239696] = select_n baf zo bag bai:bool[8] = lt zt 0 baj:i32[8] = add zt 42 bak:i32[8] = select_n bai zt baj bal:i32[239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(239696, 1) ] bah bam:i32[8,239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(8, 239696, 1) ] bak ban:i32[8,239696,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(8, 239696, 1) ] bal bao:i32[8,239696,2] = concatenate[ dimension=2 ] ban bam bap:f32[8,239696,1,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(2, 3), collapsed_slice_dims=(), start_index_map=(0, 1)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=False ] yv bao baq:f32[8,239696] = squeeze[ dimensions=(2, 3) ] bap bar:f32[8,239696] = mul 0.0010000000474974513 baq bas:f32[8,239696] = add bae bar bat:f32[8,239696] = add 0.0 bas bau:f32[8,239696] = mul bat 1.0 bav:f32[1,239696] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 239696) ] yw baw:f32[8,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(8, 1) ] zj bax:f32[8,239696] = add bav baw bay:f32[8,239696] = sub bax bau baz:f32[8,239696] = div bay yy bba:f32[8] bbb:f32[8] = custom_jvp_call[ call_jaxpr={ lambda ; bbc:f32[8,239696]. let bbd:f32[8] = reduce_max[axes=(1,)] bbc bbe:bool[8] = is_finite bbd bbf:f32[8] = broadcast_in_dim[ broadcast_dimensions=() shape=(8,) ] 0.0 bbg:f32[8] = select_n bbe bbf bbd bbh:f32[8] = stop_gradient bbg bbi:f32[8,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(8, 1) ] bbh bbj:f32[8,239696] = sub bbc bbi bbk:f32[8,239696] = exp bbj bbl:f32[8] = reduce_sum[axes=(1,)] bbk bbm:f32[8] = sign bbl bbn:f32[8] = abs bbl bbo:f32[8] = log bbn bbp:f32[8] = add bbo bbh in (bbp, bbm) } jvp_jaxpr_thunk=.memoized at 0x7f8d10573ac0> num_consts=0 symbolic_zeros=False ] baz in (yw, yx, yy, bba, bbb) } length=52777 linear=(False, False, False, False, False, False, False) num_carry=3 num_consts=3 reverse=False unroll=1 ] wn wm wh wk wl yp yq bbq:f32[422216] = reshape[ dimensions=None new_sizes=(422216,) ] yr bbr:f32[422216] = reshape[ dimensions=None new_sizes=(422216,) ] ys bbs:f32[4,2001] = slice[ limit_indices=(422220, 2001) start_indices=(422216, 0) strides=None ] wn bbt:f32[4] = slice[ limit_indices=(422220,) start_indices=(422216,) strides=None ] wl bbu:f32[239696,2000] = slice[ limit_indices=(239696, 2000) start_indices=(0, 0) strides=None ] wm bbv:i32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] 2000 bbw:f32[239696,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(239696, 1) unique_indices=True ] wm bbv bbx:f32[239696] = squeeze[dimensions=(1,)] bbw bby:i32[239696] = convert_element_type[ new_dtype=int32 weak_type=False ] bbx bbz:f32[4,2000] = slice[ limit_indices=(4, 2000) start_indices=(0, 0) strides=None ] bbs bca:i32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] 2000 bcb:f32[4,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(4, 1) unique_indices=True ] bbs bca bcc:f32[4] = squeeze[dimensions=(1,)] bcb bcd:i32[4] = convert_element_type[ new_dtype=int32 weak_type=False ] bcc bce:f32[239696] = pjit[name=norm jaxpr=norm] bbu bcf:f32[4] = pjit[name=norm jaxpr=norm2] bbz bcg:f32[239696,4] = dot_general[ dimension_numbers=(([1], [1]), ([], [])) preferred_element_type=float32 ] bbu bbz bch:f32[1,239696] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 239696) ] bce bci:f32[4,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(4, 1) ] bcf bcj:f32[4,239696] = mul bch bci bck:f32[4,239696] = add bcj 9.99999993922529e-09 bcl:f32[4,239696] = transpose[permutation=(1, 0)] bcg bcm:f32[4,239696] = div bcl bck bcn:f32[4,239696] = sub 1.0 bcm bco:f32[4,239696] = mul 1.0 bcn bcp:bool[239696] = lt bby 0 bcq:i32[239696] = add bby 30 bcr:i32[239696] = select_n bcp bby bcq bcs:bool[4] = lt bcd 0 bct:i32[4] = add bcd 42 bcu:i32[4] = select_n bcs bcd bct bcv:i32[239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(239696, 1) ] bcr bcw:i32[4,239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(4, 239696, 1) ] bcu bcx:i32[4,239696,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(4, 239696, 1) ] bcv bcy:i32[4,239696,2] = concatenate[dimension=2] bcx bcw bcz:f32[4,239696,1,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(2, 3), collapsed_slice_dims=(), start_index_map=(0, 1)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=False ] wh bcy bda:f32[4,239696] = squeeze[dimensions=(2, 3)] bcz bdb:f32[4,239696] = mul 0.0010000000474974513 bda bdc:f32[4,239696] = add bco bdb bdd:f32[4,239696] = add 0.0 bdc bde:f32[4,239696] = mul bdd 1.0 bdf:f32[1,239696] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 239696) ] wk bdg:f32[4,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(4, 1) ] bbt bdh:f32[4,239696] = add bdf bdg bdi:f32[4,239696] = sub bdh bde bdj:f32[4,239696] = div bdi yp bdk:f32[4] bdl:f32[4] = custom_jvp_call[ call_jaxpr={ lambda ; bdm:f32[4,239696]. let bdn:f32[4] = reduce_max[axes=(1,)] bdm bdo:bool[4] = is_finite bdn bdp:f32[4] = broadcast_in_dim[ broadcast_dimensions=() shape=(4,) ] 0.0 bdq:f32[4] = select_n bdo bdp bdn bdr:f32[4] = stop_gradient bdq bds:f32[4,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(4, 1) ] bdr bdt:f32[4,239696] = sub bdm bds bdu:f32[4,239696] = exp bdt bdv:f32[4] = reduce_sum[axes=(1,)] bdu bdw:f32[4] = sign bdv bdx:f32[4] = abs bdv bdy:f32[4] = log bdx bdz:f32[4] = add bdy bdr in (bdz, bdw) } jvp_jaxpr_thunk=.memoized at 0x7f8d10573a30> num_consts=0 symbolic_zeros=False ] bdj bea:f32[422220] = concatenate[dimension=0] bbq bdk _:f32[422220] = concatenate[dimension=0] bbr bdl beb:f32[422220] = mul yp bea bec:bool[422220] = is_finite wl bed:f32[422220] = pjit[name=_where jaxpr=_where2] bec wl 0 bee:f32[422220] = sub beb bed bef:f32[422220] = add bee wl beg:f32[239696] = broadcast_in_dim[ broadcast_dimensions=() shape=(239696,) ] 1.0 beh:f32[] = reduce_sum[axes=(0,)] beg bei:f32[239696] = div beg beh bej:f32[239696,1] = reshape[ dimensions=None new_sizes=(239696, 1) ] bei bek:f32[239696,2000] = slice[ limit_indices=(239696, 2000) start_indices=(0, 0) strides=None ] wm bel:i32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] 2000 bem:f32[239696,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(239696, 1) unique_indices=True ] wm bel ben:f32[239696] = squeeze[dimensions=(1,)] bem beo:i32[239696] = convert_element_type[ new_dtype=int32 weak_type=False ] ben bep:f32[422220,2000] = slice[ limit_indices=(422220, 2000) start_indices=(0, 0) strides=None ] wn beq:i32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] 2000 ber:f32[422220,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(422220, 1) unique_indices=True ] wn beq bes:f32[422220] = squeeze[dimensions=(1,)] ber bet:i32[422220] = convert_element_type[ new_dtype=int32 weak_type=False ] bes beu:f32[239696] = pjit[name=norm jaxpr=norm] bek bev:f32[422220] = pjit[name=norm jaxpr=norm1] bep bew:f32[239696,422220] = dot_general[ dimension_numbers=(([1], [1]), ([], [])) preferred_element_type=float32 ] bek bep bex:f32[1,239696] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 239696) ] beu bey:f32[422220,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(422220, 1) ] bev bez:f32[422220,239696] = mul bex bey bfa:f32[422220,239696] = add bez 9.99999993922529e-09 bfb:f32[422220,239696] = transpose[ permutation=(1, 0) ] bew bfc:f32[422220,239696] = div bfb bfa bfd:f32[422220,239696] = sub 1.0 bfc bfe:f32[422220,239696] = mul 1.0 bfd bff:bool[239696] = lt beo 0 bfg:i32[239696] = add beo 30 bfh:i32[239696] = select_n bff beo bfg bfi:bool[422220] = lt bet 0 bfj:i32[422220] = add bet 42 bfk:i32[422220] = select_n bfi bet bfj bfl:i32[239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(239696, 1) ] bfh bfm:i32[422220,239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(422220, 239696, 1) ] bfk bfn:i32[422220,239696,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(422220, 239696, 1) ] bfl bfo:i32[422220,239696,2] = concatenate[ dimension=2 ] bfn bfm bfp:f32[422220,239696,1,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(2, 3), collapsed_slice_dims=(), start_index_map=(0, 1)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=False ] wh bfo bfq:f32[422220,239696] = squeeze[ dimensions=(2, 3) ] bfp bfr:f32[422220,239696] = mul 0.0010000000474974513 bfq bfs:f32[422220,239696] = add bfe bfr bft:f32[422220,239696] = add 0.0 bfs bfu:f32[422220,239696] = mul bft 1.0 bfv:f32[422220,1] = dot_general[ dimension_numbers=(([1], [0]), ([], [])) preferred_element_type=float32 ] bfu bej bfw:f32[422220] = squeeze[dimensions=(1,)] bfv bfx:f32[422220] = broadcast_in_dim[ broadcast_dimensions=() shape=(422220,) ] 1.0 bfy:f32[] = reduce_sum[axes=(0,)] bfx bfz:f32[422220] = div bfx bfy bga:f32[422220] = mul bfw bfz bgb:f32[] = reduce_sum[axes=(0,)] bga bgc:f32[] = stop_gradient bgb bgd:f32[] = mul bgc 0.05000000074505806 bge:f32[422220] = div bef bgd bgf:f32[422220] = exp bge bgg:f32[422220] = sub bgf wr bgh:f32[422220] = abs bgg bgi:i32[1,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(1, 1) ] wi bgj:f32[1,422220] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 422220) ] bgh _:f32[1,1] = convert_element_type[ new_dtype=float32 weak_type=False ] bgi bgk:f32[1,422220] = pow bgj bgi bgl:f32[1] = reduce_sum[axes=(1,)] bgk bgm:f32[1] = convert_element_type[ new_dtype=float32 weak_type=True ] wi bgn:f32[1] = div 1.0 bgm bgo:f32[1] = convert_element_type[ new_dtype=float32 weak_type=False ] bgn bgp:f32[1] = pow bgl bgo bgq:f32[1] = slice[ limit_indices=(1,) start_indices=(0,) strides=None ] bgp bgr:f32[] = squeeze[dimensions=(0,)] bgq in (bgr,) } ) linear=(False, False, False, False, False, False, False, False, False) ] vw er eu ew vp nf ep eq et es bgs:i32[] = pjit[name=floor_divide jaxpr=floor_divide] ev 10 bgt:bool[] = lt bgs 0 bgu:i32[] = add bgs 200 bgv:i32[] = select_n bgt bgs bgu bgw:i32[] = convert_element_type[ new_dtype=int32 weak_type=False ] bgv bgx:i32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] bgw bgy:f32[] = convert_element_type[ new_dtype=float32 weak_type=False ] vx bgz:f32[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) ] bgy bha:f32[200,1] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True mode=GatherScatterMode.FILL_OR_DROP unique_indices=True update_consts=() update_jaxpr=None ] ew bgx bgz bhb:i32[] = add ev 1 in (bhb, bha, vp, nf) } length=10 linear=(False, False, False, False, False, False, False, False, False, False, False) num_carry=4 num_consts=6 reverse=False unroll=1 ] ea eb ec ed ee ef eh ei ej ek eg in (el, em, en, eo) } body_nconsts=7 cond_jaxpr={ lambda ; bhc:f32[] bhd:i32[] bhe:f32[200,1] bhf:f32[239696] bhg:f32[422220]. let bhh:bool[] = lt bhd 2000 bhi:bool[] = lt bhd 0 bhj:i32[] = pjit[name=floor_divide jaxpr=floor_divide] bhd 10 bhk:i32[] = sub bhj 1 bhl:bool[] = lt bhk 0 bhm:i32[] = convert_element_type[ new_dtype=int32 weak_type=False ] bhk bhn:i32[] = add bhm 200 bho:i32[] = select_n bhl bhk bhn bhp:bool[] = lt 0 0 bhq:i32[] = add 0 1 bhr:i32[] = select_n bhp 0 bhq bhs:f32[1,1] = dynamic_slice[slice_sizes=(1, 1)] bhe bho bhr bht:f32[] = squeeze[dimensions=(0, 1)] bhs bhu:bool[] = is_finite bht bhv:bool[] = not bhu bhw:bool[] = not bhv bhx:i32[] = pjit[name=floor_divide jaxpr=floor_divide] bhd 10 bhy:i32[] = sub bhx 1 bhz:bool[] = lt bhy 0 bia:i32[] = convert_element_type[ new_dtype=int32 weak_type=False ] bhy bib:i32[] = add bia 200 bic:i32[] = select_n bhz bhy bib bid:bool[] = lt 0 0 bie:i32[] = add 0 1 bif:i32[] = select_n bid 0 bie big:f32[1,1] = dynamic_slice[slice_sizes=(1, 1)] bhe bic bif bih:f32[] = squeeze[dimensions=(0, 1)] big bii:bool[] = gt bhd 0 bij:f32[] = convert_element_type[ new_dtype=float32 weak_type=False ] bhc bik:bool[] = lt bih bij bil:bool[] = convert_element_type[ new_dtype=bool weak_type=False ] bii bim:bool[] = and bil bik bin:bool[] = not bim bio:bool[] = and bhw bin bip:bool[] = convert_element_type[ new_dtype=bool weak_type=False ] bhi biq:bool[] = or bip bio bir:bool[] = convert_element_type[ new_dtype=bool weak_type=False ] bhh bis:bool[] = and bir biq in (bis,) } cond_nconsts=1 ] dq dm dn dk dp do dl dw 0 du dr ds bit:bool[200,1] = ne dx dx biu:bool[] = reduce_or[axes=(0, 1)] bit biv:bool[] = not biu biw:bool[] = lt -1 0 bix:i32[] = add -1 200 biy:i32[] = select_n biw -1 bix biz:bool[] = lt 0 0 bja:i32[] = add 0 1 bjb:i32[] = select_n biz 0 bja bjc:f32[1,1] = dynamic_slice[slice_sizes=(1, 1)] dx biy bjb bjd:f32[1] = squeeze[dimensions=(0,)] bjc bje:f32[] = convert_element_type[new_dtype=float32 weak_type=False] dq bjf:bool[1] = lt bjd bje bjg:bool[1] = and biv bjf bjh:bool[1] = slice[ limit_indices=(1,) start_indices=(0,) strides=None ] bjg bji:bool[] = squeeze[dimensions=(0,)] bjh bjj:f32[200,1] = slice[ limit_indices=(200, 1) start_indices=(0, 0) strides=None ] dx bjk:f32[200] = squeeze[dimensions=(1,)] bjj bjl:f32[] = copy dq in (dy, dz, bjk, bjl, bji, 10) } fwd_jaxpr_thunk=.memoized at 0x7f8d1012d120> num_consts=2 out_trees=. at 0x7f8d184e7640> symbolic_zeros=False ] cf cg ch ci cj ck 0.001 cz dd bjm:f32[239696] = stop_gradient de bjn:f32[422220] = stop_gradient df bjo:f32[239696] = broadcast_in_dim[broadcast_dimensions=() shape=(239696,)] 1.0 bjp:f32[239696] = div bjo 239696.0 bjq:bool[239696] = gt bjp 0.0 bjr:f32[422220] = broadcast_in_dim[broadcast_dimensions=() shape=(422220,)] 1.0 bjs:f32[422220] = div bjr 422220.0 bjt:bool[422220] = gt bjs 0.0 bju:f32[239696] = broadcast_in_dim[broadcast_dimensions=() shape=(239696,)] 1.0 bjv:f32[239696] = div bju 239696.0 bjw:f32[239696] = broadcast_in_dim[broadcast_dimensions=() shape=(239696,)] 1.0 bjx:f32[] = reduce_sum[axes=(0,)] bjw bjy:f32[239696] = div bjw bjx bjz:f32[239696,1] = reshape[dimensions=None new_sizes=(239696, 1)] bjy bka:f32[239696,2000] = slice[ limit_indices=(239696, 2000) start_indices=(0, 0) strides=None ] ch bkb:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2000 bkc:f32[239696,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(239696, 1) unique_indices=True ] ch bkb bkd:f32[239696] = squeeze[dimensions=(1,)] bkc bke:i32[239696] = convert_element_type[new_dtype=int32 weak_type=False] bkd bkf:f32[422220,2000] = slice[ limit_indices=(422220, 2000) start_indices=(0, 0) strides=None ] ci bkg:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2000 bkh:f32[422220,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(422220, 1) unique_indices=True ] ci bkg bki:f32[422220] = squeeze[dimensions=(1,)] bkh bkj:i32[422220] = convert_element_type[new_dtype=int32 weak_type=False] bki bkk:f32[239696] = pjit[name=norm jaxpr=norm] bka bkl:f32[422220] = pjit[name=norm jaxpr=norm1] bkf bkm:f32[239696,422220] = dot_general[ dimension_numbers=(([1], [1]), ([], [])) preferred_element_type=float32 ] bka bkf bkn:f32[1,239696] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 239696) ] bkk bko:f32[422220,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(422220, 1) ] bkl bkp:f32[422220,239696] = mul bkn bko bkq:f32[422220,239696] = add bkp 9.99999993922529e-09 bkr:f32[422220,239696] = transpose[permutation=(1, 0)] bkm bks:f32[422220,239696] = div bkr bkq bkt:f32[422220,239696] = sub 1.0 bks bku:f32[422220,239696] = mul 1.0 bkt bkv:bool[239696] = lt bke 0 bkw:i32[239696] = add bke 30 bkx:i32[239696] = select_n bkv bke bkw bky:bool[422220] = lt bkj 0 bkz:i32[422220] = add bkj 42 bla:i32[422220] = select_n bky bkj bkz blb:i32[239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(239696, 1) ] bkx blc:i32[422220,239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(422220, 239696, 1) ] bla bld:i32[422220,239696,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(422220, 239696, 1) ] blb ble:i32[422220,239696,2] = concatenate[dimension=2] bld blc blf:f32[422220,239696,1,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(2, 3), collapsed_slice_dims=(), start_index_map=(0, 1)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=False ] cf ble blg:f32[422220,239696] = squeeze[dimensions=(2, 3)] blf blh:f32[422220,239696] = mul 0.0010000000474974513 blg bli:f32[422220,239696] = add bku blh blj:f32[422220,239696] = add 0.0 bli blk:f32[422220,239696] = mul blj 1.0 bll:f32[422220,1] = dot_general[ dimension_numbers=(([1], [0]), ([], [])) preferred_element_type=float32 ] blk bjz blm:f32[422220] = squeeze[dimensions=(1,)] bll bln:f32[422220] = broadcast_in_dim[broadcast_dimensions=() shape=(422220,)] 1.0 blo:f32[] = reduce_sum[axes=(0,)] bln blp:f32[422220] = div bln blo blq:f32[422220] = mul blm blp blr:f32[] = reduce_sum[axes=(0,)] blq bls:f32[] = stop_gradient blr blt:f32[] = mul bls 0.05000000074505806 blu:f32[239696] = log bjv blv:f32[239696] = mul blt blu blw:f32[239696] = broadcast_in_dim[broadcast_dimensions=() shape=(239696,)] 1.0 blx:f32[239696] = div blw 239696.0 bly:f32[239696] = sub bjm blv blz:f32[239696] = mul blx bly bma:f32[239696] = pjit[name=_where jaxpr=_where] bjq blz 0.0 bmb:f32[] = reduce_sum[axes=(0,)] bma bmc:f32[422220] = broadcast_in_dim[broadcast_dimensions=() shape=(422220,)] 1.0 bmd:f32[422220] = div bmc 422220.0 bme:f32[239696] = broadcast_in_dim[broadcast_dimensions=() shape=(239696,)] 1.0 bmf:f32[] = reduce_sum[axes=(0,)] bme bmg:f32[239696] = div bme bmf bmh:f32[239696,1] = reshape[dimensions=None new_sizes=(239696, 1)] bmg bmi:f32[239696,2000] = slice[ limit_indices=(239696, 2000) start_indices=(0, 0) strides=None ] ch bmj:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2000 bmk:f32[239696,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(239696, 1) unique_indices=True ] ch bmj bml:f32[239696] = squeeze[dimensions=(1,)] bmk bmm:i32[239696] = convert_element_type[new_dtype=int32 weak_type=False] bml bmn:f32[422220,2000] = slice[ limit_indices=(422220, 2000) start_indices=(0, 0) strides=None ] ci bmo:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2000 bmp:f32[422220,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(422220, 1) unique_indices=True ] ci bmo bmq:f32[422220] = squeeze[dimensions=(1,)] bmp bmr:i32[422220] = convert_element_type[new_dtype=int32 weak_type=False] bmq bms:f32[239696] = pjit[name=norm jaxpr=norm] bmi bmt:f32[422220] = pjit[name=norm jaxpr=norm1] bmn bmu:f32[239696,422220] = dot_general[ dimension_numbers=(([1], [1]), ([], [])) preferred_element_type=float32 ] bmi bmn bmv:f32[1,239696] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 239696) ] bms bmw:f32[422220,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(422220, 1) ] bmt bmx:f32[422220,239696] = mul bmv bmw bmy:f32[422220,239696] = add bmx 9.99999993922529e-09 bmz:f32[422220,239696] = transpose[permutation=(1, 0)] bmu bna:f32[422220,239696] = div bmz bmy bnb:f32[422220,239696] = sub 1.0 bna bnc:f32[422220,239696] = mul 1.0 bnb bnd:bool[239696] = lt bmm 0 bne:i32[239696] = add bmm 30 bnf:i32[239696] = select_n bnd bmm bne bng:bool[422220] = lt bmr 0 bnh:i32[422220] = add bmr 42 bni:i32[422220] = select_n bng bmr bnh bnj:i32[239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(239696, 1) ] bnf bnk:i32[422220,239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(422220, 239696, 1) ] bni bnl:i32[422220,239696,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(422220, 239696, 1) ] bnj bnm:i32[422220,239696,2] = concatenate[dimension=2] bnl bnk bnn:f32[422220,239696,1,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(2, 3), collapsed_slice_dims=(), start_index_map=(0, 1)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=False ] cf bnm bno:f32[422220,239696] = squeeze[dimensions=(2, 3)] bnn bnp:f32[422220,239696] = mul 0.0010000000474974513 bno bnq:f32[422220,239696] = add bnc bnp bnr:f32[422220,239696] = add 0.0 bnq bns:f32[422220,239696] = mul bnr 1.0 bnt:f32[422220,1] = dot_general[ dimension_numbers=(([1], [0]), ([], [])) preferred_element_type=float32 ] bns bmh bnu:f32[422220] = squeeze[dimensions=(1,)] bnt bnv:f32[422220] = broadcast_in_dim[broadcast_dimensions=() shape=(422220,)] 1.0 bnw:f32[] = reduce_sum[axes=(0,)] bnv bnx:f32[422220] = div bnv bnw bny:f32[422220] = mul bnu bnx bnz:f32[] = reduce_sum[axes=(0,)] bny boa:f32[] = stop_gradient bnz bob:f32[] = mul boa 0.05000000074505806 boc:f32[422220] = log bmd bod:f32[422220] = mul bob boc boe:f32[422220] = broadcast_in_dim[broadcast_dimensions=() shape=(422220,)] 1.0 bof:f32[422220] = div boe 422220.0 bog:f32[422220] = sub bjn bod boh:f32[422220] = mul bof bog boi:f32[422220] = pjit[name=_where jaxpr=_where1] bjt boh 0.0 boj:f32[] = reduce_sum[axes=(0,)] boi bok:f32[239696] = broadcast_in_dim[broadcast_dimensions=() shape=(239696,)] 1.0 bol:f32[] = reduce_sum[axes=(0,)] bok bom:f32[239696] = div bok bol bon:f32[239696,1] = reshape[dimensions=None new_sizes=(239696, 1)] bom boo:f32[239696,2000] = slice[ limit_indices=(239696, 2000) start_indices=(0, 0) strides=None ] ch bop:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2000 boq:f32[239696,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(239696, 1) unique_indices=True ] ch bop bor:f32[239696] = squeeze[dimensions=(1,)] boq bos:i32[239696] = convert_element_type[new_dtype=int32 weak_type=False] bor bot:f32[422220,2000] = slice[ limit_indices=(422220, 2000) start_indices=(0, 0) strides=None ] ci bou:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2000 bov:f32[422220,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(422220, 1) unique_indices=True ] ci bou bow:f32[422220] = squeeze[dimensions=(1,)] bov box:i32[422220] = convert_element_type[new_dtype=int32 weak_type=False] bow boy:f32[239696] = pjit[name=norm jaxpr=norm] boo boz:f32[422220] = pjit[name=norm jaxpr=norm1] bot bpa:f32[239696,422220] = dot_general[ dimension_numbers=(([1], [1]), ([], [])) preferred_element_type=float32 ] boo bot bpb:f32[1,239696] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 239696) ] boy bpc:f32[422220,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(422220, 1) ] boz bpd:f32[422220,239696] = mul bpb bpc bpe:f32[422220,239696] = add bpd 9.99999993922529e-09 bpf:f32[422220,239696] = transpose[permutation=(1, 0)] bpa bpg:f32[422220,239696] = div bpf bpe bph:f32[422220,239696] = sub 1.0 bpg bpi:f32[422220,239696] = mul 1.0 bph bpj:bool[239696] = lt bos 0 bpk:i32[239696] = add bos 30 bpl:i32[239696] = select_n bpj bos bpk bpm:bool[422220] = lt box 0 bpn:i32[422220] = add box 42 bpo:i32[422220] = select_n bpm box bpn bpp:i32[239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(239696, 1) ] bpl bpq:i32[422220,239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(422220, 239696, 1) ] bpo bpr:i32[422220,239696,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(422220, 239696, 1) ] bpp bps:i32[422220,239696,2] = concatenate[dimension=2] bpr bpq bpt:f32[422220,239696,1,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(2, 3), collapsed_slice_dims=(), start_index_map=(0, 1)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=False ] cf bps bpu:f32[422220,239696] = squeeze[dimensions=(2, 3)] bpt bpv:f32[422220,239696] = mul 0.0010000000474974513 bpu bpw:f32[422220,239696] = add bpi bpv bpx:f32[422220,239696] = add 0.0 bpw bpy:f32[422220,239696] = mul bpx 1.0 bpz:f32[422220,1] = dot_general[ dimension_numbers=(([1], [0]), ([], [])) preferred_element_type=float32 ] bpy bon bqa:f32[422220] = squeeze[dimensions=(1,)] bpz bqb:f32[422220] = broadcast_in_dim[broadcast_dimensions=() shape=(422220,)] 1.0 bqc:f32[] = reduce_sum[axes=(0,)] bqb bqd:f32[422220] = div bqb bqc bqe:f32[422220] = mul bqa bqd bqf:f32[] = reduce_sum[axes=(0,)] bqe bqg:f32[] = stop_gradient bqf bqh:f32[] = mul bqg 0.05000000074505806 bqi:i32[52777] = iota[dimension=0 dtype=int32 shape=(52777,)] _:f32[239696] _:f32[422220] _:f32[] bqj:f32[52777,8] bqk:f32[52777,8] = scan[ jaxpr={ lambda ; bql:f32[422220,2001] bqm:f32[239696,2001] bqn:f32[30,42] bqo:f32[239696] bqp:f32[422220] bqq:f32[] bqr:i32[]. let bqs:i32[] = mul bqr 8 bqt:bool[] = lt bqs 0 bqu:i32[] = add bqs 422220 bqv:i32[] = select_n bqt bqs bqu bqw:f32[8,2001] = dynamic_slice[slice_sizes=(8, 2001)] bql bqv 0 bqx:i32[] = mul bqr 8 bqy:bool[] = lt bqx 0 bqz:i32[] = add bqx 422220 bra:i32[] = select_n bqy bqx bqz brb:f32[8] = dynamic_slice[slice_sizes=(8,)] bqp bra brc:f32[239696,2000] = slice[ limit_indices=(239696, 2000) start_indices=(0, 0) strides=None ] bqm brd:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2000 bre:f32[239696,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(239696, 1) unique_indices=True ] bqm brd brf:f32[239696] = squeeze[dimensions=(1,)] bre brg:i32[239696] = convert_element_type[ new_dtype=int32 weak_type=False ] brf brh:f32[8,2000] = slice[ limit_indices=(8, 2000) start_indices=(0, 0) strides=None ] bqw bri:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2000 brj:f32[8,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(8, 1) unique_indices=True ] bqw bri brk:f32[8] = squeeze[dimensions=(1,)] brj brl:i32[8] = convert_element_type[new_dtype=int32 weak_type=False] brk brm:f32[239696] = pjit[name=norm jaxpr=norm] brc brn:f32[8] = pjit[name=norm jaxpr=norm3] brh bro:f32[239696,8] = dot_general[ dimension_numbers=(([1], [1]), ([], [])) preferred_element_type=float32 ] brc brh brp:f32[1,239696] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 239696) ] brm brq:f32[8,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(8, 1) ] brn brr:f32[8,239696] = mul brp brq brs:f32[8,239696] = add brr 9.99999993922529e-09 brt:f32[8,239696] = transpose[permutation=(1, 0)] bro bru:f32[8,239696] = div brt brs brv:f32[8,239696] = sub 1.0 bru brw:f32[8,239696] = mul 1.0 brv brx:bool[239696] = lt brg 0 bry:i32[239696] = add brg 30 brz:i32[239696] = select_n brx brg bry bsa:bool[8] = lt brl 0 bsb:i32[8] = add brl 42 bsc:i32[8] = select_n bsa brl bsb bsd:i32[239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(239696, 1) ] brz bse:i32[8,239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(8, 239696, 1) ] bsc bsf:i32[8,239696,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(8, 239696, 1) ] bsd bsg:i32[8,239696,2] = concatenate[dimension=2] bsf bse bsh:f32[8,239696,1,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(2, 3), collapsed_slice_dims=(), start_index_map=(0, 1)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=False ] bqn bsg bsi:f32[8,239696] = squeeze[dimensions=(2, 3)] bsh bsj:f32[8,239696] = mul 0.0010000000474974513 bsi bsk:f32[8,239696] = add brw bsj bsl:f32[8,239696] = add 0.0 bsk bsm:f32[8,239696] = mul bsl 1.0 bsn:f32[1,239696] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 239696) ] bqo bso:f32[8,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(8, 1) ] brb bsp:f32[8,239696] = add bsn bso bsq:f32[8,239696] = sub bsp bsm bsr:f32[8,239696] = div bsq bqq bss:f32[8] bst:f32[8] = custom_jvp_call[ call_jaxpr={ lambda ; bsu:f32[8,239696]. let bsv:f32[8] = reduce_max[axes=(1,)] bsu bsw:bool[8] = is_finite bsv bsx:f32[8] = broadcast_in_dim[ broadcast_dimensions=() shape=(8,) ] 0.0 bsy:f32[8] = select_n bsw bsx bsv bsz:f32[8] = stop_gradient bsy bta:f32[8,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(8, 1) ] bsz btb:f32[8,239696] = sub bsu bta btc:f32[8,239696] = exp btb btd:f32[8] = reduce_sum[axes=(1,)] btc bte:f32[8] = sign btd btf:f32[8] = abs btd btg:f32[8] = log btf bth:f32[8] = add btg bsz in (bth, bte) } jvp_jaxpr_thunk=.memoized at 0x7f8d10571cf0> num_consts=0 symbolic_zeros=False ] bsr in (bqo, bqp, bqq, bss, bst) } length=52777 linear=(False, False, False, False, False, False, False) num_carry=3 num_consts=3 reverse=False unroll=1 ] ci ch cf bjm bjn bqh bqi bti:f32[422216] = reshape[dimensions=None new_sizes=(422216,)] bqj btj:f32[422216] = reshape[dimensions=None new_sizes=(422216,)] bqk btk:f32[4,2001] = slice[ limit_indices=(422220, 2001) start_indices=(422216, 0) strides=None ] ci btl:f32[4] = slice[ limit_indices=(422220,) start_indices=(422216,) strides=None ] bjn btm:f32[239696,2000] = slice[ limit_indices=(239696, 2000) start_indices=(0, 0) strides=None ] ch btn:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2000 bto:f32[239696,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(239696, 1) unique_indices=True ] ch btn btp:f32[239696] = squeeze[dimensions=(1,)] bto btq:i32[239696] = convert_element_type[new_dtype=int32 weak_type=False] btp btr:f32[4,2000] = slice[ limit_indices=(4, 2000) start_indices=(0, 0) strides=None ] btk bts:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2000 btt:f32[4,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(4, 1) unique_indices=True ] btk bts btu:f32[4] = squeeze[dimensions=(1,)] btt btv:i32[4] = convert_element_type[new_dtype=int32 weak_type=False] btu btw:f32[239696] = pjit[name=norm jaxpr=norm] btm btx:f32[4] = pjit[name=norm jaxpr=norm2] btr bty:f32[239696,4] = dot_general[ dimension_numbers=(([1], [1]), ([], [])) preferred_element_type=float32 ] btm btr btz:f32[1,239696] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 239696) ] btw bua:f32[4,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(4, 1)] btx bub:f32[4,239696] = mul btz bua buc:f32[4,239696] = add bub 9.99999993922529e-09 bud:f32[4,239696] = transpose[permutation=(1, 0)] bty bue:f32[4,239696] = div bud buc buf:f32[4,239696] = sub 1.0 bue bug:f32[4,239696] = mul 1.0 buf buh:bool[239696] = lt btq 0 bui:i32[239696] = add btq 30 buj:i32[239696] = select_n buh btq bui buk:bool[4] = lt btv 0 bul:i32[4] = add btv 42 bum:i32[4] = select_n buk btv bul bun:i32[239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(239696, 1) ] buj buo:i32[4,239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(4, 239696, 1) ] bum bup:i32[4,239696,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(4, 239696, 1) ] bun buq:i32[4,239696,2] = concatenate[dimension=2] bup buo bur:f32[4,239696,1,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(2, 3), collapsed_slice_dims=(), start_index_map=(0, 1)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=False ] cf buq bus:f32[4,239696] = squeeze[dimensions=(2, 3)] bur but:f32[4,239696] = mul 0.0010000000474974513 bus buu:f32[4,239696] = add bug but buv:f32[4,239696] = add 0.0 buu buw:f32[4,239696] = mul buv 1.0 bux:f32[1,239696] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 239696) ] bjm buy:f32[4,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(4, 1)] btl buz:f32[4,239696] = add bux buy bva:f32[4,239696] = sub buz buw bvb:f32[4,239696] = div bva bqh bvc:f32[4] bvd:f32[4] = custom_jvp_call[ call_jaxpr={ lambda ; bve:f32[4,239696]. let bvf:f32[4] = reduce_max[axes=(1,)] bve bvg:bool[4] = is_finite bvf bvh:f32[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] 0.0 bvi:f32[4] = select_n bvg bvh bvf bvj:f32[4] = stop_gradient bvi bvk:f32[4,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(4, 1) ] bvj bvl:f32[4,239696] = sub bve bvk bvm:f32[4,239696] = exp bvl bvn:f32[4] = reduce_sum[axes=(1,)] bvm bvo:f32[4] = sign bvn bvp:f32[4] = abs bvn bvq:f32[4] = log bvp bvr:f32[4] = add bvq bvj in (bvr, bvo) } jvp_jaxpr_thunk=.memoized at 0x7f8d10571000> num_consts=0 symbolic_zeros=False ] bvb bvs:f32[422220] = concatenate[dimension=0] bti bvc _:f32[422220] = concatenate[dimension=0] btj bvd bvt:f32[422220] = mul bqh bvs bvu:bool[422220] = is_finite bjn bvv:f32[422220] = pjit[name=_where jaxpr=_where2] bvu bjn 0 bvw:f32[422220] = sub bvt bvv bvx:f32[422220] = add bvw bjn bvy:f32[239696] = broadcast_in_dim[broadcast_dimensions=() shape=(239696,)] 1.0 bvz:f32[] = reduce_sum[axes=(0,)] bvy bwa:f32[239696] = div bvy bvz bwb:f32[239696,1] = reshape[dimensions=None new_sizes=(239696, 1)] bwa bwc:f32[239696,2000] = slice[ limit_indices=(239696, 2000) start_indices=(0, 0) strides=None ] ch bwd:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2000 bwe:f32[239696,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(239696, 1) unique_indices=True ] ch bwd bwf:f32[239696] = squeeze[dimensions=(1,)] bwe bwg:i32[239696] = convert_element_type[new_dtype=int32 weak_type=False] bwf bwh:f32[422220,2000] = slice[ limit_indices=(422220, 2000) start_indices=(0, 0) strides=None ] ci bwi:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2000 bwj:f32[422220,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(422220, 1) unique_indices=True ] ci bwi bwk:f32[422220] = squeeze[dimensions=(1,)] bwj bwl:i32[422220] = convert_element_type[new_dtype=int32 weak_type=False] bwk bwm:f32[239696] = pjit[name=norm jaxpr=norm] bwc bwn:f32[422220] = pjit[name=norm jaxpr=norm1] bwh bwo:f32[239696,422220] = dot_general[ dimension_numbers=(([1], [1]), ([], [])) preferred_element_type=float32 ] bwc bwh bwp:f32[1,239696] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 239696) ] bwm bwq:f32[422220,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(422220, 1) ] bwn bwr:f32[422220,239696] = mul bwp bwq bws:f32[422220,239696] = add bwr 9.99999993922529e-09 bwt:f32[422220,239696] = transpose[permutation=(1, 0)] bwo bwu:f32[422220,239696] = div bwt bws bwv:f32[422220,239696] = sub 1.0 bwu bww:f32[422220,239696] = mul 1.0 bwv bwx:bool[239696] = lt bwg 0 bwy:i32[239696] = add bwg 30 bwz:i32[239696] = select_n bwx bwg bwy bxa:bool[422220] = lt bwl 0 bxb:i32[422220] = add bwl 42 bxc:i32[422220] = select_n bxa bwl bxb bxd:i32[239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(239696, 1) ] bwz bxe:i32[422220,239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(422220, 239696, 1) ] bxc bxf:i32[422220,239696,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(422220, 239696, 1) ] bxd bxg:i32[422220,239696,2] = concatenate[dimension=2] bxf bxe bxh:f32[422220,239696,1,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(2, 3), collapsed_slice_dims=(), start_index_map=(0, 1)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=False ] cf bxg bxi:f32[422220,239696] = squeeze[dimensions=(2, 3)] bxh bxj:f32[422220,239696] = mul 0.0010000000474974513 bxi bxk:f32[422220,239696] = add bww bxj bxl:f32[422220,239696] = add 0.0 bxk bxm:f32[422220,239696] = mul bxl 1.0 bxn:f32[422220,1] = dot_general[ dimension_numbers=(([1], [0]), ([], [])) preferred_element_type=float32 ] bxm bwb bxo:f32[422220] = squeeze[dimensions=(1,)] bxn bxp:f32[422220] = broadcast_in_dim[broadcast_dimensions=() shape=(422220,)] 1.0 bxq:f32[] = reduce_sum[axes=(0,)] bxp bxr:f32[422220] = div bxp bxq bxs:f32[422220] = mul bxo bxr bxt:f32[] = reduce_sum[axes=(0,)] bxs bxu:f32[] = stop_gradient bxt bxv:f32[] = mul bxu 0.05000000074505806 bxw:f32[422220] = div bvx bxv bxx:f32[422220] = exp bxw bxy:f32[] = reduce_sum[axes=(0,)] bxx bxz:f32[] = add bmb boj bya:f32[239696] = broadcast_in_dim[broadcast_dimensions=() shape=(239696,)] 1.0 byb:f32[] = reduce_sum[axes=(0,)] bya byc:f32[239696] = div bya byb byd:f32[239696,1] = reshape[dimensions=None new_sizes=(239696, 1)] byc bye:f32[239696,2000] = slice[ limit_indices=(239696, 2000) start_indices=(0, 0) strides=None ] ch byf:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2000 byg:f32[239696,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(239696, 1) unique_indices=True ] ch byf byh:f32[239696] = squeeze[dimensions=(1,)] byg byi:i32[239696] = convert_element_type[new_dtype=int32 weak_type=False] byh byj:f32[422220,2000] = slice[ limit_indices=(422220, 2000) start_indices=(0, 0) strides=None ] ci byk:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2000 byl:f32[422220,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(1,)) fill_value=None indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(422220, 1) unique_indices=True ] ci byk bym:f32[422220] = squeeze[dimensions=(1,)] byl byn:i32[422220] = convert_element_type[new_dtype=int32 weak_type=False] bym byo:f32[239696] = pjit[name=norm jaxpr=norm] bye byp:f32[422220] = pjit[name=norm jaxpr=norm1] byj byq:f32[239696,422220] = dot_general[ dimension_numbers=(([1], [1]), ([], [])) preferred_element_type=float32 ] bye byj byr:f32[1,239696] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 239696) ] byo bys:f32[422220,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(422220, 1) ] byp byt:f32[422220,239696] = mul byr bys byu:f32[422220,239696] = add byt 9.99999993922529e-09 byv:f32[422220,239696] = transpose[permutation=(1, 0)] byq byw:f32[422220,239696] = div byv byu byx:f32[422220,239696] = sub 1.0 byw byy:f32[422220,239696] = mul 1.0 byx byz:bool[239696] = lt byi 0 bza:i32[239696] = add byi 30 bzb:i32[239696] = select_n byz byi bza bzc:bool[422220] = lt byn 0 bzd:i32[422220] = add byn 42 bze:i32[422220] = select_n bzc byn bzd bzf:i32[239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(239696, 1) ] bzb bzg:i32[422220,239696,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(422220, 239696, 1) ] bze bzh:i32[422220,239696,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(422220, 239696, 1) ] bzf bzi:i32[422220,239696,2] = concatenate[dimension=2] bzh bzg bzj:f32[422220,239696,1,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(2, 3), collapsed_slice_dims=(), start_index_map=(0, 1)) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=False ] cf bzi bzk:f32[422220,239696] = squeeze[dimensions=(2, 3)] bzj bzl:f32[422220,239696] = mul 0.0010000000474974513 bzk bzm:f32[422220,239696] = add byy bzl bzn:f32[422220,239696] = add 0.0 bzm bzo:f32[422220,239696] = mul bzn 1.0 bzp:f32[422220,1] = dot_general[ dimension_numbers=(([1], [0]), ([], [])) preferred_element_type=float32 ] bzo byd bzq:f32[422220] = squeeze[dimensions=(1,)] bzp bzr:f32[422220] = broadcast_in_dim[broadcast_dimensions=() shape=(422220,)] 1.0 bzs:f32[] = reduce_sum[axes=(0,)] bzr bzt:f32[422220] = div bzr bzs bzu:f32[422220] = mul bzq bzt bzv:f32[] = reduce_sum[axes=(0,)] bzu bzw:f32[] = stop_gradient bzv bzx:f32[] = mul bzw 0.05000000074505806 bzy:f32[239696] = broadcast_in_dim[broadcast_dimensions=() shape=(239696,)] 1.0 bzz:f32[239696] = div bzy 239696.0 caa:f32[] = reduce_sum[axes=(0,)] bzz cab:f32[422220] = broadcast_in_dim[broadcast_dimensions=() shape=(422220,)] 1.0 cac:f32[422220] = div cab 422220.0 cad:f32[] = reduce_sum[axes=(0,)] cac cae:f32[] = mul caa cad caf:f32[] = sub cae bxy cag:f32[] = mul bzx caf cah:f32[] = add bxz cag in (de, df, dg, cah, ch, ci, cj, ck, dh, di, dj) }