From d39b29f10cb81882ea3d08741d27ac6e552d2d40 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 21 Apr 2023 10:33:38 +1200 Subject: [PATCH 1/2] bump compat LossFunctions = "0.9" and address breakage --- Project.toml | 2 +- src/measures/loss_functions_interface.jl | 6 +++--- test/measures/loss_functions_interface.jl | 17 ++++++----------- 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index cfa765dd..5fea7dce 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ CategoricalDistributions = "0.1" ComputationalResources = "0.3" Distributions = "0.25.3" InvertedIndices = "1" -LossFunctions = "0.5, 0.6, 0.7, 0.8" +LossFunctions = "0.9" MLJModelInterface = "1.7" Missings = "0.4, 1" OrderedCollections = "1.1" diff --git a/src/measures/loss_functions_interface.jl b/src/measures/loss_functions_interface.jl index 3ee9ed6c..e18298cb 100644 --- a/src/measures/loss_functions_interface.jl +++ b/src/measures/loss_functions_interface.jl @@ -44,7 +44,7 @@ err_wrap(n) = ArgumentError("Bad @wrap syntax: $n. ") # We define amacro to wrap a concrete `LossFunctions.SupervisedLoss` # type and define its constructor, and to define property access in -# case of paramters; the macro also defined calling behaviour: +# case of paramters; the macro also defines calling behaviour: macro wrap_loss(ex) ex.head == :call || throw(err_wrap(1)) Loss_ex = ex.args[1] @@ -130,7 +130,7 @@ MMI.prediction_type(::Type{<:DistanceLoss}) = :deterministic MMI.target_scitype(::Type{<:DistanceLoss}) = Union{Vec{Continuous},Vec{Count}} call(measure::DistanceLoss, yhat, y) = - LossFunctions.value(getfield(measure, :loss), y, yhat) + LossFunctions.value(getfield(measure, :loss), yhat, y) function call(measure::DistanceLoss, yhat, y, w::AbstractArray) return w .* call(measure, yhat, y) @@ -148,7 +148,7 @@ _scale(p) = 2p - 1 function call(measure::MarginLoss, yhat, y) probs_of_observed = broadcast(pdf, yhat, y) return (LossFunctions.value).(getfield(measure, :loss), - 1, _scale.(probs_of_observed)) + _scale.(probs_of_observed), 1) end call(measure::MarginLoss, yhat, y, w::AbstractArray) = diff --git a/test/measures/loss_functions_interface.jl b/test/measures/loss_functions_interface.jl index 30fe057e..d5894eb8 100644 --- a/test/measures/loss_functions_interface.jl +++ b/test/measures/loss_functions_interface.jl @@ -42,12 +42,9 @@ end for M_ex in MARGIN_LOSSES m = eval(:(MLJBase.$M_ex())) - @test m(yhat, y) ≈ LossFunctions.value(getfield(m, :loss), ym, yhatm) - @test MLJBase.Mean()(m(yhat, y, w)) ≈ - LossFunctions.value(getfield(m, :loss), - ym, - yhatm, - WeightedSum(w))/N + @test m(yhat, y) ≈ LossFunctions.value(getfield(m, :loss), yhatm, ym) + @test m(yhat, y, w) ≈ + w .* LossFunctions.value(getfield(m, :loss), yhatm, ym) end end @@ -64,10 +61,8 @@ end m_ex = MLJBase.snakecase(M_ex) @test m == eval(:(MLJBase.$m_ex)) @test m(yhat, y) ≈ - LossFunctions.value(getfield(m, :loss), y, yhat) - @test mean(m(yhat ,y, w)) ≈ - LossFunctions.value(getfield(m, :loss), y, yhat, - WeightedSum(w))/N - + LossFunctions.value(getfield(m, :loss), yhat, y) + @test m(yhat ,y, w) ≈ + w .* LossFunctions.value(getfield(m, :loss), yhat, y) end end From 9603c64350b38b3098976ff65ee82062fdbfbf56 Mon Sep 17 00:00:00 2001 From: "Anthony Blaom, PhD" Date: Fri, 21 Apr 2023 11:03:46 +1200 Subject: [PATCH 2/2] Update src/measures/loss_functions_interface.jl Co-authored-by: Brian Chen --- src/measures/loss_functions_interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/measures/loss_functions_interface.jl b/src/measures/loss_functions_interface.jl index e18298cb..42e9bdbf 100644 --- a/src/measures/loss_functions_interface.jl +++ b/src/measures/loss_functions_interface.jl @@ -44,7 +44,7 @@ err_wrap(n) = ArgumentError("Bad @wrap syntax: $n. ") # We define amacro to wrap a concrete `LossFunctions.SupervisedLoss` # type and define its constructor, and to define property access in -# case of paramters; the macro also defines calling behaviour: +# case of parameters; the macro also defines calling behaviour: macro wrap_loss(ex) ex.head == :call || throw(err_wrap(1)) Loss_ex = ex.args[1]