Skip to content

Commit

Permalink
Remove PassRunner (#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackalcooper authored Jan 3, 2025
1 parent 6d4a4e6 commit 4d89373
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 151 deletions.
2 changes: 1 addition & 1 deletion lib/beaver/application.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defmodule Beaver.Application do
require Logger
@moduledoc false
def start(_type, _args) do
[Beaver.MLIR.Pass.global_registrar_child_specs(), Beaver.Composer.pass_runner_child_specs()]
[Beaver.MLIR.Pass.global_registrar_child_specs()]
|> List.flatten()
|> Supervisor.start_link(strategy: :one_for_one)
end
Expand Down
75 changes: 40 additions & 35 deletions lib/beaver/composer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defmodule Beaver.Composer do
require Logger

@moduledoc """
This module provide functions to compose passes.
This module provide functions to compose and run passes.
"""
@enforce_keys [:op]
defstruct passes: [], op: nil
Expand Down Expand Up @@ -38,36 +38,22 @@ defmodule Beaver.Composer do
end

# Create an external pass.
defp do_create_pass(pid, argument, description, op) do
defp do_create_pass(pid, argument, description, op, run) do
argument_ref = MLIR.StringRef.create(argument).ref

MLIR.CAPI.beaver_raw_create_mlir_pass(
argument_ref,
argument_ref,
MLIR.StringRef.create(description).ref,
MLIR.StringRef.create(op).ref,
pid
pid,
run
)
|> then(&%MLIR.Pass{ref: &1, handler: pid})
|> Beaver.Native.check!()
end

@supervisor __MODULE__.DynamicSupervisor
@registry __MODULE__.Registry
defp create_pass(argument, desc, op, run) do
spec =
{Beaver.PassRunner, [run, name: {:via, Registry, {@registry, :"#{argument}-#{op}"}}]}

case DynamicSupervisor.start_child(@supervisor, spec) do
{:ok, pid} ->
pid

{:error, {:already_started, pid}} ->
pid

{:error, e} ->
raise Application.format_error(e)
end
|> do_create_pass(argument, desc, op)
do_create_pass(self(), argument, desc, op, run)
end

def create_pass(%MLIR.Pass{} = pass) do
Expand Down Expand Up @@ -163,6 +149,31 @@ defmodule Beaver.Composer do
@spec run!(composer) :: operation
@spec run!(composer, [run_option]) :: operation

defp dispatch_pass_action() do
receive do
{:run, op_ref, token_ref, run} ->
spawn_link(fn ->
try do
run.(%MLIR.Operation{ref: op_ref})
MLIR.CAPI.beaver_raw_logical_mutex_token_signal_success(token_ref)
rescue
exception ->
MLIR.CAPI.beaver_raw_logical_mutex_token_signal_failure(token_ref)
Logger.error(Exception.format(:error, exception, __STACKTRACE__))
end
end)

dispatch_pass_action()

{{:kind, MLIR.LogicalResult, _}, diagnostics} = ret when is_list(diagnostics) ->
Beaver.Native.check!(ret)

other ->
Logger.error("Unexpected message: #{inspect(other)}")
dispatch_pass_action()
end
end

def run!(
composer,
opts \\ @run_default_opts
Expand Down Expand Up @@ -215,16 +226,7 @@ defmodule Beaver.Composer do
txt |> Logger.info()
end

{status, diagnostics} =
case beaver_raw_run_pm_on_op_async(pm.ref, MLIR.Operation.from_module(op).ref) do
:ok ->
receive do
ret -> Beaver.Native.check!(ret)
end

ret ->
Beaver.Native.check!(ret)
end
{status, diagnostics} = run_pm_async(pm, op)

if print do
mlirContextEnableMultithreading(ctx, true)
Expand All @@ -240,10 +242,13 @@ defmodule Beaver.Composer do
end

@doc false
def pass_runner_child_specs() do
[
{DynamicSupervisor, name: @supervisor, strategy: :one_for_one},
{Registry, keys: :unique, name: @registry}
]
def run_pm_async(%MLIR.PassManager{ref: pm_ref}, op) do
case beaver_raw_run_pm_on_op_async(pm_ref, MLIR.Operation.from_module(op).ref) do
:ok ->
dispatch_pass_action()

ret ->
Beaver.Native.check!(ret)
end
end
end
3 changes: 2 additions & 1 deletion lib/beaver/mlir/capi.ex
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ defmodule Beaver.MLIR.CAPI do
_argument,
_description,
_op_name,
_handler
_handler,
_run
),
do: :erlang.nif_error(:not_loaded)

Expand Down
2 changes: 1 addition & 1 deletion lib/beaver/mlir/pass.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defmodule Beaver.MLIR.Pass do
This module defines functions working with MLIR #{__MODULE__ |> Module.split() |> List.last()}.
"""
alias Beaver.MLIR
use Kinda.ResourceKind, fields: [handler: nil], forward_module: Beaver.Native
use Kinda.ResourceKind, forward_module: Beaver.Native
@callback run(MLIR.Operation.t()) :: any()

defmacro __using__(opts) do
Expand Down
32 changes: 0 additions & 32 deletions lib/beaver/pass_runner.ex

This file was deleted.

2 changes: 1 addition & 1 deletion native/src/diagnostic.zig
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ pub fn call_with_diagnostics(env: beam.env, ctx: mlir_capi.Context.T, f: anytype
const id = c.mlirContextAttachDiagnosticHandler(ctx, DiagnosticAggregator.errorHandler, @ptrCast(@alignCast(userData)), DiagnosticAggregator.deleteUserData);
defer c.mlirContextDetachDiagnosticHandler(ctx, id);
var res_slice: []beam.term = try beam.allocator.alloc(beam.term, 2);
defer beam.allocator.free(res_slice);
res_slice[0] = try @call(.auto, f, args);
res_slice[1] = try DiagnosticAggregator.to_list(userData);
defer beam.allocator.free(res_slice);
return beam.make_tuple(env, res_slice);
}

Expand Down
8 changes: 4 additions & 4 deletions native/src/logical_mutex.zig
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,21 @@ pub const Token = struct {
self.done = true;
self.cond.signal();
}
pub fn pass_token_signal_logical_success(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term {
pub fn signal_logical_success(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term {
var token = try beam.fetch_ptr_resource_wrapped(@This(), env, args[0]);
token.signal(true);
return beam.make_ok(env);
}
pub fn pass_token_signal_logical_failure(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term {
pub fn signal_logical_failure(env: beam.env, _: c_int, args: [*c]const beam.term) !beam.term {
var token = try beam.fetch_ptr_resource_wrapped(@This(), env, args[0]);
token.signal(false);
return beam.make_ok(env);
}
};

pub const nifs = .{
result.nif("beaver_raw_logical_mutex_token_signal_success", 1, Token.pass_token_signal_logical_success).entry,
result.nif("beaver_raw_logical_mutex_token_signal_failure", 1, Token.pass_token_signal_logical_failure).entry,
result.nif("beaver_raw_logical_mutex_token_signal_success", 1, Token.signal_logical_success).entry,
result.nif("beaver_raw_logical_mutex_token_signal_failure", 1, Token.signal_logical_failure).entry,
};
pub fn open_all(env: beam.env) void {
beam.open_resource_wrapped(env, Token);
Expand Down
Loading

0 comments on commit 4d89373

Please sign in to comment.