Skip to content

Commit

Permalink
Remove PassRunner
Browse files Browse the repository at this point in the history
  • Loading branch information
jackalcooper committed Jan 3, 2025
1 parent 6d4a4e6 commit f117658
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 97 deletions.
2 changes: 1 addition & 1 deletion lib/beaver/application.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defmodule Beaver.Application do
require Logger
@moduledoc false
def start(_type, _args) do
[Beaver.MLIR.Pass.global_registrar_child_specs(), Beaver.Composer.pass_runner_child_specs()]
[Beaver.MLIR.Pass.global_registrar_child_specs()]
|> List.flatten()
|> Supervisor.start_link(strategy: :one_for_one)
end
Expand Down
73 changes: 39 additions & 34 deletions lib/beaver/composer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defmodule Beaver.Composer do
require Logger

@moduledoc """
This module provide functions to compose passes.
This module provide functions to compose and run passes.
"""
@enforce_keys [:op]
defstruct passes: [], op: nil
Expand Down Expand Up @@ -38,36 +38,22 @@ defmodule Beaver.Composer do
end

# Create an external pass.
defp do_create_pass(pid, argument, description, op) do
defp do_create_pass(pid, argument, description, op, run) do
argument_ref = MLIR.StringRef.create(argument).ref

MLIR.CAPI.beaver_raw_create_mlir_pass(
argument_ref,
argument_ref,
MLIR.StringRef.create(description).ref,
MLIR.StringRef.create(op).ref,
pid
pid,
run
)
|> then(&%MLIR.Pass{ref: &1, handler: pid})
end

@supervisor __MODULE__.DynamicSupervisor
@registry __MODULE__.Registry
defp create_pass(argument, desc, op, run) do
spec =
{Beaver.PassRunner, [run, name: {:via, Registry, {@registry, :"#{argument}-#{op}"}}]}

case DynamicSupervisor.start_child(@supervisor, spec) do
{:ok, pid} ->
pid

{:error, {:already_started, pid}} ->
pid

{:error, e} ->
raise Application.format_error(e)
end
|> do_create_pass(argument, desc, op)
do_create_pass(self(), argument, desc, op, run)
end

def create_pass(%MLIR.Pass{} = pass) do
Expand Down Expand Up @@ -163,6 +149,31 @@ defmodule Beaver.Composer do
@spec run!(composer) :: operation
@spec run!(composer, [run_option]) :: operation

defp dispatch_pass_action() do
receive do
{:run, op_ref, token_ref, run} ->
spawn_link(fn ->
try do
run.(%MLIR.Operation{ref: op_ref})
MLIR.CAPI.beaver_raw_logical_mutex_token_signal_success(token_ref)
rescue
exception ->
MLIR.CAPI.beaver_raw_logical_mutex_token_signal_failure(token_ref)
Logger.error(Exception.format(:error, exception, __STACKTRACE__))
end
end)

dispatch_pass_action()

{{:kind, MLIR.LogicalResult, _}, diagnostics} = ret when is_list(diagnostics) ->
Beaver.Native.check!(ret)

other ->
Logger.error("Unexpected message: #{inspect(other)}")
dispatch_pass_action()
end
end

def run!(
composer,
opts \\ @run_default_opts
Expand Down Expand Up @@ -215,16 +226,7 @@ defmodule Beaver.Composer do
txt |> Logger.info()
end

{status, diagnostics} =
case beaver_raw_run_pm_on_op_async(pm.ref, MLIR.Operation.from_module(op).ref) do
:ok ->
receive do
ret -> Beaver.Native.check!(ret)
end

ret ->
Beaver.Native.check!(ret)
end
{status, diagnostics} = run_pm_async(pm, op)

if print do
mlirContextEnableMultithreading(ctx, true)
Expand All @@ -240,10 +242,13 @@ defmodule Beaver.Composer do
end

@doc false
def pass_runner_child_specs() do
[
{DynamicSupervisor, name: @supervisor, strategy: :one_for_one},
{Registry, keys: :unique, name: @registry}
]
def run_pm_async(%MLIR.PassManager{ref: pm_ref}, op) do
case beaver_raw_run_pm_on_op_async(pm_ref, MLIR.Operation.from_module(op).ref) do
:ok ->
dispatch_pass_action()

ret ->
Beaver.Native.check!(ret)
end
end
end
3 changes: 2 additions & 1 deletion lib/beaver/mlir/capi.ex
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ defmodule Beaver.MLIR.CAPI do
_argument,
_description,
_op_name,
_handler
_handler,
_run
),
do: :erlang.nif_error(:not_loaded)

Expand Down
32 changes: 0 additions & 32 deletions lib/beaver/pass_runner.ex

This file was deleted.

