Skip to content

Commit

Permalink
feat: bind vars in auto_assert patterns in the current scope (#76)
Browse files Browse the repository at this point in the history
Closes #1

---------

Co-authored-by: Zach Allaun <[email protected]>
  • Loading branch information
Aleksei Matiushkin and zachallaun authored May 31, 2024
1 parent 15e5cd3 commit 58d9478
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 35 deletions.
6 changes: 0 additions & 6 deletions lib/mneme.ex
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,6 @@ defmodule Mneme do
auto_assert pid when is_pid(pid) <- self()
* Bindings created are only available inside guards, not outside the
assertion.
auto_assert pid when is_pid(pid) <- self()
pid # ERROR: pid is not bound
"""
@doc section: :assertion
defmacro _pattern <- _expression do
Expand Down
98 changes: 77 additions & 21 deletions lib/mneme/assertion.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ defmodule Mneme.Assertion do
alias __MODULE__
alias Mneme.Assertion.Pattern
alias Mneme.Assertion.PatternBuilder
alias Mneme.Utils

defstruct [
:kind,
Expand All @@ -15,7 +16,8 @@ defmodule Mneme.Assertion do
:patterns,
:pattern_idx,
:context,
:options
:options,
vars_bound_in_pattern: []
]

@type t :: %Assertion{
Expand All @@ -28,7 +30,8 @@ defmodule Mneme.Assertion do
patterns: [Pattern.t()],
pattern_idx: non_neg_integer(),
context: context,
options: map()
options: map(),
vars_bound_in_pattern: [macro_var]
}

@type kind ::
Expand All @@ -43,14 +46,16 @@ defmodule Mneme.Assertion do
module: module(),
test: atom(),
aliases: list(),
binding: list(),
binding: Code.binding(),
original_pattern: Macro.t() | nil
}

@type target :: :mneme | :ex_unit

@type macro_var :: {atom(), Macro.metadata(), atom()}

@doc false
def new({kind, _, args} = macro_ast, value, ctx, opts \\ Mneme.Options.options()) do
def new({kind, _, args} = macro_ast, value, ctx, vars \\ [], opts \\ Mneme.Options.options()) do
{stage, original_pattern} = get_stage(kind, args)
context = Enum.into(ctx, %{original_pattern: original_pattern})

Expand All @@ -60,7 +65,8 @@ defmodule Mneme.Assertion do
macro_ast: macro_ast,
value: value,
context: context,
options: Map.new(opts)
options: Map.new(opts),
vars_bound_in_pattern: vars
}
end

Expand All @@ -71,16 +77,10 @@ defmodule Mneme.Assertion do
def build(kind, args, caller, opts) do
macro_ast = {kind, Macro.Env.location(caller), args}
context = assertion_context(caller)

# Prevents warnings about unused aliases when their only usage is in
# an auto-assertion
silence_used_aliases =
macro_ast
|> extract_used_aliases(context[:aliases])
|> Enum.map(&quoted_dummy_assign/1)
vars = maybe_collect_vars(kind, args, caller)

quote do
unquote_splicing(silence_used_aliases)
unquote_splicing(silence_used_aliases(macro_ast, context[:aliases]))

ast = unquote(Macro.escape(macro_ast))

Expand All @@ -89,13 +89,62 @@ defmodule Mneme.Assertion do
ast,
unquote(value_eval_expr(macro_ast)),
Keyword.put(unquote(context), :binding, binding()),
unquote(Macro.escape(vars)),
unquote(opts)
)

Mneme.Assertion.run(assertion, __ENV__, Mneme.Server.started?())
{assertion, binding} = Mneme.Assertion.run!(assertion, __ENV__, Mneme.Server.started?())

unquote(vars) = Mneme.Assertion.ensure_vars!(assertion, binding)

assertion.value
end
end

defp maybe_collect_vars(:auto_assert, [{matcher, _, [left, _right]}], caller)
when matcher in [:<-, :=] do
left
|> Utils.expand_pattern(caller)
|> Utils.collect_vars_from_pattern()
|> Enum.uniq()
end

defp maybe_collect_vars(_, _, _), do: []

@doc false
def ensure_vars!(%Assertion{} = assertion, binding) do
binding_map = Map.new(binding)

result =
assertion.vars_bound_in_pattern
|> Enum.map(fn
{var, _, nil} -> var
{var, _, context} -> {var, context}
end)
|> Enum.reduce(%{values: [], missing: []}, fn var, acc ->
if value = binding_map[var] do
update_in(acc.values, &[value | &1])
else
update_in(acc.missing, &[var | &1])
end
end)

case result do
%{values: rev_values, missing: []} ->
Enum.reverse(rev_values)

%{missing: missing_vars} ->
raise Mneme.UnboundVariableError, vars: missing_vars
end
end

# Prevents warnings about unused aliases when their only usage is in an auto-assertion
defp silence_used_aliases(macro_ast, context_aliases) do
macro_ast
|> extract_used_aliases(context_aliases)
|> Enum.map(&quoted_dummy_assign/1)
end

defp extract_used_aliases(quoted, aliases) do
quoted
|> Macro.prewalker()
Expand Down Expand Up @@ -127,9 +176,13 @@ defmodule Mneme.Assertion do
end

@doc """
Run an auto-assertion, potentially patching the code.
Run an assertion, potentially patching the code.
Returns the assertion and binding resulting from running it. Raises if
the assertion fails.
"""
def run(assertion, env, interactive? \\ true) do
@spec run!(t, Macro.Env.t(), boolean()) :: {t, Code.binding()}
def run!(assertion, env, interactive? \\ true) do
do_run(assertion, env, interactive?)
rescue
error in [ExUnit.AssertionError] ->
Expand All @@ -152,8 +205,6 @@ defmodule Mneme.Assertion do
patch(assertion, env, error)
end
end

assertion.value
end

defp do_run(assertion, env, false) do
Expand All @@ -180,9 +231,14 @@ defmodule Mneme.Assertion do
end

defp handle_assertion(result, assertion, env, existing_error \\ nil)
defp handle_assertion({:ok, assertion}, _, env, _), do: eval(assertion, env)
defp handle_assertion({:error, :skipped}, _, _, _), do: :ok
defp handle_assertion({:error, :file_changed}, _, _, _), do: :ok

defp handle_assertion({:ok, assertion}, _, env, _) do
{_, binding} = eval(assertion, env)
{assertion, binding}
end

defp handle_assertion({:error, :skipped}, assertion, _, _), do: {assertion, []}
defp handle_assertion({:error, :file_changed}, assertion, _, _), do: {assertion, []}
defp handle_assertion({:error, :rejected}, _, _, nil), do: assertion_error!()

defp handle_assertion({:error, :rejected}, assertion, _, error) do
Expand Down
19 changes: 14 additions & 5 deletions lib/mneme/assertion/pattern_builder.ex
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ defmodule Mneme.Assertion.PatternBuilder do
keyset =
Enum.reduce(kvs, %{keys: [], ignore_values_for: []}, fn
{key, {:_, _, nil}}, %{keys: keys, ignore_values_for: ignore} ->
%{keys: [key | keys], ignore_values_for: [key | ignore]}
%{keys: [key | keys], ignore_values_for: [{key, :_} | ignore]}

{key, {name, _, ctx} = var}, %{keys: keys, ignore_values_for: ignore}
when is_atom(name) and is_atom(ctx) ->
%{keys: [key | keys], ignore_values_for: [{key, var} | ignore]}

{key, _}, keyset ->
update_in(keyset[:keys], &[key | &1])
Expand Down Expand Up @@ -222,10 +226,10 @@ defmodule Mneme.Assertion.PatternBuilder do
{k_patterns, vars} = to_patterns(k, %{context | map_key_pattern?: true}, vars)

{v_patterns, vars} =
if k in ignore_values_for do
{[Pattern.new({:_, [], nil})], []}
else
to_patterns(v, context, vars)
case fetch(ignore_values_for, k) do
:error -> to_patterns(v, context, vars)
{:ok, :_} -> {[Pattern.new({:_, [], nil})], []}
{:ok, value} -> {[Pattern.new(value)], []}
end

tuples =
Expand All @@ -249,6 +253,11 @@ defmodule Mneme.Assertion.PatternBuilder do
{maybe_bad_map_key_notes(map_patterns, bad_map_keys), vars}
end

# the analogue of `Keyword.fetch/2` for proplists having non-atom keys
defp fetch([], _key), do: :error
defp fetch([{key, value} | _rest], key), do: {:ok, value}
defp fetch([_ | rest], key), do: fetch(rest, key)

defp to_pairs(%{} = map) do
map
# we fetch by key instead of using `Enum.to_list/1` because
Expand Down
31 changes: 31 additions & 0 deletions lib/mneme/errors.ex
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,34 @@ defmodule Mneme.InternalError do
"""
end
end

defmodule Mneme.UnboundVariableError do
@moduledoc false
defexception [:vars, :message]

@impl true
def message(%{message: nil} = exception) do
%{vars: vars} = exception

"""
Updated auto-assertion is missing at least one previously bound variable:
#{format_vars(vars)}
Re-run this test to ensure it still passes.
"""
end

def message(%{message: message}) do
message
end

defp format_vars(vars) do
vars
|> Enum.map(fn
{name, _context} -> name
name -> name
end)
|> Enum.map_join(", ", &Atom.to_string/1)
end
end
104 changes: 104 additions & 0 deletions lib/mneme/utils.ex
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,108 @@ defmodule Mneme.Utils do
defp occurrences(<<char, rest::binary>>, char, acc), do: occurrences(rest, char, acc + 1)
defp occurrences(<<_, rest::binary>>, char, acc), do: occurrences(rest, char, acc)
defp occurrences(<<>>, _char, acc), do: acc

@doc """
Macro expands an expression in a pattern match context.
"""
@spec expand_pattern(Macro.t(), Macro.Env.t()) :: Macro.t()
def expand_pattern({:when, meta, [left, right]}, caller) do
left = do_expand_pattern(left, Macro.Env.to_match(caller))
right = do_expand_pattern(right, %{caller | context: :guard})
{:when, meta, [left, right]}
end

def expand_pattern(expr, caller) do
do_expand_pattern(expr, Macro.Env.to_match(caller))
end

defp do_expand_pattern({:quote, _, [_]} = expr, _caller), do: expr
defp do_expand_pattern({:quote, _, [_, _]} = expr, _caller), do: expr
defp do_expand_pattern({:__aliases__, _, _} = expr, caller), do: Macro.expand(expr, caller)

defp do_expand_pattern({:@, _, [{attribute, _, _}]}, caller) do
caller.module |> Module.get_attribute(attribute) |> Macro.escape()
end

defp do_expand_pattern({left, meta, right} = expr, caller) do
case Macro.expand(expr, caller) do
^expr ->
{do_expand_pattern(left, caller), meta, do_expand_pattern(right, caller)}

{left, meta, right} ->
{do_expand_pattern(left, caller), [original: expr] ++ meta,
do_expand_pattern(right, caller)}

other ->
other
end
end

defp do_expand_pattern({left, right}, caller) do
{do_expand_pattern(left, caller), do_expand_pattern(right, caller)}
end

defp do_expand_pattern([_ | _] = list, caller) do
Enum.map(list, &do_expand_pattern(&1, caller))
end

defp do_expand_pattern(other, _caller), do: other

@doc """
Collects variables bound in the given pattern.
"""
@spec collect_vars_from_pattern(Macro.t()) :: [var] when var: {atom(), Macro.metadata(), atom()}
def collect_vars_from_pattern({:when, _, [left, right]}) do
pattern = collect_vars_from_pattern(left)

vars =
for {name, _, context} = var <- collect_vars_from_pattern(right),
has_var?(pattern, name, context),
do: var

pattern ++ vars
end

def collect_vars_from_pattern(expr) do
expr
|> Macro.prewalk([], fn
{:"::", _, [left, right]}, acc ->
{[left], collect_vars_from_binary(right, acc)}

{skip, _, [_]}, acc when skip in [:^, :@, :quote] ->
{:ok, acc}

{skip, _, [_, _]}, acc when skip in [:quote] ->
{:ok, acc}

{:_, _, context}, acc when is_atom(context) ->
{:ok, acc}

{name, meta, context}, acc when is_atom(name) and is_atom(context) ->
{:ok, [{name, meta, context} | acc]}

node, acc ->
{node, acc}
end)
|> elem(1)
end

defp collect_vars_from_binary(right, original_acc) do
right
|> Macro.prewalk(original_acc, fn
{mode, _, [{name, meta, context}]}, acc
when is_atom(mode) and is_atom(name) and is_atom(context) ->
if has_var?(original_acc, name, context) do
{:ok, [{name, meta, context} | acc]}
else
{:ok, acc}
end

node, acc ->
{node, acc}
end)
|> elem(1)
end

defp has_var?(pattern, name, context), do: Enum.any?(pattern, &match?({^name, _, ^context}, &1))
end
Loading

0 comments on commit 58d9478

Please sign in to comment.