Skip to content

Commit

Permalink
Run diagnostic callback on dirty scheduler (#376)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackalcooper authored Dec 22, 2024
1 parent b637a32 commit b928c71
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 30 deletions.
5 changes: 3 additions & 2 deletions lib/beaver/capturer.ex
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
defmodule Beaver.Capturer do
alias Beaver.MLIR
use GenServer
require Logger

@moduledoc """
`GenServer` to run MLIR diagnostic error handler.
Expand Down Expand Up @@ -52,9 +53,9 @@ defmodule Beaver.Capturer do
MLIR.CAPI.beaver_raw_logical_mutex_token_signal_success(token_ref)
end)
rescue
e ->
exception ->
MLIR.CAPI.beaver_raw_logical_mutex_token_signal_failure(token_ref)
reraise e, __STACKTRACE__
Logger.error("#{Exception.format(:error, exception, __STACKTRACE__)}")
end
|> then(&{:noreply, %__MODULE__{state | return: &1}})
end
Expand Down
40 changes: 19 additions & 21 deletions lib/beaver/mlir/execution_engine.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@ defmodule Beaver.MLIR.ExecutionEngine do
Composer.run!(composer_or_op) |> create!()
end

@type dirty :: nil | :io_bound | :cpu_bound

@type opt_level :: 0 | 1 | 2 | 3
@type shared_lib_path :: String.t()
@type object_dump :: boolean()
@type opts :: [
{:shared_lib_paths, [shared_lib_path]},
{:opt_level, opt_level},
{:object_dump, object_dump}
{:object_dump, object_dump},
{:dirty, dirty}
]
@spec create!(MLIR.Module.t(), opts()) :: t()
def create!(module, opts \\ []) do
Expand All @@ -37,12 +40,16 @@ defmodule Beaver.MLIR.ExecutionEngine do
require MLIR.Context

jit =
mlirExecutionEngineCreate(
module,
opt_level,
length(shared_lib_paths),
shared_lib_paths_ptr,
object_dump
Beaver.Native.apply_dirty(
:mlirExecutionEngineCreate,
[
module,
opt_level,
length(shared_lib_paths),
shared_lib_paths_ptr,
object_dump
],
opts[:dirty]
)

if MLIR.null?(jit) do
Expand All @@ -55,7 +62,6 @@ defmodule Beaver.MLIR.ExecutionEngine do
@doc """
invoke a function by symbol name.
"""
@type dirty :: nil | :io_bound | :cpu_bound
@type invoke_opts :: [
{:dirty, dirty}
]
Expand All @@ -71,22 +77,14 @@ defmodule Beaver.MLIR.ExecutionEngine do
end
|> List.wrap()

case opts[:dirty] do
:io_bound ->
:mlirExecutionEngineInvokePacked_dirty_io

:cpu_bound ->
:mlirExecutionEngineInvokePacked_dirty_cpu

nil ->
:mlirExecutionEngineInvokePacked
end
|> then(
&apply(MLIR.CAPI, &1, [
Beaver.Native.apply_dirty(
:mlirExecutionEngineInvokePacked,
[
jit,
MLIR.StringRef.create(symbol),
Beaver.Native.array(arg_ptr_list ++ return_ptr, Beaver.Native.OpaquePtr, mut: true)
])
],
opts[:dirty]
)
|> then(
&if MLIR.LogicalResult.success?(&1) do
Expand Down
16 changes: 16 additions & 0 deletions lib/beaver/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,20 @@ defmodule Beaver.Native do
def dump(%kind{ref: ref}) do
Beaver.Native.forward(kind, "dump", [ref])
end

def apply_dirty(fun, args, dirty_flag) do
f =
case dirty_flag do
:io_bound ->
:"#{fun}_dirty_io"

:cpu_bound ->
:"#{fun}_dirty_cpu"

nil ->
fun
end

apply(CAPI, f, args)
end
end
13 changes: 7 additions & 6 deletions native/gen_wrapper.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@ defmodule Updater do
System.argv() |> Enum.chunk_every(2)
end

@io_only ~w{mlirPassManagerRunOnOp mlirOperationVerify mlirAttributeParseGet mlirTypeParseGet mlirModuleCreateParse}
|> Enum.map(&String.to_atom/1)
@regular_io_cpu ~w{mlirExecutionEngineInvokePacked} |> Enum.map(&String.to_atom/1)
@dirty_io ~w{mlirPassManagerRunOnOp mlirOperationVerify mlirAttributeParseGet mlirTypeParseGet mlirModuleCreateParse}
|> Enum.map(&String.to_atom/1)
@regular_and_dirty ~w{mlirExecutionEngineInvokePacked mlirExecutionEngineCreate}
|> Enum.map(&String.to_atom/1)

defp dirty_io(name), do: "#{name}_dirty_io" |> String.to_atom()
defp dirty_cpu(name), do: "#{name}_dirty_cpu" |> String.to_atom()

def gen(functions, :elixir) do
for {name, arity} <- functions do
if name in @regular_io_cpu do
if name in @regular_and_dirty do
[{name, arity}, {dirty_io(name), arity}, {dirty_cpu(name), arity}]
else
{name, arity}
Expand All @@ -37,10 +38,10 @@ defmodule Updater do
entries =
for {name, _arity} <- functions do
cond do
name in @io_only ->
name in @dirty_io ->
~s{D_CPU(K, c, "#{name}", null),}

name in @regular_io_cpu ->
name in @regular_and_dirty ->
[
~s{N(K, c, "#{name}"),},
~s{D_IO(K, c, "#{name}", "#{dirty_io(name)}"),},
Expand Down
41 changes: 40 additions & 1 deletion test/diagnostic_test.exs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
defmodule DiagnosticTest do
use Beaver.Case, async: true, diagnostic: :server
alias Beaver.MLIR.Attribute
alias Beaver.MLIR
use Beaver
alias MLIR.{Type, Attribute}
alias MLIR.Dialect.{Func, Builtin}
require Func

defmodule DiagnosticTestHelper do
def start_and_attach(ctx, cb) do
Expand Down Expand Up @@ -73,5 +76,41 @@ defmodule DiagnosticTest do

assert txt == @collected
end

defp unrealized_conversion_cast_f(ctx) do
import MLIR.Conversion

mlir ctx: ctx do
module do
Func.func some_func(
function_type: Type.function([Type.i64()], [Type.i32()]),
sym_name: MLIR.Attribute.string("f#{System.unique_integer([:positive])}")
) do
region do
block _(a >>> Type.i64()) do
v0 = Builtin.unrealized_conversion_cast(a) >>> Type.i32()
Func.return(v0) >>> []
end
end
end
end
end
|> convert_func_to_llvm()
|> Beaver.Composer.run!()
|> MLIR.ExecutionEngine.create!(dirty: :io_bound)
end

test "large invalid llvm ir", %{ctx: ctx} do
{_, d_str} =
MLIR.Context.with_diagnostics(
ctx,
fn ->
assert_raise RuntimeError, fn -> unrealized_conversion_cast_f(ctx) end
end,
fn d, _acc -> MLIR.to_string(d) end
)

assert d_str =~ "LLVM Translation failed for operation: builtin.unrealized_conversion_cast"
end
end
end

0 comments on commit b928c71

Please sign in to comment.