diff --git a/lib/beaver/application.ex b/lib/beaver/application.ex index e90c9def..4013e5a3 100644 --- a/lib/beaver/application.ex +++ b/lib/beaver/application.ex @@ -3,7 +3,7 @@ defmodule Beaver.Application do require Logger @moduledoc false def start(_type, _args) do - [Beaver.MLIR.Pass.global_registrar_child_specs(), Beaver.Composer.pass_runner_child_specs()] + [Beaver.MLIR.Pass.global_registrar_child_specs()] |> List.flatten() |> Supervisor.start_link(strategy: :one_for_one) end diff --git a/lib/beaver/composer.ex b/lib/beaver/composer.ex index 3d1729d5..fa5d8b8e 100644 --- a/lib/beaver/composer.ex +++ b/lib/beaver/composer.ex @@ -4,7 +4,7 @@ defmodule Beaver.Composer do require Logger @moduledoc """ - This module provide functions to compose passes. + This module provide functions to compose and run passes. """ @enforce_keys [:op] defstruct passes: [], op: nil @@ -38,7 +38,7 @@ defmodule Beaver.Composer do end # Create an external pass. - defp do_create_pass(pid, argument, description, op) do + defp do_create_pass(pid, argument, description, op, run) do argument_ref = MLIR.StringRef.create(argument).ref MLIR.CAPI.beaver_raw_create_mlir_pass( @@ -46,28 +46,14 @@ defmodule Beaver.Composer do argument_ref, MLIR.StringRef.create(description).ref, MLIR.StringRef.create(op).ref, - pid + pid, + run ) - |> then(&%MLIR.Pass{ref: &1, handler: pid}) + |> Beaver.Native.check!() end - @supervisor __MODULE__.DynamicSupervisor - @registry __MODULE__.Registry defp create_pass(argument, desc, op, run) do - spec = - {Beaver.PassRunner, [run, name: {:via, Registry, {@registry, :"#{argument}-#{op}"}}]} - - case DynamicSupervisor.start_child(@supervisor, spec) do - {:ok, pid} -> - pid - - {:error, {:already_started, pid}} -> - pid - - {:error, e} -> - raise Application.format_error(e) - end - |> do_create_pass(argument, desc, op) + do_create_pass(self(), argument, desc, op, run) end def create_pass(%MLIR.Pass{} = pass) do @@ -163,6 +149,31 @@ defmodule Beaver.Composer do @spec run!(composer) :: operation @spec run!(composer, [run_option]) :: operation + defp dispatch_pass_action() do + receive do + {:run, op_ref, token_ref, run} -> + spawn_link(fn -> + try do + run.(%MLIR.Operation{ref: op_ref}) + MLIR.CAPI.beaver_raw_logical_mutex_token_signal_success(token_ref) + rescue + exception -> + MLIR.CAPI.beaver_raw_logical_mutex_token_signal_failure(token_ref) + Logger.error(Exception.format(:error, exception, __STACKTRACE__)) + end + end) + + dispatch_pass_action() + + {{:kind, MLIR.LogicalResult, _}, diagnostics} = ret when is_list(diagnostics) -> + Beaver.Native.check!(ret) + + other -> + Logger.error("Unexpected message: #{inspect(other)}") + dispatch_pass_action() + end + end + def run!( composer, opts \\ @run_default_opts @@ -215,16 +226,7 @@ defmodule Beaver.Composer do txt |> Logger.info() end - {status, diagnostics} = - case beaver_raw_run_pm_on_op_async(pm.ref, MLIR.Operation.from_module(op).ref) do - :ok -> - receive do - ret -> Beaver.Native.check!(ret) - end - - ret -> - Beaver.Native.check!(ret) - end + {status, diagnostics} = run_pm_async(pm, op) if print do mlirContextEnableMultithreading(ctx, true) @@ -240,10 +242,13 @@ defmodule Beaver.Composer do end @doc false - def pass_runner_child_specs() do - [ - {DynamicSupervisor, name: @supervisor, strategy: :one_for_one}, - {Registry, keys: :unique, name: @registry} - ] + def run_pm_async(%MLIR.PassManager{ref: pm_ref}, op) do + case beaver_raw_run_pm_on_op_async(pm_ref, MLIR.Operation.from_module(op).ref) do + :ok -> + dispatch_pass_action() + + ret -> + Beaver.Native.check!(ret) + end end end diff --git a/lib/beaver/mlir/capi.ex b/lib/beaver/mlir/capi.ex index 2f8016cf..96647675 100644 --- a/lib/beaver/mlir/capi.ex +++ b/lib/beaver/mlir/capi.ex @@ -44,7 +44,8 @@ defmodule Beaver.MLIR.CAPI do _argument, _description, _op_name, - _handler + _handler, + _run ), do: :erlang.nif_error(:not_loaded) diff --git a/lib/beaver/mlir/pass.ex b/lib/beaver/mlir/pass.ex index 9d8cdec8..54ddb317 100644 --- a/lib/beaver/mlir/pass.ex +++ b/lib/beaver/mlir/pass.ex @@ -3,7 +3,7 @@ defmodule Beaver.MLIR.Pass do This module defines functions working with MLIR #{__MODULE__ |> Module.split() |> List.last()}. """ alias Beaver.MLIR - use Kinda.ResourceKind, fields: [handler: nil], forward_module: Beaver.Native + use Kinda.ResourceKind, forward_module: Beaver.Native @callback run(MLIR.Operation.t()) :: any() defmacro __using__(opts) do diff --git a/lib/beaver/pass_runner.ex b/lib/beaver/pass_runner.ex deleted file mode 100644 index f6c6a687..00000000 --- a/lib/beaver/pass_runner.ex +++ /dev/null @@ -1,32 +0,0 @@ -defmodule Beaver.PassRunner do - alias Beaver.MLIR - require Logger - - @moduledoc """ - `GenServer` to run an MLIR pass implemented in Elixir - """ - use GenServer - - def start_link([run | opts]) do - GenServer.start_link(__MODULE__, run, opts) - end - - @impl true - def init(run) do - {:ok, %{run: run}} - end - - @impl true - def handle_info({:run, op_ref, token_ref}, %{run: run} = state) do - try do - run.(%MLIR.Operation{ref: op_ref}) - MLIR.CAPI.beaver_raw_logical_mutex_token_signal_success(token_ref) - rescue - exception -> - MLIR.CAPI.beaver_raw_logical_mutex_token_signal_failure(token_ref) - Logger.error("#{Exception.format(:error, exception, __STACKTRACE__)}") - end - - {:noreply, state} - end -end diff --git a/native/src/diagnostic.zig b/native/src/diagnostic.zig index 7ad76846..f044c58c 100644 --- a/native/src/diagnostic.zig +++ b/native/src/diagnostic.zig @@ -66,9 +66,9 @@ pub fn call_with_diagnostics(env: beam.env, ctx: mlir_capi.Context.T, f: anytype const id = c.mlirContextAttachDiagnosticHandler(ctx, DiagnosticAggregator.errorHandler, @ptrCast(@alignCast(userData)), DiagnosticAggregator.deleteUserData); defer c.mlirContextDetachDiagnosticHandler(ctx, id); var res_slice: []beam.term = try beam.allocator.alloc(beam.term, 2); + defer beam.allocator.free(res_slice); res_slice[0] = try @call(.auto, f, args); res_slice[1] = try DiagnosticAggregator.to_list(userData); - defer beam.allocator.free(res_slice); return beam.make_tuple(env, res_slice); } diff --git a/native/src/logical_mutex.zig b/native/src/logical_mutex.zig index 098d45da..0a258079 100644 --- a/native/src/logical_mutex.zig +++ b/native/src/logical_mutex.zig @@ -29,12 +29,12 @@ pub const Token = struct { self.done = true; self.cond.signal(); } - pub fn pass_token_signal_logical_success(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term { + pub fn signal_logical_success(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term { var token = try beam.fetch_ptr_resource_wrapped(@This(), env, args[0]); token.signal(true); return beam.make_ok(env); } - pub fn pass_token_signal_logical_failure(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term { + pub fn signal_logical_failure(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term { var token = try beam.fetch_ptr_resource_wrapped(@This(), env, args[0]); token.signal(false); return beam.make_ok(env); @@ -42,8 +42,8 @@ pub const Token = struct { }; pub const nifs = .{ - result.nif("beaver_raw_logical_mutex_token_signal_success", 1, Token.pass_token_signal_logical_success).entry, - result.nif("beaver_raw_logical_mutex_token_signal_failure", 1, Token.pass_token_signal_logical_failure).entry, + result.nif("beaver_raw_logical_mutex_token_signal_success", 1, Token.signal_logical_success).entry, + result.nif("beaver_raw_logical_mutex_token_signal_failure", 1, Token.signal_logical_failure).entry, }; pub fn open_all(env: beam.env) void { beam.open_resource_wrapped(env, Token); diff --git a/native/src/pass.zig b/native/src/pass.zig index c78ed4dc..909ea557 100644 --- a/native/src/pass.zig +++ b/native/src/pass.zig @@ -1,7 +1,7 @@ const std = @import("std"); const beam = @import("beam"); const mlir_capi = @import("mlir_capi.zig"); -pub const c = @import("prelude.zig"); +const c = @import("prelude.zig"); const e = @import("erl_nif"); const debug_print = @import("std").debug.print; const kinda = @import("kinda"); @@ -9,94 +9,84 @@ const result = @import("kinda").result; const diagnostic = @import("diagnostic.zig"); const Token = @import("logical_mutex.zig").Token; -const BeaverPass = extern struct { +const beaverPassCreateWrap = kinda.BangFunc(c.K, c, "beaverPassCreate").wrap_ret_call; +threadlocal var typeIDAllocator: ?mlir_capi.TypeIDAllocator.T = null; +const CallbackDispatcher = extern struct { handler: beam.pid, + env: beam.env, + run_fn: beam.term, fn construct(_: ?*anyopaque) callconv(.C) void {} fn destruct(userData: ?*anyopaque) callconv(.C) void { - const ptr: *@This() = @ptrCast(@alignCast(userData)); - beam.allocator.destroy(ptr); + const this: *@This() = @ptrCast(@alignCast(userData)); + e.enif_free_env(this.env); + beam.allocator.destroy(this); } - fn initialize(_: mlir_capi.Context.T, _: ?*anyopaque) callconv(.C) mlir_capi.LogicalResult.T { + fn initialize(_: mlir_capi.Context.T, userData: ?*anyopaque) callconv(.C) mlir_capi.LogicalResult.T { + const this: *@This() = @ptrCast(@alignCast(userData)); + this.*.env = e.enif_alloc_env() orelse return c.mlirLogicalResultFailure(); return c.mlirLogicalResultSuccess(); } fn clone(userData: ?*anyopaque) callconv(.C) ?*anyopaque { const old: *@This() = @ptrCast(@alignCast(userData)); const new = beam.allocator.create(@This()) catch unreachable; new.* = old.*; + new.*.env = e.enif_alloc_env() orelse unreachable; + new.*.run_fn = e.enif_make_copy(new.*.env, old.run_fn); return new; } - const Error = error{ @"Fail to allocate BEAM environment", @"Fail to send message to pass server", @"Fail to run a pass implemented in Elixir" }; - fn do_run(op: mlir_capi.Operation.T, userData: ?*anyopaque) !void { - const ud: *@This() = @ptrCast(@alignCast(userData)); + const Error = error{ @"Fail to allocate BEAM environment", @"Fail to send message to pass server", @"Fail to run a pass implemented in Elixir", @"External pass must be run on non-scheduler thread to prevent deadlock" }; + fn forward_cb(op: mlir_capi.Operation.T, this: *@This()) !void { + if (e.enif_thread_type() != e.ERL_NIF_THR_UNDEFINED) { + return Error.@"External pass must be run on non-scheduler thread to prevent deadlock"; + } const env = e.enif_alloc_env() orelse return Error.@"Fail to allocate BEAM environment"; - defer e.enif_clear_env(env); - const handler = ud.*.handler; - var tuple_slice: []beam.term = try beam.allocator.alloc(beam.term, 3); - defer beam.allocator.free(tuple_slice); - tuple_slice[0] = beam.make_atom(env, "run"); - tuple_slice[1] = try mlir_capi.Operation.resource.make(env, op); var token = Token{}; - tuple_slice[2] = try beam.make_ptr_resource_wrapped(env, &token); - if (!beam.send_advanced(env, handler, env, beam.make_tuple(env, tuple_slice))) { + const tuple_slice: []const beam.term = &.{ beam.make_atom(env, "run"), try mlir_capi.Operation.resource.make(env, op), try beam.make_ptr_resource_wrapped(env, &token), e.enif_make_copy(env, this.run_fn) }; + const msg = beam.make_tuple(env, @constCast(tuple_slice)); + if (!beam.send_advanced(env, this.*.handler, env, msg)) { return Error.@"Fail to send message to pass server"; } if (c.beaverLogicalResultIsFailure(token.wait_logical())) return Error.@"Fail to run a pass implemented in Elixir"; } fn run(op: mlir_capi.Operation.T, pass: c.MlirExternalPass, userData: ?*anyopaque) callconv(.C) void { - if (do_run(op, userData)) |_| {} else |err| { + if (forward_cb(op, @ptrCast(@alignCast(userData)))) |_| {} else |err| { c.mlirEmitError(c.mlirOperationGetLocation(op), @errorName(err)); c.mlirExternalPassSignalFailure(pass); } } -}; - -pub fn do_create(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term { - const name = try mlir_capi.StringRef.resource.fetch(env, args[0]); - const argument = try mlir_capi.StringRef.resource.fetch(env, args[1]); - const description = try mlir_capi.StringRef.resource.fetch(env, args[2]); - const op_name = try mlir_capi.StringRef.resource.fetch(env, args[3]); - const handler: beam.pid = try beam.get_pid(env, args[4]); - - const typeIDAllocator = c.mlirTypeIDAllocatorCreate(); - defer c.mlirTypeIDAllocatorDestroy(typeIDAllocator); - const passID = c.mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator); - const nDependentDialects = 0; - const dependentDialects = null; - const bp: *BeaverPass = try beam.allocator.create(BeaverPass); - bp.* = BeaverPass{ .handler = handler }; // use this function to avoid ABI issue - const ep = c.beaverPassCreate( - BeaverPass.construct, - BeaverPass.destruct, - BeaverPass.initialize, - BeaverPass.clone, - BeaverPass.run, - passID, - name, - argument, - description, - op_name, - nDependentDialects, - dependentDialects, - bp, - ); - return try mlir_capi.Pass.resource.make(env, ep); -} - -const WorkerError = error{ @"fail to allocate BEAM environment", @"fail to send message to pm caller", @"fail get caller's self pid" }; + fn create_mlir_pass(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term { + const name = try mlir_capi.StringRef.resource.fetch(env, args[0]); + const argument = try mlir_capi.StringRef.resource.fetch(env, args[1]); + const description = try mlir_capi.StringRef.resource.fetch(env, args[2]); + const op_name = try mlir_capi.StringRef.resource.fetch(env, args[3]); + const handler: beam.pid = try beam.get_pid(env, args[4]); + if (typeIDAllocator == null) { + typeIDAllocator = c.mlirTypeIDAllocatorCreate(); + } + const passID = c.mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator.?); + const nDependentDialects = 0; + const dependentDialects = null; + const bp: *@This() = try beam.allocator.create(@This()); + const bp_env = e.enif_alloc_env() orelse return Error.@"Fail to allocate BEAM environment"; + bp.* = @This(){ .handler = handler, .env = bp_env, .run_fn = e.enif_make_copy(bp_env, args[5]) }; + return beaverPassCreateWrap(env, .{ construct, destruct, initialize, clone, run, passID, name, argument, description, op_name, nDependentDialects, dependentDialects, bp }); + } +}; // we only use the return functionality of BangFunc here because we are not fetching resources here const mlirPassManagerRunOnOpWrap = kinda.BangFunc(c.K, c, "mlirPassManagerRunOnOp").wrap_ret_call; -const PassManagerRunner = extern struct { +const ManagerRunner = extern struct { + const Error = error{ @"fail to allocate BEAM environment", @"fail to send message to pm caller", @"fail get caller's self pid" }; pid: beam.pid, pm: mlir_capi.PassManager.T, op: mlir_capi.Operation.T, fn run_with_diagnostics(this: @This()) !void { - const env = e.enif_alloc_env() orelse return WorkerError.@"fail to allocate BEAM environment"; + const env = e.enif_alloc_env() orelse return Error.@"fail to allocate BEAM environment"; const ctx = c.mlirOperationGetContext(this.op); const args = .{ this.pm, this.op }; if (!beam.send_advanced(env, this.pid, env, try diagnostic.call_with_diagnostics(env, ctx, mlirPassManagerRunOnOpWrap, .{ env, args }))) { - return WorkerError.@"fail to send message to pm caller"; + return Error.@"fail to send message to pm caller"; } } fn run_and_send(worker: ?*anyopaque) callconv(.C) void { @@ -106,25 +96,24 @@ const PassManagerRunner = extern struct { c.mlirEmitError(c.mlirOperationGetLocation(this.?.*.op), @errorName(err)); } } -}; - -pub fn run_pm_on_op(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term { - const w = try beam.allocator.create(PassManagerRunner); - if (e.enif_self(env, &w.*.pid) == null) { - return WorkerError.@"fail get caller's self pid"; - } - w.*.pm = try mlir_capi.PassManager.resource.fetch(env, args[0]); - w.*.op = try mlir_capi.Operation.resource.fetch(env, args[1]); - const ctx = c.mlirOperationGetContext(w.op); - if (c.beaverContextAddWork(ctx, PassManagerRunner.run_and_send, @ptrCast(@constCast(w)))) { - return beam.make_ok(env); - } else { - defer beam.allocator.destroy(w); - return try diagnostic.call_with_diagnostics(env, ctx, mlirPassManagerRunOnOpWrap, .{ env, .{ w.*.pm, w.*.op } }); + fn run_pm_on_op(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term { + const w = try beam.allocator.create(@This()); + if (e.enif_self(env, &w.*.pid) == null) { + return Error.@"fail get caller's self pid"; + } + w.*.pm = try mlir_capi.PassManager.resource.fetch(env, args[0]); + w.*.op = try mlir_capi.Operation.resource.fetch(env, args[1]); + const ctx = c.mlirOperationGetContext(w.op); + if (c.beaverContextAddWork(ctx, @This().run_and_send, @ptrCast(@constCast(w)))) { + return beam.make_ok(env); + } else { + defer beam.allocator.destroy(w); + return try diagnostic.call_with_diagnostics(env, ctx, mlirPassManagerRunOnOpWrap, .{ env, .{ w.*.pm, w.*.op } }); + } } -} +}; pub const nifs = .{ - result.nif("beaver_raw_create_mlir_pass", 5, do_create).entry, - result.nif("beaver_raw_run_pm_on_op_async", 2, run_pm_on_op).entry, + result.nif("beaver_raw_create_mlir_pass", 6, CallbackDispatcher.create_mlir_pass).entry, + result.nif("beaver_raw_run_pm_on_op_async", 2, ManagerRunner.run_pm_on_op).entry, }; diff --git a/test/capi_test.exs b/test/capi_test.exs index 859c9b50..207b9481 100644 --- a/test/capi_test.exs +++ b/test/capi_test.exs @@ -191,7 +191,7 @@ defmodule MlirTest do pm = mlirPassManagerCreate(ctx) mlirPassManagerAddOwnedPass(pm, external) mlirPassManagerAddOwnedPass(pm, mlirCreateTransformsCSE()) - success = mlirPassManagerRunOnOp(pm, MLIR.Operation.from_module(module)) + {success, _} = Beaver.Composer.run_pm_async(pm, module) assert Beaver.MLIR.LogicalResult.success?(success) mlirPassManagerDestroy(pm) mlirModuleDestroy(module) @@ -205,7 +205,7 @@ defmodule MlirTest do pm = mlirPassManagerCreate(ctx) npm = mlirPassManagerGetNestedUnder(pm, MLIR.StringRef.create("func.func")) mlirOpPassManagerAddOwnedPass(npm, external) - success = mlirPassManagerRunOnOp(pm, MLIR.Operation.from_module(module)) + {success, _} = Beaver.Composer.run_pm_async(pm, module) assert Beaver.MLIR.LogicalResult.success?(success) mlirPassManagerDestroy(pm) mlirModuleDestroy(module) @@ -217,7 +217,7 @@ defmodule MlirTest do pm = mlirPassManagerCreate(ctx) npm = mlirPassManagerGetNestedUnder(pm, MLIR.StringRef.create("func.func")) mlirOpPassManagerAddOwnedPass(npm, external) - success = mlirPassManagerRunOnOp(pm, MLIR.Operation.from_module(module)) + {success, _} = Beaver.Composer.run_pm_async(pm, module) assert Beaver.MLIR.LogicalResult.success?(success) mlirPassManagerDestroy(pm) mlirModuleDestroy(module) diff --git a/test/pass_test.exs b/test/pass_test.exs index e8812182..e886baca 100644 --- a/test/pass_test.exs +++ b/test/pass_test.exs @@ -91,4 +91,10 @@ defmodule PassTest do |> Beaver.Composer.run!() end end + + test "parallel processing func.func", %{ctx: ctx} do + Beaver.Dummy.gigantic(ctx) + |> Beaver.Composer.nested("func.func", {"DoNothingHere", "func.func", fn _ -> :ok end}) + |> Beaver.Composer.run!() + end end