Skip to content

Commit

Permalink
beaver_raw_run_pm_on_op_async
Browse files Browse the repository at this point in the history
  • Loading branch information
jackalcooper committed Dec 30, 2024
1 parent b3bc8b6 commit 7f0e408
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 59 deletions.
22 changes: 18 additions & 4 deletions lib/beaver/composer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,23 @@ defmodule Beaver.Composer do
{:ok, op} ->
op

{:error, msg} ->
raise msg
{:error, diagnostics} ->
raise ArgumentError,
(for {_severity, loc, d, _num} <- diagnostics,
reduce: "Unexpected failure running passes" do
acc -> "#{acc}\n#{to_string(loc)}: #{d}"
end)
end
end

@spec run(composer) :: run_result
@spec run(composer, [run_option]) :: run_result

@doc """
Run the passes on the operation.
Note that it can be more expensive than a C/C++ implementation because ENIF Thread will be created to run the CAPI.
"""
def run(
%__MODULE__{op: op} = composer,
opts \\ @run_default_opts
Expand Down Expand Up @@ -207,7 +216,12 @@ defmodule Beaver.Composer do
txt |> Logger.info()
end

status = mlirPassManagerRunOnOp(pm, MLIR.Operation.from_module(op))
:ok = beaver_raw_run_pm_on_op_async(pm.ref, MLIR.Operation.from_module(op).ref)

{status, diagnostics} =
receive do
ret -> Beaver.Native.check!(ret)
end

if print do
mlirContextEnableMultithreading(ctx, true)
Expand All @@ -218,7 +232,7 @@ defmodule Beaver.Composer do
if MLIR.LogicalResult.success?(status) do
{:ok, op}
else
{:error, "Unexpected failure running passes"}
{:error, diagnostics}
end
end

Expand Down
1 change: 1 addition & 0 deletions lib/beaver/mlir/capi.ex
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ defmodule Beaver.MLIR.CAPI do
),
do: :erlang.nif_error(:not_loaded)

def beaver_raw_run_pm_on_op_async(_pm, _op), do: :erlang.nif_error(:not_loaded)
def beaver_raw_logical_mutex_token_signal_success(_), do: :erlang.nif_error(:not_loaded)
def beaver_raw_logical_mutex_token_signal_failure(_), do: :erlang.nif_error(:not_loaded)
def beaver_raw_registered_ops(_ctx), do: :erlang.nif_error(:not_loaded)
Expand Down
2 changes: 1 addition & 1 deletion lib/beaver/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ defmodule Beaver.Native do
end,
for {severity_i, loc_ref, note, num} <- diagnostics do
{Beaver.MLIR.Diagnostic.severity(severity_i), %Beaver.MLIR.Location{ref: loc_ref},
Enum.join(note), num}
to_string(note), num}
end}

ret ->
Expand Down
13 changes: 4 additions & 9 deletions native/gen_wrapper.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ defmodule Updater do
System.argv() |> Enum.chunk_every(2)
end

@dirty_io ~w{mlirPassManagerRunOnOp}
|> Enum.map(&String.to_atom/1)
@with_diagnostics ~w{mlirAttributeParseGet mlirOperationVerify mlirTypeParseGet mlirModuleCreateParse beaverModuleApplyPatternsAndFoldGreedily mlirExecutionEngineCreate}
|> Enum.map(&String.to_atom/1)
@regular_and_dirty ~w{mlirExecutionEngineInvokePacked}
|> Enum.map(&String.to_atom/1)
@normal_and_dirty ~w{mlirExecutionEngineInvokePacked}
|> Enum.map(&String.to_atom/1)

defp dirty_io(name), do: "#{name}_dirty_io" |> String.to_atom()
defp dirty_cpu(name), do: "#{name}_dirty_cpu" |> String.to_atom()
Expand All @@ -22,7 +20,7 @@ defmodule Updater do
name in @with_diagnostics or String.ends_with?(Atom.to_string(name), "GetChecked") ->
[{name, arity}, {with_diagnostics(name), arity + 1}]

name in @regular_and_dirty ->
name in @normal_and_dirty ->
[{name, arity}, {dirty_io(name), arity}, {dirty_cpu(name), arity}]

true ->
Expand All @@ -49,10 +47,7 @@ defmodule Updater do
name in @with_diagnostics or String.ends_with?(Atom.to_string(name), "GetChecked") ->
[~s{N(K, c, "#{name}"),}, ~s{diagnostic.WithDiagnosticsNIF(K, c, "#{name}"),}]

name in @dirty_io ->
~s{D_CPU(K, c, "#{name}", null),}

