From 69f69f61b4eb047ec1f2e55c1b99975b7fd0304c Mon Sep 17 00:00:00 2001 From: Dennis Yatunin Date: Thu, 16 Nov 2023 11:25:14 -0800 Subject: [PATCH] Add matrix-free iterative solver and fix bugs in FieldNameSet --- Project.toml | 2 + benchmarks/bickleyjet/Manifest.toml | 27 +- docs/Manifest.toml | 8 +- docs/src/matrix_fields.md | 32 +- examples/Manifest.toml | 43 +- perf/Manifest.toml | 43 +- src/MatrixFields/MatrixFields.jl | 7 +- .../field_matrix_iterative_solver.jl | 499 +++++++++++++++ src/MatrixFields/field_matrix_solver.jl | 406 ++++++++---- src/MatrixFields/field_name_dict.jl | 148 ++++- src/MatrixFields/field_name_set.jl | 312 ++++----- src/MatrixFields/single_field_solver.jl | 2 +- src/MatrixFields/unrolled_functions.jl | 10 + test/MatrixFields/field_matrix_solvers.jl | 367 +++++++---- test/MatrixFields/field_names.jl | 601 +++++++++++------- 15 files changed, 1802 insertions(+), 705 deletions(-) create mode 100644 src/MatrixFields/field_matrix_iterative_solver.jl diff --git a/Project.toml b/Project.toml index 1dfb306ed9..f071b5fb8b 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ GilbertCurves = "88fa7841-ef32-4516-bb70-c6ec135699d9" HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" +KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Memoize = "c03570c3-d221-55d1-a50c-7939bbd78826" PkgVersion = "eebad327-c553-4316-9ea0-9fa01ccd7688" @@ -51,6 +52,7 @@ HDF5 = "0.16, 0.17" InteractiveUtils = "1" IntervalSets = "0.5, 0.6, 0.7" Krylov = "0.9" +KrylovKit = "0.6" LinearAlgebra = "1" Memoize = "0.4" PkgVersion = "0.1, 0.2, 0.3" diff --git a/benchmarks/bickleyjet/Manifest.toml b/benchmarks/bickleyjet/Manifest.toml index 0249528eae..910364dccf 100644 --- a/benchmarks/bickleyjet/Manifest.toml +++ b/benchmarks/bickleyjet/Manifest.toml @@ -9,15 +9,12 @@ deps = ["LinearAlgebra"] git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" version = "1.5.0" +weakdeps = ["ChainRulesCore", "Test"] [deps.AbstractFFTs.extensions] AbstractFFTsChainRulesCoreExt = "ChainRulesCore" AbstractFFTsTestExt = "Test" - [deps.AbstractFFTs.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] git-tree-sha1 = "02f731463748db57cc2ebfbd9fbc9ce8280d3433" @@ -150,6 +147,16 @@ git-tree-sha1 = "4b859a208b2397a7a623a03449e4636bdb17bcf2" uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" version = "1.16.1+1" +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra"] +git-tree-sha1 = "e0af648f0692ec1691b5d094b8724ba1346281cf" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.18.0" +weakdeps = ["SparseArrays"] + + [deps.ChainRulesCore.extensions] + ChainRulesCoreSparseArraysExt = "SparseArrays" + [[deps.ClimaComms]] deps = ["CUDA", "MPI"] git-tree-sha1 = "57c054ddd4280ca8e2b5915ef1cf1395c4edbc78" @@ -157,7 +164,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" version = "0.5.6" [[deps.ClimaCore]] -deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "WeakValueDicts"] +deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "WeakValueDicts"] path = "../.." uuid = "d414da3d-4745-48bb-8d80-42e94e092884" version = "0.11.0" @@ -573,6 +580,12 @@ version = "0.9.13" [deps.KernelAbstractions.weakdeps] EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +[[deps.KrylovKit]] +deps = ["ChainRulesCore", "GPUArraysCore", "LinearAlgebra", "Printf"] +git-tree-sha1 = "1a5e1d9941c783b0119897d29f2eb665d876ecf3" +uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" +version = "0.6.0" + [[deps.LAME_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "f6250b16881adf048549549fba48b1161acdac8c" @@ -1097,13 +1110,11 @@ deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_j git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" version = "2.3.1" +weakdeps = ["ChainRulesCore"] [deps.SpecialFunctions.extensions] SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" - [deps.SpecialFunctions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - [[deps.Static]] deps = ["IfElse"] git-tree-sha1 = "f295e0a1da4ca425659c57441bcb59abb035a4bc" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index ff68b0c818..09f41a145f 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -288,7 +288,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" version = "0.5.6" [[deps.ClimaCore]] -deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "WeakValueDicts"] +deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "WeakValueDicts"] path = ".." uuid = "d414da3d-4745-48bb-8d80-42e94e092884" version = "0.11.0" @@ -1177,6 +1177,12 @@ git-tree-sha1 = "17e462054b42dcdda73e9a9ba0c67754170c88ae" uuid = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" version = "0.9.4" +[[deps.KrylovKit]] +deps = ["ChainRulesCore", "GPUArraysCore", "LinearAlgebra", "Printf"] +git-tree-sha1 = "1a5e1d9941c783b0119897d29f2eb665d876ecf3" +uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" +version = "0.6.0" + [[deps.LAME_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "f6250b16881adf048549549fba48b1161acdac8c" diff --git a/docs/src/matrix_fields.md b/docs/src/matrix_fields.md index 3a1fa51b78..12c2d3abd6 100644 --- a/docs/src/matrix_fields.md +++ b/docs/src/matrix_fields.md @@ -34,8 +34,23 @@ FieldMatrixSolver field_matrix_solve! BlockDiagonalSolve BlockLowerTriangularSolve -SchurComplementSolve -ApproximateFactorizationSolve +BlockArrowheadSolve +BlockLUDecompositionSolve +LazyFieldMatrixSolverAlgorithm +StationaryIterativeSolve +ApproximateBlockArrowheadIterativeSolve +``` + +# Preconditioners + +```@docs +PreconditionerAlgorithm +MainDiagonalPreconditioner +BlockDiagonalPreconditioner +BlockArrowheadPreconditioner +BlockArrowheadSchurComplementPreconditioner +WeightedPreconditioner +CustomPreconditioner ``` ## Internals @@ -57,6 +72,19 @@ FieldNameTree FieldNameSet FieldNameDict field_vector_view +concrete_field_vector +lazy_main_diagonal +lazy_mul +LazySchurComplement +field_matrix_solver_cache +check_field_matrix_solver +run_field_matrix_solver! +solver_algorithm +lazy_preconditioner +preconditioner_cache +check_preconditioner +lazy_or_concrete_preconditioner +apply_preconditioner ``` ## Utilities diff --git a/examples/Manifest.toml b/examples/Manifest.toml index be11726bf2..4c484bd819 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -20,15 +20,12 @@ deps = ["LinearAlgebra"] git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" version = "1.5.0" +weakdeps = ["ChainRulesCore", "Test"] [deps.AbstractFFTs.extensions] AbstractFFTsChainRulesCoreExt = "ChainRulesCore" AbstractFFTsTestExt = "Test" - [deps.AbstractFFTs.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - [[deps.AbstractTrees]] git-tree-sha1 = "faa260e4cb5aba097a73fab382dd4b5819d8ec8c" uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" @@ -208,15 +205,12 @@ deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_ git-tree-sha1 = "76582ae19006b1186e87dadd781747f76cead72c" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" version = "5.1.1" +weakdeps = ["ChainRulesCore", "SpecialFunctions"] [deps.CUDA.extensions] ChainRulesCoreExt = "ChainRulesCore" SpecialFunctionsExt = "SpecialFunctions" - [deps.CUDA.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" - [[deps.CUDA_Driver_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] git-tree-sha1 = "1e42ef1bdb45487ff28de16182c0df4920181dc3" @@ -241,6 +235,16 @@ git-tree-sha1 = "4b859a208b2397a7a623a03449e4636bdb17bcf2" uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" version = "1.16.1+1" +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra"] +git-tree-sha1 = "e0af648f0692ec1691b5d094b8724ba1346281cf" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.18.0" +weakdeps = ["SparseArrays"] + + [deps.ChainRulesCore.extensions] + ChainRulesCoreSparseArraysExt = "SparseArrays" + [[deps.ClimaComms]] deps = ["CUDA", "MPI"] git-tree-sha1 = "57c054ddd4280ca8e2b5915ef1cf1395c4edbc78" @@ -248,7 +252,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" version = "0.5.6" [[deps.ClimaCore]] -deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "WeakValueDicts"] +deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "WeakValueDicts"] path = ".." uuid = "d414da3d-4745-48bb-8d80-42e94e092884" version = "0.11.0" @@ -506,15 +510,12 @@ deps = ["LinearAlgebra", "Statistics", "StatsAPI"] git-tree-sha1 = "5225c965635d8c21168e32a12954675e7bea1151" uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" version = "0.10.10" +weakdeps = ["ChainRulesCore", "SparseArrays"] [deps.Distances.extensions] DistancesChainRulesCoreExt = "ChainRulesCore" DistancesSparseArraysExt = "SparseArrays" - [deps.Distances.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - [[deps.Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -959,6 +960,12 @@ git-tree-sha1 = "17e462054b42dcdda73e9a9ba0c67754170c88ae" uuid = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" version = "0.9.4" +[[deps.KrylovKit]] +deps = ["ChainRulesCore", "GPUArraysCore", "LinearAlgebra", "Printf"] +git-tree-sha1 = "1a5e1d9941c783b0119897d29f2eb665d876ecf3" +uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" +version = "0.6.0" + [[deps.LAME_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "f6250b16881adf048549549fba48b1161acdac8c" @@ -1213,16 +1220,12 @@ deps = ["ArrayInterface", "CPUSummary", "CloseOpenIntervals", "DocStringExtensio git-tree-sha1 = "0f5648fbae0d015e3abe5867bca2b362f67a5894" uuid = "bdcacae8-1622-11e9-2a5c-532679323890" version = "0.12.166" +weakdeps = ["ChainRulesCore", "ForwardDiff", "SpecialFunctions"] [deps.LoopVectorization.extensions] ForwardDiffExt = ["ChainRulesCore", "ForwardDiff"] SpecialFunctionsExt = "SpecialFunctions" - [deps.LoopVectorization.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" - SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" - [[deps.MKL_jll]] deps = ["Artifacts", "IntelOpenMP_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] git-tree-sha1 = "eb006abbd7041c28e0d16260e50a24f8f9104913" @@ -1876,13 +1879,11 @@ deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_j git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" version = "2.3.1" +weakdeps = ["ChainRulesCore"] [deps.SpecialFunctions.extensions] SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" - [deps.SpecialFunctions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - [[deps.SplittablesBase]] deps = ["Setfield", "Test"] git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" diff --git a/perf/Manifest.toml b/perf/Manifest.toml index 4ba467f5d3..29a2dfc6f2 100644 --- a/perf/Manifest.toml +++ b/perf/Manifest.toml @@ -20,15 +20,12 @@ deps = ["LinearAlgebra"] git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" version = "1.5.0" +weakdeps = ["ChainRulesCore", "Test"] [deps.AbstractFFTs.extensions] AbstractFFTsChainRulesCoreExt = "ChainRulesCore" AbstractFFTsTestExt = "Test" - [deps.AbstractFFTs.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - [[deps.AbstractTrees]] git-tree-sha1 = "faa260e4cb5aba097a73fab382dd4b5819d8ec8c" uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" @@ -189,15 +186,12 @@ deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_ git-tree-sha1 = "76582ae19006b1186e87dadd781747f76cead72c" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" version = "5.1.1" +weakdeps = ["ChainRulesCore", "SpecialFunctions"] [deps.CUDA.extensions] ChainRulesCoreExt = "ChainRulesCore" SpecialFunctionsExt = "SpecialFunctions" - [deps.CUDA.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" - [[deps.CUDA_Driver_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] git-tree-sha1 = "1e42ef1bdb45487ff28de16182c0df4920181dc3" @@ -222,6 +216,16 @@ git-tree-sha1 = "4b859a208b2397a7a623a03449e4636bdb17bcf2" uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" version = "1.16.1+1" +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra"] +git-tree-sha1 = "e0af648f0692ec1691b5d094b8724ba1346281cf" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.18.0" +weakdeps = ["SparseArrays"] + + [deps.ChainRulesCore.extensions] + ChainRulesCoreSparseArraysExt = "SparseArrays" + [[deps.ClimaComms]] deps = ["CUDA", "MPI"] git-tree-sha1 = "57c054ddd4280ca8e2b5915ef1cf1395c4edbc78" @@ -229,7 +233,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" version = "0.5.6" [[deps.ClimaCore]] -deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "WeakValueDicts"] +deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "WeakValueDicts"] path = ".." uuid = "d414da3d-4745-48bb-8d80-42e94e092884" version = "0.11.0" @@ -505,15 +509,12 @@ deps = ["LinearAlgebra", "Statistics", "StatsAPI"] git-tree-sha1 = "5225c965635d8c21168e32a12954675e7bea1151" uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" version = "0.10.10" +weakdeps = ["ChainRulesCore", "SparseArrays"] [deps.Distances.extensions] DistancesChainRulesCoreExt = "ChainRulesCore" DistancesSparseArraysExt = "SparseArrays" - [deps.Distances.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - [[deps.Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -993,6 +994,12 @@ git-tree-sha1 = "17e462054b42dcdda73e9a9ba0c67754170c88ae" uuid = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" version = "0.9.4" +[[deps.KrylovKit]] +deps = ["ChainRulesCore", "GPUArraysCore", "LinearAlgebra", "Printf"] +git-tree-sha1 = "1a5e1d9941c783b0119897d29f2eb665d876ecf3" +uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" +version = "0.6.0" + [[deps.LAME_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "f6250b16881adf048549549fba48b1161acdac8c" @@ -1247,16 +1254,12 @@ deps = ["ArrayInterface", "CPUSummary", "CloseOpenIntervals", "DocStringExtensio git-tree-sha1 = "0f5648fbae0d015e3abe5867bca2b362f67a5894" uuid = "bdcacae8-1622-11e9-2a5c-532679323890" version = "0.12.166" +weakdeps = ["ChainRulesCore", "ForwardDiff", "SpecialFunctions"] [deps.LoopVectorization.extensions] ForwardDiffExt = ["ChainRulesCore", "ForwardDiff"] SpecialFunctionsExt = "SpecialFunctions" - [deps.LoopVectorization.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" - SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" - [[deps.LoweredCodeUtils]] deps = ["JuliaInterpreter"] git-tree-sha1 = "60168780555f3e663c536500aa790b6368adc02a" @@ -1962,13 +1965,11 @@ deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_j git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" version = "2.3.1" +weakdeps = ["ChainRulesCore"] [deps.SpecialFunctions.extensions] SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" - [deps.SpecialFunctions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - [[deps.Static]] deps = ["IfElse"] git-tree-sha1 = "f295e0a1da4ca425659c57441bcb59abb035a4bc" diff --git a/src/MatrixFields/MatrixFields.jl b/src/MatrixFields/MatrixFields.jl index cbfa38dcb1..16ee884ff1 100644 --- a/src/MatrixFields/MatrixFields.jl +++ b/src/MatrixFields/MatrixFields.jl @@ -44,9 +44,11 @@ multiples of `LinearAlgebra.I`. This comes with the following functionality: module MatrixFields import CUDA -import LinearAlgebra: I, UniformScaling, Adjoint, AdjointAbsVec, mul!, inv +import LinearAlgebra: I, UniformScaling, Adjoint, AdjointAbsVec, mul!, inv, norm import StaticArrays: SMatrix, SVector import BandedMatrices: BandedMatrix, band, _BandedMatrix +import RecursiveArrayTools: recursive_bottom_eltype +import KrylovKit: eigsolve import ClimaComms import ..Utilities: PlusHalf, half @@ -96,8 +98,9 @@ include("unrolled_functions.jl") include("field_name.jl") include("field_name_set.jl") include("field_name_dict.jl") -include("field_matrix_solver.jl") include("single_field_solver.jl") +include("field_matrix_solver.jl") +include("field_matrix_iterative_solver.jl") function Base.show(io::IO, field::ColumnwiseBandMatrixField) print(io, eltype(field), "-valued Field") diff --git a/src/MatrixFields/field_matrix_iterative_solver.jl b/src/MatrixFields/field_matrix_iterative_solver.jl new file mode 100644 index 0000000000..b83d81fb71 --- /dev/null +++ b/src/MatrixFields/field_matrix_iterative_solver.jl @@ -0,0 +1,499 @@ +""" + PreconditionerAlgorithm + +Description of how to approximate a `FieldMatrix` or something similar like a +[`LazySchurComplement`](@ref) with a preconditioner `P` for which `P * x = b` is +easy to solve for `x`. If `P` is a diagonal matrix, then `x` can be computed as +`@. inv(P) * b`; otherwise, the `PreconditionerAlgorithm` must specify a +`FieldMatrixSolverAlgorithm` that can be used to solve `P * x = b` for `x`. + +# Interface + +Every subtype of `PreconditionerAlgorithm` must implement methods for the +following functions: +- [`solver_algorithm`](@ref) +- [`lazy_preconditioner`](@ref) +""" +abstract type PreconditionerAlgorithm end + +""" + solver_algorithm(P_alg) + +A `FieldMatrixSolverAlgorithm` that can be used to solve `P * x = b` for `x`, +where `P` is the preconditioner generated by the `PreconditionerAlgorithm` +`P_alg`. If `P_alg` is `nothing` instead of a `PreconditionerAlgorithm`, or if +`P` is a diagonal matrix (and no solver is required to invert it), this returns +`nothing`. +""" +solver_algorithm(::Nothing) = nothing + +is_diagonal(P_alg) = isnothing(solver_algorithm(P_alg)) + +""" + lazy_preconditioner(P_alg, A) + +Constructs an un-materialized `FieldMatrixBroadcasted` (or just a `FieldMatrix` +when possible) that approximates `A` according to the `PreconditionerAlgorithm` +`P_alg`. If `P_alg` is `nothing` instead of a `PreconditionerAlgorithm`, this +returns `one(A)`. +""" +lazy_preconditioner(::Nothing, A::FieldMatrix) = one(A) + +""" + preconditioner_cache(P_alg, A, b) + +Allocates the cache required to solve the equation `P * x = b`, where `P` is the +preconditioner generated by the `PreconditionerAlgorithm` `P_alg` for `A`. +""" +function preconditioner_cache(P_alg, A, b) + is_diagonal(P_alg) && return (;) + lazy_P = lazy_preconditioner(P_alg, A) + is_lazy_P_concrete = !(lazy_P isa FieldMatrixBroadcasted) + P = is_lazy_P_concrete ? lazy_P : Base.Broadcast.materialize(lazy_P) + P_if_needed = is_lazy_P_concrete ? (;) : (; P) + x = similar_to_x(P, b) + alg = solver_algorithm(P_alg) + cache = field_matrix_solver_cache(alg, P, b) + return (; P_if_needed..., b = similar(b), x, cache) +end + +""" + check_preconditioner(P_alg, P_cache, A, b) + +Checks that `P` is compatible with `b` in the equation `P * x = b`, where `P` is +the preconditioner generated by the `PreconditionerAlgorithm` `P_alg` for `A`. +If `P_alg` requires a `FieldMatrixSolverAlgorithm` `alg` to solve the equation, +this also calls [`check_field_matrix_solver`](@ref) on `alg`. +""" +function check_preconditioner(P_alg, P_cache, A, b) + isnothing(P_alg) && return nothing + lazy_P = lazy_preconditioner(P_alg, A) + if is_diagonal(P_alg) + check_block_diagonal_matrix_has_no_missing_blocks(lazy_P, b) + else + alg = solver_algorithm(P_alg) + check_field_matrix_solver(alg, P_cache.cache, lazy_P, b) + end +end + +""" + lazy_or_concrete_preconditioner(P_alg, P_cache, A) + +A wrapper for [`lazy_preconditioner`](@ref) that turns the un-materialized +`FieldMatrixBroadcasted` `P` into a concrete `FieldMatrix` when the +`PreconditionerAlgorithm` `P_alg` requires a `FieldMatrixSolverAlgorithm` to +invert it. +""" +function lazy_or_concrete_preconditioner(P_alg, P_cache, A) + isnothing(P_alg) && return nothing + lazy_P = lazy_preconditioner(P_alg, A) + is_lazy_P_concrete = !(lazy_P isa FieldMatrixBroadcasted) + (is_diagonal(P_alg) || is_lazy_P_concrete) && return lazy_P + @. P_cache.P = lazy_P + return P_cache.P +end + +""" + apply_preconditioner(P_alg, P_cache, P, lazy_b) + +Constructs an un-materialized `FieldMatrixBroadcasted` (or just a `FieldMatrix` +when possible) that represents the product `@. inv(P) * b`. Here, `lazy_b` +denotes an un-materialized `FieldVectorViewBroadcasted` (or a `FieldVectorView`) +that represents `b`. +""" +function apply_preconditioner(P_alg, P_cache, P, lazy_b) + isnothing(P_alg) && return lazy_b + is_diagonal(P_alg) && return lazy_mul(lazy_inv(P), lazy_b) + @. P_cache.b = lazy_b + alg = solver_algorithm(P_alg) + run_field_matrix_solver!(alg, P_cache.cache, P_cache.x, P, P_cache.b) + return P_cache.x +end + +################################################################################ + +""" + MainDiagonalPreconditioner() + +A `PreconditionerAlgorithm` that sets `P` to the main diagonal of `A`. +""" +struct MainDiagonalPreconditioner <: PreconditionerAlgorithm end + +solver_algorithm(::MainDiagonalPreconditioner) = nothing +lazy_preconditioner(::MainDiagonalPreconditioner, A::FieldMatrix) = + lazy_main_diagonal(A) + +""" + BlockDiagonalPreconditioner() + +A `PreconditionerAlgorithm` that sets `P` to the block diagonal entries of `A`. +""" +struct BlockDiagonalPreconditioner <: PreconditionerAlgorithm end + +solver_algorithm(::BlockDiagonalPreconditioner) = BlockDiagonalSolve() +lazy_preconditioner(::BlockDiagonalPreconditioner, A::FieldMatrix) = + A[matrix_diagonal_keys(keys(A))] + +""" + BlockArrowheadPreconditioner(names₁...; [P_alg₁], [alg₂]) + +A `PreconditionerAlgorithm` for a 2×2 block matrix: +```math +A = \\begin{bmatrix} A_{11} & A_{12} \\\\ A_{21} & A_{22} \\end{bmatrix} +``` +The `FieldName`s in `names₁` correspond to the subscript `₁`, while all other +`FieldName`s correspond to the subscript `₂`. The preconditioner `P` is set to +the following matrix: +```math +P = \\begin{bmatrix} P_{11} & A_{12} \\\\ A_{21} & A_{22} \\end{bmatrix}, \\quad +\\text{where } P_{11} \\text{ is a diagonal matrix} +``` +The internal preconditioner `P₁₁` is generated by the `PreconditionerAlgorithm` +`P_alg₁`, which is set to [`MainDiagonalPreconditioner`](@ref) by default. The +Schur complement of `P₁₁` in `P`, `A₂₂ - A₂₁ * inv(P₁₁) * A₁₂`, is inverted +using the `FieldMatrixSolverAlgorithm` `alg₂`, which is set to +[`BlockDiagonalSolve`](@ref) by default. +""" +struct BlockArrowheadPreconditioner{ + N <: NTuple{<:Any, FieldName}, + P <: Union{Nothing, PreconditionerAlgorithm}, + A <: FieldMatrixSolverAlgorithm, +} <: PreconditionerAlgorithm + names₁::N + P_alg₁::P + alg₂::A +end +function BlockArrowheadPreconditioner( + names₁...; + P_alg₁ = MainDiagonalPreconditioner(), + alg₂ = BlockDiagonalSolve(), +) + is_diagonal(P_alg₁) || + error("BlockArrowheadPreconditioner requires a preconditioner P_alg₁ \ + that generates a diagonal matrix") + return BlockArrowheadPreconditioner(names₁, P_alg₁, alg₂) +end + +solver_algorithm(P_alg::BlockArrowheadPreconditioner) = + BlockArrowheadSolve(P_alg.names₁...; P_alg.alg₂) +function lazy_preconditioner( + P_alg::BlockArrowheadPreconditioner, + A::FieldMatrix, +) + A₁₁, A₁₂, A₂₁, A₂₂ = partition_blocks(P_alg.names₁, A) + lazy_P₁₁ = lazy_preconditioner(P_alg.P_alg₁, A₁₁) + return lazy_add(lazy_P₁₁, A₁₂, A₂₁, A₂₂) +end + +""" + BlockArrowheadSchurComplementPreconditioner(; [P_alg₁], [alg₂]) + +A `PreconditionerAlgorithm` that is equivalent to a +[`BlockArrowheadPreconditioner`](@ref), but only applied to the Schur complement +of `A₁₁` in `A` (represented using a [`LazySchurComplement`](@ref)). +""" +struct BlockArrowheadSchurComplementPreconditioner{ + P <: Union{Nothing, PreconditionerAlgorithm}, + A <: FieldMatrixSolverAlgorithm, +} <: PreconditionerAlgorithm + P_alg₁::P + alg₂::A +end +function BlockArrowheadSchurComplementPreconditioner(; + P_alg₁ = MainDiagonalPreconditioner(), + alg₂ = BlockDiagonalSolve(), +) + is_diagonal(P_alg₁) || + error("BlockArrowheadSchurComplementPreconditioner requires a \ + preconditioner P_alg₁ that generates a diagonal matrix") + return BlockArrowheadSchurComplementPreconditioner(P_alg₁, alg₂) +end + +solver_algorithm(P_alg₂::BlockArrowheadSchurComplementPreconditioner) = + P_alg₂.alg₂ +function lazy_preconditioner( + P_alg₂::BlockArrowheadSchurComplementPreconditioner, + A₂₂′::LazySchurComplement, +) + (; A₁₁, A₁₂, A₂₁, A₂₂) = A₂₂′ + lazy_P₁₁ = lazy_preconditioner(P_alg₂.P_alg₁, A₁₁) + return lazy_sub(A₂₂, lazy_mul(A₂₁, lazy_inv(lazy_P₁₁), A₁₂)) +end + +""" + WeightedPreconditioner(M, unweighted_P_alg) + +A `PreconditionerAlgorithm` that sets `P` to `M * P′`, where `M` is a diagonal +`FieldMatrix` and `P′` is the preconditioner generated by `unweighted_P_alg`. +""" +struct WeightedPreconditioner{M <: FieldMatrix, P <: PreconditionerAlgorithm} <: + PreconditionerAlgorithm + M::M + unweighted_P_alg::P +end +function WeightedPreconditioner(M, unweighted_P_alg) + check_diagonal_matrix( + M, + "WeightedPreconditioner cannot use M as a weighting matrix because it", + ) + P = typeof(unweighted_P_alg) + return WeightedPreconditioner{typeof(M), P}(M, unweighted_P_alg) +end + +solver_algorithm(P_alg::WeightedPreconditioner) = + solver_algorithm(P_alg.unweighted_P_alg) +lazy_preconditioner(P_alg::WeightedPreconditioner, A) = + lazy_mul(P_alg.M, lazy_preconditioner(P_alg.unweighted_P_alg, A)) + +""" + CustomPreconditioner(M, alg) + +A `PreconditionerAlgorithm` that sets `P` to the `FieldMatrix` `M` and inverts +`P` using the `FieldMatrixSolverAlgorithm` `alg`. +""" +struct CustomPreconditioner{ + M <: FieldMatrix, + A <: Union{Nothing, FieldMatrixSolverAlgorithm}, +} <: PreconditionerAlgorithm + M::M + alg::A +end +function CustomPreconditioner(M; alg = nothing) + isnothing(alg) && check_diagonal_matrix( + M, + "CustomPreconditioner requires alg to be specified for the matrix M \ + because it", + ) + return CustomPreconditioner(M, alg) +end + +solver_algorithm(P_alg::CustomPreconditioner) = P_alg.alg +lazy_preconditioner(P_alg::CustomPreconditioner, _) = P_alg.M + +################################################################################ + +""" + StationaryIterativeSolve(; [kwargs...]) + +A `LazyFieldMatrixSolverAlgorithm` that solves `A * x = b` by setting `x` to +some initial value `x[0]` (usually just the zero vector ``\\mathbf{0}``) and +then iteratively updating it to +```math +x[n] = x[n - 1] + \\textrm{inv}(P) * (b - A * x[n - 1]). +``` +The matrix `P` is called a "left preconditioner" for `A`. In general, this +algorithm converges more quickly when `P` is a close approximation of `A`, +although more complicated approximations often come with a performance penalty. + +# Background + +Let `x'` denote the value of `x` for which `A * x = b`. Replacing `b` with +`A * x'` in the formula for `x[n]` tells us that +```math +x[n] = x' + (I - \\textrm{inv}(P) * A) * (x[n - 1] - x'). +``` +In other words, the error on iteration `n`, `x[n] - x'`, can be expressed in +terms of the error on the previous iteration, `x[n - 1] - x'`, as +```math +x[n] - x' = (I - \\textrm{inv}(P) * A) * (x[n - 1] - x'). +``` +By induction, this means that the error on iteration `n` is +```math +x[n] - x' = (I - \\textrm{inv}(P) * A)^n * (x[0] - x'). +``` +If we pick some norm ``||\\cdot||``, we find that the norm of the error is +bounded by +```math +||x[n] - x'|| ≤ ||(I - \\textrm{inv}(P) * A)^n|| * ||x[0] - x'||. +``` +For any matrix ``M``, the spectral radius of ``M`` is defined as +```math +\\rho(M) = \\max\\{|λ| : λ \\text{ is an eigenvalue of } M\\}. +``` +The spectral radius has the property that +```math +||M^n|| \\sim \\rho(M)^n, \\quad \\text{i.e.,} \\quad +\\lim_{n \\to \\infty} \\frac{||M^n||}{\\rho(M)^n} = 1. +``` +So, as the value of `n` increases, the norm of the error becomes bounded by +```math +||x[n] - x'|| \\leq \\rho(I - \\textrm{inv}(P) * A)^n * ||x[0] - x'||. +``` +This indicates that `x[n]` will converge to `x'` (i.e., that the norm of the +error will converge to 0) when `ρ(I - inv(P) * A) < 1`, and that the convergence +rate is roughly bounded by `ρ(I - inv(P) * A)` for large values of `n`. More +precisely, it can be shown that `x[n]` will converge to `x'` if and only if +`ρ(I - inv(P) * A) < 1`. In practice, though, the convergence eventually stops +due to the limits of floating point precision. + +Also, if we assume that `x[n] ≈ x'`, we can use the formula for `x[n]` to +approximate the error on the previous iteration as +```math +x[n - 1] - x' ≈ x[n - 1] - x[n] = \\textrm{inv}(P) * (A * x[n - 1] - b). +``` + +# Keyword Arguments + +- `P_alg = nothing`: a `PreconditionerAlgorithm` that specifies how to compute + `P` and solve `P * x = b` for `x`, or `nothing` if preconditioning + is not required (in which case `P` is effectively set to `one(A)`) +- `n_iters = 1`: the number of iterations +- `correlated_solves = false`: whether to set `x[0]` to a value of `x` that was + generated during an earlier call to `field_matrix_solve!`, instead of setting + it to ``\\mathbf{0}`` (it is always set to ``\\mathbf{0}`` on the first call + to `field_matrix_solve!`) +- `print_norm = false`: whether to print `||x[n] - x'||₂` on every iteration, + where the error `x[n] - x'` is approximated as described above +- `print_radius = false`: whether to print `ρ(I - inv(P) * A)`, which is + approximated using the `eigsolve` function from KrylovKit.jl +- `eigsolve_kwargs = (;)`: keyword arguments for the `eigsolve` function that + can be used to tune its accuracy and speed +""" +struct StationaryIterativeSolve{ + correlated_solves, + print_norm, + print_radius, + P <: Union{Nothing, PreconditionerAlgorithm}, + K <: NamedTuple, +} <: LazyFieldMatrixSolverAlgorithm + P_alg::P + n_iters::Int + eigsolve_kwargs::K +end +function StationaryIterativeSolve(; + P_alg = nothing, + n_iters = 1, + correlated_solves = false, + print_norm = false, + print_radius = false, + eigsolve_kwargs = (;), +) + # Since Field operations can be much slower than typical Array operations, + # the default values of krylovdim and maxiter specified in KrylovKit.jl + # should be replaced with smaller values, and the default value of tol + # should be replaced with a larger value. + eigsolve_kwargs′ = + (; krylovdim = 4, maxiter = 20, tol = 0.01, eigsolve_kwargs...) + K = typeof(eigsolve_kwargs′) + # Turn all boolean fields into type parameters to ensure type-stability. + params = (correlated_solves, print_norm, print_radius, typeof(P_alg), K) + return StationaryIterativeSolve{params...}(P_alg, n_iters, eigsolve_kwargs′) +end + +Base.getproperty( + alg::StationaryIterativeSolve{correlated_solves, print_norm, print_radius}, + name::Symbol, +) where {correlated_solves, print_norm, print_radius} = + if name == :correlated_solves + correlated_solves + elseif name == :print_norm + print_norm + elseif name == :print_radius + print_radius + else + getfield(alg, name) + end # Extract the boolean type parameters as if they were regular fields. + +function field_matrix_solver_cache(alg::StationaryIterativeSolve, A, b) + P_cache = preconditioner_cache(alg.P_alg, A, b) + # Note: We cannot use similar_to_x here because it doesn't work for some + # particularly complicated unit tests. For now, we will assume that x is + # similar to b, rather than just keys(x) == keys(b). + previous_x_cache = + alg.correlated_solves ? (; previous_x = zero.(similar(b))) : (;) + return (; P_cache, previous_x_cache...) +end + +check_field_matrix_solver(alg::StationaryIterativeSolve, cache, A, b) = + check_preconditioner(alg.P_alg, cache.P_cache, A, b) + +function run_field_matrix_solver!(alg::StationaryIterativeSolve, cache, x, A, b) + P = lazy_or_concrete_preconditioner(alg.P_alg, cache.P_cache, A) + if alg.print_radius + e₀ = concrete_field_vector(b) # Initialize e to any nonzero vector. + eigenvalues, _, info = eigsolve(e₀, 1; alg.eigsolve_kwargs...) do e + e_view = field_vector_view(e, keys(b).name_tree) + lazy_A_e = lazy_mul(A, e_view) + lazy_invP_A_e = + apply_preconditioner(alg.P_alg, cache.P_cache, P, lazy_A_e) + concrete_field_vector(@. e_view - lazy_invP_A_e) + end + if info.converged == 0 + (; tol, maxiter) = alg.eigsolve_kwargs + @warn "Unable to approximate ρ(I - inv(P) * A) to within a \ + tolerance of $(100 * tol) % in $maxiter or fewer iterations" + else + spectral_radius = abs(eigenvalues[1]) + if spectral_radius < 1 + @info "ρ(I - inv(P) * A) ≈ $spectral_radius" + else + @warn "StationaryIterativeSolve may not converge because \ + ρ(I - inv(P) * A) ≈ $spectral_radius is not less than 1" + end + end + end + if alg.correlated_solves + @. x = cache.previous_x + else + @. x = zero(x) + end + for iter in 1:(alg.n_iters) + lazy_Δb = lazy_sub(b, lazy_mul(A, x)) + lazy_Δx = apply_preconditioner(alg.P_alg, cache.P_cache, P, lazy_Δb) + if alg.print_norm + norm_Δx = norm(concrete_field_vector(Base.materialize(lazy_Δx))) + @info "||x[$(iter - 1)] - x'||₂ ≈ $norm_Δx" + end + @. x += lazy_Δx + end + if alg.print_norm + lazy_Δb = lazy_sub(b, lazy_mul(A, x)) + lazy_Δx = apply_preconditioner(alg.P_alg, cache.P_cache, P, lazy_Δb) + norm_Δx = norm(concrete_field_vector(Base.materialize(lazy_Δx))) + @info "||x[$(alg.n_iters)] - x'||₂ ≈ $norm_Δx" + end + if alg.correlated_solves + @. cache.previous_x = x + end +end + +""" + ApproximateBlockArrowheadIterativeSolve(names₁...; [P_alg₁], [alg₁], [alg₂], [kwargs...]) + +Shorthand for constructing a [`BlockLUDecompositionSolve`](@ref) that wraps a +[`StationaryIterativeSolve`](@ref) with a +[`BlockArrowheadSchurComplementPreconditioner`](@ref). The keyword argument +`alg₂` is passed to the constructor for `BlockLUDecompositionSolve`, the keyword +arguments `P_alg₁` and `alg₂` are passed to the constructor for +`BlockArrowheadSchurComplementPreconditioner`, and any remaining +keyword arguments are passed to the constructor for `StationaryIterativeSolve`. + +Although this algorithm is similar to a `StationaryIterativeSolve` with a +[`BlockArrowheadPreconditioner`](@ref), it usually converges much more quickly +because the spectral radius of its iteration matrix (`I - inv(P) * A`) tends to +be smaller. Roughly speaking, this is due to the fact that it runs the iterative +solver on an equation with fewer variables (the Schur complement equation), +which means that a smaller error caused by interactions between variables is +accumulated on each iteration. However, even though it converges more quickly, +its iterations take longer because they involve using `alg₁` to invert `A₁₁`. +So, if only a few iterations are required, the simpler, slower-converging +algorithm may be more performant. + +In the context of computational fluid dynamics, this algorithm is called a +"Schur complement reduction" or "segregated" solve, and the simpler alternative +is called a "coupled" solve; see Section 1 of [A robust and efficient iterative +method for hyper-elastodynamics with nested block preconditioning by J. Liu and +A. Marsden](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6781635) for additional +information. +""" +function ApproximateBlockArrowheadIterativeSolve( + names₁...; + P_alg₁ = MainDiagonalPreconditioner(), + alg₁ = BlockDiagonalSolve(), + alg₂ = BlockDiagonalSolve(), + kwargs..., +) + P_alg₂ = BlockArrowheadSchurComplementPreconditioner(; P_alg₁, alg₂) + outer_alg₂ = StationaryIterativeSolve(; P_alg = P_alg₂, kwargs...) + return BlockLUDecompositionSolve(names₁...; alg₁, alg₂ = outer_alg₂) +end diff --git a/src/MatrixFields/field_matrix_solver.jl b/src/MatrixFields/field_matrix_solver.jl index dbe2675e9d..05c8ea7065 100644 --- a/src/MatrixFields/field_matrix_solver.jl +++ b/src/MatrixFields/field_matrix_solver.jl @@ -5,13 +5,46 @@ Description of how to solve an equation of the form `A * x = b` for `x`, where `A` is a `FieldMatrix` and where `x` and `b` are both `FieldVector`s. Different algorithms can be nested inside each other, enabling the construction of specialized linear solvers that fully utilize the sparsity pattern of `A`. + +# Interface + +Every subtype of `FieldMatrixSolverAlgorithm` must implement methods for the +following functions: +- [`field_matrix_solver_cache`](@ref) +- [`check_field_matrix_solver`](@ref) +- [`run_field_matrix_solver!`](@ref) """ abstract type FieldMatrixSolverAlgorithm end +""" + field_matrix_solver_cache(alg, A, b) + +Allocates the cache required by the `FieldMatrixSolverAlgorithm` `alg` to solve +the equation `A * x = b`. +""" +function field_matrix_solver_cache end + +""" + check_field_matrix_solver(alg, cache, A, b) + +Checks that the sparsity structure of `A` is supported by the +`FieldMatrixSolverAlgorithm` `alg`, and that `A` is compatible with `b` in the +equation `A * x = b`. +""" +function check_field_matrix_solver end + +""" + run_field_matrix_solver!(alg, cache, x, A, b) + +Sets `x` to the value that solves the equation `A * x = b` using the +`FieldMatrixSolverAlgorithm` `alg`. +""" +function run_field_matrix_solver! end + """ FieldMatrixSolver(alg, A, b) -Combination of a `FieldMatrixSolverAlgorithm` and the cache that it requires to +Combination of a `FieldMatrixSolverAlgorithm` `alg` and the cache it requires to solve the equation `A * x = b` for `x`. The values of `A` and `b` that get passed to this constructor should be `similar` to the ones that get passed to `field_matrix_solve!` in order to ensure that the cache gets allocated @@ -27,15 +60,16 @@ function FieldMatrixSolver( b::Fields.FieldVector, ) b_view = field_vector_view(b) - cache = field_matrix_solver_cache(alg, A, b_view) - check_field_matrix_solver(alg, cache, A, b_view) + A_with_tree = FieldMatrix(pairs(A)...; keys(b_view).name_tree) + cache = field_matrix_solver_cache(alg, A_with_tree, b_view) + check_field_matrix_solver(alg, cache, A_with_tree, b_view) return FieldMatrixSolver(alg, cache) end """ field_matrix_solve!(solver, x, A, b) -Solves the equation `A * x = b` for `x` using the given `FieldMatrixSolver`. +Solves the equation `A * x = b` for `x` using the `FieldMatrixSolver` `solver`. """ function field_matrix_solve!( solver::FieldMatrixSolver, @@ -43,14 +77,16 @@ function field_matrix_solve!( A::FieldMatrix, b::Fields.FieldVector, ) + (; alg, cache) = solver x_view = field_vector_view(x) b_view = field_vector_view(b) keys(x_view) == keys(b_view) || error( "The linear system cannot be solved because x and b have incompatible \ keys: $(set_string(keys(x_view))) vs. $(set_string(keys(b_view)))", ) - check_field_matrix_solver(solver.alg, solver.cache, A, b_view) - field_matrix_solve!(solver.alg, solver.cache, x_view, A, b_view) + A_with_tree = FieldMatrix(pairs(A)...; keys(b_view).name_tree) + check_field_matrix_solver(alg, cache, A_with_tree, b_view) + run_field_matrix_solver!(alg, cache, x_view, A_with_tree, b_view) return x end @@ -68,46 +104,131 @@ function check_block_diagonal_matrix_has_no_missing_blocks(A, b) entries at the following keys: $(set_string(missing_keys))") end -function partition_blocks(names₁, A, b, x = nothing) - keys₁ = FieldVectorKeys(names₁, keys(b).name_tree) +function partition_blocks(names₁, A, b = nothing, x = nothing) + keys₁ = FieldVectorKeys(names₁, keys(A).name_tree) keys₂ = set_complement(keys₁) A₁₁ = A[cartesian_product(keys₁, keys₁)] A₁₂ = A[cartesian_product(keys₁, keys₂)] A₂₁ = A[cartesian_product(keys₂, keys₁)] A₂₂ = A[cartesian_product(keys₂, keys₂)] - return isnothing(x) ? (A₁₁, A₁₂, A₂₁, A₂₂, b[keys₁], b[keys₂]) : - (A₁₁, A₁₂, A₂₁, A₂₂, b[keys₁], b[keys₂], x[keys₁], x[keys₂]) + b_blocks = isnothing(b) ? () : (b[keys₁], b[keys₂]) + x_blocks = isnothing(x) ? () : (x[keys₁], x[keys₂]) + return (A₁₁, A₁₂, A₂₁, A₂₂, b_blocks..., x_blocks...) +end + +function similar_to_x(A, b) + entries = map(matrix_row_keys(keys(A))) do name + similar(b[name], x_eltype(A[name, name], b[name])) + end + return FieldVectorView(matrix_row_keys(keys(A)), entries) +end + +################################################################################ + +# Lazy (i.e., as matrix-free as possible) operations for FieldMatrix and +# analogues of FieldMatrix + +lazy_inv(A) = Base.Broadcast.broadcasted(inv, A) +lazy_add(As...) = Base.Broadcast.broadcasted(+, As...) +lazy_sub(As...) = Base.Broadcast.broadcasted(-, As...) + +""" + lazy_mul(A, args...) + +Constructs an un-materialized `FieldMatrixBroadcasted` that represents the +product `@. *(A, args...)`. This involves regular broadcasting when `A` is a +`FieldMatrix` or `FieldMatrixBroadcasted`, but it has more complex behavior for +other objects like the [`LazySchurComplement`](@ref). +""" +lazy_mul(A, args...) = Base.Broadcast.broadcasted(*, A, args...) + +""" + LazySchurComplement(A₁₁, A₁₂, A₂₁, A₂₂, [alg₁, cache₁, A₁₂_x₂, invA₁₁_A₁₂_x₂]) + +An analogue of a `FieldMatrix` that represents the Schur complement of `A₁₁` in +`A`, `A₂₂ - A₂₁ * inv(A₁₁) * A₁₂`. Since `inv(A₁₁)` will generally be a dense +matrix, it would not be efficient to directly compute the Schur complement. So, +this object only supports the "lazy" functions [`lazy_mul`](@ref), which allows +it to be multiplied by the vector `x₂`, and [`lazy_preconditioner`](@ref), which +allows it to be approximated with a `FieldMatrix`. + +The values `alg₁`, `cache₁`, `A₁₂_x₂`, and `invA₁₁_A₁₂_x₂` need to be specified +in order for `lazy_mul` to be able to compute `inv(A₁₁) * A₁₂ * x₂`. When a +`LazySchurComplement` is not passed to `lazy_mul`, these values can be omitted. +""" +struct LazySchurComplement{M11, M12, M21, M22, A1, C1, V1, V2} + A₁₁::M11 + A₁₂::M12 + A₂₁::M21 + A₂₂::M22 + alg₁::A1 + cache₁::C1 + A₁₂_x₂::V1 + invA₁₁_A₁₂_x₂::V2 +end +LazySchurComplement(A₁₁, A₁₂, A₂₁, A₂₂) = + LazySchurComplement(A₁₁, A₁₂, A₂₁, A₂₂, nothing, nothing, nothing, nothing) + +function lazy_mul(A₂₂′::LazySchurComplement, x₂) + (; A₁₁, A₁₂, A₂₁, A₂₂, alg₁, cache₁, A₁₂_x₂, invA₁₁_A₁₂_x₂) = A₂₂′ + zero_rows = setdiff(keys(A₁₂_x₂), matrix_row_keys(keys(A₁₂))) + @. A₁₂_x₂ = A₁₂ * x₂ + zero(A₁₂_x₂[zero_rows]) + run_field_matrix_solver!(alg₁, cache₁, invA₁₁_A₁₂_x₂, A₁₁, A₁₂_x₂) + return lazy_sub(lazy_mul(A₂₂, x₂), lazy_mul(A₂₁, invA₁₁_A₁₂_x₂)) end +""" + LazyFieldMatrixSolverAlgorithm + +A `FieldMatrixSolverAlgorithm` that does not require `A` to be a `FieldMatrix`; +i.e., a "matrix-free" algorithm. Internally, a `FieldMatrixSolverAlgorithm` +(e.g., [`BlockLUDecompositionSolve`](@ref)) might run a +`LazyFieldMatrixSolverAlgorithm` on a "lazy" representation of a `FieldMatrix` +(e.g., a [`LazySchurComplement`](@ref)). + +The only operations used by a `LazyFieldMatrixSolverAlgorithm` that depend on +`A` are [`lazy_mul`](@ref) and, when required, [`lazy_preconditioner`](@ref). +These and other lazy operations are used to minimize the number of calls to +`Base.materialize!`, since each call comes with a small performance penalty. +""" +abstract type LazyFieldMatrixSolverAlgorithm <: FieldMatrixSolverAlgorithm end + ################################################################################ """ BlockDiagonalSolve() -A `FieldMatrixSolverAlgorithm` for a block diagonal matrix `A`, which solves -each block's equation `Aᵢᵢ * xᵢ = bᵢ` in sequence. The equation for `xᵢ` is -solved as follows: -- If `Aᵢᵢ = λᵢ * I`, the equation is solved by setting `xᵢ .= inv(λᵢ) .* bᵢ`. -- If `Aᵢᵢ = Dᵢ`, where `Dᵢ` is a diagonal matrix, the equation is solved by - making a single pass over the data, setting each `xᵢ[n] = inv(Dᵢ[n]) * bᵢ[n]`. -- If `Aᵢᵢ = Lᵢ * Dᵢ * Uᵢ`, where `Dᵢ` is a diagonal matrix and where `Lᵢ` and - `Uᵢ` are unit lower and upper triangular matrices, respectively, the equation +A `FieldMatrixSolverAlgorithm` for a block diagonal matrix: +```math +A = \\begin{bmatrix} + A_{11} & \\mathbf{0} & \\mathbf{0} & \\cdots & \\mathbf{0} \\\\ +\\mathbf{0} & A_{22} & \\mathbf{0} & \\cdots & \\mathbf{0} \\\\ +\\mathbf{0} & \\mathbf{0} & A_{33} & \\cdots & \\mathbf{0} \\\\ + \\vdots & \\vdots & \\vdots & \\ddots & \\vdots \\\\ +\\mathbf{0} & \\mathbf{0} & \\mathbf{0} & \\cdots & A_{NN} +\\end{bmatrix} +``` +The `N` block equations `Aₙₙ * xₙ = bₙ` are solved sequentially according to the +following rules: +- If `Aₙₙ = λₙ * I`, where `λₙ` is a scalar, the equation is solved by setting + `xₙ .= inv(λₙ) .* bₙ`. +- If `Aₙₙ = Dₙ`, where `Dₙ` is a diagonal matrix, the equation is solved by + making a single pass over the data, setting each `xₙ[i] = inv(Dₙ[i]) * bₙ[i]`. +- If `Aₙₙ = Lₙ * Dₙ * Uₙ`, where `Dₙ` is a diagonal matrix and where `Lₙ` and + `Uₙ` are unit lower and upper triangular matrices, respectively, the equation is solved using Gauss-Jordan elimination, which makes two passes over the - data. The first pass multiplies both sides of the equation by `inv(Lᵢ * Dᵢ)`, - replacing `Aᵢᵢ` with `Uᵢ` and `bᵢ` with `Uᵢxᵢ`, which is also referred to as - putting `Aᵢᵢ` into "reduced row echelon form". The second pass solves - `Uᵢ * xᵢ = Uᵢxᵢ` for `xᵢ` using a unit upper triangular matrix solver, which + data. The first pass multiplies both sides of the equation by `inv(Lₙ * Dₙ)`, + replacing `Aₙₙ` with `Uₙ` and `bₙ` with `Uₙxₙ`, which is also referred to as + putting `Aₙₙ` into "reduced row echelon form". The second pass solves + `Uₙ * xₙ = Uₙxₙ` for `xₙ` using a unit upper triangular matrix solver, which is also referred to as "back substitution". Only tri-diagonal and - penta-diagonal matrices `Aᵢᵢ` are currently supported. -- The general case of `Aᵢᵢ = inv(Pᵢ) * Lᵢ * Uᵢ`, where `Pᵢ` is a row permutation - matrix (i.e., LU factorization with partial pivoting), is not currently - supported. + penta-diagonal matrices `Aₙₙ` are currently supported. """ struct BlockDiagonalSolve <: FieldMatrixSolverAlgorithm end function field_matrix_solver_cache(::BlockDiagonalSolve, A, b) caches = map(matrix_row_keys(keys(A))) do name - single_field_solver_cache(A[(name, name)], b[name]) + single_field_solver_cache(A[name, name], b[name]) end return FieldNameDict{FieldName}(matrix_row_keys(keys(A)), caches) end @@ -119,51 +240,46 @@ function check_field_matrix_solver(::BlockDiagonalSolve, _, A, b) ) check_block_diagonal_matrix_has_no_missing_blocks(A, b) foreach(matrix_row_keys(keys(A))) do name - check_single_field_solver(A[(name, name)], b[name]) + check_single_field_solver(A[name, name], b[name]) end end -field_matrix_solve!(::BlockDiagonalSolve, cache, x, A, b) = +run_field_matrix_solver!(::BlockDiagonalSolve, cache, x, A, b) = foreach(matrix_row_keys(keys(A))) do name - single_field_solve!(cache[name], x[name], A[(name, name)], b[name]) + single_field_solve!(cache[name], x[name], A[name, name], b[name]) end """ BlockLowerTriangularSolve(names₁...; [alg₁], [alg₂]) -A `FieldMatrixSolverAlgorithm` for a block lower triangular matrix `A`, which -solves for `x` by executing the following steps: -1. Partition the entries in `A`, `x`, and `b` into the blocks `A₁₁`, `A₁₂`, - `A₂₁`, `A₂₂`, `x₁`, `x₂`, `b₁`, and `b₂`, based on the `FieldName`s in - `names₁`. In this notation, the subscript `₁` corresponds to `FieldName`s - that are covered by `names₁`, while the subscript `₂` corresponds to all - other `FieldNames`. A subscript in the first position refers to `FieldName`s - that are used as row indices, while a subscript in the second position refers - to column indices. This algorithm requires that the upper triangular block - `A₁₂` be empty. (Any upper triangular solve can also be expressed as a lower - triangular solve by swapping the subscripts `₁` and `₂`.) -2. Solve `A₁₁ * x₁ = b₁` for `x₁` using the algorithm `alg₁`, which is set to - `BlockDiagonalSolve()` by default. -3. Solve `A₂₂ * x₂ = b₂ - A₂₁ * x₁` for `x₂` using the algorithm `alg₂`, which - is set to `BlockDiagonalSolve()` by default. +A `FieldMatrixSolverAlgorithm` for a 2×2 block lower triangular matrix: +```math +A = \\begin{bmatrix} A_{11} & \\mathbf{0} \\\\ A_{21} & A_{22} \\end{bmatrix} +``` +The `FieldName`s in `names₁` correspond to the subscript `₁`, while all other +`FieldName`s correspond to the subscript `₂`. This algorithm has 2 steps: +1. Solve `A₁₁ * x₁ = b₁` for `x₁` using the algorithm `alg₁`, which is set to + [`BlockDiagonalSolve()`](@ref) by default. +2. Solve `A₂₂ * x₂ = b₂ - A₂₁ * x₁` for `x₂` using the algorithm `alg₂`, which + is set to [`BlockDiagonalSolve()`](@ref) by default. """ struct BlockLowerTriangularSolve{ - V <: NTuple{<:Any, FieldName}, + N <: NTuple{<:Any, FieldName}, A1 <: FieldMatrixSolverAlgorithm, A2 <: FieldMatrixSolverAlgorithm, } <: FieldMatrixSolverAlgorithm - names₁::V + names₁::N alg₁::A1 alg₂::A2 end BlockLowerTriangularSolve( - names₁::FieldName...; + names₁...; alg₁ = BlockDiagonalSolve(), alg₂ = BlockDiagonalSolve(), ) = BlockLowerTriangularSolve(names₁, alg₁, alg₂) function field_matrix_solver_cache(alg::BlockLowerTriangularSolve, A, b) - A₁₁, _, A₂₁, A₂₂, b₁, b₂ = partition_blocks(alg.names₁, A, b) + A₁₁, _, _, A₂₂, b₁, b₂ = partition_blocks(alg.names₁, A, b) cache₁ = field_matrix_solver_cache(alg.alg₁, A₁₁, b₁) b₂′ = similar(b₂) cache₂ = field_matrix_solver_cache(alg.alg₂, A₂₂, b₂′) @@ -180,129 +296,137 @@ function check_field_matrix_solver(alg::BlockLowerTriangularSolve, cache, A, b) check_field_matrix_solver(alg.alg₂, cache.cache₂, A₂₂, cache.b₂′) end -function field_matrix_solve!(alg::BlockLowerTriangularSolve, cache, x, A, b) +function run_field_matrix_solver!( + alg::BlockLowerTriangularSolve, + cache, + x, + A, + b, +) A₁₁, _, A₂₁, A₂₂, b₁, b₂, x₁, x₂ = partition_blocks(alg.names₁, A, b, x) - field_matrix_solve!(alg.alg₁, cache.cache₁, x₁, A₁₁, b₁) + run_field_matrix_solver!(alg.alg₁, cache.cache₁, x₁, A₁₁, b₁) @. cache.b₂′ = b₂ - A₂₁ * x₁ - field_matrix_solve!(alg.alg₂, cache.cache₂, x₂, A₂₂, cache.b₂′) + run_field_matrix_solver!(alg.alg₂, cache.cache₂, x₂, A₂₂, cache.b₂′) end """ - SchurComplementSolve(names₁...; [alg₁]) - -A `FieldMatrixSolverAlgorithm` for a block matrix `A`, which solves for `x` by -executing the following steps: -1. Partition the entries in `A`, `x`, and `b` into the blocks `A₁₁`, `A₁₂`, - `A₂₁`, `A₂₂`, `x₁`, `x₂`, `b₁`, and `b₂`, based on the `FieldName`s in - `names₁`. In this notation, the subscript `₁` corresponds to `FieldName`s - that are covered by `names₁`, while the subscript `₂` corresponds to all - other `FieldNames`. A subscript in the first position refers to `FieldName`s - that are used as row indices, while a subscript in the second position refers - to column indices. This algorithm requires that the block `A₂₂` be a diagonal - matrix, which allows it to assume that `inv(A₂₂)` can be computed on the fly. -2. Solve `(A₁₁ - A₁₂ * inv(A₂₂) * A₂₁) * x₁ = b₁ - A₁₂ * inv(A₂₂) * b₂` for `x₁` - using the algorithm `alg₁`, which is set to `BlockDiagonalSolve()` by - default. The matrix `A₁₁ - A₁₂ * inv(A₂₂) * A₂₁` is called the "Schur - complement" of `A₂₂` in `A`. -3. Set `x₂` to `inv(A₂₂) * (b₂ - A₂₁ * x₁)`. + BlockArrowheadSolve(names₁...; [alg₂]) + +A `FieldMatrixSolverAlgorithm` for a 2×2 block arrowhead matrix: +```math +A = \\begin{bmatrix} A_{11} & A_{12} \\\\ A_{21} & A_{22} \\end{bmatrix}, \\quad +\\text{where } A_{11} \\text{ is a diagonal matrix} +``` +The `FieldName`s in `names₁` correspond to the subscript `₁`, while all other +`FieldName`s correspond to the subscript `₂`. This algorithm has only 1 step: +1. Solve `(A₂₂ - A₂₁ * inv(A₁₁) * A₁₂) * x₂ = b₂ - A₂₁ * inv(A₁₁) * b₁` for `x₂` + using the algorithm `alg₂`, which is set to [`BlockDiagonalSolve()`](@ref) by + default, and set `x₁` to `inv(A₁₁) * (b₁ - A₁₂ * x₂)`. + +Since `A₁₁` is a diagonal matrix, `inv(A₁₁)` is easy to compute, which means +that the Schur complement of `A₁₁` in `A`, `A₂₂ - A₂₁ * inv(A₁₁) * A₁₂`, as well +as the vectors `b₂ - A₂₁ * inv(A₁₁) * b₁` and `inv(A₁₁) * (b₁ - A₁₂ * x₂)`, are +also easy to compute. """ -struct SchurComplementSolve{ - V <: NTuple{<:Any, FieldName}, +struct BlockArrowheadSolve{ + N <: NTuple{<:Any, FieldName}, A <: FieldMatrixSolverAlgorithm, } <: FieldMatrixSolverAlgorithm - names₁::V - alg₁::A + names₁::N + alg₂::A end -SchurComplementSolve(names₁::FieldName...; alg₁ = BlockDiagonalSolve()) = - SchurComplementSolve(names₁, alg₁) +BlockArrowheadSolve(names₁...; alg₂ = BlockDiagonalSolve()) = + BlockArrowheadSolve(names₁, alg₂) -function field_matrix_solver_cache(alg::SchurComplementSolve, A, b) - A₁₁, A₁₂, A₂₁, A₂₂, b₁, b₂ = partition_blocks(alg.names₁, A, b) - A₁₁′ = @. A₁₁ - A₁₂ * inv(A₂₂) * A₂₁ # A₁₁′ could have more blocks than A₁₁ - b₁′ = similar(b₁) - cache₁ = field_matrix_solver_cache(alg.alg₁, A₁₁′, b₁′) - return (; A₁₁′, b₁′, cache₁) +function field_matrix_solver_cache(alg::BlockArrowheadSolve, A, b) + A₁₁, A₁₂, A₂₁, A₂₂, _, b₂ = partition_blocks(alg.names₁, A, b) + A₂₂′ = @. A₂₂ - A₂₁ * inv(A₁₁) * A₁₂ + b₂′ = similar(b₂) + cache₂ = field_matrix_solver_cache(alg.alg₂, A₂₂′, b₂′) + return (; A₂₂′, b₂′, cache₂) end -function check_field_matrix_solver(alg::SchurComplementSolve, cache, A, b) - _, _, _, A₂₂, _, b₂ = partition_blocks(alg.names₁, A, b) - check_diagonal_matrix(A₂₂, "SchurComplementSolve cannot be used because A") - check_block_diagonal_matrix_has_no_missing_blocks(A₂₂, b₂) - check_field_matrix_solver(alg.alg₁, cache.cache₁, cache.A₁₁′, cache.b₁′) +function check_field_matrix_solver(alg::BlockArrowheadSolve, cache, A, b) + A₁₁, _, _, _, b₁, _ = partition_blocks(alg.names₁, A, b) + check_diagonal_matrix(A₁₁, "BlockArrowheadSolve cannot be used because A") + check_block_diagonal_matrix_has_no_missing_blocks(A₁₁, b₁) + check_field_matrix_solver(alg.alg₂, cache.cache₂, cache.A₂₂′, cache.b₂′) end -function field_matrix_solve!(alg::SchurComplementSolve, cache, x, A, b) +function run_field_matrix_solver!(alg::BlockArrowheadSolve, cache, x, A, b) A₁₁, A₁₂, A₂₁, A₂₂, b₁, b₂, x₁, x₂ = partition_blocks(alg.names₁, A, b, x) - @. cache.A₁₁′ = A₁₁ - A₁₂ * inv(A₂₂) * A₂₁ - @. cache.b₁′ = b₁ - A₁₂ * inv(A₂₂) * b₂ - field_matrix_solve!(alg.alg₁, cache.cache₁, x₁, cache.A₁₁′, cache.b₁′) - @. x₂ = inv(A₂₂) * (b₂ - A₂₁ * x₁) + @. cache.A₂₂′ = A₂₂ - A₂₁ * inv(A₁₁) * A₁₂ + @. cache.b₂′ = b₂ - A₂₁ * inv(A₁₁) * b₁ + run_field_matrix_solver!(alg.alg₂, cache.cache₂, x₂, cache.A₂₂′, cache.b₂′) + @. x₁ = inv(A₁₁) * (b₁ - A₁₂ * x₂) end """ - ApproximateFactorizationSolve(name_pairs₁...; [alg₁], [alg₂]) - -A `FieldMatrixSolverAlgorithm` for a block matrix `A`, which (approximately) -solves for `x` by executing the following steps: -1. Use the entries in `A = M + I = M₁ + M₂ + I` to compute `A₁ = M₁ + I` and - `A₂ = M₂ + I`, based on the pairs of `FieldName`s in `name_pairs₁`. In this - notation, the subscript `₁` refers to pairs of `FieldName`s that are covered - by `name_pairs₁`, while the subscript `₂` refers to all other pairs of - `FieldNames`s. This algorithm approximates the matrix `A` as the product - `A₁ * A₂`, which introduces an error that scales roughly with the norm of - `A₁ * A₂ - A = M₁ * M₂`. (More precisely, the error introduced by this - algorithm is `x_exact - x_approx = inv(A) * b - inv(A₁ * A₂) * b`.) -2. Solve `A₁ * A₂x = b` for `A₂x` using the algorithm `alg₁`, which is set to - `BlockDiagonalSolve()` by default. -3. Solve `A₂ * x = A₂x` for `x` using the algorithm `alg₂`, which is set to - `BlockDiagonalSolve()` by default. + BlockLUDecompositionSolve(names₁...; [alg₁], alg₂) + +A `FieldMatrixSolverAlgorithm` for a 2×2 block matrix: +```math +A = \\begin{bmatrix} A_{11} & A_{12} \\\\ A_{21} & A_{22} \\end{bmatrix} +``` +The `FieldName`s in `names₁` correspond to the subscript `₁`, while all other +`FieldName`s correspond to the subscript `₂`. This algorithm has 3 steps: +1. Solve `A₁₁ * x₁′ = b₁` for `x₁′` using the algorithm `alg₁`, which is set to + [`BlockDiagonalSolve()`](@ref) by default. +2. Solve `(A₂₂ - A₂₁ * inv(A₁₁) * A₁₂) * x₂ = b₂ - A₂₁ * x₁′` for `x₂` + using the algorithm `alg₂`. +3. Solve `A₁₁ * x₁ = b₁ - A₁₂ * x₂` for `x₁` using the algorithm `alg₁`. + +Since `A₁₁` is not a diagonal matrix, `inv(A₁₁)` will generally be a dense +matrix, which means that the Schur complement of `A₁₁` in `A`, +`A₂₂ - A₂₁ * inv(A₁₁) * A₁₂`, cannot be computed efficiently. So, `alg₂` must be +set to a LazyFieldMatrixSolverAlgorithm`, which can evaluate the matrix-vector +product `(A₂₂ - A₂₁ * inv(A₁₁) * A₁₂) * x₂` without actually computing the Schur +complement matrix. This involves using `alg₁` to solve an equation with `A₁₁`, +instead of just multiplying by `inv(A₁₁)`. """ -struct ApproximateFactorizationSolve{ - V <: NTuple{<:Any, FieldNamePair}, +struct BlockLUDecompositionSolve{ + N <: NTuple{<:Any, FieldName}, A1 <: FieldMatrixSolverAlgorithm, - A2 <: FieldMatrixSolverAlgorithm, + A2 <: LazyFieldMatrixSolverAlgorithm, } <: FieldMatrixSolverAlgorithm - name_pairs₁::V + names₁::N alg₁::A1 alg₂::A2 end -ApproximateFactorizationSolve( - name_pairs₁::FieldNamePair...; - alg₁ = BlockDiagonalSolve(), - alg₂ = BlockDiagonalSolve(), -) = ApproximateFactorizationSolve(name_pairs₁, alg₁, alg₂) -# Note: This algorithm assumes that x is `similar` to b. In other words, it -# assumes that typeof(x) == typeof(b), rather than just keys(x) == keys(b). +BlockLUDecompositionSolve(names₁...; alg₁ = BlockDiagonalSolve(), alg₂) = + BlockLUDecompositionSolve(names₁, alg₁, alg₂) -function approximate_factors(name_pairs₁, A, b) - keys₁ = FieldMatrixKeys(name_pairs₁, keys(b).name_tree) - keys₂ = set_complement(keys₁) - A₁ = A[keys₁] .+ one(A)[keys₂] # `one` can be used because x is similar to b - A₂ = A[keys₂] .+ one(A)[keys₁] - return A₁, A₂ +function field_matrix_solver_cache(alg::BlockLUDecompositionSolve, A, b) + A₁₁, A₁₂, A₂₁, A₂₂, b₁, b₂ = partition_blocks(alg.names₁, A, b) + b₁′ = similar(b₁) + cache₁ = field_matrix_solver_cache(alg.alg₁, A₁₁, b₁) + A₂₂′ = LazySchurComplement(A₁₁, A₁₂, A₂₁, A₂₂) + b₂′ = similar(b₂) + cache₂ = field_matrix_solver_cache(alg.alg₂, A₂₂′, b₂′) + return (; b₁′, cache₁, b₂′, cache₂) end -function field_matrix_solver_cache(alg::ApproximateFactorizationSolve, A, b) - A₁, A₂ = approximate_factors(alg.name_pairs₁, A, b) - cache₁ = field_matrix_solver_cache(alg.alg₁, A₁, b) - A₂x = @. A₂ * b # x can be replaced with b because they are similar - cache₂ = field_matrix_solver_cache(alg.alg₂, A₂, A₂x) - return (; cache₁, A₂x, cache₂) +function check_field_matrix_solver(alg::BlockLUDecompositionSolve, cache, A, b) + A₁₁, A₁₂, A₂₁, A₂₂, b₁, _ = partition_blocks(alg.names₁, A, b) + check_field_matrix_solver(alg.alg₁, cache.cache₁, A₁₁, b₁) + A₂₂′ = LazySchurComplement(A₁₁, A₁₂, A₂₁, A₂₂) + check_field_matrix_solver(alg.alg₂, cache.cache₂, A₂₂′, cache.b₂′) end -function check_field_matrix_solver( - alg::ApproximateFactorizationSolve, +function run_field_matrix_solver!( + alg::BlockLUDecompositionSolve, cache, + x, A, b, ) - A₁, A₂ = approximate_factors(alg.name_pairs₁, A, b) - check_field_matrix_solver(alg.alg₁, cache.cache₁, A₁, b) - check_field_matrix_solver(alg.alg₂, cache.cache₂, A₂, cache.A₂x) -end - -function field_matrix_solve!(alg::ApproximateFactorizationSolve, cache, x, A, b) - A₁, A₂ = approximate_factors(alg.name_pairs₁, A, b) - field_matrix_solve!(alg.alg₁, cache.cache₁, cache.A₂x, A₁, b) - field_matrix_solve!(alg.alg₂, cache.cache₂, x, A₂, cache.A₂x) + A₁₁, A₁₂, A₂₁, A₂₂, b₁, b₂, x₁, x₂ = partition_blocks(alg.names₁, A, b, x) + x₁′ = x₁ # Use x₁ as temporary storage to avoid additional allocations. + schur_complement_args = (alg.alg₁, cache.cache₁, cache.b₁′, x₁′) + A₂₂′ = LazySchurComplement(A₁₁, A₁₂, A₂₁, A₂₂, schur_complement_args...) + run_field_matrix_solver!(alg.alg₁, cache.cache₁, x₁′, A₁₁, b₁) + @. cache.b₂′ = b₂ - A₂₁ * x₁′ + run_field_matrix_solver!(alg.alg₂, cache.cache₂, x₂, A₂₂′, cache.b₂′) + @. cache.b₁′ = b₁ - A₁₂ * x₂ + run_field_matrix_solver!(alg.alg₁, cache.cache₁, x₁, A₁₁, cache.b₁′) end diff --git a/src/MatrixFields/field_name_dict.jl b/src/MatrixFields/field_name_dict.jl index ed26b40822..994d7b52c3 100644 --- a/src/MatrixFields/field_name_dict.jl +++ b/src/MatrixFields/field_name_dict.jl @@ -44,11 +44,14 @@ struct FieldNameDict{T1, T2, K <: FieldNameSet{T1}, E <: NTuple{<:Any, T2}} <: ) where {T1, T2, N} = new{T1, T2, typeof(keys), typeof(entries)}(keys, entries) end -FieldNameDict{T1, T2}(key_entry_pairs::Pair{<:T1, <:T2}...) where {T1, T2} = - FieldNameDict{T1, T2}( - FieldNameSet{T1}(unrolled_map(pair -> pair[1], key_entry_pairs)), - unrolled_map(pair -> pair[2], key_entry_pairs), - ) +function FieldNameDict{T1, T2}( + key_entry_pairs::Pair{<:T1, <:T2}...; + name_tree = nothing, +) where {T1, T2} + keys = unrolled_map(pair -> pair[1], key_entry_pairs) + entries = unrolled_map(pair -> pair[2], key_entry_pairs) + return FieldNameDict{T1, T2}(FieldNameSet{T1}(keys, name_tree), entries) +end FieldNameDict{T1}(args...) where {T1} = FieldNameDict{T1, Any}(args...) const FieldVectorView = FieldNameDict{FieldName, Fields.Field} @@ -66,8 +69,24 @@ const FieldMatrixBroadcasted = FieldNameDict{ dict_type(::FieldNameDict{T1, T2}) where {T1, T2} = FieldNameDict{T1, T2} function Base.show(io::IO, dict::FieldNameDict) - strings = map((key, value) -> " $key => $value", pairs(dict)) - print(io, "$(dict_type(dict))($(join(strings, ",\n")))") + print(io, "$(dict_type(dict)) with $(length(dict)) entries:") + for (key, entry) in dict + print(io, "\n $key => ") + if entry isa Fields.Field + print(io, eltype(entry), "-valued Field:") + Fields._show_compact_field(io, entry, " ", true) + elseif entry isa UniformScaling + if entry.λ == 1 + print(io, "I") + elseif entry.λ == -1 + print(io, "-I") + else + print(io, "$(entry.λ) * I") + end + else + print(io, entry) + end + end end Base.keys(dict::FieldNameDict) = dict.keys @@ -118,18 +137,18 @@ function get_internal_entry( # See note above matrix_product_keys in field_name_set.jl for more details. T = eltype(eltype(entry)) if name_pair == (@name(), @name()) - # multiplication case 1, either argument entry - elseif broadcasted_has_field(T, name_pair[1]) && name_pair[2] == @name() + elseif name_pair[1] == name_pair[2] + # multiplication case 3 or 4, first argument + @assert T <: SingleValue && !broadcasted_has_field(T, name_pair[1]) + entry + elseif name_pair[2] == @name() && broadcasted_has_field(T, name_pair[1]) # multiplication case 2 or 4, second argument Base.broadcasted(entry) do matrix_row map(matrix_row) do matrix_row_entry broadcasted_get_field(matrix_row_entry, name_pair[1]) end end # Note: This assumes that the entry is in a FieldMatrixBroadcasted. - elseif T <: SingleValue && name_pair[1] == name_pair[2] - # multiplication case 3 or 4, first argument - entry else unsupported_internal_entry_error(entry, name_pair) end @@ -168,13 +187,7 @@ end function check_diagonal_matrix(matrix, error_message_start = "The matrix") check_block_diagonal_matrix(matrix, error_message_start) non_diagonal_entry_pairs = unrolled_filter(pairs(matrix)) do pair - !( - pair[2] isa UniformScaling || - pair[2] isa ColumnwiseBandMatrixField && - eltype(pair[2]) <: DiagonalMatrixRow || - pair[2] isa Base.AbstractBroadcasted && - eltype(pair[2]) <: DiagonalMatrixRow - ) + !(pair[2] isa UniformScaling || eltype(pair[2]) <: DiagonalMatrixRow) end non_diagonal_entry_keys = FieldMatrixKeys(unrolled_map(pair -> pair[1], non_diagonal_entry_pairs)) @@ -185,15 +198,85 @@ function check_diagonal_matrix(matrix, error_message_start = "The matrix") end """ - field_vector_view(x) + lazy_main_diagonal(matrix) + +Creates an un-materialized `FieldMatrixBroadcasted` that extracts the main +diagonal of the `FieldMatrix`/`FieldMatrixBroadcasted` `matrix`. +""" +function lazy_main_diagonal(matrix) + diagonal_keys = matrix_diagonal_keys(keys(matrix)) + entries = map(diagonal_keys) do key + entry = matrix[key] + entry isa UniformScaling || eltype(entry) <: DiagonalMatrixRow ? + entry : + Base.Broadcast.broadcasted(row -> DiagonalMatrixRow(row[0]), entry) + end + return FieldMatrixBroadcasted(diagonal_keys, entries) +end + +""" + field_vector_view(x, [name_tree]) -Constructs a `FieldVectorView` that contains all the top-level `Field`s in the -`FieldVector` `x`. +Constructs a `FieldVectorView` that contains all of the `Field`s in the +`FieldVector` `x`. The default `name_tree` is `FieldNameTree(x)`, but this can +be modified if needed. """ -function field_vector_view(x) - top_level_keys = FieldVectorKeys(top_level_names(x), FieldNameTree(x)) - entries = map(name -> get_field(x, name), top_level_keys) - return FieldVectorView(top_level_keys, entries) +function field_vector_view(x, name_tree = FieldNameTree(x)) + keys_of_fields = FieldVectorKeys(names_of_fields(x, name_tree), name_tree) + entries = map(name -> get_field(x, name), keys_of_fields) + return FieldVectorView(keys_of_fields, entries) +end +names_of_fields(x, name_tree) = + unrolled_mapflatten(top_level_names(x)) do name + entry = get_field(x, name) + if entry isa Fields.Field + (name,) + elseif entry isa Fields.FieldVector + unrolled_map(names_of_fields(entry, name_tree)) do internal_name + append_internal_name(name, internal_name) + end + else + error("field_vector_view does not support entries of type \ + $(typeof(entry).name.name)") + end + end + +""" + concrete_field_vector(vector) + +Converts the `FieldVectorView` `vector` back into a `FieldVector`. +""" +concrete_field_vector(vector) = + concrete_field_vector_within_subtree(keys(vector).name_tree, vector) +concrete_field_vector_within_subtree(tree, vector) = + if tree.name in keys(vector) + vector[tree.name] + else + subtrees = unrolled_filter(tree.subtrees) do subtree + unrolled_any(keys(vector).values) do key + is_child_name(key, subtree.name) + end + end + internal_names = unrolled_map(subtrees) do subtree + extract_first(extract_internal_name(subtree.name, tree.name)) + end + internal_entries = unrolled_map(subtrees) do subtree + concrete_field_vector_within_subtree(subtree, vector) + end + entry_eltypes = unrolled_map(recursive_bottom_eltype, internal_entries) + T = promote_type(entry_eltypes...) + Fields.FieldVector{T}(NamedTuple{internal_names}(internal_entries)) + end + +# This is required for type-stability as of Julia 1.9. +if hasfield(Method, :recursion_relation) + dont_limit = (args...) -> true + for m in methods(names_of_fields) + m.recursion_relation = dont_limit + end + for m in methods(concrete_field_vector_within_subtree) + m.recursion_relation = dont_limit + end end ################################################################################ @@ -260,6 +343,19 @@ Base.Broadcast.broadcasted( arg::FieldMatrixStyleType, ) = arg +function Base.Broadcast.broadcasted( + ::FieldMatrixStyle, + ::typeof(zero), + vector_or_matrix::FieldMatrixStyleType, +) + FieldNameDictType = dict_type(vector_or_matrix) + entries = unrolled_map(values(vector_or_matrix)) do entry + entry isa UniformScaling ? zero(entry) : + Base.Broadcast.broadcasted(value -> rzero(typeof(value)), entry) + end + return FieldNameDictType(keys(vector_or_matrix), entries) +end + function Base.Broadcast.broadcasted( ::FieldMatrixStyle, ::typeof(-), diff --git a/src/MatrixFields/field_name_set.jl b/src/MatrixFields/field_name_set.jl index 1f28960bff..8896f3f4f2 100644 --- a/src/MatrixFields/field_name_set.jl +++ b/src/MatrixFields/field_name_set.jl @@ -77,48 +77,41 @@ function Base.issubset(set1::FieldNameSet, set2::FieldNameSet) end Base.:(==)(set1::FieldNameSet, set2::FieldNameSet) = - issubset(set1, set2) && issubset(set2, set1) + unrolled_all(value -> unrolled_in(value, set2.values), set1.values) && + unrolled_all(value -> unrolled_in(value, set1.values), set2.values) function Base.intersect(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} name_tree = combine_name_trees(set1.name_tree, set2.name_tree) - values1′, values2′ = set1.values, set2.values - values1, values2 = non_overlapping_values(values1′, values2′, name_tree) - result_values = unrolled_filter(values2) do value - unrolled_any(isequal(value), values1) + all_values = union_values(set1.values, set2.values, name_tree) + result_values = unrolled_filter(all_values) do value + is_value_in_set(value, set1.values, name_tree) && + is_value_in_set(value, set2.values, name_tree) end return FieldNameSet{T}(result_values, name_tree) end function Base.union(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} name_tree = combine_name_trees(set1.name_tree, set2.name_tree) - values1′, values2′ = set1.values, set2.values - values1, values2 = non_overlapping_values(values1′, values2′, name_tree) - values2_minus_values1 = unrolled_filter(values2) do value - !unrolled_any(isequal(value), values1) - end - result_values = (values1..., values2_minus_values1...) + result_values = union_values(set1.values, set2.values, name_tree) return FieldNameSet{T}(result_values, name_tree) end function Base.setdiff(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} name_tree = combine_name_trees(set1.name_tree, set2.name_tree) - set2_complement_values = set_complement_values(T, set2.values, name_tree) - set2_complement = FieldNameSet{T}(set2_complement_values, name_tree) - return intersect(set1, set2_complement) + all_values = union_values(set1.values, set2.values, name_tree) + result_values = unrolled_filter(all_values) do value + !is_value_in_set(value, set2.values, name_tree) + end + return FieldNameSet{T}(result_values, name_tree) end -set_string(set) = - length(set) == 2 ? join(set.values, " and ") : - join(set.values, ", ", ", and ") +set_string(set) = values_string(set.values) + +set_complement(set) = setdiff(universal_set(eltype(set), set.name_tree), set) is_subset_that_covers_set(set1, set2) = issubset(set1, set2) && isempty(setdiff(set2, set1)) -function set_complement(set::FieldNameSet{T}) where {T} - result_values = set_complement_values(T, set.values, set.name_tree) - return FieldNameSet{T}(result_values, set.name_tree) -end - function corresponding_matrix_keys(set::FieldVectorKeys) result_values = unrolled_map(name -> (name, name), set.values) return FieldMatrixKeys(result_values, set.name_tree) @@ -126,9 +119,7 @@ end function cartesian_product(set1::FieldVectorKeys, set2::FieldVectorKeys) name_tree = combine_name_trees(set1.name_tree, set2.name_tree) - result_values = unrolled_mapflatten(set1.values) do row_name - unrolled_map(col_name -> (row_name, col_name), set2.values) - end + result_values = unrolled_product(set1.values, set2.values) return FieldMatrixKeys(result_values, name_tree) end @@ -217,10 +208,16 @@ function summand_names_for_matrix_product( result_values = unrolled_mapflatten(overlapping_set1_values) do name_pair1 overlapping_set2_values = unrolled_filter(set2.values) do value2 row_name2 = eltype(set2) <: FieldName ? value2 : value2[1] - names_are_overlapping(name_pair1[2], row_name2) && ( - eltype(set2) <: FieldName || - names_are_overlapping(product_key[2], value2[2]) - ) + names_are_overlapping(name_pair1[2], row_name2) && + ( + eltype(set2) <: FieldName || + names_are_overlapping(product_key[2], value2[2]) + ) && + ( + is_child_name(name_pair1[2], row_name2) || + product_row_name == row_name2 && + name_pair1[1] == name_pair1[2] + ) end unrolled_map(overlapping_set2_values) do value2 row_name2 = eltype(set2) <: FieldName ? value2 : value2[1] @@ -240,9 +237,6 @@ function summand_names_for_matrix_product( end else # multiplication case 3 - product_row_name == row_name2 && - name_pair1[1] == name_pair1[2] || - error("Invalid matrix product key $product_key") row_name2 end end @@ -267,16 +261,15 @@ check_values(values, name_tree) = overlapping_values = unrolled_filter(values) do value′ value != value′ && values_are_overlapping(value, value′) end - if !isempty(overlapping_values) - overlapping_values_string = - length(overlapping_values) == 2 ? - join(overlapping_values, " or ") : - join(overlapping_values, ", ", ", or ") - error("Overlapping FieldNameSet values: $value cannot be in the \ - same FieldNameSet as $overlapping_values_string") - end + isempty(overlapping_values) || error( + "Overlapping FieldNameSet values: $value cannot be in the same \ + FieldNameSet as $(values_string(overlapping_values))", + ) end +values_string(values) = + length(values) == 2 ? join(values, " and ") : join(values, ", ", ", and ") + combine_name_trees(::Nothing, ::Nothing) = nothing combine_name_trees(name_tree1, ::Nothing) = name_tree1 combine_name_trees(::Nothing, name_tree2) = name_tree2 @@ -285,6 +278,18 @@ combine_name_trees(name_tree1, name_tree2) = error("Mismatched FieldNameTrees: The ability to combine different \ FieldNameTrees has not been implemented") +function universal_set(::Type{FieldName}, name_tree) + isnothing(name_tree) && error( + "Missing FieldNameTree: Cannot compute complement of FieldNameSet \ + without a FieldNameTree", + ) + return FieldVectorKeys(child_names(@name(), name_tree), name_tree) +end +function universal_set(::Type{FieldNamePair}, name_tree) + row_set = universal_set(FieldName, name_tree) + return cartesian_product(row_set, row_set) +end + is_valid_value(name::FieldName, name_tree) = is_valid_name(name, name_tree) is_valid_value(name_pair::FieldNamePair, name_tree) = is_valid_name(name_pair[1], name_tree) && @@ -302,141 +307,142 @@ is_child_value(name_pair1::FieldNamePair, name_pair2::FieldNamePair) = is_child_name(name_pair1[2], name_pair2[2]) is_value_in_set(value, values, name_tree) = - if unrolled_any(isequal(value), values) - true - elseif unrolled_any(value′ -> is_child_value(value, value′), values) - isnothing(name_tree) && error( - "Cannot check if $value is in FieldNameSet without a FieldNameTree", - ) - is_valid_value(value, name_tree) - else - false - end + unrolled_in(value, values) || + unrolled_any(value′ -> is_child_value(value, value′), values) && + (isnothing(name_tree) ? true : is_valid_value(value, name_tree)) -function non_overlapping_values(values1, values2, name_tree) - new_values1 = unrolled_mapflatten(values1) do value - value_or_non_overlapping_children(value, values2, name_tree) - end - new_values2 = unrolled_mapflatten(values2) do value - value_or_non_overlapping_children(value, values1, name_tree) - end - if eltype(values1) <: FieldName - new_values1, new_values2 - else - # Repeat the above operation to handle complex matrix key overlaps. - new_values1′ = unrolled_mapflatten(new_values1) do value - value_or_non_overlapping_children(value, new_values2, name_tree) +function unique_and_non_overlapping_values(values, name_tree) + unique_values = unrolled_unique(values) + overlapping_values, non_overlapping_values = + unrolled_split(unique_values) do value + unrolled_any(unique_values) do value′ + value != value′ && values_are_overlapping(value, value′) + end end - new_values2′ = unrolled_mapflatten(new_values2) do value - value_or_non_overlapping_children(value, new_values1, name_tree) + isempty(overlapping_values) && return unique_values + isnothing(name_tree) && + error("Missing FieldNameTree: Cannot eliminate overlaps among \ + $(values_string(overlapping_values)) without a FieldNameTree") + expanded_overlapping_values = + unrolled_mapflatten(overlapping_values) do value + values_overlapping_with_value = + unrolled_filter(overlapping_values) do value′ + value != value′ && values_are_overlapping(value, value′) + end + expand_child_values(value, values_overlapping_with_value, name_tree) end - return new_values1′, new_values2′ - end + no_longer_overlapping_values = unique_and_non_overlapping_values( + expanded_overlapping_values, + name_tree, + ) + return (non_overlapping_values..., no_longer_overlapping_values...) end -function unique_and_non_overlapping_values(values, name_tree) - new_values = unrolled_mapflatten(values) do value - value_or_non_overlapping_children(value, values, name_tree) - end - return unrolled_unique(new_values) +# The function union_values(values1, values2, name_tree) gives the same result +# as unique_and_non_overlapping_values((values1..., values2...), name_tree), but +# it is slightly more efficient (and faster to compile) because it makes use of +# the fact that values1 == unique_and_non_overlapping_values(values1, name_tree) +# and values2 == unique_and_non_overlapping_values(values2, name_tree). +function union_values(values1, values2, name_tree) + unique_values2 = + unrolled_filter(value2 -> !unrolled_in(value2, values1), values2) + overlapping_values1, non_overlapping_values1 = + unrolled_split(values1) do value1 + unrolled_any(unique_values2) do value2 + values_are_overlapping(value1, value2) + end + end + isempty(overlapping_values1) && return (values1..., unique_values2...) + overlapping_values2, non_overlapping_values2 = + unrolled_split(unique_values2) do value2 + unrolled_any(values1) do value1 + values_are_overlapping(value1, value2) + end + end + isnothing(name_tree) && error( + "Missing FieldNameTree: Cannot eliminate overlaps between \ + $overlapping_values1 and $overlapping_values2 without a FieldNameTree", + ) + expanded_overlapping_values1 = + unrolled_mapflatten(overlapping_values1) do value1 + values2_overlapping_value1 = + unrolled_filter(overlapping_values2) do value2 + values_are_overlapping(value1, value2) + end + expand_child_values(value1, values2_overlapping_value1, name_tree) + end + expanded_overlapping_values2 = + unrolled_mapflatten(overlapping_values2) do value2 + values1_overlapping_value2 = + unrolled_filter(overlapping_values1) do value1 + values_are_overlapping(value1, value2) + end + expand_child_values(value2, values1_overlapping_value2, name_tree) + end + union_of_overlapping_values = union_values( + expanded_overlapping_values1, + expanded_overlapping_values2, + name_tree, + ) + return ( + non_overlapping_values1..., + non_overlapping_values2..., + union_of_overlapping_values..., + ) end -function value_or_non_overlapping_children(name::FieldName, names, name_tree) - need_child_names = unrolled_any(names) do name′ - is_child_value(name′, name) && name′ != name - end - need_child_names || return (name,) - isnothing(name_tree) && - error("Cannot compute child names of $name without a FieldNameTree") - return unrolled_mapflatten(child_names(name, name_tree)) do child_name - value_or_non_overlapping_children(child_name, names, name_tree) - end -end -function value_or_non_overlapping_children( +expand_child_values(name::FieldName, overlapping_names, name_tree) = + unrolled_all(overlapping_names) do name′ + name′ != name && is_child_name(name′, name) + end ? child_names(name, name_tree) : (name,) +function expand_child_values( name_pair::FieldNamePair, - name_pairs, + overlapping_name_pairs, name_tree, ) - need_row_child_names = unrolled_any(name_pairs) do name_pair′ - is_child_value(name_pair′, name_pair) && name_pair′[1] != name_pair[1] - end - need_col_child_names = unrolled_any(name_pairs) do name_pair′ - is_child_value(name_pair′, name_pair) && name_pair′[2] != name_pair[2] - end - need_row_child_names || need_col_child_names || return (name_pair,) - isnothing(name_tree) && error( - "Cannot compute child name pairs of $name_pair without a FieldNameTree", - ) + row_name, col_name = name_pair + row_name_children_needed = + unrolled_all(overlapping_name_pairs) do name_pair′ + name_pair′[1] != row_name && is_child_name(name_pair′[1], row_name) + end + col_name_children_needed = + unrolled_all(overlapping_name_pairs) do name_pair′ + name_pair′[2] != col_name && is_child_name(name_pair′[2], col_name) + end row_name_children = - need_row_child_names ? child_names(name_pair[1], name_tree) : - (name_pair[1],) + row_name_children_needed ? child_names(row_name, name_tree) : () col_name_children = - need_col_child_names ? child_names(name_pair[2], name_tree) : - (name_pair[2],) - return unrolled_mapflatten(row_name_children) do row_name_child - unrolled_mapflatten(col_name_children) do col_name_child - child_pair = (row_name_child, col_name_child) - value_or_non_overlapping_children(child_pair, name_pairs, name_tree) + col_name_children_needed ? child_names(col_name, name_tree) : () + # Note: We need special cases for when either row_name or col_name only has + # one child name, since automatically expanding that name can generate + # results with unnecessary expansions. For example, it can lead to a + # situation in which issubset(set1, set2) && union(set1, set2) != set2 + # evaluates to true because union(set1, set2) has too many expanded values. + return if length(row_name_children) > 1 && length(col_name_children) > 1 || + length(row_name_children) == 1 && length(col_name_children) == 1 + unrolled_product(row_name_children, col_name_children) + elseif length(row_name_children) > 1 && length(col_name_children) == 1 || + length(row_name_children) > 0 && length(col_name_children) == 0 + unrolled_map(row_name_children) do row_name_child + (row_name_child, col_name) end - end -end - -set_complement_values(_, _, ::Nothing) = - error("Cannot compute complement of a FieldNameSet without a FieldNameTree") -set_complement_values(::Type{<:FieldName}, names, name_tree::FieldNameTree) = - complement_values_in_subtree(names, name_tree) -set_complement_values( - ::Type{<:FieldNamePair}, - name_pairs, - name_tree::FieldNameTree, -) = complement_values_in_subtree_pair(name_pairs, (name_tree, name_tree)) - -function complement_values_in_subtree(names, subtree) - name = subtree.name - unrolled_all(name′ -> !is_child_value(name, name′), names) || return () - unrolled_any(name′ -> is_child_value(name′, name), names) || return (name,) - return unrolled_mapflatten(subtree.subtrees) do subsubtree - complement_values_in_subtree(names, subsubtree) - end -end - -function complement_values_in_subtree_pair(name_pairs, subtree_pair) - name_pair = (subtree_pair[1].name, subtree_pair[2].name) - is_name_pair_in_complement = unrolled_all(name_pairs) do name_pair′ - !is_child_value(name_pair, name_pair′) - end - is_name_pair_in_complement || return () - need_row_subsubtrees = unrolled_any(name_pairs) do name_pair′ - is_child_value(name_pair′, name_pair) && name_pair′[1] != name_pair[1] - end - need_col_subsubtrees = unrolled_any(name_pairs) do name_pair′ - is_child_value(name_pair′, name_pair) && name_pair′[2] != name_pair[2] - end - need_row_subsubtrees || need_col_subsubtrees || return (name_pair,) - row_subsubtrees = - need_row_subsubtrees ? subtree_pair[1].subtrees : (subtree_pair[1],) - col_subsubtrees = - need_col_subsubtrees ? subtree_pair[2].subtrees : (subtree_pair[2],) - return unrolled_mapflatten(row_subsubtrees) do row_subsubtree - unrolled_mapflatten(col_subsubtrees) do col_subsubtree - subsubtree_pair = (row_subsubtree, col_subsubtree) - complement_values_in_subtree_pair(name_pairs, subsubtree_pair) + elseif length(row_name_children) == 1 && length(col_name_children) > 1 || + length(row_name_children) == 0 && length(col_name_children) > 0 + unrolled_map(col_name_children) do col_name_child + (row_name, col_name_child) end + else # length(row_name_children) == 0 && length(col_name_children) == 0 + (name_pair,) end end -################################################################################ - # This is required for type-stability as of Julia 1.9. if hasfield(Method, :recursion_relation) dont_limit = (args...) -> true - for m in methods(value_or_non_overlapping_children) - m.recursion_relation = dont_limit - end - for m in methods(complement_values_in_subtree) + for m in methods(unique_and_non_overlapping_values) m.recursion_relation = dont_limit end - for m in methods(complement_values_in_subtree_pair) + for m in methods(union_values) m.recursion_relation = dont_limit end end diff --git a/src/MatrixFields/single_field_solver.jl b/src/MatrixFields/single_field_solver.jl index 62ad93e8ee..967d4ccdd6 100644 --- a/src/MatrixFields/single_field_solver.jl +++ b/src/MatrixFields/single_field_solver.jl @@ -22,7 +22,7 @@ unit_eltype(::Type{T_A}) where {T_A} = ################################################################################ check_single_field_solver(::UniformScaling, _) = nothing -function check_single_field_solver(A::ColumnwiseBandMatrixField, b) +function check_single_field_solver(A, b) matrix_shape(A) == Square() || error( "Cannot solve linear system because a diagonal entry in A is not a \ square matrix", diff --git a/src/MatrixFields/unrolled_functions.jl b/src/MatrixFields/unrolled_functions.jl index 947a45084d..03cbb2a8a1 100644 --- a/src/MatrixFields/unrolled_functions.jl +++ b/src/MatrixFields/unrolled_functions.jl @@ -48,6 +48,8 @@ unrolled_any(f::F, values) where {F} = unrolled_all(f::F, values) where {F} = unrolled_foldl(&, unrolled_map(f, values), true) +unrolled_in(value, values) = unrolled_any(isequal(value), values) + unrolled_filter(f::F, values) where {F} = unrolled_foldl(values, ()) do filtered_values, value f(value) ? (filtered_values..., value) : filtered_values @@ -69,6 +71,14 @@ unrolled_flatten(values) = unrolled_mapflatten(f::F, values) where {F} = unrolled_flatten(unrolled_map(f, values)) +unrolled_product(values1, values2) = + unrolled_mapflatten(values1) do value1 + unrolled_map(value2 -> (value1, value2), values2) + end + +unrolled_split(f::F, values) where {F} = + (unrolled_filter(f, values), unrolled_filter(value -> !f(value), values)) + function unrolled_findonly(f::F, values) where {F} filtered_values = unrolled_filter(f, values) length(filtered_values) == 1 || diff --git a/test/MatrixFields/field_matrix_solvers.jl b/test/MatrixFields/field_matrix_solvers.jl index 2a21087e97..e959a68e14 100644 --- a/test/MatrixFields/field_matrix_solvers.jl +++ b/test/MatrixFields/field_matrix_solvers.jl @@ -13,8 +13,8 @@ function test_field_matrix_solver(; alg, A, b, - ignore_approximation_error = false, - skip_correctness_test = false, + use_rel_error = false, + allocations_test_broken = false, ) @testset "$test_name" begin x = similar(b) @@ -30,21 +30,13 @@ function test_field_matrix_solver(; time_ratio = solve_time_rounded / mul_time_rounded time_ratio_rounded = round(time_ratio; sigdigits = 2) - # If possible, test that A * (inv(A) * b) == b. - if skip_correctness_test - relative_error = - norm(abs.(parent(b_test) .- parent(b))) / norm(parent(b)) - relative_error_rounded = round(relative_error; sigdigits = 2) - error_string = "Relative Error = $(relative_error_rounded * 100) %" + error_vector = abs.(parent(b_test) .- parent(b)) + if use_rel_error + rel_error = norm(error_vector) / norm(parent(b)) + rel_error_rounded = round(rel_error; sigdigits = 2) + error_string = "Relative Error = $rel_error_rounded" else - if ignore_approximation_error - @assert alg isa MatrixFields.ApproximateFactorizationSolve - b_view = MatrixFields.field_vector_view(b) - A₁, A₂ = - MatrixFields.approximate_factors(alg.name_pairs₁, A, b_view) - @. b_test = A₁ * A₂ * x - end - max_error = maximum(abs.(parent(b_test) .- parent(b))) + max_error = maximum(error_vector) max_eps_error = ceil(Int, max_error / eps(typeof(max_error))) error_string = "Maximum Error = $max_eps_error eps" end @@ -53,14 +45,22 @@ function test_field_matrix_solver(; Multiplication Time = $mul_time_rounded s (Ratio = \ $time_ratio_rounded)\n\t$error_string" - skip_correctness_test || @test max_eps_error <= 3 + if use_rel_error + @test rel_error < 1e-6 + else + @test max_eps_error <= 3 + end @test_opt ignored_modules = ignore_cuda FieldMatrixSolver(alg, A, b) @test_opt ignored_modules = ignore_cuda field_matrix_solve!(args...) @test_opt ignored_modules = ignore_cuda field_matrix_mul!(b, A, x) - using_cuda || @test @allocated(field_matrix_solve!(args...)) == 0 - using_cuda || @test @allocated(field_matrix_mul!(b, A, x)) == 0 + if !using_cuda + @test @allocated(field_matrix_solve!(args...)) == 0 broken = + allocations_test_broken + @test @allocated(field_matrix_mul!(b, A, x)) == 0 broken = + allocations_test_broken + end end end @@ -111,11 +111,15 @@ end # TODO: Add a simple test where typeof(x) != typeof(b). + # Note: The round-off error of StationaryIterativeSolve can be much larger + # on GPUs, so n_iters often has to be increased when using_cuda is true. + for alg in ( MatrixFields.BlockDiagonalSolve(), MatrixFields.BlockLowerTriangularSolve(@name(c)), - MatrixFields.SchurComplementSolve(@name(f)), - MatrixFields.ApproximateFactorizationSolve((@name(c), @name(c))), + MatrixFields.BlockArrowheadSolve(@name(c)), + MatrixFields.ApproximateBlockArrowheadIterativeSolve(@name(c)), + MatrixFields.StationaryIterativeSolve(; n_iters = using_cuda ? 28 : 18), ) test_field_matrix_solver(; test_name = "$(typeof(alg).name.name) for a block diagonal matrix \ @@ -154,9 +158,9 @@ end ) test_field_matrix_solver(; - test_name = "SchurComplementSolve for a block matrix with diagonal, \ + test_name = "BlockArrowheadSolve for a block matrix with diagonal, \ quad-diagonal, bi-diagonal, and penta-diagonal blocks", - alg = MatrixFields.SchurComplementSolve(@name(f)), + alg = MatrixFields.BlockArrowheadSolve(@name(c)), A = MatrixFields.FieldMatrix( (@name(c), @name(c)) => ᶜᶜmat1, (@name(c), @name(f)) => ᶜᶠmat4, @@ -166,14 +170,14 @@ end b = Fields.FieldVector(; c = ᶜvec, f = ᶠvec), ) + # Since test_field_matrix_solver runs the solver many times with the same + # values of x, A, and b for benchmarking, setting correlated_solves to true + # is equivalent to setting n_iters to some very large number. test_field_matrix_solver(; - test_name = "ApproximateFactorizationSolve for a block matrix with \ - tri-diagonal, quad-diagonal, bi-diagonal, and \ - penta-diagonal blocks", - alg = MatrixFields.ApproximateFactorizationSolve( - (@name(c), @name(c)); - alg₂ = MatrixFields.SchurComplementSolve(@name(f)), - ), + test_name = "StationaryIterativeSolve with correlated_solves for a \ + block matrix with tri-diagonal, quad-diagonal, \ + bi-diagonal, and penta-diagonal blocks", + alg = MatrixFields.StationaryIterativeSolve(; correlated_solves = true), A = MatrixFields.FieldMatrix( (@name(c), @name(c)) => ᶜᶜmat3, (@name(c), @name(f)) => ᶜᶠmat4, @@ -181,8 +185,156 @@ end (@name(f), @name(f)) => ᶠᶠmat5, ), b = Fields.FieldVector(; c = ᶜvec, f = ᶠvec), - ignore_approximation_error = true, ) + + # Each of the scaled identity matrices below was chosen to minimize the + # value of ρ(I - P⁻¹ * A), which was found by setting print_radius to true. + # Each value of n_iters below was then chosen to be the smallest value for + # which the relative error was less than 1e-6. + scaled_identity_matrix(scalar) = + MatrixFields.FieldMatrix((@name(), @name()) => scalar * I) + for (P_name, alg) in ( + ( + "no (identity matrix)", + MatrixFields.StationaryIterativeSolve(; + n_iters = using_cuda ? 10 : 7, + ), + ), # ρ(I - P⁻¹ * A) ≈ 0.3777 + ( + "Richardson (damped identity matrix)", + MatrixFields.StationaryIterativeSolve(; + P_alg = MatrixFields.CustomPreconditioner( + scaled_identity_matrix(FT(1.12)), + ), + n_iters = using_cuda ? 8 : 7, + ), + ), # ρ(I - P⁻¹ * A) ≈ 0.2294 + ( + "Jacobi (diagonal)", + MatrixFields.StationaryIterativeSolve(; + P_alg = MatrixFields.MainDiagonalPreconditioner(), + n_iters = using_cuda ? 8 : 6, + ), + ), # ρ(I - P⁻¹ * A) ≈ 0.3241 + ( + "damped Jacobi (diagonal)", + MatrixFields.StationaryIterativeSolve(; + P_alg = MatrixFields.WeightedPreconditioner( + scaled_identity_matrix(FT(1.08)), + MatrixFields.MainDiagonalPreconditioner(), + ), + n_iters = using_cuda ? 8 : 7, + ), + ), # ρ(I - P⁻¹ * A) ≈ 0.2249 + ( + "block Jacobi (diagonal)", + MatrixFields.StationaryIterativeSolve(; + P_alg = MatrixFields.BlockDiagonalPreconditioner(), + n_iters = 7, + ), + ), # ρ(I - P⁻¹ * A) ≈ 0.1450 + ( + "damped block Jacobi (diagonal)", + MatrixFields.StationaryIterativeSolve(; + P_alg = MatrixFields.WeightedPreconditioner( + scaled_identity_matrix(FT(1.002)), + MatrixFields.BlockDiagonalPreconditioner(), + ), + n_iters = 7, + ), + ), # ρ(I - P⁻¹ * A) ≈ 0.1427 + ( + "block arrowhead", + MatrixFields.StationaryIterativeSolve(; + P_alg = MatrixFields.BlockArrowheadPreconditioner(@name(c)), + n_iters = 6, + ), + ), # ρ(I - P⁻¹ * A) ≈ 0.1356 + ( + "damped block arrowhead", + MatrixFields.StationaryIterativeSolve(; + P_alg = MatrixFields.BlockArrowheadPreconditioner( + @name(c); + P_alg₁ = MatrixFields.WeightedPreconditioner( + scaled_identity_matrix(FT(1.0001)), + MatrixFields.MainDiagonalPreconditioner(), + ), + ), + n_iters = 6, + ), + ), # ρ(I - P⁻¹ * A) ≈ 0.1355 + ( + "block arrowhead Schur complement", + MatrixFields.ApproximateBlockArrowheadIterativeSolve( + @name(c); + n_iters = 3, + ), + ), # ρ(I - P⁻¹ * A) ≈ 0.0009 + ( + "damped block arrowhead Schur complement", + MatrixFields.ApproximateBlockArrowheadIterativeSolve( + @name(c); + P_alg₁ = MatrixFields.WeightedPreconditioner( + scaled_identity_matrix(FT(1.09)), + MatrixFields.MainDiagonalPreconditioner(), + ), + n_iters = 2, + ), + ), # ρ(I - P⁻¹ * A) ≈ 0.000006 + ) + test_field_matrix_solver(; + test_name = "approximate iterative solve with $P_name \ + preconditioning for a block matrix with tri-diagonal, \ + quad-diagonal, bi-diagonal, and penta-diagonal blocks", + alg, + A = MatrixFields.FieldMatrix( + (@name(c), @name(c)) => ᶜᶜmat3, + (@name(c), @name(f)) => ᶜᶠmat4, + (@name(f), @name(c)) => ᶠᶜmat2, + (@name(f), @name(f)) => ᶠᶠmat5, + ), + b = Fields.FieldVector(; c = ᶜvec, f = ᶠvec), + use_rel_error = true, + ) + end + + # Since printing causes both allocations and type instabilities, we do not + # use test_field_matrix_solver to test print_radius and print_norm. Instead, + # we just run the solver and make sure it does not throw an error. + @testset "approximate iterative solve with print_norm/print_radius" begin + A = MatrixFields.FieldMatrix( + (@name(c), @name(c)) => ᶜᶜmat3, + (@name(c), @name(f)) => ᶜᶠmat4, + (@name(f), @name(c)) => ᶠᶜmat2, + (@name(f), @name(f)) => ᶠᶠmat5, + ) + b = Fields.FieldVector(; c = ᶜvec, f = ᶠvec) + + alg = MatrixFields.ApproximateBlockArrowheadIterativeSolve( + @name(c); + P_alg₁ = MatrixFields.WeightedPreconditioner( + scaled_identity_matrix(FT(1.09)), + MatrixFields.MainDiagonalPreconditioner(), + ), + n_iters = 2, + print_norm = true, + ) + @info "output of print_norm for the previous solve:" + field_matrix_solve!(FieldMatrixSolver(alg, A, b), similar(b), A, b) + + if !using_cuda # KrylovKit's eigsolve function is not GPU-compatible. + alg = MatrixFields.ApproximateBlockArrowheadIterativeSolve( + @name(c); + P_alg₁ = MatrixFields.WeightedPreconditioner( + scaled_identity_matrix(FT(1.09)), + MatrixFields.MainDiagonalPreconditioner(), + ), + print_radius = true, + ) + @info "output of print_radius for the previous solve:" + field_matrix_solve!(FieldMatrixSolver(alg, A, b), similar(b), A, b) + end + end end @testset "FieldMatrixSolver ClimaAtmos-Based Tests" begin @@ -249,7 +401,7 @@ end test_field_matrix_solver(; test_name = "similar solve to ClimaAtmos's dry dycore with implicit \ acoustic waves", - alg = MatrixFields.SchurComplementSolve(@name(f)), + alg = MatrixFields.BlockArrowheadSolve(@name(c)), A = MatrixFields.FieldMatrix( (@name(c.ρ), @name(c.ρ)) => I, (@name(c.ρe_tot), @name(c.ρe_tot)) => I, @@ -266,11 +418,9 @@ end test_field_matrix_solver(; test_name = "similar solve to ClimaAtmos's dry dycore with implicit \ acoustic waves and diffusion", - alg = MatrixFields.ApproximateFactorizationSolve( - (@name(c), @name(f)), - (@name(f), @name(c)), - (@name(f), @name(f)); - alg₁ = MatrixFields.SchurComplementSolve(@name(f)), + alg = MatrixFields.ApproximateBlockArrowheadIterativeSolve( + @name(c); + n_iters = 6, ), A = MatrixFields.FieldMatrix( (@name(c.ρ), @name(c.ρ)) => I, @@ -283,17 +433,14 @@ end (@name(f.u₃), @name(f.u₃)) => ᶠᶠmat3_u₃_u₃, ), b = b_dry_dycore, - ignore_approximation_error = true, ) test_field_matrix_solver(; test_name = "similar solve to ClimaAtmos's moist dycore + diagnostic \ EDMF with implicit acoustic waves and SGS fluxes", - alg = MatrixFields.ApproximateFactorizationSolve( - (@name(c), @name(f)), - (@name(f), @name(c)), - (@name(f), @name(f)); - alg₁ = MatrixFields.SchurComplementSolve(@name(f)), + alg = MatrixFields.ApproximateBlockArrowheadIterativeSolve( + @name(c); + n_iters = 6, ), A = MatrixFields.FieldMatrix( (@name(c.ρ), @name(c.ρ)) => I, @@ -310,73 +457,73 @@ end (@name(f.u₃), @name(f.u₃)) => ᶠᶠmat3_u₃_u₃, ), b = b_moist_dycore_diagnostic_edmf, - ignore_approximation_error = true, ) - # TODO: This unit test is currently broken. - # test_field_matrix_solver(; - # test_name = "similar solve to ClimaAtmos's moist dycore + prognostic \ - # EDMF + prognostic surface temperature with implicit \ - # acoustic waves and SGS fluxes", - # alg = MatrixFields.BlockLowerTriangularSolve( - # @name(c.sgsʲs), - # @name(f.sgsʲs); - # alg₁ = MatrixFields.SchurComplementSolve(@name(f)), - # alg₂ = MatrixFields.ApproximateFactorizationSolve( - # (@name(c), @name(f)), - # (@name(f), @name(c)), - # (@name(f), @name(f)); - # alg₁ = MatrixFields.SchurComplementSolve(@name(f)), - # ), - # ), - # A = MatrixFields.FieldMatrix( - # # GS-GS blocks: - # (@name(sfc), @name(sfc)) => I, - # (@name(c.ρ), @name(c.ρ)) => I, - # (@name(c.ρe_tot), @name(c.ρe_tot)) => ᶜᶜmat3, - # (@name(c.ρatke), @name(c.ρatke)) => ᶜᶜmat3, - # (@name(c.ρχ), @name(c.ρχ)) => ᶜᶜmat3, - # (@name(c.uₕ), @name(c.uₕ)) => ᶜᶜmat3, - # (@name(c.ρ), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, - # (@name(c.ρe_tot), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, - # (@name(c.ρatke), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, - # (@name(c.ρχ), @name(f.u₃)) => ᶜᶠmat2_ρχ_u₃, - # (@name(f.u₃), @name(c.ρ)) => ᶠᶜmat2_u₃_scalar, - # (@name(f.u₃), @name(c.ρe_tot)) => ᶠᶜmat2_u₃_scalar, - # (@name(f.u₃), @name(f.u₃)) => ᶠᶠmat3_u₃_u₃, - # # GS-SGS blocks: - # (@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρae_tot)) => ᶜᶜmat3, - # (@name(c.ρχ.ρq_tot), @name(c.sgsʲs.:(1).ρaχ.ρaq_tot)) => ᶜᶜmat3, - # (@name(c.ρχ.ρq_liq), @name(c.sgsʲs.:(1).ρaχ.ρaq_liq)) => ᶜᶜmat3, - # (@name(c.ρχ.ρq_ice), @name(c.sgsʲs.:(1).ρaχ.ρaq_ice)) => ᶜᶜmat3, - # (@name(c.ρχ.ρq_rai), @name(c.sgsʲs.:(1).ρaχ.ρaq_rai)) => ᶜᶜmat3, - # (@name(c.ρχ.ρq_sno), @name(c.sgsʲs.:(1).ρaχ.ρaq_sno)) => ᶜᶜmat3, - # (@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3, - # (@name(c.ρatke), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3, - # (@name(c.ρχ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3_ρχ_scalar, - # (@name(c.uₕ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3_uₕ_scalar, - # (@name(c.ρe_tot), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_scalar_u₃, - # (@name(c.ρatke), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_scalar_u₃, - # (@name(c.ρχ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_ρχ_u₃, - # (@name(c.uₕ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_uₕ_u₃, - # (@name(f.u₃), @name(c.sgsʲs.:(1).ρa)) => ᶠᶜmat2_u₃_scalar, - # (@name(f.u₃), @name(f.sgsʲs.:(1).u₃)) => ᶠᶠmat3_u₃_u₃, - # # SGS-SGS blocks: - # (@name(c.sgsʲs.:(1).ρa), @name(c.sgsʲs.:(1).ρa)) => I, - # (@name(c.sgsʲs.:(1).ρae_tot), @name(c.sgsʲs.:(1).ρae_tot)) => I, - # (@name(c.sgsʲs.:(1).ρaχ), @name(c.sgsʲs.:(1).ρaχ)) => I, - # (@name(c.sgsʲs.:(1).ρa), @name(f.sgsʲs.:(1).u₃)) => - # ᶜᶠmat2_scalar_u₃, - # (@name(c.sgsʲs.:(1).ρae_tot), @name(f.sgsʲs.:(1).u₃)) => - # ᶜᶠmat2_scalar_u₃, - # (@name(c.sgsʲs.:(1).ρaχ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_ρaχ_u₃, - # (@name(f.sgsʲs.:(1).u₃), @name(c.sgsʲs.:(1).ρa)) => - # ᶠᶜmat2_u₃_scalar, - # (@name(f.sgsʲs.:(1).u₃), @name(c.sgsʲs.:(1).ρae_tot)) => - # ᶠᶜmat2_u₃_scalar, - # (@name(f.sgsʲs.:(1).u₃), @name(f.sgsʲs.:(1).u₃)) => ᶠᶠmat3_u₃_u₃, - # ), - # b = b_moist_dycore_prognostic_edmf_prognostic_surface, - # skip_correctness_test = true, - # ) + # TODO: This test currently triggers allocations on CPUs and takes around 40 + # minutes to compile on GPUs. Fix this by switching to Unrolled.jl. + !using_cuda && test_field_matrix_solver(; + test_name = "similar solve to ClimaAtmos's moist dycore + prognostic \ + EDMF + prognostic surface temperature with implicit \ + acoustic waves and SGS fluxes", + alg = MatrixFields.BlockLowerTriangularSolve( + @name(c.sgsʲs), + @name(f.sgsʲs); + alg₁ = MatrixFields.BlockArrowheadSolve(@name(c)), + alg₂ = MatrixFields.ApproximateBlockArrowheadIterativeSolve( + @name(c); + n_iters = 6, + ), + ), + A = MatrixFields.FieldMatrix( + # GS-GS blocks: + (@name(sfc), @name(sfc)) => I, + (@name(c.ρ), @name(c.ρ)) => I, + (@name(c.ρe_tot), @name(c.ρe_tot)) => ᶜᶜmat3, + (@name(c.ρatke), @name(c.ρatke)) => ᶜᶜmat3, + (@name(c.ρχ), @name(c.ρχ)) => ᶜᶜmat3, + (@name(c.uₕ), @name(c.uₕ)) => ᶜᶜmat3, + (@name(c.ρ), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(c.ρe_tot), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(c.ρatke), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(c.ρχ), @name(f.u₃)) => ᶜᶠmat2_ρχ_u₃, + (@name(f.u₃), @name(c.ρ)) => ᶠᶜmat2_u₃_scalar, + (@name(f.u₃), @name(c.ρe_tot)) => ᶠᶜmat2_u₃_scalar, + (@name(f.u₃), @name(f.u₃)) => ᶠᶠmat3_u₃_u₃, + # GS-SGS blocks: + (@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρae_tot)) => ᶜᶜmat3, + (@name(c.ρχ.ρq_tot), @name(c.sgsʲs.:(1).ρaχ.ρaq_tot)) => ᶜᶜmat3, + (@name(c.ρχ.ρq_liq), @name(c.sgsʲs.:(1).ρaχ.ρaq_liq)) => ᶜᶜmat3, + (@name(c.ρχ.ρq_ice), @name(c.sgsʲs.:(1).ρaχ.ρaq_ice)) => ᶜᶜmat3, + (@name(c.ρχ.ρq_rai), @name(c.sgsʲs.:(1).ρaχ.ρaq_rai)) => ᶜᶜmat3, + (@name(c.ρχ.ρq_sno), @name(c.sgsʲs.:(1).ρaχ.ρaq_sno)) => ᶜᶜmat3, + (@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3, + (@name(c.ρatke), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3, + (@name(c.ρχ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3_ρχ_scalar, + (@name(c.uₕ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3_uₕ_scalar, + (@name(c.ρe_tot), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(c.ρatke), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(c.ρχ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_ρχ_u₃, + (@name(c.uₕ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_uₕ_u₃, + (@name(f.u₃), @name(c.sgsʲs.:(1).ρa)) => ᶠᶜmat2_u₃_scalar, + (@name(f.u₃), @name(f.sgsʲs.:(1).u₃)) => ᶠᶠmat3_u₃_u₃, + # SGS-SGS blocks: + (@name(c.sgsʲs.:(1).ρa), @name(c.sgsʲs.:(1).ρa)) => I, + (@name(c.sgsʲs.:(1).ρae_tot), @name(c.sgsʲs.:(1).ρae_tot)) => I, + (@name(c.sgsʲs.:(1).ρaχ), @name(c.sgsʲs.:(1).ρaχ)) => I, + (@name(c.sgsʲs.:(1).ρa), @name(f.sgsʲs.:(1).u₃)) => + ᶜᶠmat2_scalar_u₃, + (@name(c.sgsʲs.:(1).ρae_tot), @name(f.sgsʲs.:(1).u₃)) => + ᶜᶠmat2_scalar_u₃, + (@name(c.sgsʲs.:(1).ρaχ), @name(f.sgsʲs.:(1).u₃)) => + ᶜᶠmat2_ρaχ_u₃, + (@name(f.sgsʲs.:(1).u₃), @name(c.sgsʲs.:(1).ρa)) => + ᶠᶜmat2_u₃_scalar, + (@name(f.sgsʲs.:(1).u₃), @name(c.sgsʲs.:(1).ρae_tot)) => + ᶠᶜmat2_u₃_scalar, + (@name(f.sgsʲs.:(1).u₃), @name(f.sgsʲs.:(1).u₃)) => + ᶠᶠmat3_u₃_u₃, + ), + b = b_moist_dycore_prognostic_edmf_prognostic_surface, + allocations_test_broken = true, + ) end diff --git a/test/MatrixFields/field_names.jl b/test/MatrixFields/field_names.jl index 6834a9eef1..fe9582426d 100644 --- a/test/MatrixFields/field_names.jl +++ b/test/MatrixFields/field_names.jl @@ -1,4 +1,6 @@ -import ClimaCore.MatrixFields: @name +import LinearAlgebra: I +import ClimaCore.DataLayouts: replace_basetype +import ClimaCore.MatrixFields: @name, is_subset_that_covers_set include("matrix_field_test_utils.jl") @@ -7,7 +9,8 @@ struct Foo{T} end Base.propertynames(::Foo) = (:value,) Base.getproperty(foo::Foo, s::Symbol) = - s == :value ? getfield(foo, :_value) : nothing + s == :value ? getfield(foo, :_value) : error("Invalid property name") +Base.convert(::Type{Foo{T}}, foo::Foo) where {T} = Foo{T}(foo.value) const x = (; foo = Foo(0), a = (; b = 1, c = ((; d = 2), (;), ((), nothing)))) @@ -18,7 +21,7 @@ const x = (; foo = Foo(0), a = (; b = 1, c = ((; d = 2), (;), ((), nothing)))) @test_throws "not a valid property name" @macroexpand @name("a") @test_throws "not a valid property name" @macroexpand @name([a]) - @test_throws "not a valid property name" @macroexpand @name((a.c.:(3)):(1)) + @test_throws "not a valid property name" @macroexpand @name((a.c.:3.0):1) @test_throws "not a valid property name" @macroexpand @name(a.c.:(3).(1)) @test string(@name()) == "@name()" @@ -106,6 +109,7 @@ end @testset "FieldNameSet Unit Tests" begin name_tree = MatrixFields.FieldNameTree(x) + vector_keys(names...) = MatrixFields.FieldVectorKeys(names, name_tree) matrix_keys(name_pairs...) = MatrixFields.FieldMatrixKeys(name_pairs, name_tree) @@ -114,12 +118,17 @@ end matrix_keys_no_tree(name_pairs...) = MatrixFields.FieldMatrixKeys(name_pairs) - @testset "FieldNameSet Construction" begin + drop_tree(set) = + set isa MatrixFields.FieldVectorKeys ? + MatrixFields.FieldVectorKeys(set.values) : + MatrixFields.FieldMatrixKeys(set.values) + + @testset "FieldNameSet Constructors" begin @test_throws "Invalid FieldNameSet value" vector_keys( - @name(foo.invalid_name), + @name(invalid_name), ) @test_throws "Invalid FieldNameSet value" matrix_keys(( - @name(foo.invalid_name), + @name(invalid_name), @name(a.c), ),) @@ -145,14 +154,51 @@ end end end - @testset "FieldNameSet Iteration" begin - v_set1 = vector_keys(@name(foo), @name(a.c)) - v_set1_no_tree = vector_keys_no_tree(@name(foo), @name(a.c)) - m_set1 = matrix_keys((@name(foo), @name(a.c)), (@name(a.b), @name(foo))) - m_set1_no_tree = matrix_keys_no_tree( - (@name(foo), @name(a.c)), - (@name(a.b), @name(foo)), - ) + v_set1 = vector_keys(@name(foo), @name(a.c)) + m_set1 = matrix_keys((@name(foo), @name(a.c)), (@name(a.b), @name(foo))) + + # Proper subsets of v_set1 and m_set1. + v_set2 = vector_keys(@name(foo)) + m_set2 = matrix_keys((@name(foo), @name(a.c))) + + # Subsets of v_set1 and m_set1 that cover those sets. + v_set3 = vector_keys( + @name(foo.value), + @name(a.c.:(1)), + @name(a.c.:(2)), + @name(a.c.:(3)), + ) + m_set3 = matrix_keys( + (@name(foo.value), @name(a.c.:(1))), + (@name(foo), @name(a.c.:(2))), + (@name(foo), @name(a.c.:(3))), + (@name(a.b), @name(foo.value)), + ) + + # Sets that overlap with v_set1 and m_set1, but are neither subsets nor + # supersets of those sets. Some of the values in m_set4 overlap with values + # in m_set1, but they are neither children nor parents of those values (this + # is only possible with matrix keys). + v_set4 = vector_keys(@name(a.b), @name(a.c.:(1)), @name(a.c.:(2))) + m_set4 = matrix_keys( + (@name(), @name(a.c.:(1))), + (@name(foo.value), @name(foo)), + (@name(foo.value), @name(a.c.:(2))), + (@name(a), @name(foo.value)), + (@name(a.c.:(3)), @name(a.c.:(3))), + ) + + @testset "FieldNameSet Basic Operations" begin + @test string(v_set1) == + "FieldVectorKeys(@name(foo), @name(a.c); )" + @test string(drop_tree(v_set1)) == + "FieldVectorKeys(@name(foo), @name(a.c))" + @test string(m_set1) == + "FieldMatrixKeys((@name(foo), @name(a.c)), (@name(a.b), \ + @name(foo)); )" + @test string(drop_tree(m_set1)) == + "FieldMatrixKeys((@name(foo), @name(a.c)), (@name(a.b), \ + @name(foo)))" @test_all map(name -> (name, name), v_set1) == ((@name(foo), @name(foo)), (@name(a.c), @name(a.c))) @@ -162,155 +208,212 @@ end @test_all isnothing(foreach(name -> (name, name), v_set1)) @test_all isnothing(foreach(name_pair -> name_pair[1], m_set1)) - @test string(v_set1) == - "FieldVectorKeys(@name(foo), @name(a.c); )" - @test string(v_set1_no_tree) == - "FieldVectorKeys(@name(foo), @name(a.c))" - @test string(m_set1) == "FieldMatrixKeys((@name(foo), @name(a.c)), \ - (@name(a.b), @name(foo)); )" - @test string(m_set1_no_tree) == "FieldMatrixKeys((@name(foo), \ - @name(a.c)), (@name(a.b), @name(foo)))" - - for set in (v_set1, v_set1_no_tree) - @test_all @name(foo) in set - @test_all !(@name(a.b) in set) - @test_all !(@name(invalid_name) in set) + for set1 in (v_set1, drop_tree(v_set1)) + @test_all @name(foo) in set1 + @test_all @name(foo.value) in set1 + @test_all !(@name(a.b) in set1) + @test_all !(@name(invalid_name) in set1) end - for set in (m_set1, m_set1_no_tree) - @test_all (@name(foo), @name(a.c)) in set - @test_all !((@name(foo), @name(a.b)) in set) - @test_all !((@name(foo), @name(invalid_name)) in set) + for set1 in (m_set1, drop_tree(m_set1)) + @test_all (@name(foo), @name(a.c)) in set1 + @test_all (@name(foo.value), @name(a.c)) in set1 + @test_all !((@name(foo), @name(a.b)) in set1) + @test_all !((@name(foo), @name(invalid_name)) in set1) end - @test_all @name(foo.value) in v_set1 @test_all !(@name(foo.invalid_name) in v_set1) - @test_throws "FieldNameTree" @name(foo.value) in v_set1_no_tree - @test_throws "FieldNameTree" @name(foo.invalid_name) in v_set1_no_tree - - @test_all (@name(foo.value), @name(a.c)) in m_set1 + @test_all @name(foo.invalid_name) in drop_tree(v_set1) @test_all !((@name(foo.invalid_name), @name(a.c)) in m_set1) - @test_throws "FieldNameTree" (@name(foo.value), @name(a.c)) in - m_set1_no_tree - @test_throws "FieldNameTree" (@name(foo.invalid_name), @name(a.c)) in - m_set1_no_tree + @test_all (@name(foo.invalid_name), @name(a.c)) in drop_tree(m_set1) end - @testset "FieldNameSet Operations for Addition/Subtraction" begin - v_set1 = vector_keys(@name(foo), @name(a.c)) - v_set1_no_tree = vector_keys_no_tree(@name(foo), @name(a.c)) - m_set1 = matrix_keys((@name(foo), @name(a.c)), (@name(a.b), @name(foo))) - m_set1_no_tree = matrix_keys_no_tree( - (@name(foo), @name(a.c)), - (@name(a.b), @name(foo)), + @testset "FieldNameSet Complement Sets" begin + @test_all MatrixFields.set_complement(v_set1) == + vector_keys_no_tree(@name(a.b)) + @test_all MatrixFields.set_complement(v_set2) == + vector_keys_no_tree(@name(a)) + @test_all MatrixFields.set_complement(v_set3) == + vector_keys_no_tree(@name(a.b)) + @test_all MatrixFields.set_complement(v_set4) == + vector_keys_no_tree(@name(foo), @name(a.c.:(3))) + @test_throws "FieldNameTree" MatrixFields.set_complement( + drop_tree(v_set1), ) - v_set2 = vector_keys(@name(foo)) - v_set2_no_tree = vector_keys_no_tree(@name(foo)) - m_set2 = matrix_keys((@name(foo), @name(a.c))) - m_set2_no_tree = matrix_keys_no_tree((@name(foo), @name(a.c))) - - v_set3 = vector_keys( - @name(foo.value), - @name(a.c.:(1)), - @name(a.c.:(2)), - @name(a.c.:(3)), + @test_all MatrixFields.set_complement(m_set1) == matrix_keys_no_tree( + (@name(foo), @name(foo)), + (@name(foo), @name(a.b)), + (@name(a), @name(a)), + (@name(a.c), @name(foo)), ) - v_set3_no_tree = vector_keys_no_tree( - @name(foo.value), - @name(a.c.:(1)), - @name(a.c.:(2)), - @name(a.c.:(3)), + @test_all MatrixFields.set_complement(m_set2) == matrix_keys_no_tree( + (@name(foo), @name(foo)), + (@name(foo), @name(a.b)), + (@name(a), @name(foo)), + (@name(a), @name(a)), ) - m_set3 = matrix_keys( - (@name(foo), @name(a.c.:(1))), - (@name(foo), @name(a.c.:(2))), - (@name(foo.value), @name(a.c.:(3))), - (@name(a.b), @name(foo)), + @test_all MatrixFields.set_complement(m_set3) == matrix_keys_no_tree( + (@name(foo), @name(foo)), + (@name(foo), @name(a.b)), + (@name(a), @name(a)), + (@name(a.c), @name(foo)), ) - m_set3_no_tree = matrix_keys_no_tree( - (@name(foo), @name(a.c.:(1))), - (@name(foo), @name(a.c.:(2))), - (@name(foo.value), @name(a.c.:(3))), - (@name(a.b), @name(foo)), + @test_all MatrixFields.set_complement(m_set4) == matrix_keys_no_tree( + (@name(foo), @name(a.b)), + (@name(foo), @name(a.c.:(3))), + (@name(a), @name(a.b)), + (@name(a), @name(a.c.:(2))), + (@name(a.b), @name(a.c.:(3))), + (@name(a.c.:(1)), @name(a.c.:(3))), + (@name(a.c.:(2)), @name(a.c.:(3))), ) - m_set3_no_tree′ = matrix_keys_no_tree( - (@name(foo.value), @name(a.c.:(1))), - (@name(foo.value), @name(a.c.:(2))), - (@name(foo.value), @name(a.c.:(3))), - (@name(a.b), @name(foo)), + @test_throws "FieldNameTree" MatrixFields.set_complement( + drop_tree(m_set1), ) + end + + @testset "FieldNameSet Binary Set Operations" begin + for set1 in (v_set1, drop_tree(v_set1), m_set1, drop_tree(m_set1)) + @test_all set1 == set1 + @test_all issubset(set1, set1) + @test_all is_subset_that_covers_set(set1, set1) + @test_all intersect(set1, set1) == set1 + @test_all union(set1, set1) == set1 + @test_all isempty(setdiff(set1, set1)) + end for (set1, set2) in ( (v_set1, v_set2), - (v_set1, v_set2_no_tree), - (v_set1_no_tree, v_set2), + (v_set1, drop_tree(v_set2)), + (drop_tree(v_set1), v_set2), + (drop_tree(v_set1), drop_tree(v_set2)), (m_set1, m_set2), - (m_set1, m_set2_no_tree), - (m_set1_no_tree, m_set2), + (m_set1, drop_tree(m_set2)), + (drop_tree(m_set1), m_set2), + (drop_tree(m_set1), drop_tree(m_set2)), ) - @test_all set1 != set2 - @test_all !issubset(set1, set2) - @test_all issubset(set2, set1) - @test_all intersect(set1, set2) == set2 - @test_all union(set1, set2) == set1 - @test_all !MatrixFields.is_subset_that_covers_set(set1, set2) - @test_all !MatrixFields.is_subset_that_covers_set(set2, set1) - end - - for (set1, set2) in - ((v_set1_no_tree, v_set2_no_tree), (m_set1_no_tree, m_set2_no_tree)) - @test_all set1 != set2 - @test_all !issubset(set1, set2) - @test_all issubset(set2, set1) - @test_all intersect(set1, set2) == set2 - @test_all union(set1, set2) == set1 - @test_all !MatrixFields.is_subset_that_covers_set(set1, set2) - @test_throws "FieldNameTree" MatrixFields.is_subset_that_covers_set( - set2, - set1, - ) + @test_all set1 != set2 && set2 != set1 + @test_all !issubset(set1, set2) && issubset(set2, set1) + @test_all !is_subset_that_covers_set(set1, set2) && + !is_subset_that_covers_set(set2, set1) + @test_all intersect(set1, set2) == intersect(set2, set1) == set2 + @test_all union(set1, set2) == union(set2, set1) == set1 + if set1 isa MatrixFields.FieldVectorKeys + @test_all setdiff(set1, set2) == vector_keys(@name(a.c)) + else + @test_all setdiff(set1, set2) == + matrix_keys((@name(a.b), @name(foo))) + end + @test_all isempty(setdiff(set2, set1)) end for (set1, set3) in ( (v_set1, v_set3), - (v_set1, v_set3_no_tree), - (v_set1_no_tree, v_set3), + (v_set1, drop_tree(v_set3)), + (drop_tree(v_set1), v_set3), + (m_set1, m_set3), + (m_set1, drop_tree(m_set3)), + (drop_tree(m_set1), m_set3), ) - @test_all set1 != set3 - @test_all !issubset(set1, set3) - @test_all issubset(set3, set1) - @test_all intersect(set1, set3) == set3 - @test_all union(set1, set3) == set3 - @test_all !MatrixFields.is_subset_that_covers_set(set1, set3) - @test_all MatrixFields.is_subset_that_covers_set(set3, set1) + @test_all set1 != set3 && set3 != set1 + @test_all !issubset(set1, set3) && issubset(set3, set1) + @test_all !is_subset_that_covers_set(set1, set3) && + is_subset_that_covers_set(set3, set1) + @test_all intersect(set1, set3) == intersect(set3, set1) == set3 + @test_all union(set1, set3) == union(set3, set1) == set3 + @test_all isempty(setdiff(set1, set3)) && + isempty(setdiff(set3, set1)) end for (set1, set3) in ( - (m_set1, m_set3), - (m_set1, m_set3_no_tree), - (m_set1_no_tree, m_set3), + (drop_tree(v_set1), drop_tree(v_set3)), + (drop_tree(m_set1), drop_tree(m_set3)), ) - @test_all set1 != set3 - @test_all !issubset(set1, set3) - @test_all issubset(set3, set1) - @test_all intersect(set1, set3) == m_set3_no_tree′ - @test_all union(set1, set3) == m_set3_no_tree′ - @test_all !MatrixFields.is_subset_that_covers_set(set1, set3) - @test_all MatrixFields.is_subset_that_covers_set(set3, set1) + @test_all set1 != set3 && set3 != set1 + @test_all !issubset(set1, set3) && issubset(set3, set1) + @test_all !is_subset_that_covers_set(set1, set3) + @test_throws "FieldNameTree" is_subset_that_covers_set(set3, set1) + @test_throws "FieldNameTree" intersect(set1, set3) + @test_throws "FieldNameTree" intersect(set3, set1) + @test_throws "FieldNameTree" union(set1, set3) + @test_throws "FieldNameTree" union(set3, set1) + @test_throws "FieldNameTree" setdiff(set1, set3) + @test_throws "FieldNameTree" setdiff(set3, set1) end - for (set1, set3) in - ((v_set1_no_tree, v_set3_no_tree), (m_set1_no_tree, m_set3_no_tree)) - @test_all set1 != set3 - @test_all !issubset(set1, set3) - @test_throws "FieldNameTree" issubset(set3, set1) - @test_throws "FieldNameTree" intersect(set1, set3) == set3 - @test_throws "FieldNameTree" union(set1, set3) == set3 - @test_all !MatrixFields.is_subset_that_covers_set(set1, set3) - @test_throws "FieldNameTree" MatrixFields.is_subset_that_covers_set( - set3, - set1, - ) + for (set1, set4) in ( + (v_set1, v_set4), + (v_set1, drop_tree(v_set4)), + (drop_tree(v_set1), v_set4), + (m_set1, m_set4), + (m_set1, drop_tree(m_set4)), + (drop_tree(m_set1), m_set4), + ) + @test_all set1 != set4 && set4 != set1 + @test_all !issubset(set1, set4) && !issubset(set4, set1) + @test_all !is_subset_that_covers_set(set1, set4) && + !is_subset_that_covers_set(set4, set1) + if set1 isa MatrixFields.FieldVectorKeys + @test_all intersect(set1, set4) == + intersect(set4, set1) == + vector_keys_no_tree(@name(a.c.:(1)), @name(a.c.:(2))) + @test_all union(set1, set4) == + union(set4, set1) == + vector_keys_no_tree( + @name(foo), + @name(a.b), + @name(a.c.:(1)), + @name(a.c.:(2)), + @name(a.c.:(3)), + ) + @test_all setdiff(set1, set4) == + vector_keys_no_tree(@name(foo), @name(a.c.:(3))) + @test_all setdiff(set4, set1) == vector_keys_no_tree(@name(a.b)) + else + @test_all intersect(set1, set4) == + intersect(set4, set1) == + matrix_keys_no_tree( + (@name(foo), @name(a.c.:(1))), + (@name(foo.value), @name(a.c.:(2))), + (@name(a.b), @name(foo.value)), + ) + @test_all union(set1, set4) == + union(set4, set1) == + matrix_keys_no_tree( + (@name(foo), @name(a.c.:(1))), + (@name(foo), @name(a.c.:(3))), + (@name(foo.value), @name(foo)), + (@name(foo.value), @name(a.c.:(2))), + (@name(a), @name(a.c.:(1))), + (@name(a.b), @name(foo.value)), + (@name(a.c), @name(foo.value)), + (@name(a.c.:(3)), @name(a.c.:(3))), + ) + @test_all setdiff(set1, set4) == + matrix_keys_no_tree((@name(foo), @name(a.c.:(3)))) + @test_all setdiff(set4, set1) == matrix_keys_no_tree( + (@name(foo.value), @name(foo)), + (@name(a), @name(a.c.:(1))), + (@name(a.c), @name(foo.value)), + (@name(a.c.:(3)), @name(a.c.:(3))), + ) + end + end + + for (set1, set4) in ( + (drop_tree(v_set1), drop_tree(v_set4)), + (drop_tree(m_set1), drop_tree(m_set4)), + ) + @test_all set1 != set4 && set4 != set1 + @test_all !issubset(set1, set4) && !issubset(set4, set1) + @test_all !is_subset_that_covers_set(set1, set4) && + !is_subset_that_covers_set(set4, set1) + @test_throws "FieldNameTree" intersect(set1, set4) + @test_throws "FieldNameTree" intersect(set4, set1) + @test_throws "FieldNameTree" union(set1, set4) + @test_throws "FieldNameTree" union(set4, set1) + @test_throws "FieldNameTree" setdiff(set1, set4) + @test_throws "FieldNameTree" setdiff(set4, set1) end end @@ -503,111 +606,171 @@ end end @testset "Other FieldNameSet Operations" begin - v_set1 = vector_keys(@name(foo), @name(a.c)) - v_set1_no_tree = vector_keys_no_tree(@name(foo), @name(a.c)) - m_set1 = matrix_keys((@name(foo), @name(a.c)), (@name(a.b), @name(foo))) - m_set1_no_tree = matrix_keys_no_tree( - (@name(foo), @name(a.c)), - (@name(a.b), @name(foo)), + # With one exception, none of the following operations require a + # FieldNameTree. + + @test_all MatrixFields.corresponding_matrix_keys(drop_tree(v_set1)) == + matrix_keys_no_tree( + (@name(foo), @name(foo)), + (@name(a.c), @name(a.c)), + ) + + @test_all MatrixFields.cartesian_product( + drop_tree(v_set1), + drop_tree(v_set4), + ) == matrix_keys_no_tree( + (@name(foo), @name(a.b)), + (@name(foo), @name(a.c.:(1))), + (@name(foo), @name(a.c.:(2))), + (@name(a.c), @name(a.b)), + (@name(a.c), @name(a.c.:(1))), + (@name(a.c), @name(a.c.:(2))), ) - v_set2 = vector_keys(@name(foo.value), @name(a.c.:(1)), @name(a.c.:(3))) - v_set2_no_tree = vector_keys_no_tree( + @test_all MatrixFields.matrix_row_keys(drop_tree(m_set1)) == + vector_keys_no_tree(@name(foo), @name(a.b)) + + @test_all MatrixFields.matrix_row_keys(m_set4) == vector_keys_no_tree( @name(foo.value), + @name(a.b), @name(a.c.:(1)), + @name(a.c.:(2)), @name(a.c.:(3)) ) - m_set2 = matrix_keys( - (@name(foo), @name(foo)), - (@name(foo), @name(a.c.:(1))), - (@name(foo.value), @name(a.c.:(3))), - (@name(a.b), @name(foo.value)), - (@name(a), @name(a.c)), - ) - m_set2_no_tree = matrix_keys_no_tree( - (@name(foo), @name(foo)), - (@name(foo), @name(a.c.:(1))), - (@name(foo.value), @name(a.c.:(3))), - (@name(a.b), @name(foo.value)), - (@name(a), @name(a.c)), + @test_throws "FieldNameTree" MatrixFields.matrix_row_keys( + drop_tree(m_set4), ) - @test_all MatrixFields.set_complement(v_set2) == - vector_keys(@name(a.b), @name(a.c.:(2))) - @test_throws "FieldNameTree" MatrixFields.set_complement(v_set2_no_tree) - - @test_all MatrixFields.set_complement(m_set2) == matrix_keys( - (@name(foo.value), @name(a.b)), + @test_all MatrixFields.matrix_off_diagonal_keys(drop_tree(m_set4)) == + matrix_keys_no_tree( + (@name(), @name(a.c.:(1))), + (@name(foo.value), @name(foo)), (@name(foo.value), @name(a.c.:(2))), - (@name(a.c), @name(foo.value)), - (@name(a), @name(a.b)), + (@name(a), @name(foo.value)), ) - @test_throws "FieldNameTree" MatrixFields.set_complement(m_set2_no_tree) - for (set1, set2) in ( - (v_set1, v_set2), - (v_set1, v_set2_no_tree), - (v_set1_no_tree, v_set2), + @test_all MatrixFields.matrix_diagonal_keys(drop_tree(m_set4)) == + matrix_keys_no_tree( + (@name(foo.value), @name(foo.value)), + (@name(a.c.:(1)), @name(a.c.:(1))), + (@name(a.c.:(3)), @name(a.c.:(3))), ) - @test_all setdiff(set1, set2) == vector_keys(@name(a.c.:(2))) - end + end +end - for (set1, set2) in ( - (m_set1, m_set2), - (m_set1, m_set2_no_tree), - (m_set1_no_tree, m_set2), - ) - @test_all setdiff(set1, set2) == - matrix_keys((@name(foo.value), @name(a.c.:(2)))) - end +@testset "FieldNameDict Unit Tests" begin + FT = Float64 + center_space, face_space = test_spaces(FT) - for (set1, set2) in - ((v_set1_no_tree, v_set2_no_tree), (m_set1_no_tree, m_set2_no_tree)) - @test_throws "FieldNameTree" setdiff(set1, set2) - end + x_FT = convert(replace_basetype(Int, FT, typeof(x)), x) - # With one exception, none of the following operations require a - # FieldNameTree. + seed!(1) # ensures reproducibility - @test_all MatrixFields.corresponding_matrix_keys(v_set1_no_tree) == - matrix_keys( - (@name(foo), @name(foo)), - (@name(a.c), @name(a.c)), - ) + vector = Fields.FieldVector(; + foo = random_field(typeof(x_FT.foo), center_space), + a = random_field(typeof(x_FT.a), face_space), + ) - @test_all MatrixFields.cartesian_product( - v_set1_no_tree, - v_set2_no_tree, - ) == matrix_keys( - (@name(foo), @name(foo.value)), - (@name(foo), @name(a.c.:(1))), - (@name(foo), @name(a.c.:(3))), - (@name(a.c), @name(foo.value)), - (@name(a.c), @name(a.c.:(1))), - (@name(a.c), @name(a.c.:(3))), - ) + matrix = MatrixFields.FieldMatrix( + (@name(foo), @name(foo)) => -I, + (@name(a), @name(a)) => + random_field(DiagonalMatrixRow{FT}, face_space), + (@name(foo), @name(a.b)) => random_field( + BidiagonalMatrixRow{typeof(x_FT.foo)}, + center_space, + ), + (@name(a), @name(foo._value)) => + random_field(QuaddiagonalMatrixRow{typeof(x_FT.a)}, face_space); + name_tree = MatrixFields.FieldNameTree(vector), + ) - @test_all MatrixFields.matrix_row_keys(m_set1_no_tree) == - vector_keys(@name(foo), @name(a.b)) + @test_all MatrixFields.field_vector_view(vector) == + MatrixFields.FieldVectorView( + @name(foo) => vector.foo, + @name(a) => vector.a, + ) - @test_all MatrixFields.matrix_row_keys(m_set2) == - vector_keys(@name(foo.value), @name(a.b), @name(a.c)) - @test_throws "FieldNameTree" MatrixFields.matrix_row_keys( - m_set2_no_tree, - ) + vector_view = MatrixFields.field_vector_view(vector) - @test_all MatrixFields.matrix_off_diagonal_keys(m_set2_no_tree) == - matrix_keys( - (@name(foo), @name(a.c.:(1))), - (@name(foo.value), @name(a.c.:(3))), - (@name(a.b), @name(foo.value)), - (@name(a), @name(a.c)), - ) + # Some of the `.*`s in the following RegEx strings are needed to account for + # module qualifiers that may or may not get printed, depending on how these + # tests are run. - @test_all MatrixFields.matrix_diagonal_keys(m_set2_no_tree) == - matrix_keys( - (@name(foo), @name(foo)), - (@name(a.c), @name(a.c)), - ) - end + @test startswith( + string(vector_view), + r""" + .*FieldVectorView with 2 entries: + @name\(foo\) => .*-valued Field: + _value: \[.*\] + @name\(a\) => .*-valued Field: + """, + ) + + @test startswith( + string(matrix), + r""" + .*FieldMatrix with 4 entries: + \(@name\(foo\), @name\(foo\)\) => -I + \(@name\(a\), @name\(a\)\) => .*DiagonalMatrixRow{.*}-valued Field: + entries: \ + 1: \[.*\] + \(@name\(foo\), @name\(a.b\)\) => .*BidiagonalMatrixRow{.*}-valued Field: + entries: \ + 1: \ + _value: \[.*\] + 2: \ + _value: \[.*\] + \(@name\(a\), @name\(foo._value\)\) => .*QuaddiagonalMatrixRow{.*}-valued Field: + """, + ) + + @test_all vector_view[@name(foo)] == vector.foo + @test_throws KeyError vector_view[@name(invalid_name)] + @test_throws KeyError vector_view[@name(foo.invalid_name)] + + @test_all matrix[@name(foo), @name(foo)] == -I + @test_throws KeyError matrix[@name(invalid_name), @name(foo)] + @test_throws KeyError matrix[@name(foo.invalid_name), @name(foo)] + + @test_all vector_view[@name(foo._value)] == vector.foo._value + @test_all vector_view[@name(a.c)] == vector.a.c + + @test_all matrix[@name(foo._value), @name(foo._value)] == + matrix[@name(foo), @name(foo)] + @test_throws "get_internal_entry" matrix[@name(foo), @name(foo._value)] + @test_throws "get_internal_entry" matrix[@name(foo._value), @name(foo)] + + @test_all matrix[@name(a.c), @name(a.c)] == matrix[@name(a), @name(a)] + @test_throws "get_internal_entry" matrix[@name(a), @name(a.c)] + @test_throws "get_internal_entry" matrix[@name(a.c), @name(a)] + + @test_all matrix[@name(foo._value), @name(a.b)] isa Base.AbstractBroadcasted + @test Base.materialize(matrix[@name(foo._value), @name(a.b)]) == + map(row -> map(foo -> foo.value, row), matrix[@name(foo), @name(a.b)]) + + @test_all matrix[@name(a.c), @name(foo._value)] isa Base.AbstractBroadcasted + @test Base.materialize(matrix[@name(a.c), @name(foo._value)]) == + map(row -> map(a -> a.c, row), matrix[@name(a), @name(foo._value)]) + + vector_keys = MatrixFields.FieldVectorKeys((@name(foo), @name(a.c))) + @test_all vector_view[vector_keys] == MatrixFields.FieldVectorView( + @name(foo) => vector_view[@name(foo)], + @name(a.c) => vector_view[@name(a.c)], + ) + + matrix_keys = MatrixFields.FieldMatrixKeys(( + (@name(foo), @name(foo)), + (@name(a.c), @name(a.c)), + ),) + @test_all matrix[matrix_keys] == MatrixFields.FieldMatrix( + (@name(foo), @name(foo)) => matrix[@name(foo), @name(foo)], + (@name(a.c), @name(a.c)) => matrix[@name(a.c), @name(a.c)], + ) + + @test_all one(matrix) == MatrixFields.FieldMatrix( + (@name(foo), @name(foo)) => I, + (@name(a), @name(a)) => I, + ) + + # FieldNameDict broadcast operations are tested in field_matrix_solvers.jl. end