diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a1da50a..303018d 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -3,6 +3,7 @@ on: push: branches: - main + - develop tags: ['*'] pull_request: concurrency: diff --git a/Manifest.toml b/Manifest.toml new file mode 100644 index 0000000..703ba85 --- /dev/null +++ b/Manifest.toml @@ -0,0 +1,777 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.8.1" +manifest_format = "2.0" +project_hash = "0dfb98cfb849f2685c915fd274d6d54878794749" + +[[deps.ARFFFiles]] +deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"] +git-tree-sha1 = "e8c8e0a2be6eb4f56b1672e46004463033daa409" +uuid = "da404889-ca92-49ff-9e8b-0aa6b4d38dc8" +version = "1.4.1" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.BitFlags]] +git-tree-sha1 = "84259bb6172806304b9101094a7cc4bc6f56dbc6" +uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" +version = "0.1.5" + +[[deps.Bzip2_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "19a35467a82e236ff51bc17a3a44b69ef35185a2" +uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" +version = "1.0.8+0" + +[[deps.CEnum]] +git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" +uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" +version = "0.4.2" + +[[deps.Calculus]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" +uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" +version = "0.5.1" + +[[deps.CategoricalArrays]] +deps = ["DataAPI", "Future", "Missings", "Printf", "Requires", "Statistics", "Unicode"] +git-tree-sha1 = "5084cc1a28976dd1642c9f337b28a3cb03e0f7d2" +uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597" +version = "0.10.7" + +[[deps.CategoricalDistributions]] +deps = ["CategoricalArrays", "Distributions", "Missings", "OrderedCollections", "Random", "ScientificTypes", "UnicodePlots"] +git-tree-sha1 = "23fe4c6668776fedfd3747c545cd0d1a5190eb15" +uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e" +version = "0.1.9" + +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "e7ff6cadf743c098e08fca25c91103ee4303c9bb" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.15.6" + +[[deps.ChangesOfVariables]] +deps = ["ChainRulesCore", "LinearAlgebra", "Test"] +git-tree-sha1 = "38f7a08f19d8810338d4f5085211c7dfa5d5bdd8" +uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" +version = "0.1.4" + +[[deps.CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.0" + +[[deps.ColorSchemes]] +deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "Random"] +git-tree-sha1 = "1fd869cc3875b57347f7027521f561cf46d1fcd8" +uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" +version = "3.19.0" + +[[deps.ColorTypes]] +deps = ["FixedPointNumbers", "Random"] +git-tree-sha1 = "eb7f0f8307f71fac7c606984ea5fb2817275d6e4" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.11.4" + +[[deps.ColorVectorSpace]] +deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "SpecialFunctions", "Statistics", "TensorCore"] +git-tree-sha1 = "d08c20eef1f2cbc6e60fd3612ac4340b89fea322" +uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" +version = "0.9.9" + +[[deps.Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] +git-tree-sha1 = "417b0ed7b8b838aa6ca0a87aadf1bb9eb111ce40" +uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" +version = "0.12.8" + +[[deps.Combinatorics]] +git-tree-sha1 = "08c8b6831dc00bfea825826be0bc8336fc369860" +uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +version = "1.0.2" + +[[deps.Compat]] +deps = ["Dates", "LinearAlgebra", "UUIDs"] +git-tree-sha1 = "3ca828fe1b75fa84b021a7860bd039eaea84d2f2" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.3.0" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "0.5.2+0" + +[[deps.ComputationalResources]] +git-tree-sha1 = "52cb3ec90e8a8bea0e62e275ba577ad0f74821f7" +uuid = "ed09eef8-17a6-5b46-8889-db040fac31e3" +version = "0.3.2" + +[[deps.ConstructionBase]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "fb21ddd70a051d882a1686a5a550990bbe371a95" +uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +version = "1.4.1" + +[[deps.Contour]] +git-tree-sha1 = "d05d9e7b7aedff4e5b51a029dced05cfb6125781" +uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" +version = "0.6.2" + +[[deps.Crayons]] +git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.1.1" + +[[deps.DataAPI]] +git-tree-sha1 = "46d2680e618f8abd007bce0c3026cb0c4a8f2032" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.12.0" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "d1fff3a548102f48987a52a2e0d114fa97d730f0" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.13" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.DelimitedFiles]] +deps = ["Mmap"] +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" + +[[deps.DensityInterface]] +deps = ["InverseFunctions", "Test"] +git-tree-sha1 = "80c3e8639e3353e5d2912fb3a1916b8455e2494b" +uuid = "b429d917-457f-4dbc-8f4c-0cc954292b1d" +version = "0.4.0" + +[[deps.Distances]] +deps = ["LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "3258d0659f812acde79e8a74b11f17ac06d0ca04" +uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +version = "0.10.7" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[deps.Distributions]] +deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"] +git-tree-sha1 = "04db820ebcfc1e053bd8cbb8d8bccf0ff3ead3f7" +uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" +version = "0.25.76" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "5158c2b41018c5f7eb1470d558127ac274eca0c9" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.1" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.DualNumbers]] +deps = ["Calculus", "NaNMath", "SpecialFunctions"] +git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" +uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" +version = "0.6.8" + +[[deps.EarlyStopping]] +deps = ["Dates", "Statistics"] +git-tree-sha1 = "98fdf08b707aaf69f524a6cd0a67858cefe0cfb6" +uuid = "792122b4-ca99-40de-a6bc-6742525f08b6" +version = "0.3.0" + +[[deps.FileIO]] +deps = ["Pkg", "Requires", "UUIDs"] +git-tree-sha1 = "94f5101b96d2d968ace56f7f2db19d0a5f592e28" +uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +version = "1.15.0" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.FillArrays]] +deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] +git-tree-sha1 = "87519eb762f85534445f5cda35be12e32759ee14" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "0.13.4" + +[[deps.FixedPointNumbers]] +deps = ["Statistics"] +git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.8.4" + +[[deps.Formatting]] +deps = ["Printf"] +git-tree-sha1 = "8339d61043228fdd3eb658d86c926cb282ae72a8" +uuid = "59287772-0a20-5a39-b81b-1366585eb4c0" +version = "0.4.2" + +[[deps.FreeType]] +deps = ["CEnum", "FreeType2_jll"] +git-tree-sha1 = "cabd77ab6a6fdff49bfd24af2ebe76e6e018a2b4" +uuid = "b38be410-82b0-50bf-ab77-7b57e271db43" +version = "4.0.0" + +[[deps.FreeType2_jll]] +deps = ["Artifacts", "Bzip2_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] +git-tree-sha1 = "87eb71354d8ec1a96d4a7636bd57a7347dde3ef9" +uuid = "d7e528f0-a631-5988-bf34-fe36492bcfd7" +version = "2.10.4+0" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" + +[[deps.HTTP]] +deps = ["Base64", "CodecZlib", "Dates", "IniFile", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] +git-tree-sha1 = "e8c58d5f03b9d9eb9ed7067a2f34c7c371ab130b" +uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" +version = "1.4.1" + +[[deps.HypergeometricFunctions]] +deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions", "Test"] +git-tree-sha1 = "709d864e3ed6e3545230601f94e11ebc65994641" +uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" +version = "0.3.11" + +[[deps.IniFile]] +git-tree-sha1 = "f550e6e32074c939295eb5ea6de31849ac2c9625" +uuid = "83e8ac13-25f8-5344-8a64-a9f2b223428f" +version = "0.5.1" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "49510dfcb407e572524ba94aeae2fced1f3feb0f" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.8" + +[[deps.InvertedIndices]] +git-tree-sha1 = "bee5f1ef5bf65df56bdd2e40447590b272a5471f" +uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" +version = "1.1.0" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.1.1" + +[[deps.IterationControl]] +deps = ["EarlyStopping", "InteractiveUtils"] +git-tree-sha1 = "d7df9a6fdd82a8cfdfe93a94fcce35515be634da" +uuid = "b3c1a2ee-3fec-4384-bf48-272ea71de57c" +version = "0.5.3" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.JLLWrappers]] +deps = ["Preferences"] +git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.4.1" + +[[deps.JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.3" + +[[deps.LatinHypercubeSampling]] +deps = ["Random", "StableRNGs", "StatsBase", "Test"] +git-tree-sha1 = "42938ab65e9ed3c3029a8d2c58382ca75bdab243" +uuid = "a5e1c1ea-c99a-51d3-a14d-a9a37257b02d" +version = "1.8.0" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.3" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "7.84.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.10.2+0" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.LogExpFunctions]] +deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "94d9c52ca447e23eac0c0f074effbcd38830deb5" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.18" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.LoggingExtras]] +deps = ["Dates", "Logging"] +git-tree-sha1 = "5d4d2d9904227b8bd66386c1138cf4d5ffa826bf" +uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" +version = "0.4.9" + +[[deps.LossFunctions]] +deps = ["InteractiveUtils", "Markdown", "RecipesBase"] +git-tree-sha1 = "53cd63a12f06a43eef6f4aafb910ac755c122be7" +uuid = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" +version = "0.8.0" + +[[deps.MLJ]] +deps = ["CategoricalArrays", "ComputationalResources", "Distributed", "Distributions", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJIteration", "MLJModels", "MLJTuning", "OpenML", "Pkg", "ProgressMeter", "Random", "ScientificTypes", "Statistics", "StatsBase", "Tables"] +git-tree-sha1 = "025706ea81e635ac530a1d3dd365af971805bf79" +uuid = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" +version = "0.18.5" + +[[deps.MLJBase]] +deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "Serialization", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] +git-tree-sha1 = "f68deea1f25727f24a4afa9f941763e6fc44f5af" +uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +version = "0.20.19" + +[[deps.MLJEnsembles]] +deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJBase", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatsBase"] +git-tree-sha1 = "ed2f724be26d0023cade9d59b55da93f528c3f26" +uuid = "50ed68f4-41fd-4504-931a-ed422449fee0" +version = "0.3.1" + +[[deps.MLJIteration]] +deps = ["IterationControl", "MLJBase", "Random", "Serialization"] +git-tree-sha1 = "024d0bd22bf4a5b273f626e89d742a9db95285ef" +uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55" +version = "0.5.0" + +[[deps.MLJModelInterface]] +deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] +git-tree-sha1 = "0a36882e73833d60dac49b00d203f73acfd50b85" +uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" +version = "1.7.0" + +[[deps.MLJModels]] +deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] +git-tree-sha1 = "147a8e7939601f8c37204addbbe29f2bcfb876a8" +uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7" +version = "0.15.14" + +[[deps.MLJTuning]] +deps = ["ComputationalResources", "Distributed", "Distributions", "LatinHypercubeSampling", "MLJBase", "ProgressMeter", "Random", "RecipesBase"] +git-tree-sha1 = "77209966cc028c1d7730001dc32bffe17a198f29" +uuid = "03970b2e-30c4-11ea-3135-d1576263f10f" +version = "0.7.3" + +[[deps.MarchingCubes]] +deps = ["SnoopPrecompile", "StaticArrays"] +git-tree-sha1 = "ffc66942498a5f0d02b9e7b1b1af0f5873142cdc" +uuid = "299715c1-40a9-479a-aaf9-4a633d36f717" +version = "0.1.4" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS]] +deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "Random", "Sockets"] +git-tree-sha1 = "6872f9594ff273da6d13c7c1a1545d5a8c7d0c1c" +uuid = "739be429-bea8-5141-9913-cc70e7f3736d" +version = "1.1.6" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.0+0" + +[[deps.Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "bf210ce90b6c9eed32d25dbcae1ebc565df2687f" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.0.2" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2022.2.1" + +[[deps.NaNMath]] +deps = ["OpenLibm_jll"] +git-tree-sha1 = "a7c3d1da1189a1c2fe843a3bfa04d18d20eb3211" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "1.0.1" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.20+0" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+0" + +[[deps.OpenML]] +deps = ["ARFFFiles", "HTTP", "JSON", "Markdown", "Pkg"] +git-tree-sha1 = "88dfa70c818f7a4728c6b82a72a0e597e083938b" +uuid = "8b6db2d4-7670-4922-a472-f9537c81ab66" +version = "0.3.0" + +[[deps.OpenSSL]] +deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] +git-tree-sha1 = "ebe81469e9d7b471d7ddb611d9e147ea16de0add" +uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" +version = "1.2.1" + +[[deps.OpenSSL_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "a94dc0169bffbf7e5250fb7e1efb1a85b09105c7" +uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" +version = "1.1.18+0" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[deps.OrderedCollections]] +git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.4.1" + +[[deps.PDMats]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "cf494dca75a69712a72b80bc48f59dcf3dea63ec" +uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" +version = "0.11.16" + +[[deps.Parameters]] +deps = ["OrderedCollections", "UnPack"] +git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe" +uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a" +version = "0.12.3" + +[[deps.Parsers]] +deps = ["Dates"] +git-tree-sha1 = "595c0b811cf2bab8b0849a70d9bd6379cc1cfb52" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.4.1" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.8.0" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "47e5f437cc0e7ef2ce8406ce1e7e24d44915f88d" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.3.0" + +[[deps.PrettyPrinting]] +git-tree-sha1 = "4be53d093e9e37772cc89e1009e8f6ad10c4681b" +uuid = "54e16d92-306c-5ea0-a30b-337be88ac337" +version = "0.4.0" + +[[deps.PrettyTables]] +deps = ["Crayons", "Formatting", "Markdown", "Reexport", "StringManipulation", "Tables"] +git-tree-sha1 = "460d9e154365e058c4d886f6f7d6df5ffa1ea80e" +uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +version = "2.1.2" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.ProgressMeter]] +deps = ["Distributed", "Printf"] +git-tree-sha1 = "d7a7aef8f8f2d537104f170139553b14dfe39fe9" +uuid = "92933f4c-e287-5a05-a399-4b506db050ca" +version = "1.7.2" + +[[deps.QuadGK]] +deps = ["DataStructures", "LinearAlgebra"] +git-tree-sha1 = "3c009334f45dfd546a16a57960a821a1a023d241" +uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +version = "2.5.0" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA", "Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.RecipesBase]] +deps = ["SnoopPrecompile"] +git-tree-sha1 = "612a4d76ad98e9722c8ba387614539155a59e30c" +uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +version = "1.3.0" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.RelocatableFolders]] +deps = ["SHA", "Scratch"] +git-tree-sha1 = "90bc7a7c96410424509e4263e277e43250c05691" +uuid = "05181044-ff0b-4ac5-8273-598c1e38db00" +version = "1.0.0" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + +[[deps.Rmath]] +deps = ["Random", "Rmath_jll"] +git-tree-sha1 = "bf3188feca147ce108c76ad82c2792c57abe7b1f" +uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" +version = "0.7.0" + +[[deps.Rmath_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "68db32dff12bb6127bac73c209881191bf0efbb7" +uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" +version = "0.3.0+0" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.ScientificTypes]] +deps = ["CategoricalArrays", "ColorTypes", "Dates", "Distributions", "PrettyTables", "Reexport", "ScientificTypesBase", "StatisticalTraits", "Tables"] +git-tree-sha1 = "82b2426c11fa6cb23bbfbe0d7378837a653ba44b" +uuid = "321657f4-b219-11e9-178b-2701a2544e81" +version = "3.0.1" + +[[deps.ScientificTypesBase]] +git-tree-sha1 = "a8e18eb383b5ecf1b5e6fc237eb39255044fd92b" +uuid = "30f210dd-8aff-4c5f-94ba-8e64358c1161" +version = "3.0.0" + +[[deps.Scratch]] +deps = ["Dates"] +git-tree-sha1 = "f94f779c94e58bf9ea243e77a37e16d9de9126bd" +uuid = "6c6a2e73-6563-6170-7368-637461726353" +version = "1.1.1" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.SimpleBufferStream]] +git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" +uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" +version = "1.1.0" + +[[deps.SnoopPrecompile]] +git-tree-sha1 = "f604441450a3c0569830946e5b33b78c928e1a85" +uuid = "66db9d55-30c0-4569-8b51-7e840670fc0c" +version = "1.0.1" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "b3363d7460f7d098ca0912c69b082f75625d7508" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.0.1" + +[[deps.SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[deps.SpecialFunctions]] +deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "d75bda01f8c31ebb72df80a46c88b25d1c79c56d" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.1.7" + +[[deps.StableRNGs]] +deps = ["Random", "Test"] +git-tree-sha1 = "3be7d49667040add7ee151fefaf1f8c04c8c8276" +uuid = "860ef19b-820b-49d6-a774-d7a799459cd3" +version = "1.0.0" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] +git-tree-sha1 = "f86b3a049e5d05227b10e15dbb315c5b90f14988" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.5.9" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.0" + +[[deps.StatisticalTraits]] +deps = ["ScientificTypesBase"] +git-tree-sha1 = "30b9236691858e13f167ce829490a68e1a597782" +uuid = "64bff920-2084-43da-a3e6-9bb72801c0c9" +version = "3.2.0" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "f9af7f195fb13589dd2e2d57fdb401717d2eb1f6" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.5.0" + +[[deps.StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "d1bf48bfcc554a3761a133fe3a9bb01488e06916" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.33.21" + +[[deps.StatsFuns]] +deps = ["ChainRulesCore", "HypergeometricFunctions", "InverseFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] +git-tree-sha1 = "5783b877201a82fc0014cbf381e7e6eb130473a4" +uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +version = "1.0.1" + +[[deps.StringManipulation]] +git-tree-sha1 = "46da2434b41f41ac3594ee9816ce5541c6096123" +uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" +version = "0.3.0" + +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.0" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"] +git-tree-sha1 = "c79322d36826aa2f4fd8ecfa96ddb47b174ac78d" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.10.0" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.TensorCore]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" +uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" +version = "0.1.1" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.TranscodingStreams]] +deps = ["Random", "Test"] +git-tree-sha1 = "8a75929dcd3c38611db2f8d08546decb514fcadf" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.9.9" + +[[deps.URIs]] +git-tree-sha1 = "e59ecc5a41b000fa94423a578d29290c7266fc10" +uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" +version = "1.4.0" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.UnPack]] +git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" +uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" +version = "1.0.2" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.UnicodePlots]] +deps = ["ColorSchemes", "ColorTypes", "Contour", "Crayons", "Dates", "FileIO", "FreeType", "LinearAlgebra", "MarchingCubes", "NaNMath", "Printf", "Requires", "SnoopPrecompile", "SparseArrays", "StaticArrays", "StatsBase", "Unitful"] +git-tree-sha1 = "8a6dcd44129de81cc760b9d8af6fba188d3a01a6" +uuid = "b8865327-cd53-5732-bb35-84acbb429228" +version = "3.1.3" + +[[deps.Unitful]] +deps = ["ConstructionBase", "Dates", "LinearAlgebra", "Random"] +git-tree-sha1 = "d57a4ed70b6f9ff1da6719f5f2713706d57e0d66" +uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" +version = "1.12.0" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.12+3" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.1.1+0" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.48.0+0" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+0" diff --git a/Project.toml b/Project.toml index 227d163..495f890 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,8 @@ version = "0.1.0" [deps] MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" +MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] diff --git a/README.md b/README.md index cfa2de3..6d3db43 100644 --- a/README.md +++ b/README.md @@ -40,27 +40,27 @@ X, y = MLJ.make_regression(1000, 2) train, calibration, test = partition(eachindex(y), 0.4, 0.4) ``` -We then train a boosted tree ([EvoTrees](https://github.com/Evovest/EvoTrees.jl)) and follow the standard [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) training procedure. +We then train a decision tree ([DecisionTree](https://github.com/Evovest/DecisionTree.jl)) and follow the standard [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) training procedure. ``` julia -EvoTreeRegressor = @load EvoTreeRegressor pkg=EvoTrees -model = EvoTreeRegressor() +DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree +model = DecisionTreeRegressor() mach = machine(model, X, y) fit!(mach, rows=train) ``` -To turn our conventional machine into a conformal machine, we just need to declare it as such and then calibrate it using our calibration data: +To turn our conventional machine into a conformal model, we just need to declare it as such and then calibrate it using our calibration data: ``` julia using ConformalPrediction -conf_mach = conformal_machine(mach) -calibrate!(conf_mach, selectrows(X, calibration), y[calibration]) +conf_model = conformal_model(model) +calibrate!(conf_model, selectrows(X, calibration), y[calibration]) ``` Predictions can then be computed using the generic `predict` method. The code below produces predictions a random subset of test samples: ``` julia -predict(conf_mach, selectrows(X, rand(test,5))) +predict(conf_model, selectrows(X, rand(test,5))) ``` 5-element Vector{Vector{Pair{String, Vector{Float64}}}}: diff --git a/_freeze/docs/src/classification/simple/execute-results/md.json b/_freeze/docs/src/classification/simple/execute-results/md.json index e10bc8b..78697ee 100644 --- a/_freeze/docs/src/classification/simple/execute-results/md.json +++ b/_freeze/docs/src/classification/simple/execute-results/md.json @@ -1,7 +1,7 @@ { "hash": "23e5ff6ddc8b19eba4e8290b20658f33", "result": { - "markdown": "---\nformat:\n commonmark:\n variant: '-raw_html'\n wrap: none\n self-contained: true\ncrossref:\n fig-prefix: Figure\n tbl-prefix: Table\nbibliography: 'https://raw.githubusercontent.com/pat-alt/bib/main/bib.bib'\noutput: asis\nexecute:\n output: false\n freeze: auto\n eval: true\n echo: true\n---\n\n# Classification Tutorial\n\n[INCOMPLETE]\n\nWe firstly generate some synthetic data with three classes and partition it into a training set, a calibration set and a test set:\n\n::: {.cell execution_count=1}\n``` {.julia .cell-code}\nusing MLJ\nX, y = MLJ.make_blobs(1000, 2, centers=3, cluster_std=2)\ntrain, calibration, test = partition(eachindex(y), 0.4, 0.4)\n```\n:::\n\n\nFollowing the standard [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) procedure, we train a boosted tree for the classification task:\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\nEvoTreeClassifier = @load EvoTreeClassifier pkg=EvoTrees\nmodel = EvoTreeClassifier() \nmach = machine(model, X, y)\nfit!(mach, rows=train)\n```\n:::\n\n\nNext we instantiate our conformal machine and calibrate using the calibration data:\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\nusing ConformalPrediction\nconf_mach = conformal_machine(mach)\ncalibrate!(conf_mach, selectrows(X, calibration), y[calibration])\n```\n:::\n\n\nUsing the generic `predict` method we can generate prediction sets like so:\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\npredict(conf_mach, selectrows(X, rand(test,5)))\n```\n\n::: {.cell-output .cell-output-display execution_count=5}\n```\n╭──────────────────────────────────────────────────────────────────────────╮\n│ │\n│ (1) Pair[1 => missing, 2 => 0.6448661054062889, 3 => missing] │\n│ (2) Pair[1 => missing, 2 => missing, 3 => 0.8197529347049547] │\n│ (3) Pair[1 => missing, 2 => 0.8229512785953512, 3 => missing] │\n│ (4) Pair[1 => missing, 2 => 0.7858778376049668, 3 => missing] │\n│ (5) Pair[1 => missing, 2 => missing, 3 => 0.8197529347049547] │\n│ │\n│ │\n╰────────────────────────────────────────────────────────────── 5 items ───╯\n```\n:::\n:::\n\n\n", + "markdown": "---\nformat:\n commonmark:\n variant: '-raw_html'\n wrap: none\n self-contained: true\ncrossref:\n fig-prefix: Figure\n tbl-prefix: Table\nbibliography: 'https://raw.githubusercontent.com/pat-alt/bib/main/bib.bib'\noutput: asis\nexecute:\n output: false\n freeze: auto\n eval: true\n echo: true\n---\n\n# Classification Tutorial\n\n[INCOMPLETE]\n\nWe firstly generate some synthetic data with three classes and partition it into a training set, a calibration set and a test set:\n\n::: {.cell execution_count=1}\n``` {.julia .cell-code}\nusing MLJ\nX, y = MLJ.make_blobs(1000, 2, centers=3, cluster_std=2)\ntrain, calibration, test = partition(eachindex(y), 0.4, 0.4)\n```\n:::\n\n\nFollowing the standard [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) procedure, we train a decision tree for the classification task:\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\nEvoTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree\nmodel = DecisionTreeClassifier() \nmodel = machine(model, X, y)\nfit!(model, rows=train)\n```\n:::\n\n\nNext we instantiate our conformal model and calibrate using the calibration data:\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\nusing ConformalPrediction\nconformal_model = conformal_model(model)\ncalibrate!(conf_model, selectrows(X, calibration), y[calibration])\n```\n:::\n\n\nUsing the generic `predict` method we can generate prediction sets like so:\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\npredict(conf_model, selectrows(X, rand(test,5)))\n```\n\n::: {.cell-output .cell-output-display execution_count=5}\n```\n╭──────────────────────────────────────────────────────────────────────────╮\n│ │\n│ (1) Pair[1 => missing, 2 => 0.6448661054062889, 3 => missing] │\n│ (2) Pair[1 => missing, 2 => missing, 3 => 0.8197529347049547] │\n│ (3) Pair[1 => missing, 2 => 0.8229512785953512, 3 => missing] │\n│ (4) Pair[1 => missing, 2 => 0.7858778376049668, 3 => missing] │\n│ (5) Pair[1 => missing, 2 => missing, 3 => 0.8197529347049547] │\n│ │\n│ │\n╰────────────────────────────────────────────────────────────── 5 items ───╯\n```\n:::\n:::\n\n\n", "supporting": [ "simple_files" ], diff --git a/docs/Project.toml b/docs/Project.toml index 673c9e6..8efc189 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -2,7 +2,7 @@ ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" +DecisionTree = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" PlotThemes = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" diff --git a/docs/src/classification/simple.md b/docs/src/classification/simple.md index 40f958d..e153580 100644 --- a/docs/src/classification/simple.md +++ b/docs/src/classification/simple.md @@ -11,27 +11,27 @@ X, y = MLJ.make_blobs(1000, 2, centers=3, cluster_std=2) train, calibration, test = partition(eachindex(y), 0.4, 0.4) ``` -Following the standard [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) procedure, we train a boosted tree for the classification task: +Following the standard [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) procedure, we train a decision tree for the classification task: ``` julia -EvoTreeClassifier = @load EvoTreeClassifier pkg=EvoTrees -model = EvoTreeClassifier() +DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree +model = DecisionTreeClassifier() mach = machine(model, X, y) fit!(mach, rows=train) ``` -Next we instantiate our conformal machine and calibrate using the calibration data: +Next we instantiate our conformal model and calibrate using the calibration data: ``` julia using ConformalPrediction -conf_mach = conformal_machine(mach) -calibrate!(conf_mach, selectrows(X, calibration), y[calibration]) +conf_model = conformal_model(model) +calibrate!(conf_model, selectrows(X, calibration), y[calibration]) ``` Using the generic `predict` method we can generate prediction sets like so: ``` julia -predict(conf_mach, selectrows(X, rand(test,5))) +predict(conf_model, selectrows(X, rand(test,5))) ``` ╭──────────────────────────────────────────────────────────────────────────╮ diff --git a/docs/src/classification/simple.qmd b/docs/src/classification/simple.qmd index 4356164..2c45241 100644 --- a/docs/src/classification/simple.qmd +++ b/docs/src/classification/simple.qmd @@ -37,37 +37,37 @@ X, y = MLJ.make_moons(1000) train, calibration, test = partition(eachindex(y), 0.4, 0.4) ``` -Following the standard [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) procedure, we train a boosted tree for the classification task: +Following the standard [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) procedure, we train a decision tree for the classification task: ```{julia} -EvoTreeClassifier = @load EvoTreeClassifier pkg=EvoTrees -model = EvoTreeClassifier() +DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree +model = DecisionTreeClassifier() mach = machine(model, X, y) fit!(mach, rows=train) ``` -Next we instantiate our conformal machine and calibrate using the calibration data: +Next we instantiate our conformal model and calibrate using the calibration data: ```{julia} using ConformalPrediction -conf_mach = conformal_machine(mach) -calibrate!(conf_mach, selectrows(X, calibration), y[calibration]) +conf_model = conformal_model(model) +calibrate!(conf_model, selectrows(X, calibration), y[calibration]) ``` Using the generic `predict` method we can generate prediction sets like so: ```{julia} #| output: true -predict(conf_mach, selectrows(X, rand(test,5))) +predict(conf_model, selectrows(X, rand(test,5))) ``` ```{julia} coverage = 0.90 X_lim = map(i -> extrema(X[i]), eachindex(X)) X_grid = collect(map(x_lim -> range(x_lim..., length=30), X_lim)) -label_grid = [Int.(predict_mode(conf_mach.mach, [x1 x2]).refs) for x1 in X_grid[1], x2 in X_grid[2]] -p_grid = [MLJ.pdf.(predict(conf_mach.mach, [x1 x2]), 1) for x1 in X_grid[1], x2 in X_grid[2]] -C_grid = [sum([!ismissing(val) for (key,val) in predict(conf_mach, [x1 x2],coverage)[1]]) for x1 in X_grid[1], x2 in X_grid[2]] +label_grid = [Int.(predict_mode(conf_model.model, [x1 x2]).refs) for x1 in X_grid[1], x2 in X_grid[2]] +p_grid = [MLJ.pdf.(predict(conf_model.model, [x1 x2]), 1) for x1 in X_grid[1], x2 in X_grid[2]] +C_grid = [sum([!ismissing(val) for (key,val) in predict(conf_model, [x1 x2],coverage)[1]]) for x1 in X_grid[1], x2 in X_grid[2]] ``` ```{julia} diff --git a/docs/src/index.md b/docs/src/index.md index 87a499d..1499a46 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -44,27 +44,27 @@ X, y = MLJ.make_regression(1000, 2) train, calibration, test = partition(eachindex(y), 0.4, 0.4) ``` -We then train a boosted tree ([EvoTrees](https://github.com/Evovest/EvoTrees.jl)) and follow the standard [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) training procedure. +We then train a decision tree ([DecisionTree](https://github.com/Evovest/DecisionTree.jl)) and follow the standard [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) training procedure. ``` julia -EvoTreeRegressor = @load EvoTreeRegressor pkg=EvoTrees -model = EvoTreeRegressor() +DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree +model = DecisionTreeRegressor() mach = machine(model, X, y) fit!(mach, rows=train) ``` -To turn our conventional machine into a conformal machine, we just need to declare it as such and then calibrate it using our calibration data: +To turn our conventional machine into a conformal model, we just need to declare it as such and then calibrate it using our calibration data: ``` julia using ConformalPrediction -conf_mach = conformal_machine(mach) -calibrate!(conf_mach, selectrows(X, calibration), y[calibration]) +conf_model = conformal_model(model) +calibrate!(conf_model, selectrows(X, calibration), y[calibration]) ``` Predictions can then be computed using the generic `predict` method. The code below produces predictions a random subset of test samples: ``` julia -predict(conf_mach, selectrows(X, rand(test,5))) +predict(conf_model, selectrows(X, rand(test,5))) ``` ## Contribute 🛠 diff --git a/docs/src/intro.qmd b/docs/src/intro.qmd index 749eeb3..b8b3df5 100644 --- a/docs/src/intro.qmd +++ b/docs/src/intro.qmd @@ -42,28 +42,28 @@ X, y = MLJ.make_regression(1000, 2) train, calibration, test = partition(eachindex(y), 0.4, 0.4) ``` -We then train a boosted tree ([EvoTrees](https://github.com/Evovest/EvoTrees.jl)) and follow the standard [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) training procedure. +We then train a decision tree ([DecisionTree](https://github.com/Evovest/DecisionTree.jl)) and follow the standard [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) training procedure. ```{julia} -EvoTreeRegressor = @load EvoTreeRegressor pkg=EvoTrees -model = EvoTreeRegressor() +DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree +model = DecisionTreeRegressor() mach = machine(model, X, y) fit!(mach, rows=train) ``` -To turn our conventional machine into a conformal machine, we just need to declare it as such and then calibrate it using our calibration data: +To turn our conventional machine into a conformal model, we just need to declare it as such and then calibrate it using our calibration data: ```{julia} using ConformalPrediction -conf_mach = conformal_machine(mach) -calibrate!(conf_mach, selectrows(X, calibration), y[calibration]) +conf_model = conformal_model(model) +calibrate!(conf_model, selectrows(X, calibration), y[calibration]) ``` Predictions can then be computed using the generic `predict` method. The code below produces predictions a random subset of test samples: ```{julia} #| output: true -predict(conf_mach, selectrows(X, rand(test,5))) +predict(conf_model, selectrows(X, rand(test,5))) ``` ## Contribute 🛠 diff --git a/docs/src/reference.md b/docs/src/reference.md index 3f44f35..187bc3e 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -12,7 +12,7 @@ CurrentModule = ConformalPrediction ```@autodocs Modules = [ ConformalPrediction, - ConformalPrediction.ConformalMachines + ConformalPrediction.ConformalModels ] Private = false ``` @@ -22,7 +22,7 @@ Private = false ```@autodocs Modules = [ ConformalPrediction, - ConformalPrediction.ConformalMachines + ConformalPrediction.ConformalModels ] Public = false ``` \ No newline at end of file diff --git a/docs/src/regression/simple.qmd b/docs/src/regression/simple.qmd index 6afde5b..1841144 100644 --- a/docs/src/regression/simple.qmd +++ b/docs/src/regression/simple.qmd @@ -23,19 +23,19 @@ X, y = MLJ.make_regression(1000, 2) train, calibration, test = partition(eachindex(y), 0.4, 0.4) ``` -We then train a boosted tree ([EvoTrees](https://github.com/Evovest/EvoTrees.jl)) and follow the standard [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) training procedure. +We then train a decision tree ([DecisionTree](https://github.com/Evovest/DecisionTree.jl)) and follow the standard [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) training procedure. ```{julia} -EvoTreeRegressor = @load EvoTreeRegressor pkg=EvoTrees -model = EvoTreeRegressor() +DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree +model = DecisionTreeRegressor() mach = machine(model, X, y) fit!(mach, rows=train) ``` -To turn our conventional machine into a conformal machine, we just need to declare it as such and then calibrate it using our calibration data: +To turn our conventional machine into a conformal model, we just need to declare it as such and then calibrate it using our calibration data: ```{julia} using ConformalPrediction -conf_mach = conformal_machine(mach) -calibrate!(conf_mach, selectrows(X, calibration), y[calibration]) +conf_model = conformal_model(model) +calibrate!(conf_model, selectrows(X, calibration), y[calibration]) ``` \ No newline at end of file diff --git a/src/ConformalMachines/ConformalMachines.jl b/src/ConformalMachines/ConformalMachines.jl deleted file mode 100644 index 7605658..0000000 --- a/src/ConformalMachines/ConformalMachines.jl +++ /dev/null @@ -1,74 +0,0 @@ -module ConformalMachines - -"An abstract base type for conformal machines." -abstract type ConformalMachine end -export ConformalMachine - -""" - score(conf_mach::ConformalMachine, Xcal, ycal) - -Generic method for computing non-conformity scores for any conformal machine using calibration data. -""" -function score(conf_mach::ConformalMachine, Xcal, ycal) - # pass -end - -""" - prediction_region(conf_mach::ConformalMachine, Xnew, q̂::Real) - -Generic method for generating prediction regions from a calibrated conformal machine for a given quantile. -""" -function prediction_region(conf_mach::ConformalMachine, Xnew, q̂::Real) - # pass -end - -include("regression.jl") -export NaiveConformalRegressor - -include("classification.jl") -export LABELConformalClassifier - -"A container listing all available methods for conformal prediction." -const available_machines = Dict( - :regression => Dict( - :naive => NaiveConformalRegressor, - ), - :classification => Dict( - :label => LABELConformalClassifier, - ) -) - -# API -using MLJ -""" - conformal_machine(mach::Machine{<:Supervised}; method::Union{Nothing, Symbol}=nothing) - -A simple wrapper function that turns any `MLJ.Machine{<:Supervised}` into a conformal machine. It accepts an optional key argument that can be used to specify the desired method for conformal prediction. -""" -function conformal_machine(mach::Machine{<:Supervised}; method::Union{Nothing, Symbol}=nothing) - - is_classifier = target_scitype(mach.model) <: AbstractVector{<:Finite} - - if isnothing(method) - _method = is_classifier ? LABELConformalClassifier : NaiveConformalRegressor - else - if is_classifier - @assert method in keys(available_machines[:classification]) "$(method) is not a valid method for classifiers." - _method = available_machines[:classification][method] - else - @assert method in keys(available_machines[:regression]) "$(method) is not a valid method for regressors." - _method = available_machines[:regression][method] - end - end - - conf_mach = _method(mach, nothing) - - return conf_mach - -end -export conformal_machine - -# Other general methods: -export score, prediction_region - -end \ No newline at end of file diff --git a/src/ConformalMachines/classification.jl b/src/ConformalMachines/classification.jl deleted file mode 100644 index eaa4c8d..0000000 --- a/src/ConformalMachines/classification.jl +++ /dev/null @@ -1,31 +0,0 @@ -abstract type ConformalClassifier <: ConformalMachine end - -using MLJ - -# LABEL -"The LABEL method for conformal prediction is the simplest approach to classification." -mutable struct LABELConformalClassifier <: ConformalClassifier - mach::Machine{<:Supervised} - scores::Union{Nothing,AbstractArray} -end - -function LABELConformalClassifier(mach::Machine{<:Supervised}) - return LABELConformalClassifier(mach, nothing) -end - -using MLJ -function score(conf_mach::LABELConformalClassifier, Xcal, ycal) - ŷ = pdf.(MLJ.predict(conf_mach.mach, Xcal),ycal) - return @.(1.0 - ŷ) -end - -function prediction_region(conf_mach::LABELConformalClassifier, Xnew, q̂::Real) - L = levels(conf_mach.mach.data[2]) - ŷnew = MLJ.pdf(MLJ.predict(conf_mach.mach, Xnew), L) - # Could rephrase in sense of hypothesis test where - # H_0: Label is in prediction set. - # H_1: Label is not in prediction set. - ŷnew = map(x -> collect(key => 1-val <= q̂::Real ? val : missing for (key,val) in zip(L,x)),eachrow(ŷnew)) - return ŷnew -end - diff --git a/src/ConformalMachines/regression.jl b/src/ConformalMachines/regression.jl deleted file mode 100644 index 948e773..0000000 --- a/src/ConformalMachines/regression.jl +++ /dev/null @@ -1,25 +0,0 @@ -abstract type ConformalRegressor <: ConformalMachine end - -using MLJ - -# Naive -"The **Naive** method for conformal prediction is the simplest approach to regression." -mutable struct NaiveConformalRegressor <: ConformalRegressor - mach::Machine{<:Supervised} - scores::Union{Nothing,AbstractArray} -end - -function NaiveConformalRegressor(mach::Machine{<:Supervised}) - return NaiveConformalRegressor(mach, nothing) -end - -function score(conf_mach::NaiveConformalRegressor, Xcal, ycal) - ŷ = MLJ.predict(conf_mach.mach, Xcal) - return @.(abs(ŷ - ycal)) -end - -function prediction_region(conf_mach::NaiveConformalRegressor, Xnew, q̂::Real) - ŷnew = MLJ.predict(conf_mach.mach, Xnew) - ŷnew = map(x -> ["lower" => x .- q̂, "upper" => x .+ q̂],eachrow(ŷnew)) - return ŷnew -end \ No newline at end of file diff --git a/src/ConformalModels/ConformalModels.jl b/src/ConformalModels/ConformalModels.jl new file mode 100644 index 0000000..3799853 --- /dev/null +++ b/src/ConformalModels/ConformalModels.jl @@ -0,0 +1,48 @@ +module ConformalModels + +using MLJ +import MLJModelInterface as MMI +import MLJModelInterface: predict, fit, save, restore +import MLJBase + +"An abstract base type for conformal models." +abstract type ConformalModel <: MMI.Model end +abstract type InductiveConformalModel <: ConformalModel end +abstract type TransductiveConformalModel <: ConformalModel end +export ConformalModel, InductiveConformalModel, TransductiveConformalModel + +include("conformal_models.jl") + +include("inductive_regression.jl") +include("transductive_regression.jl") +export NaiveRegressor, SimpleInductiveRegressor + +include("inductive_classification.jl") +include("transductive_classification.jl") +export NaiveClassifier, SimpleInductiveClassifier + +"A container listing all available methods for conformal prediction." +const available_models = Dict( + :regression => Dict( + :transductive => Dict( + :naive => NaiveRegressor, + ), + :inductive => Dict( + :simple => SimpleInductiveRegressor, + ), + ), + :classification => Dict( + :transductive => Dict( + :naive => NaiveClassifier, + ), + :inductive => Dict( + :simple => SimpleInductiveClassifier, + ), + ) +) +export available_models + +# Other general methods: +export score, prediction_region + +end \ No newline at end of file diff --git a/src/ConformalModels/conformal_models.jl b/src/ConformalModels/conformal_models.jl new file mode 100644 index 0000000..92033ff --- /dev/null +++ b/src/ConformalModels/conformal_models.jl @@ -0,0 +1,126 @@ +# Main API call to wrap model: +""" + conformal_model(model::Supervised; method::Union{Nothing, Symbol}=nothing) + +A simple wrapper function that turns any `modeline{<:Supervised}` into a conformal model. It accepts an optional key argument that can be used to specify the desired method for conformal prediction. +""" +function conformal_model(model::Supervised; method::Union{Nothing, Symbol}=nothing) + + is_classifier = target_scitype(model) <: AbstractVector{<:Finite} + + if isnothing(method) + _method = is_classifier ? SimpleInductiveClassifier : NaiveRegressor + else + if is_classifier + classification_methods = merge(values(available_models[:classification])...) + @assert method in keys(classification_methods) "$(method) is not a valid method for classifiers." + _method = classification_methods[method] + else + regression_methods = merge(values(available_models[:regression])...) + @assert method in keys(regression_methods) "$(method) is not a valid method for regressors." + _method = regression_methods[method] + end + end + + conf_model = _method(model, nothing) + + return conf_model + +end +export conformal_model + +# Training +""" + fit(conf_model::TransductiveConformalModel, verbosity, X, y) + +Wrapper function to fit the underlying MLJ model and compute nonconformity scores in one single call. This method is only applicable to Transductive Conformal Prediction. +""" +function MMI.fit(conf_model::TransductiveConformalModel, verbosity, X, y) + fitresult, cache, report = MMI.fit(conf_model.model, verbosity, MMI.reformat(conf_model.model, X, y)...) + conf_model.fitresult = fitresult + # Use training data to compute nonconformity scores: + conf_model.scores = sort(ConformalModels.score(conf_model, X, y), rev=true) # non-conformity scores + return (fitresult, cache, report) +end + +""" + fit(conf_model::InductiveConformalModel, verbosity, X, y) + +Wrapper function to fit the underlying MLJ model. For Inductive Conformal Prediction the underlying model is fitted on the *proper training set*. The `fitresult` is assigned to the model instance. Computation of nonconformity scores requires a separate calibration step involving a *calibration data set* (see [`calibrate!`](@ref)). +""" +function MMI.fit(conf_model::InductiveConformalModel, verbosity, X, y) + fitresult, cache, report = MMI.fit(conf_model.model, verbosity, MMI.reformat(conf_model.model, X, y)...) + conf_model.fitresult = fitresult + return (fitresult, cache, report) +end +export fit + +# Calibration +""" + calibrate!(conf_model::InductiveConformalModel, Xcal, ycal) + +Calibrates a Inductive Conformal Model using calibration data. +""" +function calibrate!(conf_model::InductiveConformalModel, Xcal, ycal) + @assert !isnothing(conf_model.fitresult) "Cannot calibrate a model that has not been fitted." + conf_model.scores = sort(ConformalModels.score(conf_model, Xcal, ycal), rev=true) # non-conformity scores +end + +export calibrate! + +using Statistics +""" + empirical_quantile(conf_model::ConformalModel, coverage::AbstractFloat=0.95) + +Computes the empirical quantile `q̂` of the calibrated conformal scores for a user chosen coverage rate `(1-α)`. +""" +function empirical_quantile(conf_model::ConformalModel, coverage::AbstractFloat=0.95) + @assert 0.0 <= coverage <= 1.0 "Coverage out of [0,1] range." + @assert !isnothing(conf_model.scores) "conformal model has not been calibrated." + n = length(conf_model.scores) + p̂ = ceil(((n+1) * coverage)) / n + p̂ = clamp(p̂, 0.0, 1.0) + q̂ = Statistics.quantile(conf_model.scores, p̂) + return q̂ +end +export empirical_quantile + +# Prediction +""" + MMI.predict(conf_model::ConformalModel, fitresult, Xnew) + +Compulsory generic `predict` method of MMI. Simply wraps the underlying model and apply generic method to underlying model. +""" +function MMI.predict(conf_model::ConformalModel, fitresult, Xnew) + yhat = predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, Xnew)...) + return yhat +end + +# Conformal prediction through dispatch: +""" + MMI.predict(conf_model::ConformalModel, Xnew, coverage::AbstractFloat=0.95) + +Computes the conformal prediction for any calibrated conformal model and new data `Xnew`. The default coverage ratio `(1-α)` is set to 95%. +""" +function MMI.predict(conf_model::ConformalModel, Xnew, coverage::AbstractFloat=0.95) + q̂ = empirical_quantile(conf_model, coverage) + return ConformalModels.prediction_region(conf_model, Xnew, q̂) +end + +""" + score(conf_model::ConformalModel, Xcal, ycal) + +Generic method for computing non-conformity scores for any conformal model using calibration data. +""" +function score(conf_model::ConformalModel, Xcal, ycal) + # pass +end + +""" + prediction_region(conf_model::ConformalModel, Xnew, q̂::Real) + +Generic method for generating prediction regions from a calibrated conformal model for a given quantile. +""" +function prediction_region(conf_model::ConformalModel, Xnew, q̂::Real) + # pass +end \ No newline at end of file diff --git a/src/ConformalModels/inductive_classification.jl b/src/ConformalModels/inductive_classification.jl new file mode 100644 index 0000000..354ed0a --- /dev/null +++ b/src/ConformalModels/inductive_classification.jl @@ -0,0 +1,29 @@ +"A base type for Inductive Conformal Classifiers." +abstract type InductiveConformalClassifier <: InductiveConformalModel end + +# Simple +"The `SimpleInductiveClassifier` is the simplest approach to Inductive Conformal Classification. Contrary to the [`NaiveClassifier`](@ref) it computes nonconformity scores using a designated calibration dataset." +mutable struct SimpleInductiveClassifier{Model <: Supervised} <: InductiveConformalClassifier + model::Model + fitresult::Any + scores::Union{Nothing,AbstractArray} +end + +function SimpleInductiveClassifier(model::Supervised, fitresult=nothing) + return SimpleInductiveClassifier(model, fitresult, nothing) +end + + +function score(conf_model::SimpleInductiveClassifier, Xcal, ycal) + ŷ = pdf.(MMI.predict(conf_model.model, conf_model.fitresult, Xcal),ycal) + return @.(1.0 - ŷ) +end + +function prediction_region(conf_model::SimpleInductiveClassifier, Xnew, q̂::Real) + p̂ = MMI.predict(conf_model.model, conf_model.fitresult, Xnew) + L = p̂.decoder.classes + ŷnew = pdf(p̂, L) + ŷnew = map(x -> collect(key => 1-val <= q̂::Real ? val : missing for (key,val) in zip(L,x)),eachrow(ŷnew)) + return ŷnew +end + diff --git a/src/ConformalModels/inductive_regression.jl b/src/ConformalModels/inductive_regression.jl new file mode 100644 index 0000000..4dfc537 --- /dev/null +++ b/src/ConformalModels/inductive_regression.jl @@ -0,0 +1,24 @@ +"A base type for Inductive Conformal Regressors." +abstract type InductiveConformalRegressor <: InductiveConformalModel end + +"The `SimpleInductiveRegressor` is the simplest approach to Inductive Conformal Regression. Contrary to the [`NaiveRegressor`](@ref) it computes nonconformity scores using a designated calibration dataset." +mutable struct SimpleInductiveRegressor{Model <: Supervised} <: InductiveConformalRegressor + model::Model + fitresult::Any + scores::Union{Nothing,AbstractArray} +end + +function SimpleInductiveRegressor(model::Supervised, fitresult=nothing) + return SimpleInductiveRegressor(model, fitresult, nothing) +end + +function score(conf_model::SimpleInductiveRegressor, Xcal, ycal) + ŷ = MMI.predict(conf_model.model, conf_model.fitresult, Xcal) + return @.(abs(ŷ - ycal)) +end + +function prediction_region(conf_model::SimpleInductiveRegressor, Xnew, q̂::Real) + ŷnew = MMI.predict(conf_model.model, conf_model.fitresult, Xnew) + ŷnew = map(x -> ["lower" => x .- q̂, "upper" => x .+ q̂],eachrow(ŷnew)) + return ŷnew +end \ No newline at end of file diff --git a/src/ConformalModels/transductive_classification.jl b/src/ConformalModels/transductive_classification.jl new file mode 100644 index 0000000..28727f4 --- /dev/null +++ b/src/ConformalModels/transductive_classification.jl @@ -0,0 +1,28 @@ +"A base type for Transductive Conformal Classifiers." +abstract type TransductiveConformalClassifier <: TransductiveConformalModel end + +# Simple +"The `NaiveClassifier` is the simplest approach to Inductive Conformal Classification. Contrary to the [`NaiveClassifier`](@ref) it computes nonconformity scores using a designated calibration dataset." +mutable struct NaiveClassifier{Model <: Supervised} <: TransductiveConformalClassifier + model::Model + fitresult::Any + scores::Union{Nothing,AbstractArray} +end + +function NaiveClassifier(model::Supervised, fitresult=nothing) + return NaiveClassifier(model, fitresult, nothing) +end + + +function score(conf_model::NaiveClassifier, Xcal, ycal) + ŷ = pdf.(MMI.predict(conf_model.model, conf_model.fitresult, Xcal),ycal) + return @.(1.0 - ŷ) +end + +function prediction_region(conf_model::NaiveClassifier, Xnew, q̂::Real) + p̂ = MMI.predict(conf_model.model, conf_model.fitresult, Xnew) + L = p̂.decoder.classes + ŷnew = pdf(p̂, L) + ŷnew = map(x -> collect(key => 1-val <= q̂::Real ? val : missing for (key,val) in zip(L,x)),eachrow(ŷnew)) + return ŷnew +end \ No newline at end of file diff --git a/src/ConformalModels/transductive_regression.jl b/src/ConformalModels/transductive_regression.jl new file mode 100644 index 0000000..2e3e513 --- /dev/null +++ b/src/ConformalModels/transductive_regression.jl @@ -0,0 +1,27 @@ +"A base type for Transductive Conformal Regressors." +abstract type TransductiveConformalRegressor <: TransductiveConformalModel end + +# Naive +"The `NaiveRegressor` for conformal prediction is the simplest approach to conformal regression. It computes nonconformity scores by simply using the training data." +mutable struct NaiveRegressor{Model <: Supervised} <: TransductiveConformalRegressor + model::Model + fitresult::Any + scores::Union{Nothing,AbstractArray} +end + +function NaiveRegressor(model::Supervised, fitresult=nothing) + return NaiveRegressor(model, fitresult, nothing) +end + +function score(conf_model::NaiveRegressor, Xcal, ycal) + ŷ = MMI.predict(conf_model.model, conf_model.fitresult, Xcal) + return @.(abs(ŷ - ycal)) +end + +function prediction_region(conf_model::NaiveRegressor, Xnew, q̂::Real) + ŷnew = MMI.predict(conf_model.model, conf_model.fitresult, Xnew) + ŷnew = map(x -> ["lower" => x .- q̂, "upper" => x .+ q̂],eachrow(ŷnew)) + return ŷnew +end + + diff --git a/src/ConformalPrediction.jl b/src/ConformalPrediction.jl index 71fd18d..d081d08 100644 --- a/src/ConformalPrediction.jl +++ b/src/ConformalPrediction.jl @@ -1,51 +1,11 @@ module ConformalPrediction -# Conformal Machines -include("ConformalMachines/ConformalMachines.jl") -using .ConformalMachines -export conformal_machine -export NaiveConformalRegressor -export LABELConformalClassifier - -# Calibration -""" - calibrate!(conf_mach::ConformalMachine, Xcal, ycal) - -Calibrates a conformal machine using calibration data. -""" -function calibrate!(conf_mach::ConformalMachine, Xcal, ycal) - conf_mach.scores = sort(ConformalMachines.score(conf_mach, Xcal, ycal), rev=true) # non-conformity scores -end -export calibrate! - -using Statistics -""" - empirical_quantile(conf_mach::ConformalMachine, coverage::AbstractFloat=0.95) - -Computes the empirical quantile `q̂` of the calibrated conformal scores for a user chosen coverage rate `(1-α)`. -""" -function empirical_quantile(conf_mach::ConformalMachine, coverage::AbstractFloat=0.95) - @assert 0.0 <= coverage <= 1.0 "Coverage out of [0,1] range." - @assert !isnothing(conf_mach.scores) "Conformal machine has not been calibrated." - n = length(conf_mach.scores) - p̂ = ceil(((n+1) * coverage)) / n - p̂ = clamp(p̂, 0.0, 1.0) - q̂ = Statistics.quantile(conf_mach.scores, p̂) - return q̂ -end -export empirical_quantile - -# Prediction -import MLJ: predict -""" - predict(conf_mach::ConformalMachine, Xnew; coverage=0.95) - -Computes the conformal prediction for any calibrated conformal machine and new data `Xnew`. The default coverage ratio `(1-α)` is set to 95%. -""" -function predict(conf_mach::ConformalMachine, Xnew, coverage::AbstractFloat=0.95) - q̂ = empirical_quantile(conf_mach, coverage) - return ConformalMachines.prediction_region(conf_mach, Xnew, q̂) -end -export predict +# conformal models +include("ConformalModels/ConformalModels.jl") +using .ConformalModels +export conformal_model, fit, calibrate! +export NaiveRegressor, SimpleInductiveRegressor +export NaiveClassifier, SimpleInductiveClassifier +export available_models end diff --git a/test/Manifest.toml b/test/Manifest.toml index fe1c687..d97b7fe 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.8.1" manifest_format = "2.0" -project_hash = "28287b6138ddd0377076ad3d212b4d272ea0edc0" +project_hash = "1a94aedbb255cbd4580adc0349d35c2d43143bc3" [[deps.ARFFFiles]] deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"] @@ -16,6 +16,11 @@ git-tree-sha1 = "69f7020bd72f069c219b5e8c236c1fa90d2cb409" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" version = "1.2.1" +[[deps.AbstractTrees]] +git-tree-sha1 = "5c0b629df8a5566a06f5fef5100b53ea56e465a0" +uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +version = "0.4.2" + [[deps.Adapt]] deps = ["LinearAlgebra"] git-tree-sha1 = "195c5505521008abea5aee4f96930717958eac6f" @@ -26,6 +31,36 @@ version = "3.4.0" uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" version = "1.1.1" +[[deps.ArrayInterface]] +deps = ["ArrayInterfaceCore", "Compat", "IfElse", "LinearAlgebra", "Static"] +git-tree-sha1 = "d6173480145eb632d6571c148d94b9d3d773820e" +uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +version = "6.0.23" + +[[deps.ArrayInterfaceCore]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "5bb0f8292405a516880a3809954cb832ae7a31c5" +uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2" +version = "0.1.20" + +[[deps.ArrayInterfaceOffsetArrays]] +deps = ["ArrayInterface", "OffsetArrays", "Static"] +git-tree-sha1 = "c49f6bad95a30defff7c637731f00934c7289c50" +uuid = "015c0d05-e682-4f19-8f0a-679ce4c54826" +version = "0.1.6" + +[[deps.ArrayInterfaceStaticArrays]] +deps = ["Adapt", "ArrayInterface", "ArrayInterfaceStaticArraysCore", "LinearAlgebra", "Static", "StaticArrays"] +git-tree-sha1 = "efb000a9f643f018d5154e56814e338b5746c560" +uuid = "b0d46f97-bff5-4637-a19a-dd75974142cd" +version = "0.1.4" + +[[deps.ArrayInterfaceStaticArraysCore]] +deps = ["Adapt", "ArrayInterfaceCore", "LinearAlgebra", "StaticArraysCore"] +git-tree-sha1 = "a1e2cf6ced6505cbad2490532388683f1e88c3ed" +uuid = "dd5226c6-a4d4-4bc7-8575-46859f9c95b9" +version = "0.1.0" + [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" @@ -48,6 +83,12 @@ git-tree-sha1 = "84259bb6172806304b9101094a7cc4bc6f56dbc6" uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" version = "0.1.5" +[[deps.BitTwiddlingConvenienceFunctions]] +deps = ["Static"] +git-tree-sha1 = "eaee37f76339077f86679787a71990c4e465477f" +uuid = "62783981-4cbd-42fc-bca8-16325de8dc4b" +version = "0.1.4" + [[deps.Bzip2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "19a35467a82e236ff51bc17a3a44b69ef35185a2" @@ -59,6 +100,12 @@ git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" version = "0.4.2" +[[deps.CPUSummary]] +deps = ["CpuId", "IfElse", "Static"] +git-tree-sha1 = "9bdd5aceea9fa109073ace6b430a24839d79315e" +uuid = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" +version = "0.1.27" + [[deps.CUDA]] deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"] git-tree-sha1 = "49549e2c28ffb9cc77b3689dc10e46e6271e9452" @@ -95,6 +142,12 @@ git-tree-sha1 = "38f7a08f19d8810338d4f5085211c7dfa5d5bdd8" uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" version = "0.1.4" +[[deps.CloseOpenIntervals]] +deps = ["ArrayInterface", "Static"] +git-tree-sha1 = "5522c338564580adf5d58d91e43a55db0fa5fb39" +uuid = "fb6a15b2-703c-40df-9091-08a04967cfa9" +version = "0.1.10" + [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da" @@ -130,11 +183,17 @@ git-tree-sha1 = "08c8b6831dc00bfea825826be0bc8336fc369860" uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" version = "1.0.2" +[[deps.CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + [[deps.Compat]] deps = ["Dates", "LinearAlgebra", "UUIDs"] -git-tree-sha1 = "5856d3031cdb1f3b2b6340dfdc66b6d9a149a374" +git-tree-sha1 = "3ca828fe1b75fa84b021a7860bd039eaea84d2f2" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.2.0" +version = "4.3.0" [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] @@ -157,15 +216,21 @@ git-tree-sha1 = "d05d9e7b7aedff4e5b51a029dced05cfb6125781" uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" version = "0.6.2" +[[deps.CpuId]] +deps = ["Markdown"] +git-tree-sha1 = "fcbb72b032692610bfbdb15018ac16a36cf2e406" +uuid = "adafc99b-e345-5852-983c-f28acb93d879" +version = "0.3.1" + [[deps.Crayons]] git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" version = "4.1.1" [[deps.DataAPI]] -git-tree-sha1 = "1106fa7e1256b402a86a8e7b15c00c85036fef49" +git-tree-sha1 = "46d2680e618f8abd007bce0c3026cb0c4a8f2032" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.11.0" +version = "1.12.0" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] @@ -182,6 +247,12 @@ version = "1.0.0" deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +[[deps.DecisionTree]] +deps = ["AbstractTrees", "DelimitedFiles", "LinearAlgebra", "Random", "ScikitLearnBase", "Statistics"] +git-tree-sha1 = "fb3f7ff27befb9877bee84076dd9173185d7d86a" +uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" +version = "0.11.2" + [[deps.DelimitedFiles]] deps = ["Mmap"] uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" @@ -192,6 +263,18 @@ git-tree-sha1 = "80c3e8639e3353e5d2912fb3a1916b8455e2494b" uuid = "b429d917-457f-4dbc-8f4c-0cc954292b1d" version = "0.4.0" +[[deps.DiffResults]] +deps = ["StaticArraysCore"] +git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.1.0" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "992a23afdb109d0d2f8802a30cf5ae4b1fe7ea68" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.11.1" + [[deps.Distances]] deps = ["LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI"] git-tree-sha1 = "3258d0659f812acde79e8a74b11f17ac06d0ca04" @@ -204,9 +287,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[deps.Distributions]] deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"] -git-tree-sha1 = "0d7d213133d948c56e8c2d9f4eab0293491d8e4a" +git-tree-sha1 = "04db820ebcfc1e053bd8cbb8d8bccf0ff3ead3f7" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.75" +version = "0.25.76" [[deps.DocStringExtensions]] deps = ["LibGit2"] @@ -238,10 +321,10 @@ uuid = "792122b4-ca99-40de-a6bc-6742525f08b6" version = "0.3.0" [[deps.EvoTrees]] -deps = ["BSON", "CUDA", "CategoricalArrays", "Distributions", "MLJModelInterface", "NetworkLayout", "Random", "RecipesBase", "SpecialFunctions", "Statistics", "StatsBase"] -git-tree-sha1 = "bb297d76065d6272781aaeae1aa1bf9217fde369" +deps = ["BSON", "CUDA", "CategoricalArrays", "Distributions", "LoopVectorization", "MLJModelInterface", "NetworkLayout", "Random", "RecipesBase", "Statistics", "StatsBase", "Tables"] +git-tree-sha1 = "2e01454a464cdb4ad82cc9b824ef762254c0e2ca" uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" -version = "0.11.0" +version = "0.12.0" [[deps.ExprTools]] git-tree-sha1 = "56559bbef6ca5ea0c0818fa5c90320398a6fbf8d" @@ -280,6 +363,12 @@ git-tree-sha1 = "8339d61043228fdd3eb658d86c926cb282ae72a8" uuid = "59287772-0a20-5a39-b81b-1366585eb4c0" version = "0.4.2" +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] +git-tree-sha1 = "187198a4ed8ccd7b5d99c41b69c679269ea2b2d4" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.32" + [[deps.FreeType]] deps = ["CEnum", "FreeType2_jll"] git-tree-sha1 = "cabd77ab6a6fdff49bfd24af2ebe76e6e018a2b4" @@ -292,12 +381,6 @@ git-tree-sha1 = "87eb71354d8ec1a96d4a7636bd57a7347dde3ef9" uuid = "d7e528f0-a631-5988-bf34-fe36492bcfd7" version = "2.10.4+0" -[[deps.FreeTypeAbstraction]] -deps = ["ColorVectorSpace", "Colors", "FreeType", "GeometryBasics"] -git-tree-sha1 = "38a92e40157100e796690421e34a11c107205c86" -uuid = "663a7486-cb36-511b-a19d-713bb74d65c9" -version = "0.10.0" - [[deps.Future]] deps = ["Random"] uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" @@ -334,9 +417,15 @@ version = "0.4.4" [[deps.HTTP]] deps = ["Base64", "CodecZlib", "Dates", "IniFile", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] -git-tree-sha1 = "4abede886fcba15cd5fd041fef776b230d004cee" +git-tree-sha1 = "e8c58d5f03b9d9eb9ed7067a2f34c7c371ab130b" uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "1.4.0" +version = "1.4.1" + +[[deps.HostCPUFeatures]] +deps = ["BitTwiddlingConvenienceFunctions", "IfElse", "Libdl", "Static"] +git-tree-sha1 = "b7b88a4716ac33fe31d6556c02fc60017594343c" +uuid = "3e5b6fbb-0976-4d2c-9146-d79de83f2fb0" +version = "0.1.8" [[deps.HypergeometricFunctions]] deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions", "Test"] @@ -344,6 +433,11 @@ git-tree-sha1 = "709d864e3ed6e3545230601f94e11ebc65994641" uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" version = "0.3.11" +[[deps.IfElse]] +git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" +uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +version = "0.1.1" + [[deps.IniFile]] git-tree-sha1 = "f550e6e32074c939295eb5ea6de31849ac2c9625" uuid = "83e8ac13-25f8-5344-8a64-a9f2b223428f" @@ -355,9 +449,9 @@ uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[deps.InverseFunctions]] deps = ["Test"] -git-tree-sha1 = "b3364212fb5d870f724876ffcd34dd8ec6d98918" +git-tree-sha1 = "49510dfcb407e572524ba94aeae2fced1f3feb0f" uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.7" +version = "0.1.8" [[deps.InvertedIndices]] git-tree-sha1 = "bee5f1ef5bf65df56bdd2e40447590b272a5471f" @@ -415,15 +509,16 @@ git-tree-sha1 = "42938ab65e9ed3c3029a8d2c58382ca75bdab243" uuid = "a5e1c1ea-c99a-51d3-a14d-a9a37257b02d" version = "1.8.0" +[[deps.LayoutPointers]] +deps = ["ArrayInterface", "ArrayInterfaceOffsetArrays", "ArrayInterfaceStaticArrays", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static"] +git-tree-sha1 = "b67e749fb35530979839e7b4b606a97105fe4f1c" +uuid = "10f19ff3-798f-405d-979b-55457f8fc047" +version = "0.1.10" + [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" -[[deps.LazyModules]] -git-tree-sha1 = "a560dd966b386ac9ae60bdd3a3d3a326062d3c3e" -uuid = "8cdb02fc-e678-4876-92c5-9defec4f444e" -version = "0.3.1" - [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" @@ -465,6 +560,12 @@ git-tree-sha1 = "5d4d2d9904227b8bd66386c1138cf4d5ffa826bf" uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" version = "0.4.9" +[[deps.LoopVectorization]] +deps = ["ArrayInterface", "ArrayInterfaceCore", "ArrayInterfaceOffsetArrays", "ArrayInterfaceStaticArrays", "CPUSummary", "ChainRulesCore", "CloseOpenIntervals", "DocStringExtensions", "ForwardDiff", "HostCPUFeatures", "IfElse", "LayoutPointers", "LinearAlgebra", "OffsetArrays", "PolyesterWeave", "SIMDDualNumbers", "SIMDTypes", "SLEEFPirates", "SnoopPrecompile", "SpecialFunctions", "Static", "ThreadingUtilities", "UnPack", "VectorizationBase"] +git-tree-sha1 = "39af6a1e398a29f568dc9fe469f459ad3aacb03b" +uuid = "bdcacae8-1622-11e9-2a5c-532679323890" +version = "0.12.133" + [[deps.LossFunctions]] deps = ["InteractiveUtils", "Markdown", "RecipesBase"] git-tree-sha1 = "53cd63a12f06a43eef6f4aafb910ac755c122be7" @@ -483,6 +584,12 @@ git-tree-sha1 = "f68deea1f25727f24a4afa9f941763e6fc44f5af" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" version = "0.20.19" +[[deps.MLJDecisionTreeInterface]] +deps = ["DecisionTree", "MLJModelInterface", "Random", "Tables"] +git-tree-sha1 = "d0d682ef8504e1ab705f10307c587239ebb20c4d" +uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661" +version = "0.2.5" + [[deps.MLJEnsembles]] deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJBase", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatsBase"] git-tree-sha1 = "ed2f724be26d0023cade9d59b55da93f528c3f26" @@ -497,15 +604,15 @@ version = "0.5.0" [[deps.MLJModelInterface]] deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] -git-tree-sha1 = "16fa7c2e14aa5b3854bc77ab5f1dbe2cdc488903" +git-tree-sha1 = "0a36882e73833d60dac49b00d203f73acfd50b85" uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" -version = "1.6.0" +version = "1.7.0" [[deps.MLJModels]] deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "bce989ee5972ae420356fddb4a77e6fbc36798cd" +git-tree-sha1 = "147a8e7939601f8c37204addbbe29f2bcfb876a8" uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7" -version = "0.15.12" +version = "0.15.14" [[deps.MLJTuning]] deps = ["ComputationalResources", "Distributed", "Distributions", "LatinHypercubeSampling", "MLJBase", "ProgressMeter", "Random", "RecipesBase"] @@ -513,6 +620,17 @@ git-tree-sha1 = "77209966cc028c1d7730001dc32bffe17a198f29" uuid = "03970b2e-30c4-11ea-3135-d1576263f10f" version = "0.7.3" +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "42324d08725e200c23d4dfb549e0d5d89dede2d2" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.10" + +[[deps.ManualMemory]] +git-tree-sha1 = "bcaef4fc7a0cfe2cba636d84cda54b5e4e4ca3cd" +uuid = "d125e4d3-2237-4719-b19c-fa641b8a4667" +version = "0.1.8" + [[deps.MarchingCubes]] deps = ["SnoopPrecompile", "StaticArrays"] git-tree-sha1 = "ffc66942498a5f0d02b9e7b1b1af0f5873142cdc" @@ -563,6 +681,12 @@ version = "0.4.4" uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" version = "1.2.0" +[[deps.OffsetArrays]] +deps = ["Adapt"] +git-tree-sha1 = "f71d8950b724e9ff6110fc948dff5a329f901d64" +uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +version = "1.12.8" + [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" @@ -581,15 +705,15 @@ version = "0.3.0" [[deps.OpenSSL]] deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] -git-tree-sha1 = "02be9f845cb58c2d6029a6d5f67f4e0af3237814" +git-tree-sha1 = "ebe81469e9d7b471d7ddb611d9e147ea16de0add" uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" -version = "1.1.3" +version = "1.2.1" [[deps.OpenSSL_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "e60321e3f2616584ff98f0a4f18d98ae6f89bbb3" +git-tree-sha1 = "a94dc0169bffbf7e5250fb7e1efb1a85b09105c7" uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "1.1.17+0" +version = "1.1.18+0" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -616,15 +740,21 @@ version = "0.12.3" [[deps.Parsers]] deps = ["Dates"] -git-tree-sha1 = "3d5bf43e3e8b412656404ed9466f1dcbf7c50269" +git-tree-sha1 = "595c0b811cf2bab8b0849a70d9bd6379cc1cfb52" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.4.0" +version = "2.4.1" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" version = "1.8.0" +[[deps.PolyesterWeave]] +deps = ["BitTwiddlingConvenienceFunctions", "CPUSummary", "IfElse", "Static", "ThreadingUtilities"] +git-tree-sha1 = "b42fb2292fbbaed36f25d33a15c8cc0b4f287fcf" +uuid = "1d0040c9-8b98-4ee7-8388-3f51789ca0ad" +version = "0.1.10" + [[deps.Preferences]] deps = ["TOML"] git-tree-sha1 = "47e5f437cc0e7ef2ce8406ce1e7e24d44915f88d" @@ -638,9 +768,9 @@ version = "0.4.0" [[deps.PrettyTables]] deps = ["Crayons", "Formatting", "Markdown", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "9be26cbb85be86e293e2f65404139102c5c652d9" +git-tree-sha1 = "460d9e154365e058c4d886f6f7d6df5ffa1ea80e" uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.1.1" +version = "2.1.2" [[deps.Printf]] deps = ["Unicode"] @@ -691,9 +821,9 @@ version = "1.2.2" [[deps.RelocatableFolders]] deps = ["SHA", "Scratch"] -git-tree-sha1 = "22c5201127d7b243b9ee1de3b43c408879dff60f" +git-tree-sha1 = "90bc7a7c96410424509e4263e277e43250c05691" uuid = "05181044-ff0b-4ac5-8273-598c1e38db00" -version = "0.3.0" +version = "1.0.0" [[deps.Requires]] deps = ["UUIDs"] @@ -717,6 +847,23 @@ version = "0.3.0+0" uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" +[[deps.SIMDDualNumbers]] +deps = ["ForwardDiff", "IfElse", "SLEEFPirates", "VectorizationBase"] +git-tree-sha1 = "dd4195d308df24f33fb10dde7c22103ba88887fa" +uuid = "3cdde19b-5bb0-4aaf-8931-af3e248e098b" +version = "0.1.1" + +[[deps.SIMDTypes]] +git-tree-sha1 = "330289636fb8107c5f32088d2741e9fd7a061a5c" +uuid = "94e857df-77ce-4151-89e5-788b33177be4" +version = "0.1.0" + +[[deps.SLEEFPirates]] +deps = ["IfElse", "Static", "VectorizationBase"] +git-tree-sha1 = "938c9ecffb28338a6b8b970bda0f3806a65e7906" +uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa" +version = "0.6.36" + [[deps.ScientificTypes]] deps = ["CategoricalArrays", "ColorTypes", "Dates", "Distributions", "PrettyTables", "Reexport", "ScientificTypesBase", "StatisticalTraits", "Tables"] git-tree-sha1 = "82b2426c11fa6cb23bbfbe0d7378837a653ba44b" @@ -728,6 +875,12 @@ git-tree-sha1 = "a8e18eb383b5ecf1b5e6fc237eb39255044fd92b" uuid = "30f210dd-8aff-4c5f-94ba-8e64358c1161" version = "3.0.0" +[[deps.ScikitLearnBase]] +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "7877e55c1523a4b336b433da39c8e8c08d2f221f" +uuid = "6e75b9c4-186b-50bd-896f-2d2496a4843e" +version = "0.5.0" + [[deps.Scratch]] deps = ["Dates"] git-tree-sha1 = "f94f779c94e58bf9ea243e77a37e16d9de9126bd" @@ -772,11 +925,17 @@ git-tree-sha1 = "3be7d49667040add7ee151fefaf1f8c04c8c8276" uuid = "860ef19b-820b-49d6-a774-d7a799459cd3" version = "1.0.0" +[[deps.Static]] +deps = ["IfElse"] +git-tree-sha1 = "de4f0a4f049a4c87e4948c04acff37baf1be01a6" +uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +version = "0.7.7" + [[deps.StaticArrays]] deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] -git-tree-sha1 = "2189eb2c1f25cb3f43e5807f26aa864052e50c17" +git-tree-sha1 = "f86b3a049e5d05227b10e15dbb315c5b90f14988" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.5.8" +version = "1.5.9" [[deps.StaticArraysCore]] git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a" @@ -839,9 +998,9 @@ version = "1.0.1" [[deps.Tables]] deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"] -git-tree-sha1 = "2d7164f7b8a066bcfa6224e67736ce0eb54aef5b" +git-tree-sha1 = "c79322d36826aa2f4fd8ecfa96ddb47b174ac78d" uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.9.0" +version = "1.10.0" [[deps.Tar]] deps = ["ArgTools", "SHA"] @@ -858,6 +1017,12 @@ version = "0.1.1" deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[[deps.ThreadingUtilities]] +deps = ["ManualMemory"] +git-tree-sha1 = "f8629df51cab659d70d2e5618a430b4d3f37f2c3" +uuid = "8290d209-cae3-49c0-8002-c8c24d57dab5" +version = "0.5.0" + [[deps.TimerOutputs]] deps = ["ExprTools", "Printf"] git-tree-sha1 = "9dfcb767e17b0849d6aaf85997c98a5aea292513" @@ -888,10 +1053,10 @@ version = "1.0.2" uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" [[deps.UnicodePlots]] -deps = ["ColorSchemes", "ColorTypes", "Contour", "Crayons", "Dates", "FileIO", "FreeTypeAbstraction", "LazyModules", "LinearAlgebra", "MarchingCubes", "NaNMath", "Printf", "Requires", "SnoopPrecompile", "SparseArrays", "StaticArrays", "StatsBase", "Unitful"] -git-tree-sha1 = "f2ac653d1b971c27f59c1ba88532ca3c259031e2" +deps = ["ColorSchemes", "ColorTypes", "Contour", "Crayons", "Dates", "FileIO", "FreeType", "LinearAlgebra", "MarchingCubes", "NaNMath", "Printf", "Requires", "SnoopPrecompile", "SparseArrays", "StaticArrays", "StatsBase", "Unitful"] +git-tree-sha1 = "8a6dcd44129de81cc760b9d8af6fba188d3a01a6" uuid = "b8865327-cd53-5732-bb35-84acbb429228" -version = "3.1.2" +version = "3.1.3" [[deps.Unitful]] deps = ["ConstructionBase", "Dates", "LinearAlgebra", "Random"] @@ -899,6 +1064,12 @@ git-tree-sha1 = "d57a4ed70b6f9ff1da6719f5f2713706d57e0d66" uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" version = "1.12.0" +[[deps.VectorizationBase]] +deps = ["ArrayInterface", "CPUSummary", "HostCPUFeatures", "IfElse", "LayoutPointers", "Libdl", "LinearAlgebra", "SIMDTypes", "Static"] +git-tree-sha1 = "3bc5ea8fbf25f233c4c49c0a75f14b276d2f9a69" +uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f" +version = "0.21.51" + [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" diff --git a/test/Project.toml b/test/Project.toml index a59ac51..2e6ad33 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,6 @@ [deps] EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" +MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661" +MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/classification.jl b/test/classification.jl index 1880cd6..36401b2 100644 --- a/test/classification.jl +++ b/test/classification.jl @@ -1,35 +1,63 @@ using MLJ X, y = MLJ.make_blobs(1000, 2, centers=2) train, calibration, test = partition(eachindex(y), 0.4, 0.4) -EvoTreeClassifier = @load EvoTreeClassifier pkg=EvoTrees -model = EvoTreeClassifier() +DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree +model = DecisionTreeClassifier() mach = machine(model, X, y) fit!(mach, rows=train) -available_machines = ConformalPrediction.ConformalMachines.available_machines[:classification] +available_models = ConformalPrediction.ConformalModels.available_models[:classification] @testset "Classification" begin - using ConformalPrediction - - @testset "Default" begin - conf_mach = conformal_machine(mach) - @test isnothing(conf_mach.scores) - @test typeof(conf_mach) <: ConformalPrediction.ConformalMachines.ConformalClassifier - calibrate!(conf_mach, selectrows(X, calibration), y[calibration]) - @test !isnothing(conf_mach.scores) - predict(conf_mach, selectrows(X, test)) + @testset "Inductive" begin + + for _method in keys(available_models[:inductive]) + @testset "Method: $(_method)" begin + conf_model = conformal_model(model; method=_method) + conf_model = available_models[:inductive][_method](model) + @test isnothing(conf_model.scores) + @test typeof(conf_model) <: ConformalPrediction.ConformalModels.InductiveConformalClassifier + + # No fitresult provided: + @test_throws AssertionError calibrate!(conf_model, selectrows(X, calibration), y[calibration]) + + # Use fitresult from machine: + conf_model.fitresult = mach.fitresult + calibrate!(conf_model, selectrows(X, calibration), y[calibration]) + + # Use generic fit() method: + conf_model.fitresult = nothing + _mach = machine(conf_model, X, y) + fit!(_mach, rows=train) + calibrate!(conf_model, selectrows(X, calibration), y[calibration]) + + @test !isnothing(conf_model.scores) + predict(conf_model, selectrows(X, test)) + end + end end - for _method in keys(available_machines) - @testset "Method: $(_method)" begin - conf_mach = conformal_machine(mach; method=_method) - conf_mach = available_machines[_method](mach) - @test isnothing(conf_mach.scores) - @test typeof(conf_mach) <: ConformalPrediction.ConformalMachines.ConformalClassifier - calibrate!(conf_mach, selectrows(X, calibration), y[calibration]) - @test !isnothing(conf_mach.scores) - predict(conf_mach, selectrows(X, test)) + @testset "Transductive" begin + + for _method in keys(available_models[:transductive]) + @testset "Method: $(_method)" begin + conf_model = conformal_model(model; method=_method) + conf_model = available_models[:transductive][_method](model) + @test isnothing(conf_model.scores) + @test typeof(conf_model) <: ConformalPrediction.ConformalModels.TransductiveConformalClassifier + + # Trying to use calibration data: + @test_throws MethodError calibrate!(conf_model, selectrows(X, calibration), y[calibration]) + + # Use generic fit() method: + _mach = machine(conf_model, X, y) + fit!(_mach, rows=train) + + @test !isnothing(conf_model.scores) + predict(conf_model, selectrows(X, test)) + end end + end -end +end \ No newline at end of file diff --git a/test/regression.jl b/test/regression.jl index 0e28e4a..ee8beae 100644 --- a/test/regression.jl +++ b/test/regression.jl @@ -1,35 +1,64 @@ using MLJ + X, y = MLJ.make_regression(1000, 2) train, calibration, test = partition(eachindex(y), 0.4, 0.4) -EvoTreeRegressor = @load EvoTreeRegressor pkg=EvoTrees -model = EvoTreeRegressor() +DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree +model = DecisionTreeRegressor() mach = machine(model, X, y) fit!(mach, rows=train) -available_machines = ConformalPrediction.ConformalMachines.available_machines[:regression] - -@testset "Classification" begin +available_models = ConformalPrediction.ConformalModels.available_models[:regression] - using ConformalPrediction +@testset "Regression" begin - @testset "Default" begin - conf_mach = conformal_machine(mach) - @test isnothing(conf_mach.scores) - @test typeof(conf_mach) <: ConformalPrediction.ConformalMachines.ConformalRegressor - calibrate!(conf_mach, selectrows(X, calibration), y[calibration]) - @test !isnothing(conf_mach.scores) - predict(conf_mach, selectrows(X, test)) + @testset "Inductive" begin + + for _method in keys(available_models[:inductive]) + @testset "Method: $(_method)" begin + conf_model = conformal_model(model; method=_method) + conf_model = available_models[:inductive][_method](model) + @test isnothing(conf_model.scores) + @test typeof(conf_model) <: ConformalPrediction.ConformalModels.InductiveConformalRegressor + + # No fitresult provided: + @test_throws AssertionError calibrate!(conf_model, selectrows(X, calibration), y[calibration]) + + # Use fitresult from machine: + conf_model.fitresult = mach.fitresult + calibrate!(conf_model, selectrows(X, calibration), y[calibration]) + + # Use generic fit() method: + conf_model.fitresult = nothing + _mach = machine(conf_model, X, y) + fit!(_mach, rows=train) + calibrate!(conf_model, selectrows(X, calibration), y[calibration]) + + @test !isnothing(conf_model.scores) + predict(conf_model, selectrows(X, test)) + end + end end - for _method in keys(available_machines) - @testset "Method: $(_method)" begin - conf_mach = conformal_machine(mach; method=_method) - conf_mach = available_machines[_method](mach) - @test isnothing(conf_mach.scores) - @test typeof(conf_mach) <: ConformalPrediction.ConformalMachines.ConformalRegressor - calibrate!(conf_mach, selectrows(X, calibration), y[calibration]) - @test !isnothing(conf_mach.scores) - predict(conf_mach, selectrows(X, test)) + @testset "Transductive" begin + + for _method in keys(available_models[:transductive]) + @testset "Method: $(_method)" begin + conf_model = conformal_model(model; method=_method) + conf_model = available_models[:transductive][_method](model) + @test isnothing(conf_model.scores) + @test typeof(conf_model) <: ConformalPrediction.ConformalModels.TransductiveConformalRegressor + + # Trying to use calibration data: + @test_throws MethodError calibrate!(conf_model, selectrows(X, calibration), y[calibration]) + + # Use generic fit() method: + _mach = machine(conf_model, X, y) + fit!(_mach, rows=train) + + @test !isnothing(conf_model.scores) + predict(conf_model, selectrows(X, test)) + end end + end end \ No newline at end of file