Skip to content

Commit

Permalink
Broaden the scope
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 29, 2023
1 parent 6a92f2a commit ac1ce79
Showing 1 changed file with 48 additions and 13 deletions.
61 changes: 48 additions & 13 deletions src/MaybeInplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ following operations are supported:
5. `x .= <expr>`
6. `@. <expr>`
7. `x = copy(y)`
8. `axpy!(a, x, y)`
This macro also allows some custom operators:
Expand Down Expand Up @@ -77,11 +78,31 @@ all operations on the list.
"""

## Main Function
function __bangbang__(M, iip::Symbol, expr)
new_expr = nothing
if @capture(expr, f_(a_, args__))
new_expr = quote
if $(iip)
$(expr)

Check warning on line 86 in src/MaybeInplace.jl

View check run for this annotation

Codecov / codecov/patch

src/MaybeInplace.jl#L81-L86

Added lines #L81 - L86 were not covered by tests
else
$(a) = $(f)($(a), $(args...))

Check warning on line 88 in src/MaybeInplace.jl

View check run for this annotation

Codecov / codecov/patch

src/MaybeInplace.jl#L88

Added line #L88 was not covered by tests
end
end
end
if new_expr !== nothing
return esc(new_expr)

Check warning on line 93 in src/MaybeInplace.jl

View check run for this annotation

Codecov / codecov/patch

src/MaybeInplace.jl#L92-L93

Added lines #L92 - L93 were not covered by tests
end
error("`$(iip) $(expr)` cannot be handled. Check the documentation for allowed \

Check warning on line 95 in src/MaybeInplace.jl

View check run for this annotation

Codecov / codecov/patch

src/MaybeInplace.jl#L95

Added line #L95 was not covered by tests
expressions.")
end

function __bangbang__(M, expr; depth::Int = 1)
new_expr = nothing
if @capture(expr, a_Symbol=copy(b_))
if @capture(expr, a_=copy(b_))
new_expr = :($(a) = $(__copy)($(setindex_trait)($(b)), $(b)))
elseif @capture(expr, f_(a_Symbol, args__))
elseif @capture(expr, axpy!(α_, x_, y_))
new_expr = __handle_axpy(M, α, x, y, depth)
elseif @capture(expr, f_(a_, args__))

Check warning on line 105 in src/MaybeInplace.jl

View check run for this annotation

Codecov / codecov/patch

src/MaybeInplace.jl#L103-L105

Added lines #L103 - L105 were not covered by tests
g = get(OP_MAPPING, f, nothing)
if g !== nothing
new_expr = :($(a) = $(g)($(setindex_trait)($(a)), $(a), $(args...)))
Expand Down Expand Up @@ -184,6 +205,16 @@ function __handle_dot_macro(M, a, f, depth)
end
end

function __handle_axpy(M, α, x, y, depth)
return quote
if $(setindex_trait)($(y)) === $(CanSetindex())
$(__safe_axpy!)($(α), $(x), $(y))

Check warning on line 211 in src/MaybeInplace.jl

View check run for this annotation

Codecov / codecov/patch

src/MaybeInplace.jl#L208-L211

Added lines #L208 - L211 were not covered by tests
else
$(y) = @. $(α) * $(x) + $(y)

Check warning on line 213 in src/MaybeInplace.jl

View check run for this annotation

Codecov / codecov/patch

src/MaybeInplace.jl#L213

Added line #L213 was not covered by tests
end
end
end

## Traits
abstract type AbstractMaybeSetindex end
struct CannotSetindex <: AbstractMaybeSetindex end
Expand Down Expand Up @@ -220,20 +251,24 @@ setindex_trait(A) = ifelse(can_setindex(A), CanSetindex(), CannotSetindex())
const OP_MAPPING = Dict{Symbol, Function}(:copyto! => __copyto!!, :.-= => __sub!!,
:.+= => __add!!, :.*= => __mul!!, :./= => __div!!, :copy => __copy)

## Macros
@doc __bangbang__docs
macro bangbang(expr)
return __bangbang__(__module__, expr)
@inline @generated function __safe_axpy!(α, x, y)
hasmethod(axpy!, Tuple{typeof(α), typeof(x), typeof(y)}) || return :(axpy!(α, x, y))
return :(@. y += α * x)

Check warning on line 256 in src/MaybeInplace.jl

View check run for this annotation

Codecov / codecov/patch

src/MaybeInplace.jl#L254-L256

Added lines #L254 - L256 were not covered by tests
end

@doc __bangbang__docs
macro bb(expr)
return __bangbang__(__module__, expr)
end
## Macros
for m in (:bangbang, :bb, :❗)
@eval begin
@doc __bangbang__docs
macro $m(expr)
return __bangbang__(__module__, expr)
end

@doc __bangbang__docs
macro (expr)
return __bangbang__(__module__, expr)
@doc __bangbang__docs
macro $m(iip::Symbol, expr)
return __bangbang__(__module__, iip, expr)
end
end
end

@inline _vec(v) = v
Expand Down

0 comments on commit ac1ce79

Please sign in to comment.