diff --git a/lib/beaver/composer.ex b/lib/beaver/composer.ex index ad11df72..920655b2 100644 --- a/lib/beaver/composer.ex +++ b/lib/beaver/composer.ex @@ -171,14 +171,23 @@ defmodule Beaver.Composer do {:ok, op} -> op - {:error, msg} -> - raise msg + {:error, diagnostics} -> + raise ArgumentError, + (for {_severity, loc, d, _num} <- diagnostics, + reduce: "Unexpected failure running passes" do + acc -> "#{acc}\n#{to_string(loc)}: #{d}" + end) end end @spec run(composer) :: run_result @spec run(composer, [run_option]) :: run_result + @doc """ + Run the passes on the operation. + + Note that it can be more expensive than a C/C++ implementation because ENIF Thread will be created to run the CAPI. + """ def run( %__MODULE__{op: op} = composer, opts \\ @run_default_opts @@ -207,7 +216,12 @@ defmodule Beaver.Composer do txt |> Logger.info() end - status = mlirPassManagerRunOnOp(pm, MLIR.Operation.from_module(op)) + :ok = beaver_raw_run_pm_on_op_async(pm.ref, MLIR.Operation.from_module(op).ref) + + {status, diagnostics} = + receive do + ret -> Beaver.Native.check!(ret) + end if print do mlirContextEnableMultithreading(ctx, true) @@ -218,7 +232,7 @@ defmodule Beaver.Composer do if MLIR.LogicalResult.success?(status) do {:ok, op} else - {:error, "Unexpected failure running passes"} + {:error, diagnostics} end end diff --git a/lib/beaver/mlir/capi.ex b/lib/beaver/mlir/capi.ex index abf5c200..2f8016cf 100644 --- a/lib/beaver/mlir/capi.ex +++ b/lib/beaver/mlir/capi.ex @@ -48,6 +48,7 @@ defmodule Beaver.MLIR.CAPI do ), do: :erlang.nif_error(:not_loaded) + def beaver_raw_run_pm_on_op_async(_pm, _op), do: :erlang.nif_error(:not_loaded) def beaver_raw_logical_mutex_token_signal_success(_), do: :erlang.nif_error(:not_loaded) def beaver_raw_logical_mutex_token_signal_failure(_), do: :erlang.nif_error(:not_loaded) def beaver_raw_registered_ops(_ctx), do: :erlang.nif_error(:not_loaded) diff --git a/lib/beaver/native.ex b/lib/beaver/native.ex index 61b51208..4415d7ae 100644 --- a/lib/beaver/native.ex +++ b/lib/beaver/native.ex @@ -79,7 +79,7 @@ defmodule Beaver.Native do end, for {severity_i, loc_ref, note, num} <- diagnostics do {Beaver.MLIR.Diagnostic.severity(severity_i), %Beaver.MLIR.Location{ref: loc_ref}, - Enum.join(note), num} + to_string(note), num} end} ret -> diff --git a/native/gen_wrapper.exs b/native/gen_wrapper.exs index d3e02c2a..7ca2bbaf 100644 --- a/native/gen_wrapper.exs +++ b/native/gen_wrapper.exs @@ -5,12 +5,10 @@ defmodule Updater do System.argv() |> Enum.chunk_every(2) end - @dirty_io ~w{mlirPassManagerRunOnOp} - |> Enum.map(&String.to_atom/1) @with_diagnostics ~w{mlirAttributeParseGet mlirOperationVerify mlirTypeParseGet mlirModuleCreateParse beaverModuleApplyPatternsAndFoldGreedily mlirExecutionEngineCreate} |> Enum.map(&String.to_atom/1) - @regular_and_dirty ~w{mlirExecutionEngineInvokePacked} - |> Enum.map(&String.to_atom/1) + @normal_and_dirty ~w{mlirExecutionEngineInvokePacked} + |> Enum.map(&String.to_atom/1) defp dirty_io(name), do: "#{name}_dirty_io" |> String.to_atom() defp dirty_cpu(name), do: "#{name}_dirty_cpu" |> String.to_atom() @@ -22,7 +20,7 @@ defmodule Updater do name in @with_diagnostics or String.ends_with?(Atom.to_string(name), "GetChecked") -> [{name, arity}, {with_diagnostics(name), arity + 1}] - name in @regular_and_dirty -> + name in @normal_and_dirty -> [{name, arity}, {dirty_io(name), arity}, {dirty_cpu(name), arity}] true -> @@ -49,10 +47,7 @@ defmodule Updater do name in @with_diagnostics or String.ends_with?(Atom.to_string(name), "GetChecked") -> [~s{N(K, c, "#{name}"),}, ~s{diagnostic.WithDiagnosticsNIF(K, c, "#{name}"),}] - name in @dirty_io -> - ~s{D_CPU(K, c, "#{name}", null),} - - name in @regular_and_dirty -> + name in @normal_and_dirty -> [ ~s{N(K, c, "#{name}"),}, ~s{D_IO(K, c, "#{name}", "#{dirty_io(name)}"),}, diff --git a/native/include/mlir-c/Beaver/Context.h b/native/include/mlir-c/Beaver/Context.h index 716d4ab9..d59159c2 100644 --- a/native/include/mlir-c/Beaver/Context.h +++ b/native/include/mlir-c/Beaver/Context.h @@ -11,6 +11,8 @@ MLIR_CAPI_EXPORTED void beaverContextEnterMultiThreadedExecution(MlirContext context); MLIR_CAPI_EXPORTED void beaverContextExitMultiThreadedExecution(MlirContext context); +MLIR_CAPI_EXPORTED bool beaverContextAddWork(MlirContext context, + void (*task)(void *), void *arg); #ifdef __cplusplus } diff --git a/native/lib/CAPI/Beaver.cpp b/native/lib/CAPI/Beaver.cpp index 8d5e166f..cdbfd6e2 100644 --- a/native/lib/CAPI/Beaver.cpp +++ b/native/lib/CAPI/Beaver.cpp @@ -5,6 +5,7 @@ #include "mlir/Dialect/IRDL/IRDLLoading.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/ExtensibleDialect.h" +#include "llvm/Support/ThreadPool.h" using namespace mlir; @@ -277,3 +278,13 @@ MLIR_CAPI_EXPORTED MlirLogicalResult beaverModuleApplyPatternsAndFoldGreedily( MlirModule module, MlirFrozenRewritePatternSet patterns) { return mlirApplyPatternsAndFoldGreedily(module, patterns, {}); } + +MLIR_CAPI_EXPORTED bool beaverContextAddWork(MlirContext context, + void (*task)(void *), void *arg) { + if (unwrap(context)->isMultithreadingEnabled()) { + unwrap(context)->getThreadPool().async([task, arg]() { task(arg); }); + return true; + } else { + return false; + } +} diff --git a/native/src/diagnostic.zig b/native/src/diagnostic.zig index 858254f8..8eac4579 100644 --- a/native/src/diagnostic.zig +++ b/native/src/diagnostic.zig @@ -10,38 +10,8 @@ const kinda = @import("kinda"); const result = @import("kinda").result; const StringRefCollector = @import("string_ref.zig").StringRefCollector; -const BeaverDiagnostic = struct { - handler: ?beam.pid = null, - const Error = error{ - EnvAllocFailure, - MsgSendFailure, - }; - pub fn sendDiagnostic(diagnostic: c.MlirDiagnostic, userData: ?*anyopaque) !mlir_capi.LogicalResult.T { - const ud: ?*@This() = @ptrCast(@alignCast(userData)); - const h = ud.?.*.handler.?; - const env = e.enif_alloc_env() orelse return Error.EnvAllocFailure; - var tuple_slice: []beam.term = try beam.allocator.alloc(beam.term, 3); - defer beam.allocator.free(tuple_slice); - tuple_slice[0] = beam.make_atom(env, "diagnostic"); - tuple_slice[1] = try mlir_capi.Diagnostic.resource.make(env, diagnostic); - var token = MutexToken{}; - tuple_slice[2] = try beam.make_ptr_resource_wrapped(env, &token); - if (!beam.send(env, h, beam.make_tuple(env, tuple_slice))) { - return Error.MsgSendFailure; - } - return token.wait_logical(); - } - pub fn deleteUserData(userData: ?*anyopaque) callconv(.C) void { - const ud: ?*@This() = @ptrCast(@alignCast(userData)); - beam.allocator.destroy(ud.?); - } - pub fn errorHandler(diagnostic: c.MlirDiagnostic, userData: ?*anyopaque) callconv(.C) mlir_capi.LogicalResult.T { - return sendDiagnostic(diagnostic, userData) catch return c.mlirLogicalResultFailure(); - } -}; - // collect diagnostic as {severity, loc, message, num_notes} -const DiagnosticAggregator = struct { +pub const DiagnosticAggregator = struct { const Container = std.ArrayList(beam.term); env: beam.env, container: Container = undefined, @@ -77,13 +47,13 @@ const DiagnosticAggregator = struct { const ud: ?*@This() = @ptrCast(@alignCast(userData)); beam.allocator.destroy(ud.?); } - fn init(env: beam.env) !*@This() { + pub fn init(env: beam.env) !*@This() { var userData = try beam.allocator.create(DiagnosticAggregator); userData.env = env; userData.container = Container.init(beam.allocator); return userData; } - fn collect_and_destroy(this: *@This()) !beam.term { + pub fn collect_and_destroy(this: *@This()) !beam.term { defer this.container.deinit(); return beam.make_term_list(this.env, this.container.items); } @@ -109,11 +79,4 @@ pub fn WithDiagnosticsNIF(comptime Kinds: anytype, c_: anytype, comptime name: a return result.nif(nifPrefix ++ name ++ nifSuffix, 1 + bang.arity, AttachAndRun.with_diagnostics).entry; } -fn do_attach(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term { - var userData: ?*BeaverDiagnostic = try beam.allocator.create(BeaverDiagnostic); - userData.?.handler = beam.get_pid(env, args[1]) catch null; - const id = c.mlirContextAttachDiagnosticHandler(try mlir_capi.Context.resource.fetch(env, args[0]), BeaverDiagnostic.errorHandler, userData, BeaverDiagnostic.deleteUserData); - return try mlir_capi.DiagnosticHandlerID.resource.make(env, id); -} - -pub const nifs = .{result.nif("beaver_raw_context_attach_diagnostic_handler", 2, do_attach).entry}; +pub const nifs = .{}; diff --git a/native/src/pass.zig b/native/src/pass.zig index 679d1fee..b8659f89 100644 --- a/native/src/pass.zig +++ b/native/src/pass.zig @@ -4,8 +4,10 @@ const mlir_capi = @import("mlir_capi.zig"); pub const c = @import("prelude.zig"); const e = @import("erl_nif"); const debug_print = @import("std").debug.print; +const kinda = @import("kinda"); const result = @import("kinda").result; const diagnostic = @import("diagnostic.zig"); +const DiagnosticAggregator = diagnostic.DiagnosticAggregator; const Token = @import("logical_mutex.zig").Token; const BeaverPass = extern struct { @@ -36,7 +38,7 @@ const BeaverPass = extern struct { tuple_slice[1] = try mlir_capi.Operation.resource.make(env, op); var token = Token{}; tuple_slice[2] = try beam.make_ptr_resource_wrapped(env, &token); - if (!beam.send(env, handler, beam.make_tuple(env, tuple_slice))) { + if (!beam.send_advanced(env, handler, env, beam.make_tuple(env, tuple_slice))) { return Error.@"Fail to send message to pass server"; } if (c.beaverLogicalResultIsFailure(token.wait_logical())) return Error.@"Fail to run a pass implemented in Elixir"; @@ -82,6 +84,53 @@ pub fn do_create(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term return try mlir_capi.Pass.resource.make(env, ep); } +const WorkerError = error{ @"failed to add work", @"Fail to allocate BEAM environment", @"Fail to send message to pm caller" }; + +const PassManagerRunner = extern struct { + pid: beam.pid, + pm: mlir_capi.PassManager.T, + op: mlir_capi.Operation.T, + fn run_with_diagnostics(this: @This()) !void { + const env = e.enif_alloc_env() orelse return WorkerError.@"Fail to allocate BEAM environment"; + const userData = try DiagnosticAggregator.init(env); + const ctx = c.mlirOperationGetContext(this.op); + const id = c.mlirContextAttachDiagnosticHandler(ctx, DiagnosticAggregator.errorHandler, @ptrCast(@alignCast(userData)), DiagnosticAggregator.deleteUserData); + defer c.mlirContextDetachDiagnosticHandler(ctx, id); + const res = c.mlirPassManagerRunOnOp(this.pm, this.op); + var res_slice: []beam.term = try beam.allocator.alloc(beam.term, 2); + // we only use the return functionality of BangFunc here because we are not fetching resources here + const bang = kinda.BangFunc(c.K, c, "mlirPassManagerRunOnOp"); + res_slice[0] = try bang.make_return(env, res); + res_slice[1] = try DiagnosticAggregator.collect_and_destroy(userData); + defer beam.allocator.free(res_slice); + if (!beam.send_advanced(env, this.pid, env, beam.make_tuple(env, res_slice))) { + return WorkerError.@"Fail to send message to pm caller"; + } + } + fn run_and_send(worker: ?*anyopaque) callconv(.C) void { + const this: ?*@This() = @ptrCast(@alignCast(worker)); + defer beam.allocator.destroy(this.?); + run_with_diagnostics(this.?.*) catch @panic("Fail to run pass on operation"); + } +}; + +pub fn run_pm_on_op(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term { + const w = try beam.allocator.create(PassManagerRunner); + if (e.enif_self(env, &w.*.pid) == null) { + @panic("Fail to get self pid"); + } + w.*.pm = try mlir_capi.PassManager.resource.fetch(env, args[0]); + w.*.op = try mlir_capi.Operation.resource.fetch(env, args[1]); + const ctx = c.mlirOperationGetContext(w.op); + if (c.beaverContextAddWork(ctx, PassManagerRunner.run_and_send, @ptrCast(@constCast(w)))) { + return beam.make_ok(env); + } else { + beam.allocator.destroy(w); + return WorkerError.@"failed to add work"; + } +} + pub const nifs = .{ result.nif("beaver_raw_create_mlir_pass", 5, do_create).entry, + result.nif("beaver_raw_run_pm_on_op_async", 2, run_pm_on_op).entry, }; diff --git a/test/pass_test.exs b/test/pass_test.exs index 29b77bd9..e8812182 100644 --- a/test/pass_test.exs +++ b/test/pass_test.exs @@ -35,7 +35,7 @@ defmodule PassTest do test "exception in run/1", %{ctx: ctx} do ir = example_ir(ctx) - assert_raise RuntimeError, ~r"Unexpected failure running passes", fn -> + assert_raise ArgumentError, ~r"Fail to run a pass implemented in Elixir", fn -> ir |> Beaver.Composer.nested("func.func", [ PassRaisingException diff --git a/test/test_helper.exs b/test/test_helper.exs index 8a15c075..b9b204e1 100644 --- a/test/test_helper.exs +++ b/test/test_helper.exs @@ -3,8 +3,7 @@ ExUnit.configure( stderr: true, cuda: :os.type() == {:unix, :darwin}, cuda_runtime: :os.type() == {:unix, :darwin} or System.get_env("CI") == "true" - ], - timeout: 10_000 + ] ) ExUnit.start()