Skip to content

Commit

Permalink
Disallow mixing offset and non-offset axes in conv input
Browse files Browse the repository at this point in the history
  • Loading branch information
martinholters committed Nov 14, 2024
1 parent 8902bef commit 97e6c31
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 16 deletions.
2 changes: 1 addition & 1 deletion ext/OffsetArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ module OffsetArraysExt
import DSP
import OffsetArrays

DSP.conv_with_offset(::OffsetArrays.IdOffsetRange) = true
DSP.conv_axis_with_offset(::OffsetArrays.IdOffsetRange) = true

end
32 changes: 20 additions & 12 deletions src/dspbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -660,8 +660,16 @@ function _conv_td!(out, output_indices, u::AbstractArray{<:Number, N}, v::Abstra
end

# whether the given axis are to be considered to carry an offset for `conv!` and `conv`
conv_with_offset(::Base.OneTo) = false
conv_with_offset(a::Any) = throw(ArgumentError("unsupported axis type $(typeof(a))"))
conv_axis_with_offset(::Base.OneTo) = false
conv_axis_with_offset(a::Any) = throw(ArgumentError("unsupported axis type $(typeof(a))"))

Check warning on line 664 in src/dspbase.jl

View check run for this annotation

Codecov / codecov/patch

src/dspbase.jl#L664

Added line #L664 was not covered by tests

function conv_axes_with_offset(as::Tuple...)
with_offset = ((map(a -> map(conv_axis_with_offset, a), as)...)...,)
if !allequal(with_offset)
throw(ArgumentError("cannot mix offset and non-offset axes"))
end
return first(with_offset)
end

const FFTTypes = Union{Float32, Float64, ComplexF32, ComplexF64}

Expand Down Expand Up @@ -704,12 +712,8 @@ function conv!(
v::AbstractArray{<:Number, N};
algorithm=:auto
) where {T<:Number, N}
offset = conv_axes_with_offset(axes(out), axes(u), axes(v)) ? 0 : 1
output_indices = CartesianIndices(map(axes(out), axes(u), axes(v)) do ao, au, av
input_has_offset = conv_with_offset(au) || conv_with_offset(av)
if input_has_offset !== conv_with_offset(ao)
throw(ArgumentError("output must have offset axes if and only if the input has"))
end
offset = input_has_offset ? 0 : 1
return (first(au)+first(av) : last(au)+last(av)) .- offset
end)

Expand Down Expand Up @@ -752,9 +756,13 @@ function conv!(
end
end

conv_output_axis(au, av) =
conv_with_offset(au) || conv_with_offset(av) ?
(first(au)+first(av):last(au)+last(av)) : Base.OneTo(last(au) + last(av) - 1)
function conv_output_axes(au::Tuple, av::Tuple)
if conv_axes_with_offset(au, av)
return map((au, av) -> first(au)+first(av):last(au)+last(av), au, av)
else
return map((au, av) -> Base.OneTo(last(au) + last(av) - 1), au, av)
end
end

"""
conv(u, v; algorithm)
Expand All @@ -768,7 +776,7 @@ function conv(
u::AbstractArray{Tu, N}, v::AbstractArray{Tv, N}; kwargs...
) where {Tu<:Number, Tv<:Number, N}
T = promote_type(Tu, Tv)
out_axes = map(conv_output_axis, axes(u), axes(v))
out_axes = conv_output_axes(axes(u), axes(v))
out = similar(u, T, out_axes)
return conv!(out, u, v; kwargs...)
end
Expand All @@ -792,7 +800,7 @@ Uses 2-D FFT algorithm.
"""
function conv(u::AbstractVector{T}, v::Transpose{T,<:AbstractVector}, A::AbstractMatrix{T}) where T
# Arbitrary indexing offsets not implemented
if any(conv_with_offset, (axes(u)..., axes(v)..., axes(A)...))
if any(conv_axis_with_offset, (axes(u)..., axes(v)..., axes(A)...))
throw(ArgumentError("offset axes not supported"))
end
m = length(u)+size(A,1)-1
Expand Down
9 changes: 6 additions & 3 deletions test/dsp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,12 @@ end

offset_arr = OffsetVector{Int}(undef, -1:2)
offset_arr[:] = a
@test conv(offset_arr, 1:3) == OffsetVector(expectation, 0:5)
@test_throws ArgumentError conv(offset_arr, 1:3)
@test conv(offset_arr, OffsetArray(1:3)) == OffsetVector(expectation, 0:5)
offset_arr_f = OffsetVector{Float64}(undef, -1:2)
offset_arr_f[:] = fa
@test conv(offset_arr_f, 1:3) OffsetVector(fexp, 0:5)
@test_throws ArgumentError conv(offset_arr_f, 1:3)
@test conv(offset_arr_f, OffsetArray(1:3)) OffsetVector(fexp, 0:5)
@test_throws ArgumentError conv!(zeros(6), offset_arr, 1:3) # output needs to be OA, too
@test_throws ArgumentError conv!(OffsetVector{Int}(undef, 1:6), 1:4, 1:3) # output mustn't be OA

Expand Down Expand Up @@ -156,7 +158,8 @@ end

offset_arr = OffsetMatrix{Int}(undef, -1:1, -1:1)
offset_arr[:] = a
@test conv(offset_arr, b) == OffsetArray(expectation, 0:3, 0:3)
@test_throws ArgumentError conv(offset_arr, b)
@test conv(offset_arr, OffsetArray(b)) == OffsetArray(expectation, 0:3, 0:3)

for (M1, M2) in [(10, 20), (190, 200)], (N1, N2) in [(20, 10), (210, 200)], T in [Float64, ComplexF64]
u = rand(T, M1, M2)
Expand Down

0 comments on commit 97e6c31

Please sign in to comment.