name in @regular_and_dirty ->
name in @normal_and_dirty ->
[
~s{N(K, c, "#{name}"),},
~s{D_IO(K, c, "#{name}", "#{dirty_io(name)}"),},
Expand Down
2 changes: 2 additions & 0 deletions native/include/mlir-c/Beaver/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ MLIR_CAPI_EXPORTED void
beaverContextEnterMultiThreadedExecution(MlirContext context);
MLIR_CAPI_EXPORTED void
beaverContextExitMultiThreadedExecution(MlirContext context);
MLIR_CAPI_EXPORTED bool beaverContextAddWork(MlirContext context,
void (*task)(void *), void *arg);

#ifdef __cplusplus
}
Expand Down
11 changes: 11 additions & 0 deletions native/lib/CAPI/Beaver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "mlir/Dialect/IRDL/IRDLLoading.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/ExtensibleDialect.h"
#include "llvm/Support/ThreadPool.h"

using namespace mlir;

Expand Down Expand Up @@ -277,3 +278,13 @@ MLIR_CAPI_EXPORTED MlirLogicalResult beaverModuleApplyPatternsAndFoldGreedily(
MlirModule module, MlirFrozenRewritePatternSet patterns) {
return mlirApplyPatternsAndFoldGreedily(module, patterns, {});
}

MLIR_CAPI_EXPORTED bool beaverContextAddWork(MlirContext context,
void (*task)(void *), void *arg) {
if (unwrap(context)->isMultithreadingEnabled()) {
unwrap(context)->getThreadPool().async([task, arg]() { task(arg); });
return true;
} else {
return false;
}
}
45 changes: 4 additions & 41 deletions native/src/diagnostic.zig
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,8 @@ const kinda = @import("kinda");
const result = @import("kinda").result;
const StringRefCollector = @import("string_ref.zig").StringRefCollector;

const BeaverDiagnostic = struct {
handler: ?beam.pid = null,
const Error = error{
EnvAllocFailure,
MsgSendFailure,
};
pub fn sendDiagnostic(diagnostic: c.MlirDiagnostic, userData: ?*anyopaque) !mlir_capi.LogicalResult.T {
const ud: ?*@This() = @ptrCast(@alignCast(userData));
const h = ud.?.*.handler.?;
const env = e.enif_alloc_env() orelse return Error.EnvAllocFailure;
var tuple_slice: []beam.term = try beam.allocator.alloc(beam.term, 3);
defer beam.allocator.free(tuple_slice);
tuple_slice[0] = beam.make_atom(env, "diagnostic");
tuple_slice[1] = try mlir_capi.Diagnostic.resource.make(env, diagnostic);
var token = MutexToken{};
tuple_slice[2] = try beam.make_ptr_resource_wrapped(env, &token);
if (!beam.send(env, h, beam.make_tuple(env, tuple_slice))) {
return Error.MsgSendFailure;
}
return token.wait_logical();
}
pub fn deleteUserData(userData: ?*anyopaque) callconv(.C) void {
const ud: ?*@This() = @ptrCast(@alignCast(userData));
beam.allocator.destroy(ud.?);
}
pub fn errorHandler(diagnostic: c.MlirDiagnostic, userData: ?*anyopaque) callconv(.C) mlir_capi.LogicalResult.T {
return sendDiagnostic(diagnostic, userData) catch return c.mlirLogicalResultFailure();
}
};

