diff --git a/Project.toml b/Project.toml index 2137a4c..e77db9a 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" diff --git a/src/conformal_models/conformal_models.jl b/src/conformal_models/conformal_models.jl index 1804588..e6373db 100644 --- a/src/conformal_models/conformal_models.jl +++ b/src/conformal_models/conformal_models.jl @@ -101,7 +101,6 @@ const available_models = Dict( :transductive => Dict(:naive => NaiveClassifier), :inductive => Dict( :simple_inductive => SimpleInductiveClassifier, - :trainable_simple_inductive => TrainableSimpleInductiveClassifier, :adaptive_inductive => AdaptiveInductiveClassifier, ), ), @@ -120,6 +119,7 @@ const tested_atomic_models = Dict( :evo_tree => :(@load EvoTreeClassifier pkg = EvoTrees), :nearest_neighbor => :(@load KNNClassifier pkg = NearestNeighborModels), :light_gbm => :(@load LGBMClassifier pkg = LightGBM), + :flux => :(@load NeuralNetworkClassifier pkg = MLJFlux), ), ) diff --git a/src/conformal_models/inductive_classification.jl b/src/conformal_models/inductive_classification.jl index be3af5d..24d3d1b 100644 --- a/src/conformal_models/inductive_classification.jl +++ b/src/conformal_models/inductive_classification.jl @@ -18,17 +18,18 @@ function SimpleInductiveClassifier( end function score(conf_model::SimpleInductiveClassifier, fitresult, X, y::Union{Nothing,AbstractArray}=nothing) - # X = isa(X, Matrix) ? table(X) : X - # p̂ = reformat_mlj_prediction(MMI.predict(conf_model.model, fitresult, X)) - # L = p̂.decoder.classes - # probas = pdf(p̂, L) - X = size(X,2) == 1 ? X : permutedims(X) - probas = permutedims(fitresult[1](X)) + score(conf_model, typeof(conf_model.model), fitresult, X, y) +end + +function score(conf_model::SimpleInductiveClassifier, ::Type{<:Supervised}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing) + p̂ = reformat_mlj_prediction(MMI.predict(conf_model.model, fitresult, X)) + L = p̂.decoder.classes + probas = pdf(p̂, L) scores = @.(conf_model.heuristic(probas)) if isnothing(y) return scores else - cal_scores = getindex.(Ref(scores), 1:size(scores,1), levelcode.(y)) + cal_scores = getindex.(Ref(scores), 1:size(scores, 1), levelcode.(y)) return cal_scores, scores end end @@ -59,7 +60,7 @@ function MMI.fit(conf_model::SimpleInductiveClassifier, verbosity, X, y) fitresult, cache, report = MMI.fit(conf_model.model, verbosity, Xtrain, ytrain) # Nonconformity Scores: - cal_scores, scores = score(conf_model, fitresult, matrix(Xcal), ycal) + cal_scores, scores = score(conf_model, fitresult, Xcal, ycal) conf_model.scores = Dict( :calibration => cal_scores, :all => scores, diff --git a/src/conformal_models/training/inductive_classification.jl b/src/conformal_models/training/inductive_classification.jl index 7bcd3e8..5c5248a 100644 --- a/src/conformal_models/training/inductive_classification.jl +++ b/src/conformal_models/training/inductive_classification.jl @@ -1,25 +1,7 @@ using MLJFlux: MLJFluxModel -"The `TrainableSimpleInductiveClassifier` is the simplest approach to Inductive Conformal Classification. Contrary to the [`NaiveClassifier`](@ref) it computes nonconformity scores using a designated calibration dataset." -mutable struct TrainableSimpleInductiveClassifier{Model<:MLJFluxModel} <: ConformalProbabilisticSet - model::Model - coverage::AbstractFloat - scores::Union{Nothing,Dict{Any,Any}} - heuristic::Function - train_ratio::AbstractFloat -end - -function TrainableSimpleInductiveClassifier( - model::MLJFluxModel; - coverage::AbstractFloat = 0.95, - heuristic::Function = f(p̂) = 1.0 - p̂, - train_ratio::AbstractFloat = 0.5, -) - return TrainableSimpleInductiveClassifier(model, coverage, nothing, heuristic, train_ratio) -end - -function score(conf_model::TrainableSimpleInductiveClassifier, fitresult, X, y::Union{Nothing,AbstractArray}=nothing) - X = size(X,2) == 1 ? X : permutedims(X) +function score(conf_model::SimpleInductiveClassifier, ::Type{<:MLJFluxModel}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing) + X = permutedims(matrix(X)) probas = permutedims(fitresult[1](X)) scores = @.(conf_model.heuristic(probas)) if isnothing(y) @@ -28,158 +10,4 @@ function score(conf_model::TrainableSimpleInductiveClassifier, fitresult, X, y:: cal_scores = getindex.(Ref(scores), 1:size(scores,1), levelcode.(y)) return cal_scores, scores end -end - -@doc raw""" - MMI.fit(conf_model::TrainableSimpleInductiveClassifier, verbosity, X, y) - -For the [`TrainableSimpleInductiveClassifier`](@ref) nonconformity scores are computed as follows: - -`` -S_i^{\text{CAL}} = s(X_i, Y_i) = h(\hat\mu(X_i), Y_i), \ i \in \mathcal{D}_{\text{calibration}} -`` - -A typical choice for the heuristic function is ``h(\hat\mu(X_i), Y_i)=1-\hat\mu(X_i)_{Y_i}`` where ``\hat\mu(X_i)_{Y_i}`` denotes the softmax output of the true class and ``\hat\mu`` denotes the model fitted on training data ``\mathcal{D}_{\text{train}}``. The simple approach only takes the softmax probability of the true label into account. -""" -function MMI.fit(conf_model::TrainableSimpleInductiveClassifier, verbosity, X, y) - - # Data Splitting: - train, calibration = partition(eachindex(y), conf_model.train_ratio) - Xtrain = selectrows(X, train) - ytrain = y[train] - Xtrain, ytrain = MMI.reformat(conf_model.model, Xtrain, ytrain) - Xcal = selectrows(X, calibration) - ycal = y[calibration] - Xcal, ycal = MMI.reformat(conf_model.model, Xcal, ycal) - - # Training: - fitresult, cache, report = MMI.fit(conf_model.model, verbosity, Xtrain, ytrain) - - # Nonconformity Scores: - cal_scores, scores = score(conf_model, fitresult, matrix(Xcal), ycal) - conf_model.scores = Dict( - :calibration => cal_scores, - :all => scores, - ) - - return (fitresult, cache, report) -end - -@doc raw""" - MMI.predict(conf_model::TrainableSimpleInductiveClassifier, fitresult, Xnew) - -For the [`TrainableSimpleInductiveClassifier`](@ref) prediction sets are computed as follows, - -`` -\hat{C}_{n,\alpha}(X_{n+1}) = \left\{y: s(X_{n+1},y) \le \hat{q}_{n, \alpha}^{+} \{S_i^{\text{CAL}}\} \right\}, \ i \in \mathcal{D}_{\text{calibration}} -`` - -where ``\mathcal{D}_{\text{calibration}}`` denotes the designated calibration data. -""" -function MMI.predict(conf_model::TrainableSimpleInductiveClassifier, fitresult, Xnew) - p̂ = reformat_mlj_prediction( - MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, Xnew)...), - ) - v = conf_model.scores[:calibration] - q̂ = StatsBase.quantile(v, conf_model.coverage) - p̂ = map(p̂) do pp - L = p̂.decoder.classes - probas = pdf.(pp, L) - is_in_set = 1.0 .- probas .<= q̂ - if !all(is_in_set .== false) - pp = UnivariateFinite(L[is_in_set], probas[is_in_set]) - else - pp = missing - end - return pp - end - return p̂ -end - -# Adaptive -"The `TrainableAdaptiveInductiveClassifier` is an improvement to the [`TrainableSimpleInductiveClassifier`](@ref) and the [`NaiveClassifier`](@ref). Contrary to the [`NaiveClassifier`](@ref) it computes nonconformity scores using a designated calibration dataset like the [`TrainableSimpleInductiveClassifier`](@ref). Contrary to the [`TrainableSimpleInductiveClassifier`](@ref) it utilizes the softmax output of all classes." -mutable struct TrainableAdaptiveInductiveClassifier{Model<:MLJFluxModel} <: ConformalProbabilisticSet - model::Model - coverage::AbstractFloat - scores::Union{Nothing,AbstractArray} - heuristic::Function - train_ratio::AbstractFloat -end - -function TrainableAdaptiveInductiveClassifier( - model::MLJFluxModel; - coverage::AbstractFloat = 0.95, - heuristic::Function = f(y, ŷ) = 1.0 - ŷ, - train_ratio::AbstractFloat = 0.5, -) - return TrainableAdaptiveInductiveClassifier(model, coverage, nothing, heuristic, train_ratio) -end - -@doc raw""" - MMI.fit(conf_model::TrainableAdaptiveInductiveClassifier, verbosity, X, y) - -For the [`TrainableAdaptiveInductiveClassifier`](@ref) nonconformity scores are computed by cumulatively summing the ranked scores of each label in descending order until reaching the true label ``Y_i``: - -`` -S_i^{\text{CAL}} = s(X_i,Y_i) = \sum_{j=1}^k \hat\mu(X_i)_{\pi_j} \ \text{where } \ Y_i=\pi_k, i \in \mathcal{D}_{\text{calibration}} -`` -""" -function MMI.fit(conf_model::TrainableAdaptiveInductiveClassifier, verbosity, X, y) - - # Data Splitting: - train, calibration = partition(eachindex(y), conf_model.train_ratio) - Xtrain = selectrows(X, train) - ytrain = y[train] - Xtrain, ytrain = MMI.reformat(conf_model.model, Xtrain, ytrain) - Xcal = selectrows(X, calibration) - ycal = y[calibration] - Xcal, ycal = MMI.reformat(conf_model.model, Xcal, ycal) - - # Training: - fitresult, cache, report = MMI.fit(conf_model.model, verbosity, Xtrain, ytrain) - - # Nonconformity Scores: - p̂ = reformat_mlj_prediction(MMI.predict(conf_model.model, fitresult, Xcal)) - L = p̂.decoder.classes - ŷ = pdf(p̂, L) # compute probabilities for all classes - scores = map(eachrow(ŷ), eachrow(ycal)) do ŷᵢ, ycalᵢ - ranks = sortperm(.-ŷᵢ) # rank in descending order - index_y = findall(L[ranks] .== ycalᵢ)[1] # index of true y in sorted array - scoreᵢ = last(cumsum(ŷᵢ[ranks][1:index_y])) # sum up until true y is reached - return scoreᵢ - end - conf_model.scores = scores - - return (fitresult, cache, report) -end - -@doc raw""" - MMI.predict(conf_model::TrainableAdaptiveInductiveClassifier, fitresult, Xnew) - -For the [`TrainableAdaptiveInductiveClassifier`](@ref) prediction sets are computed as follows, - -`` -\hat{C}_{n,\alpha}(X_{n+1}) = \left\{y: s(X_{n+1},y) \le \hat{q}_{n, \alpha}^{+} \{S_i^{\text{CAL}}\} \right\}, i \in \mathcal{D}_{\text{calibration}} -`` - -where ``\mathcal{D}_{\text{calibration}}`` denotes the designated calibration data. -""" -function MMI.predict(conf_model::TrainableAdaptiveInductiveClassifier, fitresult, Xnew) - p̂ = reformat_mlj_prediction( - MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, Xnew)...), - ) - v = conf_model.scores - q̂ = StatsBase.quantile(v, conf_model.coverage) - p̂ = map(p̂) do pp - L = p̂.decoder.classes - probas = pdf.(pp, L) - is_in_set = 1.0 .- probas .<= q̂ - if !all(is_in_set .== false) - pp = UnivariateFinite(L[is_in_set], probas[is_in_set]) - else - pp = missing - end - return pp - end - return p̂ -end +end \ No newline at end of file diff --git a/test/Manifest.toml b/test/Manifest.toml index 4d76a13..0b03d20 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.7.2" +julia_version = "1.8.5" manifest_format = "2.0" -project_hash = "927d1b7a82bc93ab3d48f603d6d16919bd6f0dc1" +project_hash = "35130c42d0ed70ece3ae50bdfeedecf590b3fb1d" [[deps.ANSIColoredPrinters]] git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" @@ -26,14 +26,26 @@ git-tree-sha1 = "faa260e4cb5aba097a73fab382dd4b5819d8ec8c" uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" version = "0.4.4" +[[deps.Accessors]] +deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "StaticArrays", "Test"] +git-tree-sha1 = "beabc31fa319f9de4d16372bff31b4801e43d32c" +uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +version = "0.1.28" + [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] git-tree-sha1 = "cc37d689f599e8df4f464b2fa3870ff7db7492ef" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" version = "3.6.1" +[[deps.ArgCheck]] +git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" +uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" +version = "2.3.0" + [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" [[deps.ArrayInterface]] deps = ["Adapt", "LinearAlgebra", "Requires", "SnoopPrecompile", "SparseArrays", "SuiteSparse"] @@ -61,9 +73,20 @@ git-tree-sha1 = "86e9781ac28f4e80e9b98f7f96eae21891332ac2" uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" version = "0.3.6" +[[deps.BangBang]] +deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"] +git-tree-sha1 = "7fe6d92c4f281cf4ca6f2fba0ce7b299742da7ca" +uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" +version = "0.3.37" + [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +[[deps.Baselet]] +git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" +uuid = "9718e550-a3fa-408a-8086-8db961cd8217" +version = "0.1.1" + [[deps.BitFlags]] git-tree-sha1 = "43b1a4a8f797c1cddadf60499a8a077d4af2cd2d" uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" @@ -122,6 +145,12 @@ git-tree-sha1 = "da68989f027dcefa74d44a452c9e36af9730a70d" uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e" version = "0.1.10" +[[deps.ChainRules]] +deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"] +git-tree-sha1 = "7d20c2fb8ab838e41069398685e7b6b5f89ed85b" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "1.48.0" + [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] git-tree-sha1 = "c6d890a52d2c4d55d326439580c3b8d0875a77d9" @@ -196,6 +225,12 @@ version = "0.1.25" [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.0.1+0" + +[[deps.CompositionsBase]] +git-tree-sha1 = "455419f7e328a1a2493cabc6428d79e951349769" +uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" +version = "0.1.1" [[deps.ComputationalResources]] git-tree-sha1 = "52cb3ec90e8a8bea0e62e275ba577ad0f74821f7" @@ -208,6 +243,12 @@ git-tree-sha1 = "89a9db8d28102b094992472d333674bd1a83ce2a" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" version = "1.5.1" +[[deps.ContextVariablesX]] +deps = ["Compat", "Logging", "UUIDs"] +git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" +uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" +version = "0.1.3" + [[deps.Contour]] git-tree-sha1 = "d05d9e7b7aedff4e5b51a029dced05cfb6125781" uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" @@ -250,6 +291,11 @@ git-tree-sha1 = "c6475a3ccad06cb1c2ebc0740c1bb4fe5a0731b7" uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" version = "0.12.3" +[[deps.DefineSingletons]] +git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" +uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" +version = "0.1.2" + [[deps.DelimitedFiles]] deps = ["Mmap"] uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" @@ -301,8 +347,9 @@ uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" version = "0.27.23" [[deps.Downloads]] -deps = ["ArgTools", "LibCURL", "NetworkOptions"] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" [[deps.DualNumbers]] deps = ["Calculus", "NaNMath", "SpecialFunctions"] @@ -356,6 +403,21 @@ git-tree-sha1 = "74faea50c1d007c85837327f6775bea60b5492dd" uuid = "b22a6f82-2f65-5046-a5b2-351ab43fb4e5" version = "4.4.2+2" +[[deps.FLoops]] +deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] +git-tree-sha1 = "ffb97765602e3cbe59a0589d237bf07f245a8576" +uuid = "cc61a311-1640-44b5-9fba-1b764f453329" +version = "0.2.1" + +[[deps.FLoopsBase]] +deps = ["ContextVariablesX"] +git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" +uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" +version = "0.1.1" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + [[deps.FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] git-tree-sha1 = "d3ba08ab64bdfd27234d3f61956c966266757fe6" @@ -374,6 +436,18 @@ git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" version = "0.8.4" +[[deps.Flux]] +deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Zygote"] +git-tree-sha1 = "4ff3a1d7b0dd38f2fc38e813bc801f817639c1f2" +uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" +version = "0.13.13" + +[[deps.FoldsThreads]] +deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"] +git-tree-sha1 = "eb8e1989b9028f7e0985b4268dabe94682249025" +uuid = "9c68100b-dfe1-47cf-94c8-95104e173443" +version = "0.1.1" + [[deps.Fontconfig_jll]] deps = ["Artifacts", "Bzip2_jll", "Expat_jll", "FreeType2_jll", "JLLWrappers", "Libdl", "Libuuid_jll", "Pkg", "Zlib_jll"] git-tree-sha1 = "21efd19106a55620a188615da6d3d06cd7f6ee03" @@ -404,6 +478,17 @@ git-tree-sha1 = "aa31987c2ba8704e23c6c8ba8a4f769d5d7e4f91" uuid = "559328eb-81f9-559d-9380-de523a88c83c" version = "1.0.10+0" +[[deps.FunctionWrappers]] +git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e" +uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" +version = "1.1.3" + +[[deps.Functors]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "478f8c3145bb91d82c2cf20433e8c1b30df454cc" +uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +version = "0.4.4" + [[deps.Future]] deps = ["Random"] uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" @@ -509,6 +594,12 @@ git-tree-sha1 = "f7be53659ab06ddc986428d3a9dcc95f6fa6705a" uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" version = "0.2.2" +[[deps.IRTools]] +deps = ["InteractiveUtils", "MacroTools", "Test"] +git-tree-sha1 = "0ade27f0c49cebd8db2523c4eeccf779407cf12c" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.4.9" + [[deps.IfElse]] git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" @@ -519,6 +610,11 @@ git-tree-sha1 = "f550e6e32074c939295eb5ea6de31849ac2c9625" uuid = "83e8ac13-25f8-5344-8a64-a9f2b223428f" version = "0.5.1" +[[deps.InitialValues]] +git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" +uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" +version = "0.3.1" + [[deps.InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -585,6 +681,12 @@ git-tree-sha1 = "6f2675ef130a300a112286de91973805fcc5ffbc" uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" version = "2.1.91+0" +[[deps.JuliaVariables]] +deps = ["MLStyle", "NameResolution"] +git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" +uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" +version = "0.2.4" + [[deps.LAME_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "f6250b16881adf048549549fba48b1161acdac8c" @@ -645,10 +747,12 @@ uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.3" [[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "7.84.0+0" [[deps.LibGit2]] deps = ["Base64", "NetworkOptions", "Printf", "SHA"] @@ -657,6 +761,7 @@ uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" [[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.10.2+0" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -782,6 +887,12 @@ git-tree-sha1 = "bb8a1056b1d8b40f2f27167fc3ef6412a6719fbf" uuid = "50ed68f4-41fd-4504-931a-ed422449fee0" version = "0.3.2" +[[deps.MLJFlux]] +deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"] +git-tree-sha1 = "2ecdce4dd9214789ee1796103d29eaee7619ebd0" +uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845" +version = "0.2.9" + [[deps.MLJIteration]] deps = ["IterationControl", "MLJBase", "Random", "Serialization"] git-tree-sha1 = "be6d5c71ab499a59e82d65e00a89ceba8732fcd5" @@ -812,6 +923,17 @@ git-tree-sha1 = "02688098bd77827b64ed8ad747c14f715f98cfc4" uuid = "03970b2e-30c4-11ea-3135-d1576263f10f" version = "0.7.4" +[[deps.MLStyle]] +git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" +uuid = "d8e11817-5142-5d16-987a-aa16d5891078" +version = "0.4.17" + +[[deps.MLUtils]] +deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "FoldsThreads", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] +git-tree-sha1 = "f69cdbb5b7c630c02481d81d50eac43697084fe0" +uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" +version = "0.4.1" + [[deps.MacroTools]] deps = ["Markdown", "Random"] git-tree-sha1 = "42324d08725e200c23d4dfb549e0d5d89dede2d2" @@ -842,12 +964,25 @@ version = "1.1.7" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.0+0" [[deps.Measures]] git-tree-sha1 = "c13304c81eec1ed3af7fc20e75fb6b26092a1102" uuid = "442fdcdd-2543-5da2-b0f3-8c86c306513e" version = "0.3.2" +[[deps.Metalhead]] +deps = ["Artifacts", "BSON", "Flux", "Functors", "LazyArtifacts", "MLUtils", "NNlib", "Random", "Statistics"] +git-tree-sha1 = "0e95f91cc5f23610f8f270d7397f307b21e19d2b" +uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" +version = "0.7.4" + +[[deps.MicroCollections]] +deps = ["BangBang", "InitialValues", "Setfield"] +git-tree-sha1 = "629afd7d10dbc6935ec59b32daeb33bc4460a42e" +uuid = "128add7d-3638-4c79-886c-908ea0c25c34" +version = "0.1.4" + [[deps.Missings]] deps = ["DataAPI"] git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272" @@ -859,6 +994,7 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2022.2.1" [[deps.NLSolversBase]] deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"] @@ -866,12 +1002,30 @@ git-tree-sha1 = "a0b464d183da839699f4c79e7606d9d186ec172c" uuid = "d41bc354-129a-5804-8e4c-c37616107c6c" version = "7.8.3" +[[deps.NNlib]] +deps = ["Adapt", "ChainRulesCore", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] +git-tree-sha1 = "33ad5a19dc6730d592d8ce91c14354d758e53b0e" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.8.19" + +[[deps.NNlibCUDA]] +deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"] +git-tree-sha1 = "b05a082b08a3af0e5c576883bc6dfb6513e7e478" +uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" +version = "0.2.6" + [[deps.NaNMath]] deps = ["OpenLibm_jll"] git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" version = "1.0.2" +[[deps.NameResolution]] +deps = ["PrettyPrint"] +git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" +uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" +version = "0.1.5" + [[deps.NearestNeighborModels]] deps = ["Distances", "FillArrays", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "NearestNeighbors", "Statistics", "StatsBase", "Tables"] git-tree-sha1 = "727b8f1c3f9fec6b1a805ba9bef72c73758eda02" @@ -892,6 +1046,7 @@ version = "0.4.4" [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" [[deps.OffsetArrays]] deps = ["Adapt"] @@ -905,13 +1060,21 @@ git-tree-sha1 = "887579a3eb005446d514ab7aeac5d1d027658b8f" uuid = "e7412a2a-1a6e-54c0-be00-318e2571c051" version = "1.3.5+1" +[[deps.OneHotArrays]] +deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] +git-tree-sha1 = "f511fca956ed9e70b80cd3417bb8c2dde4b68644" +uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" +version = "0.2.3" + [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.20+0" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+0" [[deps.OpenML]] deps = ["ARFFFiles", "HTTP", "JSON", "Markdown", "Pkg", "Scratch"] @@ -943,6 +1106,12 @@ git-tree-sha1 = "1903afc76b7d01719d9c30d3c7d501b61db96721" uuid = "429524aa-4258-5aef-a3af-852621145aeb" version = "1.7.4" +[[deps.Optimisers]] +deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "4b214125921ec010160ddb39931885e0a6585639" +uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" +version = "0.2.17" + [[deps.Opus_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "51a08fb14ec28da2ec7a927c4337e4332c2a4720" @@ -957,6 +1126,7 @@ version = "1.4.1" [[deps.PCRE2_jll]] deps = ["Artifacts", "Libdl"] uuid = "efcefdf7-47ab-520b-bdef-62a2eaa19f15" +version = "10.40.0+0" [[deps.PDMats]] deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] @@ -990,6 +1160,7 @@ version = "0.40.1+0" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.8.0" [[deps.PlotThemes]] deps = ["PlotUtils", "Statistics"] @@ -1027,6 +1198,11 @@ git-tree-sha1 = "47e5f437cc0e7ef2ce8406ce1e7e24d44915f88d" uuid = "21216c6a-2e73-6563-6e65-726566657250" version = "1.3.0" +[[deps.PrettyPrint]] +git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" +uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" +version = "0.2.0" + [[deps.PrettyPrinting]] git-tree-sha1 = "4be53d093e9e37772cc89e1009e8f6ad10c4681b" uuid = "54e16d92-306c-5ea0-a30b-337be88ac337" @@ -1042,6 +1218,12 @@ version = "2.2.2" deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" +[[deps.ProgressLogging]] +deps = ["Logging", "SHA", "UUIDs"] +git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" +uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" +version = "0.1.4" + [[deps.ProgressMeter]] deps = ["Distributed", "Printf"] git-tree-sha1 = "d7a7aef8f8f2d537104f170139553b14dfe39fe9" @@ -1080,6 +1262,12 @@ git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" version = "1.5.3" +[[deps.RealDot]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" +uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +version = "0.1.0" + [[deps.RecipesBase]] deps = ["SnoopPrecompile"] git-tree-sha1 = "261dddd3b862bd2c940cf6ca4d1c8fe593e457c8" @@ -1123,6 +1311,7 @@ version = "0.4.0+0" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" [[deps.SIMDTypes]] git-tree-sha1 = "330289636fb8107c5f32088d2741e9fd7a061a5c" @@ -1167,6 +1356,11 @@ git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" version = "1.1.1" +[[deps.ShowCases]] +git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" +uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" +version = "0.1.0" + [[deps.Showoff]] deps = ["Dates", "Grisu"] git-tree-sha1 = "91eddf657aca81df9ae6ceb20b959ae5653ad1de" @@ -1178,6 +1372,12 @@ git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" version = "1.1.0" +[[deps.SimpleTraits]] +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" +uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" +version = "0.9.4" + [[deps.SnoopPrecompile]] deps = ["Preferences"] git-tree-sha1 = "e760a70afdcd461cf01a575947738d359234665c" @@ -1203,6 +1403,12 @@ git-tree-sha1 = "ef28127915f4229c971eb43f3fc075dd3fe91880" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" version = "2.2.0" +[[deps.SplittablesBase]] +deps = ["Setfield", "Test"] +git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" +uuid = "171d559e-b47b-412a-8079-5efa626c420e" +version = "0.1.15" + [[deps.StableRNGs]] deps = ["Random", "Test"] git-tree-sha1 = "3be7d49667040add7ee151fefaf1f8c04c8c8276" @@ -1278,6 +1484,7 @@ uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.0" [[deps.TableTraits]] deps = ["IteratorInterfaceExtensions"] @@ -1294,6 +1501,7 @@ version = "1.10.0" [[deps.Tar]] deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.1" [[deps.TensorCore]] deps = ["LinearAlgebra"] @@ -1323,6 +1531,12 @@ git-tree-sha1 = "94f38103c984f89cf77c402f2a68dbd870f8165f" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" version = "0.9.11" +[[deps.Transducers]] +deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] +git-tree-sha1 = "c42fa452a60f022e9e087823b47e5a5f8adc53d5" +uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" +version = "0.4.75" + [[deps.URIs]] git-tree-sha1 = "074f993b0ca030848b897beff716d93aca60f06a" uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" @@ -1516,6 +1730,7 @@ version = "1.4.0+3" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.12+3" [[deps.Zstd_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -1523,6 +1738,18 @@ git-tree-sha1 = "c6edfe154ad7b313c01aceca188c05c835c67360" uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" version = "1.5.4+0" +[[deps.Zygote]] +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "e1af683167eea952684188f5e1e29b9cabc2e5f9" +uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.6.55" + +[[deps.ZygoteRules]] +deps = ["ChainRulesCore", "MacroTools"] +git-tree-sha1 = "977aed5d006b840e2e40c0b48984f7463109046d" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.3" + [[deps.fzf_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "868e669ccb12ba16eaf50cb2957ee2ff61261c56" @@ -1544,6 +1771,7 @@ version = "0.15.1+0" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.1.1+0" [[deps.libfdk_aac_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -1566,10 +1794,12 @@ version = "1.3.7+1" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.48.0+0" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+0" [[deps.x264_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] diff --git a/test/Project.toml b/test/Project.toml index dc9215c..3dd471b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,6 +5,7 @@ EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" LightGBM = "7acf609c-83a4-11e9-1ffb-b912bcd3b04a" MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661" +MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" NearestNeighborModels = "636a865e-7cf4-491e-846c-de09b730eb36"