Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use simple assert for Nout check
Browse files Browse the repository at this point in the history
This is an internal sanity check that shouldn't need to throw an error
back to the user
danielwe committed Jan 22, 2025
1 parent 25b7ee3 commit af87875
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions src/typeutils/recursive_maps.jl
Original file line number Diff line number Diff line change
@@ -206,7 +206,7 @@ const YS{Nout,T} = Union{Val{Nout},NTuple{Nout,T}}
function recursive_map(
f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config::InactiveConfig=InactiveConfig()
) where {F,Nout,Nin,T}
check_nout(ys)
@assert (Nout == 1) || (Nout == 2)
newys = if isinactivetype(T, config)
recursive_map_inactive(nothing, ys, xs, config)
elseif isvectortype(T) || isbitstype(T)
@@ -226,7 +226,7 @@ function recursive_map(
config::InactiveConfig=InactiveConfig(),
) where {F,Nout,Nin,T}
# determine whether to continue recursion, copy/share, or retrieve from cache
check_nout(ys)
@assert (Nout == 1) || (Nout == 2)
newys = if isinactivetype(T, config)
recursive_map_inactive(seen, ys, xs, config)
elseif isbitstype(T) # no object identity to to track in this branch
@@ -581,19 +581,13 @@ end
else # may not show in coverage but is covered via accumulate_into! TODO: ensure coverage via VectorSpace once implemented
cache = seen[x1]::NTuple{(Nout + Nin - 1),T}
cachedtail = cache[(Nout+1):end]
check_identical(cachedtail, xtail) # check compatible layout
check_identical(cachedtail, xtail) # check compatible structure
cache[1:Nout]
end
return newys::NTuple{Nout,T}
end

## argument validation
@inline function check_nout(::YS{Nout}) where {Nout}
if Nout > 2
throw_nout()
end
end

Base.@propagate_inbounds function check_initialized(x, i, initialized=true)
if isinitialized(x, i) != initialized
throw_initialized() # TODO: hit this when VectorSpace implemented
@@ -644,7 +638,7 @@ end
end

@noinline function throw_identical()
msg = "recursive_map(!) called on objects whose layout don't match"
msg = "recursive_map(!) called on objects whose structure don't match"
throw(ArgumentError(msg))
end

0 comments on commit af87875

Please sign in to comment.