From f2e563aa72945f1f4d26534d49363e3145979b7c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 13 Jul 2024 14:49:34 -0700 Subject: [PATCH] fix: eltype fix for wrapper types --- Project.toml | 2 +- ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 1 + src/impl/fused_conv.jl | 24 ++++++++++++------------ src/utils.jl | 7 ++++--- test/others/qa_tests.jl | 5 ++++- 5 files changed, 22 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index 01ab63ea..d6f79c5d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.30" +version = "0.3.31-DEV" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index bd2b4e2e..537c43c1 100644 --- a/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -35,6 +35,7 @@ end function CRC.rrule(::typeof(LuxLib.batchnorm_cudnn), running_mean, running_var, scale, bias, x, momentum, epsilon, t::Val{training}) where {training} + # TODO: Transition this to an error in the future !training && @warn "`training=Val(false)` but gradient was called." maxlog=1 y, xmean, xivar = LuxLib.batchnorm_cudnn( running_mean, running_var, scale, bias, x, momentum, epsilon, t) diff --git a/src/impl/fused_conv.jl b/src/impl/fused_conv.jl index 29c747e0..9b413f0b 100644 --- a/src/impl/fused_conv.jl +++ b/src/impl/fused_conv.jl @@ -48,28 +48,28 @@ function __conv!(::Type{<:AbstractLuxGPUDevice}, y::AbstractArray{yT, N}, __materialize_subarray(_ofeltype_array(yT, weight)), cdims) end -function __conv(x_::AbstractArray, weight_::AbstractArray, cdims::ConvDims) - x, weight = __get_conv_input_weight( - get_device_type((x_, weight_)), eltype(x_), eltype(weight_), x_, weight_) +function __conv( + x_::AbstractArray{xT}, weight_::AbstractArray{wT}, cdims::ConvDims) where {xT, wT} + x, weight = __get_conv_input_weight(get_device_type((x_, weight_)), xT, wT, x_, weight_) return conv(x, weight, cdims) end -function __∇conv_data(x_::AbstractArray, weight_::AbstractArray, cdims::ConvDims) - x, weight = __get_conv_input_weight( - get_device_type((x_, weight_)), eltype(x_), eltype(weight_), x_, weight_) +function __∇conv_data( + x_::AbstractArray{xT}, weight_::AbstractArray{wT}, cdims::ConvDims) where {xT, wT} + x, weight = __get_conv_input_weight(get_device_type((x_, weight_)), xT, wT, x_, weight_) return ∇conv_data(x, weight, cdims) end -function __∇conv_filter(x_::AbstractArray, y_::AbstractArray, cdims::ConvDims) - x, y = __get_conv_input_weight( - get_device_type((x_, y_)), eltype(x_), eltype(y_), x_, y_) +function __∇conv_filter( + x_::AbstractArray{xT}, y_::AbstractArray{yT}, cdims::ConvDims) where {xT, yT} + x, y = __get_conv_input_weight(get_device_type((x_, y_)), xT, yT, x_, y_) return ∇conv_filter(x, y, cdims) end -function __conv_bias_act(x_::AbstractArray, weight_::AbstractArray, cdims::ConvDims, - bias_::Optional{<:AbstractArray}, act::F) where {F} +function __conv_bias_act(x_::AbstractArray{xT}, weight_::AbstractArray{wT}, cdims::ConvDims, + bias_::Optional{<:AbstractArray}, act::F) where {xT, wT, F} dev = get_device_type((x_, weight_, bias_)) - x, weight = __get_conv_input_weight(dev, eltype(x_), eltype(weight_), x_, weight_) + x, weight = __get_conv_input_weight(dev, xT, wT, x_, weight_) bias = _ofeltype_array(eltype(x), bias_) return __conv_bias_act_impl(dev, x, weight, cdims, bias, act) end diff --git a/src/utils.jl b/src/utils.jl index 12eeae4f..e5519d7c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,9 +5,8 @@ return ntuple(i -> i == N - 1 ? ly : 1, N) elseif N > 2 && ly == sx[N - 1] * sx[N - 2] return ntuple(i -> i == (N - 1) || i == (N - 2) ? sx[i] : 1, N) - else - throw(ArgumentError("Invalid Dimensions!")) end + throw(ArgumentError("Invalid Dimensions!")) end CRC.@non_differentiable _get_reshape_dims(::Any...) @@ -194,6 +193,8 @@ __value(::Type{T}) where {T <: Number} = T __value(x::ForwardDiff.Dual) = ForwardDiff.value(x) __value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) -__value(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = LuxLib.__value(T) +__value(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = __value(T) + +__value(::Nothing) = nothing __aos_to_soa(x::AbstractArray) = x # FIXME: Upstream this to ArrayInterface.jl diff --git a/test/others/qa_tests.jl b/test/others/qa_tests.jl index f49ea740..c975375b 100644 --- a/test/others/qa_tests.jl +++ b/test/others/qa_tests.jl @@ -1,7 +1,10 @@ @testitem "Aqua: Quality Assurance" tags=[:others] begin using Aqua - Aqua.test_all(LuxLib) + Aqua.test_all(LuxLib; ambiguities=false, piracies=false) + Aqua.test_ambiguities( + LuxLib; recursive=false, exclude=[conv, ∇conv_data, ∇conv_filter, depthwiseconv]) + Aqua.test_piracies(LuxLib; treat_as_own=[conv, ∇conv_data, ∇conv_filter, depthwiseconv]) end @testitem "Explicit Imports" tags=[:others] begin