From b05d6edb52a0612feedcbdb35a0116e0b91c06df Mon Sep 17 00:00:00 2001 From: Olivier Nicole Date: Wed, 4 Dec 2024 18:31:03 +0100 Subject: [PATCH] CR --- compiler/lib/effects.ml | 135 +++++++----------- compiler/lib/subst.ml | 3 +- .../double-translation/direct_calls.ml | 6 +- .../effects_continuations.ml | 20 +-- .../double-translation/effects_exceptions.ml | 34 ++--- 5 files changed, 81 insertions(+), 117 deletions(-) diff --git a/compiler/lib/effects.ml b/compiler/lib/effects.ml index 08b5d48044..9ecc2c60c9 100644 --- a/compiler/lib/effects.ml +++ b/compiler/lib/effects.ml @@ -818,24 +818,6 @@ let rewrite_direct_block ~st ~cps_needed ~closure_info ~pc block = { block with body } else { block with body = List.map ~f:(rewrite_direct_instr ~st) block.body } -(* Apply a substitution in a set of blocks *) -let subst_in_blocks blocks s = - Addr.Map.mapi - (fun pc block -> - if debug () - then ( - debug_print "@[block before first subst: @,"; - Code.Print.block (fun _ _ -> "") pc block; - debug_print "@]"); - let res = Subst.Excluding_Binders.block s block in - if debug () - then ( - debug_print "@[block after first subst: @,"; - Code.Print.block (fun _ _ -> "") pc res; - debug_print "@]"); - res) - blocks - (* Apply a substitution in a set of blocks, including to bound variables *) let subst_bound_in_blocks blocks s = Addr.Map.mapi @@ -854,20 +836,21 @@ let subst_bound_in_blocks blocks s = res) blocks +let subst_add array v v' = + if 0 <= Var.idx v && Var.idx v < Array.length array then array.(Var.idx v) <- v' + let cps_transform ~live_vars ~flow_info ~cps_needed p = (* Define an identity function, needed for the boilerplate around "resume" *) let closure_info = Hashtbl.create 16 in let trampolined_calls = ref Var.Set.empty in let in_cps = ref Var.Set.empty in let cps_pc_of_direct = Hashtbl.create 512 in - let p, bound_subst, param_subst, new_blocks = + let cloned_vars = Array.init (Var.count ()) ~f:Var.of_idx in + let cloned_subst = Subst.from_array cloned_vars in + let p, new_blocks = Code.fold_closures_innermost_first p - (fun name_opt - params - (start, args) - (({ blocks; free_pc; _ } as p), bound_subst, param_subst, new_blocks) - -> + (fun name_opt params (start, args) (({ blocks; free_pc; _ } as p), new_blocks) -> Option.iter name_opt ~f:(fun v -> debug_print "@[cname = %s@,@]" @@ Var.to_string v); (* We speculatively add a block at the beginning of the @@ -957,7 +940,7 @@ let cps_transform ~live_vars ~flow_info ~cps_needed p = start blocks ()); - let blocks, free_pc, bound_subst, param_subst, new_blocks = + let blocks, free_pc, new_blocks = (* For every block in the closure, 1. CPS-translate it if needed. If we double-translate, add its CPS translation to the block map at a fresh address. Otherwise, @@ -965,49 +948,41 @@ let cps_transform ~live_vars ~flow_info ~cps_needed p = 2. If we double-translate, keep the direct-style block but modify function definitions to add the CPS version where needed, and turn uses of %resume and %perform into switchings to CPS. *) - let param_subst, transform_block = + let transform_block = if function_needs_cps && double_translate () then ( let k = Var.fresh_n "cont" in let cps_start = mk_cps_pc_of_direct ~st start in let params' = List.map ~f:Var.fork params in - let param_subst = - List.fold_left2 - ~f:(fun m p p' -> Var.Map.add p p' m) - ~init:param_subst - params - params' - in - let cps_args = List.map ~f:(Subst.from_map param_subst) args in + List.iter2 params params' ~f:(fun x x' -> cloned_vars.(Var.idx x) <- x'); + let cps_args = List.map ~f:cloned_subst args in Hashtbl.add st.closure_info initial_start (params' @ [ k ], (cps_start, cps_args)); - ( param_subst - , fun pc block -> - let cps_block = cps_block ~st ~k ~orig_pc:pc block in - ( rewrite_direct_block - ~st - ~cps_needed - ~closure_info:st.closure_info - ~pc - block - , Some cps_block ) )) + fun pc block -> + let cps_block = cps_block ~st ~k ~orig_pc:pc block in + ( rewrite_direct_block + ~st + ~cps_needed + ~closure_info:st.closure_info + ~pc + block + , Some cps_block )) else if function_needs_cps && not (double_translate ()) then ( let k = Var.fresh_n "cont" in Hashtbl.add st.closure_info initial_start (params @ [ k ], (start, args)); - param_subst, fun pc block -> cps_block ~st ~k ~orig_pc:pc block, None) + fun pc block -> cps_block ~st ~k ~orig_pc:pc block, None) else - ( param_subst - , fun pc block -> - ( rewrite_direct_block - ~st - ~cps_needed - ~closure_info:st.closure_info - ~pc - block - , None ) ) + fun pc block -> + ( rewrite_direct_block + ~st + ~cps_needed + ~closure_info:st.closure_info + ~pc + block + , None ) in let blocks = Code.traverse @@ -1030,45 +1005,33 @@ let cps_transform ~live_vars ~flow_info ~cps_needed p = (* If double-translating, all variables bound in the CPS version will have to be subst with fresh ones to avoid clashing with the definitions in the original blocks (the actual substitution is done later). *) - let bound_subst = - if double_translate () + if double_translate () + then + if function_needs_cps && double_translate () then - let bound = - Addr.Map.fold - (fun _ block bound -> - Var.Set.union - bound - (Freevars.block_bound_vars ~closure_params:true block)) - new_blocks_this_clos - Var.Set.empty - in - Var.Set.fold (fun v m -> Var.Map.add v (Var.fork v) m) bound bound_subst - else bound_subst - in + Code.traverse + Code.{ fold = fold_children } + (fun pc () -> + let block = Addr.Map.find pc blocks in + Freevars.iter_block_bound_vars + (fun v -> subst_add cloned_vars v (Var.fork v)) + block) + start + st.blocks + (); let blocks = Addr.Map.fold Addr.Map.add new_blocks_this_clos blocks in ( blocks , free_pc - , bound_subst - , param_subst , Addr.Map.union (fun _ _ -> assert false) new_blocks new_blocks_this_clos ) in - { p with blocks; free_pc }, bound_subst, param_subst, new_blocks) - (p, Var.Map.empty, Var.Map.empty, Addr.Map.empty) + { p with blocks; free_pc }, new_blocks) + (p, Addr.Map.empty) in - let bound_subst = Subst.from_map bound_subst in - let new_blocks = subst_bound_in_blocks new_blocks bound_subst in - (* Also apply that substitution to the sets of trampolined calls, - single-version closures and cps call sites *) - trampolined_calls := Var.Set.map bound_subst !trampolined_calls; - in_cps := Var.Set.map bound_subst !in_cps; - (* All variables that were a closure parameter in a direct-style block must be - substituted by a fresh name. *) - let param_subst = Subst.from_map param_subst in - let new_blocks = subst_in_blocks new_blocks param_subst in - (* Also apply that 2nd substitution to the sets of trampolined calls, - single-version closures and cps call sites *) - trampolined_calls := Var.Set.map param_subst !trampolined_calls; - in_cps := Var.Set.map param_subst !in_cps; + let new_blocks = subst_bound_in_blocks new_blocks cloned_subst in + (* Also apply that substitution to the sets of trampolined calls, and cps + call sites *) + trampolined_calls := Var.Set.map cloned_subst !trampolined_calls; + in_cps := Var.Set.map cloned_subst !in_cps; let p = { p with blocks = diff --git a/compiler/lib/subst.ml b/compiler/lib/subst.ml index 30f06d38da..ca7fbbd267 100644 --- a/compiler/lib/subst.ml +++ b/compiler/lib/subst.ml @@ -97,7 +97,8 @@ end (****) -let from_array s x = s.(Var.idx x) +let from_array s x = + if 0 <= Var.idx x && Var.idx x < Array.length s then s.(Var.idx x) else x (****) diff --git a/compiler/tests-compiler/double-translation/direct_calls.ml b/compiler/tests-compiler/double-translation/direct_calls.ml index b21e3230e7..de6c6f24fb 100644 --- a/compiler/tests-compiler/double-translation/direct_calls.ml +++ b/compiler/tests-compiler/double-translation/direct_calls.ml @@ -155,9 +155,9 @@ let%expect_test "direct calls with --enable effects,doubletranslate" = } function f$1(g, x, cont){ runtime.caml_push_trap - (function(e){ - var raise = caml_pop_trap(), e$0 = caml_maybe_attach_backtrace(e, 0); - return raise(e$0); + (function(e$0){ + var raise = caml_pop_trap(), e = caml_maybe_attach_backtrace(e$0, 0); + return raise(e); }); return caml_exact_trampoline_cps_call (g, x, function(_P_){caml_pop_trap(); return cont();}); diff --git a/compiler/tests-compiler/double-translation/effects_continuations.ml b/compiler/tests-compiler/double-translation/effects_continuations.ml index 3ff2035b45..6fcaa8eb25 100644 --- a/compiler/tests-compiler/double-translation/effects_continuations.ml +++ b/compiler/tests-compiler/double-translation/effects_continuations.ml @@ -132,25 +132,25 @@ let%expect_test "test-compiler/lib-effects/test1.ml" = } //end function exceptions$1(s, cont){ - try{var _A_ = caml_int_of_string(s), n = _A_;} + try{var _z_ = caml_int_of_string(s), n = _z_;} catch(_E_){ - var _w_ = caml_wrap_exception(_E_); - if(_w_[1] !== Stdlib[7]){ + var _A_ = caml_wrap_exception(_E_); + if(_A_[1] !== Stdlib[7]){ var raise$1 = caml_pop_trap(); - return raise$1(caml_maybe_attach_backtrace(_w_, 0)); + return raise$1(caml_maybe_attach_backtrace(_A_, 0)); } var n = 0; } try{ if(caml_string_equal(s, cst$0)) throw caml_maybe_attach_backtrace(Stdlib[8], 1); - var _z_ = 7, m = _z_; + var _x_ = 7, m = _x_; } catch(_D_){ - var _x_ = caml_wrap_exception(_D_); - if(_x_ !== Stdlib[8]){ + var _y_ = caml_wrap_exception(_D_); + if(_y_ !== Stdlib[8]){ var raise$0 = caml_pop_trap(); - return raise$0(caml_maybe_attach_backtrace(_x_, 0)); + return raise$0(caml_maybe_attach_backtrace(_y_, 0)); } var m = 0; } @@ -165,8 +165,8 @@ let%expect_test "test-compiler/lib-effects/test1.ml" = (Stdlib[79], cst_toto, function(_B_){caml_pop_trap(); return cont([0, [0, _B_, n, m]]);}); - var _y_ = Stdlib[8], raise = caml_pop_trap(); - return raise(caml_maybe_attach_backtrace(_y_, 1)); + var _w_ = Stdlib[8], raise = caml_pop_trap(); + return raise(caml_maybe_attach_backtrace(_w_, 1)); } //end var exceptions = caml_cps_closure(exceptions$0, exceptions$1); diff --git a/compiler/tests-compiler/double-translation/effects_exceptions.ml b/compiler/tests-compiler/double-translation/effects_exceptions.ml index 6870ed6094..9a920e14e2 100644 --- a/compiler/tests-compiler/double-translation/effects_exceptions.ml +++ b/compiler/tests-compiler/double-translation/effects_exceptions.ml @@ -87,25 +87,25 @@ let%expect_test "test-compiler/lib-effects/test1.ml" = } //end function exceptions$1(s, cont){ - try{var _r_ = caml_int_of_string(s), n = _r_;} + try{var _q_ = caml_int_of_string(s), n = _q_;} catch(_v_){ - var _n_ = caml_wrap_exception(_v_); - if(_n_[1] !== Stdlib[7]){ + var _r_ = caml_wrap_exception(_v_); + if(_r_[1] !== Stdlib[7]){ var raise$1 = caml_pop_trap(); - return raise$1(caml_maybe_attach_backtrace(_n_, 0)); + return raise$1(caml_maybe_attach_backtrace(_r_, 0)); } var n = 0; } try{ if(caml_string_equal(s, cst$0)) throw caml_maybe_attach_backtrace(Stdlib[8], 1); - var _q_ = 7, m = _q_; + var _o_ = 7, m = _o_; } catch(_u_){ - var _o_ = caml_wrap_exception(_u_); - if(_o_ !== Stdlib[8]){ + var _p_ = caml_wrap_exception(_u_); + if(_p_ !== Stdlib[8]){ var raise$0 = caml_pop_trap(); - return raise$0(caml_maybe_attach_backtrace(_o_, 0)); + return raise$0(caml_maybe_attach_backtrace(_p_, 0)); } var m = 0; } @@ -120,8 +120,8 @@ let%expect_test "test-compiler/lib-effects/test1.ml" = (Stdlib[79], cst_toto, function(_s_){caml_pop_trap(); return cont([0, [0, _s_, n, m]]);}); - var _p_ = Stdlib[8], raise = caml_pop_trap(); - return raise(caml_maybe_attach_backtrace(_p_, 1)); + var _n_ = Stdlib[8], raise = caml_pop_trap(); + return raise(caml_maybe_attach_backtrace(_n_, 1)); } //end var exceptions = caml_cps_closure(exceptions$0, exceptions$1); @@ -148,24 +148,24 @@ let%expect_test "test-compiler/lib-effects/test1.ml" = //end function handler_is_loop$1(f, g, l, cont){ caml_push_trap - (function(_j_){ - function _k_(l){ + (function(_k_){ + function _j_(l){ return caml_trampoline_cps_call2 (g, l, function(match){ if(72330306 <= match[1]){ var l = match[2]; - return caml_exact_trampoline_call1(_k_, l); + return caml_exact_trampoline_call1(_j_, l); } var - exn = match[2], + exn$0 = match[2], raise = caml_pop_trap(), - exn$0 = caml_maybe_attach_backtrace(exn, 1); - return raise(exn$0); + exn = caml_maybe_attach_backtrace(exn$0, 1); + return raise(exn); }); } - return _k_(l); + return _j_(l); }); return caml_trampoline_cps_call2 (f, 0, function(_i_){caml_pop_trap(); return cont(_i_);});