2 changes: 1 addition & 1 deletion native/src/diagnostic.zig
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ pub fn call_with_diagnostics(env: beam.env, ctx: mlir_capi.Context.T, f: anytype
const id = c.mlirContextAttachDiagnosticHandler(ctx, DiagnosticAggregator.errorHandler, @ptrCast(@alignCast(userData)), DiagnosticAggregator.deleteUserData);
defer c.mlirContextDetachDiagnosticHandler(ctx, id);
var res_slice: []beam.term = try beam.allocator.alloc(beam.term, 2);
defer beam.allocator.free(res_slice);
res_slice[0] = try @call(.auto, f, args);
res_slice[1] = try DiagnosticAggregator.to_list(userData);
defer beam.allocator.free(res_slice);
return beam.make_tuple(env, res_slice);
}

Expand Down
8 changes: 4 additions & 4 deletions native/src/logical_mutex.zig
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,21 @@ pub const Token = struct {
self.done = true;
self.cond.signal();
}
pub fn pass_token_signal_logical_success(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term {
pub fn signal_logical_success(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term {
var token = try beam.fetch_ptr_resource_wrapped(@This(), env, args[0]);
token.signal(true);
return beam.make_ok(env);
}
pub fn pass_token_signal_logical_failure(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term {
pub fn signal_logical_failure(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term {
var token = try beam.fetch_ptr_resource_wrapped(@This(), env, args[0]);
token.signal(false);
return beam.make_ok(env);
}
};

pub const nifs = .{
result.nif("beaver_raw_logical_mutex_token_signal_success", 1, Token.pass_token_signal_logical_success).entry,
result.nif("beaver_raw_logical_mutex_token_signal_failure", 1, Token.pass_token_signal_logical_failure).entry,
result.nif("beaver_raw_logical_mutex_token_signal_success", 1, Token.signal_logical_success).entry,
result.nif("beaver_raw_logical_mutex_token_signal_failure", 1, Token.signal_logical_failure).entry,
};
pub fn open_all(env: beam.env) void {
beam.open_resource_wrapped(env, Token);
Expand Down
48 changes: 27 additions & 21 deletions native/src/pass.zig
Original file line number Diff line number Diff line change
Expand Up @@ -11,59 +11,65 @@ const Token = @import("logical_mutex.zig").Token;

const BeaverPass = extern struct {
handler: beam.pid,
env: beam.env,
run_fn: beam.term,
fn construct(_: ?*anyopaque) callconv(.C) void {}
fn destruct(userData: ?*anyopaque) callconv(.C) void {
const ptr: *@This() = @ptrCast(@alignCast(userData));
beam.allocator.destroy(ptr);
const this: *@This() = @ptrCast(@alignCast(userData));
e.enif_free_env(this.env);
beam.allocator.destroy(this);
}
fn initialize(_: mlir_capi.Context.T, _: ?*anyopaque) callconv(.C) mlir_capi.LogicalResult.T {
fn initialize(_: mlir_capi.Context.T, userData: ?*anyopaque) callconv(.C) mlir_capi.LogicalResult.T {
const this: *@This() = @ptrCast(@alignCast(userData));
this.*.env = e.enif_alloc_env() orelse return c.mlirLogicalResultFailure();
return c.mlirLogicalResultSuccess();
}
fn clone(userData: ?*anyopaque) callconv(.C) ?*anyopaque {
const old: *@This() = @ptrCast(@alignCast(userData));
const new = beam.allocator.create(@This()) catch unreachable;
new.* = old.*;
new.*.env = e.enif_alloc_env();
new.*.run_fn = e.enif_make_copy(new.*.env, old.run_fn);
return new;
}
const Error = error{ @"Fail to allocate BEAM environment", @"Fail to send message to pass server", @"Fail to run a pass implemented in Elixir" };
fn do_run(op: mlir_capi.Operation.T, userData: ?*anyopaque) !void {
const ud: *@This() = @ptrCast(@alignCast(userData));
const Error = error{ @"Fail to allocate BEAM environment", @"Fail to send message to pass server", @"Fail to run a pass implemented in Elixir", @"External pass must be run on non-scheduler thread to prevent deadlock" };
fn do_run(op: mlir_capi.Operation.T, this: *@This()) !void {
if (e.enif_thread_type() != e.ERL_NIF_THR_UNDEFINED) {
return Error.@"External pass must be run on non-scheduler thread to prevent deadlock";
}
const env = e.enif_alloc_env() orelse return Error.@"Fail to allocate BEAM environment";
defer e.enif_clear_env(env);
const handler = ud.*.handler;
var tuple_slice: []beam.term = try beam.allocator.alloc(beam.term, 3);
defer beam.allocator.free(tuple_slice);
tuple_slice[0] = beam.make_atom(env, "run");
tuple_slice[1] = try mlir_capi.Operation.resource.make(env, op);
var token = Token{};
tuple_slice[2] = try beam.make_ptr_resource_wrapped(env, &token);
if (!beam.send_advanced(env, handler, env, beam.make_tuple(env, tuple_slice))) {
const tuple_slice: []const beam.term = &.{ beam.make_atom(env, "run"), try mlir_capi.Operation.resource.make(env, op), try beam.make_ptr_resource_wrapped(env, &token), e.enif_make_copy(env, this.run_fn) };
const msg = beam.make_tuple(env, @constCast(tuple_slice));
if (!beam.send_advanced(env, this.*.handler, env, msg)) {
return Error.@"Fail to send message to pass server";
}
if (c.beaverLogicalResultIsFailure(token.wait_logical())) return Error.@"Fail to run a pass implemented in Elixir";
}
fn run(op: mlir_capi.Operation.T, pass: c.MlirExternalPass, userData: ?*anyopaque) callconv(.C) void {
if (do_run(op, userData)) |_| {} else |err| {
if (do_run(op, @ptrCast(@alignCast(userData)))) |_| {} else |err| {
c.mlirEmitError(c.mlirOperationGetLocation(op), @errorName(err));
c.mlirExternalPassSignalFailure(pass);
}
}
};

threadlocal var typeIDAllocator: ?mlir_capi.TypeIDAllocator.T = null;
pub fn do_create(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term {
const name = try mlir_capi.StringRef.resource.fetch(env, args[0]);
const argument = try mlir_capi.StringRef.resource.fetch(env, args[1]);
const description = try mlir_capi.StringRef.resource.fetch(env, args[2]);
const op_name = try mlir_capi.StringRef.resource.fetch(env, args[3]);
const handler: beam.pid = try beam.get_pid(env, args[4]);

const typeIDAllocator = c.mlirTypeIDAllocatorCreate();
defer c.mlirTypeIDAllocatorDestroy(typeIDAllocator);
const passID = c.mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
if (typeIDAllocator == null) {
typeIDAllocator = c.mlirTypeIDAllocatorCreate();
}
const passID = c.mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator.?);
const nDependentDialects = 0;
const dependentDialects = null;
const bp: *BeaverPass = try beam.allocator.create(BeaverPass);
bp.* = BeaverPass{ .handler = handler };
const bp_env = e.enif_alloc_env();
bp.* = BeaverPass{ .handler = handler, .env = bp_env, .run_fn = e.enif_make_copy(bp_env, args[5]) };
// use this function to avoid ABI issue
const ep = c.beaverPassCreate(
BeaverPass.construct,
Expand Down Expand Up @@ -125,6 +131,6 @@ pub fn run_pm_on_op(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.te
}

pub const nifs = .{
result.nif("beaver_raw_create_mlir_pass", 5, do_create).entry,
result.nif("beaver_raw_create_mlir_pass", 6, do_create).entry,
result.nif("beaver_raw_run_pm_on_op_async", 2, run_pm_on_op).entry,
};
6 changes: 3 additions & 3 deletions test/capi_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ defmodule MlirTest do
pm = mlirPassManagerCreate(ctx)
mlirPassManagerAddOwnedPass(pm, external)
mlirPassManagerAddOwnedPass(pm, mlirCreateTransformsCSE())
success = mlirPassManagerRunOnOp(pm, MLIR.Operation.from_module(module))
{success, _} = Beaver.Composer.run_pm_async(pm, module)
assert Beaver.MLIR.LogicalResult.success?(success)
mlirPassManagerDestroy(pm)
mlirModuleDestroy(module)
Expand All @@ -205,7 +205,7 @@ defmodule MlirTest do
pm = mlirPassManagerCreate(ctx)
npm = mlirPassManagerGetNestedUnder(pm, MLIR.StringRef.create("func.func"))
mlirOpPassManagerAddOwnedPass(npm, external)
success = mlirPassManagerRunOnOp(pm, MLIR.Operation.from_module(module))
{success, _} = Beaver.Composer.run_pm_async(pm, module)
assert Beaver.MLIR.LogicalResult.success?(success)
mlirPassManagerDestroy(pm)
mlirModuleDestroy(module)
Expand All @@ -217,7 +217,7 @@ defmodule MlirTest do
pm = mlirPassManagerCreate(ctx)
npm = mlirPassManagerGetNestedUnder(pm, MLIR.StringRef.create("func.func"))
mlirOpPassManagerAddOwnedPass(npm, external)
success = mlirPassManagerRunOnOp(pm, MLIR.Operation.from_module(module))
{success, _} = Beaver.Composer.run_pm_async(pm, module)
assert Beaver.MLIR.LogicalResult.success?(success)
mlirPassManagerDestroy(pm)
mlirModuleDestroy(module)
Expand Down
6 changes: 6 additions & 0 deletions test/pass_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,10 @@ defmodule PassTest do
|> Beaver.Composer.run!()
end
end

test "parallel processing func.func", %{ctx: ctx} do
Beaver.Dummy.gigantic(ctx)
|> Beaver.Composer.nested("func.func", {"DoNothingHere", "func.func", fn _ -> :ok end})
|> Beaver.Composer.run!()
end
end

0 comments on commit f117658

Please sign in to comment.