// collect diagnostic as {severity, loc, message, num_notes}
const DiagnosticAggregator = struct {
pub const DiagnosticAggregator = struct {
const Container = std.ArrayList(beam.term);
env: beam.env,
container: Container = undefined,
Expand Down Expand Up @@ -77,13 +47,13 @@ const DiagnosticAggregator = struct {
const ud: ?*@This() = @ptrCast(@alignCast(userData));
beam.allocator.destroy(ud.?);
}
fn init(env: beam.env) !*@This() {
pub fn init(env: beam.env) !*@This() {
var userData = try beam.allocator.create(DiagnosticAggregator);
userData.env = env;
userData.container = Container.init(beam.allocator);
return userData;
}
fn collect_and_destroy(this: *@This()) !beam.term {
pub fn collect_and_destroy(this: *@This()) !beam.term {
defer this.container.deinit();
return beam.make_term_list(this.env, this.container.items);
}
Expand All @@ -109,11 +79,4 @@ pub fn WithDiagnosticsNIF(comptime Kinds: anytype, c_: anytype, comptime name: a
return result.nif(nifPrefix ++ name ++ nifSuffix, 1 + bang.arity, AttachAndRun.with_diagnostics).entry;
}

fn do_attach(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term {
var userData: ?*BeaverDiagnostic = try beam.allocator.create(BeaverDiagnostic);
userData.?.handler = beam.get_pid(env, args[1]) catch null;
const id = c.mlirContextAttachDiagnosticHandler(try mlir_capi.Context.resource.fetch(env, args[0]), BeaverDiagnostic.errorHandler, userData, BeaverDiagnostic.deleteUserData);
return try mlir_capi.DiagnosticHandlerID.resource.make(env, id);
}

pub const nifs = .{result.nif("beaver_raw_context_attach_diagnostic_handler", 2, do_attach).entry};
pub const nifs = .{};
51 changes: 50 additions & 1 deletion native/src/pass.zig
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ const mlir_capi = @import("mlir_capi.zig");
pub const c = @import("prelude.zig");
const e = @import("erl_nif");
const debug_print = @import("std").debug.print;
const kinda = @import("kinda");
const result = @import("kinda").result;
const diagnostic = @import("diagnostic.zig");
const DiagnosticAggregator = diagnostic.DiagnosticAggregator;
const Token = @import("logical_mutex.zig").Token;

const BeaverPass = extern struct {
Expand Down Expand Up @@ -36,7 +38,7 @@ const BeaverPass = extern struct {
tuple_slice[1] = try mlir_capi.Operation.resource.make(env, op);
var token = Token{};
tuple_slice[2] = try beam.make_ptr_resource_wrapped(env, &token);
if (!beam.send(env, handler, beam.make_tuple(env, tuple_slice))) {
if (!beam.send_advanced(env, handler, env, beam.make_tuple(env, tuple_slice))) {
return Error.@"Fail to send message to pass server";
}
if (c.beaverLogicalResultIsFailure(token.wait_logical())) return Error.@"Fail to run a pass implemented in Elixir";
Expand Down Expand Up @@ -82,6 +84,53 @@ pub fn do_create(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term
return try mlir_capi.Pass.resource.make(env, ep);
}

const WorkerError = error{ @"failed to add work", @"Fail to allocate BEAM environment", @"Fail to send message to pm caller" };

const PassManagerRunner = extern struct {
pid: beam.pid,
pm: mlir_capi.PassManager.T,
op: mlir_capi.Operation.T,
fn run_with_diagnostics(this: @This()) !void {
const env = e.enif_alloc_env() orelse return WorkerError.@"Fail to allocate BEAM environment";
const userData = try DiagnosticAggregator.init(env);
const ctx = c.mlirOperationGetContext(this.op);
const id = c.mlirContextAttachDiagnosticHandler(ctx, DiagnosticAggregator.errorHandler, @ptrCast(@alignCast(userData)), DiagnosticAggregator.deleteUserData);
defer c.mlirContextDetachDiagnosticHandler(ctx, id);
const res = c.mlirPassManagerRunOnOp(this.pm, this.op);
var res_slice: []beam.term = try beam.allocator.alloc(beam.term, 2);
// we only use the return functionality of BangFunc here because we are not fetching resources here
const bang = kinda.BangFunc(c.K, c, "mlirPassManagerRunOnOp");
res_slice[0] = try bang.make_return(env, res);
res_slice[1] = try DiagnosticAggregator.collect_and_destroy(userData);
defer beam.allocator.free(res_slice);
if (!beam.send_advanced(env, this.pid, env, beam.make_tuple(env, res_slice))) {
return WorkerError.@"Fail to send message to pm caller";
}
}
fn run_and_send(worker: ?*anyopaque) callconv(.C) void {
const this: ?*@This() = @ptrCast(@alignCast(worker));
defer beam.allocator.destroy(this.?);
run_with_diagnostics(this.?.*) catch @panic("Fail to run pass on operation");
}
};

pub fn run_pm_on_op(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term {
const w = try beam.allocator.create(PassManagerRunner);
if (e.enif_self(env, &w.*.pid) == null) {
@panic("Fail to get self pid");
}
w.*.pm = try mlir_capi.PassManager.resource.fetch(env, args[0]);
w.*.op = try mlir_capi.Operation.resource.fetch(env, args[1]);
const ctx = c.mlirOperationGetContext(w.op);
if (c.beaverContextAddWork(ctx, PassManagerRunner.run_and_send, @ptrCast(@constCast(w)))) {
return beam.make_ok(env);
} else {
beam.allocator.destroy(w);
return WorkerError.@"failed to add work";
}
}

pub const nifs = .{
result.nif("beaver_raw_create_mlir_pass", 5, do_create).entry,
result.nif("beaver_raw_run_pm_on_op_async", 2, run_pm_on_op).entry,
};
2 changes: 1 addition & 1 deletion test/pass_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ defmodule PassTest do
test "exception in run/1", %{ctx: ctx} do
ir = example_ir(ctx)

assert_raise RuntimeError, ~r"Unexpected failure running passes", fn ->
assert_raise ArgumentError, ~r"Fail to run a pass implemented in Elixir", fn ->
ir
|> Beaver.Composer.nested("func.func", [
PassRaisingException
Expand Down
3 changes: 1 addition & 2 deletions test/test_helper.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ ExUnit.configure(
stderr: true,
cuda: :os.type() == {:unix, :darwin},
cuda_runtime: :os.type() == {:unix, :darwin} or System.get_env("CI") == "true"
],
timeout: 10_000
]
)

ExUnit.start()

0 comments on commit 7f0e408

Please sign in to comment.