From ad11d87882067166962925ea45f2ee3309f876f3 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 8 Nov 2023 15:59:58 +0100 Subject: [PATCH 01/11] hmm --- Project.toml | 8 +- docs/Manifest.toml | 564 +++++++++++------- docs/setup_docs.jl | 1 + docs/src/tutorials/training.qmd | 25 +- .../ConformalTraining/classifier.jl | 6 + .../ConformalTraining/losses.jl | 15 +- .../ConformalTraining/regressor.jl | 6 + .../ConformalTraining/training.jl | 5 +- 8 files changed, 393 insertions(+), 237 deletions(-) diff --git a/Project.toml b/Project.toml index dfe7dfae..e639ddb9 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +InferOpt = "4846b161-c94e-4150-8dac-c7ae193c601f" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" @@ -21,15 +22,16 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] CategoricalArrays = "0.10" ChainRules = "1.49.0" ComputationalResources = "0.3" Flux = "0.13.16, 0.14" -MLJBase = "0.20, 0.21" -MLJEnsembles = "0.3.3" -MLJFlux = "0.2.10, 0.3" +MLJBase = "0.20, 0.21, 1" +MLJEnsembles = "0.3.3, 0.4" +MLJFlux = "0.2.10, 0.3, 0.4" MLJModelInterface = "1" MLUtils = "0.4.2" NaturalSort = "1" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index b33bf6fa..c944b474 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -39,9 +39,9 @@ version = "0.4.4" [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24" +git-tree-sha1 = "02f731463748db57cc2ebfbd9fbc9ce8280d3433" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.6.2" +version = "3.7.1" weakdeps = ["StaticArrays"] [deps.Adapt.extensions] @@ -76,9 +76,9 @@ version = "3.5.1+1" [[deps.ArrayInterface]] deps = ["Adapt", "LinearAlgebra", "Requires", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "f83ec24f76d4c8f525099b2ac475fc098138ec31" +git-tree-sha1 = "16267cf279190ca7c1b30d020758ced95db89cd0" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "7.4.11" +version = "7.5.1" [deps.ArrayInterface.extensions] ArrayInterfaceBandedMatricesExt = "BandedMatrices" @@ -96,12 +96,6 @@ version = "7.4.11" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -[[deps.ArrayInterfaceCore]] -deps = ["LinearAlgebra", "SnoopPrecompile", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "e5f08b5689b1aad068e01751889f2f615c7db36d" -uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2" -version = "0.1.29" - [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" @@ -113,9 +107,9 @@ version = "0.1.0" [[deps.AtomsBase]] deps = ["LinearAlgebra", "PeriodicTable", "Printf", "Requires", "StaticArrays", "Unitful", "UnitfulAtomic"] -git-tree-sha1 = "c9804781ca49261c8eb6ce4b62f171cfa3d900f0" +git-tree-sha1 = "995c2b6b17840cd87b722ce9c6cdd72f47bab545" uuid = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" -version = "0.3.4" +version = "0.3.5" [[deps.AxisAlgorithms]] deps = ["LinearAlgebra", "Random", "SparseArrays", "WoodburyMatrices"] @@ -191,10 +185,10 @@ uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" version = "1.2.1" [[deps.BytePairEncoding]] -deps = ["DoubleArrayTries", "StructWalk", "TextEncodeBase", "Unicode"] -git-tree-sha1 = "91752c465dfbdd55837a18f9aa9e6d20899658e9" +deps = ["Artifacts", "Base64", "DataStructures", "DoubleArrayTries", "LazyArtifacts", "StructWalk", "TextEncodeBase", "Unicode"] +git-tree-sha1 = "295253961b9bcb1020bfd8711c7b51311dbfa102" uuid = "a4280ba5-8788-555a-8ca8-4a8c3d966a71" -version = "0.3.2" +version = "0.4.1" [[deps.Bzip2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -282,25 +276,34 @@ weakdeps = ["JSON", "RecipesBase", "SentinelArrays", "StructTypes"] [[deps.CategoricalDistributions]] deps = ["CategoricalArrays", "Distributions", "Missings", "OrderedCollections", "Random", "ScientificTypes"] -git-tree-sha1 = "ed760a4fde49997ff9360a780abe6e20175162aa" +git-tree-sha1 = "3124343a1b0c9a2f5fdc1d9bcc633ba11735a4c4" uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e" -version = "0.1.11" +version = "0.1.13" weakdeps = ["UnicodePlots"] [deps.CategoricalDistributions.extensions] UnivariateFiniteDisplayExt = "UnicodePlots" +[[deps.Chain]] +git-tree-sha1 = "8c4920235f6c561e401dfe569beb8b924adad003" +uuid = "8be319e6-bccf-4806-a6f7-6fae938471bc" +version = "0.5.0" + [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] -git-tree-sha1 = "dbeca245b0680f5393b4e6c40dcead7230ab0b3b" +git-tree-sha1 = "710940598100496ad6cbb707e481c28186354197" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.54.0" +version = "1.57.0" [[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "e30f2f4e20f7f186dc36529910beaedc60cfa644" +deps = ["Compat", "LinearAlgebra"] +git-tree-sha1 = "e0af648f0692ec1691b5d094b8724ba1346281cf" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.16.0" +version = "1.18.0" +weakdeps = ["SparseArrays"] + + [deps.ChainRulesCore.extensions] + ChainRulesCoreSparseArraysExt = "SparseArrays" [[deps.Chemfiles]] deps = ["AtomsBase", "Chemfiles_jll", "DocStringExtensions", "PeriodicTable", "Unitful", "UnitfulAtomic"] @@ -322,21 +325,21 @@ version = "0.1.12" [[deps.Clustering]] deps = ["Distances", "LinearAlgebra", "NearestNeighbors", "Printf", "Random", "SparseArrays", "Statistics", "StatsBase"] -git-tree-sha1 = "b86ac2c5543660d238957dbde5ac04520ae977a7" +git-tree-sha1 = "05f9816a77231b07e634ab8715ba50e5249d6f76" uuid = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" -version = "0.15.4" +version = "0.15.5" [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "02aa26a4cf76381be7f66e020a3eddeb27b0a092" +git-tree-sha1 = "cd67fc487743b2f0fd4380d4cbd3a24660d0eec8" uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.2" +version = "0.7.3" [[deps.ColorSchemes]] deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "d9a8f86737b665e15a9641ecbac64deef9ce6724" +git-tree-sha1 = "67c1f244b991cad9b0aa4b7540fb758c2488b129" uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.23.0" +version = "3.24.0" [[deps.ColorTypes]] deps = ["FixedPointNumbers", "Random"] @@ -369,9 +372,9 @@ version = "0.3.0" [[deps.Compat]] deps = ["UUIDs"] -git-tree-sha1 = "e460f044ca8b99be31d35fe54fc33a5c33dd8ed7" +git-tree-sha1 = "8a62af3e248a8c4bad6b32cbbe663ae02275e32c" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.9.0" +version = "4.10.0" weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] @@ -386,13 +389,11 @@ version = "1.0.5+0" git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" version = "0.1.2" +weakdeps = ["InverseFunctions"] [deps.CompositionsBase.extensions] CompositionsBaseInverseFunctionsExt = "InverseFunctions" - [deps.CompositionsBase.weakdeps] - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - [[deps.ComputationalResources]] git-tree-sha1 = "52cb3ec90e8a8bea0e62e275ba577ad0f74821f7" uuid = "ed09eef8-17a6-5b46-8889-db040fac31e3" @@ -400,15 +401,15 @@ version = "0.3.2" [[deps.ConcurrentUtilities]] deps = ["Serialization", "Sockets"] -git-tree-sha1 = "5372dbbf8f0bdb8c700db5367132925c0771ef7e" +git-tree-sha1 = "8cfa272e8bdedfa88b6aefbbca7c19f1befac519" uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" -version = "2.2.1" +version = "2.3.0" [[deps.ConformalPrediction]] -deps = ["CategoricalArrays", "ChainRules", "ComputationalResources", "Flux", "LazyArtifacts", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "NaturalSort", "Plots", "ProgressMeter", "Random", "Serialization", "StatsBase", "Tables"] +deps = ["CategoricalArrays", "ChainRules", "ComputationalResources", "Flux", "InferOpt", "LazyArtifacts", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "NaturalSort", "ProgressMeter", "Random", "Serialization", "StatsBase", "Tables", "cuDNN"] path = ".." uuid = "98bfc277-1877-43dc-819b-a3e38c30242f" -version = "0.1.9" +version = "0.1.12" [[deps.ConstructionBase]] deps = ["LinearAlgebra"] @@ -439,10 +440,10 @@ uuid = "150eb455-5306-5404-9cee-2592286d6298" version = "0.6.3" [[deps.CounterfactualExplanations]] -deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "PrecompileTools", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "UUIDs", "cuDNN"] -git-tree-sha1 = "6cd46eda67b19a577a14eb8a7287197fcb8e9954" +deps = ["CSV", "CUDA", "CategoricalArrays", "Chain", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "Plots", "ProgressMeter", "Random", "Serialization", "SnoopPrecompile", "Statistics", "StatsBase", "Tables", "UMAP"] +git-tree-sha1 = "8393721ffa3c9be209a93eb154d0d9fe9ca187d5" uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" -version = "0.1.20" +version = "0.1.15" [deps.CounterfactualExplanations.extensions] MPIExt = "MPI" @@ -504,9 +505,9 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" [[deps.DecisionTree]] deps = ["AbstractTrees", "DelimitedFiles", "LinearAlgebra", "Random", "ScikitLearnBase", "Statistics"] -git-tree-sha1 = "c6475a3ccad06cb1c2ebc0740c1bb4fe5a0731b7" +git-tree-sha1 = "526ca14aaaf2d5a0e242f3a8a7966eb9065d7d78" uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" -version = "0.12.3" +version = "0.12.4" [[deps.DefineSingletons]] git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" @@ -519,6 +520,12 @@ git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" version = "1.9.1" +[[deps.DensityInterface]] +deps = ["InverseFunctions", "Test"] +git-tree-sha1 = "80c3e8639e3353e5d2912fb3a1916b8455e2494b" +uuid = "b429d917-457f-4dbc-8f4c-0cc954292b1d" +version = "0.4.0" + [[deps.DiffResults]] deps = ["StaticArraysCore"] git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" @@ -533,12 +540,13 @@ version = "1.15.1" [[deps.Distances]] deps = ["LinearAlgebra", "Statistics", "StatsAPI"] -git-tree-sha1 = "b6def76ffad15143924a2199f72a5cd883a2e8a9" +git-tree-sha1 = "5225c965635d8c21168e32a12954675e7bea1151" uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.9" -weakdeps = ["SparseArrays"] +version = "0.10.10" +weakdeps = ["ChainRulesCore", "SparseArrays"] [deps.Distances.extensions] + DistancesChainRulesCoreExt = "ChainRulesCore" DistancesSparseArraysExt = "SparseArrays" [[deps.Distributed]] @@ -546,18 +554,16 @@ deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[deps.Distributions]] -deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "Test"] -git-tree-sha1 = "938fe2981db009f531b6332e31c58e9584a2f9bd" +deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] +git-tree-sha1 = "a6c00f894f24460379cb7136633cef54ac9f6f4a" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.100" +version = "0.25.103" +weakdeps = ["ChainRulesCore", "DensityInterface", "Test"] [deps.Distributions.extensions] DistributionsChainRulesCoreExt = "ChainRulesCore" DistributionsDensityInterfaceExt = "DensityInterface" - - [deps.Distributions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" + DistributionsTestExt = "Test" [[deps.DocStringExtensions]] deps = ["LibGit2"] @@ -566,10 +572,10 @@ uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" version = "0.9.3" [[deps.Documenter]] -deps = ["ANSIColoredPrinters", "Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] -git-tree-sha1 = "39fd748a73dce4c05a9655475e437170d8fb1b67" +deps = ["ANSIColoredPrinters", "AbstractTrees", "Base64", "Dates", "DocStringExtensions", "Downloads", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "MarkdownAST", "Pkg", "PrecompileTools", "REPL", "RegistryInstances", "SHA", "Test", "Unicode"] +git-tree-sha1 = "662fb21ae7fad33e044c2b59ece832fdce32c171" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.27.25" +version = "1.1.2" [[deps.DoubleArrayTries]] deps = ["OffsetArrays", "Preferences", "StringViews"] @@ -625,11 +631,21 @@ git-tree-sha1 = "98fdf08b707aaf69f524a6cd0a67858cefe0cfb6" uuid = "792122b4-ca99-40de-a6bc-6742525f08b6" version = "0.3.0" +[[deps.EpollShim_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "8e9441ee83492030ace98f9789a654a6d0b1f643" +uuid = "2702e6a9-849d-5ed8-8c21-79e8b8f9ee43" +version = "0.0.20230411+0" + [[deps.EvoTrees]] -deps = ["BSON", "CUDA", "CategoricalArrays", "Distributions", "MLJModelInterface", "NetworkLayout", "Random", "RecipesBase", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "a1fa1d1743478394a0a7188d054b67546e4ca143" +deps = ["BSON", "CategoricalArrays", "Distributions", "MLJModelInterface", "NetworkLayout", "Random", "RecipesBase", "Statistics", "StatsBase", "Tables"] +git-tree-sha1 = "6bae99c964218fcb9af8b0cca80f9bd278d59dcb" uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" -version = "0.16.1" +version = "0.16.4" +weakdeps = ["CUDA"] + + [deps.EvoTrees.extensions] + EvoTreesCUDAExt = "CUDA" [[deps.ExceptionUnwrapping]] deps = ["Test"] @@ -649,9 +665,9 @@ uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" version = "0.1.10" [[deps.Extents]] -git-tree-sha1 = "5e1e4c53fa39afe63a7d356e30452249365fba99" +git-tree-sha1 = "2140cd04483da90b2da7f99b2add0750504fc39c" uuid = "411431e0-e8b7-467b-b5e0-f676ba4f2910" -version = "0.1.1" +version = "0.1.2" [[deps.FFMPEG]] deps = ["FFMPEG_jll"] @@ -709,18 +725,18 @@ version = "1.16.1" [[deps.FilePathsBase]] deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] -git-tree-sha1 = "e27c4ebe80e8699540f2d6c805cc12203b614f12" +git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa" uuid = "48062228-2e41-5def-b9a4-89aafe57970f" -version = "0.9.20" +version = "0.9.21" [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.FillArrays]] deps = ["LinearAlgebra", "Random"] -git-tree-sha1 = "a20eaa3ad64254c61eeb5f230d9306e937405434" +git-tree-sha1 = "35f0c0f345bff2c6d636f95fdb136323b5a796ef" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.6.1" +version = "1.7.0" weakdeps = ["SparseArrays", "Statistics"] [deps.FillArrays.extensions] @@ -826,9 +842,9 @@ version = "3.3.8+0" [[deps.GLM]] deps = ["Distributions", "LinearAlgebra", "Printf", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "StatsModels"] -git-tree-sha1 = "97829cfda0df99ddaeaafb5b370d6cab87b7013e" +git-tree-sha1 = "273bd1cd30768a2fddfa3fd63bbc746ed7249e5f" uuid = "38e38edf-8417-5370-95a0-9cbb8c7f171a" -version = "1.8.3" +version = "1.9.0" [[deps.GPUArrays]] deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] @@ -868,9 +884,9 @@ version = "0.6.1" [[deps.GeoInterface]] deps = ["Extents"] -git-tree-sha1 = "bb198ff907228523f3dee1070ceee63b9359b6ab" +git-tree-sha1 = "d53480c0793b13341c40199190f92c611aa2e93c" uuid = "cf35fbd7-0cd7-5166-be24-54bfbe79505f" -version = "1.3.1" +version = "1.3.2" [[deps.GeometryBasics]] deps = ["EarCut_jll", "Extents", "GeoInterface", "IterTools", "LinearAlgebra", "StaticArrays", "StructArrays", "Tables"] @@ -909,9 +925,9 @@ version = "1.3.14+0" [[deps.Graphs]] deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] -git-tree-sha1 = "1cf1d7dcb4bc32d7b4a5add4232db3750c27ecb4" +git-tree-sha1 = "899050ace26649433ef1af25bc17a815b3db52b7" uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" -version = "1.8.0" +version = "1.9.0" [[deps.Grisu]] git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2" @@ -938,9 +954,9 @@ version = "1.0.1" [[deps.HTTP]] deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] -git-tree-sha1 = "19e974eced1768fb46fd6020171f2cec06b1edb5" +git-tree-sha1 = "5eab648309e2e060198b45820af1a37182de3cce" uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "1.9.15" +version = "1.10.0" [[deps.HarfBuzz_jll]] deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"] @@ -974,9 +990,9 @@ version = "0.0.4" [[deps.HypertextLiteral]] deps = ["Tricks"] -git-tree-sha1 = "c47c5fa4c5308f27ccaac35504858d8914e102f9" +git-tree-sha1 = "7134810b1afce04bbc1045ca1985fbe81ce17653" uuid = "ac1192a8-f4b3-4bfe-ba22-af5b92cd3ab2" -version = "0.9.4" +version = "0.9.5" [[deps.IOCapture]] deps = ["Logging", "Random"] @@ -986,9 +1002,9 @@ version = "0.2.3" [[deps.IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "eac00994ce3229a464c2847e956d77a2c64ad3a5" +git-tree-sha1 = "8aa91235360659ca7560db43a7d57541120aa31d" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.10" +version = "0.4.11" [[deps.IfElse]] git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" @@ -1102,10 +1118,22 @@ git-tree-sha1 = "012e604e1c7458645cb8b436f8fba789a51b257f" uuid = "9b13fd28-a010-5f03-acff-a1bbcff69959" version = "1.0.0" +[[deps.InferOpt]] +deps = ["ChainRulesCore", "DensityInterface", "LinearAlgebra", "Random", "RequiredInterfaces", "Statistics", "StatsBase", "StatsFuns", "ThreadsX"] +git-tree-sha1 = "13f3f8e166390e31f45f989653674a76d14ab252" +uuid = "4846b161-c94e-4150-8dac-c7ae193c601f" +version = "0.6.0" + + [deps.InferOpt.extensions] + InferOptFrankWolfeExt = "DifferentiableFrankWolfe" + + [deps.InferOpt.weakdeps] + DifferentiableFrankWolfe = "b383313e-5450-4164-a800-befbd27b574d" + [[deps.Inflate]] -git-tree-sha1 = "5cd07aab533df5170988219191dfad0519391428" +git-tree-sha1 = "ea8031dea4aff6bd41f1df8f2fdfb25b33626381" uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" -version = "0.1.3" +version = "0.1.4" [[deps.InitialValues]] git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" @@ -1148,14 +1176,20 @@ version = "0.14.7" [[deps.IntervalSets]] deps = ["Dates", "Random"] -git-tree-sha1 = "8e59ea773deee525c99a8018409f64f19fb719e6" +git-tree-sha1 = "3d8866c029dd6b16e69e0d4a939c4dfcb98fac47" uuid = "8197267c-284f-5f27-9208-e0e47529a953" -version = "0.7.7" +version = "0.7.8" weakdeps = ["Statistics"] [deps.IntervalSets.extensions] IntervalSetsStatisticsExt = "Statistics" +[[deps.InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "68772f49f54b479fa88ace904f6127f0a3bb2e46" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.12" + [[deps.InvertedIndices]] git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" @@ -1179,9 +1213,9 @@ version = "0.5.3" [[deps.IterativeSolvers]] deps = ["LinearAlgebra", "Printf", "Random", "RecipesBase", "SparseArrays"] -git-tree-sha1 = "1169632f425f79429f245113b775a0e3d121457c" +git-tree-sha1 = "b435d190ef8369cf4d79cc9dd5fba88ba0165307" uuid = "42fd0dbc-a981-5370-80f2-aaf504508153" -version = "0.9.2" +version = "0.9.3" [[deps.IteratorInterfaceExtensions]] git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" @@ -1189,16 +1223,16 @@ uuid = "82899510-4779-5014-852e-03e436cf321d" version = "1.0.0" [[deps.JLD2]] -deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "Printf", "Reexport", "Requires", "TranscodingStreams", "UUIDs"] -git-tree-sha1 = "aa6ffef1fd85657f4999030c52eaeec22a279738" +deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Printf", "Reexport", "Requires", "TranscodingStreams", "UUIDs"] +git-tree-sha1 = "9bbb5130d3b4fa52846546bca4791ecbdfb52730" uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.4.33" +version = "0.4.38" [[deps.JLFzf]] deps = ["Pipe", "REPL", "Random", "fzf_jll"] -git-tree-sha1 = "f377670cda23b6b7c1c0b3893e37451c5c1a2185" +git-tree-sha1 = "9fb0b890adab1c0a4a475d4210d51f228bfc250d" uuid = "1019f520-868f-41f5-a6de-eb00f4b6a39c" -version = "0.1.5" +version = "0.1.6" [[deps.JLLWrappers]] deps = ["Artifacts", "Preferences"] @@ -1220,9 +1254,9 @@ version = "1.13.2" [[deps.JpegTurbo]] deps = ["CEnum", "FileIO", "ImageCore", "JpegTurbo_jll", "TOML"] -git-tree-sha1 = "327713faef2a3e5c80f96bf38d1fa26f7a6ae29e" +git-tree-sha1 = "d65930fa2bc96b07d7691c652d701dcbe7d9cf0b" uuid = "b835a17e-a41a-41e7-81f0-2f016b05efe0" -version = "0.1.3" +version = "0.1.4" [[deps.JpegTurbo_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -1238,9 +1272,9 @@ version = "0.2.4" [[deps.KernelAbstractions]] deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "4c5875e4c228247e1c2b087669846941fb6e0118" +git-tree-sha1 = "95063c5bc98ba0c47e75e05ae71f1fed4deac6f6" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.8" +version = "0.9.12" [deps.KernelAbstractions.extensions] EnzymeExt = "EnzymeCore" @@ -1267,16 +1301,20 @@ uuid = "88015f11-f218-50d7-93a8-a6af411a945d" version = "3.0.0+1" [[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "a9d2ce1d5007b1e8f6c5b89c5a31ff8bd146db5c" +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] +git-tree-sha1 = "c879e47398a7ab671c782e02b51a4456794a7fa3" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "6.2.1" +version = "6.4.0" +weakdeps = ["BFloat16s"] + + [deps.LLVM.extensions] + BFloat16sExt = "BFloat16s" [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "7ca6850ae880cc99b59b88517545f91a52020afa" +git-tree-sha1 = "a84f8f1e8caaaa4e3b4c101306b9e801d3883ace" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.25+0" +version = "0.0.27+0" [[deps.LLVMOpenMP_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -1284,6 +1322,11 @@ git-tree-sha1 = "f689897ccbe049adb19a065c495e75f372ecd42b" uuid = "1d63c593-3942-5779-bab2-d838dc0a180e" version = "15.0.4+0" +[[deps.LRUCache]] +git-tree-sha1 = "d36130483e3b6e4cd88d81633b596563264f15db" +uuid = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" +version = "1.5.0" + [[deps.LZO_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "e5b909bcf985c5e2605737d2ce278ed791b89be6" @@ -1291,9 +1334,9 @@ uuid = "dd4b983a-f0e5-5f8d-a1b7-129d4a5fb1ac" version = "2.10.1+0" [[deps.LaTeXStrings]] -git-tree-sha1 = "f2355693d6778a178ade15952b7ac47a4ff97996" +git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" -version = "1.3.0" +version = "1.3.1" [[deps.LaplaceRedux]] deps = ["CSV", "Compat", "ComputationalResources", "DataFrames", "Flux", "LinearAlgebra", "MLJ", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Parameters", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"] @@ -1323,9 +1366,14 @@ version = "1.9.0" [[deps.LayoutPointers]] deps = ["ArrayInterface", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static", "StaticArrayInterface"] -git-tree-sha1 = "88b8f66b604da079a627b6fb2860d3704a6729a1" +git-tree-sha1 = "62edfee3211981241b57ff1cedf4d74d79519277" uuid = "10f19ff3-798f-405d-979b-55457f8fc047" -version = "0.1.14" +version = "0.1.15" + +[[deps.LazilyInitializedFields]] +git-tree-sha1 = "410fe4739a4b092f2ffe36fcb0dcc3ab12648ce1" +uuid = "0e77f7df-68c5-4e49-93ce-4cd80f5598bf" +version = "1.2.1" [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] @@ -1430,9 +1478,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.LinearMaps]] deps = ["LinearAlgebra"] -git-tree-sha1 = "6698ab5e662b47ffc63a82b2f43c1cee015cf80d" +git-tree-sha1 = "9df2ab050ffefe870a09c7b6afdb0cde381703f2" uuid = "7a12625a-238d-50fd-b39a-03d52299707e" -version = "3.11.0" +version = "3.11.1" weakdeps = ["ChainRulesCore", "SparseArrays", "Statistics"] [deps.LinearMaps.extensions] @@ -1461,15 +1509,15 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[deps.LoggingExtras]] deps = ["Dates", "Logging"] -git-tree-sha1 = "0d097476b6c381ab7906460ef1ef1638fbce1d91" +git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" -version = "1.0.2" +version = "1.0.3" [[deps.LoopVectorization]] -deps = ["ArrayInterface", "ArrayInterfaceCore", "CPUSummary", "CloseOpenIntervals", "DocStringExtensions", "HostCPUFeatures", "IfElse", "LayoutPointers", "LinearAlgebra", "OffsetArrays", "PolyesterWeave", "PrecompileTools", "SIMDTypes", "SLEEFPirates", "Static", "StaticArrayInterface", "ThreadingUtilities", "UnPack", "VectorizationBase"] -git-tree-sha1 = "c88a4afe1703d731b1c4fdf4e3c7e77e3b176ea2" +deps = ["ArrayInterface", "CPUSummary", "CloseOpenIntervals", "DocStringExtensions", "HostCPUFeatures", "IfElse", "LayoutPointers", "LinearAlgebra", "OffsetArrays", "PolyesterWeave", "PrecompileTools", "SIMDTypes", "SLEEFPirates", "Static", "StaticArrayInterface", "ThreadingUtilities", "UnPack", "VectorizationBase"] +git-tree-sha1 = "0f5648fbae0d015e3abe5867bca2b362f67a5894" uuid = "bdcacae8-1622-11e9-2a5c-532679323890" -version = "0.12.165" +version = "0.12.166" weakdeps = ["ChainRulesCore", "ForwardDiff", "SpecialFunctions"] [deps.LoopVectorization.extensions] @@ -1486,11 +1534,17 @@ weakdeps = ["CategoricalArrays"] [deps.LossFunctions.extensions] LossFunctionsCategoricalArraysExt = "CategoricalArrays" +[[deps.LsqFit]] +deps = ["Distributions", "ForwardDiff", "LinearAlgebra", "NLSolversBase", "OptimBase", "Random", "StatsBase"] +git-tree-sha1 = "00f475f85c50584b12268675072663dfed5594b2" +uuid = "2fda8390-95c7-5789-9bda-21331edee243" +version = "0.13.0" + [[deps.MAT]] deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] -git-tree-sha1 = "79fd0b5ee384caf8ebba6c8fb3f365ca3e2c5493" +git-tree-sha1 = "ed1cf0a322d78cee07718bed5fd945e2218c35a1" uuid = "23992714-dd62-5051-b70f-ba57cb901cac" -version = "0.10.5" +version = "0.10.6" [[deps.MIMEs]] git-tree-sha1 = "65f28ad4b594aebe22157d6fac869786a255b7eb" @@ -1505,9 +1559,9 @@ version = "2023.2.0+0" [[deps.MLDatasets]] deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"] -git-tree-sha1 = "10bc70e4c875f1b2ca65cef3ef9ebe705ef936b5" +git-tree-sha1 = "aab72207b3c687086a400be710650a57494992bd" uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" -version = "0.7.13" +version = "0.7.14" [[deps.MLFlowClient]] deps = ["Dates", "FilePathsBase", "HTTP", "JSON", "ShowCases", "URIs", "UUIDs"] @@ -1565,27 +1619,27 @@ version = "0.5.1" [[deps.MLJLinearModels]] deps = ["DocStringExtensions", "IterativeSolvers", "LinearAlgebra", "LinearMaps", "MLJModelInterface", "Optim", "Parameters"] -git-tree-sha1 = "c92bf0ea37bf51e1ef0160069c572825819748b8" +git-tree-sha1 = "7f517fd840ca433a8fae673edb31678ff55d969c" uuid = "6ee0df7b-362f-4a72-a706-9e79364fb692" -version = "0.9.2" +version = "0.10.0" [[deps.MLJModelInterface]] deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] -git-tree-sha1 = "03ae109be87f460fe3c96b8a0dbbf9c7bf840bd5" +git-tree-sha1 = "381d99f0af76d98f50bd5512dcf96a99c13f8223" uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" -version = "1.9.2" +version = "1.9.3" [[deps.MLJModels]] deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "2b49f04f70266a2b040eb46ece157c4f5c1b0c13" +git-tree-sha1 = "10d221910fc3f3eedad567178ddbca3cc0f776a3" uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7" -version = "0.16.10" +version = "0.16.12" [[deps.MLJMultivariateStatsInterface]] -deps = ["CategoricalDistributions", "Distances", "LinearAlgebra", "MLJModelInterface", "MultivariateStats", "StatsBase"] -git-tree-sha1 = "0d76e36bf83926235dcd3eaeafa7f47d3e7f32ea" +deps = ["Distances", "LinearAlgebra", "MLJModelInterface", "MultivariateStats", "StatsBase"] +git-tree-sha1 = "a282960828015daf766b4d66ba75445b0c909099" uuid = "1b6a4a23-ba22-4f51-9698-8599985d3728" -version = "0.5.3" +version = "0.4.0" [[deps.MLJNaiveBayesInterface]] deps = ["LogExpFunctions", "MLJModelInterface", "NaiveBayes"] @@ -1616,6 +1670,12 @@ git-tree-sha1 = "9ee1618cbf5240e6d4e0371d6f24065083f60c48" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" version = "0.5.11" +[[deps.ManifoldLearning]] +deps = ["Combinatorics", "Graphs", "LinearAlgebra", "MultivariateStats", "Random", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "4c5564c707899c3b6bc6d324b05e43eb7f277f2b" +uuid = "06eb3307-b2af-5a2a-abea-d33192699d32" +version = "0.9.0" + [[deps.ManualMemory]] git-tree-sha1 = "bcaef4fc7a0cfe2cba636d84cda54b5e4e4ca3cd" uuid = "d125e4d3-2237-4719-b19c-fa641b8a4667" @@ -1636,6 +1696,12 @@ version = "0.1.8" deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +[[deps.MarkdownAST]] +deps = ["AbstractTrees", "Markdown"] +git-tree-sha1 = "465a70f0fc7d443a00dcdc3267a497397b8a3899" +uuid = "d0879d2d-cac2-40c8-9cee-1863dc0c7391" +version = "0.1.2" + [[deps.MbedTLS]] deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "Random", "Sockets"] git-tree-sha1 = "03a9b9718f5682ecb107ac9f7308991db4ce395b" @@ -1660,9 +1726,9 @@ version = "0.7.2" [[deps.Metalhead]] deps = ["Artifacts", "BSON", "CUDA", "ChainRulesCore", "Flux", "Functors", "JLD2", "LazyArtifacts", "MLUtils", "NNlib", "PartialFunctions", "Random", "Statistics"] -git-tree-sha1 = "c093734078e92a4edcf54e850af68ef8cc2c9e03" +git-tree-sha1 = "4bbdb628c60c5f473148292df3ecb87058ba515f" uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" -version = "0.8.2" +version = "0.8.3" [[deps.MicroCollections]] deps = ["BangBang", "InitialValues", "Setfield"] @@ -1690,10 +1756,10 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159" version = "2022.10.11" [[deps.MultivariateStats]] -deps = ["Arpack", "LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI", "StatsBase"] -git-tree-sha1 = "68bf5103e002c44adfd71fea6bd770b3f0586843" +deps = ["Arpack", "LinearAlgebra", "SparseArrays", "Statistics", "StatsBase"] +git-tree-sha1 = "6d019f5a0465522bbfdd68ecfad7f86b535d6935" uuid = "6f286f6a-111f-5878-ab1e-185364afe411" -version = "0.10.2" +version = "0.9.0" [[deps.NLSolversBase]] deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"] @@ -1748,6 +1814,12 @@ git-tree-sha1 = "eda490d06b9f7c00752ee81cfa451efe55521e21" uuid = "c020b1a1-e9b0-503a-9c33-f039bfc54a85" version = "1.0.0" +[[deps.NearestNeighborDescent]] +deps = ["DataStructures", "Distances", "Graphs", "Random", "Reexport", "SparseArrays"] +git-tree-sha1 = "b7d4bd2ab58f0c3a001fd6eedc2e0aac8e278152" +uuid = "dd2c4c9e-a32f-5b2f-b342-08c2f244fce8" +version = "0.3.6" + [[deps.NearestNeighborModels]] deps = ["Distances", "FillArrays", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "NearestNeighbors", "Statistics", "StatsBase", "Tables"] git-tree-sha1 = "e411143a8362926e4284a54e745972e939fbab78" @@ -1767,10 +1839,14 @@ uuid = "f09324ee-3d7c-5217-9330-fc30815ba969" version = "1.1.1" [[deps.NetworkLayout]] -deps = ["GeometryBasics", "LinearAlgebra", "Random", "Requires", "SparseArrays", "StaticArrays"] -git-tree-sha1 = "2bfd8cd7fba3e46ce48139ae93904ee848153660" +deps = ["GeometryBasics", "LinearAlgebra", "Random", "Requires", "StaticArrays"] +git-tree-sha1 = "91bb2fedff8e43793650e7a677ccda6e6e6e166b" uuid = "46757867-2c16-5918-afeb-47bfcb05e46a" -version = "0.4.5" +version = "0.4.6" +weakdeps = ["Graphs"] + + [deps.NetworkLayout.extensions] + NetworkLayoutGraphsExt = "Graphs" [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" @@ -1783,9 +1859,9 @@ uuid = "12afc1b8-fad6-47e1-9132-84abc478905f" version = "0.2.12" [[deps.Observables]] -git-tree-sha1 = "6862738f9796b3edc1c09d0890afce4eca9e7e93" +git-tree-sha1 = "7438a59546cf62428fc9d1bc94729146d37a7225" uuid = "510215fc-4207-5dde-b226-833fc4488ee2" -version = "0.5.4" +version = "0.5.5" [[deps.OffsetArrays]] deps = ["Adapt"] @@ -1863,6 +1939,12 @@ git-tree-sha1 = "e3a6546c1577bfd701771b477b794a52949e7594" uuid = "429524aa-4258-5aef-a3af-852621145aeb" version = "1.7.6" +[[deps.OptimBase]] +deps = ["NLSolversBase", "Printf", "Reexport"] +git-tree-sha1 = "9cb1fee807b599b5f803809e85c81b582d2009d6" +uuid = "87e2bd06-a317-5318-96d9-3ecbac512eee" +version = "2.0.2" + [[deps.Optimisers]] deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] git-tree-sha1 = "c1fc26bab5df929a5172f296f25d7d08688fd25b" @@ -1887,20 +1969,20 @@ version = "10.42.0+0" [[deps.PDMats]] deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "67eae2738d63117a196f497d7db789821bce61d1" +git-tree-sha1 = "66b2fcd977db5329aa35cac121e5b94dd6472198" uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.11.17" +version = "0.11.28" [[deps.PNGFiles]] deps = ["Base64", "CEnum", "ImageCore", "IndirectArrays", "OffsetArrays", "libpng_jll"] -git-tree-sha1 = "9b02b27ac477cad98114584ff964e3052f656a0f" +git-tree-sha1 = "5ded86ccaf0647349231ed6c0822c10886d4a1ee" uuid = "f57f5aa1-a3ce-4bc8-8ab9-96f992907883" -version = "0.4.0" +version = "0.4.1" [[deps.PackageExtensionCompat]] -git-tree-sha1 = "f9b1e033c2b1205cf30fd119f4e50881316c1923" +git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" -version = "1.0.1" +version = "1.0.2" weakdeps = ["Requires", "TOML"] [[deps.PaddedViews]] @@ -1922,9 +2004,10 @@ uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" version = "2.7.2" [[deps.PartialFunctions]] -git-tree-sha1 = "b3901ea034cfd8aae57a2fa0dde0b0ea18bad1cb" +deps = ["MacroTools"] +git-tree-sha1 = "47b49a4dbc23b76682205c646252c0f9e1eb75af" uuid = "570af359-4316-4cb7-8c74-252c00c2016b" -version = "1.1.1" +version = "1.2.0" [[deps.PeriodicTable]] deps = ["Base64", "Test", "Unitful"] @@ -2000,9 +2083,9 @@ version = "1.39.0" [[deps.PlutoUI]] deps = ["AbstractPlutoDingetjes", "Base64", "ColorTypes", "Dates", "FixedPointNumbers", "Hyperscript", "HypertextLiteral", "IOCapture", "InteractiveUtils", "JSON", "Logging", "MIMEs", "Markdown", "Random", "Reexport", "URIs", "UUIDs"] -git-tree-sha1 = "e47cd150dbe0443c3a3651bc5b9cbd5576ab75b7" +git-tree-sha1 = "db8ec28846dbf846228a32de5a6912c63e2052e3" uuid = "7f904dfe-b85e-4ff6-b463-dae2292396a8" -version = "0.7.52" +version = "0.7.53" [[deps.PolyesterWeave]] deps = ["BitTwiddlingConvenienceFunctions", "CPUSummary", "IfElse", "Static", "ThreadingUtilities"] @@ -2030,9 +2113,9 @@ version = "4.0.4" [[deps.PooledArrays]] deps = ["DataAPI", "Future"] -git-tree-sha1 = "a6062fe4063cdafe78f4a0a81cfffb89721b30e7" +git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" -version = "1.4.2" +version = "1.4.3" [[deps.PositiveFactorizations]] deps = ["LinearAlgebra"] @@ -2048,9 +2131,9 @@ version = "1.2.0" [[deps.Preferences]] deps = ["TOML"] -git-tree-sha1 = "7eb1686b4f04b82f96ed7a4ea5890a4f0c7a09f1" +git-tree-sha1 = "00805cd429dcb4870060ff49ef443486c262e38e" uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.0" +version = "1.4.1" [[deps.PrettyPrint]] git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" @@ -2064,9 +2147,9 @@ version = "0.4.1" [[deps.PrettyTables]] deps = ["Crayons", "LaTeXStrings", "Markdown", "Printf", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "ee094908d720185ddbdc58dbe0c1cbe35453ec7a" +git-tree-sha1 = "6842ce83a836fbbc0cfeca0b5a4de1a4dcbdb8d1" uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.2.7" +version = "2.2.8" [[deps.PrimitiveOneHot]] deps = ["Adapt", "ChainRulesCore", "NNlib", "Requires"] @@ -2110,9 +2193,9 @@ version = "5.15.3+2" [[deps.QuadGK]] deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "6ec7ac8412e83d57e313393220879ede1740f9ee" +git-tree-sha1 = "9ebcd48c498668c7fa0e97a9cae873fbee7bfee1" uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.8.2" +version = "2.9.1" [[deps.Quaternions]] deps = ["LinearAlgebra", "Random", "RealDot"] @@ -2178,17 +2261,35 @@ git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" uuid = "189a3867-3050-52da-a836-e630ba90ab69" version = "1.2.2" +[[deps.Referenceables]] +deps = ["Adapt"] +git-tree-sha1 = "e681d3bfa49cd46c3c161505caddf20f0e62aaa9" +uuid = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" +version = "0.1.2" + [[deps.RegionTrees]] deps = ["IterTools", "LinearAlgebra", "StaticArrays"] git-tree-sha1 = "4618ed0da7a251c7f92e869ae1a19c74a7d2a7f9" uuid = "dee08c22-ab7f-5625-9660-a9af2021b33f" version = "0.3.2" +[[deps.RegistryInstances]] +deps = ["LazilyInitializedFields", "Pkg", "TOML", "Tar"] +git-tree-sha1 = "ffd19052caf598b8653b99404058fce14828be51" +uuid = "2792f1a3-b283-48e8-9a74-f99dce5104f3" +version = "0.1.0" + [[deps.RelocatableFolders]] deps = ["SHA", "Scratch"] -git-tree-sha1 = "90bc7a7c96410424509e4263e277e43250c05691" +git-tree-sha1 = "ffdaf70d81cf6ff22c2b6e733c900c3321cab864" uuid = "05181044-ff0b-4ac5-8273-598c1e38db00" -version = "1.0.0" +version = "1.0.1" + +[[deps.RequiredInterfaces]] +deps = ["InteractiveUtils", "Logging", "Test"] +git-tree-sha1 = "deb5b451248bbe5ce37cb639a546ac13d07b791f" +uuid = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6" +version = "0.1.4" [[deps.Requires]] deps = ["UUIDs"] @@ -2210,9 +2311,15 @@ version = "0.4.0+0" [[deps.Rotations]] deps = ["LinearAlgebra", "Quaternions", "Random", "StaticArrays"] -git-tree-sha1 = "54ccb4dbab4b1f69beb255a2c0ca5f65a9c82f08" +git-tree-sha1 = "0783924e4a332493f72490253ba4e668aeba1d73" uuid = "6038ab10-8711-5258-84ad-4b1120ba62dc" -version = "1.5.1" +version = "1.6.0" + +[[deps.RustRegex]] +deps = ["rure_jll"] +git-tree-sha1 = "16be5e710d7b980678ec0d8c61d4c00e9a5591e3" +uuid = "cdf36688-0c6d-42c6-a883-5d2df16e9e88" +version = "0.1.0" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -2225,9 +2332,9 @@ version = "0.1.0" [[deps.SLEEFPirates]] deps = ["IfElse", "Static", "VectorizationBase"] -git-tree-sha1 = "4b8586aece42bee682399c4c4aee95446aa5cd19" +git-tree-sha1 = "3aac6d68c5e57449f5b9b865c9ba50ac2970c4cf" uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa" -version = "0.6.39" +version = "0.6.42" [[deps.ScientificTypes]] deps = ["CategoricalArrays", "ColorTypes", "Dates", "Distributions", "PrettyTables", "Reexport", "ScientificTypesBase", "StatisticalTraits", "Tables"] @@ -2248,15 +2355,15 @@ version = "0.5.0" [[deps.Scratch]] deps = ["Dates"] -git-tree-sha1 = "30449ee12237627992a99d5e30ae63e4d78cd24a" +git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" uuid = "6c6a2e73-6563-6170-7368-637461726353" -version = "1.2.0" +version = "1.2.1" [[deps.SentinelArrays]] deps = ["Dates", "Random"] -git-tree-sha1 = "04bdff0b09c65ff3e06a05e3eb7b120223da3d39" +git-tree-sha1 = "0e7508ff27ba32f26cd459474ca2ede1bc10991f" uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" -version = "1.4.0" +version = "1.4.1" [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -2321,9 +2428,9 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc" [[deps.SortingAlgorithms]] deps = ["DataStructures"] -git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee" +git-tree-sha1 = "5165dfb9fd131cf0c6957a3a7605dede376e7b63" uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.1.1" +version = "1.2.0" [[deps.SparseArrays]] deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] @@ -2382,9 +2489,9 @@ weakdeps = ["OffsetArrays", "StaticArrays"] [[deps.StaticArrays]] deps = ["LinearAlgebra", "Random", "StaticArraysCore"] -git-tree-sha1 = "51621cca8651d9e334a659443a74ce50a3b6dfab" +git-tree-sha1 = "0adf069a2a490c47273727e029371b31d44b72b2" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.6.3" +version = "1.6.5" weakdeps = ["Statistics"] [deps.StaticArrays.extensions] @@ -2414,24 +2521,21 @@ version = "1.7.0" [[deps.StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "75ebe04c5bed70b91614d684259b661c9e6274a4" +git-tree-sha1 = "d1bf48bfcc554a3761a133fe3a9bb01488e06916" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.34.0" +version = "0.33.21" [[deps.StatsFuns]] deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] git-tree-sha1 = "f625d686d5a88bcd2b15cd81f18f98186fdc0c9a" uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" version = "1.3.0" +weakdeps = ["ChainRulesCore", "InverseFunctions"] [deps.StatsFuns.extensions] StatsFunsChainRulesCoreExt = "ChainRulesCore" StatsFunsInverseFunctionsExt = "InverseFunctions" - [deps.StatsFuns.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - [[deps.StatsModels]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Printf", "REPL", "ShiftedArrays", "SparseArrays", "StatsAPI", "StatsBase", "StatsFuns", "Tables"] git-tree-sha1 = "5cf6c4583533ee38639f73b880f35fc85f2941e0" @@ -2532,16 +2636,16 @@ uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" version = "1.0.1" [[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"] -git-tree-sha1 = "1544b926975372da01227b382066ab70e574a3ec" +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits"] +git-tree-sha1 = "cb76cf677714c095e535e3501ac7954732aeea2d" uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.10.1" +version = "1.11.1" [[deps.TaijaPlotting]] -deps = ["CategoricalArrays", "ConformalPrediction", "CounterfactualExplanations", "Flux", "LaplaceRedux", "LinearAlgebra", "MLJBase", "MultivariateStats", "NaturalSort", "Plots"] -git-tree-sha1 = "d5d1c9fccd05c4ff9793394c56fc07c81db40eda" +deps = ["CategoricalArrays", "ConformalPrediction", "CounterfactualExplanations", "DataAPI", "Distributions", "Flux", "LaplaceRedux", "LinearAlgebra", "MLJBase", "ManifoldLearning", "MultivariateStats", "NaturalSort", "NearestNeighborModels", "Plots"] +path = "../../TaijaPlotting.jl" uuid = "bd7198b4-c7d6-400c-9bab-9a24614b0240" -version = "1.0.2" +version = "1.0.3" [[deps.Tar]] deps = ["ArgTools", "SHA"] @@ -2559,10 +2663,10 @@ deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[deps.TextEncodeBase]] -deps = ["FuncPipelines", "PartialFunctions", "PrimitiveOneHot", "StaticArrays", "StructWalk", "Unicode", "WordTokenizers"] -git-tree-sha1 = "1304ca2c65d9b28c1e2a78cdf5032348c0c405e5" +deps = ["FuncPipelines", "PartialFunctions", "PrimitiveOneHot", "RustRegex", "StaticArrays", "StructWalk", "Unicode", "WordTokenizers"] +git-tree-sha1 = "4753ea70646cb276a4db65952e59103fbc2b0576" uuid = "f92c20c0-9f2a-4705-8116-881385faba05" -version = "0.6.0" +version = "0.7.0" [[deps.ThreadingUtilities]] deps = ["ManualMemory"] @@ -2570,11 +2674,17 @@ git-tree-sha1 = "eda08f7e9818eb53661b3deb74e3159460dfbc27" uuid = "8290d209-cae3-49c0-8002-c8c24d57dab5" version = "0.5.2" +[[deps.ThreadsX]] +deps = ["ArgCheck", "BangBang", "ConstructionBase", "InitialValues", "MicroCollections", "Referenceables", "Setfield", "SplittablesBase", "Transducers"] +git-tree-sha1 = "34e6bcf36b9ed5d56489600cf9f3c16843fa2aa2" +uuid = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d" +version = "0.1.11" + [[deps.TiffImages]] deps = ["ColorTypes", "DataStructures", "DocStringExtensions", "FileIO", "FixedPointNumbers", "IndirectArrays", "Inflate", "Mmap", "OffsetArrays", "PkgVersion", "ProgressMeter", "UUIDs"] -git-tree-sha1 = "8621f5c499a8aa4aa970b1ae381aae0ef1576966" +git-tree-sha1 = "34cc045dd0aaa59b8bbe86c644679bc57f1d5bd0" uuid = "731e570b-9d59-4bfa-96dc-6df516fadf69" -version = "0.6.4" +version = "0.6.8" [[deps.TiledIteration]] deps = ["OffsetArrays"] @@ -2589,16 +2699,19 @@ uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" version = "0.5.23" [[deps.TranscodingStreams]] -deps = ["Random", "Test"] -git-tree-sha1 = "9a6ae7ed916312b41236fcef7e0af564ef934769" +git-tree-sha1 = "1fbeaaca45801b4ba17c251dd8603ef24801dd84" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.9.13" +version = "0.10.2" +weakdeps = ["Random", "Test"] + + [deps.TranscodingStreams.extensions] + TestExt = ["Test", "Random"] [[deps.Transducers]] deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] -git-tree-sha1 = "53bd5978b182fa7c57577bdb452c35e5b4fb73a5" +git-tree-sha1 = "e579d3c991938fecbb225699e8f611fa3fbf2141" uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.78" +version = "0.4.79" [deps.Transducers.extensions] TransducersBlockArraysExt = "BlockArrays" @@ -2615,31 +2728,49 @@ version = "0.4.78" Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" [[deps.Transformers]] -deps = ["Base64", "BytePairEncoding", "CUDA", "ChainRulesCore", "DataDeps", "DataStructures", "Dates", "DelimitedFiles", "DoubleArrayTries", "Fetch", "FillArrays", "Flux", "FuncPipelines", "Functors", "HTTP", "HuggingFaceApi", "JSON3", "LightXML", "LinearAlgebra", "Mmap", "NNlib", "NNlibCUDA", "NeuralAttentionlib", "Pickle", "Pkg", "PrimitiveOneHot", "Random", "SHA", "Static", "Statistics", "StringViews", "StructWalk", "TextEncodeBase", "Unicode", "ValSplit", "WordTokenizers", "Zygote"] -git-tree-sha1 = "efa5d0441f3f9e6f3e9e5f64e5dacb736edd13ad" +deps = ["Base64", "BytePairEncoding", "CUDA", "ChainRulesCore", "DataDeps", "DataStructures", "Dates", "DelimitedFiles", "DoubleArrayTries", "Fetch", "FillArrays", "Flux", "FuncPipelines", "Functors", "HTTP", "HuggingFaceApi", "JSON3", "LRUCache", "LightXML", "LinearAlgebra", "Mmap", "NNlib", "NNlibCUDA", "NeuralAttentionlib", "Pickle", "Pkg", "PrimitiveOneHot", "Random", "SHA", "Static", "Statistics", "StringViews", "StructWalk", "TextEncodeBase", "Unicode", "ValSplit", "WordTokenizers", "Zygote"] +git-tree-sha1 = "77d831ef0b378cd59c97fd52f1b6e420c972960b" uuid = "21ca0261-441d-5938-ace7-c90938fde4d4" -version = "0.2.7" +version = "0.2.8" [[deps.Tricks]] -git-tree-sha1 = "aadb748be58b492045b4f56166b5188aa63ce549" +git-tree-sha1 = "eae1bb484cd63b36999ee58be2de6c178105112f" uuid = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" -version = "0.1.7" +version = "0.1.8" [[deps.Tullio]] -deps = ["ChainRulesCore", "DiffRules", "LinearAlgebra", "Requires"] -git-tree-sha1 = "7871a39eac745697ee512a87eeff06a048a7905b" +deps = ["DiffRules", "LinearAlgebra", "Requires"] +git-tree-sha1 = "6d476962ba4e435d7f4101a403b1d3d72afe72f3" uuid = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc" -version = "0.3.5" +version = "0.3.7" + + [deps.Tullio.extensions] + TullioCUDAExt = "CUDA" + TullioChainRulesCoreExt = "ChainRulesCore" + TullioFillArraysExt = "FillArrays" + TullioTrackerExt = "Tracker" + + [deps.Tullio.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [[deps.TupleTools]] -git-tree-sha1 = "c8cdc29448afa1a306419f5d1c7af0854c171c80" +git-tree-sha1 = "155515ed4c4236db30049ac1495e2969cc06be9d" uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" -version = "1.4.1" +version = "1.4.3" + +[[deps.UMAP]] +deps = ["Arpack", "Distances", "LinearAlgebra", "LsqFit", "NearestNeighborDescent", "Random", "SparseArrays"] +git-tree-sha1 = "accad220f075445f68caa6488be728957a5d82d6" +uuid = "c4f8c510-2410-5be4-91d7-4fbaeb39457e" +version = "0.1.10" [[deps.URIs]] -git-tree-sha1 = "b7a5e99f24892b6824a954199a45e9ffcc1c70f0" +git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" -version = "1.5.0" +version = "1.5.1" [[deps.UUIDs]] deps = ["Random", "SHA"] @@ -2685,15 +2816,12 @@ deps = ["Dates", "LinearAlgebra", "Random"] git-tree-sha1 = "a72d22c7e13fe2de562feda8645aa134712a87ee" uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" version = "1.17.0" +weakdeps = ["ConstructionBase", "InverseFunctions"] [deps.Unitful.extensions] ConstructionBaseUnitfulExt = "ConstructionBase" InverseFunctionsUnitfulExt = "InverseFunctions" - [deps.Unitful.weakdeps] - ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - [[deps.UnitfulAtomic]] deps = ["Unitful"] git-tree-sha1 = "903be579194534af1c4b4778d1ace676ca042238" @@ -2724,9 +2852,9 @@ version = "0.2.0" [[deps.ValSplit]] deps = ["ExprTools", "Tricks"] -git-tree-sha1 = "0d087f8ddc8eced370cc968eeb3b01db32cb2c01" +git-tree-sha1 = "3e1d94627f9276c40034c80dc5ab29ac1a3b06c0" uuid = "0625e100-946b-11ec-09cd-6328dd093154" -version = "0.1.0" +version = "0.1.1" [[deps.VectorizationBase]] deps = ["ArrayInterface", "CPUSummary", "HostCPUFeatures", "IfElse", "LayoutPointers", "Libdl", "LinearAlgebra", "SIMDTypes", "Static", "StaticArrayInterface"] @@ -2735,10 +2863,10 @@ uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f" version = "0.21.64" [[deps.Wayland_jll]] -deps = ["Artifacts", "Expat_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg", "XML2_jll"] -git-tree-sha1 = "ed8d92d9774b077c53e1da50fd81a36af3744c1c" +deps = ["Artifacts", "EpollShim_jll", "Expat_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg", "XML2_jll"] +git-tree-sha1 = "7558e29847e99bc3f04d6569e82d0f5c54460703" uuid = "a2964d1f-97da-50d4-b82a-358c7fce9d89" -version = "1.21.0+0" +version = "1.21.0+1" [[deps.Wayland_protocols_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -2777,9 +2905,9 @@ version = "1.6.1" [[deps.XML2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"] -git-tree-sha1 = "04a51d15436a572301b5abbb9d099713327e9fc4" +git-tree-sha1 = "24b81b59bd35b3c42ab84fa589086e19be919916" uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" -version = "2.10.4+0" +version = "2.11.5+0" [[deps.XSLT_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "Pkg", "XML2_jll", "Zlib_jll"] @@ -2932,9 +3060,9 @@ version = "1.5.5+0" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "b97c927497c1de55a78dc9030f6068be5d83ef80" +git-tree-sha1 = "5ded212acd815612df112bb895ef3910c5a03f57" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.64" +version = "0.6.67" [deps.Zygote.extensions] ZygoteColorsExt = "Colors" @@ -2948,9 +3076,9 @@ version = "0.6.64" [[deps.ZygoteRules]] deps = ["ChainRulesCore", "MacroTools"] -git-tree-sha1 = "977aed5d006b840e2e40c0b48984f7463109046d" +git-tree-sha1 = "9d749cd449fb448aeca4feee9a2f4186dbb5d184" uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.3" +version = "0.2.4" [[deps.cuDNN]] deps = ["CEnum", "CUDA", "CUDNN_jll"] @@ -2960,9 +3088,9 @@ version = "1.1.1" [[deps.fzf_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "868e669ccb12ba16eaf50cb2957ee2ff61261c56" +git-tree-sha1 = "47cf33e62e138b920039e8ff9f9841aafe1b733e" uuid = "214eeab7-80f7-51ab-84ad-2988db7cef09" -version = "0.29.0+0" +version = "0.35.1+0" [[deps.libaom_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -3015,6 +3143,12 @@ deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" version = "17.4.0+0" +[[deps.rure_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "a24449573502225e7833277f99a8e2c19801f5a7" +uuid = "2a13b4fb-3cbe-5d55-9db2-86fcb16976f1" +version = "0.2.2+0" + [[deps.x264_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "4fea590b89e6ec504593146bf8b988b2c00922b2" @@ -3029,6 +3163,6 @@ version = "3.5.0+0" [[deps.xkbcommon_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Wayland_jll", "Wayland_protocols_jll", "Xorg_libxcb_jll", "Xorg_xkeyboard_config_jll"] -git-tree-sha1 = "9ebfc140cc56e8c2156a15ceac2f0302e327ac0a" +git-tree-sha1 = "9c304562909ab2bab0262639bd4f444d7bc2be37" uuid = "d8fb68d0-12a3-5cfd-a85a-d49703b185fd" -version = "1.4.1+0" +version = "1.4.1+1" diff --git a/docs/setup_docs.jl b/docs/setup_docs.jl index 7f10f8a5..7c995f3f 100644 --- a/docs/setup_docs.jl +++ b/docs/setup_docs.jl @@ -19,6 +19,7 @@ setup_docs = quote using SharedArrays using StatsBase using StatsPlots + using TaijaPlotting using Transformers using Transformers.TextEncoders using Transformers.HuggingFace diff --git a/docs/src/tutorials/training.qmd b/docs/src/tutorials/training.qmd index 209c89e1..e51cf41f 100644 --- a/docs/src/tutorials/training.qmd +++ b/docs/src/tutorials/training.qmd @@ -6,12 +6,8 @@ CurrentModule = ConformalPrediction ```{julia} #| echo: false -using Pkg; Pkg.activate("docs") -using Plots -theme(:wong) -using Random -Random.seed!(2022) -www_path = "docs/src/www" # output path for files don't get automatically saved in auto-generated path (e.g. GIFs) +include("$(pwd())/docs/setup_docs.jl") +eval(setup_docs) ``` @@ -21,7 +17,7 @@ using Random Random.seed!(123) # Data: -X, y = make_blobs(500, centers=4, cluster_std=1.0) +X, y = make_blobs(1000, centers=2, cluster_std=1.0) X = MLJ.table(Float32.(MLJ.matrix(X))) train, test = partition(eachindex(y), 0.8, shuffle=true) ``` @@ -34,9 +30,10 @@ using ConformalPrediction using ConformalPrediction.ConformalTraining: ConformalNNClassifier # Model: -builder = MLJFlux.MLP(hidden=(32, 32, 32,), σ=Flux.relu) -# clf = ConformalNNClassifier(epochs=250, builder=builder, batch_size=50) -clf = NeuralNetworkClassifier(epochs=250, builder=builder, batch_size=50) +hidden_dim = 32 +builder = MLJFlux.MLP(hidden=ntuple(x -> hidden_dim, 3), σ=Flux.relu) +clf = ConformalNNClassifier(epochs=25, builder=builder, batch_size=10, reg_strength_size=0.5, epsilon=0.1) +# clf = NeuralNetworkClassifier(epochs=250, builder=builder, batch_size=50) ``` @@ -45,7 +42,7 @@ using ConformalPrediction conf_model = conformal_model(clf; method=:simple_inductive) mach = machine(conf_model, X, y) -fit!(mach, rows=train) +fit!(mach, rows=train, verbosity=0) ``` ```{julia} @@ -55,7 +52,9 @@ using Plots p_proba = contourf(mach.model, mach.fitresult, X, y) p_set_size = contourf(mach.model, mach.fitresult, X, y; plot_set_size=true) p_smooth = contourf(mach.model, mach.fitresult, X, y; plot_set_loss=true) -plot(p_proba, p_set_size, p_smooth, layout=(1,3), size=(1200,250)) +plt = plot(p_proba, p_set_size, p_smooth, layout=(1,3), size=(1200,250)) +display(plt) +ineff(MLJBase.predict(mach)) ``` ```{julia} @@ -63,7 +62,7 @@ plot(p_proba, p_set_size, p_smooth, layout=(1,3), size=(1200,250)) _eval = evaluate!( mach, - operation=predict, + operation=MLJBase.predict, measure=[emp_coverage, ssc, ineff] ) diff --git a/src/conformal_models/ConformalTraining/classifier.jl b/src/conformal_models/ConformalTraining/classifier.jl index 2e07ccd4..2b68c6ec 100644 --- a/src/conformal_models/ConformalTraining/classifier.jl +++ b/src/conformal_models/ConformalTraining/classifier.jl @@ -19,6 +19,8 @@ mutable struct ConformalNNClassifier{B,F,O,L} <: MLJFlux.MLJFluxProbabilistic rng::Union{AbstractRNG,Int64} optimiser_changes_trigger_retraining::Bool acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` + reg_strength_size::Float64 # regularization strength for size loss + epsilon::Float64 # epsilon for soft sorting end function ConformalNNClassifier(; @@ -33,6 +35,8 @@ function ConformalNNClassifier(; rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG, optimiser_changes_trigger_retraining::Bool=false, acceleration::AbstractResource=CPU1(), + reg_strength_size::Float64=5.0, + epsilon::Float64=0.1, ) where {B,F,O,L} # Initialise the MLJFlux wrapper: @@ -48,6 +52,8 @@ function ConformalNNClassifier(; rng, optimiser_changes_trigger_retraining, acceleration, + reg_strength_size, + epsilon, ) return mod diff --git a/src/conformal_models/ConformalTraining/losses.jl b/src/conformal_models/ConformalTraining/losses.jl index 8fc67131..1d5f100e 100644 --- a/src/conformal_models/ConformalTraining/losses.jl +++ b/src/conformal_models/ConformalTraining/losses.jl @@ -1,5 +1,6 @@ using ConformalPrediction: ConformalProbabilisticSet using Flux +using InferOpt: soft_sort_kl using LinearAlgebra using MLJBase using StatsBase @@ -10,10 +11,11 @@ using StatsBase Computes soft assignment scores for each label and sample. That is, the probability of label `k` being included in the confidence set. This implementation follows Stutz et al. (2022): https://openreview.net/pdf?id=t8O-4LKFVx. Contrary to the paper, we use non-conformity scores instead of conformity scores, hence the sign swap. """ function soft_assignment( - conf_model::ConformalProbabilisticSet; temp::Union{Nothing,Real}=nothing + conf_model::ConformalProbabilisticSet; temp::Union{Nothing,Real}=nothing, ε::Real=1e-6 ) temp = isnothing(temp) ? 0.5 : temp - v = sort(conf_model.scores[:calibration]) + ε = hasfield(typeof(conf_model.model), :epsilon) ? conf_model.model.epsilon : ε + v = soft_sort_kl(conf_model.scores[:calibration]; ε=ε,) q̂ = qplus(v, conf_model.coverage; sorted=true) scores = conf_model.scores[:all] return @.(σ((q̂ - scores) / temp)) @@ -25,10 +27,15 @@ end This function can be used to compute soft assigment probabilities for new data `X` as in [`soft_assignment(conf_model::ConformalProbabilisticSet; temp::Real=0.5)`](@ref). When a fitted model $\mu$ (`fitresult`) and new samples `X` are supplied, non-conformity scores are first computed for the new data points. Then the existing threshold/quantile `q̂` is used to compute the final soft assignments. """ function soft_assignment( - conf_model::ConformalProbabilisticSet, fitresult, X; temp::Union{Nothing,Real}=nothing + conf_model::ConformalProbabilisticSet, + fitresult, + X; + temp::Union{Nothing,Real}=nothing, + ε::Real=1e-6, ) temp = isnothing(temp) ? 0.5 : temp - v = sort(conf_model.scores[:calibration]) + ε = hasfield(typeof(conf_model.model), :epsilon) ? conf_model.model.epsilon : ε + v = soft_sort_kl(conf_model.scores[:calibration]; ε=ε) q̂ = StatsBase.quantile(v, conf_model.coverage; sorted=true) scores = ConformalPrediction.score(conf_model, fitresult, X) return @.(σ((q̂ - scores) / temp)) diff --git a/src/conformal_models/ConformalTraining/regressor.jl b/src/conformal_models/ConformalTraining/regressor.jl index 3522750e..5f5a6f1a 100644 --- a/src/conformal_models/ConformalTraining/regressor.jl +++ b/src/conformal_models/ConformalTraining/regressor.jl @@ -18,6 +18,8 @@ mutable struct ConformalNNRegressor{B,O,L} <: MLJFlux.MLJFluxDeterministic rng::Union{AbstractRNG,Integer} optimiser_changes_trigger_retraining::Bool acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` + reg_strength_size::Float64 # regularization strength for size loss + epsilon::Float64 # epsilon for soft sorting end function ConformalNNRegressor(; @@ -31,6 +33,8 @@ function ConformalNNRegressor(; rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG, optimiser_changes_trigger_retraining::Bool=false, acceleration::AbstractResource=CPU1(), + reg_strength_size::Float64=5.0, + epsilon::Float64=0.1, ) where {B,O,L} # Initialise the MLJFlux wrapper: @@ -45,6 +49,8 @@ function ConformalNNRegressor(; rng, optimiser_changes_trigger_retraining, acceleration, + reg_strength_size, + epsilon, ) return mod diff --git a/src/conformal_models/ConformalTraining/training.jl b/src/conformal_models/ConformalTraining/training.jl index ae8a82ee..019e07f1 100644 --- a/src/conformal_models/ConformalTraining/training.jl +++ b/src/conformal_models/ConformalTraining/training.jl @@ -13,6 +13,7 @@ function MLJFlux.train!(model::ConformalNN, penalty, chain, optimiser, X, y) training_loss = zero(Float32) size_loss = zero(Float32) fitresult = (chain, nothing) + λ = model.reg_strength_size # Training loop: for i in 1:n_batches @@ -42,8 +43,8 @@ function MLJFlux.train!(model::ConformalNN, penalty, chain, optimiser, X, y) gs = Flux.gradient(parameters) do Ω = smooth_size_loss(conf_model, fitresult, Xpred') yhat = chain(X_batch) - batch_loss = loss(yhat, y_batch) + penalty(parameters) / n_batches - batch_loss += 0.5 * sum(Ω) / length(Ω) # add size loss + batch_loss = + (loss(yhat, y_batch) + penalty(parameters) + λ * sum(Ω) / length(Ω)) / n_batches training_loss += batch_loss size_loss += sum(Ω) / length(Ω) return batch_loss From 14f1df10f51044ef4429c82b6d3ed60633073a04 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 8 Nov 2023 16:02:39 +0100 Subject: [PATCH 02/11] compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e639ddb9..1d4ea40e 100644 --- a/Project.toml +++ b/Project.toml @@ -38,7 +38,7 @@ NaturalSort = "1" ProgressMeter = "1" StatsBase = "0.33, 0.34.0" Tables = "1" -julia = "1.7, 1.8, 1.9" +julia = "1.6, 1.7, 1.8, 1.9" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" From 3137ec9811c34bd48726027d8689c3371529bd2d Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 8 Nov 2023 16:06:18 +0100 Subject: [PATCH 03/11] formatting and compat helper --- Project.toml | 1 + docs/Manifest.toml | 69 ++++--------------- .../ConformalTraining/losses.jl | 2 +- .../ConformalTraining/training.jl | 3 +- 4 files changed, 18 insertions(+), 57 deletions(-) diff --git a/Project.toml b/Project.toml index 1d4ea40e..7c116eda 100644 --- a/Project.toml +++ b/Project.toml @@ -29,6 +29,7 @@ CategoricalArrays = "0.10" ChainRules = "1.49.0" ComputationalResources = "0.3" Flux = "0.13.16, 0.14" +LazyArtifacts = "1" MLJBase = "0.20, 0.21, 1" MLJEnsembles = "0.3.3, 0.4" MLJFlux = "0.2.10, 0.3, 0.4" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index c944b474..31d1cd5a 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -284,11 +284,6 @@ weakdeps = ["UnicodePlots"] [deps.CategoricalDistributions.extensions] UnivariateFiniteDisplayExt = "UnicodePlots" -[[deps.Chain]] -git-tree-sha1 = "8c4920235f6c561e401dfe569beb8b924adad003" -uuid = "8be319e6-bccf-4806-a6f7-6fae938471bc" -version = "0.5.0" - [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] git-tree-sha1 = "710940598100496ad6cbb707e481c28186354197" @@ -440,10 +435,10 @@ uuid = "150eb455-5306-5404-9cee-2592286d6298" version = "0.6.3" [[deps.CounterfactualExplanations]] -deps = ["CSV", "CUDA", "CategoricalArrays", "Chain", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "Plots", "ProgressMeter", "Random", "Serialization", "SnoopPrecompile", "Statistics", "StatsBase", "Tables", "UMAP"] -git-tree-sha1 = "8393721ffa3c9be209a93eb154d0d9fe9ca187d5" +deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "PrecompileTools", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "UUIDs", "cuDNN"] +git-tree-sha1 = "3e51ccd2c65c7e455621bd3bda6bf92ce10da4d4" uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" -version = "0.1.15" +version = "0.1.30" [deps.CounterfactualExplanations.extensions] MPIExt = "MPI" @@ -1534,12 +1529,6 @@ weakdeps = ["CategoricalArrays"] [deps.LossFunctions.extensions] LossFunctionsCategoricalArraysExt = "CategoricalArrays" -[[deps.LsqFit]] -deps = ["Distributions", "ForwardDiff", "LinearAlgebra", "NLSolversBase", "OptimBase", "Random", "StatsBase"] -git-tree-sha1 = "00f475f85c50584b12268675072663dfed5594b2" -uuid = "2fda8390-95c7-5789-9bda-21331edee243" -version = "0.13.0" - [[deps.MAT]] deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] git-tree-sha1 = "ed1cf0a322d78cee07718bed5fd945e2218c35a1" @@ -1636,10 +1625,10 @@ uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7" version = "0.16.12" [[deps.MLJMultivariateStatsInterface]] -deps = ["Distances", "LinearAlgebra", "MLJModelInterface", "MultivariateStats", "StatsBase"] -git-tree-sha1 = "a282960828015daf766b4d66ba75445b0c909099" +deps = ["CategoricalDistributions", "Distances", "LinearAlgebra", "MLJModelInterface", "MultivariateStats", "StatsBase"] +git-tree-sha1 = "0d76e36bf83926235dcd3eaeafa7f47d3e7f32ea" uuid = "1b6a4a23-ba22-4f51-9698-8599985d3728" -version = "0.4.0" +version = "0.5.3" [[deps.MLJNaiveBayesInterface]] deps = ["LogExpFunctions", "MLJModelInterface", "NaiveBayes"] @@ -1670,12 +1659,6 @@ git-tree-sha1 = "9ee1618cbf5240e6d4e0371d6f24065083f60c48" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" version = "0.5.11" -[[deps.ManifoldLearning]] -deps = ["Combinatorics", "Graphs", "LinearAlgebra", "MultivariateStats", "Random", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "4c5564c707899c3b6bc6d324b05e43eb7f277f2b" -uuid = "06eb3307-b2af-5a2a-abea-d33192699d32" -version = "0.9.0" - [[deps.ManualMemory]] git-tree-sha1 = "bcaef4fc7a0cfe2cba636d84cda54b5e4e4ca3cd" uuid = "d125e4d3-2237-4719-b19c-fa641b8a4667" @@ -1756,10 +1739,10 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159" version = "2022.10.11" [[deps.MultivariateStats]] -deps = ["Arpack", "LinearAlgebra", "SparseArrays", "Statistics", "StatsBase"] -git-tree-sha1 = "6d019f5a0465522bbfdd68ecfad7f86b535d6935" +deps = ["Arpack", "LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI", "StatsBase"] +git-tree-sha1 = "68bf5103e002c44adfd71fea6bd770b3f0586843" uuid = "6f286f6a-111f-5878-ab1e-185364afe411" -version = "0.9.0" +version = "0.10.2" [[deps.NLSolversBase]] deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"] @@ -1814,12 +1797,6 @@ git-tree-sha1 = "eda490d06b9f7c00752ee81cfa451efe55521e21" uuid = "c020b1a1-e9b0-503a-9c33-f039bfc54a85" version = "1.0.0" -[[deps.NearestNeighborDescent]] -deps = ["DataStructures", "Distances", "Graphs", "Random", "Reexport", "SparseArrays"] -git-tree-sha1 = "b7d4bd2ab58f0c3a001fd6eedc2e0aac8e278152" -uuid = "dd2c4c9e-a32f-5b2f-b342-08c2f244fce8" -version = "0.3.6" - [[deps.NearestNeighborModels]] deps = ["Distances", "FillArrays", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "NearestNeighbors", "Statistics", "StatsBase", "Tables"] git-tree-sha1 = "e411143a8362926e4284a54e745972e939fbab78" @@ -1939,12 +1916,6 @@ git-tree-sha1 = "e3a6546c1577bfd701771b477b794a52949e7594" uuid = "429524aa-4258-5aef-a3af-852621145aeb" version = "1.7.6" -[[deps.OptimBase]] -deps = ["NLSolversBase", "Printf", "Reexport"] -git-tree-sha1 = "9cb1fee807b599b5f803809e85c81b582d2009d6" -uuid = "87e2bd06-a317-5318-96d9-3ecbac512eee" -version = "2.0.2" - [[deps.Optimisers]] deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] git-tree-sha1 = "c1fc26bab5df929a5172f296f25d7d08688fd25b" @@ -2417,12 +2388,6 @@ git-tree-sha1 = "2da10356e31327c7096832eb9cd86307a50b1eb6" uuid = "45858cf5-a6b0-47a3-bbea-62219f50df47" version = "0.1.3" -[[deps.SnoopPrecompile]] -deps = ["Preferences"] -git-tree-sha1 = "e760a70afdcd461cf01a575947738d359234665c" -uuid = "66db9d55-30c0-4569-8b51-7e840670fc0c" -version = "1.0.3" - [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" @@ -2521,9 +2486,9 @@ version = "1.7.0" [[deps.StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "d1bf48bfcc554a3761a133fe3a9bb01488e06916" +git-tree-sha1 = "1d77abd07f617c4868c33d4f5b9e1dbb2643c9cf" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.21" +version = "0.34.2" [[deps.StatsFuns]] deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] @@ -2642,10 +2607,10 @@ uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" version = "1.11.1" [[deps.TaijaPlotting]] -deps = ["CategoricalArrays", "ConformalPrediction", "CounterfactualExplanations", "DataAPI", "Distributions", "Flux", "LaplaceRedux", "LinearAlgebra", "MLJBase", "ManifoldLearning", "MultivariateStats", "NaturalSort", "NearestNeighborModels", "Plots"] -path = "../../TaijaPlotting.jl" +deps = ["CategoricalArrays", "ConformalPrediction", "CounterfactualExplanations", "Flux", "LaplaceRedux", "LinearAlgebra", "MLJBase", "MultivariateStats", "NaturalSort", "Plots"] +git-tree-sha1 = "d5d1c9fccd05c4ff9793394c56fc07c81db40eda" uuid = "bd7198b4-c7d6-400c-9bab-9a24614b0240" -version = "1.0.3" +version = "1.0.2" [[deps.Tar]] deps = ["ArgTools", "SHA"] @@ -2761,12 +2726,6 @@ git-tree-sha1 = "155515ed4c4236db30049ac1495e2969cc06be9d" uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" version = "1.4.3" -[[deps.UMAP]] -deps = ["Arpack", "Distances", "LinearAlgebra", "LsqFit", "NearestNeighborDescent", "Random", "SparseArrays"] -git-tree-sha1 = "accad220f075445f68caa6488be728957a5d82d6" -uuid = "c4f8c510-2410-5be4-91d7-4fbaeb39457e" -version = "0.1.10" - [[deps.URIs]] git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" diff --git a/src/conformal_models/ConformalTraining/losses.jl b/src/conformal_models/ConformalTraining/losses.jl index 1d5f100e..cb0cda0e 100644 --- a/src/conformal_models/ConformalTraining/losses.jl +++ b/src/conformal_models/ConformalTraining/losses.jl @@ -15,7 +15,7 @@ function soft_assignment( ) temp = isnothing(temp) ? 0.5 : temp ε = hasfield(typeof(conf_model.model), :epsilon) ? conf_model.model.epsilon : ε - v = soft_sort_kl(conf_model.scores[:calibration]; ε=ε,) + v = soft_sort_kl(conf_model.scores[:calibration]; ε=ε) q̂ = qplus(v, conf_model.coverage; sorted=true) scores = conf_model.scores[:all] return @.(σ((q̂ - scores) / temp)) diff --git a/src/conformal_models/ConformalTraining/training.jl b/src/conformal_models/ConformalTraining/training.jl index 019e07f1..ffeb2397 100644 --- a/src/conformal_models/ConformalTraining/training.jl +++ b/src/conformal_models/ConformalTraining/training.jl @@ -44,7 +44,8 @@ function MLJFlux.train!(model::ConformalNN, penalty, chain, optimiser, X, y) Ω = smooth_size_loss(conf_model, fitresult, Xpred') yhat = chain(X_batch) batch_loss = - (loss(yhat, y_batch) + penalty(parameters) + λ * sum(Ω) / length(Ω)) / n_batches + (loss(yhat, y_batch) + penalty(parameters) + λ * sum(Ω) / length(Ω)) / + n_batches training_loss += batch_loss size_loss += sum(Ω) / length(Ω) return batch_loss From 7051ba659edb93ba71d06dca8cf7b2f5e499a4c7 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 8 Nov 2023 16:10:08 +0100 Subject: [PATCH 04/11] remove 1.6 from testing --- .github/workflows/CI.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index e9b4a8d8..223ccc05 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -19,7 +19,6 @@ jobs: fail-fast: false matrix: version: - - '1.6' - '1.7' - '1.8' - '1.9' From 989a7a019607af269dd081c9851bdab9d73a026b Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 8 Nov 2023 16:45:41 +0100 Subject: [PATCH 05/11] damn --- docs/Manifest.toml | 69 +++++++++++++++---- docs/src/tutorials/training.qmd | 4 +- .../ConformalTraining/losses.jl | 23 +++---- 3 files changed, 67 insertions(+), 29 deletions(-) diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 31d1cd5a..c944b474 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -284,6 +284,11 @@ weakdeps = ["UnicodePlots"] [deps.CategoricalDistributions.extensions] UnivariateFiniteDisplayExt = "UnicodePlots" +[[deps.Chain]] +git-tree-sha1 = "8c4920235f6c561e401dfe569beb8b924adad003" +uuid = "8be319e6-bccf-4806-a6f7-6fae938471bc" +version = "0.5.0" + [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] git-tree-sha1 = "710940598100496ad6cbb707e481c28186354197" @@ -435,10 +440,10 @@ uuid = "150eb455-5306-5404-9cee-2592286d6298" version = "0.6.3" [[deps.CounterfactualExplanations]] -deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "PrecompileTools", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "UUIDs", "cuDNN"] -git-tree-sha1 = "3e51ccd2c65c7e455621bd3bda6bf92ce10da4d4" +deps = ["CSV", "CUDA", "CategoricalArrays", "Chain", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "Plots", "ProgressMeter", "Random", "Serialization", "SnoopPrecompile", "Statistics", "StatsBase", "Tables", "UMAP"] +git-tree-sha1 = "8393721ffa3c9be209a93eb154d0d9fe9ca187d5" uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" -version = "0.1.30" +version = "0.1.15" [deps.CounterfactualExplanations.extensions] MPIExt = "MPI" @@ -1529,6 +1534,12 @@ weakdeps = ["CategoricalArrays"] [deps.LossFunctions.extensions] LossFunctionsCategoricalArraysExt = "CategoricalArrays" +[[deps.LsqFit]] +deps = ["Distributions", "ForwardDiff", "LinearAlgebra", "NLSolversBase", "OptimBase", "Random", "StatsBase"] +git-tree-sha1 = "00f475f85c50584b12268675072663dfed5594b2" +uuid = "2fda8390-95c7-5789-9bda-21331edee243" +version = "0.13.0" + [[deps.MAT]] deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] git-tree-sha1 = "ed1cf0a322d78cee07718bed5fd945e2218c35a1" @@ -1625,10 +1636,10 @@ uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7" version = "0.16.12" [[deps.MLJMultivariateStatsInterface]] -deps = ["CategoricalDistributions", "Distances", "LinearAlgebra", "MLJModelInterface", "MultivariateStats", "StatsBase"] -git-tree-sha1 = "0d76e36bf83926235dcd3eaeafa7f47d3e7f32ea" +deps = ["Distances", "LinearAlgebra", "MLJModelInterface", "MultivariateStats", "StatsBase"] +git-tree-sha1 = "a282960828015daf766b4d66ba75445b0c909099" uuid = "1b6a4a23-ba22-4f51-9698-8599985d3728" -version = "0.5.3" +version = "0.4.0" [[deps.MLJNaiveBayesInterface]] deps = ["LogExpFunctions", "MLJModelInterface", "NaiveBayes"] @@ -1659,6 +1670,12 @@ git-tree-sha1 = "9ee1618cbf5240e6d4e0371d6f24065083f60c48" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" version = "0.5.11" +[[deps.ManifoldLearning]] +deps = ["Combinatorics", "Graphs", "LinearAlgebra", "MultivariateStats", "Random", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "4c5564c707899c3b6bc6d324b05e43eb7f277f2b" +uuid = "06eb3307-b2af-5a2a-abea-d33192699d32" +version = "0.9.0" + [[deps.ManualMemory]] git-tree-sha1 = "bcaef4fc7a0cfe2cba636d84cda54b5e4e4ca3cd" uuid = "d125e4d3-2237-4719-b19c-fa641b8a4667" @@ -1739,10 +1756,10 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159" version = "2022.10.11" [[deps.MultivariateStats]] -deps = ["Arpack", "LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI", "StatsBase"] -git-tree-sha1 = "68bf5103e002c44adfd71fea6bd770b3f0586843" +deps = ["Arpack", "LinearAlgebra", "SparseArrays", "Statistics", "StatsBase"] +git-tree-sha1 = "6d019f5a0465522bbfdd68ecfad7f86b535d6935" uuid = "6f286f6a-111f-5878-ab1e-185364afe411" -version = "0.10.2" +version = "0.9.0" [[deps.NLSolversBase]] deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"] @@ -1797,6 +1814,12 @@ git-tree-sha1 = "eda490d06b9f7c00752ee81cfa451efe55521e21" uuid = "c020b1a1-e9b0-503a-9c33-f039bfc54a85" version = "1.0.0" +[[deps.NearestNeighborDescent]] +deps = ["DataStructures", "Distances", "Graphs", "Random", "Reexport", "SparseArrays"] +git-tree-sha1 = "b7d4bd2ab58f0c3a001fd6eedc2e0aac8e278152" +uuid = "dd2c4c9e-a32f-5b2f-b342-08c2f244fce8" +version = "0.3.6" + [[deps.NearestNeighborModels]] deps = ["Distances", "FillArrays", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "NearestNeighbors", "Statistics", "StatsBase", "Tables"] git-tree-sha1 = "e411143a8362926e4284a54e745972e939fbab78" @@ -1916,6 +1939,12 @@ git-tree-sha1 = "e3a6546c1577bfd701771b477b794a52949e7594" uuid = "429524aa-4258-5aef-a3af-852621145aeb" version = "1.7.6" +[[deps.OptimBase]] +deps = ["NLSolversBase", "Printf", "Reexport"] +git-tree-sha1 = "9cb1fee807b599b5f803809e85c81b582d2009d6" +uuid = "87e2bd06-a317-5318-96d9-3ecbac512eee" +version = "2.0.2" + [[deps.Optimisers]] deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] git-tree-sha1 = "c1fc26bab5df929a5172f296f25d7d08688fd25b" @@ -2388,6 +2417,12 @@ git-tree-sha1 = "2da10356e31327c7096832eb9cd86307a50b1eb6" uuid = "45858cf5-a6b0-47a3-bbea-62219f50df47" version = "0.1.3" +[[deps.SnoopPrecompile]] +deps = ["Preferences"] +git-tree-sha1 = "e760a70afdcd461cf01a575947738d359234665c" +uuid = "66db9d55-30c0-4569-8b51-7e840670fc0c" +version = "1.0.3" + [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" @@ -2486,9 +2521,9 @@ version = "1.7.0" [[deps.StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "1d77abd07f617c4868c33d4f5b9e1dbb2643c9cf" +git-tree-sha1 = "d1bf48bfcc554a3761a133fe3a9bb01488e06916" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.34.2" +version = "0.33.21" [[deps.StatsFuns]] deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] @@ -2607,10 +2642,10 @@ uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" version = "1.11.1" [[deps.TaijaPlotting]] -deps = ["CategoricalArrays", "ConformalPrediction", "CounterfactualExplanations", "Flux", "LaplaceRedux", "LinearAlgebra", "MLJBase", "MultivariateStats", "NaturalSort", "Plots"] -git-tree-sha1 = "d5d1c9fccd05c4ff9793394c56fc07c81db40eda" +deps = ["CategoricalArrays", "ConformalPrediction", "CounterfactualExplanations", "DataAPI", "Distributions", "Flux", "LaplaceRedux", "LinearAlgebra", "MLJBase", "ManifoldLearning", "MultivariateStats", "NaturalSort", "NearestNeighborModels", "Plots"] +path = "../../TaijaPlotting.jl" uuid = "bd7198b4-c7d6-400c-9bab-9a24614b0240" -version = "1.0.2" +version = "1.0.3" [[deps.Tar]] deps = ["ArgTools", "SHA"] @@ -2726,6 +2761,12 @@ git-tree-sha1 = "155515ed4c4236db30049ac1495e2969cc06be9d" uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" version = "1.4.3" +[[deps.UMAP]] +deps = ["Arpack", "Distances", "LinearAlgebra", "LsqFit", "NearestNeighborDescent", "Random", "SparseArrays"] +git-tree-sha1 = "accad220f075445f68caa6488be728957a5d82d6" +uuid = "c4f8c510-2410-5be4-91d7-4fbaeb39457e" +version = "0.1.10" + [[deps.URIs]] git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" diff --git a/docs/src/tutorials/training.qmd b/docs/src/tutorials/training.qmd index e51cf41f..f49bade9 100644 --- a/docs/src/tutorials/training.qmd +++ b/docs/src/tutorials/training.qmd @@ -17,7 +17,7 @@ using Random Random.seed!(123) # Data: -X, y = make_blobs(1000, centers=2, cluster_std=1.0) +X, y = make_blobs(100, centers=10, cluster_std=1.0) X = MLJ.table(Float32.(MLJ.matrix(X))) train, test = partition(eachindex(y), 0.8, shuffle=true) ``` @@ -32,7 +32,7 @@ using ConformalPrediction.ConformalTraining: ConformalNNClassifier # Model: hidden_dim = 32 builder = MLJFlux.MLP(hidden=ntuple(x -> hidden_dim, 3), σ=Flux.relu) -clf = ConformalNNClassifier(epochs=25, builder=builder, batch_size=10, reg_strength_size=0.5, epsilon=0.1) +clf = ConformalNNClassifier(epochs=25, builder=builder, batch_size=10, reg_strength_size=1.0, epsilon=0.1) # clf = NeuralNetworkClassifier(epochs=250, builder=builder, batch_size=50) ``` diff --git a/src/conformal_models/ConformalTraining/losses.jl b/src/conformal_models/ConformalTraining/losses.jl index cb0cda0e..93014851 100644 --- a/src/conformal_models/ConformalTraining/losses.jl +++ b/src/conformal_models/ConformalTraining/losses.jl @@ -6,14 +6,13 @@ using MLJBase using StatsBase """ - soft_assignment(conf_model::ConformalProbabilisticSet; temp::Real=0.5) + soft_assignment(conf_model::ConformalProbabilisticSet; temp::Real=0.1) Computes soft assignment scores for each label and sample. That is, the probability of label `k` being included in the confidence set. This implementation follows Stutz et al. (2022): https://openreview.net/pdf?id=t8O-4LKFVx. Contrary to the paper, we use non-conformity scores instead of conformity scores, hence the sign swap. """ function soft_assignment( - conf_model::ConformalProbabilisticSet; temp::Union{Nothing,Real}=nothing, ε::Real=1e-6 + conf_model::ConformalProbabilisticSet; temp::Real=0.1, ε::Real=1e-6 ) - temp = isnothing(temp) ? 0.5 : temp ε = hasfield(typeof(conf_model.model), :epsilon) ? conf_model.model.epsilon : ε v = soft_sort_kl(conf_model.scores[:calibration]; ε=ε) q̂ = qplus(v, conf_model.coverage; sorted=true) @@ -22,18 +21,17 @@ function soft_assignment( end @doc raw""" - soft_assignment(conf_model::ConformalProbabilisticSet, fitresult, X; temp::Real=0.5) + soft_assignment(conf_model::ConformalProbabilisticSet, fitresult, X; temp::Real=0.1) -This function can be used to compute soft assigment probabilities for new data `X` as in [`soft_assignment(conf_model::ConformalProbabilisticSet; temp::Real=0.5)`](@ref). When a fitted model $\mu$ (`fitresult`) and new samples `X` are supplied, non-conformity scores are first computed for the new data points. Then the existing threshold/quantile `q̂` is used to compute the final soft assignments. +This function can be used to compute soft assigment probabilities for new data `X` as in [`soft_assignment(conf_model::ConformalProbabilisticSet; temp::Real=0.1)`](@ref). When a fitted model $\mu$ (`fitresult`) and new samples `X` are supplied, non-conformity scores are first computed for the new data points. Then the existing threshold/quantile `q̂` is used to compute the final soft assignments. """ function soft_assignment( conf_model::ConformalProbabilisticSet, fitresult, X; - temp::Union{Nothing,Real}=nothing, + temp::Real=0.1, ε::Real=1e-6, ) - temp = isnothing(temp) ? 0.5 : temp ε = hasfield(typeof(conf_model.model), :epsilon) ? conf_model.model.epsilon : ε v = soft_sort_kl(conf_model.scores[:calibration]; ε=ε) q̂ = StatsBase.quantile(v, conf_model.coverage; sorted=true) @@ -44,7 +42,7 @@ end @doc raw""" function smooth_size_loss( conf_model::ConformalProbabilisticSet, fitresult, X; - temp::Real=0.5, κ::Real=1.0 + temp::Real=0.1, κ::Real=1.0 ) Computes the smooth (differentiable) size loss following Stutz et al. (2022): https://openreview.net/pdf?id=t8O-4LKFVx. First, soft assignment probabilities are computed for new data `X`. Then (following the notation in the paper) the loss is computed as, @@ -59,10 +57,10 @@ function smooth_size_loss( conf_model::ConformalProbabilisticSet, fitresult, X; - temp::Union{Nothing,Real}=nothing, + temp::Real=0.1, κ::Real=1.0, ) - temp = isnothing(temp) ? 0.1 : temp + C = soft_assignment(conf_model, fitresult, X; temp=temp) is_empty_set = all( x -> x .== 0, soft_assignment(conf_model, fitresult, X; temp=0.0); dims=2 @@ -86,7 +84,7 @@ end classification_loss( conf_model::ConformalProbabilisticSet, fitresult, X, y; loss_matrix::Union{AbstractMatrix,UniformScaling}=UniformScaling(1.0), - temp::Real=0.5 + temp::Real=0.1 ) Computes the calibration loss following Stutz et al. (2022): https://openreview.net/pdf?id=t8O-4LKFVx. Following the notation in the paper, the loss is computed as, @@ -103,10 +101,9 @@ function classification_loss( X, y; loss_matrix::Union{AbstractMatrix,UniformScaling}=UniformScaling(1.0), - temp::Union{Nothing,Real}=nothing, + temp::Real=0.1, ) # Setup: - temp = isnothing(temp) ? 0.5 : temp if typeof(y) <: CategoricalArray L = levels(y) yenc = permutedims(Flux.onehotbatch(levelcode.(y), L)) From 52f079fe3a367dfd2805325c6ad35b4ba1b4fa9f Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 8 Nov 2023 16:48:28 +0100 Subject: [PATCH 06/11] formatting --- src/conformal_models/ConformalTraining/losses.jl | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/conformal_models/ConformalTraining/losses.jl b/src/conformal_models/ConformalTraining/losses.jl index 93014851..bced8404 100644 --- a/src/conformal_models/ConformalTraining/losses.jl +++ b/src/conformal_models/ConformalTraining/losses.jl @@ -26,11 +26,7 @@ end This function can be used to compute soft assigment probabilities for new data `X` as in [`soft_assignment(conf_model::ConformalProbabilisticSet; temp::Real=0.1)`](@ref). When a fitted model $\mu$ (`fitresult`) and new samples `X` are supplied, non-conformity scores are first computed for the new data points. Then the existing threshold/quantile `q̂` is used to compute the final soft assignments. """ function soft_assignment( - conf_model::ConformalProbabilisticSet, - fitresult, - X; - temp::Real=0.1, - ε::Real=1e-6, + conf_model::ConformalProbabilisticSet, fitresult, X; temp::Real=0.1, ε::Real=1e-6 ) ε = hasfield(typeof(conf_model.model), :epsilon) ? conf_model.model.epsilon : ε v = soft_sort_kl(conf_model.scores[:calibration]; ε=ε) @@ -54,13 +50,8 @@ Computes the smooth (differentiable) size loss following Stutz et al. (2022): ht where $\tau$ is just the quantile `q̂` and $\kappa$ is the target set size (defaults to $1$). For empty sets, the loss is computed as $K - \kappa$, that is the maximum set size minus the target set size. """ function smooth_size_loss( - conf_model::ConformalProbabilisticSet, - fitresult, - X; - temp::Real=0.1, - κ::Real=1.0, + conf_model::ConformalProbabilisticSet, fitresult, X; temp::Real=0.1, κ::Real=1.0 ) - C = soft_assignment(conf_model, fitresult, X; temp=temp) is_empty_set = all( x -> x .== 0, soft_assignment(conf_model, fitresult, X; temp=0.0); dims=2 From 2012366a80c83e64552ba1ab68306554d87b002f Mon Sep 17 00:00:00 2001 From: pat-alt Date: Thu, 9 Nov 2023 11:55:05 +0100 Subject: [PATCH 07/11] sorting out issue with docs --- docs/Manifest.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Manifest.toml b/docs/Manifest.toml index c944b474..c151080d 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -2643,7 +2643,7 @@ version = "1.11.1" [[deps.TaijaPlotting]] deps = ["CategoricalArrays", "ConformalPrediction", "CounterfactualExplanations", "DataAPI", "Distributions", "Flux", "LaplaceRedux", "LinearAlgebra", "MLJBase", "ManifoldLearning", "MultivariateStats", "NaturalSort", "NearestNeighborModels", "Plots"] -path = "../../TaijaPlotting.jl" +git-tree-sha1 = "f86ed2cbb9e9a08b2fe19f44d6c0a1266d05a2f4" uuid = "bd7198b4-c7d6-400c-9bab-9a24614b0240" version = "1.0.3" From 7b1113f61f18342b25ee7546211e4d5cac322767 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Thu, 9 Nov 2023 12:30:48 +0100 Subject: [PATCH 08/11] trying to sort out docs issue --- .../plotting/execute-results/md.json | 7 +- .../figure-commonmark/cell-10-output-1.svg | 64 +- .../figure-commonmark/cell-12-output-1.svg | 86 +- .../figure-commonmark/cell-16-output-1.svg | 326 +- .../figure-commonmark/cell-19-output-1.svg | 718 ++-- .../figure-commonmark/cell-22-output-1.svg | 1240 +++---- .../figure-commonmark/cell-24-output-1.svg | 1180 +++--- .../figure-commonmark/cell-6-output-1.svg | 272 +- .../figure-commonmark/cell-9-output-1.svg | 3228 ++++++++--------- docs/src/tutorials/plotting.md | 13 +- docs/src/tutorials/plotting.qmd | 13 +- .../figure-commonmark/cell-10-output-1.svg | 64 +- .../figure-commonmark/cell-12-output-1.svg | 86 +- .../figure-commonmark/cell-16-output-1.svg | 326 +- .../figure-commonmark/cell-19-output-1.svg | 718 ++-- .../figure-commonmark/cell-22-output-1.svg | 1240 +++---- .../figure-commonmark/cell-24-output-1.svg | 1180 +++--- .../figure-commonmark/cell-6-output-1.svg | 272 +- .../figure-commonmark/cell-9-output-1.svg | 3228 ++++++++--------- .../ConformalTraining/ConformalTraining.jl | 1 + .../ConformalTraining/losses.jl | 9 +- .../ConformalTraining/smooth_quantile.jl | 32 + 22 files changed, 7115 insertions(+), 7188 deletions(-) create mode 100644 src/conformal_models/ConformalTraining/smooth_quantile.jl diff --git a/_freeze/docs/src/tutorials/plotting/execute-results/md.json b/_freeze/docs/src/tutorials/plotting/execute-results/md.json index e0cfbb2e..5a83ad23 100644 --- a/_freeze/docs/src/tutorials/plotting/execute-results/md.json +++ b/_freeze/docs/src/tutorials/plotting/execute-results/md.json @@ -1,9 +1,10 @@ { - "hash": "c6e3193657b1eee4f55df45354477754", + "hash": "b9cf54ba3e24e6647a3f7153438e30a5", "result": { - "markdown": "---\ntitle: Visualization using `Plots.jl` recipes\n---\n\n\n\n```@meta\nCurrentModule = ConformalPrediction\n```\n\n\n\n\nThis tutorial demonstrates how various custom `Plots.jl` recipes can be used to visually analyze conformal predictors.\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\nusing ConformalPrediction\n```\n:::\n\n\n## Regression\n\n### Visualizing Prediction Intervals\n\nFor conformal regressors, the [`Plots.plot(conf_model::ConformalPrediction.ConformalInterval, fitresult, X, y; kwrgs...)`](@ref) can be used to visualize the prediction intervals for given data points.\n\n#### Univariate Input\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\nusing MLJ\nX, y = make_regression(100, 1; noise=0.3)\n```\n:::\n\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\nEvoTreeRegressor = @load EvoTreeRegressor pkg=EvoTrees\nmodel = EvoTreeRegressor() \nconf_model = conformal_model(model)\nmach = machine(conf_model, X, y)\nfit!(mach)\n```\n:::\n\n\n::: {.cell execution_count=5}\n``` {.julia .cell-code}\nplot(mach.model, mach.fitresult, X, y; input_var=1)\n```\n\n::: {.cell-output .cell-output-display execution_count=6}\n![](plotting_files/figure-commonmark/cell-6-output-1.svg){}\n:::\n:::\n\n\n#### Multivariate Input\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\nusing MLJ\nX, y = @load_boston\nschema(X)\n```\n:::\n\n\n::: {.cell execution_count=7}\n``` {.julia .cell-code}\nEvoTreeRegressor = @load EvoTreeRegressor pkg=EvoTrees\nmodel = EvoTreeRegressor() \nconf_model = conformal_model(model)\nmach = machine(conf_model, X, y)\nfit!(mach)\n```\n:::\n\n\n::: {.cell execution_count=8}\n``` {.julia .cell-code}\ninput_vars = [:Crim, :Age, :Tax]\nnvars = length(input_vars)\nplt_list = []\nfor input_var in input_vars\n plt = plot(mach.model, mach.fitresult, X, y; input_var=input_var, title=input_var)\n push!(plt_list, plt)\nend\nplot(plt_list..., layout=(1,nvars), size=(nvars*200, 200))\n```\n\n::: {.cell-output .cell-output-display execution_count=9}\n![](plotting_files/figure-commonmark/cell-9-output-1.svg){}\n:::\n:::\n\n\n### Visualizing Set Size\n\nTo visualize the set size distribution, the [`Plots.bar(conf_model::ConformalPrediction.ConformalModel, fitresult, X; label=\"\", xtickfontsize=6, kwrgs...)`](@ref) can be used. For regression models the prediction interval widths are stratified into discrete bins.a\n\n::: {.cell execution_count=9}\n``` {.julia .cell-code}\nbar(mach.model, mach.fitresult, X)\n```\n\n::: {.cell-output .cell-output-display execution_count=10}\n![](plotting_files/figure-commonmark/cell-10-output-1.svg){}\n:::\n:::\n\n\n::: {.cell execution_count=10}\n``` {.julia .cell-code}\nEvoTreeRegressor = @load EvoTreeRegressor pkg=EvoTrees\nmodel = EvoTreeRegressor() \nconf_model = conformal_model(model, method=:jackknife_plus)\nmach = machine(conf_model, X, y)\nfit!(mach)\n```\n:::\n\n\n::: {.cell execution_count=11}\n``` {.julia .cell-code}\nbar(mach.model, mach.fitresult, X)\n```\n\n::: {.cell-output .cell-output-display execution_count=12}\n![](plotting_files/figure-commonmark/cell-12-output-1.svg){}\n:::\n:::\n\n\n## Classification\n\n::: {.cell execution_count=12}\n``` {.julia .cell-code}\nKNNClassifier = @load KNNClassifier pkg=NearestNeighborModels\nmodel = KNNClassifier(;K=3)\n```\n:::\n\n\n### Visualizing Predictions\n\n#### Stacked Area Charts\n\nStacked area charts can be used to visualize prediction sets for any conformal classifier.a\n\n::: {.cell execution_count=13}\n``` {.julia .cell-code}\nusing MLJ\nn_input = 4\nX, y = make_blobs(100, n_input)\n```\n:::\n\n\n::: {.cell execution_count=14}\n``` {.julia .cell-code}\nconf_model = conformal_model(model)\nmach = machine(conf_model, X, y)\nfit!(mach)\n```\n:::\n\n\n::: {.cell execution_count=15}\n``` {.julia .cell-code}\nplt_list = []\nfor i in 1:n_input\n plt = areaplot(mach.model, mach.fitresult, X, y; input_var=i, title=\"Input $i\")\n push!(plt_list, plt)\nend\nplot(plt_list..., size=(220*n_input,200), layout=(1, n_input))\n```\n\n::: {.cell-output .cell-output-display execution_count=16}\n![](plotting_files/figure-commonmark/cell-16-output-1.svg){}\n:::\n:::\n\n\n#### Contour Plots for Two-Dimensional Inputs\n\nFor conformal classifiers with exactly two input variables, the [`Plots.contourf(conf_model::ConformalPrediction.ConformalProbabilisticSet, fitresult, X, y; kwrgs...)`](@ref) method can be used to visualize conformal predictions in the two-dimensional feature space.a\n\n::: {.cell execution_count=16}\n``` {.julia .cell-code}\nusing MLJ\nX, y = make_blobs(100, 2)\n```\n:::\n\n\n::: {.cell execution_count=17}\n``` {.julia .cell-code}\nconf_model = conformal_model(model)\nmach = machine(conf_model, X, y)\nfit!(mach)\n```\n:::\n\n\n::: {.cell execution_count=18}\n``` {.julia .cell-code}\np1 = contourf(mach.model, mach.fitresult, X, y)\np2 = contourf(mach.model, mach.fitresult, X, y; plot_set_size=true)\nplot(p1, p2, size=(700,300))\n```\n\n::: {.cell-output .cell-output-display execution_count=19}\n![](plotting_files/figure-commonmark/cell-19-output-1.svg){}\n:::\n:::\n\n\n### Visualizing Set Size\n\nTo visualize the set size distribution, the [`Plots.bar(conf_model::ConformalPrediction.ConformalModel, fitresult, X; label=\"\", xtickfontsize=6, kwrgs...)`](@ref) can be used. Recall that for more adaptive predictors the distribution of set sizes is typically spread out more widely, which reflects that “the procedure is effectively distinguishing between easy and hard inputs” [@angelopoulos2021gentle]. This is desirable: when for a given sample it is difficult to make predictions, this should be reflected in the set size (or interval width in the regression case). Since ‘difficult’ lies on some spectrum that ranges from ‘very easy’ to ‘very difficult’ the set size should very across the spectrum of ‘empty set’ to ‘all labels included’.\n\n::: {.cell execution_count=19}\n``` {.julia .cell-code}\nX, y = make_moons(500; noise=0.15)\nKNNClassifier = @load KNNClassifier pkg=NearestNeighborModels\nmodel = KNNClassifier(;K=50) \n```\n:::\n\n\n::: {.cell execution_count=20}\n``` {.julia .cell-code}\nconf_model = conformal_model(model)\nmach = machine(conf_model, X, y)\nfit!(mach)\n```\n:::\n\n\n::: {.cell execution_count=21}\n``` {.julia .cell-code}\np1 = contourf(mach.model, mach.fitresult, X, y; plot_set_size=true)\np2 = bar(mach.model, mach.fitresult, X)\nplot(p1, p2, size=(700,300))\n```\n\n::: {.cell-output .cell-output-display execution_count=22}\n![](plotting_files/figure-commonmark/cell-22-output-1.svg){}\n:::\n:::\n\n\n::: {.cell execution_count=22}\n``` {.julia .cell-code}\nconf_model = conformal_model(model, method=:adaptive_inductive)\nmach = machine(conf_model, X, y)\nfit!(mach)\n```\n:::\n\n\n::: {.cell execution_count=23}\n``` {.julia .cell-code}\np1 = contourf(mach.model, mach.fitresult, X, y; plot_set_size=true)\np2 = bar(mach.model, mach.fitresult, X)\nplot(p1, p2, size=(700,300))\n```\n\n::: {.cell-output .cell-output-display execution_count=24}\n![](plotting_files/figure-commonmark/cell-24-output-1.svg){}\n:::\n:::\n\n\n", + "engine": "jupyter", + "markdown": "---\ntitle: Visualization using `TaijaPlotting.jl`\n---\n\n\n\n```@meta\nCurrentModule = ConformalPrediction\n```\n\n\n\n\nThis tutorial demonstrates how various custom plotting methods can be used to visually analyze conformal predictors.\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\nusing ConformalPrediction\nusing Plots, TaijaPlotting\n```\n:::\n\n\n## Regression\n\n### Visualizing Prediction Intervals\n\nFor conformal regressors, the [`TaijaPlotting.plot(conf_model::ConformalPrediction.ConformalInterval, fitresult, X, y; kwrgs...)`](@ref) can be used to visualize the prediction intervals for given data points.\n\n#### Univariate Input\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\nusing MLJ\nX, y = make_regression(100, 1; noise=0.3)\n```\n:::\n\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\nEvoTreeRegressor = @load EvoTreeRegressor pkg=EvoTrees\nmodel = EvoTreeRegressor() \nconf_model = conformal_model(model)\nmach = machine(conf_model, X, y)\nfit!(mach)\n```\n:::\n\n\n::: {.cell execution_count=5}\n``` {.julia .cell-code}\nplot(mach.model, mach.fitresult, X, y; input_var=1)\n```\n\n::: {.cell-output .cell-output-display execution_count=6}\n![](plotting_files/figure-commonmark/cell-6-output-1.svg){}\n:::\n:::\n\n\n#### Multivariate Input\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\nusing MLJ\nX, y = @load_boston\nschema(X)\n```\n:::\n\n\n::: {.cell execution_count=7}\n``` {.julia .cell-code}\nEvoTreeRegressor = @load EvoTreeRegressor pkg=EvoTrees\nmodel = EvoTreeRegressor() \nconf_model = conformal_model(model)\nmach = machine(conf_model, X, y)\nfit!(mach)\n```\n:::\n\n\n::: {.cell execution_count=8}\n``` {.julia .cell-code}\ninput_vars = [:Crim, :Age, :Tax]\nnvars = length(input_vars)\nplt_list = []\nfor input_var in input_vars\n plt = plot(mach.model, mach.fitresult, X, y; input_var=input_var, title=input_var)\n push!(plt_list, plt)\nend\nplot(plt_list..., layout=(1,nvars), size=(nvars*200, 200))\n```\n\n::: {.cell-output .cell-output-display execution_count=9}\n![](plotting_files/figure-commonmark/cell-9-output-1.svg){}\n:::\n:::\n\n\n### Visualizing Set Size\n\nTo visualize the set size distribution, the [`TaijaPlotting.bar(conf_model::ConformalPrediction.ConformalModel, fitresult, X; label=\"\", xtickfontsize=6, kwrgs...)`](@ref) can be used. For regression models the prediction interval widths are stratified into discrete bins.a\n\n::: {.cell execution_count=9}\n``` {.julia .cell-code}\nbar(mach.model, mach.fitresult, X)\n```\n\n::: {.cell-output .cell-output-display execution_count=10}\n![](plotting_files/figure-commonmark/cell-10-output-1.svg){}\n:::\n:::\n\n\n::: {.cell execution_count=10}\n``` {.julia .cell-code}\nEvoTreeRegressor = @load EvoTreeRegressor pkg=EvoTrees\nmodel = EvoTreeRegressor() \nconf_model = conformal_model(model, method=:jackknife_plus)\nmach = machine(conf_model, X, y)\nfit!(mach)\n```\n:::\n\n\n::: {.cell execution_count=11}\n``` {.julia .cell-code}\nbar(mach.model, mach.fitresult, X)\n```\n\n::: {.cell-output .cell-output-display execution_count=12}\n![](plotting_files/figure-commonmark/cell-12-output-1.svg){}\n:::\n:::\n\n\n## Classification\n\n::: {.cell execution_count=12}\n``` {.julia .cell-code}\nKNNClassifier = @load KNNClassifier pkg=NearestNeighborModels\nmodel = KNNClassifier(;K=3)\n```\n:::\n\n\n### Visualizing Predictions\n\n#### Stacked Area Charts\n\nStacked area charts can be used to visualize prediction sets for any conformal classifier.a\n\n::: {.cell execution_count=13}\n``` {.julia .cell-code}\nusing MLJ\nn_input = 4\nX, y = make_blobs(100, n_input)\n```\n:::\n\n\n::: {.cell execution_count=14}\n``` {.julia .cell-code}\nconf_model = conformal_model(model)\nmach = machine(conf_model, X, y)\nfit!(mach)\n```\n:::\n\n\n::: {.cell execution_count=15}\n``` {.julia .cell-code}\nplt_list = []\nfor i in 1:n_input\n plt = areaplot(mach.model, mach.fitresult, X, y; input_var=i, title=\"Input $i\")\n push!(plt_list, plt)\nend\nplot(plt_list..., size=(220*n_input,200), layout=(1, n_input))\n```\n\n::: {.cell-output .cell-output-display execution_count=16}\n![](plotting_files/figure-commonmark/cell-16-output-1.svg){}\n:::\n:::\n\n\n#### Contour Plots for Two-Dimensional Inputs\n\nFor conformal classifiers with exactly two input variables, the [`TaijaPlotting.contourf(conf_model::ConformalPrediction.ConformalProbabilisticSet, fitresult, X, y; kwrgs...)`](@ref) method can be used to visualize conformal predictions in the two-dimensional feature space.a\n\n::: {.cell execution_count=16}\n``` {.julia .cell-code}\nusing MLJ\nX, y = make_blobs(100, 2)\n```\n:::\n\n\n::: {.cell execution_count=17}\n``` {.julia .cell-code}\nconf_model = conformal_model(model)\nmach = machine(conf_model, X, y)\nfit!(mach)\n```\n:::\n\n\n::: {.cell execution_count=18}\n``` {.julia .cell-code}\np1 = contourf(mach.model, mach.fitresult, X, y)\np2 = contourf(mach.model, mach.fitresult, X, y; plot_set_size=true)\nplot(p1, p2, size=(700,300))\n```\n\n::: {.cell-output .cell-output-display execution_count=19}\n![](plotting_files/figure-commonmark/cell-19-output-1.svg){}\n:::\n:::\n\n\n### Visualizing Set Size\n\nTo visualize the set size distribution, the [`TaijaPlotting.bar(conf_model::ConformalPrediction.ConformalModel, fitresult, X; label=\"\", xtickfontsize=6, kwrgs...)`](@ref) can be used. Recall that for more adaptive predictors the distribution of set sizes is typically spread out more widely, which reflects that “the procedure is effectively distinguishing between easy and hard inputs” [@angelopoulos2021gentle]. This is desirable: when for a given sample it is difficult to make predictions, this should be reflected in the set size (or interval width in the regression case). Since ‘difficult’ lies on some spectrum that ranges from ‘very easy’ to ‘very difficult’ the set size should very across the spectrum of ‘empty set’ to ‘all labels included’.\n\n::: {.cell execution_count=19}\n``` {.julia .cell-code}\nX, y = make_moons(500; noise=0.15)\nKNNClassifier = @load KNNClassifier pkg=NearestNeighborModels\nmodel = KNNClassifier(;K=50) \n```\n:::\n\n\n::: {.cell execution_count=20}\n``` {.julia .cell-code}\nconf_model = conformal_model(model)\nmach = machine(conf_model, X, y)\nfit!(mach)\n```\n:::\n\n\n::: {.cell execution_count=21}\n``` {.julia .cell-code}\np1 = contourf(mach.model, mach.fitresult, X, y; plot_set_size=true)\np2 = bar(mach.model, mach.fitresult, X)\nplot(p1, p2, size=(700,300))\n```\n\n::: {.cell-output .cell-output-display execution_count=22}\n![](plotting_files/figure-commonmark/cell-22-output-1.svg){}\n:::\n:::\n\n\n::: {.cell execution_count=22}\n``` {.julia .cell-code}\nconf_model = conformal_model(model, method=:adaptive_inductive)\nmach = machine(conf_model, X, y)\nfit!(mach)\n```\n:::\n\n\n::: {.cell execution_count=23}\n``` {.julia .cell-code}\np1 = contourf(mach.model, mach.fitresult, X, y; plot_set_size=true)\np2 = bar(mach.model, mach.fitresult, X)\nplot(p1, p2, size=(700,300))\n```\n\n::: {.cell-output .cell-output-display execution_count=24}\n![](plotting_files/figure-commonmark/cell-24-output-1.svg){}\n:::\n:::\n\n\n", "supporting": [ - "plotting_files/figure-commonmark" + "plotting_files" ], "filters": [] } diff --git a/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-10-output-1.svg b/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-10-output-1.svg index c4544867..b79e9b4c 100644 --- a/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-10-output-1.svg +++ b/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-10-output-1.svg @@ -1,47 +1,47 @@ - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-12-output-1.svg b/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-12-output-1.svg index de825bc7..d8b00cbd 100644 --- a/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-12-output-1.svg +++ b/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-12-output-1.svg @@ -1,59 +1,55 @@ - + - + - + - + - - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-16-output-1.svg b/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-16-output-1.svg index a0f8a3fb..451f679a 100644 --- a/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-16-output-1.svg +++ b/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-16-output-1.svg @@ -1,188 +1,190 @@ - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-19-output-1.svg b/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-19-output-1.svg index cc3f0dcd..1d67165d 100644 --- a/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-19-output-1.svg +++ b/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-19-output-1.svg @@ -1,386 +1,380 @@ - + - + - + - + - - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + - - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + diff --git a/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-22-output-1.svg b/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-22-output-1.svg index a40a0b91..090c250a 100644 --- a/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-22-output-1.svg +++ b/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-22-output-1.svg @@ -1,665 +1,621 @@ - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + diff --git a/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-24-output-1.svg b/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-24-output-1.svg index e0e8bfb9..1ddedb12 100644 --- a/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-24-output-1.svg +++ b/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-24-output-1.svg @@ -1,613 +1,613 @@ - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + diff --git a/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-6-output-1.svg b/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-6-output-1.svg index 8d1f2630..cd470688 100644 --- a/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-6-output-1.svg +++ b/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-6-output-1.svg @@ -1,150 +1,148 @@ - + - + - + - + - - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-9-output-1.svg b/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-9-output-1.svg index 39871d2a..61e488a3 100644 --- a/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-9-output-1.svg +++ b/_freeze/docs/src/tutorials/plotting/figure-commonmark/cell-9-output-1.svg @@ -1,1636 +1,1636 @@ - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/src/tutorials/plotting.md b/docs/src/tutorials/plotting.md index 625b9946..40cd9437 100644 --- a/docs/src/tutorials/plotting.md +++ b/docs/src/tutorials/plotting.md @@ -1,20 +1,21 @@ -# Visualization using `Plots.jl` recipes +# Visualization using `TaijaPlotting.jl` ``` @meta CurrentModule = ConformalPrediction ``` -This tutorial demonstrates how various custom `Plots.jl` recipes can be used to visually analyze conformal predictors. +This tutorial demonstrates how various custom plotting methods can be used to visually analyze conformal predictors. ``` julia using ConformalPrediction +using Plots, TaijaPlotting ``` ## Regression ### Visualizing Prediction Intervals -For conformal regressors, the [`Plots.plot(conf_model::ConformalPrediction.ConformalInterval, fitresult, X, y; kwrgs...)`](@ref) can be used to visualize the prediction intervals for given data points. +For conformal regressors, the [`TaijaPlotting.plot(conf_model::ConformalPrediction.ConformalInterval, fitresult, X, y; kwrgs...)`](@ref) can be used to visualize the prediction intervals for given data points. #### Univariate Input @@ -68,7 +69,7 @@ plot(plt_list..., layout=(1,nvars), size=(nvars*200, 200)) ### Visualizing Set Size -To visualize the set size distribution, the [`Plots.bar(conf_model::ConformalPrediction.ConformalModel, fitresult, X; label="", xtickfontsize=6, kwrgs...)`](@ref) can be used. For regression models the prediction interval widths are stratified into discrete bins.a +To visualize the set size distribution, the [`TaijaPlotting.bar(conf_model::ConformalPrediction.ConformalModel, fitresult, X; label="", xtickfontsize=6, kwrgs...)`](@ref) can be used. For regression models the prediction interval widths are stratified into discrete bins.a ``` julia bar(mach.model, mach.fitresult, X) @@ -128,7 +129,7 @@ plot(plt_list..., size=(220*n_input,200), layout=(1, n_input)) #### Contour Plots for Two-Dimensional Inputs -For conformal classifiers with exactly two input variables, the [`Plots.contourf(conf_model::ConformalPrediction.ConformalProbabilisticSet, fitresult, X, y; kwrgs...)`](@ref) method can be used to visualize conformal predictions in the two-dimensional feature space.a +For conformal classifiers with exactly two input variables, the [`TaijaPlotting.contourf(conf_model::ConformalPrediction.ConformalProbabilisticSet, fitresult, X, y; kwrgs...)`](@ref) method can be used to visualize conformal predictions in the two-dimensional feature space.a ``` julia using MLJ @@ -151,7 +152,7 @@ plot(p1, p2, size=(700,300)) ### Visualizing Set Size -To visualize the set size distribution, the [`Plots.bar(conf_model::ConformalPrediction.ConformalModel, fitresult, X; label="", xtickfontsize=6, kwrgs...)`](@ref) can be used. Recall that for more adaptive predictors the distribution of set sizes is typically spread out more widely, which reflects that “the procedure is effectively distinguishing between easy and hard inputs” (Angelopoulos and Bates 2021). This is desirable: when for a given sample it is difficult to make predictions, this should be reflected in the set size (or interval width in the regression case). Since ‘difficult’ lies on some spectrum that ranges from ‘very easy’ to ‘very difficult’ the set size should very across the spectrum of ‘empty set’ to ‘all labels included’. +To visualize the set size distribution, the [`TaijaPlotting.bar(conf_model::ConformalPrediction.ConformalModel, fitresult, X; label="", xtickfontsize=6, kwrgs...)`](@ref) can be used. Recall that for more adaptive predictors the distribution of set sizes is typically spread out more widely, which reflects that “the procedure is effectively distinguishing between easy and hard inputs” (Angelopoulos and Bates 2021). This is desirable: when for a given sample it is difficult to make predictions, this should be reflected in the set size (or interval width in the regression case). Since ‘difficult’ lies on some spectrum that ranges from ‘very easy’ to ‘very difficult’ the set size should very across the spectrum of ‘empty set’ to ‘all labels included’. ``` julia X, y = make_moons(500; noise=0.15) diff --git a/docs/src/tutorials/plotting.qmd b/docs/src/tutorials/plotting.qmd index aeb79a96..6786623c 100644 --- a/docs/src/tutorials/plotting.qmd +++ b/docs/src/tutorials/plotting.qmd @@ -1,4 +1,4 @@ -# Visualization using `Plots.jl` recipes +# Visualization using `TaijaPlotting.jl` ```@meta CurrentModule = ConformalPrediction @@ -13,17 +13,18 @@ using Random Random.seed!(2022) ``` -This tutorial demonstrates how various custom `Plots.jl` recipes can be used to visually analyze conformal predictors. +This tutorial demonstrates how various custom plotting methods can be used to visually analyze conformal predictors. ```{julia} using ConformalPrediction +using Plots, TaijaPlotting ``` ## Regression ### Visualizing Prediction Intervals -For conformal regressors, the [`Plots.plot(conf_model::ConformalPrediction.ConformalInterval, fitresult, X, y; kwrgs...)`](@ref) can be used to visualize the prediction intervals for given data points. +For conformal regressors, the [`TaijaPlotting.plot(conf_model::ConformalPrediction.ConformalInterval, fitresult, X, y; kwrgs...)`](@ref) can be used to visualize the prediction intervals for given data points. #### Univariate Input @@ -77,7 +78,7 @@ plot(plt_list..., layout=(1,nvars), size=(nvars*200, 200)) ### Visualizing Set Size -To visualize the set size distribution, the [`Plots.bar(conf_model::ConformalPrediction.ConformalModel, fitresult, X; label="", xtickfontsize=6, kwrgs...)`](@ref) can be used. For regression models the prediction interval widths are stratified into discrete bins.a +To visualize the set size distribution, the [`TaijaPlotting.bar(conf_model::ConformalPrediction.ConformalModel, fitresult, X; label="", xtickfontsize=6, kwrgs...)`](@ref) can be used. For regression models the prediction interval widths are stratified into discrete bins.a ```{julia} #| output: true @@ -137,7 +138,7 @@ plot(plt_list..., size=(220*n_input,200), layout=(1, n_input)) #### Contour Plots for Two-Dimensional Inputs -For conformal classifiers with exactly two input variables, the [`Plots.contourf(conf_model::ConformalPrediction.ConformalProbabilisticSet, fitresult, X, y; kwrgs...)`](@ref) method can be used to visualize conformal predictions in the two-dimensional feature space.a +For conformal classifiers with exactly two input variables, the [`TaijaPlotting.contourf(conf_model::ConformalPrediction.ConformalProbabilisticSet, fitresult, X, y; kwrgs...)`](@ref) method can be used to visualize conformal predictions in the two-dimensional feature space.a ```{julia} using MLJ @@ -160,7 +161,7 @@ plot(p1, p2, size=(700,300)) ### Visualizing Set Size -To visualize the set size distribution, the [`Plots.bar(conf_model::ConformalPrediction.ConformalModel, fitresult, X; label="", xtickfontsize=6, kwrgs...)`](@ref) can be used. Recall that for more adaptive predictors the distribution of set sizes is typically spread out more widely, which reflects that “the procedure is effectively distinguishing between easy and hard inputs” [@angelopoulos2021gentle]. This is desirable: when for a given sample it is difficult to make predictions, this should be reflected in the set size (or interval width in the regression case). Since ‘difficult’ lies on some spectrum that ranges from ‘very easy’ to ‘very difficult’ the set size should very across the spectrum of ‘empty set’ to ‘all labels included’. +To visualize the set size distribution, the [`TaijaPlotting.bar(conf_model::ConformalPrediction.ConformalModel, fitresult, X; label="", xtickfontsize=6, kwrgs...)`](@ref) can be used. Recall that for more adaptive predictors the distribution of set sizes is typically spread out more widely, which reflects that “the procedure is effectively distinguishing between easy and hard inputs” [@angelopoulos2021gentle]. This is desirable: when for a given sample it is difficult to make predictions, this should be reflected in the set size (or interval width in the regression case). Since ‘difficult’ lies on some spectrum that ranges from ‘very easy’ to ‘very difficult’ the set size should very across the spectrum of ‘empty set’ to ‘all labels included’. ```{julia} X, y = make_moons(500; noise=0.15) diff --git a/docs/src/tutorials/plotting_files/figure-commonmark/cell-10-output-1.svg b/docs/src/tutorials/plotting_files/figure-commonmark/cell-10-output-1.svg index c4544867..b79e9b4c 100644 --- a/docs/src/tutorials/plotting_files/figure-commonmark/cell-10-output-1.svg +++ b/docs/src/tutorials/plotting_files/figure-commonmark/cell-10-output-1.svg @@ -1,47 +1,47 @@ - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/src/tutorials/plotting_files/figure-commonmark/cell-12-output-1.svg b/docs/src/tutorials/plotting_files/figure-commonmark/cell-12-output-1.svg index de825bc7..d8b00cbd 100644 --- a/docs/src/tutorials/plotting_files/figure-commonmark/cell-12-output-1.svg +++ b/docs/src/tutorials/plotting_files/figure-commonmark/cell-12-output-1.svg @@ -1,59 +1,55 @@ - + - + - + - + - - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/src/tutorials/plotting_files/figure-commonmark/cell-16-output-1.svg b/docs/src/tutorials/plotting_files/figure-commonmark/cell-16-output-1.svg index a0f8a3fb..451f679a 100644 --- a/docs/src/tutorials/plotting_files/figure-commonmark/cell-16-output-1.svg +++ b/docs/src/tutorials/plotting_files/figure-commonmark/cell-16-output-1.svg @@ -1,188 +1,190 @@ - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/src/tutorials/plotting_files/figure-commonmark/cell-19-output-1.svg b/docs/src/tutorials/plotting_files/figure-commonmark/cell-19-output-1.svg index cc3f0dcd..1d67165d 100644 --- a/docs/src/tutorials/plotting_files/figure-commonmark/cell-19-output-1.svg +++ b/docs/src/tutorials/plotting_files/figure-commonmark/cell-19-output-1.svg @@ -1,386 +1,380 @@ - + - + - + - + - - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + - - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/src/tutorials/plotting_files/figure-commonmark/cell-22-output-1.svg b/docs/src/tutorials/plotting_files/figure-commonmark/cell-22-output-1.svg index a40a0b91..090c250a 100644 --- a/docs/src/tutorials/plotting_files/figure-commonmark/cell-22-output-1.svg +++ b/docs/src/tutorials/plotting_files/figure-commonmark/cell-22-output-1.svg @@ -1,665 +1,621 @@ - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + diff --git a/docs/src/tutorials/plotting_files/figure-commonmark/cell-24-output-1.svg b/docs/src/tutorials/plotting_files/figure-commonmark/cell-24-output-1.svg index e0e8bfb9..1ddedb12 100644 --- a/docs/src/tutorials/plotting_files/figure-commonmark/cell-24-output-1.svg +++ b/docs/src/tutorials/plotting_files/figure-commonmark/cell-24-output-1.svg @@ -1,613 +1,613 @@ - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + diff --git a/docs/src/tutorials/plotting_files/figure-commonmark/cell-6-output-1.svg b/docs/src/tutorials/plotting_files/figure-commonmark/cell-6-output-1.svg index 8d1f2630..cd470688 100644 --- a/docs/src/tutorials/plotting_files/figure-commonmark/cell-6-output-1.svg +++ b/docs/src/tutorials/plotting_files/figure-commonmark/cell-6-output-1.svg @@ -1,150 +1,148 @@ - + - + - + - + - - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/src/tutorials/plotting_files/figure-commonmark/cell-9-output-1.svg b/docs/src/tutorials/plotting_files/figure-commonmark/cell-9-output-1.svg index 39871d2a..61e488a3 100644 --- a/docs/src/tutorials/plotting_files/figure-commonmark/cell-9-output-1.svg +++ b/docs/src/tutorials/plotting_files/figure-commonmark/cell-9-output-1.svg @@ -1,1636 +1,1636 @@ - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/conformal_models/ConformalTraining/ConformalTraining.jl b/src/conformal_models/ConformalTraining/ConformalTraining.jl index 405082eb..04885fb5 100644 --- a/src/conformal_models/ConformalTraining/ConformalTraining.jl +++ b/src/conformal_models/ConformalTraining/ConformalTraining.jl @@ -6,6 +6,7 @@ using MLJFlux const default_builder = MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.relu) +include("smooth_quantile.jl") include("losses.jl") include("inductive_classification.jl") include("classifier.jl") diff --git a/src/conformal_models/ConformalTraining/losses.jl b/src/conformal_models/ConformalTraining/losses.jl index bced8404..c14b231c 100644 --- a/src/conformal_models/ConformalTraining/losses.jl +++ b/src/conformal_models/ConformalTraining/losses.jl @@ -1,6 +1,5 @@ using ConformalPrediction: ConformalProbabilisticSet using Flux -using InferOpt: soft_sort_kl using LinearAlgebra using MLJBase using StatsBase @@ -14,8 +13,8 @@ function soft_assignment( conf_model::ConformalProbabilisticSet; temp::Real=0.1, ε::Real=1e-6 ) ε = hasfield(typeof(conf_model.model), :epsilon) ? conf_model.model.epsilon : ε - v = soft_sort_kl(conf_model.scores[:calibration]; ε=ε) - q̂ = qplus(v, conf_model.coverage; sorted=true) + v = conf_model.scores[:calibration] + q̂ = qplus_smooth(v, conf_model.coverage; ε=ε) scores = conf_model.scores[:all] return @.(σ((q̂ - scores) / temp)) end @@ -29,8 +28,8 @@ function soft_assignment( conf_model::ConformalProbabilisticSet, fitresult, X; temp::Real=0.1, ε::Real=1e-6 ) ε = hasfield(typeof(conf_model.model), :epsilon) ? conf_model.model.epsilon : ε - v = soft_sort_kl(conf_model.scores[:calibration]; ε=ε) - q̂ = StatsBase.quantile(v, conf_model.coverage; sorted=true) + v = conf_model.scores[:calibration] + q̂ = qplus_smooth(v, conf_model.coverage; ε=ε) scores = ConformalPrediction.score(conf_model, fitresult, X) return @.(σ((q̂ - scores) / temp)) end diff --git a/src/conformal_models/ConformalTraining/smooth_quantile.jl b/src/conformal_models/ConformalTraining/smooth_quantile.jl new file mode 100644 index 00000000..52116d38 --- /dev/null +++ b/src/conformal_models/ConformalTraining/smooth_quantile.jl @@ -0,0 +1,32 @@ +using InferOpt: soft_sort_kl +using StatsBase + +@doc raw""" + qplus_smooth(v::AbstractArray, coverage::AbstractFloat=0.9) + +Implements the ``\hat{q}_{n,\alpha}^{+}`` finite-sample corrected quantile function as defined in Barber et al. (2020): https://arxiv.org/pdf/1905.02928.pdf. To allow for differentiability, we use the soft sort function from InferOpt.jl. +""" +function qplus_smooth(v::AbstractArray, coverage::AbstractFloat=0.9; ε::Real=1e-6, kwrgs...) + n = length(v) + p̂ = ceil(((n + 1) * coverage)) / n + p̂ = clamp(p̂, 0.0, 1.0) + v = soft_sort_kl(v; ε=ε) # soft sort (differentiable) + q̂ = quantile(v, p̂; sorted=true, kwrgs...) + return q̂ +end + +@doc raw""" + qminus_smooth(v::AbstractArray, coverage::AbstractFloat=0.9) + +Implements the ``\hat{q}_{n,\alpha}^{-}`` finite-sample corrected quantile function as defined in Barber et al. (2020): https://arxiv.org/pdf/1905.02928.pdf. To allow for differentiability, we use the soft sort function from InferOpt.jl. +""" +function qminus_smooth( + v::AbstractArray, coverage::AbstractFloat=0.9; ε::Real=1e-6, kwrgs... +) + n = length(v) + p̂ = floor(((n + 1) * coverage)) / n + p̂ = clamp(p̂, 0.0, 1.0) + v = soft_sort_kl(v; ε=ε) # soft sort (differentiable) + q̂ = quantile(v, p̂; sorted=true, kwrgs...) + return q̂ +end \ No newline at end of file From 236ce7500544bafc8dce23967effe3b90b158f54 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Thu, 9 Nov 2023 12:33:12 +0100 Subject: [PATCH 09/11] formatter --- src/conformal_models/ConformalTraining/smooth_quantile.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/conformal_models/ConformalTraining/smooth_quantile.jl b/src/conformal_models/ConformalTraining/smooth_quantile.jl index 52116d38..499ec7bb 100644 --- a/src/conformal_models/ConformalTraining/smooth_quantile.jl +++ b/src/conformal_models/ConformalTraining/smooth_quantile.jl @@ -29,4 +29,4 @@ function qminus_smooth( v = soft_sort_kl(v; ε=ε) # soft sort (differentiable) q̂ = quantile(v, p̂; sorted=true, kwrgs...) return q̂ -end \ No newline at end of file +end From 1066a98a74807342ce52ac92aaa73bedcb6c30ef Mon Sep 17 00:00:00 2001 From: pat-alt Date: Thu, 9 Nov 2023 13:04:59 +0100 Subject: [PATCH 10/11] docs issue sorted out --- docs/make.jl | 1 - docs/src/reference.md | 21 +- docs/src/reference.qmd | 6 +- docs/src/tutorials/plotting.md | 191 ------------------ .../transductive_regression.jl | 10 +- 5 files changed, 21 insertions(+), 208 deletions(-) delete mode 100644 docs/src/tutorials/plotting.md diff --git a/docs/make.jl b/docs/make.jl index 1e0425d8..b03c1856 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -36,7 +36,6 @@ makedocs(; "Overview" => "tutorials/index.md", "Classification" => "tutorials/classification.md", "Regression" => "tutorials/regression.md", - "Visualizations" => "tutorials/plotting.md", ], "🫡 How-To Guides" => [ "Overview" => "how_to_guides/index.md", diff --git a/docs/src/reference.md b/docs/src/reference.md index 2defbf1a..d87890af 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -1,4 +1,5 @@ -```@meta + +``` @meta CurrentModule = ConformalPrediction ``` @@ -8,35 +9,37 @@ In this reference you will find a detailed overview of the package API. > Reference guides are technical descriptions of the machinery and how to operate it. Reference material is information-oriented. > -> --- [Diátaxis](https://diataxis.fr/reference/) +> — [Diátaxis](https://diataxis.fr/reference/) In other words, you come here because you want to take a very close look at the code 🧐 ## Content -```@contents +``` @contents Pages = ["_reference.md"] ``` ## Index -```@index +``` @index ``` ## Public Interface -```@autodocs +``` @autodocs Modules = [ - ConformalPrediction + ConformalPrediction, + ConformalPrediction.ConformalTraining, ] Private = false ``` ## Internal functions -```@autodocs +``` @autodocs Modules = [ - ConformalPrediction + ConformalPrediction, + ConformalPrediction.ConformalTraining, ] Public = false -``` \ No newline at end of file +``` diff --git a/docs/src/reference.qmd b/docs/src/reference.qmd index 2defbf1a..6875112b 100644 --- a/docs/src/reference.qmd +++ b/docs/src/reference.qmd @@ -27,7 +27,8 @@ Pages = ["_reference.md"] ```@autodocs Modules = [ - ConformalPrediction + ConformalPrediction, + ConformalPrediction.ConformalTraining, ] Private = false ``` @@ -36,7 +37,8 @@ Private = false ```@autodocs Modules = [ - ConformalPrediction + ConformalPrediction, + ConformalPrediction.ConformalTraining, ] Public = false ``` \ No newline at end of file diff --git a/docs/src/tutorials/plotting.md b/docs/src/tutorials/plotting.md deleted file mode 100644 index 40cd9437..00000000 --- a/docs/src/tutorials/plotting.md +++ /dev/null @@ -1,191 +0,0 @@ -# Visualization using `TaijaPlotting.jl` - -``` @meta -CurrentModule = ConformalPrediction -``` - -This tutorial demonstrates how various custom plotting methods can be used to visually analyze conformal predictors. - -``` julia -using ConformalPrediction -using Plots, TaijaPlotting -``` - -## Regression - -### Visualizing Prediction Intervals - -For conformal regressors, the [`TaijaPlotting.plot(conf_model::ConformalPrediction.ConformalInterval, fitresult, X, y; kwrgs...)`](@ref) can be used to visualize the prediction intervals for given data points. - -#### Univariate Input - -``` julia -using MLJ -X, y = make_regression(100, 1; noise=0.3) -``` - -``` julia -EvoTreeRegressor = @load EvoTreeRegressor pkg=EvoTrees -model = EvoTreeRegressor() -conf_model = conformal_model(model) -mach = machine(conf_model, X, y) -fit!(mach) -``` - -``` julia -plot(mach.model, mach.fitresult, X, y; input_var=1) -``` - -![](plotting_files/figure-commonmark/cell-6-output-1.svg) - -#### Multivariate Input - -``` julia -using MLJ -X, y = @load_boston -schema(X) -``` - -``` julia -EvoTreeRegressor = @load EvoTreeRegressor pkg=EvoTrees -model = EvoTreeRegressor() -conf_model = conformal_model(model) -mach = machine(conf_model, X, y) -fit!(mach) -``` - -``` julia -input_vars = [:Crim, :Age, :Tax] -nvars = length(input_vars) -plt_list = [] -for input_var in input_vars - plt = plot(mach.model, mach.fitresult, X, y; input_var=input_var, title=input_var) - push!(plt_list, plt) -end -plot(plt_list..., layout=(1,nvars), size=(nvars*200, 200)) -``` - -![](plotting_files/figure-commonmark/cell-9-output-1.svg) - -### Visualizing Set Size - -To visualize the set size distribution, the [`TaijaPlotting.bar(conf_model::ConformalPrediction.ConformalModel, fitresult, X; label="", xtickfontsize=6, kwrgs...)`](@ref) can be used. For regression models the prediction interval widths are stratified into discrete bins.a - -``` julia -bar(mach.model, mach.fitresult, X) -``` - -![](plotting_files/figure-commonmark/cell-10-output-1.svg) - -``` julia -EvoTreeRegressor = @load EvoTreeRegressor pkg=EvoTrees -model = EvoTreeRegressor() -conf_model = conformal_model(model, method=:jackknife_plus) -mach = machine(conf_model, X, y) -fit!(mach) -``` - -``` julia -bar(mach.model, mach.fitresult, X) -``` - -![](plotting_files/figure-commonmark/cell-12-output-1.svg) - -## Classification - -``` julia -KNNClassifier = @load KNNClassifier pkg=NearestNeighborModels -model = KNNClassifier(;K=3) -``` - -### Visualizing Predictions - -#### Stacked Area Charts - -Stacked area charts can be used to visualize prediction sets for any conformal classifier.a - -``` julia -using MLJ -n_input = 4 -X, y = make_blobs(100, n_input) -``` - -``` julia -conf_model = conformal_model(model) -mach = machine(conf_model, X, y) -fit!(mach) -``` - -``` julia -plt_list = [] -for i in 1:n_input - plt = areaplot(mach.model, mach.fitresult, X, y; input_var=i, title="Input $i") - push!(plt_list, plt) -end -plot(plt_list..., size=(220*n_input,200), layout=(1, n_input)) -``` - -![](plotting_files/figure-commonmark/cell-16-output-1.svg) - -#### Contour Plots for Two-Dimensional Inputs - -For conformal classifiers with exactly two input variables, the [`TaijaPlotting.contourf(conf_model::ConformalPrediction.ConformalProbabilisticSet, fitresult, X, y; kwrgs...)`](@ref) method can be used to visualize conformal predictions in the two-dimensional feature space.a - -``` julia -using MLJ -X, y = make_blobs(100, 2) -``` - -``` julia -conf_model = conformal_model(model) -mach = machine(conf_model, X, y) -fit!(mach) -``` - -``` julia -p1 = contourf(mach.model, mach.fitresult, X, y) -p2 = contourf(mach.model, mach.fitresult, X, y; plot_set_size=true) -plot(p1, p2, size=(700,300)) -``` - -![](plotting_files/figure-commonmark/cell-19-output-1.svg) - -### Visualizing Set Size - -To visualize the set size distribution, the [`TaijaPlotting.bar(conf_model::ConformalPrediction.ConformalModel, fitresult, X; label="", xtickfontsize=6, kwrgs...)`](@ref) can be used. Recall that for more adaptive predictors the distribution of set sizes is typically spread out more widely, which reflects that “the procedure is effectively distinguishing between easy and hard inputs” (Angelopoulos and Bates 2021). This is desirable: when for a given sample it is difficult to make predictions, this should be reflected in the set size (or interval width in the regression case). Since ‘difficult’ lies on some spectrum that ranges from ‘very easy’ to ‘very difficult’ the set size should very across the spectrum of ‘empty set’ to ‘all labels included’. - -``` julia -X, y = make_moons(500; noise=0.15) -KNNClassifier = @load KNNClassifier pkg=NearestNeighborModels -model = KNNClassifier(;K=50) -``` - -``` julia -conf_model = conformal_model(model) -mach = machine(conf_model, X, y) -fit!(mach) -``` - -``` julia -p1 = contourf(mach.model, mach.fitresult, X, y; plot_set_size=true) -p2 = bar(mach.model, mach.fitresult, X) -plot(p1, p2, size=(700,300)) -``` - -![](plotting_files/figure-commonmark/cell-22-output-1.svg) - -``` julia -conf_model = conformal_model(model, method=:adaptive_inductive) -mach = machine(conf_model, X, y) -fit!(mach) -``` - -``` julia -p1 = contourf(mach.model, mach.fitresult, X, y; plot_set_size=true) -p2 = bar(mach.model, mach.fitresult, X) -plot(p1, p2, size=(700,300)) -``` - -![](plotting_files/figure-commonmark/cell-24-output-1.svg) - -Angelopoulos, Anastasios N., and Stephen Bates. 2021. “A Gentle Introduction to Conformal Prediction and Distribution-Free Uncertainty Quantification.” . diff --git a/src/conformal_models/transductive_regression.jl b/src/conformal_models/transductive_regression.jl index 4fed2328..23f9f01e 100644 --- a/src/conformal_models/transductive_regression.jl +++ b/src/conformal_models/transductive_regression.jl @@ -663,7 +663,7 @@ function MMI.predict(conf_model::JackknifePlusAbRegressor, fitresult, Xnew) end # Jackknife_plus_after_bootstrapping_minmax -"Constructor for `JackknifePlusAbPlusMinMaxRegressor`." +"Constructor for `JackknifePlusAbMinMaxRegressor`." mutable struct JackknifePlusAbMinMaxRegressor{Model<:Supervised} <: ConformalInterval model::Model coverage::AbstractFloat @@ -692,7 +692,7 @@ end @doc raw""" MMI.fit(conf_model::JackknifePlusMinMaxAbRegressor, verbosity, X, y) -For the [`JackknifePlusABMinMaxRegressor`](@ref) nonconformity scores are as, +For the [`JackknifePlusAbMinMaxRegressor`](@ref) nonconformity scores are as, `` S_i^{\text{J+MinMax}} = s(X_i, Y_i) = h(agg(\hat\mu_{B_{K(-i)}}(X_i)), Y_i), \ i \in \mathcal{D}_{\text{train}} @@ -767,7 +767,7 @@ function MMI.predict(conf_model::JackknifePlusAbMinMaxRegressor, fitresult, Xnew end # TimeSeries_Regressor_Ensemble_Batch_Prediction_Interval -"Constructor for `TimeSeriesRegressorEnsemble`." +"Constructor for `TimeSeriesRegressorEnsembleBatch`." mutable struct TimeSeriesRegressorEnsembleBatch{Model<:Supervised} <: ConformalInterval model::Model coverage::AbstractFloat @@ -868,9 +868,9 @@ end # Prediction @doc raw""" - MMI.predict(conf_model::TimeSeriesRegressorEnsemble, fitresult, Xnew) + MMI.predict(conf_model::TimeSeriesRegressorEnsembleBatch, fitresult, Xnew) -For the [`TimeSeriesRegressorEnsemble`](@ref) prediction intervals are computed as follows, +For the [`TimeSeriesRegressorEnsembleBatch`](@ref) prediction intervals are computed as follows, `` \hat{C}_{n,\alpha, B}^{J+ab}(X_{n+1}) = \left[ \hat{q}_{n, \alpha}^{-} \{\hat\mu_{agg(-i)}(X_{n+1}) - S_i^{\text{J+ab}} \}, \hat{q}_{n, \alpha}^{+} \{\hat\mu_{agg(-i)}(X_{n+1}) + S_i^{\text{J+ab}}\} \right] , i \in \mathcal{D}_{\text{train}} From 3a67c8ffed2a240ef86621890367931ffb55fbe1 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Thu, 9 Nov 2023 13:15:22 +0100 Subject: [PATCH 11/11] package updates --- docs/Manifest.toml | 134 ++++++++++++++++++--------------------------- 1 file changed, 54 insertions(+), 80 deletions(-) diff --git a/docs/Manifest.toml b/docs/Manifest.toml index c151080d..e7488d27 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -284,11 +284,6 @@ weakdeps = ["UnicodePlots"] [deps.CategoricalDistributions.extensions] UnivariateFiniteDisplayExt = "UnicodePlots" -[[deps.Chain]] -git-tree-sha1 = "8c4920235f6c561e401dfe569beb8b924adad003" -uuid = "8be319e6-bccf-4806-a6f7-6fae938471bc" -version = "0.5.0" - [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] git-tree-sha1 = "710940598100496ad6cbb707e481c28186354197" @@ -440,10 +435,10 @@ uuid = "150eb455-5306-5404-9cee-2592286d6298" version = "0.6.3" [[deps.CounterfactualExplanations]] -deps = ["CSV", "CUDA", "CategoricalArrays", "Chain", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "Plots", "ProgressMeter", "Random", "Serialization", "SnoopPrecompile", "Statistics", "StatsBase", "Tables", "UMAP"] -git-tree-sha1 = "8393721ffa3c9be209a93eb154d0d9fe9ca187d5" +deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "PrecompileTools", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "UUIDs", "cuDNN"] +git-tree-sha1 = "af4687806d81a3265173fad6250e3902eb659f37" uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" -version = "0.1.15" +version = "0.1.31" [deps.CounterfactualExplanations.extensions] MPIExt = "MPI" @@ -1339,10 +1334,10 @@ uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" version = "1.3.1" [[deps.LaplaceRedux]] -deps = ["CSV", "Compat", "ComputationalResources", "DataFrames", "Flux", "LinearAlgebra", "MLJ", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Parameters", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"] -git-tree-sha1 = "ca7a96bd2be5066bb2378b42c0191c672811bfaa" +deps = ["Flux", "LinearAlgebra", "Parameters", "Plots", "Zygote"] +git-tree-sha1 = "a4adebbeafb96d0864b4833c254013f66dc6e0ee" uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478" -version = "0.1.3" +version = "0.1.2" [[deps.Latexify]] deps = ["Formatting", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Printf", "Requires"] @@ -1384,6 +1379,12 @@ git-tree-sha1 = "a560dd966b386ac9ae60bdd3a3d3a326062d3c3e" uuid = "8cdb02fc-e678-4876-92c5-9defec4f444e" version = "0.3.1" +[[deps.LearnAPI]] +deps = ["InteractiveUtils", "Statistics"] +git-tree-sha1 = "ec695822c1faaaa64cee32d0b21505e1977b4809" +uuid = "92ad9a40-7767-427a-9ee6-6e577f1266cb" +version = "0.1.0" + [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" @@ -1534,12 +1535,6 @@ weakdeps = ["CategoricalArrays"] [deps.LossFunctions.extensions] LossFunctionsCategoricalArraysExt = "CategoricalArrays" -[[deps.LsqFit]] -deps = ["Distributions", "ForwardDiff", "LinearAlgebra", "NLSolversBase", "OptimBase", "Random", "StatsBase"] -git-tree-sha1 = "00f475f85c50584b12268675072663dfed5594b2" -uuid = "2fda8390-95c7-5789-9bda-21331edee243" -version = "0.13.0" - [[deps.MAT]] deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] git-tree-sha1 = "ed1cf0a322d78cee07718bed5fd945e2218c35a1" @@ -1570,16 +1565,20 @@ uuid = "64a0f543-368b-4a9a-827a-e71edb2a0b83" version = "0.4.4" [[deps.MLJ]] -deps = ["CategoricalArrays", "ComputationalResources", "Distributed", "Distributions", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlow", "MLJIteration", "MLJModels", "MLJTuning", "OpenML", "Pkg", "ProgressMeter", "Random", "Reexport", "ScientificTypes", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "193f1f1ac77d91eabe1ac81ff48646b378270eef" +deps = ["CategoricalArrays", "ComputationalResources", "Distributed", "Distributions", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlow", "MLJIteration", "MLJModels", "MLJTuning", "OpenML", "Pkg", "ProgressMeter", "Random", "Reexport", "ScientificTypes", "StatisticalMeasures", "Statistics", "StatsBase", "Tables"] +git-tree-sha1 = "58d17a367ee211ade6e53f83a9cc5adf9d26f833" uuid = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" -version = "0.19.5" +version = "0.20.0" [[deps.MLJBase]] -deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "Serialization", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "0b7307d1a7214ec3c0ba305571e713f9492ea984" +deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LearnAPI", "LinearAlgebra", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "Reexport", "ScientificTypes", "Serialization", "StatisticalMeasuresBase", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] +git-tree-sha1 = "6d433d34a1764324cf37a1ddc47dcc42ec05340f" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -version = "0.21.14" +version = "1.0.1" +weakdeps = ["StatisticalMeasures"] + + [deps.MLJBase.extensions] + DefaultMeasuresExt = "StatisticalMeasures" [[deps.MLJDecisionTreeInterface]] deps = ["CategoricalArrays", "DecisionTree", "MLJModelInterface", "Random", "Tables"] @@ -1588,16 +1587,16 @@ uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661" version = "0.4.0" [[deps.MLJEnsembles]] -deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJBase", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatsBase"] -git-tree-sha1 = "95b306ef8108067d26dfde9ff3457d59911cc0d6" +deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatisticalMeasuresBase", "StatsBase"] +git-tree-sha1 = "94403b2c8f692011df6731913376e0e37f6c0fe9" uuid = "50ed68f4-41fd-4504-931a-ed422449fee0" -version = "0.3.3" +version = "0.4.0" [[deps.MLJFlow]] deps = ["MLFlowClient", "MLJBase", "MLJModelInterface"] -git-tree-sha1 = "bceeeb648c9aa2fc6f65f957c688b164d30f2905" +git-tree-sha1 = "dc0de70a794c6d4c1aa4bde8196770c6b6e6b550" uuid = "7b7b8358-b45c-48ea-a8ef-7ca328ad328f" -version = "0.1.1" +version = "0.2.0" [[deps.MLJFlux]] deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"] @@ -1613,9 +1612,9 @@ version = "0.3.5" [[deps.MLJIteration]] deps = ["IterationControl", "MLJBase", "Random", "Serialization"] -git-tree-sha1 = "be6d5c71ab499a59e82d65e00a89ceba8732fcd5" +git-tree-sha1 = "991e10d4c8da49d534e312e8a4fbe56b7ac6f70c" uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55" -version = "0.5.1" +version = "0.6.0" [[deps.MLJLinearModels]] deps = ["DocStringExtensions", "IterativeSolvers", "LinearAlgebra", "LinearMaps", "MLJModelInterface", "Optim", "Parameters"] @@ -1648,10 +1647,10 @@ uuid = "33e4bacb-b9e2-458e-9a13-5d9a90b235fa" version = "0.1.6" [[deps.MLJTuning]] -deps = ["ComputationalResources", "Distributed", "Distributions", "LatinHypercubeSampling", "MLJBase", "ProgressMeter", "Random", "RecipesBase"] -git-tree-sha1 = "02688098bd77827b64ed8ad747c14f715f98cfc4" +deps = ["ComputationalResources", "Distributed", "Distributions", "LatinHypercubeSampling", "MLJBase", "ProgressMeter", "Random", "RecipesBase", "StatisticalMeasuresBase"] +git-tree-sha1 = "44dc126646a15018d7829f020d121b85b4def9bc" uuid = "03970b2e-30c4-11ea-3135-d1576263f10f" -version = "0.7.4" +version = "0.8.0" [[deps.MLStyle]] git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" @@ -1703,10 +1702,10 @@ uuid = "d0879d2d-cac2-40c8-9cee-1863dc0c7391" version = "0.1.2" [[deps.MbedTLS]] -deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "Random", "Sockets"] -git-tree-sha1 = "03a9b9718f5682ecb107ac9f7308991db4ce395b" +deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "NetworkOptions", "Random", "Sockets"] +git-tree-sha1 = "f512dc13e64e96f703fd92ce617755ee6b5adf0f" uuid = "739be429-bea8-5141-9913-cc70e7f3736d" -version = "1.1.7" +version = "1.1.8" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] @@ -1814,12 +1813,6 @@ git-tree-sha1 = "eda490d06b9f7c00752ee81cfa451efe55521e21" uuid = "c020b1a1-e9b0-503a-9c33-f039bfc54a85" version = "1.0.0" -[[deps.NearestNeighborDescent]] -deps = ["DataStructures", "Distances", "Graphs", "Random", "Reexport", "SparseArrays"] -git-tree-sha1 = "b7d4bd2ab58f0c3a001fd6eedc2e0aac8e278152" -uuid = "dd2c4c9e-a32f-5b2f-b342-08c2f244fce8" -version = "0.3.6" - [[deps.NearestNeighborModels]] deps = ["Distances", "FillArrays", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "NearestNeighbors", "Statistics", "StatsBase", "Tables"] git-tree-sha1 = "e411143a8362926e4284a54e745972e939fbab78" @@ -1939,12 +1932,6 @@ git-tree-sha1 = "e3a6546c1577bfd701771b477b794a52949e7594" uuid = "429524aa-4258-5aef-a3af-852621145aeb" version = "1.7.6" -[[deps.OptimBase]] -deps = ["NLSolversBase", "Printf", "Reexport"] -git-tree-sha1 = "9cb1fee807b599b5f803809e85c81b582d2009d6" -uuid = "87e2bd06-a317-5318-96d9-3ecbac512eee" -version = "2.0.2" - [[deps.Optimisers]] deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] git-tree-sha1 = "c1fc26bab5df929a5172f296f25d7d08688fd25b" @@ -2417,12 +2404,6 @@ git-tree-sha1 = "2da10356e31327c7096832eb9cd86307a50b1eb6" uuid = "45858cf5-a6b0-47a3-bbea-62219f50df47" version = "0.1.3" -[[deps.SnoopPrecompile]] -deps = ["Preferences"] -git-tree-sha1 = "e760a70afdcd461cf01a575947738d359234665c" -uuid = "66db9d55-30c0-4569-8b51-7e840670fc0c" -version = "1.0.3" - [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" @@ -2502,6 +2483,23 @@ git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" version = "1.4.2" +[[deps.StatisticalMeasures]] +deps = ["CategoricalArrays", "CategoricalDistributions", "Distributions", "LearnAPI", "LinearAlgebra", "MacroTools", "OrderedCollections", "PrecompileTools", "ScientificTypesBase", "StatisticalMeasuresBase", "Statistics", "StatsBase"] +git-tree-sha1 = "b58c7cc3d7de6c0d75d8437b81481af924970123" +uuid = "a19d573c-0a75-4610-95b3-7071388c7541" +version = "0.1.3" +weakdeps = ["LossFunctions", "ScientificTypes"] + + [deps.StatisticalMeasures.extensions] + LossFunctionsExt = "LossFunctions" + ScientificTypesExt = "ScientificTypes" + +[[deps.StatisticalMeasuresBase]] +deps = ["CategoricalArrays", "InteractiveUtils", "MLUtils", "MacroTools", "OrderedCollections", "PrecompileTools", "ScientificTypesBase", "Statistics"] +git-tree-sha1 = "17dfb22e2e4ccc9cd59b487dce52883e0151b4d3" +uuid = "c062fc1d-0d66-479b-b6ac-8b44719de4cc" +version = "0.1.1" + [[deps.StatisticalTraits]] deps = ["ScientificTypesBase"] git-tree-sha1 = "30b9236691858e13f167ce829490a68e1a597782" @@ -2643,9 +2641,9 @@ version = "1.11.1" [[deps.TaijaPlotting]] deps = ["CategoricalArrays", "ConformalPrediction", "CounterfactualExplanations", "DataAPI", "Distributions", "Flux", "LaplaceRedux", "LinearAlgebra", "MLJBase", "ManifoldLearning", "MultivariateStats", "NaturalSort", "NearestNeighborModels", "Plots"] -git-tree-sha1 = "f86ed2cbb9e9a08b2fe19f44d6c0a1266d05a2f4" +git-tree-sha1 = "1202acdbf670f1682f0f5a3abdfcd8f5ce3e0df4" uuid = "bd7198b4-c7d6-400c-9bab-9a24614b0240" -version = "1.0.3" +version = "1.0.4" [[deps.Tar]] deps = ["ArgTools", "SHA"] @@ -2738,35 +2736,11 @@ git-tree-sha1 = "eae1bb484cd63b36999ee58be2de6c178105112f" uuid = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" version = "0.1.8" -[[deps.Tullio]] -deps = ["DiffRules", "LinearAlgebra", "Requires"] -git-tree-sha1 = "6d476962ba4e435d7f4101a403b1d3d72afe72f3" -uuid = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc" -version = "0.3.7" - - [deps.Tullio.extensions] - TullioCUDAExt = "CUDA" - TullioChainRulesCoreExt = "ChainRulesCore" - TullioFillArraysExt = "FillArrays" - TullioTrackerExt = "Tracker" - - [deps.Tullio.weakdeps] - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - [[deps.TupleTools]] git-tree-sha1 = "155515ed4c4236db30049ac1495e2969cc06be9d" uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" version = "1.4.3" -[[deps.UMAP]] -deps = ["Arpack", "Distances", "LinearAlgebra", "LsqFit", "NearestNeighborDescent", "Random", "SparseArrays"] -git-tree-sha1 = "accad220f075445f68caa6488be728957a5d82d6" -uuid = "c4f8c510-2410-5be4-91d7-4fbaeb39457e" -version = "0.1.10" - [[deps.URIs]] git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4"