diff --git a/Project.toml b/Project.toml index f243bfa..040764f 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.1.8" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -16,8 +17,11 @@ MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] CategoricalArrays = "0.10" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 37d10a6..2d0743e 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.9.2" +julia_version = "1.9.0" manifest_format = "2.0" -project_hash = "e1685a29d6d370eab88233cc5ac9d849e2f3994f" +project_hash = "006313002b3cecd4f4a1095930b3c91deb25115f" [[deps.ANSIColoredPrinters]] git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" @@ -17,13 +17,14 @@ version = "1.4.1" [[deps.AbstractFFTs]] deps = ["LinearAlgebra"] -git-tree-sha1 = "cad4c758c0038eea30394b1b671526921ca85b21" +git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.4.0" -weakdeps = ["ChainRulesCore"] +version = "1.5.0" +weakdeps = ["ChainRulesCore", "Test"] [deps.AbstractFFTs.extensions] AbstractFFTsChainRulesCoreExt = "ChainRulesCore" + AbstractFFTsTestExt = "Test" [[deps.AbstractPlutoDingetjes]] deps = ["Pkg"] @@ -184,10 +185,10 @@ uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" version = "1.2.0" [[deps.BytePairEncoding]] -deps = ["StructWalk", "TextEncodeBase", "Unicode"] -git-tree-sha1 = "40ee2783de5efc5b478e1bb828b750ad8ce1714f" +deps = ["DoubleArrayTries", "StructWalk", "TextEncodeBase", "Unicode"] +git-tree-sha1 = "91752c465dfbdd55837a18f9aa9e6d20899658e9" uuid = "a4280ba5-8788-555a-8ca8-4a8c3d966a71" -version = "0.3.1" +version = "0.3.2" [[deps.Bzip2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -315,9 +316,9 @@ version = "0.1.12" [[deps.Clustering]] deps = ["Distances", "LinearAlgebra", "NearestNeighbors", "Printf", "Random", "SparseArrays", "Statistics", "StatsBase"] -git-tree-sha1 = "42fe66dbc8f1d09a44aa87f18d26926d06a35f84" +git-tree-sha1 = "b86ac2c5543660d238957dbde5ac04520ae977a7" uuid = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" -version = "0.15.3" +version = "0.15.4" [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] @@ -327,9 +328,9 @@ version = "0.7.2" [[deps.ColorSchemes]] deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "dd3000d954d483c1aad05fe1eb9e6a715c97013e" +git-tree-sha1 = "d9a8f86737b665e15a9641ecbac64deef9ce6724" uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.22.0" +version = "3.23.0" [[deps.ColorTypes]] deps = ["FixedPointNumbers", "Random"] @@ -366,9 +367,9 @@ version = "0.3.0" [[deps.Compat]] deps = ["UUIDs"] -git-tree-sha1 = "5ce999a19f4ca23ea484e92a1774a61b8ca4cf8e" +git-tree-sha1 = "e460f044ca8b99be31d35fe54fc33a5c33dd8ed7" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.8.0" +version = "4.9.0" weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] @@ -377,7 +378,7 @@ weakdeps = ["Dates", "LinearAlgebra"] [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.5+0" +version = "1.0.2+0" [[deps.CompositionsBase]] git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" @@ -403,7 +404,7 @@ version = "2.2.1" [[deps.ConformalPrediction]] deps = ["CategoricalArrays", "ChainRules", "Flux", "LazyArtifacts", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "NaturalSort", "Plots", "Serialization", "StatsBase"] -path = ".." +git-tree-sha1 = "d4ce78a2a13fa4880daf7051a39f888509788d7d" uuid = "98bfc277-1877-43dc-819b-a3e38c30242f" version = "0.1.8" @@ -470,9 +471,9 @@ version = "1.6.1" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "cf25ccb972fec4e4817764d01c82386ae94f77b4" +git-tree-sha1 = "3dbd312d370723b6bb43ba9d02fc36abade4518d" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.14" +version = "0.18.15" [[deps.DataValueInterfaces]] git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" @@ -527,10 +528,10 @@ deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[deps.Distributions]] -deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "Test"] -git-tree-sha1 = "e76a3281de2719d7c81ed62c6ea7057380c87b1d" +deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "Test"] +git-tree-sha1 = "938fe2981db009f531b6332e31c58e9584a2f9bd" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.98" +version = "0.25.100" [deps.Distributions.extensions] DistributionsChainRulesCoreExt = "ChainRulesCore" @@ -571,9 +572,9 @@ version = "0.6.8" [[deps.DynamicExpressions]] deps = ["Compat", "LinearAlgebra", "LoopVectorization", "MacroTools", "PackageExtensionCompat", "PrecompileTools", "Printf", "Random", "Reexport", "TOML"] -git-tree-sha1 = "a98488649931b24f320dd737a9adf6eb49b5deda" +git-tree-sha1 = "4f59922b1b80847959c5b1987437d2fcd83c8788" uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b" -version = "0.11.0" +version = "0.12.3" [deps.DynamicExpressions.extensions] DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils" @@ -584,10 +585,10 @@ version = "0.11.0" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [[deps.DynamicQuantities]] -deps = ["Compat", "LinearAlgebra", "Requires", "SparseArrays", "Tricks"] -git-tree-sha1 = "c9c36325b0b29b479725d0d8b368081e4b053933" +deps = ["Compat", "LinearAlgebra", "PackageExtensionCompat", "SparseArrays", "Tricks"] +git-tree-sha1 = "2a8346673552625beb26680d20b4cfd9285167d0" uuid = "06fc5a27-2a28-4c7c-a15d-362465fb6821" -version = "0.6.2" +version = "0.6.3" weakdeps = ["ScientificTypes", "ScientificTypesBase", "Unitful"] [deps.DynamicQuantities.extensions] @@ -608,9 +609,9 @@ version = "0.3.0" [[deps.EvoTrees]] deps = ["BSON", "CUDA", "CategoricalArrays", "Distributions", "MLJModelInterface", "NetworkLayout", "Random", "RecipesBase", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "1b63fdc0acad47c3203398171c138835c1c40d69" +git-tree-sha1 = "1b418518c0eb1fd1ef0a6d0bfc8051e6abb1232b" uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" -version = "0.15.0" +version = "0.15.2" [[deps.ExceptionUnwrapping]] deps = ["Test"] @@ -625,9 +626,9 @@ uuid = "2e619515-83b5-522b-bb60-26c02a35a201" version = "2.5.0+0" [[deps.ExprTools]] -git-tree-sha1 = "c1d06d129da9f55715c6c212866f5b1bddc5fa00" +git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" -version = "0.1.9" +version = "0.1.10" [[deps.Extents]] git-tree-sha1 = "5e1e4c53fa39afe63a7d356e30452249365fba99" @@ -759,9 +760,9 @@ version = "0.4.2" [[deps.ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "00e252f4d706b3d55a8863432e742bf5717b498d" +git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.35" +version = "0.10.36" weakdeps = ["StaticArrays"] [deps.ForwardDiff.extensions] @@ -932,9 +933,9 @@ version = "0.3.1" [[deps.HostCPUFeatures]] deps = ["BitTwiddlingConvenienceFunctions", "IfElse", "Libdl", "Static"] -git-tree-sha1 = "d38bd0d9759e3c6cfa19bdccc314eccf8ce596cc" +git-tree-sha1 = "eb8fed28f4994600e29beef49744639d985a04b2" uuid = "3e5b6fbb-0976-4d2c-9146-d79de83f2fb0" -version = "0.1.15" +version = "0.1.16" [[deps.HuggingFaceApi]] deps = ["Dates", "Downloads", "JSON3", "LibGit2", "OhMyArtifacts", "Pkg", "SHA"] @@ -944,9 +945,9 @@ version = "0.1.0" [[deps.HypergeometricFunctions]] deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] -git-tree-sha1 = "83e95aaab9dc184a6dcd9c4c52aa0dc26cd14a1d" +git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" -version = "0.3.21" +version = "0.3.23" [[deps.Hyperscript]] deps = ["Test"] @@ -1021,9 +1022,9 @@ version = "0.2.17" [[deps.ImageFiltering]] deps = ["CatIndices", "ComputationalResources", "DataStructures", "FFTViews", "FFTW", "ImageBase", "ImageCore", "LinearAlgebra", "OffsetArrays", "PrecompileTools", "Reexport", "SparseArrays", "StaticArrays", "Statistics", "TiledIteration"] -git-tree-sha1 = "c371a39622dc3b941ffd7c00e6b519d63b3f3f06" +git-tree-sha1 = "432ae2b430a18c58eb7eca9ef8d0f2db90bc749c" uuid = "6a3955dd-da59-5b1f-98d4-e7296123deb5" -version = "0.7.7" +version = "0.7.8" [[deps.ImageIO]] deps = ["FileIO", "IndirectArrays", "JpegTurbo", "LazyModules", "Netpbm", "OpenEXR", "PNGFiles", "QOI", "Sixel", "TiffImages", "UUIDs"] @@ -1120,9 +1121,9 @@ version = "0.1.5" [[deps.IntelOpenMP_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "0cb9352ef2e01574eeebdb102948a58740dcaf83" +git-tree-sha1 = "ad37c091f7d7daf900963171600d7c1c5c3ede32" uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0" -version = "2023.1.0+0" +version = "2023.2.0+0" [[deps.InteractiveUtils]] deps = ["Markdown"] @@ -1141,10 +1142,14 @@ uuid = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" version = "0.14.7" [[deps.IntervalSets]] -deps = ["Dates", "Random", "Statistics"] -git-tree-sha1 = "16c0cc91853084cb5f58a78bd209513900206ce6" +deps = ["Dates", "Random"] +git-tree-sha1 = "8e59ea773deee525c99a8018409f64f19fb719e6" uuid = "8197267c-284f-5f27-9208-e0e47529a953" -version = "0.7.4" +version = "0.7.7" +weakdeps = ["Statistics"] + + [deps.IntervalSets.extensions] + IntervalSetsStatisticsExt = "Statistics" [[deps.InvertedIndices]] git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" @@ -1204,9 +1209,9 @@ version = "0.21.4" [[deps.JSON3]] deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] -git-tree-sha1 = "5b62d93f2582b09e469b3099d839c2d2ebf5066d" +git-tree-sha1 = "95220473901735a0f4df9d1ca5b171b568b2daa3" uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" -version = "1.13.1" +version = "1.13.2" [[deps.JpegTurbo]] deps = ["CEnum", "FileIO", "ImageCore", "JpegTurbo_jll", "TOML"] @@ -1462,9 +1467,9 @@ weakdeps = ["ChainRulesCore", "ForwardDiff", "SpecialFunctions"] [[deps.LossFunctions]] deps = ["Markdown", "Requires", "Statistics"] -git-tree-sha1 = "065084a6e63bb30b622b46c613a8f61116787174" +git-tree-sha1 = "c2b72b61d2e3489b1f9cae3403226b21ec90c943" uuid = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" -version = "0.10.1" +version = "0.11.0" weakdeps = ["CategoricalArrays"] [deps.LossFunctions.extensions] @@ -1483,9 +1488,9 @@ version = "0.1.4" [[deps.MKL_jll]] deps = ["Artifacts", "IntelOpenMP_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] -git-tree-sha1 = "154d7aaa82d24db6d8f7e4ffcfe596f40bff214b" +git-tree-sha1 = "eb006abbd7041c28e0d16260e50a24f8f9104913" uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7" -version = "2023.1.0+0" +version = "2023.2.0+0" [[deps.MLDatasets]] deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"] @@ -1501,9 +1506,9 @@ version = "0.19.2" [[deps.MLJBase]] deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "Serialization", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "9094381ad079dde43c4c74a2f71926232f11cb12" +git-tree-sha1 = "2c9d6b9c627a80f6e6acbc6193026f455581fd04" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -version = "0.21.12" +version = "0.21.13" [[deps.MLJDecisionTreeInterface]] deps = ["CategoricalArrays", "DecisionTree", "MLJModelInterface", "Random", "Tables"] @@ -1549,9 +1554,9 @@ version = "1.8.0" [[deps.MLJModels]] deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "8f2cf0a7147d370d0de402d43f6de0d3473fcd5e" +git-tree-sha1 = "2b49f04f70266a2b040eb46ece157c4f5c1b0c13" uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7" -version = "0.16.9" +version = "0.16.10" [[deps.MLJMultivariateStatsInterface]] deps = ["CategoricalDistributions", "Distances", "LinearAlgebra", "MLJModelInterface", "MultivariateStats", "StatsBase"] @@ -1750,9 +1755,9 @@ version = "1.2.0" [[deps.NeuralAttentionlib]] deps = ["Adapt", "CUDA", "ChainRulesCore", "GPUArrays", "GPUArraysCore", "LinearAlgebra", "NNlib", "NNlibCUDA", "Requires", "Static"] -git-tree-sha1 = "5ee110f3d54e0f29daacc3bdde01b638bf05b9bc" +git-tree-sha1 = "dab54e810d7d9159c73d3f8b43de9e4b98286517" uuid = "12afc1b8-fad6-47e1-9132-84abc478905f" -version = "0.2.10" +version = "0.2.11" [[deps.Observables]] git-tree-sha1 = "6862738f9796b3edc1c09d0890afce4eca9e7e93" @@ -1819,9 +1824,9 @@ version = "1.4.1" [[deps.OpenSSL_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "1aa4b74f80b01c6bc2b89992b861b5f210e665b5" +git-tree-sha1 = "bbb5c2115d63c2f1451cb70e5ef75e8fe4707019" uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "1.1.21+0" +version = "1.1.22+0" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -1870,9 +1875,9 @@ uuid = "f57f5aa1-a3ce-4bc8-8ab9-96f992907883" version = "0.4.0" [[deps.PackageExtensionCompat]] -git-tree-sha1 = "32f3d52212a8d1c5d589a58851b1f04c97339110" +git-tree-sha1 = "f9b1e033c2b1205cf30fd119f4e50881316c1923" uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" -version = "1.0.0" +version = "1.0.1" weakdeps = ["Requires", "TOML"] [[deps.PaddedViews]] @@ -1889,9 +1894,9 @@ version = "0.12.3" [[deps.Parsers]] deps = ["Dates", "PrecompileTools", "UUIDs"] -git-tree-sha1 = "4b2e829ee66d4218e0cef22c0a64ee37cf258c29" +git-tree-sha1 = "716e24b21538abc91f6205fd1d8363f39b442851" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.7.1" +version = "2.7.2" [[deps.PartialFunctions]] git-tree-sha1 = "b3901ea034cfd8aae57a2fa0dde0b0ea18bad1cb" @@ -1924,7 +1929,7 @@ version = "0.42.2+0" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.9.2" +version = "1.9.0" [[deps.PkgVersion]] deps = ["Pkg"] @@ -2459,9 +2464,9 @@ version = "5.10.1+6" [[deps.SymbolicRegression]] deps = ["Compat", "Dates", "Distributed", "DynamicExpressions", "DynamicQuantities", "LineSearches", "LossFunctions", "MLJModelInterface", "MacroTools", "Optim", "PackageExtensionCompat", "Pkg", "PrecompileTools", "Printf", "ProgressBars", "Random", "Reexport", "SpecialFunctions", "StatsBase", "TOML", "Tricks"] -git-tree-sha1 = "c6eb021000ff7bf49e96d62e6d75070149cf2323" +git-tree-sha1 = "83d12323cc7cd5b9800cb0c27e5a7b0fdda58438" uuid = "8254be44-1295-4e6a-a16d-46603ac705cb" -version = "0.21.2" +version = "0.22.2" [deps.SymbolicRegression.extensions] SymbolicRegressionJSON3Ext = "JSON3" @@ -2567,9 +2572,9 @@ version = "0.4.78" [[deps.Transformers]] deps = ["Base64", "BytePairEncoding", "CUDA", "ChainRulesCore", "DataDeps", "DataStructures", "Dates", "DelimitedFiles", "DoubleArrayTries", "Fetch", "FillArrays", "Flux", "FuncPipelines", "Functors", "HTTP", "HuggingFaceApi", "JSON3", "LightXML", "LinearAlgebra", "Mmap", "NNlib", "NNlibCUDA", "NeuralAttentionlib", "Pickle", "Pkg", "PrimitiveOneHot", "Random", "SHA", "Static", "Statistics", "StringViews", "StructWalk", "TextEncodeBase", "Unicode", "ValSplit", "WordTokenizers", "Zygote"] -git-tree-sha1 = "35b63543a154cea7e9068f45e67c5fdb7467f2ed" +git-tree-sha1 = "efa5d0441f3f9e6f3e9e5f64e5dacb736edd13ad" uuid = "21ca0261-441d-5938-ace7-c90938fde4d4" -version = "0.2.6" +version = "0.2.7" [[deps.Tricks]] git-tree-sha1 = "aadb748be58b492045b4f56166b5188aa63ce549" @@ -2582,9 +2587,9 @@ uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" version = "1.3.0" [[deps.URIs]] -git-tree-sha1 = "074f993b0ca030848b897beff716d93aca60f06a" +git-tree-sha1 = "b7a5e99f24892b6824a954199a45e9ffcc1c70f0" uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" -version = "1.4.2" +version = "1.5.0" [[deps.UUIDs]] deps = ["Random", "SHA"] @@ -2627,9 +2632,9 @@ version = "3.6.0" [[deps.Unitful]] deps = ["Dates", "LinearAlgebra", "Random"] -git-tree-sha1 = "c4d2a349259c8eba66a00a540d550f122a3ab228" +git-tree-sha1 = "64eb17acef1d9734cf09967539818f38093d9b35" uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" -version = "1.15.0" +version = "1.16.2" [deps.Unitful.extensions] ConstructionBaseUnitfulExt = "ConstructionBase" @@ -2871,9 +2876,9 @@ version = "1.5.5+0" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "5be3ddb88fc992a7d8ea96c3f10a49a7e98ebc7b" +git-tree-sha1 = "e2fe78907130b521619bc88408c859a472c4172b" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.62" +version = "0.6.63" [deps.Zygote.extensions] ZygoteColorsExt = "Colors" @@ -2918,7 +2923,7 @@ version = "0.15.1+0" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+0" +version = "5.7.0+0" [[deps.libfdk_aac_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] diff --git a/docs/Project.toml b/docs/Project.toml index 0db2a5b..c8b9ecf 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -23,6 +23,7 @@ MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MLJMultivariateStatsInterface = "1b6a4a23-ba22-4f51-9698-8599985d3728" MLJNaiveBayesInterface = "33e4bacb-b9e2-458e-9a13-5d9a90b235fa" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Measures = "442fdcdd-2543-5da2-b0f3-8c86c306513e" NaiveBayes = "9bbee03b-0db5-5f46-924f-b5c9c21b8c60" NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85" diff --git a/docs/pluto/intro.jl b/docs/pluto/intro.jl index 996cf91..f664571 100644 --- a/docs/pluto/intro.jl +++ b/docs/pluto/intro.jl @@ -7,7 +7,14 @@ using InteractiveUtils # This Pluto notebook uses @bind for interactivity. When running this notebook outside of Pluto, the following 'mock version' of @bind gives bound variables a default value (instead of an error). macro bind(def, element) quote - local iv = try Base.loaded_modules[Base.PkgId(Base.UUID("6e696c72-6542-2067-7265-42206c756150"), "AbstractPlutoDingetjes")].Bonds.initial_value catch; b -> missing; end + local iv = try + Base.loaded_modules[Base.PkgId( + Base.UUID("6e696c72-6542-2067-7265-42206c756150"), + "AbstractPlutoDingetjes", + )].Bonds.initial_value + catch + b -> missing + end local el = $(esc(element)) global $(esc(def)) = Core.applicable(Base.get, el) ? Base.get(el) : iv(el) el @@ -18,19 +25,19 @@ end # ╠═╡ show_logs = false begin using ConformalPrediction - using ConformalPrediction: set_size + using ConformalPrediction: set_size using Distributions using EvoTrees: EvoTreeRegressor - using LinearAlgebra: norm + using LinearAlgebra: norm using MLJBase - using MLJFlux: NeuralNetworkRegressor + using MLJFlux: NeuralNetworkRegressor using MLJLinearModels using MLJModels using NearestNeighborModels: KNNRegressor using Plots using PlutoUI - using Random - using StatsBase: quantile + using Random + using StatsBase: quantile end; # ╔═╡ bc0d7575-dabd-472d-a0ce-db69d242ced8 @@ -46,7 +53,7 @@ Let's start by loading the necessary packages: # helper functions begin - # UI stuff + # UI stuff function multi_slider(vals::Dict; title="") return PlutoUI.combine() do Child inputs = [ @@ -64,53 +71,57 @@ begin end end - # SCP illustration: - function illustrate_scp(;cov=0.9, xcord=0.0, ycord=0.0) - # Data: - n = 250 - D = 5 - X, y = make_blobs(n, 2; centers=D, cluster_std=2.0) - train, test = partition(eachindex(y), 0.8, shuffle=true) - - # Model: - KNNClassifier = @load KNNClassifier pkg=NearestNeighborModels - model = KNNClassifier(;K=10) - - # Training: - conf_model = conformal_model(model; coverage=cov) - mach = machine(conf_model, X, y) - fit!(mach, rows=train) - - # Test set: - s_test = predict(mach, rows=test) - Xtest = MLJBase.matrix(X)[test,:] - i_test = argmin(map(X -> norm(X - [xcord,ycord]), eachrow(Xtest))) - x1, x2 = Xtest[i_test,:] - y_test = y[test][i_test] - - # Plotting: - p1 = contourf(mach.model, mach.fitresult, X, y, - cbar=false) - scatter!([x1], [x2], label="Test point", color=:yellow, ms=10, marker=:star) - s = mach.model.scores[:calibration] - n_cal = length(s) - cov = mach.model.coverage - x = 1:n_cal - p2 = bar(x, s, title="Cal. scores: 1 - p̂[y]", label="") - - p3 = bar(x, sort(s), label="", title="(1-α) quantile") - q̂ = quantile(s, cov) - hline!([q̂], label="q̂", lw=3, ls=:dash, color=2) - - p̂ = pdf.(predict(mach.model.model, mach.fitresult, [x1 x2])[1], classes(y)) - p4 = bar(1:D, 1 .- p̂, label="", - title="Test scores: 1-p̂",ylim=(0.0,1.0), - alpha=map(x -> x < q̂ ? 1.0 : 0.2, 1 .- p̂), - linecolor=map(y -> y == y_test ? :yellow : :black ,1:D), - lw = map(y -> y == y_test ? 5 : 1 ,1:D),) - hline!([q̂], label="q̂", lw=3, ls=:dash, color=2) - plot(p1,p2,p3,p4,layout=(2,2),size=(550,500),dpi=300) - end + # SCP illustration: + function illustrate_scp(; cov=0.9, xcord=0.0, ycord=0.0) + # Data: + n = 250 + D = 5 + X, y = make_blobs(n, 2; centers=D, cluster_std=2.0) + train, test = partition(eachindex(y), 0.8; shuffle=true) + + # Model: + KNNClassifier = @load KNNClassifier pkg = NearestNeighborModels + model = KNNClassifier(; K=10) + + # Training: + conf_model = conformal_model(model; coverage=cov) + mach = machine(conf_model, X, y) + fit!(mach; rows=train) + + # Test set: + s_test = predict(mach; rows=test) + Xtest = MLJBase.matrix(X)[test, :] + i_test = argmin(map(X -> norm(X - [xcord, ycord]), eachrow(Xtest))) + x1, x2 = Xtest[i_test, :] + y_test = y[test][i_test] + + # Plotting: + p1 = contourf(mach.model, mach.fitresult, X, y; cbar=false) + scatter!([x1], [x2]; label="Test point", color=:yellow, ms=10, marker=:star) + s = mach.model.scores[:calibration] + n_cal = length(s) + cov = mach.model.coverage + x = 1:n_cal + p2 = bar(x, s; title="Cal. scores: 1 - p̂[y]", label="") + + p3 = bar(x, sort(s); label="", title="(1-α) quantile") + q̂ = quantile(s, cov) + hline!([q̂]; label="q̂", lw=3, ls=:dash, color=2) + + p̂ = pdf.(predict(mach.model.model, mach.fitresult, [x1 x2])[1], classes(y)) + p4 = bar( + 1:D, + 1 .- p̂; + label="", + title="Test scores: 1-p̂", + ylim=(0.0, 1.0), + alpha=map(x -> x < q̂ ? 1.0 : 0.2, 1 .- p̂), + linecolor=map(y -> y == y_test ? :yellow : :black, 1:D), + lw=map(y -> y == y_test ? 5 : 1, 1:D), + ) + hline!([q̂]; label="q̂", lw=3, ls=:dash, color=2) + return plot(p1, p2, p3, p4; layout=(2, 2), size=(550, 500), dpi=300) + end end; # ╔═╡ be8b2fbb-3b3d-496e-9041-9b8f50872350 @@ -145,20 +156,17 @@ The test point is chosen based on the $x$- and $y$-coordinate that can be specif # ╔═╡ ebd25f7a-7cd7-4578-8bd0-1332dd5bc47e begin illu_dict = Dict( - "Coverage" => (0.0:0.05:1.0, 0.9), - "x" => (-20.0:1.0:20.0, 0.0), - "y" => (-20.0:1.0:20.0, 0.0) + "Coverage" => (0.0:0.05:1.0, 0.9), + "x" => (-20.0:1.0:20.0, 0.0), + "y" => (-20.0:1.0:20.0, 0.0), ) @bind illu_specs multi_slider(illu_dict, title="") end # ╔═╡ e1115c04-42c8-4b8a-8845-fc55f63defbf begin - Random.seed!(123) - illustrate_scp(; - cov=illu_specs.Coverage, - xcord=illu_specs.x, - ycord=illu_specs.y) + Random.seed!(123) + illustrate_scp(; cov=illu_specs.Coverage, xcord=illu_specs.x, ycord=illu_specs.y) end # ╔═╡ b47b9bd5-f4c1-439c-b1a8-ef042aa6adc6 @@ -239,8 +247,8 @@ train, test = partition(eachindex(y), 0.4, 0.4; shuffle=true); # ╔═╡ 698b1429-f478-45ac-8799-d73f6a0ab869 begin - model_dict = tested_atomic_models[:regression] - model_dict[:neural_network] = :(@load NeuralNetworkRegressor pkg = MLJFlux) + model_dict = tested_atomic_models[:regression] + model_dict[:neural_network] = :(@load NeuralNetworkRegressor pkg = MLJFlux) end; # ╔═╡ a34b8c07-08e0-4a0e-a0f9-8054b41b038b @@ -359,13 +367,13 @@ In particular, we will call `evaluate!` on our conformal model using `emp_covera # ╔═╡ d1140af9-608a-4669-9595-aee72ffbaa46 begin - holdout = Holdout(fraction_train=0.8) # pro tip: change this to CV + holdout = Holdout(; fraction_train=0.8) # pro tip: change this to CV model_evaluation = evaluate!( - _mach; - operation=MLJBase.predict, - measure=emp_coverage, - verbosity=0, - resampling=holdout, + _mach; + operation=MLJBase.predict, + measure=emp_coverage, + verbosity=0, + resampling=holdout, ) println("Empirical coverage: $(round(model_evaluation.measurement[1], digits=3))") end @@ -409,19 +417,25 @@ Below you can test out other available regression methods. While you may not be # ╔═╡ 45212d6a-2a09-4a6e-aa39-16cc659d1e18 begin - # Predictions: - new_conf_model = conformal_model(model; method=conf_model_name, coverage=new_cov) - new_mach = machine(new_conf_model, X, y) - MLJBase.fit!(new_mach; rows=train, verbosity=0) - p1 = plot(new_mach.model, new_mach.fitresult, Xtest, ytest; - zoom=0, observed_lab="Test points", title="Predictions") - plot!(p1, xrange, @.(f(xrange)); - lw=2, ls=:dash, colour=:black, label="Ground truth") - - # Interval width: - p2 = bar(new_mach.model, new_mach.fitresult, Xtest, title="Interval Width") - - plot(p1, p2, size=(1000,400)) + # Predictions: + new_conf_model = conformal_model(model; method=conf_model_name, coverage=new_cov) + new_mach = machine(new_conf_model, X, y) + MLJBase.fit!(new_mach; rows=train, verbosity=0) + p1 = plot( + new_mach.model, + new_mach.fitresult, + Xtest, + ytest; + zoom=0, + observed_lab="Test points", + title="Predictions", + ) + plot!(p1, xrange, @.(f(xrange)); lw=2, ls=:dash, colour=:black, label="Ground truth") + + # Interval width: + p2 = bar(new_mach.model, new_mach.fitresult, Xtest; title="Interval Width") + + plot(p1, p2; size=(1000, 400)) end # ╔═╡ 74444c01-1a0a-47a7-9b14-749946614f07 diff --git a/docs/pluto/jcon2023.jl b/docs/pluto/jcon2023.jl index 0dd781c..6175335 100644 --- a/docs/pluto/jcon2023.jl +++ b/docs/pluto/jcon2023.jl @@ -10,7 +10,7 @@ begin using Distributions using EvoTrees: EvoTreeRegressor using MLJBase - using MLJFlux: NeuralNetworkRegressor + using MLJFlux: NeuralNetworkRegressor using MLJLinearModels using MLJModels using NearestNeighborModels: KNNRegressor diff --git a/docs/pluto/understanding_coverage.jl b/docs/pluto/understanding_coverage.jl index 9132887..5c38505 100644 --- a/docs/pluto/understanding_coverage.jl +++ b/docs/pluto/understanding_coverage.jl @@ -7,7 +7,14 @@ using InteractiveUtils # This Pluto notebook uses @bind for interactivity. When running this notebook outside of Pluto, the following 'mock version' of @bind gives bound variables a default value (instead of an error). macro bind(def, element) quote - local iv = try Base.loaded_modules[Base.PkgId(Base.UUID("6e696c72-6542-2067-7265-42206c756150"), "AbstractPlutoDingetjes")].Bonds.initial_value catch; b -> missing; end + local iv = try + Base.loaded_modules[Base.PkgId( + Base.UUID("6e696c72-6542-2067-7265-42206c756150"), + "AbstractPlutoDingetjes", + )].Bonds.initial_value + catch + b -> missing + end local el = $(esc(element)) global $(esc(def)) = Core.applicable(Base.get, el) ? Base.get(el) : iv(el) el @@ -22,7 +29,7 @@ begin using EvoTrees: EvoTreeRegressor using LightGBM.MLJInterface: LGBMRegressor using MLJBase - using MLJLinearModels + using MLJLinearModels using MLJModels using NearestNeighborModels: KNNRegressor using Plots diff --git a/docs/src/tutorials/training.qmd b/docs/src/tutorials/training.qmd new file mode 100644 index 0000000..209c89e --- /dev/null +++ b/docs/src/tutorials/training.qmd @@ -0,0 +1,73 @@ +# ConformalTraining + +```@meta +CurrentModule = ConformalPrediction +``` + +```{julia} +#| echo: false +using Pkg; Pkg.activate("docs") +using Plots +theme(:wong) +using Random +Random.seed!(2022) +www_path = "docs/src/www" # output path for files don't get automatically saved in auto-generated path (e.g. GIFs) +``` + + +```{julia} +using MLJ +using Random +Random.seed!(123) + +# Data: +X, y = make_blobs(500, centers=4, cluster_std=1.0) +X = MLJ.table(Float32.(MLJ.matrix(X))) +train, test = partition(eachindex(y), 0.8, shuffle=true) +``` + + +```{julia} +using Flux +using MLJFlux +using ConformalPrediction +using ConformalPrediction.ConformalTraining: ConformalNNClassifier + +# Model: +builder = MLJFlux.MLP(hidden=(32, 32, 32,), σ=Flux.relu) +# clf = ConformalNNClassifier(epochs=250, builder=builder, batch_size=50) +clf = NeuralNetworkClassifier(epochs=250, builder=builder, batch_size=50) +``` + + +```{julia} +using ConformalPrediction + +conf_model = conformal_model(clf; method=:simple_inductive) +mach = machine(conf_model, X, y) +fit!(mach, rows=train) +``` + +```{julia} +#| output: true + +using Plots +p_proba = contourf(mach.model, mach.fitresult, X, y) +p_set_size = contourf(mach.model, mach.fitresult, X, y; plot_set_size=true) +p_smooth = contourf(mach.model, mach.fitresult, X, y; plot_set_loss=true) +plot(p_proba, p_set_size, p_smooth, layout=(1,3), size=(1200,250)) +``` + +```{julia} +#| output: true + +_eval = evaluate!( + mach, + operation=predict, + measure=[emp_coverage, ssc, ineff] +) + +println("Empirical coverage: $(round(_eval.measurement[1], digits=3))") +println("SSC: $(round(_eval.measurement[2], digits=3))") +println("Inefficiency: $(round(_eval.measurement[3], digits=3))") +``` \ No newline at end of file diff --git a/src/ConformalPrediction.jl b/src/ConformalPrediction.jl index 4d6abec..a4c4c00 100644 --- a/src/ConformalPrediction.jl +++ b/src/ConformalPrediction.jl @@ -10,7 +10,7 @@ export soft_assignment # Evaluation: include("evaluation/evaluation.jl") -export emp_coverage, size_stratified_coverage, ssc +export emp_coverage, size_stratified_coverage, ssc, ineff # Artifacts: include("artifacts/core.jl") diff --git a/src/conformal_models/ConformalTraining/ConformalTraining.jl b/src/conformal_models/ConformalTraining/ConformalTraining.jl new file mode 100644 index 0000000..405082e --- /dev/null +++ b/src/conformal_models/ConformalTraining/ConformalTraining.jl @@ -0,0 +1,15 @@ +module ConformalTraining + +using ConformalPrediction +using Flux +using MLJFlux + +const default_builder = MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.relu) + +include("losses.jl") +include("inductive_classification.jl") +include("classifier.jl") +include("regressor.jl") +include("training.jl") + +end diff --git a/src/conformal_models/ConformalTraining/classifier.jl b/src/conformal_models/ConformalTraining/classifier.jl new file mode 100644 index 0000000..2e07ccd --- /dev/null +++ b/src/conformal_models/ConformalTraining/classifier.jl @@ -0,0 +1,91 @@ +using ComputationalResources +using Flux +using MLJFlux +import MLJModelInterface as MMI +using ProgressMeter +using Random +using Tables + +"The `ConformalNNClassifier` struct is a wrapper for a `ConformalModel` that can be used with MLJFlux.jl." +mutable struct ConformalNNClassifier{B,F,O,L} <: MLJFlux.MLJFluxProbabilistic + builder::B + finaliser::F + optimiser::O # mutable struct from Flux/src/optimise/optimisers.jl + loss::L # can be called as in `loss(yhat, y)` + epochs::Int # number of epochs + batch_size::Int # size of a batch + lambda::Float64 # regularization strength + alpha::Float64 # regularizaton mix (0 for all l2, 1 for all l1) + rng::Union{AbstractRNG,Int64} + optimiser_changes_trigger_retraining::Bool + acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` +end + +function ConformalNNClassifier(; + builder::B=default_builder, + finaliser::F=Flux.softmax, + optimiser::O=Flux.Optimise.Adam(), + loss::L=Flux.crossentropy, + epochs::Int=100, + batch_size::Int=100, + lambda::Float64=0.0, + alpha::Float64=0.0, + rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG, + optimiser_changes_trigger_retraining::Bool=false, + acceleration::AbstractResource=CPU1(), +) where {B,F,O,L} + + # Initialise the MLJFlux wrapper: + mod = ConformalNNClassifier( + builder, + finaliser, + optimiser, + loss, + epochs, + batch_size, + lambda, + alpha, + rng, + optimiser_changes_trigger_retraining, + acceleration, + ) + + return mod +end + +# if `b` is a builder, then `b(model, rng, shape...)` is called to make a +# new chain, where `shape` is the return value of this method: +function MLJFlux.shape(model::ConformalNNClassifier, X, y) + levels = MMI.classes(y[1]) + n_output = length(levels) + n_input = Tables.schema(X).names |> length + return (n_input, n_output) +end + +# builds the end-to-end Flux chain needed, given the `model` and `shape`: +function MLJFlux.build(model::ConformalNNClassifier, rng, shape) + + # Chain: + chain = Flux.Chain(MLJFlux.build(model.builder, rng, shape...), model.finaliser) + + return chain +end + +# returns the model `fitresult` (see "Adding Models for General Use" +# section of the MLJ manual) which must always have the form `(chain, +# metadata)`, where `metadata` is anything extra needed by `predict`: +MLJFlux.fitresult(model::ConformalNNClassifier, chain, y) = (chain, MMI.classes(y[1])) + +function MMI.predict(model::ConformalNNClassifier, fitresult, Xnew) + chain, levels = fitresult + X = MLJFlux.reformat(Xnew) + probs = vcat([chain(MLJFlux.tomat(X[:, i]))' for i in 1:size(X, 2)]...) + return MMI.UnivariateFinite(levels, probs) +end + +MMI.metadata_model( + ConformalNNClassifier; + input=Union{AbstractArray,MMI.Table(MMI.Continuous)}, + target=AbstractVector{<:MMI.Finite}, + path="MLJFlux.ConformalNNClassifier", +) diff --git a/src/conformal_models/training/inductive_classification.jl b/src/conformal_models/ConformalTraining/inductive_classification.jl similarity index 80% rename from src/conformal_models/training/inductive_classification.jl rename to src/conformal_models/ConformalTraining/inductive_classification.jl index 990cf78..8a76eee 100644 --- a/src/conformal_models/training/inductive_classification.jl +++ b/src/conformal_models/ConformalTraining/inductive_classification.jl @@ -1,21 +1,22 @@ +using CategoricalArrays +using ConformalPrediction: SimpleInductiveClassifier, AdaptiveInductiveClassifier using MLJEnsembles: EitherEnsembleModel using MLJFlux: MLJFluxModel, reformat using MLUtils """ - score(conf_model::InductiveModel, model::MLJFluxModel, fitresult, X, y::Union{Nothing,AbstractArray}=nothing) + ConformalPrediction.score(conf_model::InductiveModel, model::MLJFluxModel, fitresult, X, y::Union{Nothing,AbstractArray}=nothing) Overloads the `score` function for the `MLJFluxModel` type. """ -function score( +function ConformalPrediction.score( conf_model::SimpleInductiveClassifier, ::Type{<:MLJFluxModel}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing, ) - X = reformat(X) - X = typeof(X) <: AbstractArray ? X : permutedims(matrix(X)) + X = permutedims(matrix(X)) probas = permutedims(fitresult[1](X)) scores = @.(conf_model.heuristic(probas)) if isnothing(y) @@ -27,19 +28,18 @@ function score( end """ - score(conf_model::SimpleInductiveClassifier, ::Type{<:EitherEnsembleModel{<:MLJFluxModel}}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing) + ConformalPrediction.score(conf_model::SimpleInductiveClassifier, ::Type{<:EitherEnsembleModel{<:MLJFluxModel}}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing) Overloads the `score` function for ensembles of `MLJFluxModel` types. """ -function score( +function ConformalPrediction.score( conf_model::SimpleInductiveClassifier, ::Type{<:EitherEnsembleModel{<:MLJFluxModel}}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing, ) - X = reformat(X) - X = typeof(X) <: AbstractArray ? X : permutedims(matrix(X)) + X = permutedims(matrix(X)) _chains = map(res -> res[1], fitresult.ensemble) probas = MLUtils.stack(map(chain -> chain(X), _chains)) |> @@ -56,7 +56,7 @@ function score( end """ - score(conf_model::AdaptiveInductiveClassifier, ::Type{<:MLJFluxModel}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing) + ConformalPrediction.score(conf_model::AdaptiveInductiveClassifier, ::Type{<:MLJFluxModel}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing) Overloads the `score` function for the `MLJFluxModel` type. """ @@ -87,7 +87,7 @@ function score( end """ - score(conf_model::AdaptiveInductiveClassifier, ::Type{<:EitherEnsembleModel{<:MLJFluxModel}}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing) + ConformalPrediction.score(conf_model::AdaptiveInductiveClassifier, ::Type{<:EitherEnsembleModel{<:MLJFluxModel}}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing) Overloads the `score` function for ensembles of `MLJFluxModel` types. """ diff --git a/src/conformal_models/ConformalTraining/inductive_regression.jl b/src/conformal_models/ConformalTraining/inductive_regression.jl new file mode 100644 index 0000000..455af5f --- /dev/null +++ b/src/conformal_models/ConformalTraining/inductive_regression.jl @@ -0,0 +1,42 @@ +using CategoricalArrays +using ConformalPrediction: SimpleInductiveRegressor +using MLJEnsembles: EitherEnsembleModel +using MLJFlux: MLJFluxModel +using MLUtils + +""" + ConformalPrediction.score(conf_model::SimpleInductiveRegressor, model::MLJFluxModel, fitresult, X, y) + +Overloads the `score` function for the `MLJFluxModel` type. +""" +function ConformalPrediction.score( + conf_model::SimpleInductiveRegressor, ::Type{<:MLJFluxModel}, fitresult, X, y +) + X = permutedims(matrix(X)) + ŷ = permutedims(fitresult[1](X)) + scores = @.(conf_model.heuristic(y, ŷ)) + return scores +end + +""" + ConformalPrediction.score(conf_model::SimpleInductiveRegressor, ::Type{<:EitherEnsembleModel{<:MLJFluxModel}}, fitresult, X, y) + +Overloads the `score` function for ensembles of `MLJFluxModel` types. +""" +function ConformalPrediction.score( + conf_model::SimpleInductiveRegressor, + ::Type{<:EitherEnsembleModel{<:MLJFluxModel}}, + fitresult, + X, + y, +) + X = permutedims(matrix(X)) + _chains = map(res -> res[1], fitresult.ensemble) + ŷ = + MLUtils.stack(map(chain -> chain(X), _chains)) |> + y -> + mean(y; dims=ndims(y)) |> + y -> MLUtils.unstack(y; dims=ndims(y))[1] |> y -> permutedims(y) + scores = @.(conf_model.heuristic(y, ŷ)) + return scores +end diff --git a/src/conformal_models/training/losses.jl b/src/conformal_models/ConformalTraining/losses.jl similarity index 91% rename from src/conformal_models/training/losses.jl rename to src/conformal_models/ConformalTraining/losses.jl index 6a76271..8fc6713 100644 --- a/src/conformal_models/training/losses.jl +++ b/src/conformal_models/ConformalTraining/losses.jl @@ -1,6 +1,8 @@ +using ConformalPrediction: ConformalProbabilisticSet using Flux using LinearAlgebra using MLJBase +using StatsBase """ soft_assignment(conf_model::ConformalProbabilisticSet; temp::Real=0.5) @@ -27,8 +29,8 @@ function soft_assignment( ) temp = isnothing(temp) ? 0.5 : temp v = sort(conf_model.scores[:calibration]) - q̂ = qplus(v, conf_model.coverage; sorted=true) - scores = score(conf_model, fitresult, X) + q̂ = StatsBase.quantile(v, conf_model.coverage; sorted=true) + scores = ConformalPrediction.score(conf_model, fitresult, X) return @.(σ((q̂ - scores) / temp)) end @@ -53,7 +55,7 @@ function smooth_size_loss( temp::Union{Nothing,Real}=nothing, κ::Real=1.0, ) - temp = isnothing(temp) ? 0.5 : temp + temp = isnothing(temp) ? 0.1 : temp C = soft_assignment(conf_model, fitresult, X; temp=temp) is_empty_set = all( x -> x .== 0, soft_assignment(conf_model, fitresult, X; temp=0.0); dims=2 @@ -70,14 +72,6 @@ function smooth_size_loss( Ω = [Ω..., ω] end Ω = permutedims(permutedims(Ω)) - # Ω = map(sum(C; dims=2), is_empty_set) do x, is_empty - # if is_empty #&& κ > 0 - # ω = maximum([x - κ, full_set_size - κ]) - # else - # ω = maximum([0, x - κ]) - # end - # return ω - # end return Ω end diff --git a/src/conformal_models/ConformalTraining/regressor.jl b/src/conformal_models/ConformalTraining/regressor.jl new file mode 100644 index 0000000..3522750 --- /dev/null +++ b/src/conformal_models/ConformalTraining/regressor.jl @@ -0,0 +1,82 @@ +using ComputationalResources +using Flux +using MLJFlux +import MLJModelInterface as MMI +using ProgressMeter +using Random +using Tables + +"The `ConformalNNRegressor` struct is a wrapper for a `ConformalModel` that can be used with MLJFlux.jl." +mutable struct ConformalNNRegressor{B,O,L} <: MLJFlux.MLJFluxDeterministic + builder::B + optimiser::O # mutable struct from Flux/src/optimise/optimisers.jl + loss::L # can be called as in `loss(yhat, y)` + epochs::Int # number of epochs + batch_size::Int # size of a batch + lambda::Float64 # regularization strength + alpha::Float64 # regularizaton mix (0 for all l2, 1 for all l1) + rng::Union{AbstractRNG,Integer} + optimiser_changes_trigger_retraining::Bool + acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` +end + +function ConformalNNRegressor(; + builder::B=default_builder, + optimiser::O=Flux.Optimise.Adam(), + loss::L=Flux.mse, + epochs::Int=100, + batch_size::Int=100, + lambda::Float64=0.0, + alpha::Float64=0.0, + rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG, + optimiser_changes_trigger_retraining::Bool=false, + acceleration::AbstractResource=CPU1(), +) where {B,O,L} + + # Initialise the MLJFlux wrapper: + mod = ConformalNNRegressor( + builder, + optimiser, + loss, + epochs, + batch_size, + lambda, + alpha, + rng, + optimiser_changes_trigger_retraining, + acceleration, + ) + + return mod +end + +""" + shape(model::NeuralNetworkRegressor, X, y) + +A private method that returns the shape of the input and output of the model for given data `X` and `y`. +""" +function MLJFlux.shape(model::ConformalNNRegressor, X, y) + X = X isa Matrix ? Tables.table(X) : X + n_input = Tables.schema(X).names |> length + n_ouput = 1 + return (n_input, 1) +end + +function MLJFlux.build(model::ConformalNNRegressor, rng, shape) + return MLJFlux.build(model.builder, rng, shape...) +end + +MLJFlux.fitresult(model::ConformalNNRegressor, chain, y) = (chain, nothing) + +function MMI.predict(model::ConformalNNRegressor, fitresult, Xnew) + chain = fitresult[1] + Xnew_ = MLJFlux.reformat(Xnew) + return [chain(values.(MLJFlux.tomat(Xnew_[:, i])))[1] for i in 1:size(Xnew_, 2)] +end + +MMI.metadata_model( + ConformalNNRegressor; + input=Union{AbstractMatrix{Continuous},MMI.Table(MMI.Continuous)}, + target=AbstractVector{<:MMI.Continuous}, + path="MLJFlux.ConformalNNRegressor", +) diff --git a/src/conformal_models/ConformalTraining/training.jl b/src/conformal_models/ConformalTraining/training.jl new file mode 100644 index 0000000..ae8a82e --- /dev/null +++ b/src/conformal_models/ConformalTraining/training.jl @@ -0,0 +1,55 @@ +const ConformalNN = Union{ConformalNNClassifier,ConformalNNRegressor} + +@doc raw""" + MLJFlux.train!(model::ConformalNN, penalty, chain, optimiser, X, y) + +Implements the conformal traning procedure for the `ConformalNN` type. +""" +function MLJFlux.train!(model::ConformalNN, penalty, chain, optimiser, X, y) + + # Setup: + loss = model.loss + n_batches = length(y) + training_loss = zero(Float32) + size_loss = zero(Float32) + fitresult = (chain, nothing) + + # Training loop: + for i in 1:n_batches + parameters = Flux.params(chain) + + # Data Splitting: + X_batch, y_batch = X[i], y[i] + conf_model = ConformalPrediction.conformal_model( + model; method=:simple_inductive, coverage=0.95 + ) + calibration, pred = partition( + 1:size(y_batch, 2), conf_model.train_ratio; shuffle=true + ) + Xcal = X_batch[:, calibration] + ycal = y_batch[:, calibration] + Xcal, ycal = MMI.reformat(conf_model.model, Xcal, ycal) + Xpred = X_batch[:, pred] + ypred = y_batch[:, pred] + Xpred, ypred = MMI.reformat(conf_model.model, Xpred, ypred) + + # On-the-fly calibration: + cal_scores, scores = ConformalPrediction.score( + conf_model, fitresult, Xcal', categorical(Flux.onecold(ycal)) + ) + conf_model.scores = Dict(:calibration => cal_scores, :all => scores) + + gs = Flux.gradient(parameters) do + Ω = smooth_size_loss(conf_model, fitresult, Xpred') + yhat = chain(X_batch) + batch_loss = loss(yhat, y_batch) + penalty(parameters) / n_batches + batch_loss += 0.5 * sum(Ω) / length(Ω) # add size loss + training_loss += batch_loss + size_loss += sum(Ω) / length(Ω) + return batch_loss + end + Flux.update!(optimiser, parameters, gs) + end + + return training_loss / n_batches +end diff --git a/src/conformal_models/conformal_models.jl b/src/conformal_models/conformal_models.jl index ce5cd19..7c1068d 100644 --- a/src/conformal_models/conformal_models.jl +++ b/src/conformal_models/conformal_models.jl @@ -57,7 +57,8 @@ include("inductive_classification.jl") include("transductive_classification.jl") # Training: -include("training/training.jl") +include("ConformalTraining/ConformalTraining.jl") +using .ConformalTraining # Type unions: const InductiveModel = Union{ diff --git a/src/conformal_models/plotting.jl b/src/conformal_models/plotting.jl index 3461499..5e4f55d 100644 --- a/src/conformal_models/plotting.jl +++ b/src/conformal_models/plotting.jl @@ -159,7 +159,6 @@ function Plots.contourf( kwargs..., ) else - clim = @isdefined(clim) ? clim : (0, 1) plt = contourf( x1range, x2range, @@ -167,7 +166,6 @@ function Plots.contourf( title=title, xlims=xlims, ylims=ylims, - clim=clim, c=cgrad(:blues), linewidth=0, kwargs..., diff --git a/src/conformal_models/training/training.jl b/src/conformal_models/training/training.jl deleted file mode 100644 index 8fd1da3..0000000 --- a/src/conformal_models/training/training.jl +++ /dev/null @@ -1,2 +0,0 @@ -include("losses.jl") -include("inductive_classification.jl") diff --git a/src/evaluation/measures.jl b/src/evaluation/measures.jl index 09ba9f4..f96a4e3 100644 --- a/src/evaluation/measures.jl +++ b/src/evaluation/measures.jl @@ -38,3 +38,14 @@ function size_stratified_coverage(ŷ, y) return C̄ end + +""" + ineff(ŷ) + +Computes the inefficiency (average set size) for conformal predictions `ŷ`. +""" +function ineff(ŷ, y=missing) + R = length(ŷ) + ineff = sum(set_size.(ŷ)) / R + return ineff +end