diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 97a82b37..4ec3fba6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,7 +11,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - julia-version: ["1.0", "1.4", "1.5", "~1.6.0-0"] + julia-version: ["1.0", "1.6", "1.7"] os: [ubuntu-latest, macOS-latest, windows-latest] steps: - uses: actions/checkout@v2 @@ -21,6 +21,8 @@ jobs: arch: x64 - uses: julia-actions/julia-buildpkg@latest - uses: julia-actions/julia-runtest@latest - - uses: julia-actions/julia-uploadcodecov@latest - env: - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v1 + with: + fail_ci_if_error: false + if: ${{ matrix.os =='ubuntu-latest' }} diff --git a/.github/workflows/documenter.yml b/.github/workflows/documenter.yml new file mode 100644 index 00000000..75926be8 --- /dev/null +++ b/.github/workflows/documenter.yml @@ -0,0 +1,29 @@ +name: Documenter +on: + push: + branches: [master] + tags: [v*] + pull_request: + +jobs: + docs: + name: Documentation + runs-on: ubuntu-latest + if: "contains( github.event.pull_request.labels.*.name, 'preview docs') || github.ref == 'refs/heads/master' || contains(github.ref, 'refs/tags/')" + steps: + - uses: actions/checkout@v2 + - uses: julia-actions/setup-julia@latest + with: + version: 1.7 + - uses: julia-actions/julia-docdeploy@v1 + env: + PYTHON: "" + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} + note: + name: "Documentation deployment note." + runs-on: ubuntu-latest + if: "!contains( github.event.pull_request.labels.*.name, 'preview docs')" + steps: + - name: echo instructions + run: echo 'The Documentation is only generated and pushed on a PR if the “preview docs” label is added.' diff --git a/.gitignore b/.gitignore index b582d0fe..c39388c1 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ deps/deps.jl Manifest.toml +docs/build diff --git a/.zenodo.json b/.zenodo.json new file mode 100644 index 00000000..75e29027 --- /dev/null +++ b/.zenodo.json @@ -0,0 +1,28 @@ +{ + "creators": [ + { + "affiliation": "University of California, San Francisco", + "name": "Axen, Seth", + "orcid": "0000-0003-3933-8247" + }, + { + "affiliation": "AGH University of Science and Technology", + "name": "Baran, Mateusz", + "orcid": "0000-0001-9667-5579" + }, + { + "affiliation": "NTNU Trondheim", + "name": "Bergmann, Ronny", + "orcid": "0000-0001-8342-7218" + } + ], + "description": "ManifoldsBase.jl is an interface for Riemannian manifolds in Julia", + "keywords": [ + "Riemannian manifolds", + "manifolds", + "Julia" + ], + "license": "MIT", + "title": "ManifoldsBase.jl", + "upload_type": "software" +} diff --git a/Project.toml b/Project.toml index 0774c189..c80c0d39 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ManifoldsBase" uuid = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb" authors = ["Seth Axen ", "Mateusz Baran ", "Ronny Bergmann ", "Antoine Levitt "] -version = "0.12.11" +version = "0.13.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/README.md b/README.md deleted file mode 100644 index 2219eded..00000000 --- a/README.md +++ /dev/null @@ -1,73 +0,0 @@ -# ManifoldsBase.jl -[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliamanifolds.github.io/Manifolds.jl/stable/interface.html) -[![](https://img.shields.io/badge/docs-dev-blue.svg)](https://juliamanifolds.github.io/Manifolds.jl/latest/interface.html) -[![Build Status](https://travis-ci.org/JuliaManifolds/ManifoldsBase.jl.svg?branch=master)](https://travis-ci.org/JuliaManifolds/ManifoldsBase.jl/) -[![codecov.io](http://codecov.io/github/JuliaManifolds/ManifoldsBase.jl/coverage.svg?branch=master)](https://codecov.io/gh/JuliaManifolds/ManifoldsBase.jl/) -[![arXiv](https://img.shields.io/badge/arXiv%20CS.MS-2106.08777-blue.svg)](https://arxiv.org/abs/2106.08777) - -Basic interface for manifolds in Julia. - -The project [`Manifolds.jl`](https://github.com/JuliaManifolds/Manifolds.jl) -is based on this interface and provides a variety of manifolds. - -## Number system - -A number system represents the field a manifold is based upon. -Most prominently, these are real-valued (`ℝ`) and complex valued (`ℂ`) fields that -parametrize certain manifolds. -A further type to represent the field of quaternions (`ℍ`) can also be used. - -## Bases - -Several different types of bases for a tangent space at `p` on a [`AbstractManifold`](https://juliamanifolds.github.io/Manifolds.jl/stable/interface.html#ManifoldsBase.AbstractManifold) are provided. -Methods are provided to obtain such a basis, to represent a tangent vector in a basis and to reconstruct a tangent vector from coefficients with respect to a basis. -The last two can be performed without computing the complete basis. -Further a basis can be cached and hence be reused, see [`CachedBasis`](https://juliamanifolds.github.io/Manifolds.jl/stable/interface.html#ManifoldsBase.CachedBasis). - -## `DecoratorManifold` - -The decorator manifold enhances a manifold by certain, in most cases implicitly -assumed to have a standard case, properties, see for example the `EmbeddedManifold`. -The decorator acts semi transparently, i.e. `:transparent` for all functions not affected by that -decorator and `:intransparent` otherwise. Another possibility is, that the decorator just -passes to `:parent` in order to fill default values. - -## `DefaultManifold` - -This interface includes a simple `DefaultManifold`, which is a reduced version -of the [`Euclidean`](https://juliamanifolds.github.io/Manifolds.jl/stable/manifolds/euclidean.html) -manifold from [`Manifolds.jl`](https://github.com/JuliaManifolds/Manifolds.jl), -such that the interface functions can be tested. - -## `EmbeddedManifold` - -The embedded manifold models the embedding of a manifold into another manifold. -This way a manifold can benefit from existing implementations. -One example is the `TransparentIsometricEmbeddingType` where a manifold uses the metric, -`inner`, from its embedding. - -## `ValidationManifold` - -The `ValidationManifold` further illustrates how one can also used types to -represent points on a manifold, tangent vectors, and cotangent vectors, -where values are encapsulated in a certain type. - -In general, `ValidationManifold` might be used for manifolds where these three types are represented -by more complicated data structures or when it is necessary to distinguish these -by type. - -This adds a semantic layer to the interface, and the default implementation of -`ValidationManifold` adds checks to all inputs and outputs of typed data. - -## Citation -If you use `ManifoldsBase.jl` in your work, please cite the following - -```biblatex -@online{2106.08777, -Author = {Seth D. Axen and Mateusz Baran and Ronny Bergmann and Krzysztof Rzecki}, -Title = {Manifolds.jl: An Extensible Julia Framework for Data Analysis on Manifolds}, -Year = {2021}, -Eprint = {2106.08777}, -Eprinttype = {arXiv}, -} -``` diff --git a/Readme.md b/Readme.md new file mode 100644 index 00000000..2426d116 --- /dev/null +++ b/Readme.md @@ -0,0 +1,38 @@ +
+ ManifoldsBase.jl Logo with text +
+ +[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliamanifolds.github.io/Manifolds.jl/stable/interface.html) +[![](https://img.shields.io/badge/docs-dev-blue.svg)](https://juliamanifolds.github.io/Manifolds.jl/latest/interface.html) +[![Build Status](https://travis-ci.org/JuliaManifolds/ManifoldsBase.jl.svg?branch=master)](https://travis-ci.org/JuliaManifolds/ManifoldsBase.jl/) +[![codecov.io](http://codecov.io/github/JuliaManifolds/ManifoldsBase.jl/coverage.svg?branch=master)](https://codecov.io/gh/JuliaManifolds/ManifoldsBase.jl/) +[![arXiv](https://img.shields.io/badge/arXiv%20CS.MS-2106.08777-blue.svg)](https://arxiv.org/abs/2106.08777) + +## Installation + +In Julia you can install this package by typing + +```julia +] add ManifoldsBase +``` + +in the Julia REPL. + +Since this package provides an interface, you probably either want to add it as a dependency to your project/package to work on manifold generically or implement a new manifold. +A package that (only) depends on `ManifoldsBase.jl`, see [Manopt.jl](https://manoptjl.org/stable/), which implements optimization algorithms on manifolds using this interface, i.e. they can be used with any manifold based on `ManifoldsBase.jl`. A library of manifolds implemented using this interface is provided see [Manifolds.jl](https://juliamanifolds.github.io/Manifolds.jl/stable/). + +Your package is using `ManifoldsBase`? Give us a note and we add you here. + +## Citation + +If you use `ManifoldsBase.jl` in your work, please cite the following + +```biblatex +@online{2106.08777, +Author = {Seth D. Axen and Mateusz Baran and Ronny Bergmann and Krzysztof Rzecki}, +Title = {Manifolds.jl: An Extensible Julia Framework for Data Analysis on Manifolds}, +Year = {2021}, +Eprint = {2106.08777}, +Eprinttype = {arXiv}, +} +``` diff --git a/assets/logo-text.svg b/assets/logo-text.svg new file mode 100644 index 00000000..a0787b5e --- /dev/null +++ b/assets/logo-text.svg @@ -0,0 +1,211 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ManifoldsBase.jl + + diff --git a/assets/logo_interface.jl b/assets/logo_interface.jl new file mode 100644 index 00000000..ac93e75d --- /dev/null +++ b/assets/logo_interface.jl @@ -0,0 +1,141 @@ +using Manifolds, LinearAlgebra, PGFPlotsX, Colors, Contour, Random + +# +# Settings +# +dark_mode = true + +line_offset_brightness = 0.25 +patch_opacity = 1.0 +geo_opacity = dark_mode ? 0.66 : 0.5 +geo_line_width = 30 +mesh_line_width = 5 +mesh_opacity = dark_mode ? 0.5 : 0.7 +logo_colors = [(77, 100, 174), (57, 151, 79), (202, 60, 50), (146, 89, 163)] # Julia colors + +rgb_logo_colors = map(x -> RGB(x ./ 255...), logo_colors) +rgb_logo_colors_bright = + map(x -> RGB((1 + line_offset_brightness) .* x ./ 255...), logo_colors) +rgb_logo_colors_dark = + map(x -> RGB((1 - line_offset_brightness) .* x ./ 255...), logo_colors) + +out_file_prefix = dark_mode ? "logo-interface-dark" : "logo-interface" +out_file_ext = ".svg" + +# +# Helping functions +# +polar_to_cart(r, θ) = (r * cos(θ), r * sin(θ)) +cart_to_polar(x, y) = (hypot(x, y), atan(y, x)) +normal_coord_to_vector(M, x, rθ, B) = get_vector(M, x, collect(polar_to_cart(rθ...)), B) +normal_coord_to_point(M, x, rθ, B) = exp(M, x, normal_coord_to_vector(M, x, rθ, B)) + +function plot_patch!(ax, M, x, B, r, θs; options = Dict()) + push!( + ax, + Plot3( + options, + Coordinates(map(θ -> Tuple(normal_coord_to_point(M, x, [r, θ], B)), θs)), + ), + ) + return ax +end + +function plot_geodesic!(ax, M, x, y; n = 100, options = Dict()) + γ = shortest_geodesic(M, x, y) + T = range(0, 1; length = n) + push!(ax, Plot3(options, Coordinates(Tuple.(γ.(T))))) + return ax +end + +# +# Prepare document +# +resize!(PGFPlotsX.CUSTOM_PREAMBLE, 0) +push!(PGFPlotsX.CUSTOM_PREAMBLE, raw"\pgfplotsset{scale=6.0}") +push!(PGFPlotsX.CUSTOM_PREAMBLE, raw"\usetikzlibrary{arrows.meta}") +push!(PGFPlotsX.CUSTOM_PREAMBLE, raw"\pgfplotsset{roundcaps/.style={line cap=round}}") +push!( + PGFPlotsX.CUSTOM_PREAMBLE, + raw"\pgfplotsset{meshlinestyle/.style={dash pattern=on 1.8\pgflinewidth off 1.4\pgflinewidth, line cap=round}}", +) +if dark_mode + push!(PGFPlotsX.CUSTOM_PREAMBLE, raw"\pagecolor{black}") +end +S = Sphere(2) + +center = normalize([1, 1, 1]) +x, y, z = eachrow(Matrix{Float64}(I, 3, 3)) +γ1 = shortest_geodesic(S, center, z) +γ2 = shortest_geodesic(S, center, x) +γ3 = shortest_geodesic(S, center, y) +p1 = γ1(1) +p2 = γ2(1) +p3 = γ3(1) + +# +# Setup Axes +if dark_mode + tp = @pgf Axis({ + axis_lines = "none", + axis_equal, + view = "{135}{35}", + zmin = -0.05, + zmax = 1.0, + xmin = 0.0, + xmax = 1.0, + ymin = 0.0, + ymax = 1.0, + }) +else + tp = @pgf Axis({ + axis_lines = "none", + axis_equal, + view = "{135}{35}", + zmin = -0.05, + zmax = 1.0, + xmin = 0.0, + xmax = 1.0, + ymin = 0.0, + ymax = 1.0, + }) +end +rs = range(0, π / 5; length = 6) +θs = range(0, 2π; length = 100) + +# +# Plot manifold patches +patch_colors = rgb_logo_colors[2:end] +patch_colors_line = dark_mode ? rgb_logo_colors_bright[2:end] : rgb_logo_colors_dark[2:end] +base_points = [p1, p2, p3] +basis_vectors = [log(S, p1, p2), log(S, p2, p1), log(S, p3, p1)] +for i in eachindex(base_points) + b = base_points[i] + B = DiagonalizingOrthonormalBasis(basis_vectors[i]) + basis = get_basis(S, b, B) + optionsP = @pgf {fill = patch_colors[i], draw = "none", opacity = patch_opacity} + plot_patch!(tp, S, b, basis, π / 5, θs; options = optionsP) + optionsP = + @pgf {fill = dark_mode ? "black" : "white", draw = "none", opacity = patch_opacity} + plot_patch!(tp, S, b, basis, 0.75 * π / 5, θs; options = optionsP) +end + +# +# Plot geodesics +options = @pgf { + opacity = geo_opacity, + "meshlinestyle", + no_markers, + roundcaps, + line_width = geo_line_width, + color = dark_mode ? "white" : "black", +} +plot_geodesic!(tp, S, base_points[1], base_points[2]; options = options) +plot_geodesic!(tp, S, base_points[1], base_points[3]; options = options) +plot_geodesic!(tp, S, base_points[2], base_points[3]; options = options) + +# +# Export Logo. +out_file = "$(out_file_prefix)$(out_file_ext)" +pgfsave(out_file, tp) +pgfsave("$(out_file_prefix).pdf", tp) diff --git a/docs/Project.toml b/docs/Project.toml new file mode 100644 index 00000000..5b394d13 --- /dev/null +++ b/docs/Project.toml @@ -0,0 +1,7 @@ +[deps] +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +ManifoldsBase = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb" + +[compat] +Documenter = "0.27" +ManifoldsBase = "0.13" \ No newline at end of file diff --git a/docs/make.jl b/docs/make.jl new file mode 100644 index 00000000..5eaf92ff --- /dev/null +++ b/docs/make.jl @@ -0,0 +1,24 @@ +using ManifoldsBase, Documenter + +makedocs( + format = Documenter.HTML(prettyurls = false, assets = ["assets/favicon.ico"]), + modules = [ManifoldsBase], + authors = "Seth Axen, Mateusz Baran, Ronny Bergmann, and contributors.", + sitename = "ManifoldsBase.jl", + pages = [ + "Home" => "index.md", + "How to write a manifold" => "example.md", + "Design principles" => "design.md", + "The manifold type" => "manifold_type.md", + "Functions on Maniolds" => [ + "Basic functions" => "functions.md", + "Projections" => "projections.md", + "Retractions" => "retractions.md", + "Vector Transports" => "vector_transports.md", + ], + "Manifolds" => "manifolds.md", + "Extending Manifolds" => "decorator.md", + "Bases for tangent spaces" => "bases.md", + ], +) +deploydocs(repo = "github.com/JuliaManifolds/ManifoldsBase.jl.git", push_preview = true) diff --git a/docs/src/assets/android-chrome-192x192.png b/docs/src/assets/android-chrome-192x192.png new file mode 100644 index 00000000..094d72d2 Binary files /dev/null and b/docs/src/assets/android-chrome-192x192.png differ diff --git a/docs/src/assets/android-chrome-512x512.png b/docs/src/assets/android-chrome-512x512.png new file mode 100644 index 00000000..dc01dc1e Binary files /dev/null and b/docs/src/assets/android-chrome-512x512.png differ diff --git a/docs/src/assets/apple-touch-icon.png b/docs/src/assets/apple-touch-icon.png new file mode 100644 index 00000000..a555073e Binary files /dev/null and b/docs/src/assets/apple-touch-icon.png differ diff --git a/docs/src/assets/favicon-16x16.png b/docs/src/assets/favicon-16x16.png new file mode 100644 index 00000000..b247a53c Binary files /dev/null and b/docs/src/assets/favicon-16x16.png differ diff --git a/docs/src/assets/favicon-32x32.png b/docs/src/assets/favicon-32x32.png new file mode 100644 index 00000000..88a3b022 Binary files /dev/null and b/docs/src/assets/favicon-32x32.png differ diff --git a/docs/src/assets/favicon.ico b/docs/src/assets/favicon.ico new file mode 100644 index 00000000..cca80bfe Binary files /dev/null and b/docs/src/assets/favicon.ico differ diff --git a/docs/src/assets/images/projection_illustration.png b/docs/src/assets/images/projection_illustration.png new file mode 100644 index 00000000..7bea7fce Binary files /dev/null and b/docs/src/assets/images/projection_illustration.png differ diff --git a/docs/src/assets/images/projection_illustration_600.png b/docs/src/assets/images/projection_illustration_600.png new file mode 100644 index 00000000..49ace7e6 Binary files /dev/null and b/docs/src/assets/images/projection_illustration_600.png differ diff --git a/docs/src/assets/images/retraction_illustration.png b/docs/src/assets/images/retraction_illustration.png new file mode 100644 index 00000000..2c985e81 Binary files /dev/null and b/docs/src/assets/images/retraction_illustration.png differ diff --git a/docs/src/assets/images/retraction_illustration_600.png b/docs/src/assets/images/retraction_illustration_600.png new file mode 100644 index 00000000..6b51d1b3 Binary files /dev/null and b/docs/src/assets/images/retraction_illustration_600.png differ diff --git a/docs/src/assets/logo-dark.png b/docs/src/assets/logo-dark.png new file mode 100644 index 00000000..baaed618 Binary files /dev/null and b/docs/src/assets/logo-dark.png differ diff --git a/docs/src/assets/logo-dark_bg.png b/docs/src/assets/logo-dark_bg.png new file mode 100644 index 00000000..6d1ac7c1 Binary files /dev/null and b/docs/src/assets/logo-dark_bg.png differ diff --git a/docs/src/assets/logo-interface-dark.svg b/docs/src/assets/logo-interface-dark.svg new file mode 100644 index 00000000..1d9dbee4 --- /dev/null +++ b/docs/src/assets/logo-interface-dark.svg @@ -0,0 +1,102 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/src/assets/logo-interface.svg b/docs/src/assets/logo-interface.svg new file mode 100644 index 00000000..2aa3c332 --- /dev/null +++ b/docs/src/assets/logo-interface.svg @@ -0,0 +1,96 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/src/assets/logo-text-readme.png b/docs/src/assets/logo-text-readme.png new file mode 100644 index 00000000..83e3391a Binary files /dev/null and b/docs/src/assets/logo-text-readme.png differ diff --git a/docs/src/assets/logo.png b/docs/src/assets/logo.png new file mode 100644 index 00000000..dc509864 Binary files /dev/null and b/docs/src/assets/logo.png differ diff --git a/docs/src/assets/logo_bg.png b/docs/src/assets/logo_bg.png new file mode 100644 index 00000000..c6465435 Binary files /dev/null and b/docs/src/assets/logo_bg.png differ diff --git a/docs/src/assets/mstile-150x150.png b/docs/src/assets/mstile-150x150.png new file mode 100644 index 00000000..40757b53 Binary files /dev/null and b/docs/src/assets/mstile-150x150.png differ diff --git a/docs/src/assets/safari-pinned-tab.svg b/docs/src/assets/safari-pinned-tab.svg new file mode 100644 index 00000000..1cce4156 --- /dev/null +++ b/docs/src/assets/safari-pinned-tab.svg @@ -0,0 +1,4438 @@ + + + + +Created by potrace 1.14, written by Peter Selinger 2001-2017 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/src/assets/site.webmanifest b/docs/src/assets/site.webmanifest new file mode 100644 index 00000000..b20abb7c --- /dev/null +++ b/docs/src/assets/site.webmanifest @@ -0,0 +1,19 @@ +{ + "name": "", + "short_name": "", + "icons": [ + { + "src": "/android-chrome-192x192.png", + "sizes": "192x192", + "type": "image/png" + }, + { + "src": "/android-chrome-512x512.png", + "sizes": "512x512", + "type": "image/png" + } + ], + "theme_color": "#ffffff", + "background_color": "#ffffff", + "display": "standalone" +} diff --git a/docs/src/bases.md b/docs/src/bases.md new file mode 100644 index 00000000..76785d7f --- /dev/null +++ b/docs/src/bases.md @@ -0,0 +1,33 @@ +# Bases for tangent spaces + +The following functions and types provide support for bases of the tangent space of different manifolds. +Moreover, bases of the cotangent space are also supported, though this description focuses on the tangent space. +An orthonormal basis of the tangent space ``T_p \mathcal M`` of (real) dimension ``n`` has a real-coefficient basis ``e_1, e_2, …, e_n`` if ``\mathrm{Re}(g_p(e_i, e_j)) = δ_{ij}`` for each ``i,j ∈ \{1, 2, …, n\}`` where ``g_p`` is the Riemannian metric at point ``p``. +A vector ``X`` from the tangent space ``T_p \mathcal M`` can be expressed in Einstein notation as a sum ``X = X^i e_i``, where (real) coefficients ``X^i`` are calculated as ``X^i = \mathrm{Re}(g_p(X, e_i))``. + +The main types are: + +* [`DefaultOrthonormalBasis`](@ref), which is designed to work when no special properties of the tangent space basis are required. + It is designed to make [`get_coordinates`](@ref) and [`get_vector`](@ref) fast. +* [`DiagonalizingOrthonormalBasis`](@ref), which diagonalizes the curvature tensor and makes the curvature in the selected direction equal to 0. +* [`ProjectedOrthonormalBasis`](@ref), which projects a basis of the ambient space and orthonormalizes projections to obtain a basis in a generic way. +* [`CachedBasis`](@ref), which stores (explicitly or implicitly) a precomputed basis at a certain point. + +The main functions are: + +* [`get_basis`](@ref) precomputes a basis at a certain point. +* [`get_coordinates`](@ref) returns coordinates of a tangent vector. +* [`get_vector`](@ref) returns a vector for the specified coordinates. +* [`get_vectors`](@ref) returns a vector of basis vectors. Calling it should be avoided for high-dimensional manifolds. + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["bases.jl"] +Order = [:type, :function] +``` + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["vector_spaces.jl"] +Order = [:type, :function] +``` diff --git a/docs/src/decorator.md b/docs/src/decorator.md new file mode 100644 index 00000000..b3e7458b --- /dev/null +++ b/docs/src/decorator.md @@ -0,0 +1,38 @@ +# A Decorator for manifolds + +Several properties of a manifold are often implicitly assumed, for example the choice of the (Riemannian) metric, the group structure or the embedding. The latter shall serve as an example how to either implicitly or explicitly specify the embedding to avoid re-implementations and/or distinguish different embeddings. + +## The abstract decorator + +When first implementing a manifold, it might be beneficial to dispatch certain computations to already existing manifolds. +For an embedded manifold that is isometrically embedded this might be the [`inner`](@ref) the manifold inherits in each tangent space from its embedding. + +This means we dispatch the default implementation of a function to some other manifold. +We refer to this as implicit decoration, since one can not “see” explicitly that a certain manifold inherits this property. +As a small example consider the [Sphere](https://juliamanifolds.github.io/Manifolds.jl/latest/manifolds/sphere.html), which in every tangent space inherits its metric from the embedding. Since in the default implementation in [Manifolds.jl](https://juliamanifolds.github.io/Manifolds.jl/stable/) points are represented by unit vectors and tangent vectors as vectors orthogonal to a point, we can just dispatch the inner product to the embedding without having to re-implement this. +The manifold using such an implicit dispatch just has to have [`AbstractDecoratorManifold`](@ref) as its super type. + +## Traits with a inheritance hierarchy + +The properties mentioned above might form a hierarchy. +For embedded manifolds, again, we might have just a manifold whose points are represented in some embedding. +If the manifold is even isometrically embedded, it is embedded but also inherits the Riemannian metric (by restriction). But it also inherits the functions form the plain embedding. +If it is even a submanifold, also further functions are inherited. + +We use a variation of [Tim Holy's Traits Trick](https://github.com/JuliaLang/julia/issues/2345#issuecomment-54537633) (THTT) which takes into account this nestedness of traits + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["nested_trait.jl"] +Order = [:type, :macro, :function] +``` + +Then the following functions and macros introduce the decorator traits + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["decorator_trait.jl"] +Order = [:type, :macro, :function] +``` + +For an example see the [(implicit) embedded manifold](@ref subsec-implicit-embedded). \ No newline at end of file diff --git a/docs/src/design.md b/docs/src/design.md new file mode 100644 index 00000000..4ade7d13 --- /dev/null +++ b/docs/src/design.md @@ -0,0 +1,175 @@ +# Main Design Principles + +The interface for a manifold is defined to be as generic as possible, such that applications can be implemented as independently as possible from an actual manifold. +This way, algorithms like those from [`Manopt.jl`](https://manoptjl.org) can be implemented on _arbitrary_ manifolds. + +The main design criteria for the interface are: + +* Aims to also provide _efficient_ _global state-free_, both _in-place_ and _out-of-place_ computations whenever possible. +* Provide a high level interface that is easy to use. + +Therefore this interface has 3 main features, that we will explain using two (related) +concepts, the [exponential map](https://en.wikipedia.org/wiki/Exponential_map_(Riemannian_geometry)) that maps a tangent vector ``X`` at a point ``p`` to a point ``q`` or mathematically ``\exp_p:T_p\mathcal M \to \mathcal M`` and its generalization, a [`retract`](@ref)ion ``\operatorname{retr}_p`` with same domain and range. + +You do not need to know their exact definition at this point, just that there is _one_ exponential map on a Riemannian manifold, and several retractions, where one of them is the exponential map (sometime called exponential retraction for completeness). Every retraction has its own subtype of the [`AbstractRetractionMethod`](@ref) that uniquely defines it. + +The following three design patterns aim to fulfill the criteria from above, while +also avoiding ambiguities in multiple dispatch using the [dispatch on one argument at a time](https://docs.julialang.org/en/v1/manual/methods/#Dispatch-on-one-argument-at-a-time) approach. + +## General order of parameters + +Since the central element for functions on a manifold is the manifold itself, it should always be the first parameter, even for mutating functions. Then the classical parametzers of a function (for example a point and a tangent vector for the retraction) follow and the final part are parameters to further dispatch on, which usually have their defaults. + +## A 3-Layer architecture for dispatch + +The general architecture consists of three layers + +* The high level interface for ease of use – and to dispatch on other manifolds. +* The intermediate layer to dispatch on different parameters in the last section, e.g. type of retraction or vector transport. +* The lowest layer for specific manifolds to dispatch on different types of points and tangent vectors. Usually this layer with a specific manifold and no optional parameters. + +These three layers are described in more detail in the following. The main motivation to introduce this separation is reduction of method ambiguity problems. + +### [Layer I: The high level interface and ease of use](@id design-layer1) + +THe highest layer for convenience of decorators. +A usual scheme is, that a manifold might assume several things implicitly, for example the default implementation of the sphere $\mathbb S^n$ using unit vectors in $\mathbb R^{n+1}$. +The embedding can be explicitly used to avoid re-implementations – the inner product can be “passed on” to its embedding. + +To do so, we “decorate” the manifold by making it an [`AbstractDecoratorManifold`](@ref) and activating the right traits see [the example](@ref manifold-tutorial). + +The explicit case of the [`EmbeddedManifold`](@ref) can be used to distinguish different embeddings of a manifold, but also their dispatch (onto the manifold or its embedding, depending on the type of embedding) happens here. + +Note that all other parameters of a function should be as unspecific as possible on this layer. + +With respect to the [dispatch on one argument at a time](https://docs.julialang.org/en/v1/manual/methods/#Dispatch-on-one-argument-at-a-time) paradigm, this layer dispatches the _manifold first_, but here we stay on an abstract type level. + +This layer ends usually in calling the same functions like `retract` but prefixed with a `_` to enter Layer II. + +!!! note + Usually only functions from this layer are exported from the interface, since these are the ones one should use for generic implementations. If you implement your own manifold, `import` the necessary lower layer functions as needed. + +### [Layer II: An internal dispatch interface for parameters](@id design-layer2) + +This layer is an interim layer to dispatch on the (optional/default) parameters of a function. +For example the last parameter of retraction: [`retract`](@ref) determines the type (variant) to be used. +The last function in the previous layer calls `_retract`, which is an internal function. + +On this layer, e.g. for `_retract` only these last parameters should be typed, the manifold should stay at the [`AbstractManifold`](@ref) level. It dispatches on different functions per existing parameter type (and might pass this one further on, if it has fields). + +Note that this layer is an internal one. It is automatically called for functions with parameters to dispatch on. + +It should only be extended when introducing new such parameter types, for example when introducing a new type of a retraction. + +The functions from this layer should never be called directly, are hence also not exported and carry the `_` prefix. +They should only be called as the final step in the previous layer. + +If the default parameters are not dispatched per type, using `_` might be skipped. +The following resolution might even be seen as a last step in layer I or the resolution here in layer II. + +```julia +exp(M::AbstractManifold, p, X, t::Real) = exp(M, p, t * X) +``` + +When there is no dispatch for different types of the optional parameter (here `t`), the `_` might be skipped. +One could hence see the last code line as a definition on Layer I that passes directly to Layer III, since there are not parameter to dispatch on. + +To close this section, let‘s look at an example. The high level (or level I) definition of the retraction is given by + +```julia +retract(M::AbstractManifold, p, X, m::AbstractRetractionMethod=default_retraction_method(M)) = _retract(M, p, X, m) +``` + +This level now dispatches on different retraction types. It usually passes to specific functions implemented in Layer III, +here for example + +```julia +_retract(M::AbstractManifold, p, X, m::Exponentialretraction) = exp(M, p, X) +_retract(M::AbstractManifold, p, X, m::PolarRetraction) = retract_polar(M, p, X) +``` + +or the [`PolarRetraction`](@ref) which dispatches to [`retract_polar`](@ref). + +For further details and dispatches, see [retractions and inverse retractions](@ref sec-retractions) for an overview. + +!!! note + The documentation should be attached to the high level functions, since this again fosters ease of use. + If you implement a polar retraction, you should write a method of function `retract_polar` but the doc string should be attached to `retract(::M, ::P, ::V, ::PolarRetraction)` for your types `::M, ::P, ::V` of the manifold, points and vectors, respectively. + +To summarize, with respect to the [dispatch on one argument at a time](https://docs.julialang.org/en/v1/manual/methods/#Dispatch-on-one-argument-at-a-time) paradigm, this layer dispatches the (optional) _parameters second_. + +### [Layer III: The base layer with focus on implementations](@id design-layer3) + +This lower level aims for the actual implementation of the function avoiding ambiguities. +It should have as few as possible optional parameters and as concrete as possible types. +This means + +* the function name should be similar to its high level parent (for example `retract` and `retract_polar` above) +* The manifold type in method signature should always be as narrow as possible. +* The points/vectors should either be untyped (for the default representation or if there is only one implementation) or provide all type bounds (for second representations or when using [`AbstractManifoldPoint`](@ref) and [`TVector`](@ref TVector)). + +The first step that often happens on this level is memory allocation and calling the mutating function. If faster, it might also implement the function at hand itself. + +Usually functions from this layer are not exported, when they have an analogue on the first layer. For example the function [`retract_polar`](@ref)`(M, p, X)` is not exported, since when using the interface one would use the [`PolarRetraction`](@ref) or to be precise call [`retract`](@ref)`(M, p, X, PolarRetraction())`. +When implementing your own manifold, you have to import functions like these anyways. + +To summarize, with respect to the [dispatch on one argument at a time](https://docs.julialang.org/en/v1/manual/methods/#Dispatch-on-one-argument-at-a-time) paradigm, this layer dispatches the _concrete manifold and point/vector types last_. + +## [Mutating and allocating functions](@id mutating-and-nonmutating) + +Every function, where this is applicable should provide a mutating and an allocating variant. +For example for the exponential map `exp(M, p, X)` returns a _new_ point `q` where the result is computed in. +On the other hand `exp!(M, q, p, X)` computes the result in place of `q`, where the design of the implementation +should keep in mind that also `exp!(M, p, p, X)` should correctly overwrite `p`. + +The interface provides a way to determine the allocation type and a result to compute/allocate +the resulting memory, such that the default implementation allocating functions, like [`exp`](@ref) is to allocate the resulting memory and call [`exp!`](@ref). + +!!! note + it might be useful to provide two distinct implementations, for example when using AD schemes. + The default is meant for ease of use (concerning implementation), since then one has to just implement the mutating variants. + +Non-mutating functions in `ManifoldsBase.jl` are typically implemented using mutating variants. +Allocation of new points is performed using a custom mechanism that relies on the following functions: + +The [`allocate`](@ref) function behaves like `similar` for simple representations of points and vectors (for example `Array{Float64}`). +For more complex types, such as nested representations of [`PowerManifold`](@ref) (see [`NestedPowerRepresentation`](@ref)), checked types like [`ValidationMPoint`](@ref) and more it operates differently. +While `similar` only concerns itself with the higher level of nested structures, `allocate` maps itself through all levels of nesting until a simple array of numbers is reached and then calls `similar`. +The difference can be most easily seen in the following example: + +```julia +julia> x = similar([[1.0], [2.0]]) +2-element Array{Array{Float64,1},1}: + #undef + #undef + +julia> y = allocate([[1.0], [2.0]]) +2-element Array{Array{Float64,1},1}: + [6.90031725726027e-310] + [6.9003678131654e-310] + +julia> x[1] +ERROR: UndefRefError: access to undefined reference +Stacktrace: + [1] getindex(::Array{Array{Float64,1},1}, ::Int64) at ./array.jl:744 + [2] top-level scope at REPL[12]:1 + +julia> y[1] +1-element Array{Float64,1}: + 6.90031725726027e-310 +``` + +The function [`allocate_result`](@ref ManifoldsBase.allocate_result) allocates a correct return value. It takes into account the possibility that different arguments may have different numeric [`number_eltype`](@ref) types thorough the [`allocate_result_type`](@ref ManifoldsBase.allocate_result_type) function. +The most prominent example of the usage of this function is the logarithmic function [`log`](@ref) when used with typed points. +Lets assume on a manifold `M` the have points of type `P` and corresponding tangent vector types `V`. +then the logarithmic map has the signature + +```julia +log(::M, ::P, ::P) +``` + +but the return type would be ``V``, whose internal sizes (fields/arrays) will depend on the concrete type of one of the points. This is accomplished by implementing a method `allocate_result(::M, ::typeof(log), ::P, ::P)` that returns the concrete variable for the result. This way, even with specific types, one just has to implement `log!` and the one line for the allocation. + +!!! note + This dispatch from the allocating to the mutating variant happens in Layer III, that is, functions like `exp` or `retract_polar` (but not `retract` itself) allocate their result (using `::typeof(retract)` for the second function) + and call the mutating variant `exp!` and `retract_polar!` afterwards. diff --git a/docs/src/example.md b/docs/src/example.md new file mode 100644 index 00000000..e6354286 --- /dev/null +++ b/docs/src/example.md @@ -0,0 +1,239 @@ +# [How to implement your own manifold](@id manifold-tutorial) + +```@meta +CurrentModule = ManifoldsBase +``` + +## Introduction + +This tutorial explains, how to implement a manifold using the `ManifoldsBase.jl` interface. +We assume that you are familiar with the basic terminology on Riemannian manifolds, especially +the dimension of a manifold, the exponential map, and the inner product on tangent spaces. +To read more about this you can for example check [[do Carmo, 1992](#doCarmo1992)], Chapter 3, first. + +Furthermore, we will look into a manifold that is isometrically embedded into e Euclidean space. + +In general you need just a datatype (`struct`) that inherits from [`AbstractManifold`](@ref) to define a manifold. No function is _per se_ required to be implemented. +However, it is a good idea to provide functions that might be useful to others, for example [`check_point`](@ref check_point) and [`check_vector`](@ref check_point), as we do in this tutorial. + +We start with two technical preliminaries. If you want to start directly, you can [skip](@ref manifold-tutorial-task) this paragraph and revisit it for two of the implementation details. + +After that, we will + +* [model](@ref manifold-tutorial-task) the manifold +* [implement](@ref manifold-tutorial-checks) two tests, so that points and tangent vectors can be checked for validity, for example also within [`ValidationManifold`](@ref), +* [implement](@ref manifold-tutorial-fn) two functions, the exponential map and the manifold dimension. +* [decorate](@ref manifold-tutorial-emb) the manifold with an embedding to gain further features. + +## [Technical preliminaries](@id manifold-tutorial-prel) + +There are only two small technical things we need to explain at this point. +First of all our [`AbstractManifold`](@ref)`{𝔽}` has a parameter `𝔽`. +This parameter indicates the [`number_system`](@ref) the manifold is based on, for example `ℝ` for rel manifolds, which is short for `RealNumbers()`. +This indicates that the manifold is a real manifold. + +## [Startup](@id manifold-tutorial-startup) + +As a start, let's load `ManifoldsBase.jl` and import the functions we consider throughout this tutorial. + +```@example manifold-tutorial +using ManifoldsBase, LinearAlgebra, Test +import ManifoldsBase: check_point, check_vector, manifold_dimension, exp!, inner +import Base: show +``` + +We load `LinearAlgebra` for some computations. `Test` is only loaded for illustrations in the examples. + +We import the mutating variant of the [`exp`](@ref)onential map, see [the design section on mutating and nonmutating functions](@ref mutating-and-nonmutating). + +## [The manifold](@id manifold-tutorial-task) + +The manifold we want to implement here a sphere, with a radius $r$. +Since this radius is a property inherent to the manifold, it will become a field of the manifold. +The second information, we want to store is the dimension of the sphere, for example whether it's the 1-sphere, i.e. the circle, represented by vectors $p\in\mathbb R^2$ or norm $r$ or the 2-sphere in $\mathbb R^3$ of radius $r$. +Since the latter might be something we want to [dispatch](https://en.wikipedia.org/wiki/Multiple_dispatch) on, we model it as a parameter of the type. +In general the `struct` of a manifold should provide information about the manifold, which are inherent to the manifold or has to be available without a specific point or tangent vector present. +This is -- most prominently -- a way to determine the manifold dimension. + +Note that this a slightly more general manifold than the [Sphere](https://juliamanifolds.github.io/Manifolds.jl/stable/manifolds/sphere.html) in [Manifolds.jl](https://juliamanifolds.github.io/Manifolds.jl/stable/index.html) + +For our example we define the following struct. +While a first implementation might also just take [`AbstractManifold`](@ref)`{ℝ}` as supertype, we directly take +[`AbstractDecoratorManifold`](@ref)`{ℝ}, which will be useful later. For now it does not make a difference. + +```@example manifold-tutorial +""" + ScaledSphere{N} <: AbstractDecoratorManifold{ℝ} + +Define an `N`-sphere of radius `r`. Construct by `ScaledSphere(radius,n)`. +""" +struct ScaledSphere{N} <: AbstractDecoratorManifold{ManifoldsBase.ℝ} where {N} + radius::Float64 +end +ScaledSphere(radius, n) = ScaledSphere{n}(radius) +Base.show(io::IO, M::ScaledSphere{n}) where {n} = print(io, "ScaledSphere($(M.radius),$n)") +nothing #hide +``` + +Here, the last line just provides a nicer print of a variable of that type. +Now we can already initialize our manifold that we will use later, the $2$-sphere of radius $1.5$. + +```@example manifold-tutorial +S = ScaledSphere(1.5, 2) +``` + +## [Checking points and tangents](@id manifold-tutorial-checks) + +If we have now a point, represented as an array, we would first like to check, that it is a valid point on the manifold. +For this one can use the easy interface [`is_point`](@ref is_point(M::AbstractManifold, p; kwargs...)). +This is a function on [layer 1](@ref design-layer1) which handles specialy cases, so it should not be implemented. +The actual functions where we dispatch per manifold are on [layer 3](@ref design-layer3). +For the test of points this function is [`check_point`](@ref ManifoldsBase.check_point) which we actually will implement. +This function returns `nothing` if the point is a valid point and returns an error not throw it) otherwise. +This is usually a `DomainError`. + +We have to check two things: that a point `p` is a vector with `N+1` entries and its norm is the desired radius. +To spare a few lines, we can use [short-circuit evaluation](https://docs.julialang.org/en/v1/manual/control-flow/#Short-Circuit-Evaluation-1) instead of `if` statements. +If something has to only hold up to precision, we can pass that down, too using the `kwargs...`. + +```@example manifold-tutorial +function check_point(M::ScaledSphere{N}, p; kwargs...) where {N} + (size(p)) == (N+1,) || return DomainError(size(p),"The size of $p is not $((N+1,)).") + if !isapprox(norm(p), M.radius; kwargs...) + return DomainError(norm(p), "The norm of $p is not $(M.radius).") + end + return nothing +end +nothing #hide +``` + +Similarly, we can verify, whether a tangent vector `X` is valid. +It has to fulfill the same size requirements and it has to be orthogonal to `p`. +We can again use the `kwargs`, but also provide a way to check `p`, too. + +```@example manifold-tutorial +function check_vector(M::ScaledSphere, p, X; kwargs...) + size(X) != size(p) && return DomainError(size(X), "The size of $X is not $(size(p)).") + if !isapprox(dot(p,X), 0.0; kwargs...) + return DomainError(dot(p,X), "The tangent $X is not orthogonal to $p.") + end + return nothing +end +nothing #hide +``` + +to test points we can now use + +```@example manifold-tutorial +is_point(S, [1.0,0.0,0.0]) # norm 1, so not on S, returns false +@test_throws DomainError is_point(S, [1.5,0.0], true) # only on R^2, throws an error. +p = [1.5,0.0,0.0] +X = [0.0,1.0,0.0] +# The following two tests return true +[ is_point(S, p); is_vector(S,p,X) ] +``` + +## [Functions on the manifold](@id manifold-tutorial-fn) + +For the [`manifold_dimension`](@ref manifold_dimension(M::AbstractManifold)) we have to just return the `N` parameter + +```@example manifold-tutorial +manifold_dimension(::ScaledSphere{N}) where {N} = N +manifold_dimension(S) +``` + +Note that we can even omit the variable name in the first line since we do not have to access any field or use the variable otherwise. + +To implement the exponential map, we have to implement the formula for great arcs, given a start point `p` and a direction `X` on the $n$-sphere of radius $r$ the formula reads + +````math +\exp_p X = \cos(\frac{1}{r}\lVert X \rVert)p + \sin(\frac{1}{r}\lVert X \rVert)\frac{r}{\lVert X \rVert}X. +```` + +Note that with this choice we for example implicitly assume a certain metric. This is completely fine. We only have to think about specifying a metric explicitly, when we have (at least) two different metrics on the same manifold. + +An implementation of the mutation version, see the [technical note](@ref manifold-tutorial-prel), reads + +```@example manifold-tutorial +function exp!(M::ScaledSphere{N}, q, p, X) where {N} + nX = norm(X) + if nX == 0 + q .= p + else + q .= cos(nX/M.radius)*p + M.radius*sin(nX/M.radius) .* (X./nX) + end + return q +end +nothing #hide +``` + +A first easy check can be done taking `p` from above and any vector `X` of length `1.5π` from its tangent space. +The resulting point is opposite of `p`, i.e. `-p`. + +```@example manifold-tutorial +q = exp(S, p, [0.0,1.5π,0.0]) +[isapprox(p, -q); is_point(S, q)] +``` + +## [Adding an isometric embedding](@id manifold-tutorial-emb) + +Since the sphere is isometrically embedded, we do not have to implement the [`inner`](@ref)`(M,p,X,Y)`for tangent vectors, but we can “delegate” it to the embedding. The embedding is the [Euclidean](https://juliamanifolds.github.io/Manifolds.jl/stable/manifolds/essentialmanifold.html) which is also available in `ManifoldsBase` as `DefaultManifold` for testing purposes. + +```@example manifold-tutorial +using ManifoldsBase: DefaultManifold, IsIsometricEmbeddedManifold +import ManifoldsBase: active_traits, merge_traits, get_embedding +``` + +Now we can activate a decorator by specifying that the sphere has the [`IsIsometricEmbeddedManifold`](@ref) trait for the manifold by writing + +```@example manifold-tutorial +active_traits(::ScaledSphere, args...) = merge_traits(IsIsometricEmbeddedManifold()) +nothing #hide +``` + +and then specifying that said embedding is the default manifold + +```@example manifold-tutorial +get_embedding(::ScaledSphere{N}) where {N} = DefaultManifold(N+1) +nothing #hide +``` + +Now metric related functions are passed to this embedding, so the inner product works by using the embedding + +Now we can compute the inner product by calling [`inner`](@ref) + +```@example manifold-tutorial +X = [0.0, 0.1, 3.0] +Y = [0.0, 4.0, 0.2] + # returns 1.0 by calling the inner product in DefaultManifold(3) +inner(S, p, X, Y) +``` + +## [Conclusion](@id manifold-tutorial-outlook) + +You can now just continue implementing further functions from `ManifoldsBase.jl`. +but with just [`exp!`](@ref exp!(M::AbstractManifold, q, p, X)) you for example already have + +* [`geodesic`](@ref geodesic(M::AbstractManifold, p, X)) the (not necessarily shortest) geodesic emanating from `p` in direction `X`. +* the [`ExponentialRetraction`](@ref), that the [`retract`](@ref retract(M::AbstractManifold, p, X)) function uses by default. + +For the [`shortest_geodesic`](@ref shortest_geodesic(M::AbstractManifold, p, q)) the implementation of a logarithm [`log`](@ref ManifoldsBase.log(M::AbstractManifold, p, q)), again better a [`log!`](@ref log!(M::AbstractManifold, X, p, q)) is necessary. + +Sometimes a default implementation is provided; for example if you implemented [`inner`](@ref inner(M::AbstractManifold, p, X, Y)), the [`norm`](@ref norm(M, p, X)) is defined. You should overwrite it, if you can provide a more efficient version. For a start the default should suffice. +With [`log!`](@ref log!(M::AbstractManifold, X, p, q)) and [`inner`](@ref inner(M::AbstractManifold, p, X, Y)) you get the [`distance`](@ref distance(M::AbstractManifold, p, q)), and so. + +In summary with just these few functions you can already explore the first things on your own manifold. Whenever a function from `Manifolds.jl` requires another function to be specifically implemented, you get a reasonable error message. + +## Literature + +```@raw html +
    +
  • + [doCarmo, 1992] + M. P. do Carmo, + Riemannian Geometry, + Birkhäuser Boston, 1992, + ISBN: 0-8176-3490-8. +
  • +
+``` diff --git a/docs/src/functions.md b/docs/src/functions.md new file mode 100644 index 00000000..b0d0b440 --- /dev/null +++ b/docs/src/functions.md @@ -0,0 +1,70 @@ +# Functions on manifolds + +This page collects several basic functions on manifolds. + +## [The exponential map, the logarithmic map, and geodesics](@id exp-and-log) + +Geodesics are the generalizations of a straight line to manifolds, i.e. their intrinsic acceleration is zero. +Together with geodesics one also obtains the exponential map and its inverse, the logarithmic map. +Informally speaking, the exponential map takes a vector (think of a direction and a length) at one point and returns another point, +which lies towards this direction at distance of the specified length. The logarithmic map does the inverse, i.e. given two points, it tells which vector “points towards” the other point. + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["exp_log_geo.jl"] +Order = [:function] +``` + +## [Parallel transport](@id subsec-parallel-transport) + +While moving vectors from one base point to another is the identity in the Euclidean space – or in other words all tangent spaces (directions one can “walk” into) are the same. This is different on a manifold. + +If we have two points ``p,q ∈ \mathcal M``, we take a ``c: [0,1] → \mathcal M`` connecting the two points, i.e. ``c(0) = p`` and ``c(1) = q``. this could be a (or the) geodesic. +If we further consider a vector field ``X: [0,1] → T\mathcal M``, i.e. where ``X(t) ∈ T_{c(t)}\mathcal M``. +Then the vector field is called _parallel_ if its covariant derivative ``\frac{\mathrm{D}}{\mathrm{d}t}X(t) = 0`` for all ``t∈ |0,1]``. + +If we now impose a value for ``X=X(0) ∈ T_p\mathcal M``, we obtain an ODE with an initial condition. +The resulting value ``X(1) ∈ T_q\mathcal M`` is called the _parallel transport_ of `X` along ``c`` +or in case of a geodesic the _parallel transport of `X` from `p` to `q`. + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["parallel_transport.jl"] +Order = [:function] +``` + +## Further functions on manifolds + +### General functions provided by the interface + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["ManifoldsBase.jl"] +Order = [:type, :function] +Public=true +Private=false +``` + +### Internal functions + +While you should always add your documentation to functions from the last section, some of the functions dispatch onto functions on [the lower layer](@ref design-layer3). These are the ones +you usually implement for your manifold – unless there is no lower level function called, like for the [`manifold_dimension`](@ref). + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["ManifoldsBase.jl"] +Order = [:function] +Public=false +Private=true +``` + +## Error Messages + +especially to collect and display errors on [`AbstractPowerManifold`](@ref ManifoldsBase.AbstractPowerManifold)s the following +component and collection error messages are available. + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["errors.jl"] +Order = [:type] +``` diff --git a/docs/src/index.md b/docs/src/index.md new file mode 100644 index 00000000..5a8e545d --- /dev/null +++ b/docs/src/index.md @@ -0,0 +1,25 @@ +# ManifoldsBase.jl + +`ManifoldsBase.jl` is a lightweight interface for manifolds. + +This package provides an interface, so you probably either want to add it as a dependency to your project/package to work on manifold generically or implement a new manifold. +A package that (only) depends on `ManifoldsBase.jl`, see [Manopt.jl](https://manoptjl.org/stable/), which implements optimization algorithms on manifolds using this interface, i.e. they can be used with any manifold based on `ManifoldsBase.jl`. A library of manifolds implemented using this interface is provided see [Manifolds.jl](https://juliamanifolds.github.io/Manifolds.jl/stable/). + +Your package is using `ManifoldsBase`? Give us a note and we add you here. + +## Citation + +If you use `ManifoldsBase.jl` in your work, please cite the following paper, +which covers both the basic interface as well as the performance for `Manifolds.jl`. + +```biblatex +@online{2106.08777, + Author = {Seth D. Axen and Mateusz Baran and Ronny Bergmann and Krzysztof Rzecki}, + Title = {Manifolds.jl: An Extensible Julia Framework for Data Analysis on Manifolds}, + Year = {2021}, + Eprint = {2106.08777}, + Eprinttype = {arXiv}, +} +``` + +Note that the citation is in [BibLaTeX](https://ctan.org/pkg/biblatex) format. diff --git a/docs/src/manifold_type.md b/docs/src/manifold_type.md new file mode 100644 index 00000000..51d24d4b --- /dev/null +++ b/docs/src/manifold_type.md @@ -0,0 +1,35 @@ +# The Manifold Type + +## Number systems + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["numbers.jl"] +Order = [:type, :function] +``` + +## The main type: The `AbstractManifold` + +The main type is the [`AbstractManifold`](@ref). It represents the manifold per se. +During the documentation we will use the [Euclidean Space](https://juliamanifolds.github.io/Manifolds.jl/latest/manifolds/euclidean.html) and the [Sphere](https://juliamanifolds.github.io/Manifolds.jl/latest/manifolds/sphere.html) (both implemented in [Manifolds.jl](https://github.com/JuliaManifolds/Manifolds.jl)) as easy examples to often illustrate properties and features of this interface + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["maintypes.jl"] +Order = [:type, :function] +``` + +which should store information about the manifold, for example parameters inherent to the manifold. + +## Points and Tangent Vectors + +Points and tangent vectors do not necessarily have to be typed. Usually +one can just use any type. When a manifold has multiple representations, these should be distinguished by point and vector types. Then it might be that types just encapsulate a vector value. +This is taken into account by the following macros, that forward several actions just to this field. Most prominently vector operations for the tangent vectors. +If there is still a default case, a macro sets this type to be equivalent to calling the manifold functions just with the types field that carries the value. + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["point_vector_fallbacks.jl"] +Order = [:type, :function, :macro] +``` diff --git a/docs/src/manifolds.md b/docs/src/manifolds.md new file mode 100644 index 00000000..fafd0b95 --- /dev/null +++ b/docs/src/manifolds.md @@ -0,0 +1,63 @@ +# Manifolds + +While the interface `ManifoldsBase.jl` does not cover concrete manifolds, it provides a few +helpers to build or create manifolds based on existing manifolds + +## [(Abstract) Power Manifold](@id sec-power-manifold) + +A power manifold is constructed like higher dimensional vector spaces are formed from the real line, just that for every point ``p = (p_1,\ldots,p_n) ∈ \mathcal M^n`` on the power manifold ``\mathcal M^n`` the entries of ``p`` are points ``p_1,\ldots,p_n ∈ \mathcal M`` on some manifold ``\mathcal M``. + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["src/PowerManifold.jl"] +Order = [:macro, :type, :function] +``` + +## ValidationManifold + +[`ValidationManifold`](@ref) is a simple decorator using the [`AbstractDecoratorManifold`](@ref) that “decorates” a manifold with tests that all involved points and vectors are valid for the wrapped manifold. +For example involved input and output paratemers are checked before and after running a function, repectively. +This is done by calling [`is_point`](@ref) or [`is_vector`](@ref) whenever applicable. + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["ValidationManifold.jl"] +Order = [:macro, :type, :function] +``` + +## DefaultManifold + +[`DefaultManifold`](@ref ManifoldsBase.DefaultManifold) is a simplified version of [`Euclidean`](https://juliamanifolds.github.io/Manifolds.jl/latest/manifolds/euclidean.html) and demonstrates a basic interface implementation. +It can be used to perform simple tests. +Since when using `Manifolds.jl` the [`Euclidean`](https://juliamanifolds.github.io/Manifolds.jl/latest/manifolds/euclidean.html) is available, the `DefaultManifold` itself is not exported. + +```@docs +ManifoldsBase.DefaultManifold +``` + +## [EmbeddedManifold](@id sec-embedded-manifold) + +The embedded manifold is a manifold ``mathcal N`` which is modelled _explicitly_ mentioning its embedding ``\mathcal N`` in which the points and tangent vectors are represented. +Most prominently [`is_point`](@ref) and [`is_vector`](@ref) of an embedded manifold are implemented to check whether the point is a valid point in the embedding. This can of course still be extended by further tests. +`ManifoldsBase.jl` provides two possibilities of easily introducing this in order to dispatch some functions to the embedding. + +## [Implicit case: the `IsEmbeddedManifold` Trait](@id subsec-implicit-embedded) + +For the implicti case, your manifold has to be a subtype of the [`AbstractDecoratorManifold`](@ref). +Setting the [`active_traits`](@ref ManifoldsBase.active_traits) function to the [`AbstractTrait`](@ref) +[`IsEmbeddedManifold`](@ref), makes a manifold an embedded manifold. you just have to also define [`get_embedding`](@ref) such that functions are passed on to that embedding. +This is the implicit case, since the manifold type itself does not carry any information about the embedding, just the trait and the function definition do. + +## [Explicit case: the `EmbeddedManifold`](@id subsec-explicit-embedded) + +The [`EmbeddedManifold`](@ref) itself is an [`AbstractDecoratorManifold`](@ref) so it is a case of the implicit embedding itself, but internally stores both the original manifold and the embedding. +They are also parameters of the type. +This way, additional embeddings can be modelled. That is, if the manifold is implemented using the implicit embedding approach from before but can also be implemented using a _different_ embedding, then this method should be chosen, since you can dispatch functions that you want to implement in this embedding then on the type which explicitly has the manifold and its embedding as parameters. + +Hence this case should be used for any further embedding after the first or if the default implementation works without an embedding and the alternative needs one. + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["EmbeddedManifold.jl"] +Order = [:type, :macro, :function] +``` diff --git a/docs/src/projections.md b/docs/src/projections.md new file mode 100644 index 00000000..029bdba2 --- /dev/null +++ b/docs/src/projections.md @@ -0,0 +1,43 @@ +### Projections + +A manifold might be embedded in some space. +Often this is implicitly assumed, for example the complex [Circle](https://juliamanifolds.github.io/Manifolds.jl/latest/manifolds/circle.html) is embedded in the complex plane. +Let‘s keep the circle in mind in the following as a simple example. +For the general case see of explicitly stating an embedding and/or distinguising several, different embeddings, see [Embedded Manifolds](@ref sec-embedded-manifold) below. + +To make this a little more concrete, let‘s assume we have a manifold ``\mathcal M`` which is embedded in some manifold ``\mathcal N`` and the image ``i(\mathcal M)`` of the embedding function ``i`` is a closed set (with respect to the topology on ``\mathcal N``). Then we can do two kinds of projections. + +To make this concrete in an example for the Circle ``\mathcal M=\mathcal C := \{ p ∈ ℂ | |p| = 1\}`` +the embedding can be chosen to be the manifold ``N = ℂ`` and due to our representation of ``\mathcal C`` as complex numbers already, we have ``i(p) = p`` the identity as the embedding function. + +1. Given a point ``p∈\mathcal N`` we can look for the closest point on the manifold ``\mathcal M`` formally as + +```math + \operatorname*{arg\,min}_{q\in \mathcal M} d_{\mathcal N}(i(q),p) +``` + +And this resulting ``q`` we call the projection of ``p`` onto the manifold ``\mathcal M``. + +2. Given a point ``p∈\mathcal M`` and a vector in ``X\inT_{i(p)}\mathcal N`` in the embedding we can similarly look for the closest point to ``Y∈ T_p\mathcal M`` using the pushforward ``\mathrm{d}i_p`` of the embedding. + +```math + \operatorname*{arg\,min}_{Y\in T_p\mathcal M} \lVert \mathrm{d}i(p)[Y] - X \rVert_{i(p)} +``` + +And we call the resulting ``Y`` the projection of ``X`` onto the tangent space ``T_p\mathcal M`` at ``p``. + +Let‘s look at the little more concrete example of the complex Circle again. +Here, the closest point of ``p ∈ ℂ`` is just the projection onto the circle, or in other words ``q = \frac{p}{\lvert p \rvert}``. A tangent space ``T_p\mathcal C`` in the embedding is the line orthogonal to a point ``p∈\mathcal C`` through the origin. +This can be better visualized by looking at ``p+T_p\mathcal C`` which is actually the line tangent to ``p``. Note that this shift does not change the resulting projection relative to the origin of the tangent space. + +Here the projection can be computed as the classical projection onto the line, i.e. ``Y = X - ⟨X,p⟩X``. + +this is illustrated in the following figure + +![An example illustrating the two kinds of projections on the Circle.](assets/images/projection_illustration_600.png) + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["projections.jl"] +Order = [:function] +``` diff --git a/docs/src/retractions.md b/docs/src/retractions.md new file mode 100644 index 00000000..9af7cd9e --- /dev/null +++ b/docs/src/retractions.md @@ -0,0 +1,38 @@ +## [Retractions and inverse Retractions](@id sec-retractions) + +The [exponential and logarithmic map](@ref exp-and-log) might be too expensive to evaluate or not be available in a very stable numerical way. Retractions provide a possibly cheap, fast and stable alternative. + +The following figure compares the exponential map [`exp`](@ref)`(M, p, X)` on the [Circle](https://juliamanifolds.github.io/Manifolds.jl/latest/manifolds/circle.html) `(ℂ)` (or [`Sphere`](https://juliamanifolds.github.io/Manifolds.jl/latest/manifolds/sphere.html)`(1)` embedded in $ℝ^2$ with one possible retraction, the one based on projections. Note especially that ``\mathrm{dist}(p,q)=\lVert X\rVert_p`` while this is not the case for ``q'``. + +![A comparson of the exponential map and a retraction on the Circle.](assets/images/retraction_illustration_600.png) + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["retractions.jl"] +Order = [:function] +Private = false +Public = true +``` + +## Types of Retractions + +To distinguish different types of retractions, the last argument of the (inverse) retraction +specifies a type. The following ones are available. + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["retractions.jl"] +Order = [:type] +``` + +## The lower layer functions + +While you should always add your documentation to [`retract`](@ref) or [`retract!`](@ref) when implementing new manifolds, the actual implementation happens on the following functions on [the lower layer](@ref design-layer3). + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["retractions.jl"] +Order = [:function] +Public = false +Private = true +``` diff --git a/docs/src/vector_transports.md b/docs/src/vector_transports.md new file mode 100644 index 00000000..3d1e785a --- /dev/null +++ b/docs/src/vector_transports.md @@ -0,0 +1,50 @@ +## Vector transport + +Similar to the [exponential and logarithmic map](@ref exp-and-log) also the [parallel transport](@ref subsec-parallel-transport) might be costly to compute, especially when there is no closed form solution known and it has to be approximated with numerical methods. + +Similar to the [retraction and its inverse](@ref sec-retractions) the generalisation of the parallel transport can be phrased as follows + +A _vector transport_ is a way to transport a vector between two tangent spaces. +Let ``p,q ∈ \mathcal M`` be given, ``c`` the curve along which we want to transport (cf. [parallel transport](@ref subsec-parallel-transport), for example a geodesic or a geodesic or curve given by a retraction. We can speficy the geodesic or curve a retraction realises for example by a direction ``d``. + +More precisely using [^AbsilMahonySepulchre2008], Def. 8.1.1, a vector transport +``T_{p,d}: T_p\mathcal M \to T_q\mathcal M``, ``p∈ \mathcal M``, ``Y∈ T_p\mathcal M`` is a smooth mapping +associated to a retraction ``\operatorname{retr}_p(Y) = q`` such that + +1. (associated retraction) ``\mathcal T_{p,d}X ∈ T_q\mathcal M`` if and only if ``q = \operatorname{retr}_p(d)``. +2. (consistency) ``\mathcal T_{p,0_p}X = X`` for all ``X∈T_p\mathcal M`` +3. (linearity) ``\mathcal T_{p,d}(αX+βY) = \mathcal αT_{p,d}X + \mathcal βT_{p,d}Y`` + +hold. + +Currently the following methods for vector transport are defined in `ManifoldsBase.jl`. + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["vector_transport.jl"] +Order = [:function] +Public=true +Private=false +``` + +## Types of vector transports + +To distinguish different types of vector transport we introduce the [`AbstractVectorTransportMethod`](@ref). The following concrete types are available. + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["vector_transport.jl"] +Order = [:type] +``` + +## The lower layer functions + +While you should always add your documentation to the first layer vector transport methods above when implementing new manifolds, the actual implementation happens on the following functions on [the lower layer](@ref design-layer3). + +```@autodocs +Modules = [ManifoldsBase] +Pages = ["vector_transport.jl"] +Order = [:function] +Public = false +Private = true +``` diff --git a/src/DecoratorManifold.jl b/src/DecoratorManifold.jl deleted file mode 100644 index 3142e920..00000000 --- a/src/DecoratorManifold.jl +++ /dev/null @@ -1,808 +0,0 @@ -# -# Helper -# -@inline _extract_val(::Val{T}) where {T} = T - -#! format: off -# turn formatting for for the following functions -# due to the if with returns inside (formatter puts a return upfront the if) -function _split_signature(sig::Expr) - if sig.head == :where - where_exprs = sig.args[2:end] - call_expr = sig.args[1] - elseif sig.head == :call - where_exprs = [] - call_expr = sig - else - error("Incorrect syntax in $ex. Expected a :where or :call expression.") - end - fname = call_expr.args[1] - if isa(call_expr.args[2], Expr) && call_expr.args[2].head == :parameters - # we have keyword arguments - callargs = call_expr.args[3:end] - kwargs_list = call_expr.args[2].args - else - callargs = call_expr.args[2:end] - kwargs_list = [] - end - argnames = map(callargs) do arg - if isa(arg, Expr) - return arg.args[1] - else - return arg - end - end - argtypes = map(callargs) do arg - if isa(arg, Expr) - return arg.args[2] - else - return Any - end - end - - kwargs_call = map(kwargs_list) do kwarg - if kwarg.head === :... - return kwarg - else - if isa(kwarg.args[1], Symbol) - kwargname = kwarg.args[1] - else - kwargname = kwarg.args[1].args[1] - end - return :($kwargname = $kwargname) - end - end - - return (; - fname = fname, - where_exprs = where_exprs, - callargs = callargs, - kwargs_list = kwargs_list, - argnames = argnames, - argtypes = argtypes, - kwargs_call = kwargs_call, - fname__parent = Symbol(string(fname) * "__parent"), - fname__transparent = Symbol(string(fname) * "__transparent"), - fname__intransparent = Symbol(string(fname) * "__intransparent"), - ) -end -#! format: on -function _split_function(ex::Expr) - if ex.head == :function - sig = ex.args[1] - body = ex.args[2] - else - error("Incorrect syntax in $ex. Expected :function.") - end - - return (; body = body, _split_signature(sig)...) -end - -# -# Transparency types -# -""" - AbstractDecoratorType - -Decorator types can be used to specify a basic transparency for an [`AbstractDecoratorManifold`](@ref). -This can be seen as an initial (rough) transparency pattern to start a type with. - -Note that for a function `f` and it's mutating variant `f!` -* The function `f` is set to `:parent` to first invoke allocation and call of `f!` -* The mutating function `f!` is set to `transparent` -""" -abstract type AbstractDecoratorType end - -""" - DefaultDecoratorType <: AbstractDecoratorType - -A default decorator type, where all new functions are transparent by default. -""" -struct DefaultDecoratorType <: AbstractDecoratorType end - -# -# Type -# -""" - AbstractDecoratorManifold{𝔽,T<:AbstractDecoratorType} <: AbstractManifold{𝔽} - -An `AbstractDecoratorManifold` indicates that to some extent a manifold subtype -decorates another [`AbstractManifold`](@ref) in the sense that it either - -* it extends the functionality of a manifold with further features -* it defines a new manifold that internally uses functions from the decorated manifold - -with the main intent that several or most functions of [`AbstractManifold`](@ref) are transparently -passed through to the manifold that is decorated. This way a function implemented for a -decorator acts transparent on all other decorators, i.e. they just pass them through. If -the decorator the function is implemented for is not among the decorators, an error is -issued. By default all base manifold functions, for example [`exp`](@ref) and [`log`](@ref) -are transparent for all decorators. - -Transparency of functions with respect to decorators can be specified using the macros -[`@decorator_transparent_fallback`](@ref), [`@decorator_transparent_function`](@ref) and -[`@decorator_transparent_signature`](@ref). - -There are currently three modes given a new `AbstractDecoratorManifold` `M` -* `:intransparent` – this function has to be implmented for the new manifold `M` -* `:transparent` – this function is transparent, in the sense that the function is invoked - on the decorated `M.manifold`. This is the default, when introducing a function or signature. -* `:parent` specifies that (unless implemented) for this function, the classical inheritance - is issued, i.e. the function is invoked on `M`s supertype. -""" -abstract type AbstractDecoratorManifold{𝔽,T<:AbstractDecoratorType} <: AbstractManifold{𝔽} end - -# -# Macros -# -""" - @decorator_transparent_fallback(ex) - @decorator_transparent_fallback(fallback_case = :intransparent, ex) - -This macro introduces an additional implementation for a certain additional case. -This can especially be used if for an already transparent function and an abstract -intermediate type a change in the default is required. -For implementing a concrete type, neither this nor any other trick is necessary. One -just implements the function as before. Note that a decorator that [`is_default_decorator`](@ref) -still dispatches to the transparent case. - - -* `:transparent` states, that the function is transparently passed on to the manifold that - is decorated by the [`AbstractDecoratorManifold`](@ref) `M`, which is determined using - the function [`decorated_manifold`](@ref). -* `:intransparent` states that an implementation for this decorator is required, and if - none of the types provides one, an error is issued. Since this macro provides such an - implementation, this is the default. -* `:parent` states, that this function passes on to the supertype instead of to the - decorated manifold. - -Inline definitions are not supported. The function signature however may contain -keyword arguments and a where clause. It does not allow for parameters with default values. - -# Examples - -```julia -@decorator_transparent_fallback function log!(M::AbstractGroupManifold, X, p, q) - log!(decorated_manifold(M), X, p, Q) -end -@decorator_transparent_fallback :transparent function log!(M::AbstractGroupManifold, X, p, q) - log!(decorated_manifold(M), X, p, Q) -end -``` -""" -macro decorator_transparent_fallback(ex) - return esc(quote - @decorator_transparent_fallback :intransparent ($ex) - end) -end -macro decorator_transparent_fallback(fallback_case, input_ex) - ex = macroexpand(__module__, input_ex) - parts = _split_function(ex) - callargs = parts[:callargs] - where_exprs = parts[:where_exprs] - fname_fallback = Symbol(string(parts.fname) * "__" * string(fallback_case)[2:end]) - return esc( - quote - function ($(fname_fallback))( - $(callargs...); - $(parts[:kwargs_list]...), - ) where {$(where_exprs...)} - return ($(parts[:body])) - end - end, - ) -end - -""" - @decorator_transparent_function(ex) - @decorator_transparent_function(fallback_case = :intransparent, ex) - -Introduce the function specified by `ex` to act transparently with respect to -[`AbstractDecoratorManifold`](@ref)s. This introduces the possibility to modify the kind of -transparency the implementation is done for. This optional first argument, the `Symbol` -within `fallback_case`. This macro can be used to define a function and introduce it as -transparent to other decorators. Note that a decorator that [`is_default_decorator`](@ref) -still dispatches to the transparent case. - -The cases of transparency are - -* `:transparent` states, that the function is transparently passed on to the manifold that - is decorated by the [`AbstractDecoratorManifold`](@ref) `M`, which is determined using - the function [`decorated_manifold`](@ref). -* `:intransparent` states that an implementation for this decorator is required, and if - none of the types provides one, an error is issued. Since this macro provides such an - implementation, this is the default. -* `:parent` states, that this function passes on to the supertype instead of to the - decorated manifold. Passing is performed using the `invoke` function where the type of - manifold is replaced by its supertype. - -Innkoline-definitions are not yet covered – the function signature however may contain -keyword arguments and a where clause. - -# Examples - -```julia -@decorator_transparent_function log!(M::AbstractDecoratorManifold, X, p, q) - log!(decorated_manifold(M), X, p, Q) -end -@decorator_transparent_function :parent log!(M::AbstractDecoratorManifold, X, p, q) - log!(decorated_manifold(M), X, p, Q) -end -``` -""" -macro decorator_transparent_function(ex) - return esc(quote - @decorator_transparent_function :intransparent ($ex) - end) -end -macro decorator_transparent_function(fallback_case, input_ex) - ex = macroexpand(__module__, input_ex) - parts = _split_function(ex) - kwargs_list = parts[:kwargs_list] - callargs = parts[:callargs] - fname = parts[:fname] - where_exprs = parts[:where_exprs] - body = parts[:body] - argnames = parts[:argnames] - argtypes = parts[:argtypes] - kwargs_call = parts[:kwargs_call] - fname_fallback = Symbol(string(parts.fname) * "__" * string(fallback_case)[2:end]) - - return esc( - quote - function ($fname)( - $(argnames[1])::AbstractDecoratorManifold, - $(callargs[2:end]...); - $(kwargs_list...), - ) where {$(where_exprs...)} - transparency = ManifoldsBase._acts_transparently($fname, $(argnames...)) - return if transparency === Val(:parent) - return ($(parts.fname__parent))($(argnames...); $(kwargs_call...)) - elseif transparency === Val(:transparent) - return ($(parts.fname__transparent))($(argnames...); $(kwargs_call...)) - elseif transparency === Val(:intransparent) - return ($(parts.fname__intransparent))( - $(argnames...); - $(kwargs_call...), - ) - else - error("incorrect transparency: $transparency") - end - end - function ($(parts[:fname__transparent]))( - $(argnames[1])::AbstractDecoratorManifold, - $(callargs[2:end]...); - $(kwargs_list...), - ) where {$(where_exprs...)} - return ($fname)( - ManifoldsBase.decorated_manifold($(argnames[1])), - $(argnames[2:end]...); - $(kwargs_call...), - ) - end - function ($(parts[:fname__intransparent]))( - $(argnames[1])::AbstractDecoratorManifold, - $(callargs[2:end]...); - $(kwargs_list...), - ) where {$(where_exprs...)} - error_msg = ManifoldsBase.manifold_function_not_implemented_message( - $(argnames[1]), - $fname, - $(argnames[2:end]...), - ) - return error(error_msg) - end - function ($fname)( - $(argnames[1])::AbstractManifold, - $(callargs[2:end]...); - $(kwargs_list...), - ) where {$(where_exprs...)} - return error( - string( - ManifoldsBase.manifold_function_not_implemented_message( - $(argnames[1]), - $fname, - $(argnames[2:end]...), - ), - " Usually this is implemented for a ", - $(argtypes[1]), - ". Maybe you missed to implement this function for a default?", - ), - ) - end - function ($(parts[:fname__parent]))( - $(argnames[1])::AbstractDecoratorManifold, - $(callargs[2:end]...); - $(kwargs_list...), - ) where {$(where_exprs...)} - return invoke( - $fname, - Tuple{supertype($(argtypes[1])),$(argtypes[2:end]...)}, - $(argnames...); - $(kwargs_call...), - ) - end - function ($fname_fallback)( - $(callargs[1]), - $(callargs[2:end]...); - $(kwargs_list...), - ) where {$(where_exprs...)} - return ($body) - end - function decorator_transparent_dispatch( - ::typeof($fname), - $(callargs...), - ) where {$(where_exprs...)} - return Val($fallback_case) - end - end, - ) -end -#! format: off -# due to the if with returns inside (formatter puts a return upfront the if) -""" - @decorator_transparent_signature(ex) - -Introduces a given function to be transparent with respect to all decorators. -The function is adressed by its signature in `ex`. - -Supports standard, keyword arguments and `where` clauses. Doesn't support parameters with -default values. It introduces a dispatch on several transparency modes - -The cases of transparency are - -* `:transparent` states, that the function is transparently passed on to the manifold that - is decorated by the [`AbstractDecoratorManifold`](@ref) `M`, which is determined using - the function [`decorated_manifold`](@ref). This is the default. -* `:intransparent` states that an implementation for this decorator is required, and if - none of the types provides one, an error is issued. -* `:parent` states, that this function passes on to the supertype instead of to the - decorated manifold. - -Inline definitions are not supported. The function signature however may contain -keyword arguments and a where clause. - -The dispatch kind can later still be set to something different, see [`decorator_transparent_dispatch`](@ref) - -# Examples: - -```julia -@decorator_transparent_signature log!(M::AbstractDecoratorManifold, X, p, q) -@decorator_transparent_signature log!(M::TD, X, p, q) where {TD<:AbstractDecoratorManifold} -@decorator_transparent_signature isapprox(M::AbstractDecoratorManifold, p, q; kwargs...) -``` -""" -macro decorator_transparent_signature(ex) - parts = _split_signature(ex) - kwargs_list = parts[:kwargs_list] - callargs = parts[:callargs] - fname = parts[:fname] - where_exprs = parts[:where_exprs] - argnames = parts[:argnames] - argtypes = parts[:argtypes] - kwargs_call = parts[:kwargs_call] - #! format: off - return esc( - quote - function ($fname)($(callargs...); $(kwargs_list...)) where {$(where_exprs...)} - transparency = ManifoldsBase._acts_transparently($fname, $(argnames...)) - if transparency === Val(:parent) - return ($(parts.fname__parent))($(argnames...); $(kwargs_call...)) - elseif transparency === Val(:transparent) - return ($(parts.fname__transparent))($(argnames...); $(kwargs_call...)) - elseif transparency === Val(:intransparent) - return ($(parts.fname__intransparent))( - $(argnames...); - $(kwargs_call...), - ) - else - error("incorrect transparency: $transparency") - end - end - function ($(parts[:fname__transparent]))( - $(callargs...); - $(kwargs_list...), - ) where {$(where_exprs...)} - return ($fname)( - ManifoldsBase.decorated_manifold($(argnames[1])), - $(argnames[2:end]...); - $(kwargs_call...), - ) - end - function ($(parts[:fname__intransparent]))( - $(callargs...); - $(kwargs_list...), - ) where {$(where_exprs...)} - error_msg = ManifoldsBase.manifold_function_not_implemented_message( - $(argnames[1]), - $fname, - $(argnames[2:end]...), - ) - return error(error_msg) - end - function ($(parts[:fname__parent]))( - $(callargs...); - $(kwargs_list...), - ) where {$(where_exprs...)} - return invoke( - $fname, - Tuple{supertype($(argtypes[1])),$(argtypes[2:end]...)}, - $(argnames...); - $(kwargs_call...), - ) - end - end, - ) -end -#! format: on - - -# -# Functions -# - -""" - is_default_decorator(M) -> Bool - -For any manifold that is a subtype of [`AbstractDecoratorManifold`](@ref), this function -indicates whether a certain manifold `M` acts as a default decorator. - -This yields that _all_ functions are passed through to the decorated [`AbstractManifold`](@ref) -if `M` is indicated as default. This overwrites all [`is_decorator_transparent`](@ref) -values. - -This yields the following advantange: For a manifold one usually implicitly assumes for -example a metric. To avoid reimplementation of this metric when introducing a second metric, -the first metric can be set to be the default, i.e. its implementaion is already given by -the undecorated case. - -Value returned by this function is determined by [`default_decorator_dispatch`](@ref), -which returns a `Val`-wrapped boolean for type stability of certain functions. -""" -is_default_decorator(M::AbstractManifold) = _extract_val(default_decorator_dispatch(M)) - -""" - default_decorator_dispatch(M) -> Val - -Return whether by default to dispatch the the inner manifold of -a decorator (`Val(true)`) or not (`Val(false`). For more details see -[`is_decorator_transparent`](@ref). -""" -default_decorator_dispatch(::AbstractManifold) = Val(false) - -""" - is_decorator_transparent(f, M::AbstractManifold, args...) -> Bool - -Given a [`AbstractManifold`](@ref) `M` and a function `f(M, args...)`, indicate, whether an -[`AbstractDecoratorManifold`](@ref) acts transparently for `f`. This means, it -just passes through down to the internally stored manifold. -Transparency is only defined for decorator manifolds and by default all decorators are transparent. -A function that is affected by the decorator indicates this by returning `false`. To change -this behaviour, see [`decorator_transparent_dispatch`](@ref). - -If a decorator manifold is not in general transparent, it might still pass down -for the case that a decorator is the default decorator, see [`is_default_decorator`](@ref). -""" -function is_decorator_transparent(f, M::AbstractManifold, args...) - return decorator_transparent_dispatch(f, M, args...) === Val(:transparent) -end - -""" - decorator_transparent_dispatch(f, M::AbstractManifold, args...) -> Val - -Given a [`AbstractManifold`](@ref) `M` and a function `f(M,args...)`, indicate, whether a -function is `Val(:transparent)` or `Val(:intransparent)` for the (decorated) -[`AbstractManifold`](@ref) `M`. Another possibility is, that for `M` and given `args...` -the function `f` should invoke `M`s `Val(:parent)` implementation, see -[`@decorator_transparent_function`](@ref) for details. -""" -decorator_transparent_dispatch(::Any, ::AbstractManifold, args...) = Val(:transparent) - -function _acts_transparently(f, M::AbstractManifold, args...) - return _val_or( - default_decorator_dispatch(M), - decorator_transparent_dispatch(f, M, args...), - ) -end - -_val_or(::Val{true}, ::Val{T}) where {T} = Val(:transparent) -_val_or(::Val{false}, val::Val) = val - -# -# Functions overwritten with decorators -# - -function base_manifold(M::AbstractDecoratorManifold, depth::Val{N} = Val(-1)) where {N} - N == 0 && return M - N < 0 && return base_manifold(decorated_manifold(M), depth) - return base_manifold(decorated_manifold(M), Val(N - 1)) -end - -@decorator_transparent_signature check_point(M::AbstractDecoratorManifold, p; kwargs...) - -@decorator_transparent_signature check_vector(M::AbstractDecoratorManifold, p, X; kwargs...) - -""" - decorated_manifold(M::AbstractDecoratorManifold) - -Return the manifold decorated by the decorator `M`. Defaults to `M.manifold`. -""" -decorated_manifold(M::AbstractManifold) = M.manifold - -@decorator_transparent_signature copyto!(M::AbstractDecoratorManifold, q, p) -@decorator_transparent_signature copyto!(M::AbstractDecoratorManifold, Y, p, X) -@decorator_transparent_signature copy(M::AbstractDecoratorManifold, p) -@decorator_transparent_signature copy(M::AbstractDecoratorManifold, p, X) - -@decorator_transparent_signature distance(M::AbstractDecoratorManifold, p, q) - -@decorator_transparent_signature embed(M::AbstractDecoratorManifold, p, X) -@decorator_transparent_signature embed(M::AbstractDecoratorManifold, p) - -@decorator_transparent_signature embed!(M::AbstractDecoratorManifold, q, p) -@decorator_transparent_signature embed!(M::AbstractDecoratorManifold, Y, p, X) - -@decorator_transparent_signature exp(M::AbstractDecoratorManifold, p, X) - -@decorator_transparent_signature exp!(M::AbstractDecoratorManifold, q, p, X) - -@decorator_transparent_signature injectivity_radius(M::AbstractDecoratorManifold) -@decorator_transparent_signature injectivity_radius(M::AbstractDecoratorManifold, p) -@decorator_transparent_signature injectivity_radius( - M::AbstractDecoratorManifold, - m::AbstractRetractionMethod, -) -@decorator_transparent_signature injectivity_radius( - M::AbstractDecoratorManifold, - m::ExponentialRetraction, -) -@decorator_transparent_signature injectivity_radius( - M::AbstractDecoratorManifold, - p, - m::AbstractRetractionMethod, -) -@decorator_transparent_signature injectivity_radius( - M::AbstractDecoratorManifold, - p, - m::ExponentialRetraction, -) - -@decorator_transparent_signature inner(M::AbstractDecoratorManifold, p, X, Y) - -@decorator_transparent_signature inverse_retract( - M::AbstractDecoratorManifold, - p, - q, - m::AbstractInverseRetractionMethod, -) - -@decorator_transparent_signature inverse_retract( - M::AbstractDecoratorManifold, - p, - q, - m::LogarithmicInverseRetraction, -) - -@decorator_transparent_signature inverse_retract!( - M::AbstractDecoratorManifold, - X, - p, - q, - m::AbstractInverseRetractionMethod, -) - -@decorator_transparent_signature inverse_retract!( - M::AbstractDecoratorManifold, - X, - p, - q, - m::LogarithmicInverseRetraction, -) - -@decorator_transparent_signature inverse_retract( - M::AbstractDecoratorManifold, - p, - q, - m::NLsolveInverseRetraction, -) - -@decorator_transparent_signature inverse_retract!( - M::AbstractDecoratorManifold, - X, - p, - q, - m::NLsolveInverseRetraction, -) - -@decorator_transparent_signature isapprox(M::AbstractDecoratorManifold, p, q; kwargs...) -@decorator_transparent_signature isapprox(M::AbstractDecoratorManifold, p, X, Y; kwargs...) - -@decorator_transparent_signature log(M::AbstractDecoratorManifold, p, q) -@decorator_transparent_signature log!(M::AbstractDecoratorManifold, X, p, q) - -@decorator_transparent_signature manifold_dimension(M::AbstractDecoratorManifold) - -@decorator_transparent_signature mid_point(M::AbstractDecoratorManifold, p1, p2) -@decorator_transparent_signature mid_point!(M::AbstractDecoratorManifold, q, p1, p2) - -@decorator_transparent_signature number_system(M::AbstractDecoratorManifold) - -@decorator_transparent_signature project(M::AbstractDecoratorManifold, p) -@decorator_transparent_signature project!(M::AbstractDecoratorManifold, q, p) - -@decorator_transparent_signature project(M::AbstractDecoratorManifold, p, X) -@decorator_transparent_signature project!(M::AbstractDecoratorManifold, Y, p, X) - -@decorator_transparent_signature representation_size(M::AbstractDecoratorManifold) - -@decorator_transparent_signature retract( - M::AbstractDecoratorManifold, - p, - X, - m::AbstractRetractionMethod, -) - -@decorator_transparent_signature retract( - M::AbstractDecoratorManifold, - p, - X, - m::ExponentialRetraction, -) - -@decorator_transparent_signature retract!( - M::AbstractDecoratorManifold, - q, - p, - X, - m::AbstractRetractionMethod, -) - -@decorator_transparent_signature retract!( - M::AbstractDecoratorManifold, - q, - p, - X, - m::ExponentialRetraction, -) - -@decorator_transparent_signature vector_transport_along( - M::AbstractDecoratorManifold, - p, - X, - c, -) -@decorator_transparent_signature vector_transport_along!( - M::AbstractDecoratorManifold, - Y, - p, - X, - c, -) -@decorator_transparent_signature vector_transport_along!( - M::AbstractDecoratorManifold, - Y, - p, - X, - c::AbstractVector, - m::AbstractVectorTransportMethod, -) -@decorator_transparent_signature vector_transport_along!( - M::AbstractDecoratorManifold, - Y, - p, - X, - c::AbstractVector, - m::PoleLadderTransport, -) -@decorator_transparent_signature vector_transport_along!( - M::AbstractDecoratorManifold, - Y, - p, - X, - c::AbstractVector, - m::SchildsLadderTransport, -) - -@decorator_transparent_signature vector_transport_direction( - M::AbstractDecoratorManifold, - p, - X, - d, -) -@decorator_transparent_signature vector_transport_direction!( - M::AbstractDecoratorManifold, - Y, - p, - X, - d, -) - -@decorator_transparent_signature vector_transport_to( - M::AbstractDecoratorManifold, - p, - X, - q, - m::AbstractVectorTransportMethod, -) -@decorator_transparent_signature vector_transport_to!( - M::AbstractDecoratorManifold, - Y, - p, - X, - q, - m::AbstractVectorTransportMethod, -) -@decorator_transparent_signature vector_transport_to!( - M::AbstractDecoratorManifold, - Y, - p, - X, - q, - m::ProjectionTransport, -) -@decorator_transparent_signature vector_transport_to!( - M::AbstractDecoratorManifold, - Y, - p, - X, - q, - m::PoleLadderTransport, -) -@decorator_transparent_signature vector_transport_to!( - M::AbstractDecoratorManifold, - Y, - p, - X, - q, - m::ScaledVectorTransport, -) -@decorator_transparent_signature vector_transport_to!( - M::AbstractDecoratorManifold, - Y, - p, - X, - q, - m::SchildsLadderTransport, -) - -@decorator_transparent_signature zero_vector(M::AbstractDecoratorManifold, p) -@decorator_transparent_signature zero_vector!(M::AbstractDecoratorManifold, X, p) - -# -# Manually patch getindex not using the whole machinery -# -Base.@propagate_inbounds function Base.getindex( - p::AbstractArray, - M::AbstractDecoratorManifold{𝔽,<:AbstractDecoratorType}, - I::Union{Integer,Colon,AbstractVector}..., -) where {𝔽} - return getindex(p, decorated_manifold(M), I...) -end - -DEFAULT_PARENT_FUNCTIONS = [ - distance, - exp, - inner, - inverse_retract, - log, - mid_point, - norm, - retract, - vector_transport_along, - vector_transport_direction, - vector_transport_to, -] - -for f in DEFAULT_PARENT_FUNCTIONS - eval( - quote - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractDecoratorManifold{𝔽,<:AbstractDecoratorType}, - args..., - ) where {𝔽} - return Val(:parent) - end - end, - ) -end diff --git a/src/DefaultManifold.jl b/src/DefaultManifold.jl index f0cf753c..af3fabea 100644 --- a/src/DefaultManifold.jl +++ b/src/DefaultManifold.jl @@ -42,49 +42,69 @@ embed!(::DefaultManifold, Y, p, X) = copyto!(Y, X) exp!(::DefaultManifold, q, p, X) = (q .= p .+ X) -function get_basis( - ::DefaultManifold, - p, - B::DefaultOrthonormalBasis{𝔽,TangentSpaceType}, -) where {𝔽} - return CachedBasis(B, [_euclidean_basis_vector(p, i) for i in eachindex(p)]) +function get_basis_orthonormal(::DefaultManifold, p, N) + return CachedBasis( + DefaultOrthonormalBasis(N), + [_euclidean_basis_vector(p, i) for i in eachindex(p)], + ) end -function get_basis( - ::DefaultManifold, - p, - B::DefaultOrthogonalBasis{𝔽,TangentSpaceType}, -) where {𝔽} - return CachedBasis(B, [_euclidean_basis_vector(p, i) for i in eachindex(p)]) +function get_basis_orthogonal(::DefaultManifold, p, N) + return CachedBasis( + DefaultOrthogonalBasis(N), + [_euclidean_basis_vector(p, i) for i in eachindex(p)], + ) end -function get_basis(::DefaultManifold, p, B::DefaultBasis{𝔽,TangentSpaceType}) where {𝔽} - return CachedBasis(B, [_euclidean_basis_vector(p, i) for i in eachindex(p)]) +function get_basis_default(::DefaultManifold, p, N) + return CachedBasis( + DefaultBasis(N), + [_euclidean_basis_vector(p, i) for i in eachindex(p)], + ) end -function get_basis(M::DefaultManifold, p, B::DiagonalizingOrthonormalBasis) +function get_basis_diagonalizing(M::DefaultManifold, p, B) vecs = get_vectors(M, p, get_basis(M, p, DefaultOrthonormalBasis())) eigenvalues = zeros(real(eltype(p)), manifold_dimension(M)) return CachedBasis(B, DiagonalizingBasisData(B.frame_direction, eigenvalues, vecs)) end -function get_coordinates!( +# Complex manifold, real basis -> coefficients c are complesx -> reshape +# Real manifold, real basis -> reshape +function get_coordinates_orthonormal!(M::DefaultManifold, c, p, X, ::RealNumbers) + return copyto!(c, reshape(X, number_of_coordinates(M, ℝ))) +end +function get_coordinates_diagonalizing!( M::DefaultManifold, - Y, + c, p, X, - ::DefaultOrthonormalBasis{ℝ,TangentSpaceType}, + ::DiagonalizingOrthonormalBasis{ℝ}, ) - copyto!(Y, reshape(X, manifold_dimension(M))) - return Y + return copyto!(c, reshape(X, number_of_coordinates(M, ℝ))) end - -function get_vector!( +function get_coordinates_orthonormal!(::DefaultManifold, c, p, X, ::ComplexNumbers) + m = length(X) + return copyto!(c, [reshape(real(X), m); reshape(imag(X), m)]) +end +function get_vector_orthonormal!(M::DefaultManifold, Y, p, c, ::RealNumbers) + return copyto!(Y, reshape(c, representation_size(M))) +end +function get_vector_diagonalizing!( M::DefaultManifold, Y, p, - X, - ::DefaultOrthonormalBasis{ℝ,TangentSpaceType}, + c, + ::DiagonalizingOrthonormalBasis{ℝ}, ) - copyto!(Y, reshape(X, representation_size(M))) - return Y + return copyto!(Y, reshape(c, representation_size(M))) +end +function get_vector_orthonormal!( + M::DefaultManifold{T,ℂ}, + Y, + p, + c, + ::ComplexNumbers, +) where {T} + n = div(length(c), 2) + return copyto!(Y, reshape(c[1:n] + c[(n + 1):(2n)] * 1im, representation_size(M))) end injectivity_radius(::DefaultManifold) = Inf @@ -94,7 +114,7 @@ injectivity_radius(::DefaultManifold) = Inf log!(::DefaultManifold, Y, p, q) = (Y .= q .- p) @generated function manifold_dimension(::DefaultManifold{T,𝔽}) where {T,𝔽} - return *(T.parameters...) * real_dimension(𝔽) + return length(T.parameters) == 0 ? 1 : *(T.parameters...) * real_dimension(𝔽) end number_system(::DefaultManifold{T,𝔽}) where {T,𝔽} = 𝔽 @@ -110,37 +130,12 @@ function Base.show(io::IO, ::DefaultManifold{N,𝔽}) where {N,𝔽} return print(io, "DefaultManifold($(join(N.parameters, ", ")); field = $(𝔽))") end -function vector_transport_along!( - ::DefaultManifold, - Y, - p, - X, - c::AbstractVector, - ::AbstractVectorTransportMethod, -) +function parallel_transport_along!(::DefaultManifold, Y, p, X, c) return copyto!(Y, X) end -for VT in VECTOR_TRANSPORT_DISAMBIGUATION - eval( - quote - @invoke_maker 6 AbstractVectorTransportMethod vector_transport_along!( - M::DefaultManifold, - Y, - p, - X, - c::AbstractVector, - B::$VT, - ) - end, - ) -end - -function vector_transport_to!(::DefaultManifold, Y, p, X, q, ::ParallelTransport) +function parallel_transport_to!(::DefaultManifold, Y, p, X, q) return copyto!(Y, X) end -function vector_transport_to!(M::DefaultManifold, Y, p, X, q, ::ProjectionTransport) - return project!(M, Y, q, X) -end zero_vector(::DefaultManifold, p) = zero(p) diff --git a/src/EmbeddedManifold.jl b/src/EmbeddedManifold.jl index c7cfc891..f71053b3 100644 --- a/src/EmbeddedManifold.jl +++ b/src/EmbeddedManifold.jl @@ -1,81 +1,17 @@ -""" - AbstractEmbeddingType <: AbstractDecoratorType - -A type used to specify properties of an [`AbstractEmbeddedManifold`](@ref). -""" -abstract type AbstractEmbeddingType <: AbstractDecoratorType end - -""" - AbstractEmbeddedManifold{𝔽,T<:AbstractEmbeddingType,𝔽} <: AbstractDecoratorManifold{𝔽} - -This abstract type indicates that a concrete subtype is an embedded manifold with the -additional property, that its points are given in the embedding. This also means, that -the default implementation of [`embed`](@ref) is just the identity, since the points are -already stored in the form suitable for this embedding specified. This also holds true for -tangent vectors. - -Furthermore, depending on the [`AbstractEmbeddingType`](@ref) different methods are -transparently used from the embedding, for example the [`inner`](@ref) product or even the -[`distance`](@ref) function. Specifying such an embedding type transparently passes the -compuation onwards to the embedding (note again, that no [`embed`](@ref) is required) -and hence avoids to reimplement these methods in the manifold that is embedded. - -This should be used for example for [`check_point`](@ref) or [`check_vector`](@ref), -which should first invoke the test of the embedding and then test further constraints -the representation in the embedding has for these points to be valid. - -Technically this is realised by making the [`AbstractEmbeddedManifold`](@ref) is a decorator -for the [`AbstractManifold`](@ref)s that are subtypes. -""" -abstract type AbstractEmbeddedManifold{𝔽,T<:AbstractEmbeddingType} <: - AbstractDecoratorManifold{𝔽,T} end - -""" - DefaultEmbeddingType <: AbstractEmbeddingType - -A type of default embedding that does not have any special properties. -""" -struct DefaultEmbeddingType <: AbstractEmbeddingType end - -""" - AbstractIsometricEmbeddingType <: AbstractEmbeddingType - -Characterizes an embedding as isometric. For this case the [`inner`](@ref) product -is passed from the embedded manifold to the embedding. -""" -abstract type AbstractIsometricEmbeddingType <: AbstractEmbeddingType end - - -""" - DefaultIsometricEmbeddingType <: AbstractIsometricEmbeddingType - -An isometric embedding type that acts as a default, i.e. it has no specific properties -beyond its isometric property. -""" -struct DefaultIsometricEmbeddingType <: AbstractIsometricEmbeddingType end - -""" - TransparentIsometricEmbedding <: AbstractIsometricEmbeddingType - -Specify that an embedding is the default isometric embedding. This even inherits -logarithmic and exponential map as well as retraction and inverse retractions from the -embedding. - -For an example, see [`SymmetricMatrices`](@ref Main.Manifolds.SymmetricMatrices) which are -isometrically embedded in the Euclidean space of matrices but also inherit exponential -and logarithmic maps. -""" -struct TransparentIsometricEmbedding <: AbstractIsometricEmbeddingType end - """ EmbeddedManifold{𝔽, MT <: AbstractManifold, NT <: AbstractManifold} <: AbstractDecoratorManifold{𝔽} A type to represent an explicit embedding of a [`AbstractManifold`](@ref) `M` of type `MT` embedded into a manifold `N` of type `NT`. +By default, an embedded manifold is set to be embedded, but neither isometrically embedded +nor a submanifold. !!! note - This type is not required if a manifold `M` is to be embedded in one specific manifold `N`. One can then just implement - [`embed!`](@ref) and [`project!`](@ref). Only for a second –maybe considered non-default– + This type is not required if a manifold `M` is to be embedded in one specific manifold `N`. + One can then just implement [`embed!`](@ref) and [`project!`](@ref). + You can further pass functions to the embedding, for example, when it is an isometric embedding, + by using an [`AbstractDecoratorManifold`](@ref). + Only for a second –maybe considered non-default– embedding, this type should be considered in order to dispatch on different embed and project methods for different embeddings `N`. @@ -92,112 +28,33 @@ Generate the `EmbeddedManifold` of the [`AbstractManifold`](@ref) `M` into the [`AbstractManifold`](@ref) `N`. """ struct EmbeddedManifold{𝔽,MT<:AbstractManifold{𝔽},NT<:AbstractManifold} <: - AbstractDecoratorManifold{𝔽,AbstractDecoratorType} + AbstractDecoratorManifold{𝔽} manifold::MT embedding::NT end -function allocate_result(M::AbstractEmbeddedManifold, f::typeof(embed), x...) - T = allocate_result_type(M, f, x) - return allocate(x[1], T, representation_size(decorated_manifold(M))) -end - -function allocate_result(M::AbstractEmbeddedManifold, f::typeof(project), x...) - T = allocate_result_type(M, f, x) - return allocate(x[1], T, representation_size(base_manifold(M))) -end - -function allocate_result(M::EmbeddedManifold, f::typeof(embed), x...) - T = allocate_result_type(M, f, x) - return allocate(x[1], T, representation_size(get_embedding(M))) -end +@inline active_traits(f, ::EmbeddedManifold, ::Any...) = merge_traits(IsEmbeddedManifold()) function allocate_result(M::EmbeddedManifold, f::typeof(project), x...) T = allocate_result_type(M, f, x) return allocate(x[1], T, representation_size(base_manifold(M))) end - """ - base_manifold(M::AbstractEmbeddedManifold, d::Val{N} = Val(-1)) + decorated_manifold(M::EmbeddedManifold, d::Val{N} = Val(-1)) -Return the base manifold of `M` that is enhanced with its embedding. -While functions like `inner` might be overwritten to use the (decorated) manifold -representing the embedding, the base_manifold is the manifold itself in the sense that -detemining e.g. the [`is_default_metric`](@ref) does not fall back to check with -the embedding but with the manifold itself. For this abstract case, just `M` is returned. -""" -base_manifold(M::AbstractEmbeddedManifold, ::Val{N} = Val(-1)) where {N} = M -""" - base_manifold(M::EmbeddedManifold, d::Val{N} = Val(-1)) - -Return the base manifold of `M` that is enhanced with its embedding. For this specific +Return the manifold of `M` that is decorated with its embedding. For this specific type the internally stored enhanced manifold `M.manifold` is returned. -""" -base_manifold(M::EmbeddedManifold, ::Val{N} = Val(-1)) where {N} = M.manifold - - -""" - check_point(M::AbstractEmbeddedManifold, p; kwargs) - -check whether a point `p` is a valid point on the [`AbstractEmbeddedManifold`](@ref), -i.e. that `embed(M, p)` is a valid point on the embedded manifold. -""" -function check_point(M::AbstractEmbeddedManifold, p; kwargs...) - return invoke( - check_point, - Tuple{typeof(get_embedding(M)),typeof(p)}, - get_embedding(M), - p; - kwargs..., - ) -end - -""" - check_vector(M::AbstractEmbeddedManifold, p, X; kwargs...) - -Check that `embed(M, p, X)` is a valid tangent to `embed(M, p)`. -""" -function check_vector(M::AbstractEmbeddedManifold, p, X; kwargs...) - return invoke( - check_vector, - Tuple{typeof(get_embedding(M)),typeof(p),typeof(X)}, - get_embedding(M), - p, - X; - kwargs..., - ) -end - -decorated_manifold(M::EmbeddedManifold) = M.embedding - -function embed(M::EmbeddedManifold, p) - q = allocate_result(M, embed, p) - embed!(M, q, p) - return q -end - -embed!(::AbstractEmbeddedManifold, q, p) = copyto!(q, p) -embed!(::AbstractEmbeddedManifold, Y, p, X) = copyto!(Y, X) +See also [`base_manifold`](@ref), where this is used to (potentially) completely undecorate the manifold. """ - get_embedding(M::AbstractEmbeddedManifold) - -Return the [`AbstractManifold`](@ref) `N` an [`AbstractEmbeddedManifold`](@ref) is embedded into. -""" -get_embedding(::AbstractEmbeddedManifold) - -@decorator_transparent_function function get_embedding(M::AbstractEmbeddedManifold) - return decorated_manifold(M) -end +decorated_manifold(M::EmbeddedManifold) = M.manifold """ get_embedding(M::EmbeddedManifold) -Return the [`AbstractManifold`](@ref) `N` an [`EmbeddedManifold`](@ref) is embedded into. +Return the embedding [`AbstractManifold`](@ref) `N` of `M`, if it exists. """ -get_embedding(::EmbeddedManifold) - function get_embedding(M::EmbeddedManifold) return M.embedding end @@ -214,202 +71,3 @@ function show( ) where {𝔽,MT<:AbstractManifold{𝔽},NT<:AbstractManifold} return print(io, "EmbeddedManifold($(M.manifold), $(M.embedding))") end - -function default_decorator_dispatch(::EmbeddedManifold) - return Val(true) -end - -@doc raw""" - default_embedding_dispatch(M::AbstractEmbeddedManifold) - -This method indicates that an [`AbstractEmbeddedManifold`](@ref) is the default -and hence acts completely transparently and passes all functions transparently onwards. -This is used by the [`AbstractDecoratorManifold`](@ref) within -[`default_decorator_dispatch`](@ref). -By default this is set to `Val(false)`. -""" -default_embedding_dispatch(::AbstractEmbeddedManifold) = Val(false) - -# -# Abstract intransparent – i.e. new implementations necessary -for f in - [check_point, check_vector, embed!, exp!, inner, log!, manifold_dimension, project!] - eval( - quote - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractEmbeddedManifold, - args..., - ) - return Val(:intransparent) - end - end, - ) -end -# -# Abstract parent – i.e. pass to embedding -for f in [ - copy, - copyto!, - embed, - get_basis, - get_coordinates, - get_coordinates!, - get_vector, - get_vector!, - inverse_retract!, - mid_point!, - project, - retract!, - vector_transport_along, - vector_transport_direction, - vector_transport_direction!, - vector_transport_to, -] - eval( - quote - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractEmbeddedManifold, - args..., - ) - return Val(:parent) - end - end, - ) -end -# Abstract generic isometric -for f in [inverse_retract!, retract!] - eval( - quote - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractEmbeddedManifold{𝔽,<:AbstractIsometricEmbeddingType}, - args..., - ) where {𝔽} - return Val(:parent) - end - end, - ) -end -for f in [norm, inner] - eval( - quote - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractEmbeddedManifold{𝔽,<:AbstractIsometricEmbeddingType}, - args..., - ) where {𝔽} - return Val(:transparent) - end - end, - ) -end -# -# Transparent Isometric Embedding – additionally transparent -for f in [ - distance, - exp, - exp!, - inner, - inverse_retract, - inverse_retract!, - log, - log!, - mid_point, - mid_point!, - project!, - project, - retract, - retract!, - vector_transport_along, - vector_transport_direction, - vector_transport_direction!, - vector_transport_to, -] - eval( - quote - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractEmbeddedManifold{𝔽,<:TransparentIsometricEmbedding}, - args..., - ) where {𝔽} - return Val(:transparent) - end - end, - ) -end -# -# For explicit EmbeddingManifolds the following have to be reimplemented (:intransparent) -for f in [embed, project] - eval( - quote - function decorator_transparent_dispatch( - ::typeof($f), - ::EmbeddedManifold, - args..., - ) - return Val(:intransparent) - end - end, - ) -end - -# unified vector transports for the three already implemented cases, -# where _direction! still has its nice fallback -for f in [vector_transport_along!, vector_transport_to!] - eval( - quote - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractEmbeddedManifold{𝔽,<:E}, - Y, - p, - X, - q, - ::T, - ) where {𝔽,T,E} - return Val(:intransparent) - end - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractEmbeddedManifold{𝔽,<:TransparentIsometricEmbedding}, - Y, - p, - X, - q, - ::T, - ) where {𝔽,T} - return Val(:transparent) - end - end, - ) - for m in [PoleLadderTransport, SchildsLadderTransport, ScaledVectorTransport] - eval( - quote - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractEmbeddedManifold{𝔽,<:E}, - Y, - p, - X, - q, - ::$m, - ) where {𝔽,E} - return Val(:parent) - end - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractEmbeddedManifold{𝔽,<:TransparentIsometricEmbedding}, - Y, - p, - X, - q, - ::$m, - ) where {𝔽} - return Val(:parent) - end - end, - ) - end -end diff --git a/src/ManifoldsBase.jl b/src/ManifoldsBase.jl index e89e177e..13d3f5ed 100644 --- a/src/ManifoldsBase.jl +++ b/src/ManifoldsBase.jl @@ -23,6 +23,8 @@ import Markdown: @doc_str using LinearAlgebra include("maintypes.jl") +include("numbers.jl") +include("bases.jl") include("retractions.jl") include("exp_log_geo.jl") include("projections.jl") @@ -110,7 +112,7 @@ manifold for vector bundles or power manifolds. The optional parameter `depth` c to remove only the first `depth` many decorators and return the [`AbstractManifold`](@ref) from that level, whether its decorated or not. Any negative value deactivates this depth limit. """ -base_manifold(M::AbstractManifold, depth = Val(-1)) = M +base_manifold(M::AbstractManifold, ::Val = Val(-1)) = M """ check_point(M::AbstractManifold, p; kwargs...) -> Union{Nothing,String} @@ -148,8 +150,9 @@ By default, `check_size` returns `nothing`, i.e. if no checks are implemented, t assumption is to be optimistic. """ function check_size(M::AbstractManifold, p) - n = size(p) m = representation_size(M) + m === nothing && return nothing # nothing reasonable in size to check + n = size(p) if length(n) != length(m) return DomainError( length(n), @@ -166,8 +169,9 @@ end function check_size(M::AbstractManifold, p, X) mse = check_size(M, p) mse === nothing || return mse - n = size(X) m = representation_size(M) + m === nothing && return nothing # without a representation size - nothing to check. + n = size(X) if length(n) != length(m) return DomainError( length(n), @@ -280,13 +284,11 @@ the representation is changed accordingly. If you have more than one embedding, see [`EmbeddedManifold`](@ref) for defining a second embedding. If your point `p` is already represented in some embedding, -see [`AbstractEmbeddedManifold`](@ref) how you can avoid reimplementing code from the embedded manifold +see [`AbstractDecoratorManifold`](@ref) how you can avoid reimplementing code from the embedded manifold See also: [`EmbeddedManifold`](@ref), [`project!`](@ref project!(M::AbstractManifold, q, p)) """ -function embed!(M::AbstractManifold, q, p) - return error(manifold_function_not_implemented_message(M, embed!, q, p)) -end +embed!(M::AbstractManifold, q, p) """ embed(M::AbstractManifold, p, X) @@ -301,7 +303,7 @@ embedding, the representation is changed accordingly. If you have more than one embedding, see [`EmbeddedManifold`](@ref) for defining a second embedding. If your tangent vector `X` is already represented in some embedding, -see [`AbstractEmbeddedManifold`](@ref) how you can avoid reimplementing code from the embedded manifold +see [`AbstractDecoratorManifold`](@ref) how you can avoid reimplementing code from the embedded manifold See also: [`EmbeddedManifold`](@ref), [`project`](@ref project(M::AbstractManifold, p, X)) """ @@ -327,9 +329,7 @@ the tangent spaces of the embedded base points. See also: [`EmbeddedManifold`](@ref), [`project!`](@ref project!(M::AbstractManifold, Y, p, X)) """ -function embed!(M::AbstractManifold, Y, p, X) - return error(manifold_function_not_implemented_message(M, embed!, Y, p, X)) -end +embed!(M::AbstractManifold, Y, p, X) @doc raw""" injectivity_radius(M::AbstractManifold, p) @@ -349,16 +349,11 @@ Distance ``d`` such that is injective for all tangent vectors shorter than ``d`` (i.e. has an inverse) for point `p` if provided or all manifold points otherwise. """ -function injectivity_radius(M::AbstractManifold) - return error(manifold_function_not_implemented_message(M, injectivity_radius)) -end -injectivity_radius(M::AbstractManifold, p) = injectivity_radius(M) +injectivity_radius(M::AbstractManifold) +injectivity_radius(M::AbstractManifold, p) = injectivity_radius(M, ExponentialRetraction()) function injectivity_radius(M::AbstractManifold, p, method::AbstractRetractionMethod) return injectivity_radius(M, method) end -function injectivity_radius(M::AbstractManifold, method::AbstractRetractionMethod) - return error(manifold_function_not_implemented_message(M, injectivity_radius, method)) -end function injectivity_radius(M::AbstractManifold, p, ::ExponentialRetraction) return injectivity_radius(M, p) end @@ -369,12 +364,8 @@ injectivity_radius(M::AbstractManifold, ::ExponentialRetraction) = injectivity_r Compute the inner product of tangent vectors `X` and `Y` at point `p` from the [`AbstractManifold`](@ref) `M`. - -See also: [`MetricManifold`](@ref Main.Manifolds.MetricManifold) """ -function inner(M::AbstractManifold, p, X, Y) - return error(manifold_function_not_implemented_message(M, inner, p, X, Y)) -end +inner(M::AbstractManifold, p, X, Y) """ isapprox(M::AbstractManifold, p, q; kwargs...) @@ -403,10 +394,15 @@ Return whether `p` is a valid point on the [`AbstractManifold`](@ref) `M`. If `throw_error` is `false`, the function returns either `true` or `false`. If `throw_error` is `true`, the function either returns `true` or throws an error. By default the function -calls [`check_point(M, p; kwargs...)`](@ref) and checks whether the returned value +calls [`check_point`](@ref) and checks whether the returned value is `nothing` or an error. """ function is_point(M::AbstractManifold, p, throw_error = false; kwargs...) + mps = check_size(M, p) + if mps !== nothing + throw_error && throw(mps) + return false + end mpe = check_point(M, p; kwargs...) mpe === nothing && return true return throw_error ? throw(mpe) : false @@ -420,7 +416,7 @@ Returns either `true` or `false`. If `throw_error` is `false`, the function returns either `true` or `false`. If `throw_error` is `true`, the function either returns `true` or throws an error. By default the function -calls [`check_vector(M, p, X; kwargs...)`](@ref) and checks whether the returned +calls [`check_vector`](@ref) and checks whether the returned value is `nothing` or an error. If `check_base_point` is true, then the point `p` will be first checked using the @@ -430,23 +426,23 @@ function is_vector( M::AbstractManifold, p, X, - throw_error = false; - check_base_point = true, + throw_error = false, + check_base_point = true; kwargs..., ) if check_base_point - mpe = check_point(M, p; kwargs...) - if mpe !== nothing - if throw_error - throw(mpe) - else - return false - end - end + s = is_point(M, p, throw_error; kwargs...) # if throw_error, is_point throws, + !s && return false # otherwise if not a point return false + end + mXs = check_size(M, p, X) + if mXs !== nothing + throw_error && throw(mXs) + return false end - mtve = check_vector(M, p, X; kwargs...) - mtve === nothing && return true - return throw_error ? throw(mtve) : false + mXe = check_vector(M, p, X; kwargs...) + mXe === nothing && return true + throw_error && throw(mXe) + return false end @doc raw""" @@ -455,16 +451,7 @@ end The dimension $n=\dim_{\mathcal M}$ of real space $\mathbb R^n$ to which the neighborhood of each point of the [`AbstractManifold`](@ref) `M` is homeomorphic. """ -function manifold_dimension(M::AbstractManifold) - return error(manifold_function_not_implemented_message(M, manifold_dimension)) -end - -function manifold_function_not_implemented_message(M::AbstractManifold, f, x...) - s = join(map(string, map(typeof, x)), ", ", " and ") - a = length(x) > 1 ? "arguments" : "argument" - m = length(x) > 0 ? " for $(a) $(s)." : "." - return "$(f) not implemented on $(M)$(m)" -end +manifold_dimension(M::AbstractManifold) """ mid_point(M::AbstractManifold, p1, p2) @@ -565,12 +552,12 @@ function zero_vector(M::AbstractManifold, p) return X end include("errors.jl") -include("numbers.jl") +include("parallel_transport.jl") include("vector_transport.jl") -include("DecoratorManifold.jl") -include("bases.jl") include("vector_spaces.jl") include("point_vector_fallbacks.jl") +include("nested_trait.jl") +include("decorator_trait.jl") include("ValidationManifold.jl") include("EmbeddedManifold.jl") include("DefaultManifold.jl") @@ -578,43 +565,48 @@ include("PowerManifold.jl") export AbstractManifold, AbstractManifoldPoint, TVector, CoTVector, TFVector, CoTFVector export AbstractDecoratorManifold +export AbstractTrait, IsEmbeddedManifold, IsEmbeddedSubmanifold, IsIsometricEmbeddedManifold +export IsExplicitDecorator export ValidationManifold, ValidationMPoint, ValidationTVector, ValidationCoTVector -export AbstractEmbeddingType, - TransparentIsometricEmbedding, DefaultIsometricEmbeddingType, DefaultEmbeddingType -export AbstractEmbeddedManifold, EmbeddedManifold, TransparentIsometricEmbedding +export EmbeddedManifold export AbstractPowerManifold, PowerManifold export AbstractPowerRepresentation, NestedPowerRepresentation, NestedReplacingPowerRepresentation -export AbstractDecoratorType, DefaultDecoratorType - export OutOfInjectivityRadiusError export AbstractRetractionMethod, ApproximateInverseRetraction, - NLsolveInverseRetraction, + CayleyRetraction, + EmbeddedRetraction, ExponentialRetraction, + NLSolveInverseRetraction, + ODEExponentialRetraction, QRRetraction, + PadeRetraction, PolarRetraction, ProjectionRetraction, - PowerRetraction, - InversePowerRetraction + SoftmaxRetraction export AbstractInverseRetractionMethod, ApproximateInverseRetraction, + EmbeddedInverseRetraction, LogarithmicInverseRetraction, + NLSolveInverseRetraction, QRInverseRetraction, PolarInverseRetraction, - ProjectionInverseRetraction + ProjectionInverseRetraction, + SoftmaxInverseRetraction export AbstractVectorTransportMethod, DifferentiatedRetractionVectorTransport, ParallelTransport, PoleLadderTransport, - PowerVectorTransport, ProjectionTransport, ScaledVectorTransport, - SchildsLadderTransport + SchildsLadderTransport, + VectorTransportDirection, + VectorTransportTo export CachedBasis, DefaultBasis, @@ -623,15 +615,14 @@ export CachedBasis, DiagonalizingOrthonormalBasis, DefaultOrthonormalBasis, GramSchmidtOrthonormalBasis, - ProjectedOrthonormalBasis + ProjectedOrthonormalBasis, + VeeOrthogonalBasis export CompositeManifoldError, ComponentManifoldError export allocate, + angle, base_manifold, - check_point, - check_vector, - check_size, copy, copyto!, default_inverse_retraction_method, @@ -645,7 +636,6 @@ export allocate, geodesic, get_basis, get_component, - get_component!, get_coordinates, get_coordinates!, get_embedding, @@ -674,6 +664,12 @@ export allocate, number_of_coordinates, number_system, power_dimensions, + parallel_transport_along, + parallel_transport_along!, + parallel_transport_direction, + parallel_transport_direction!, + parallel_transport_to, + parallel_transport_to!, project, project!, real_dimension, @@ -682,6 +678,12 @@ export allocate, show, retract, retract!, + retract_polar, + retract_polar!, + retract_project, + retract_project!, + retract_qr, + retract_qr!, vector_transport_along, vector_transport_along!, vector_transport_direction, diff --git a/src/PowerManifold.jl b/src/PowerManifold.jl index a7fba70c..da16ddb1 100644 --- a/src/PowerManifold.jl +++ b/src/PowerManifold.jl @@ -112,60 +112,6 @@ function PowerManifold( return PowerManifold{𝔽,PowerManifold{𝔽,TM,TSize},Tuple{size...},TPR}(M) end -""" - PowerRetraction{TR<:AbstractRetractionMethod} <: AbstractRetractionMethod - -The `PowerRetraction` avoids ambiguities between dispatching on the [`AbstractPowerManifold`](@ref) -and dispatching on the [`AbstractRetractionMethod`](@ref) and encapsulates this. -This container should only be used in rare cases outside of this package. Usually a -subtype of the [`AbstractPowerManifold`](@ref) should define a way how to treat -its [`AbstractRetractionMethod`](@ref)s. - -# Constructor - - PowerRetraction(retraction::AbstractRetractionMethod) -""" -struct PowerRetraction{TR<:AbstractRetractionMethod} <: AbstractRetractionMethod - retraction::TR -end - -""" - InversePowerRetraction{TR<:AbstractInverseRetractionMethod} <: AbstractInverseRetractionMethod - -The `InversePowerRetraction` avoids ambiguities between dispatching on the [`AbstractPowerManifold`](@ref) -and dispatching on the [`AbstractInverseRetractionMethod`](@ref) and encapsulates this. -This container should only be used in rare cases outside of this package. Usually a -subtype of the [`AbstractPowerManifold`](@ref) should define a way how to treat -its [`AbstractRetractionMethod`](@ref)s. - -# Constructor - - InversePowerRetraction(inverse_retractions::AbstractInverseRetractionMethod...) -""" -struct InversePowerRetraction{TR<:AbstractInverseRetractionMethod} <: - AbstractInverseRetractionMethod - inverse_retraction::TR -end - -""" - PowerVectorTransport{TR<:AbstractVectorTransportMethod} <: - AbstractVectorTransportMethod - -The `PowerVectorTransport` avoids ambiguities between dispatching on the [`AbstractPowerManifold`](@ref) -and dispatching on the [`AbstractVectorTransportMethod`](@ref) and encapsulates this. -This container should only be used in rare cases outside of this package. Usually a -subtype of the [`AbstractPowerManifold`](@ref) should define a way how to treat -its [`AbstractVectorTransportMethod`](@ref)s. - -# Constructor - - PowerVectorTransport(method::AbstractVectorTransportMethod) -""" -struct PowerVectorTransport{TR<:AbstractVectorTransportMethod} <: - AbstractVectorTransportMethod - method::TR -end - """ PowerBasisData{TB<:AbstractArray} @@ -209,49 +155,44 @@ function allocate_result(M::PowerManifoldNested, f, x...) ] end end +# avoid ambituities - though usually not used +function allocate_result( + M::PowerManifoldNested, + f::typeof(get_coordinates), + p, + X, + B::AbstractBasis, +) + if representation_size(M.manifold) === () + return allocate(X, manifold_dimension(M)) + else + return [ + allocate_result(M.manifold, f, _access_nested(p, i), _access_nested(X, i), B) + for i in get_iterator(M) + ] + end +end function allocate_result(::PowerManifoldNestedReplacing, f, x...) return copy(x[1]) end - -for PowerRepr in [PowerManifoldNested, PowerManifoldNestedReplacing] - @eval begin - function allocate_result( - M::$PowerRepr, - f::typeof(get_coordinates), - p, - X, - B::AbstractBasis, - ) - return invoke( - allocate_result, - Tuple{AbstractManifold,typeof(get_coordinates),Any,Any,typeof(B)}, - M, - f, - p, - X, - B, - ) - end - function allocate_result( - M::$PowerRepr, - f::typeof(get_coordinates), - p, - X, - B::CachedBasis, - ) - return invoke( - allocate_result, - Tuple{AbstractManifold,typeof(get_coordinates),Any,Any,typeof(B)}, - M, - f, - p, - X, - B, - ) - end - end +# the following is not used but necessary to avoid ambiguities +function allocate_result( + M::PowerManifoldNestedReplacing, + f::typeof(get_coordinates), + p, + X, + B::AbstractBasis, +) + return invoke( + allocate_result, + Tuple{AbstractManifold,typeof(get_coordinates),Any,Any,AbstractBasis}, + M, + f, + p, + X, + B, + ) end - function allocate_result(M::PowerManifoldNested, f::typeof(get_vector), p, X) return [allocate_result(M.manifold, f, _access_nested(p, i)) for i in get_iterator(M)] end @@ -410,14 +351,6 @@ function get_basis(M::AbstractPowerManifold, p, B::DiagonalizingOrthonormalBasis ] return CachedBasis(B, PowerBasisData(vs)) end -for BT in ManifoldsBase.DISAMBIGUATION_BASIS_TYPES - if BT == DiagonalizingOrthonormalBasis - continue - end - eval(quote - @invoke_maker 3 AbstractBasis get_basis(M::AbstractPowerManifold, p, B::$BT) - end) -end """ get_component(M::AbstractPowerManifold, p, idx...) @@ -429,7 +362,7 @@ function get_component(M::AbstractPowerManifold, p, idx...) return _read(M, rep_size, p, idx) end -function get_coordinates(M::AbstractPowerManifold, p, X, B::DefaultOrthonormalBasis) +function get_coordinates(M::AbstractPowerManifold, p, X, B::AbstractBasis) rep_size = representation_size(M.manifold) vs = [ get_coordinates(M.manifold, _read(M, rep_size, p, i), _read(M, rep_size, X, i), B) for i in get_iterator(M) @@ -454,44 +387,37 @@ function get_coordinates( return reduce(vcat, reshape(vs, length(vs))) end -function get_coordinates!(M::AbstractPowerManifold, Y, p, X, B::DefaultOrthonormalBasis) +function get_coordinates!(M::AbstractPowerManifold, c, p, X, B::AbstractBasis) rep_size = representation_size(M.manifold) - dim = manifold_dimension(M.manifold) - v_iter = 1 for i in get_iterator(M) - # TODO: this view is really suboptimal when `dim` can be statically determined get_coordinates!( M.manifold, - view(Y, v_iter:(v_iter + dim - 1)), + _write_coordinates(M, c, i), _read(M, rep_size, p, i), _read(M, rep_size, X, i), B, ) - v_iter += dim end - return Y + return c end function get_coordinates!( M::AbstractPowerManifold, - Y, + c, p, X, B::CachedBasis{𝔽,<:AbstractBasis,<:PowerBasisData}, ) where {𝔽} rep_size = representation_size(M.manifold) - dim = manifold_dimension(M.manifold) - v_iter = 1 for i in get_iterator(M) get_coordinates!( M.manifold, - view(Y, v_iter:(v_iter + dim - 1)), + _write_coordinates(M, c, i), _read(M, rep_size, p, i), _read(M, rep_size, X, i), _access_nested(B.data.bases, i), ) - v_iter += dim end - return Y + return c end get_iterator(::PowerManifold{𝔽,<:AbstractManifold{𝔽},Tuple{N}}) where {𝔽,N} = Base.OneTo(N) @@ -502,77 +428,62 @@ get_iterator(::PowerManifold{𝔽,<:AbstractManifold{𝔽},Tuple{N}}) where { return Base.product(map(Base.OneTo, size_tuple)...) end -function get_vector!( +function get_vector( M::AbstractPowerManifold, - Y, p, - X, + c, B::CachedBasis{𝔽,<:AbstractBasis{𝔽},<:PowerBasisData}, ) where {𝔽} - dim = manifold_dimension(M.manifold) + Y = allocate_result(M, get_vector, p, c) rep_size = representation_size(M.manifold) - v_iter = 1 for i in get_iterator(M) - get_vector!( + Y[i...] = get_vector( M.manifold, - _write(M, rep_size, Y, i), _read(M, rep_size, p, i), - X[v_iter:(v_iter + dim - 1)], + _read_coordinates(M, c, i), _access_nested(B.data.bases, i), ) - v_iter += dim end return Y end -function get_vector!(M::AbstractPowerManifold, Y, p, X, B::DefaultOrthonormalBasis) - dim = manifold_dimension(M.manifold) +function get_vector!( + M::AbstractPowerManifold, + Y, + p, + c, + B::CachedBasis{𝔽,<:AbstractBasis{𝔽},<:PowerBasisData}, +) where {𝔽} rep_size = representation_size(M.manifold) - v_iter = 1 for i in get_iterator(M) get_vector!( M.manifold, _write(M, rep_size, Y, i), _read(M, rep_size, p, i), - X[v_iter:(v_iter + dim - 1)], - B, + _read_coordinates(M, c, i), + _access_nested(B.data.bases, i), ) - v_iter += dim end return Y end -function get_vector!( - M::PowerManifoldNestedReplacing, - Y, - p, - X, - B::CachedBasis{𝔽,<:AbstractBasis{𝔽},<:PowerBasisData}, -) where {𝔽} - dim = manifold_dimension(M.manifold) +function get_vector(M::AbstractPowerManifold, p, c, B::AbstractBasis) + Y = allocate_result(M, get_vector, p, c) rep_size = representation_size(M.manifold) - v_iter = 1 for i in get_iterator(M) - Y[i...] = get_vector( - M.manifold, - _read(M, rep_size, p, i), - X[v_iter:(v_iter + dim - 1)], - _access_nested(B.data.bases, i), - ) - v_iter += dim + Y[i...] = + get_vector(M.manifold, _read(M, rep_size, p, i), _read_coordinates(M, c, i), B) end return Y end -function get_vector!(M::PowerManifoldNestedReplacing, Y, p, X, B::DefaultOrthonormalBasis) - dim = manifold_dimension(M.manifold) +function get_vector!(M::AbstractPowerManifold, Y, p, c, B::AbstractBasis) rep_size = representation_size(M.manifold) - v_iter = 1 for i in get_iterator(M) - Y[i...] = get_vector( + get_vector!( M.manifold, + _write(M, rep_size, Y, i), _read(M, rep_size, p, i), - X[v_iter:(v_iter + dim - 1)], + _read_coordinates(M, c, i), B, ) - v_iter += dim end return Y end @@ -615,23 +526,10 @@ function injectivity_radius(M::AbstractPowerManifold, p) end return radius end +function injectivity_radius(M::AbstractPowerManifold, ::ExponentialRetraction) + return injectivity_radius(M.manifold) +end injectivity_radius(M::AbstractPowerManifold) = injectivity_radius(M.manifold) -eval( - quote - @invoke_maker 1 AbstractManifold injectivity_radius( - M::AbstractPowerManifold, - rm::AbstractRetractionMethod, - ) - end, -) -eval( - quote - @invoke_maker 1 AbstractManifold injectivity_radius( - M::AbstractPowerManifold, - rm::ExponentialRetraction, - ) - end, -) @doc raw""" inner(M::AbstractPowerManifold, p, X, Y) @@ -685,16 +583,22 @@ function Base.isapprox(M::AbstractPowerManifold, p, X, Y; kwargs...) end @doc raw""" - inverse_retract(M::AbstractPowerManifold, p, q, m::InversePowerRetraction) + inverse_retract(M::AbstractPowerManifold, p, q, m::AbstractInverseRetractionMethod) Compute the inverse retraction from `p` with respect to `q` on an [`AbstractPowerManifold`](@ref) `M` -using an [`InversePowerRetraction`](@ref), which by default encapsulates a inverse retraction -of the base manifold. Then this method is performed elementwise, so the encapsulated inverse +using an [`AbstractInverseRetractionMethod`](@ref). +Then this method is performed elementwise, so the inverse retraction method has to be one that is available on the base [`AbstractManifold`](@ref). """ inverse_retract(::AbstractPowerManifold, ::Any...) -function inverse_retract!(M::AbstractPowerManifold, X, p, q, method::InversePowerRetraction) +function inverse_retract!( + M::AbstractPowerManifold, + X, + p, + q, + m::AbstractInverseRetractionMethod = LogarithmicInverseRetraction(), +) rep_size = representation_size(M.manifold) for i in get_iterator(M) inverse_retract!( @@ -702,7 +606,7 @@ function inverse_retract!(M::AbstractPowerManifold, X, p, q, method::InversePowe _write(M, rep_size, X, i), _read(M, rep_size, p, i), _read(M, rep_size, q, i), - method.inverse_retraction, + m, ) end return X @@ -712,7 +616,7 @@ function inverse_retract!( X, p, q, - method::InversePowerRetraction, + m::AbstractInverseRetractionMethod = LogarithmicInverseRetraction(), ) rep_size = representation_size(M.manifold) for i in get_iterator(M) @@ -720,28 +624,11 @@ function inverse_retract!( M.manifold, _read(M, rep_size, p, i), _read(M, rep_size, q, i), - method.inverse_retraction, + m, ) end return X end -# log and power have to be explicitly stated to avoid an ambiguity in the third case with AbstractPower -@invoke_maker 5 AbstractInverseRetractionMethod inverse_retract!( - M::AbstractPowerManifold, - X, - q, - p, - m::LogarithmicInverseRetraction, -) -function inverse_retract!( - M::AbstractPowerManifold, - X, - q, - p, - m::AbstractInverseRetractionMethod, -) - return inverse_retract!(M, X, q, p, InversePowerRetraction(m)) -end @doc raw""" log(M::AbstractPowerManifold, p, q) @@ -894,6 +781,16 @@ Base.@propagate_inbounds @inline function _read( return _read(M, rep_size, x, (i,)) end +Base.@propagate_inbounds @inline function _read_coordinates( + M::AbstractPowerManifold, + c::AbstractArray, + i::Int, +) + d = manifold_dimension(M.manifold) + k = LinearIndices(power_dimensions(M))[i] + return c[((k - 1) * d + 1):(k * d)] +end + Base.@propagate_inbounds @inline function _read( ::Union{PowerManifoldNested,PowerManifoldNestedReplacing}, rep_size::Tuple, @@ -910,16 +807,22 @@ end @doc raw""" - retract(M::AbstractPowerManifold, p, X, method::PowerRetraction) + retract(M::AbstractPowerManifold, p, X, method::AbstractRetractionMethod) Compute the retraction from `p` with tangent vector `X` on an [`AbstractPowerManifold`](@ref) `M` -using a [`PowerRetraction`](@ref), which by default encapsulates a retraction of the -base manifold. Then this method is performed elementwise, so the encapsulated retraction +using a [`AbstractRetractionMethod`](@ref). +Then this method is performed elementwise, so the retraction method has to be one that is available on the base [`AbstractManifold`](@ref). """ retract(::AbstractPowerManifold, ::Any...) -function retract!(M::AbstractPowerManifold, q, p, X, method::PowerRetraction) +function retract!( + M::AbstractPowerManifold, + q, + p, + X, + m::AbstractRetractionMethod = ExponentialRetraction(), +) rep_size = representation_size(M.manifold) for i in get_iterator(M) retract!( @@ -927,36 +830,25 @@ function retract!(M::AbstractPowerManifold, q, p, X, method::PowerRetraction) _write(M, rep_size, q, i), _read(M, rep_size, p, i), _read(M, rep_size, X, i), - method.retraction, - ) - end - return q -end -function retract!(M::PowerManifoldNestedReplacing, q, p, X, method::PowerRetraction) - rep_size = representation_size(M.manifold) - for i in get_iterator(M) - q[i...] = retract( - M.manifold, - _read(M, rep_size, p, i), - _read(M, rep_size, X, i), - method.retraction, + m, ) end return q end -# exp and power have to be explicitly stated, since the third case otherwise introduces and ambiguity. -@invoke_maker 5 AbstractRetractionMethod retract!( - M::AbstractPowerManifold, +function retract!( + M::PowerManifoldNestedReplacing, q, p, X, - m::ExponentialRetraction, + m::AbstractRetractionMethod = ExponentialRetraction(), ) -function retract!(M::AbstractPowerManifold, q, p, X, m::AbstractRetractionMethod) - return retract!(M, q, p, X, PowerRetraction(m)) + rep_size = representation_size(M.manifold) + for i in get_iterator(M) + q[i...] = retract(M.manifold, _read(M, rep_size, p, i), _read(M, rep_size, X, i), m) + end + return q end - """ set_component!(M::AbstractPowerManifold, q, p, idx...) @@ -1008,59 +900,76 @@ function Base.show( return nothing end -function vector_transport_direction(M::AbstractPowerManifold, p, X, d) - return vector_transport_direction(M, p, X, d, PowerVectorTransport(ParallelTransport())) -end - -function vector_transport_direction!(M::AbstractPowerManifold, Y, p, X, d) - return vector_transport_direction!( - M, - Y, - p, - X, - d, - PowerVectorTransport(ParallelTransport()), - ) -end function vector_transport_direction!( M::AbstractPowerManifold, Y, p, X, d, - m::AbstractVectorTransportMethod, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), ) - return vector_transport_direction!(M, Y, p, X, d, PowerVectorTransport(m)) + rep_size = representation_size(M.manifold) + for i in get_iterator(M) + vector_transport_direction!( + M.manifold, + _write(M, rep_size, Y, i), + _read(M, rep_size, p, i), + _read(M, rep_size, X, i), + _read(M, rep_size, d, i), + m, + ) + end + return Y end -function vector_transport_direction!( +function vector_transport_direction( M::AbstractPowerManifold, - Y, p, X, d, - m::PowerVectorTransport, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), ) + Y = allocate_result(M, vector_transport_direction, p, X, d) rep_size = representation_size(M.manifold) for i in get_iterator(M) - vector_transport_direction!( + Y[i...] = vector_transport_direction( M.manifold, - _write(M, rep_size, Y, i), _read(M, rep_size, p, i), _read(M, rep_size, X, i), _read(M, rep_size, d, i), - m.method, + m, ) end return Y end + function vector_transport_direction!( M::PowerManifoldNestedReplacing, Y, p, X, d, - m::PowerVectorTransport, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), +) + rep_size = representation_size(M.manifold) + for i in get_iterator(M) + Y[i...] = vector_transport_direction( + M.manifold, + _read(M, rep_size, p, i), + _read(M, rep_size, X, i), + _read(M, rep_size, d, i), + m, + ) + end + return Y +end +function vector_transport_direction( + M::PowerManifoldNestedReplacing, + p, + X, + d, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), ) + Y = allocate_result(M, vector_transport_direction, p, X, d) rep_size = representation_size(M.manifold) for i in get_iterator(M) Y[i...] = vector_transport_direction( @@ -1068,49 +977,54 @@ function vector_transport_direction!( _read(M, rep_size, p, i), _read(M, rep_size, X, i), _read(M, rep_size, d, i), - m.method, + m, ) end return Y end @doc raw""" - vector_transport_to(M::AbstractPowerManifold, p, X, q, method::PowerVectorTransport) + vector_transport_to(M::AbstractPowerManifold, p, X, q, method::AbstractVectorTransportMethod) Compute the vector transport the tangent vector `X`at `p` to `q` on the -[`PowerManifold`](@ref) `M` using an [`PowerVectorTransport`](@ref) `m`. +[`PowerManifold`](@ref) `M` using an [`AbstractVectorTransportMethod`](@ref) `m`. This method is performed elementwise, i.e. the method `m` has to be implemented on the base manifold. """ -vector_transport_to(::AbstractPowerManifold, ::Any, ::Any, ::Any, ::PowerVectorTransport) -function vector_transport_to(M::AbstractPowerManifold, p, X, q) - return vector_transport_to(M, p, X, q, PowerVectorTransport(ParallelTransport())) -end - -function vector_transport_to!(M::AbstractPowerManifold, Y, p, X, q) - return vector_transport_to!(M, Y, p, X, q, PowerVectorTransport(ParallelTransport())) -end -function vector_transport_to!(M::AbstractPowerManifold, Y, p, X, q, m::PowerVectorTransport) +vector_transport_to( + ::AbstractPowerManifold, + ::Any, + ::Any, + ::Any, + ::AbstractVectorTransportMethod, +) +function vector_transport_to( + M::AbstractPowerManifold, + p, + X, + q, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), +) rep_size = representation_size(M.manifold) + Y = allocate_result(M, vector_transport_to, p, X) for i in get_iterator(M) - vector_transport_to!( + Y[i...] = vector_transport_to( M.manifold, - _write(M, rep_size, Y, i), _read(M, rep_size, p, i), _read(M, rep_size, X, i), _read(M, rep_size, q, i), - m.method, + m, ) end return Y end function vector_transport_to!( - M::PowerManifoldNestedReplacing, + M::AbstractPowerManifold, Y, p, X, q, - m::PowerVectorTransport, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), ) rep_size = representation_size(M.manifold) for i in get_iterator(M) @@ -1119,35 +1033,30 @@ function vector_transport_to!( _read(M, rep_size, p, i), _read(M, rep_size, X, i), _read(M, rep_size, q, i), - m.method, + m, ) end return Y end - function vector_transport_to!( - M::AbstractPowerManifold, + M::PowerManifoldNestedReplacing, Y, p, X, q, - m::AbstractVectorTransportMethod, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), ) - return vector_transport_to!(M, Y, p, X, q, PowerVectorTransport(m)) -end -for VT in ManifoldsBase.VECTOR_TRANSPORT_DISAMBIGUATION - eval( - quote - @invoke_maker 6 AbstractVectorTransportMethod vector_transport_to!( - M::AbstractPowerManifold, - Y, - p, - X, - q, - B::$VT, - ) - end, - ) + rep_size = representation_size(M.manifold) + for i in get_iterator(M) + Y[i...] = vector_transport_to( + M.manifold, + _read(M, rep_size, p, i), + _read(M, rep_size, X, i), + _read(M, rep_size, q, i), + m, + ) + end + return Y end """ @@ -1169,13 +1078,25 @@ end return _write(M, rep_size, x, (i,)) end -@inline function _write(::PowerManifoldNested, rep_size::Tuple, x::AbstractArray, i::Tuple) +@inline function _write( + ::AbstractPowerManifold, + rep_size::Tuple, + x::AbstractArray, + i::Tuple, +) return view(x[i...], rep_size_to_colons(rep_size)...) end @inline function _write(::PowerManifoldNested, ::Tuple{}, x::AbstractArray, i::Tuple) return view(x, i...) end +@inline function _write_coordinates(M::AbstractPowerManifold, c::AbstractVector, i::Int) + d = manifold_dimension(M.manifold) + k = LinearIndices(power_dimensions(M))[i] + return view(c, ((k - 1) * d + 1):(k * d)) +end + + function zero_vector!(M::AbstractPowerManifold, X, p) rep_size = representation_size(M.manifold) for i in get_iterator(M) diff --git a/src/ValidationManifold.jl b/src/ValidationManifold.jl index 3355f65d..1d640f61 100644 --- a/src/ValidationManifold.jl +++ b/src/ValidationManifold.jl @@ -9,8 +9,7 @@ encapsulated/stripped automatically when needed. This manifold is a decorator for a manifold, i.e. it decorates a [`AbstractManifold`](@ref) `M` with types points, vectors, and covectors. """ -struct ValidationManifold{𝔽,M<:AbstractManifold{𝔽}} <: - AbstractDecoratorManifold{𝔽,DefaultDecoratorType} +struct ValidationManifold{𝔽,M<:AbstractManifold{𝔽}} <: AbstractDecoratorManifold{𝔽} manifold::M end @@ -72,6 +71,8 @@ array_value(p::AbstractArray) = p array_value(p::ValidationMPoint) = p.value array_value(X::ValidationFibreVector) = X.value +decorated_manifold(M::ValidationManifold) = M.manifold + function check_point(M::ValidationManifold, p; kwargs...) return check_point(M.manifold, array_value(p); kwargs...) end @@ -226,59 +227,18 @@ function get_basis( end return Ξ end -for BT in DISAMBIGUATION_BASIS_TYPES - if BT <: - Union{AbstractOrthonormalBasis,CachedBasis{𝔽,<:AbstractOrthonormalBasis} where 𝔽} - CT = AbstractOrthonormalBasis - elseif BT <: - Union{AbstractOrthogonalBasis,CachedBasis{𝔽,<:AbstractOrthogonalBasis} where 𝔽} - CT = AbstractOrthogonalBasis - else - CT = AbstractBasis - end - eval(quote - @invoke_maker 3 $CT get_basis(M::ValidationManifold, p, B::$BT; kwargs...) - end) -end function get_coordinates(M::ValidationManifold, p, X, B::AbstractBasis; kwargs...) is_point(M, p, true; kwargs...) is_vector(M, p, X, true; kwargs...) return get_coordinates(M.manifold, p, X, B) end -for BT in DISAMBIGUATION_BASIS_TYPES - eval( - quote - @invoke_maker 4 AbstractBasis get_coordinates( - M::ValidationManifold, - p, - X, - B::$BT; - kwargs..., - ) - end, - ) -end function get_coordinates!(M::ValidationManifold, Y, p, X, B::AbstractBasis; kwargs...) is_vector(M, p, X, true; kwargs...) get_coordinates!(M.manifold, Y, p, X, B) return Y end -for BT in [DISAMBIGUATION_BASIS_TYPES..., DISAMBIGUATION_COTANGENT_BASIS_TYPES...] - eval( - quote - @invoke_maker 5 AbstractBasis get_coordinates!( - M::ValidationManifold, - Y, - p, - X, - B::$BT; - kwargs..., - ) - end, - ) -end function get_vector(M::ValidationManifold, p, X, B::AbstractBasis; kwargs...) is_point(M, p, true; kwargs...) @@ -287,19 +247,6 @@ function get_vector(M::ValidationManifold, p, X, B::AbstractBasis; kwargs...) size(Y) == representation_size(M) || error("Incorrect size of tangent vector Y") return Y end -for BT in DISAMBIGUATION_BASIS_TYPES - eval( - quote - @invoke_maker 4 AbstractBasis get_vector( - M::ValidationManifold, - p, - X, - B::$BT; - kwargs..., - ) - end, - ) -end function get_vector!(M::ValidationManifold, Y, p, X, B::AbstractBasis; kwargs...) is_point(M, p, true; kwargs...) @@ -308,20 +255,6 @@ function get_vector!(M::ValidationManifold, Y, p, X, B::AbstractBasis; kwargs... size(Y) == representation_size(M) || error("Incorrect size of tangent vector Y") return Y end -for BT in [DISAMBIGUATION_BASIS_TYPES..., DISAMBIGUATION_COTANGENT_BASIS_TYPES...] - eval( - quote - @invoke_maker 5 AbstractBasis get_vector!( - M::ValidationManifold, - Y, - p, - X, - B::$BT; - kwargs..., - ) - end, - ) -end injectivity_radius(M::ValidationManifold) = injectivity_radius(M.manifold) function injectivity_radius(M::ValidationManifold, method::AbstractRetractionMethod) @@ -414,6 +347,20 @@ function project!(M::ValidationManifold, Y, p, X; kwargs...) return Y end +function vector_transport_along( + M::ValidationManifold, + p, + X, + c::AbstractVector, + m::AbstractVectorTransportMethod; + kwargs..., +) + is_vector(M, p, X, true; kwargs...) + Y = vector_transport_along(M.manifold, array_value(p), array_value(X), c, m) + is_vector(M, c[end], Y, true; kwargs...) + return Y +end + function vector_transport_along!( M::ValidationManifold, Y, @@ -435,21 +382,21 @@ function vector_transport_along!( is_vector(M, c[end], Y, true; kwargs...) return Y end -for VT in VECTOR_TRANSPORT_DISAMBIGUATION - eval( - quote - @invoke_maker 6 AbstractVectorTransportMethod vector_transport_along!( - M::ValidationManifold, - vto, - x, - v, - c::AbstractVector, - B::$VT, - ) - end, - ) -end +function vector_transport_to( + M::ValidationManifold, + p, + X, + q, + m::AbstractVectorTransportMethod; + kwargs..., +) + is_point(M, q, true; kwargs...) + is_vector(M, p, X, true; kwargs...) + Y = vector_transport_to(M.manifold, array_value(p), array_value(X), array_value(q), m) + is_vector(M, q, Y, true; kwargs...) + return Y +end function vector_transport_to!( M::ValidationManifold, Y, @@ -473,30 +420,6 @@ function vector_transport_to!( return Y end -for T in [ - PoleLadderTransport, - ProjectionTransport, - ScaledVectorTransport, - SchildsLadderTransport, -] - @eval begin - function vector_transport_to!(M::ValidationManifold, Y, p, X, q, m::$T; kwargs...) - is_point(M, q, true; kwargs...) - is_vector(M, p, X, true; kwargs...) - vector_transport_to!( - M.manifold, - array_value(Y), - array_value(p), - array_value(X), - array_value(q), - m, - ) - is_vector(M, q, Y, true; kwargs...) - return Y - end - end -end - function zero_vector(M::ValidationManifold, p; kwargs...) is_point(M, p, true; kwargs...) w = zero_vector(M.manifold, array_value(p)) diff --git a/src/bases.jl b/src/bases.jl index b49a2051..9156860c 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -8,7 +8,7 @@ Every vector space `fiber` is supposed to provide: * a method of constructing vectors, * basic operations: addition, subtraction, multiplication by a scalar and negation (unary minus), -* [`zero_vector(fiber, p)`](@ref Main.Manifolds.zero_vector) to construct zero vectors at point `p`, +* `zero_vector(fiber, p)` to construct zero vectors at point `p`, * `allocate(X)` and `allocate(X, T)` for vector `X` and type `T`, * `copyto!(X, Y)` for vectors `X` and `Y`, * `number_eltype(v)` for vector `v`, @@ -17,10 +17,10 @@ Every vector space `fiber` is supposed to provide: Optionally: * inner product via `inner` (used to provide Riemannian metric on vector bundles), -* [`flat`](@ref Main.Manifolds.flat) and [`sharp`](@ref Main.Manifolds.sharp), +* [`flat`](https://juliamanifolds.github.io/Manifolds.jl/stable/features/atlases.html#Manifolds.flat-Tuple{AbstractManifold,%20Any,%20Any}) and [`sharp`](https://juliamanifolds.github.io/Manifolds.jl/stable/features/atlases.html#Manifolds.sharp-Tuple{AbstractManifold,%20Any,%20Any}), * `norm` (by default uses `inner`), * [`project`](@ref) (for embedded vector spaces), -* [`representation_size`](@ref) (if support for [`ProductArray`](@ref Main.Manifolds.ProductArray) is desired), +* [`representation_size`](@ref), * broadcasting for basic operations. """ abstract type VectorSpaceType end @@ -241,7 +241,7 @@ struct CachedBasis{𝔽,B,V} <: AbstractBasis{𝔽,TangentSpaceType} where {B<:AbstractBasis{𝔽,TangentSpaceType},V} data::V end -function CachedBasis(::B, data::V) where {V,𝔽,B<:AbstractBasis{𝔽,TangentSpaceType}} +function CachedBasis(::B, data::V) where {V,𝔽,B<:AbstractBasis{𝔽,<:TangentSpaceType}} return CachedBasis{𝔽,B,V}(data) end function CachedBasis(basis::CachedBasis) # avoid double encapsulation @@ -266,24 +266,6 @@ const all_uncached_bases{T} = Union{ DefaultOrthogonalBasis{<:Any,T}, DefaultOrthonormalBasis{<:Any,T}, } -const DISAMBIGUATION_BASIS_TYPES = [ - CachedBasis, - DefaultBasis, - DefaultBasis{<:Any,TangentSpaceType}, - DefaultOrthonormalBasis, - DefaultOrthonormalBasis{<:Any,TangentSpaceType}, - DefaultOrthogonalBasis, - DefaultOrthogonalBasis{<:Any,TangentSpaceType}, - DiagonalizingOrthonormalBasis, - ProjectedOrthonormalBasis{:svd,ℝ}, - ProjectedOrthonormalBasis{:gram_schmidt,ℝ}, - VeeOrthogonalBasis, -] -const DISAMBIGUATION_COTANGENT_BASIS_TYPES = [ - DefaultBasis{<:Any,CotangentSpaceType}, - DefaultOrthonormalBasis{<:Any,CotangentSpaceType}, - DefaultOrthogonalBasis{<:Any,CotangentSpaceType}, -] """ allocate_coordinates(M::AbstractManifold, p, T, n::Int) @@ -291,28 +273,19 @@ const DISAMBIGUATION_COTANGENT_BASIS_TYPES = [ Allocate vector of coordinates of length `n` of type `T` of a vector at point `p` on manifold `M`. """ -allocate_coordinates(M::AbstractManifold, p, T, n::Int) = allocate(p, T, n) - -function allocate_result( - M::AbstractManifold, - f::typeof(get_coordinates), - p, - X, - B::AbstractBasis, -) - T = allocate_result_type(M, f, (p, X)) - return allocate_coordinates(M, p, T, number_of_coordinates(M, B)) +allocate_coordinates(::AbstractManifold, p, T, n::Int) = allocate(p, T, n) +function allocate_coordinates(M::AbstractManifold, p::Int, T, n::Int) + return (representation_size(M) == () && n == 0) ? zero(T) : zeros(T, n) end - function allocate_result( M::AbstractManifold, f::typeof(get_coordinates), p, X, - B::CachedBasis, + basis::AbstractBasis, ) T = allocate_result_type(M, f, (p, X)) - return allocate_coordinates(M, p, T, number_of_coordinates(M, B)) + return allocate_coordinates(M, p, T, number_of_coordinates(M, basis)) end @inline function allocate_result_type( @@ -360,16 +333,16 @@ such that $v^i(v_j) = δ^i_j$, where $δ^i_j$ is the Kronecker delta symbol: \end{cases} ```` """ -dual_basis(M::AbstractManifold, p, B::AbstractBasis) +dual_basis(M::AbstractManifold, p, B::AbstractBasis) = _dual_basis(M, p, B) -function dual_basis( +function _dual_basis( ::AbstractManifold, p, ::DefaultOrthonormalBasis{𝔽,TangentSpaceType}, ) where {𝔽} return DefaultOrthonormalBasis{𝔽}(CotangentSpace) end -function dual_basis( +function _dual_basis( ::AbstractManifold, p, ::DefaultOrthonormalBasis{𝔽,CotangentSpaceType}, @@ -391,7 +364,7 @@ function _euclidean_basis_vector(p, i) end """ - get_basis(M::AbstractManifold, p, B::AbstractBasis) -> CachedBasis + get_basis(M::AbstractManifold, p, B::AbstractBasis; kwargs...) -> CachedBasis Compute the basis vectors of the tangent space at a point on manifold `M` represented by `p`. @@ -402,31 +375,26 @@ the function [`get_vectors`](@ref) needs to be used to retrieve the basis vector See also: [`get_coordinates`](@ref), [`get_vector`](@ref) """ -get_basis(M::AbstractManifold, p, B::AbstractBasis) -@decorator_transparent_signature get_basis( - M::AbstractDecoratorManifold, - p, - B::AbstractBasis, -) -function decorator_transparent_dispatch(::typeof(get_basis), ::AbstractManifold, args...) - return Val(:parent) +function get_basis(M::AbstractManifold, p, B::AbstractBasis; kwargs...) + return _get_basis(M, p, B; kwargs...) end -function get_basis( +function _get_basis( M::AbstractManifold, p, - B::DefaultOrthonormalBasis{<:Any,TangentSpaceType}, + B::DefaultOrthonormalBasis{<:Any,TangentSpaceType}; + kwargs..., ) - dim = manifold_dimension(M) + dim = number_of_coordinates(M, B) return CachedBasis( B, [get_vector(M, p, [ifelse(i == j, 1, 0) for j in 1:dim], B) for i in 1:dim], ) end -function get_basis(::AbstractManifold, ::Any, B::CachedBasis) +function _get_basis(::AbstractManifold, ::Any, B::CachedBasis) return B end -function get_basis(M::AbstractManifold, p, B::ProjectedOrthonormalBasis{:svd,ℝ}) +function _get_basis(M::AbstractManifold, p, B::ProjectedOrthonormalBasis{:svd,ℝ}) S = representation_size(M) PS = prod(S) dim = manifold_dimension(M) @@ -449,7 +417,7 @@ function get_basis(M::AbstractManifold, p, B::ProjectedOrthonormalBasis{:svd,ℝ end return CachedBasis(B, vecs) end -function get_basis( +function _get_basis( M::AbstractManifold, p, B::ProjectedOrthonormalBasis{:gram_schmidt,ℝ}; @@ -459,18 +427,37 @@ function get_basis( V = gram_schmidt(M, p, E; kwargs...) return CachedBasis(B, V) end -for BT in DISAMBIGUATION_BASIS_TYPES - eval( - quote - @decorator_transparent_signature get_basis( - M::AbstractDecoratorManifold, - p, - B::$BT, - ) - end, - ) +function _get_basis(M::AbstractManifold, p, B::VeeOrthogonalBasis) + return get_basis_vee(M, p, number_system(B)) +end +function get_basis_vee(M::AbstractManifold, p, N) + return get_basis(M, p, DefaultOrthogonalBasis(N)) +end + +function _get_basis(M::AbstractManifold, p, B::DefaultBasis) + return get_basis_default(M, p, number_system(B)) +end +function get_basis_default(M::AbstractManifold, p, N) + return get_basis(M, p, DefaultOrthogonalBasis(N)) end +function _get_basis(M::AbstractManifold, p, B::DefaultOrthogonalBasis) + return get_basis_orthogonal(M, p, number_system(B)) +end +function get_basis_orthogonal(M::AbstractManifold, p, N) + return get_basis(M, p, DefaultOrthonormalBasis(N)) +end + +function _get_basis(M::AbstractManifold, p, B::DiagonalizingOrthonormalBasis) + return get_basis_diagonalizing(M, p, B) +end +function get_basis_diagonalizing end + +function _get_basis(M::AbstractManifold, p, B::DefaultOrthonormalBasis) + return get_basis_orthonormal(M, p, number_system(B)) +end +function get_basis_orthonormal end + @doc raw""" get_coordinates(M::AbstractManifold, p, X, B::AbstractBasis) get_coordinates(M::AbstractManifold, p, X, B::CachedBasis) @@ -489,62 +476,113 @@ requires either a dual basis or the cached basis to be selfdual, for example ort See also: [`get_vector`](@ref), [`get_basis`](@ref) """ function get_coordinates(M::AbstractManifold, p, X, B::AbstractBasis) - Y = allocate_result(M, get_coordinates, p, X, B) - return get_coordinates!(M, Y, p, X, B) + return _get_coordinates(M, p, X, B) +end + +function _get_coordinates(M::AbstractManifold, p, X, B::VeeOrthogonalBasis) + return get_coordinates_vee(M, p, X, number_system(B)) +end +function get_coordinates_vee(M::AbstractManifold, p, X, N) + return get_coordinates(M, p, X, DefaultOrthogonalBasis(N)) end -@decorator_transparent_signature get_coordinates( - M::AbstractDecoratorManifold, + +function _get_coordinates(M::AbstractManifold, p, X, B::DefaultBasis) + return get_coordinates_default(M, p, X, number_system(B)) +end +function get_coordinates_default(M::AbstractManifold, p, X, N::AbstractNumbers) + return get_coordinates(M, p, X, DefaultOrthogonalBasis(N)) +end + +function _get_coordinates(M::AbstractManifold, p, X, B::DefaultOrthogonalBasis) + return get_coordinates_orthogonal(M, p, X, number_system(B)) +end +function get_coordinates_orthogonal(M::AbstractManifold, p, X, N) + return get_coordinates_orthonormal(M, p, X, N) +end + +function _get_coordinates(M::AbstractManifold, p, X, B::DefaultOrthonormalBasis) + return get_coordinates_orthonormal(M, p, X, number_system(B)) +end +function get_coordinates_orthonormal(M::AbstractManifold, p, X, N) + Y = allocate_result(M, get_coordinates, p, X, DefaultOrthonormalBasis(N)) + return get_coordinates_orthonormal!(M, Y, p, X, N) +end + +function _get_coordinates(M::AbstractManifold, p, X, B::DiagonalizingOrthonormalBasis) + return get_coordinates_diagonalizing(M, p, X, B) +end +function get_coordinates_diagonalizing( + M::AbstractManifold, p, X, - B::AbstractBasis, + B::DiagonalizingOrthonormalBasis, ) -function decorator_transparent_dispatch( - ::typeof(get_coordinates), - ::AbstractManifold, - args..., -) - return Val(:parent) + Y = allocate_result(M, get_coordinates, p, X, B) + return get_coordinates_diagonalizing!(M, Y, p, X, B) end -function get_coordinates!(M::AbstractManifold, Y, p, X, B::AbstractBasis) - return error( - "get_coordinates! not implemented for manifold of type $(typeof(M)) coordinates of type $(typeof(Y)), a point of type $(typeof(p)), tangent vector of type $(typeof(X)) and basis of type $(typeof(B)).", - ) +function _get_coordinates(M::AbstractManifold, p, X, B::CachedBasis) + return get_coordinates_cached(M, number_system(M), p, X, B, number_system(B)) end -@decorator_transparent_signature get_coordinates!( - M::AbstractDecoratorManifold, - Y, +function get_coordinates_cached( + M::AbstractManifold, + ::ComplexNumbers, p, X, - B::AbstractBasis, + B::CachedBasis, + ::RealNumbers, ) -for BT in [DISAMBIGUATION_BASIS_TYPES..., DISAMBIGUATION_COTANGENT_BASIS_TYPES...] - eval( - quote - @decorator_transparent_signature get_coordinates!( - M::AbstractDecoratorManifold, - Y, - p, - X, - B::$BT, - ) - end, - ) + return map(vb -> conj(inner(M, p, X, vb)), get_vectors(M, p, B)) +end +function get_coordinates_cached( + M::AbstractManifold, + ::𝔽, + p, + X, + C::CachedBasis, + ::𝔽, +) where {𝔽} + return map(vb -> real(inner(M, p, X, vb)), get_vectors(M, p, C)) +end + +function get_coordinates!(M::AbstractManifold, Y, p, X, B::AbstractBasis) + return _get_coordinates!(M, Y, p, X, B) +end +function _get_coordinates!(M::AbstractManifold, Y, p, X, B::VeeOrthogonalBasis) + return get_coordinates_vee!(M, Y, p, X, number_system(B)) +end +function get_coordinates_vee!(M::AbstractManifold, Y, p, X, N) + return get_coordinates!(M, Y, p, X, DefaultOrthogonalBasis(N)) end -function get_coordinates!(M::AbstractManifold, Y, p, X, B::VeeOrthogonalBasis) - return get_coordinates!(M, Y, p, X, DefaultOrthogonalBasis(number_system(B))) +function _get_coordinates!(M::AbstractManifold, Y, p, X, B::DefaultBasis) + return get_coordinates_default!(M, Y, p, X, number_system(B)) +end +function get_coordinates_default!(M::AbstractManifold, Y, p, X, N) + return get_coordinates!(M, Y, p, X, DefaultOrthogonalBasis(N)) +end + +function _get_coordinates!(M::AbstractManifold, Y, p, X, B::DefaultOrthogonalBasis) + return get_coordinates_orthogonal!(M, Y, p, X, number_system(B)) +end +function get_coordinates_orthogonal!(M::AbstractManifold, Y, p, X, N) + return get_coordinates!(M, Y, p, X, DefaultOrthonormalBasis(N)) end -function get_coordinates!(M::AbstractManifold, Y, p, X, B::DefaultBasis) - return get_coordinates!(M, Y, p, X, DefaultOrthogonalBasis(number_system(B))) + +function _get_coordinates!(M::AbstractManifold, Y, p, X, B::DefaultOrthonormalBasis) + return get_coordinates_orthonormal!(M, Y, p, X, number_system(B)) end -function get_coordinates!(M::AbstractManifold, Y, p, X, B::DefaultOrthogonalBasis) - return get_coordinates!(M, Y, p, X, DefaultOrthonormalBasis(number_system(B))) +function get_coordinates_orthonormal! end + +function _get_coordinates!(M::AbstractManifold, Y, p, X, B::DiagonalizingOrthonormalBasis) + return get_coordinates_diagonalizing!(M, Y, p, X, B) end -function get_coordinates!(M::AbstractManifold, Y, p, X, B::CachedBasis) - return _get_coordinates!(M, number_system(M), Y, p, X, B, number_system(B)) +function get_coordinates_diagonalizing! end + +function _get_coordinates!(M::AbstractManifold, Y, p, X, B::CachedBasis) + return get_coordinates_cached!(M, number_system(M), Y, p, X, B, number_system(B)) end -function _get_coordinates!( +function get_coordinates_cached!( M::AbstractManifold, ::ComplexNumbers, Y, @@ -556,21 +594,21 @@ function _get_coordinates!( map!(vb -> conj(inner(M, p, X, vb)), Y, get_vectors(M, p, B)) return Y end -function _get_coordinates!( +function get_coordinates_cached!( M::AbstractManifold, - a::𝔽, + ::𝔽, Y, p, X, C::CachedBasis, - b::𝔽, + ::𝔽, ) where {𝔽} map!(vb -> real(inner(M, p, X, vb)), Y, get_vectors(M, p, C)) return Y end """ - get_vector(M::AbstractManifold, p, X, B::AbstractBasis) + X = get_vector(M::AbstractManifold, p, c, B::AbstractBasis) Convert a one-dimensional vector of coefficients in a basis `B` of the tangent space at `p` on manifold `M` to a tangent vector `X` at `p`. @@ -584,58 +622,113 @@ requires either a dual basis or the cached basis to be selfdual, for example ort See also: [`get_coordinates`](@ref), [`get_basis`](@ref) """ -function get_vector(M::AbstractManifold, p, X, B::AbstractBasis) - Y = allocate_result(M, get_vector, p, X) - return get_vector!(M, Y, p, X, B) +function get_vector(M::AbstractManifold, p, c, B::AbstractBasis) + return _get_vector(M, p, c, B) end -@decorator_transparent_signature get_vector( - M::AbstractDecoratorManifold, - p, - X, - B::AbstractBasis, -) -function decorator_transparent_dispatch(::typeof(get_vector), ::AbstractManifold, args...) - return Val(:parent) + +function _get_vector(M::AbstractManifold, p, c, B::VeeOrthogonalBasis) + return get_vector_vee(M, p, c, number_system(B)) +end +function get_vector_vee(M::AbstractManifold, p, c, N) + return get_vector(M, p, c, DefaultOrthogonalBasis(N)) end -function get_vector!(M::AbstractManifold, Y, p, X, B::AbstractBasis) - return error( - "get_vector! not implemented for manifold of type $(typeof(M)) vector of type $(typeof(Y)), a point of type $(typeof(p)), coordinates of type $(typeof(X)) and basis of type $(typeof(B)).", - ) +function _get_vector(M::AbstractManifold, p, c, B::DefaultBasis) + return get_vector_default(M, p, c, number_system(B)) end -@decorator_transparent_signature get_vector!( - M::AbstractDecoratorManifold, - Y, +function get_vector_default(M::AbstractManifold, p, c, N) + return get_vector(M, p, c, DefaultOrthogonalBasis(N)) +end + +function _get_vector(M::AbstractManifold, p, c, B::DefaultOrthogonalBasis) + return get_vector_orthogonal(M, p, c, number_system(B)) +end +function get_vector_orthogonal(M::AbstractManifold, p, c, N) + return get_vector_orthonormal(M, p, c, N) +end + +function _get_vector(M::AbstractManifold, p, c, B::DefaultOrthonormalBasis) + return get_vector_orthonormal(M, p, c, number_system(B)) +end +function get_vector_orthonormal(M::AbstractManifold, p, c, N) + B = DefaultOrthonormalBasis(N) + Y = allocate_result(M, get_vector, p, c) + return get_vector!(M, Y, p, c, B) +end + +function _get_vector(M::AbstractManifold, p, c, B::DiagonalizingOrthonormalBasis) + return get_vector_diagonalizing(M, p, c, B) +end +function get_vector_diagonalizing( + M::AbstractManifold, p, - X, - B::AbstractBasis, + c, + B::DiagonalizingOrthonormalBasis, ) -for BT in [DISAMBIGUATION_BASIS_TYPES..., DISAMBIGUATION_COTANGENT_BASIS_TYPES...] - eval( - quote - @decorator_transparent_signature get_vector!( - M::AbstractDecoratorManifold, - Y, - p, - X, - B::$BT, - ) - end, - ) + Y = allocate_result(M, get_vector, p, c) + return get_vector!(M, Y, p, c, B) end +function _get_vector(M::AbstractManifold, p, c, B::CachedBasis) + return get_vector_cached(M, p, c, B) +end _get_vector_cache_broadcast(::Any) = Val(true) +function get_vector_cached(M::AbstractManifold, p, X, B::CachedBasis) + # quite convoluted but: + # 1) preserves the correct `eltype` + # 2) guarantees a reasonable array type `Y` + # (for example scalar * `SizedValidation` is an `SArray`) + bvectors = get_vectors(M, p, B) + if _get_vector_cache_broadcast(bvectors[1]) === Val(false) + Xt = X[1] * bvectors[1] + for i in 2:length(X) + copyto!(Xt, Xt + X[i] * bvectors[i]) + end + else + Xt = X[1] .* bvectors[1] + for i in 2:length(X) + Xt .+= X[i] .* bvectors[i] + end + end + return Xt +end +function get_vector!(M::AbstractManifold, Y, p, c, B::AbstractBasis) + return _get_vector!(M, Y, p, c, B) +end -function get_vector!(M::AbstractManifold, Y, p, X, B::VeeOrthogonalBasis) - return get_vector!(M, Y, p, X, DefaultOrthogonalBasis(number_system(B))) +function _get_vector!(M::AbstractManifold, Y, p, c, B::VeeOrthogonalBasis) + return get_vector_vee!(M, Y, p, c, number_system(B)) end -function get_vector!(M::AbstractManifold, Y, p, X, B::DefaultBasis) - return get_vector!(M, Y, p, X, DefaultOrthogonalBasis(number_system(B))) +get_vector_vee!(M, Y, p, c, N) = get_vector!(M, Y, p, c, DefaultOrthogonalBasis(N)) + +function _get_vector!(M::AbstractManifold, Y, p, c, B::DefaultBasis) + return get_vector_default!(M, Y, p, c, number_system(B)) +end +function get_vector_default!(M::AbstractManifold, Y, p, c, N) + return get_vector!(M, Y, p, c, DefaultOrthogonalBasis(N)) +end + +function _get_vector!(M::AbstractManifold, Y, p, c, B::DefaultOrthogonalBasis) + return get_vector_orthogonal!(M, Y, p, c, number_system(B)) end -function get_vector!(M::AbstractManifold, Y, p, X, B::DefaultOrthogonalBasis) - return get_vector!(M, Y, p, X, DefaultOrthonormalBasis(number_system(B))) +function get_vector_orthogonal!(M::AbstractManifold, Y, p, c, N) + return get_vector!(M, Y, p, c, DefaultOrthonormalBasis(N)) +end + +function _get_vector!(M::AbstractManifold, Y, p, c, B::DefaultOrthonormalBasis) + return get_vector_orthonormal!(M, Y, p, c, number_system(B)) +end +function get_vector_orthonormal! end + +function _get_vector!(M::AbstractManifold, Y, p, c, B::DiagonalizingOrthonormalBasis) + return get_vector_diagonalizing!(M, Y, p, c, B) end -function get_vector!(M::AbstractManifold, Y, p, X, B::CachedBasis) +function get_vector_diagonalizing! end + +function _get_vector!(M::AbstractManifold, Y, p, c, B::CachedBasis) + return get_vector_cached!(M, Y, p, c, B) +end +function get_vector_cached!(M::AbstractManifold, Y, p, X, B::CachedBasis) # quite convoluted but: # 1) preserves the correct `eltype` # 2) guarantees a reasonable array type `Y` @@ -664,20 +757,16 @@ end Get the basis vectors of basis `B` of the tangent space at point `p`. """ function get_vectors(M::AbstractManifold, p, B::AbstractBasis) - return error( - "get_vectors not implemented for manifold of type $(typeof(M)) a point of type $(typeof(p)) and basis of type $(typeof(B)).", - ) + return _get_vectors(M, p, B) end -function get_vectors(::AbstractManifold, ::Any, B::CachedBasis) +function _get_vectors(::AbstractManifold, ::Any, B::CachedBasis) return _get_vectors(B) end -#internal for directly cached basis i.e. those that are just arrays – used in show _get_vectors(B::CachedBasis{𝔽,<:AbstractBasis,<:AbstractArray}) where {𝔽} = B.data function _get_vectors(B::CachedBasis{𝔽,<:AbstractBasis,<:DiagonalizingBasisData}) where {𝔽} return B.data.vectors end - @doc raw""" gram_schmidt(M::AbstractManifold{𝔽}, p, B::AbstractBasis{𝔽}) where {𝔽} gram_schmidt(M::AbstractManifold, p, V::AbstractVector) @@ -793,17 +882,20 @@ hat(M::AbstractManifold, p, X) = get_vector(M, p, X, VeeOrthogonalBasis()) hat!(M::AbstractManifold, Y, p, X) = get_vector!(M, Y, p, X, VeeOrthogonalBasis()) """ - number_of_coordinates(M::AbstractManifold, B::AbstractBasis) + number_of_coordinates(M::AbstractManifold{𝔽}, B::AbstractBasis) + number_of_coordinates(M::AbstractManifold{𝔽}, ::𝔾) -Compute the number of coordinates in basis `B` of manifold `M`. +Compute the number of coordinates in basis of field type `𝔾` on a manifold `M`. This also corresponds to the number of vectors represented by `B`, or stored within `B` in case of a [`CachedBasis`](@ref). """ -function number_of_coordinates(M::AbstractManifold{𝔽}, B::AbstractBasis{𝔾}) where {𝔽,𝔾} - return div(manifold_dimension(M), real_dimension(𝔽)) * real_dimension(𝔾) +function number_of_coordinates(M::AbstractManifold{𝔽}, ::AbstractBasis{𝔾}) where {𝔽,𝔾} + return number_of_coordinates(M, 𝔾) end -function number_of_coordinates(M::AbstractManifold{𝔽}, B::AbstractBasis{𝔽}) where {𝔽} - return manifold_dimension(M) +function number_of_coordinates(M::AbstractManifold{𝔽}, f::𝔾) where {𝔽,𝔾} + # for odd manifolds this first case has to match. + (real_dimension(𝔽) == real_dimension(f)) && return manifold_dimension(M) + return div(manifold_dimension(M), real_dimension(𝔽)) * real_dimension(f) end """ @@ -851,7 +943,7 @@ end function show(io::IO, ::ProjectedOrthonormalBasis{method,𝔽}) where {method,𝔽} return print(io, "ProjectedOrthonormalBasis($(repr(method)), $(𝔽))") end -function show(io::IO, mime::MIME"text/plain", onb::DiagonalizingOrthonormalBasis) +function show(io::IO, ::MIME"text/plain", onb::DiagonalizingOrthonormalBasis) println( io, "DiagonalizingOrthonormalBasis($(number_system(onb))) with eigenvalue 0 in direction:", @@ -862,7 +954,7 @@ function show(io::IO, mime::MIME"text/plain", onb::DiagonalizingOrthonormalBasis end function show( io::IO, - mime::MIME"text/plain", + ::MIME"text/plain", B::CachedBasis{𝔽,T,D}, ) where {𝔽,T<:AbstractBasis,D} print( @@ -879,7 +971,7 @@ function show( end function show( io::IO, - mime::MIME"text/plain", + ::MIME"text/plain", B::CachedBasis{𝔽,T,D}, ) where {𝔽,T<:DiagonalizingOrthonormalBasis,D<:DiagonalizingBasisData} vectors = _get_vectors(B) @@ -913,31 +1005,3 @@ inverse. """ vee(M::AbstractManifold, p, X) = get_coordinates(M, p, X, VeeOrthogonalBasis()) vee!(M::AbstractManifold, Y, p, X) = get_coordinates!(M, Y, p, X, VeeOrthogonalBasis()) - -macro invoke_maker(argnum, type, sig) - parts = ManifoldsBase._split_signature(sig) - kwargs_list = parts[:kwargs_list] - callargs = parts[:callargs] - fname = parts[:fname] - where_exprs = parts[:where_exprs] - argnames = parts[:argnames] - argtypes = parts[:argtypes] - kwargs_call = parts[:kwargs_call] - - return esc( - quote - function ($fname)($(callargs...); $(kwargs_list...)) where {$(where_exprs...)} - return invoke( - $fname, - Tuple{ - $(argtypes[1:(argnum - 1)]...), - $type, - $(argtypes[(argnum + 1):end]...), - }, - $(argnames...); - $(kwargs_call...), - ) - end - end, - ) -end diff --git a/src/decorator_trait.jl b/src/decorator_trait.jl new file mode 100644 index 00000000..9111bb29 --- /dev/null +++ b/src/decorator_trait.jl @@ -0,0 +1,598 @@ +# +# Base passons +# +manifold_dimension(M::AbstractDecoratorManifold) = manifold_dimension(base_manifold(M)) + +# +# Traits - each passed to a function that is properly documented +# + +""" + IsEmbeddedManifold <: AbstractTrait + +A trait to declare an [`AbstractManifold`](@ref) as an embedded manifold. +""" +struct IsEmbeddedManifold <: AbstractTrait end + +""" + IsIsometricManifoldEmbeddedManifold <: AbstractTrait + +A Trait to determine whether an [`AbstractDecoratorManifold`](@ref) `M` is +an isometrically embedded manifold. +It is a special case of the [`IsEmbeddedManifold`](@ref) trait, i.e. it has all properties of this trait. + +Here, additionally, netric related functions like [`inner`](@ref) and [`norm`](@ref) are passed to the embedding +""" +struct IsIsometricEmbeddedManifold <: AbstractTrait end + +parent_trait(::IsIsometricEmbeddedManifold) = IsEmbeddedManifold() + +""" + IsEmbeddedSubmanifold <: AbstractTrait + +A trait to determine whether an [`AbstractDecoratorManifold`](@ref) `M` is an embedded submanifold. +It is a special case of the [`IsIsometricEmbeddedManifold`](@ref) trait, i.e. it has all properties of +this trait. + +In this trait, additionally to the isometric embedded manifold, all retractions, inverse retractions, +and vectors transports, especially [`exp`](@ref), [`log`](@ref), and [`parallel_transport_to`](@ref) +are passed to the embedding. +""" +struct IsEmbeddedSubmanifold <: AbstractTrait end + +parent_trait(::IsEmbeddedSubmanifold) = IsIsometricEmbeddedManifold() + + +# +# Generic Decorator functions +@doc raw""" + decorated_manifold(M::AbstractDecoratorManifold) + +For a manifold `M` that is decorated with some properties, this function returns +the manifold without that manifold, i.e. the manifold that _was decorated_. +""" +decorated_manifold(M::AbstractDecoratorManifold) +decorated_manifold(M::AbstractManifold) = M +@trait_function decorated_manifold(M::AbstractDecoratorManifold) + +# +# Implemented Traits +function base_manifold(M::AbstractDecoratorManifold, depth::Val{N} = Val(-1)) where {N} + # end recursion I: depth is 0 + N == 0 && return M + # end recursion II: M is equal to its decorated manifold (avoid stack overflow) + D = decorated_manifold(M) + M === D && return M + # indefinite many steps for negative values of M + N < 0 && return base_manifold(D, depth) + # reduce depth otherwise + return base_manifold(D, Val(N - 1)) +end + +# +# Embedded specifix functions. +""" + get_embedding(M::AbstractDecoratorManifold) + get_embedding(M::AbstractDecoratorManifold, p) + +Specify the embedding of a manifold that has abstract decorators. +the embedding might depend on a point representation, where different point representations +are distinguished as subtypes of [`AbstractManifoldPoint`](@ref). +A unique or default representation might also just be an `AbstractArray`. +""" +get_embedding(M::AbstractDecoratorManifold, p) = get_embedding(M) + +# +# ----------------------------------------------------------------------------------------- +# This is one new function + +# Introduction and default fallbacks could become a macro? +# Introduce trait +function allocate_result(M::AbstractDecoratorManifold, f, x...) + return allocate_result(trait(allocate_result, M, f, x...), M, f, x...) +end +# disambiguation +@invoke_maker 1 AbstractManifold allocate_result( + M::AbstractDecoratorManifold, + f::typeof(get_coordinates), + p, + X, + B::AbstractBasis, +) + +# Introduce fallback +@inline function allocate_result(::EmptyTrait, M::AbstractManifold, f, x...) + return invoke( + allocate_result, + Tuple{AbstractManifold,typeof(f),typeof(x).parameters...}, + M, + f, + x..., + ) +end +# Introduce automatic forward +@inline function allocate_result(t::TraitList, M::AbstractManifold, f, x...) + return allocate_result(next_trait(t), M, f, x...) +end +function allocate_result( + ::TraitList{IsEmbeddedManifold}, + M::AbstractDecoratorManifold, + f::typeof(embed), + x..., +) + T = allocate_result_type(get_embedding(M, x[1]), f, x) + return allocate(x[1], T, representation_size(get_embedding(M, x[1]))) +end +function allocate_result( + ::TraitList{IsEmbeddedManifold}, + M::AbstractDecoratorManifold, + f::typeof(project), + x..., +) + T = allocate_result_type(get_embedding(M, x[1]), f, x) + return allocate(x[1], T, representation_size(M)) +end + + +# Introduce Deco Trait | automatic foward | fallback +@trait_function check_size(M::AbstractDecoratorManifold, p) +# Embedded +function check_size(::TraitList{IsEmbeddedManifold}, M::AbstractDecoratorManifold, p) + return check_size(get_embedding(M, p), p) +end + +# Introduce Deco Trait | automatic foward | fallback +@trait_function check_size(M::AbstractDecoratorManifold, p, X) +# Embedded +function check_size(::TraitList{IsEmbeddedManifold}, M::AbstractDecoratorManifold, p, X) + return check_size(get_embedding(M, p), p, X) +end + +# Introduce Deco Trait | automatic foward | fallback +@trait_function embed(M::AbstractDecoratorManifold, p) +# EmbeddedManifold +function embed(::TraitList{IsEmbeddedManifold}, M::AbstractDecoratorManifold, p) + q = allocate_result(M, embed, p) + return embed!(M, q, p) +end + +# Introduce Deco Trait | automatic foward | fallback +@trait_function embed!(M::AbstractDecoratorManifold, q, p) +# EmbeddedManifold +function embed!(::TraitList{IsEmbeddedManifold}, M::AbstractDecoratorManifold, q, p) + return copyto!(M, q, p) +end + +# Introduce Deco Trait | automatic foward | fallback +@trait_function embed(M::AbstractDecoratorManifold, p, X) +# EmbeddedManifold +function embed(::TraitList{IsEmbeddedManifold}, M::AbstractDecoratorManifold, p, X) + q = allocate_result(M, embed, p, X) + return embed!(M, q, p, X) +end + +# Introduce Deco Trait | automatic foward | fallback +@trait_function embed!(M::AbstractDecoratorManifold, Y, p, X) +# EmbeddedManifold +function embed!(::TraitList{IsEmbeddedManifold}, M::AbstractDecoratorManifold, Y, p, X) + return copyto!(M, Y, p, X) +end + +# Introduce Deco Trait | automatic foward | fallback +@trait_function exp(M::AbstractDecoratorManifold, p, X) +# EmbeddedSubManifold +function exp(::TraitList{IsEmbeddedSubmanifold}, M::AbstractDecoratorManifold, p, X) + return exp(get_embedding(M, p), p, X) +end + +# Introduce Deco Trait | automatic foward | fallback +@trait_function exp!(M::AbstractDecoratorManifold, q, p, X) +# EmbeddedSubManifold +function exp!(::TraitList{IsEmbeddedSubmanifold}, M::AbstractDecoratorManifold, q, p, X) + return exp!(get_embedding(M, p), q, p, X) +end + +# Introduce Deco Trait | automatic foward | fallback +@trait_function get_basis(M::AbstractDecoratorManifold, p, B::AbstractBasis) + +# Introduce Deco Trait | automatic foward | fallback +@trait_function get_coordinates(M::AbstractDecoratorManifold, p, X, B::AbstractBasis) + +# Introduce Deco Trait | automatic foward | fallback +@trait_function get_coordinates!(M::AbstractDecoratorManifold, Y, p, X, B::AbstractBasis) + +# Introduce Deco Trait | automatic foward | fallback +@trait_function get_vector(M::AbstractDecoratorManifold, p, c, B::AbstractBasis) + +# Introduce Deco Trait | automatic foward | fallback +@trait_function get_vector!(M::AbstractDecoratorManifold, Y, p, c, B::AbstractBasis) + +# Introduce Deco Trait | automatic foward | fallback +@trait_function inner(M::AbstractDecoratorManifold, p, X, Y) +# Isometric Embedded submanifold +function inner( + ::TraitList{IsIsometricEmbeddedManifold}, + M::AbstractDecoratorManifold, + p, + X, + Y, +) + return inner(get_embedding(M, p), p, X, Y) +end + +# Introduce Deco Trait | automatic foward | fallback +@trait_function inverse_retract( + M::AbstractDecoratorManifold, + p, + q, + m::AbstractInverseRetractionMethod = default_inverse_retraction_method(M), +) +# Transparent for Submanifolds +function inverse_retract( + ::TraitList{IsEmbeddedSubmanifold}, + M::AbstractDecoratorManifold, + p, + q, + m::AbstractInverseRetractionMethod = default_inverse_retraction_method(M), +) + return inverse_retract(get_embedding(M, p), p, q, m) +end + +# Introduce Deco Trait | automatic foward | fallback +@trait_function inverse_retract!(M::AbstractDecoratorManifold, X, p, q) +function inverse_retract!( + ::TraitList{IsEmbeddedSubmanifold}, + M::AbstractDecoratorManifold, + X, + p, + q, + m::AbstractInverseRetractionMethod = default_inverse_retraction_method(M), +) + return inverse_retract!(get_embedding(M, p), X, p, q, m) +end + +# Introduce Deco Trait | automatic foward | fallback +@trait_function is_point(M::AbstractDecoratorManifold, p, te = false; kwargs...) +# Embedded +function is_point( + ::TraitList{IsEmbeddedManifold}, + M::AbstractDecoratorManifold, + p, + te = false; + kwargs..., +) + # to be safe check_size first + es = check_size(M, p) + if es !== nothing + te && throw(es) + return false + end + # this throws if te=true + ep = is_point(get_embedding(M, p), embed(M, p), te; kwargs...) + !ep && return false # otherwise if we get here with ep=false, end with false + mpe = check_point(M, p; kwargs...) + mpe === nothing && return true + te && throw(mpe) + return false +end + +# Introduce Deco Trait | automatic foward | fallback +@trait_function is_vector( + M::AbstractDecoratorManifold, + p, + X, + te = false, + cbp = true; + kwargs..., +) +# EmbeddedManifold +# I am not yet sure how to properly document this embedding behaviour here in a docstring. +function is_vector( + ::TraitList{IsEmbeddedManifold}, + M::AbstractDecoratorManifold, + p, + X, + te = false, + cbp = true; + kwargs..., +) + if cbp + # check whether p is valid before embedding the tangent vector + # throws it te=true + ep = is_point(M, p, te; kwargs...) + !ep && return false + end + # now that we know p is valid, check size of X + es = check_size(M, p, X) + if es !== nothing + te && throw(es) # error & throw? + return false + end + # Check vector in embedding + ev = is_vector(get_embedding(M, p), embed(M, p), embed(M, p, X), te, cbp; kwargs...) + (!ev && !te) && return false # if te, the line before throws an error, otherwise we end with false early here + # Check (additional) local stuff + mtve = check_vector(M, p, X; kwargs...) + mtve === nothing && return true + te && throw(mtve) + return false +end + +@trait_function norm(M::AbstractDecoratorManifold, p, X) +function norm(::TraitList{IsIsometricEmbeddedManifold}, M::AbstractDecoratorManifold, p, X) + return norm(get_embedding(M, p), p, X) +end + +@trait_function log(M::AbstractDecoratorManifold, p, q) +function log(::TraitList{IsEmbeddedSubmanifold}, M::AbstractDecoratorManifold, p, q) + return log(get_embedding(M, p), p, q) +end + +# Introduce Deco Trait | automatic foward | fallback +@trait_function log!(M::AbstractDecoratorManifold, X, p, q) +function log!(::TraitList{IsEmbeddedSubmanifold}, M::AbstractDecoratorManifold, X, p, q) + return log!(get_embedding(M, p), X, p, q) +end + +# Introduce Deco Trait | automatic foward | fallback +@trait_function parallel_transport_along(M::AbstractDecoratorManifold, p, X, c) +# EmbeddedSubManifold +function parallel_transport_along( + ::TraitList{IsEmbeddedSubmanifold}, + M::AbstractDecoratorManifold, + p, + X, + c, +) + return parallel_transport_along(get_embedding(M, p), p, X, c) +end + +# Introduce Deco Trait | automatic foward | fallback +@trait_function parallel_transport_along!(M::AbstractDecoratorManifold, Y, p, X, c) +# EmbeddedSubManifold +function parallel_transport_along!( + ::TraitList{IsEmbeddedSubmanifold}, + M::AbstractDecoratorManifold, + Y, + p, + X, + c, +) + return parallel_transport_along!(get_embedding(M, p), Y, p, X, c) +end + +# Introduce Deco Trait | automatic foward | fallback +@trait_function parallel_transport_direction(M::AbstractDecoratorManifold, p, X, q) +# EmbeddedSubManifold +function parallel_transport_direction( + ::TraitList{IsEmbeddedSubmanifold}, + M::AbstractDecoratorManifold, + p, + X, + q, +) + return parallel_transport_direction(get_embedding(M, p), p, X, q) +end + +# Introduce Deco Trait | automatic foward | fallback +@trait_function parallel_transport_direction!(M::AbstractDecoratorManifold, Y, p, X, q) +# EmbeddedSubManifold +function parallel_transport_direction!( + ::TraitList{IsEmbeddedSubmanifold}, + M::AbstractDecoratorManifold, + Y, + p, + X, + q, +) + return parallel_transport_direction!(get_embedding(M, p), Y, p, X, q) +end + +# Introduce Deco Trait | automatic foward | fallback +@trait_function parallel_transport_to(M::AbstractDecoratorManifold, p, X, q) +# EmbeddedSubManifold +function parallel_transport_to( + ::TraitList{IsEmbeddedSubmanifold}, + M::AbstractDecoratorManifold, + p, + X, + q, +) + return parallel_transport_to(get_embedding(M, p), p, X, q) +end + +# Introduce Deco Trait | automatic foward | fallback +@trait_function parallel_transport_to!(M::AbstractDecoratorManifold, Y, p, X, q) +# EmbeddedSubManifold +function parallel_transport_to!( + ::TraitList{IsEmbeddedSubmanifold}, + M::AbstractDecoratorManifold, + Y, + p, + X, + q, +) + return parallel_transport_to!(get_embedding(M, p), Y, p, X, q) +end + +# Introduce Deco Trait | automatic foward | fallback +@trait_function project(M::AbstractDecoratorManifold, p) + +# Introduce Deco Trait | automatic foward | fallback +@trait_function project!(M::AbstractDecoratorManifold, q, p) + +# Introduce Deco Trait | automatic foward | fallback +@trait_function project(M::AbstractDecoratorManifold, p, X) + +# Introduce Deco Trait | automatic foward | fallback +@trait_function project!(M::AbstractDecoratorManifold, Y, p, X) + +# Introduce Deco Trait | automatic foward | fallback +@trait_function representation_size(M::AbstractDecoratorManifold) (no_empty,) +# Isometric Embedded submanifold +function representation_size(::TraitList{IsEmbeddedManifold}, M::AbstractDecoratorManifold) + return representation_size(get_embedding(M)) +end +function representation_size(::EmptyTrait, M::AbstractDecoratorManifold) + return representation_size(decorated_manifold(M)) +end + + +# Introduce Deco Trait | automatic foward | fallback +@trait_function retract( + M::AbstractDecoratorManifold, + p, + X, + m::AbstractRetractionMethod = default_retraction_method(M), +) +function retract( + ::TraitList{IsEmbeddedSubmanifold}, + M::AbstractDecoratorManifold, + p, + X, + m::AbstractRetractionMethod = default_retraction_method(M), +) + return retract(get_embedding(M, p), p, X, m) +end + +@trait_function retract!( + M::AbstractDecoratorManifold, + q, + p, + X, + m::AbstractRetractionMethod = default_retraction_method(M), +) +function retract!( + ::TraitList{IsEmbeddedSubmanifold}, + M::AbstractDecoratorManifold, + q, + p, + X, + m::AbstractRetractionMethod = default_retraction_method(M), +) + return retract!(get_embedding(M, p), q, p, X, m) +end + +@trait_function vector_transport_along( + M::AbstractDecoratorManifold, + q, + p, + X, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), +) +function vector_transport_along( + ::TraitList{IsEmbeddedSubmanifold}, + M::AbstractDecoratorManifold, + p, + X, + c, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), +) + return vector_transport_along(get_embedding(M, p), p, X, c, m) +end + +@trait_function vector_transport_along!( + M::AbstractDecoratorManifold, + Y, + p, + X, + c::AbstractVector, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), +) +function vector_transport_along!( + ::TraitList{IsEmbeddedSubmanifold}, + M::AbstractDecoratorManifold, + Y, + p, + X, + c::AbstractVector, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), +) + return vector_transport_along!(get_embedding(M, p), Y, p, X, c, m) +end + +@trait_function vector_transport_direction( + M::AbstractDecoratorManifold, + p, + X, + d, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), +) +function vector_transport_direction( + ::TraitList{IsEmbeddedSubmanifold}, + M::AbstractDecoratorManifold, + p, + X, + d, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), +) + return vector_transport_direction(get_embedding(M, p), p, X, d, m) +end + +@trait_function vector_transport_direction!( + M::AbstractDecoratorManifold, + Y, + p, + X, + d, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), +) +function vector_transport_direction!( + ::TraitList{IsEmbeddedSubmanifold}, + M::AbstractDecoratorManifold, + Y, + p, + X, + d, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), +) + return vector_transport_direction!(get_embedding(M, p), Y, p, X, d, m) +end + +@trait_function vector_transport_to( + M::AbstractDecoratorManifold, + p, + X, + q, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), +) +function vector_transport_to( + ::TraitList{IsEmbeddedSubmanifold}, + M::AbstractDecoratorManifold, + p, + X, + q, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), +) + return vector_transport_to(get_embedding(M, p), p, X, q, m) +end + +@trait_function vector_transport_to!( + M::AbstractDecoratorManifold, + Y, + p, + X, + q, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), +) +function vector_transport_to!( + ::TraitList{IsEmbeddedSubmanifold}, + M::AbstractDecoratorManifold, + Y, + p, + X, + q, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), +) + return vector_transport_to!(get_embedding(M, p), Y, p, X, q, m) +end + +@trait_function zero_vector(M::AbstractDecoratorManifold, p) +function zero_vector(::TraitList{IsEmbeddedManifold}, M::AbstractDecoratorManifold, p) + return zero_vector(get_embedding(M, p), p) +end + +@trait_function zero_vector!(M::AbstractDecoratorManifold, X, p) +function zero_vector!(::TraitList{IsEmbeddedManifold}, M::AbstractDecoratorManifold, X, p) + return zero_vector!(get_embedding(M, p), X, p) +end diff --git a/src/exp_log_geo.jl b/src/exp_log_geo.jl index 809f49b9..a6248fbd 100644 --- a/src/exp_log_geo.jl +++ b/src/exp_log_geo.jl @@ -29,9 +29,7 @@ The result is saved to `q`. See also [`exp`](@ref). """ -function exp!(M::AbstractManifold, q, p, X) - return error(manifold_function_not_implemented_message(M, exp!, q, p, X)) -end +exp!(M::AbstractManifold, q, p, X) exp!(M::AbstractManifold, q, p, X, t::Real) = exp!(M, q, p, t * X) @doc raw""" @@ -84,9 +82,7 @@ Note that the logarithmic map might not be globally defined. see also [`log`](@ref) and [`inverse_retract!`](@ref), """ -function log!(M::AbstractManifold, X, p, q) - return error(manifold_function_not_implemented_message(M, log!, X, p, q)) -end +log!(M::AbstractManifold, X, p, q) @doc raw""" shortest_geodesic(M::AbstractManifold, p, q) -> Function diff --git a/src/maintypes.jl b/src/maintypes.jl index 2e37a56b..233c7f9c 100644 --- a/src/maintypes.jl +++ b/src/maintypes.jl @@ -12,9 +12,6 @@ real (ℝ) and complex (ℂ) manifolds. For subtypes the preferred order of parameters is: size and simple value parameters, followed by the [`AbstractNumbers`](@ref) `field`, followed by data type parameters, which might depend on the abstract number field type. - -For more details see [interface-types-and-functions](@ref) in the ManifoldsBase.jl documentation at -[https://juliamanifolds.github.io/Manifolds.jl/stable/interface.html#Types-and-functions](https://juliamanifolds.github.io/Manifolds.jl/stable/interface.html#Types-and-functions). """ abstract type AbstractManifold{𝔽} end @@ -24,6 +21,7 @@ abstract type AbstractManifold{𝔽} end Type for a point on a manifold. While a [`AbstractManifold`](@ref) does not necessarily require this type, for example when it is implemented for `Vector`s or `Matrix` type elements, this type can be used either + * for more complicated representations, * semantic verification, or * even dispatch for different representations of points on a manifold. diff --git a/src/nested_trait.jl b/src/nested_trait.jl new file mode 100644 index 00000000..2adb9735 --- /dev/null +++ b/src/nested_trait.jl @@ -0,0 +1,315 @@ +@doc raw""" + AbstractDecoratorManifold{𝔽} <: AbstractManifold{𝔽} + +Declare a manifold to be an abstract decorator. +A manifold which is a subtype of is a __decorated manifold__, i.e. has + +* certain additional properties or +* delegates certain properties to other manifolds. + +Most prominently, a manifold might be an embedded manifold, i.e. points on a manifold ``\mathcal M`` +are represented by (some, maybe not all) points on another manifold ``\mathcal N``. +Depending on the type of embedding, several functions are dedicated to the embedding. +For example if the embedding is isometric, then the [`inner`](@ref) does not have to be +implemented for ``\mathcal M`` but can be automatically implemented by deligation to ``\mathcal N``. + +This is modelled by the `AbstractDecoratorManifold` and traits. These are mapped to functions, +which determine the types of transparencies. +""" +abstract type AbstractDecoratorManifold{𝔽} <: AbstractManifold{𝔽} end + +""" + AbstractTrait + +An abstract trait type to build a sequence of traits +""" +abstract type AbstractTrait end + +""" + EmptyTrait <: AbstractTrait + +A Trait indicating that no feature is present. +""" +struct EmptyTrait <: AbstractTrait end + +""" + IsExplicitDecorator <: AbstractTrait + +Specify that a certain type should dispatch the function to the field of name `.fieldname` +of the decorator. This means the decorator is stated to explicitly decorate its `.fieldname`s +type. + +!!! note +Any decorator _behind_ this decorator might not have any effect, since the function dispatch +is moved to its field at this point. Therefore this decorator should always be _last_ in the +[`TraitList`](@ref). +""" +struct IsExplicitDecorator <: AbstractTrait end + +""" + TraitList <: AbstractTrait + +Combine two traits into a combined trait. Note that this introduces a preceedence. +the first of the traits takes preceedence if a trait is implemented for both functions. + +# Constructor + + TraitList(head::AbstractTrait, tail::AbstractTrait) +""" +struct TraitList{T1<:AbstractTrait,T2<:AbstractTrait} <: AbstractTrait + head::T1 + tail::T2 +end + +function Base.show(io::IO, t::TraitList) + return print(io, "TraitList(", t.head, ", ", t.tail, ")") +end + +""" + active_traits(f, args...) + +Return the list of traits applicable to the given call of function `f``. This function should be +overloaded for specific function calls. +""" +@inline active_traits(f, args...) = EmptyTrait() + +""" + merge_traits(t1, t2, trest...) + +Merge two traits into a nested list of traits. Note that this takes trait preceedence into account, +i.e. `t1` takes preceedence over `t2` is any operations. +It always returns either ab [`EmptyTrait`](@ref) or a [`TraitList`](@ref). + +This means that for +* one argument it just returns the trait itself if it is list-like, or wraps the trait in a + single-element list otherwise, +* two arguments that are list-like, it merges them, +* two arguments of which only the first one is list-like and the second one is not, + it appends the second argument to the list, +* two arguments of which only the second one is list-like, it prepends the first one to + the list, +* two arguments of which none is list-like, it creates a two-element list. +* more than two arguments it recursively performs a left-assiciative recursive reduction + on arguments, that is for example `merge_traits(t1, t2, t3)` is equivalent to + `merge_traits(merge_traits(t1, t2), t3)` +""" +merge_traits() + +@inline merge_traits() = EmptyTrait() +@inline merge_traits(t::EmptyTrait) = t +@inline merge_traits(t::TraitList) = t +@inline merge_traits(t::AbstractTrait) = TraitList(t, EmptyTrait()) +@inline merge_traits(t1::EmptyTrait, ::EmptyTrait) = t1 +@inline merge_traits(::EmptyTrait, t2::AbstractTrait) = merge_traits(t2) +@inline merge_traits(t1::AbstractTrait, t2::EmptyTrait) = TraitList(t1, t2) +@inline merge_traits(t1::AbstractTrait, t2::TraitList) = TraitList(t1, t2) +@inline merge_traits(::EmptyTrait, t2::TraitList) = t2 +@inline function merge_traits(t1::AbstractTrait, t2::AbstractTrait) + return TraitList(t1, TraitList(t2, EmptyTrait())) +end +@inline merge_traits(t1::TraitList, ::EmptyTrait) = t1 +@inline function merge_traits(t1::TraitList, t2::AbstractTrait) + return TraitList(t1.head, merge_traits(t1.tail, t2)) +end +@inline function merge_traits(t1::TraitList, t2::TraitList) + return TraitList(t1.head, merge_traits(t1.tail, t2)) +end +@inline function merge_traits( + t1::AbstractTrait, + t2::AbstractTrait, + t3::AbstractTrait, + trest::AbstractTrait..., +) + return merge_traits(merge_traits(t1, t2), t3, trest...) +end + +""" + parent_trait(t::AbstractTrait) + +Return the parent trait for trait `t`, that is the more general trait whose behaviour it +inherits as a fallback. +""" +@inline parent_trait(::AbstractTrait) = EmptyTrait() + +@inline function trait(f::TF, args...) where {TF} + bt = active_traits(f, args...) + return expand_trait(bt) +end + +""" + expand_trait(t::AbstractTrait) + +Expand given trait into an ordered [`TraitList`](@ref) list of traits with their parent +traits obtained using [`parent_trait`](@ref). +""" +expand_trait(::AbstractTrait) + +@inline expand_trait(e::EmptyTrait) = e +@inline expand_trait(t::AbstractTrait) = merge_traits(t, expand_trait(parent_trait(t))) +@inline function expand_trait(t::TraitList) + et1 = expand_trait(t.head) + et2 = expand_trait(t.tail) + return merge_traits(et1, et2) +end + +""" + next_trait(t::AbstractTrait) + +Return the next trait to be considered after `t`. +""" +next_trait(t::AbstractTrait) + +@inline next_trait(t::TraitList) = t.tail + +#! format: off +# turn formatting for for the following functions +# due to the if with returns inside (formatter puts a return upfront the if) +function _split_signature(sig::Expr) + if sig.head == :where + where_exprs = sig.args[2:end] + call_expr = sig.args[1] + elseif sig.head == :call + where_exprs = [] + call_expr = sig + else + error("Incorrect syntax in $sig. Expected a :where or :call expression.") + end + fname = call_expr.args[1] + if isa(call_expr.args[2], Expr) && call_expr.args[2].head == :parameters + # we have keyword arguments + callargs = call_expr.args[3:end] + kwargs_list = call_expr.args[2].args + else + callargs = call_expr.args[2:end] + kwargs_list = [] + end + argnames = map(callargs) do arg + if isa(arg, Expr) && arg.head === :kw # default val present + arg = arg.args[1] + end + if isa(arg, Expr) && arg.head === :(::) # typed + return arg.args[1] + end + return arg + end + argtypes = map(callargs) do arg + if isa(arg, Expr) && arg.head === :kw # default val present + arg = arg.args[1] + end + if isa(arg, Expr) + return arg.args[2] + else + return Any + end + end + + kwargs_call = map(kwargs_list) do kwarg + if kwarg.head === :... + return kwarg + else + if isa(kwarg.args[1], Symbol) + kwargname = kwarg.args[1] + else + kwargname = kwarg.args[1].args[1] + end + return :($kwargname = $kwargname) + end + end + + return (; + fname = fname, + where_exprs = where_exprs, + callargs = callargs, + kwargs_list = kwargs_list, + argnames = argnames, + argtypes = argtypes, + kwargs_call = kwargs_call, + ) +end +#! format: on + +macro invoke_maker(argnum, type, sig) + parts = ManifoldsBase._split_signature(sig) + kwargs_list = parts[:kwargs_list] + callargs = parts[:callargs] + fname = parts[:fname] + where_exprs = parts[:where_exprs] + argnames = parts[:argnames] + argtypes = parts[:argtypes] + kwargs_call = parts[:kwargs_call] + + return esc( + quote + function ($fname)($(callargs...); $(kwargs_list...)) where {$(where_exprs...)} + return invoke( + $fname, + Tuple{ + $(argtypes[1:(argnum - 1)]...), + $type, + $(argtypes[(argnum + 1):end]...), + }, + $(argnames...); + $(kwargs_call...), + ) + end + end, + ) +end + + +macro trait_function(sig, opts = :()) + parts = ManifoldsBase._split_signature(sig) + kwargs_list = parts[:kwargs_list] + callargs = parts[:callargs] + fname = parts[:fname] + where_exprs = parts[:where_exprs] + argnames = parts[:argnames] + argtypes = parts[:argtypes] + kwargs_call = parts[:kwargs_call] + + block = quote + function ($fname)($(callargs...); $(kwargs_list...)) where {$(where_exprs...)} + return ($fname)(trait($fname, $(argnames...)), $(argnames...); $(kwargs_call...)) + end + function ($fname)( + t::TraitList, + $(callargs...); + $(kwargs_list...), + ) where {$(where_exprs...)} + return ($fname)(next_trait(t), $(argnames...); $(kwargs_call...)) + end + function ($fname)( + t::TraitList{IsExplicitDecorator}, + $(callargs...); + $(kwargs_list...), + ) where {$(where_exprs...)} + arg1 = decorated_manifold($(argnames[1])) + argt1 = typeof(arg1) + return invoke( + $fname, + Tuple{argt1,$(argtypes[2:end]...)}, + arg1, + $(argnames[2:end]...); + $(kwargs_call...), + ) + end + end + if !(:no_empty in opts.args) + block = quote + $block + function ($fname)( + ::EmptyTrait, + $(callargs...); + $(kwargs_list...), + ) where {$(where_exprs...)} + return invoke( + $fname, + Tuple{supertype($(argtypes[1])),$(argtypes[2:end]...)}, + $(argnames...); + $(kwargs_call...), + ) + end + end + end + return esc(block) +end diff --git a/src/parallel_transport.jl b/src/parallel_transport.jl new file mode 100644 index 00000000..eea7e6db --- /dev/null +++ b/src/parallel_transport.jl @@ -0,0 +1,52 @@ +function parallel_transport_along! end + +@doc raw""" + Y = parallel_transport_along(M::AbstractManifold, p, X, c) + +Compute the parallel transport of the vector `X` from the tangent space at `p` +along the curve `c`. + +To be precise let ``c(t)`` be a curve ``c(0)=p`` for [`vector_transport_along`](@ref) ``\mathcal P^cY`` + +THen In the result ``Y\in T_p\mathcal M`` is the vector ``X`` from the tangent space at ``p=c(0)`` +to the tangent space at ``c(1)``. + +Let ``Z\colon [0,1] \to T\mathcal M``, ``Z(t)\in T_{c(t)}\mathcal M`` be a smooth vector field +along the curve ``c`` with ``Z(0) = Y``, such that ``Z`` is _parallel_, i.e. +its covariant derivative ``\frac{\mathrm{D}}{\mathrm{d}t}Z`` is zero. Note that such a ``Z`` always exists and is unique. + +Then the parallel transport is given by ``Z(1)``. +""" +function parallel_transport_along(M::AbstractManifold, p, X, c) + Y = allocate_result(M, vector_transport_along, X, p) + return parallel_transport_along!(M, Y, p, X, c) +end + +function parallel_transport_direction!(M::AbstractManifold, Y, p, X, d) + return parallel_transport_to!(M, Y, p, X, exp(M, p, d)) +end + +@doc raw""" + parallel_transport_direction(M::AbstractManifold, p, X, d) + +Compute the [`parallel_transport_along`](@ref) the curve ``c(t) = γ_{p,q}(t)``, +i.e. the * the unique geodesic ``c(t)=γ_{p,X}(t)`` from ``γ_{p,d}(0)=p`` into direction ``\dot γ_{p,d}(0)=d``, of the tangent vector `X`. + +By default this function calls [`parallel_transport_to`](@ref)`(M, p, X, q)`, where ``q=\exp_pX``. +""" +function parallel_transport_direction(M::AbstractManifold, p, X, d) + return parallel_transport_to(M, p, X, exp(M, p, d)) +end + +function parallel_transport_to! end + +@doc raw""" + parallel_transport_to(M::AbstractManifold, p, X, q) + +Compute the [`parallel_transport_along`](@ref) the curve ``c(t) = γ_{p,q}(t)``, +i.e. the (assumed to be unique) [`geodesic`](@ref) connecting `p` and `q`, of the tangent vector `X`. +""" +function parallel_transport_to(M::AbstractManifold, p, X, q) + Y = allocate_result(M, vector_transport_to, X, p, q) + return parallel_transport_to!(M, Y, p, X, q) +end diff --git a/src/point_vector_fallbacks.jl b/src/point_vector_fallbacks.jl index 9f5a3518..edc95f07 100644 --- a/src/point_vector_fallbacks.jl +++ b/src/point_vector_fallbacks.jl @@ -13,11 +13,12 @@ List of forwarded functions: * [`copyto!`](@ref), * [`number_eltype`](@ref) (only for values, not the type itself), * `similar`, +* `size`, * `==`. """ macro manifold_element_forwards(T, field::Symbol) return esc(quote - @manifold_element_forwards ($T) _ ($field) + ManifoldsBase.@manifold_element_forwards ($T) _ ($field) end) end macro manifold_element_forwards(T, Twhere, field::Symbol) @@ -27,7 +28,11 @@ macro manifold_element_forwards(T, Twhere, field::Symbol) function ManifoldsBase.allocate(p::$T, ::Type{P}) where {P,$Twhere} return $T(allocate(p.$field, P)) end - function allocate(p::$T, ::Type{P}, dims::Tuple) where {P,$Twhere} + function ManifoldsBase.allocate( + p::$T, + ::Type{P}, + dims::Tuple, + ) where {P,$Twhere} return $T(allocate(p.$field, P, dims)) end @@ -45,6 +50,8 @@ macro manifold_element_forwards(T, Twhere, field::Symbol) Base.similar(p::$T) where {$Twhere} = $T(similar(p.$field)) Base.similar(p::$T, ::Type{P}) where {P,$Twhere} = $T(similar(p.$field, P)) + Base.size(p::$T) where {$Twhere} = size(p.$field) + Base.:(==)(p::$T, q::$T) where {$Twhere} = (p.$field == q.$field) end, ) @@ -60,6 +67,19 @@ points of type `TP`, tangent vectors of type `TV`, with forwarding to fields `pf """ macro default_manifold_fallbacks(TM, TP, TV, pfield::Symbol, vfield::Symbol) block = quote + function ManifoldsBase.allocate_result(::$TM, ::typeof(log), p::$TP, ::$TP) + a = allocate(p.$vfield) + return $TV(a) + end + function ManifoldsBase.allocate_result( + ::$TM, + ::typeof(inverse_retract), + p::$TP, + ::$TP, + ) + a = allocate(p.$vfield) + return $TV(a) + end function ManifoldsBase.allocate_coordinates(M::$TM, p::$TP, T, n::Int) return ManifoldsBase.allocate_coordinates(M, p.$pfield, T, n) end @@ -73,7 +93,7 @@ macro default_manifold_fallbacks(TM, TP, TV, pfield::Symbol, vfield::Symbol) end function ManifoldsBase.check_vector(M::$TM, p::$TP, X::$TV; kwargs...) - return check_vector(M, p.$pfield, X.$vfield; kwargs...) + return ManifoldsBase.check_vector(M, p.$pfield, X.$vfield; kwargs...) end function ManifoldsBase.distance(M::$TM, p::$TP, q::$TP) @@ -81,11 +101,13 @@ macro default_manifold_fallbacks(TM, TP, TV, pfield::Symbol, vfield::Symbol) end function ManifoldsBase.embed!(M::$TM, q::$TP, p::$TP) - return embed!(M, q.$pfield, p.$pfield) + embed!(M, q.$pfield, p.$pfield) + return q end function ManifoldsBase.embed!(M::$TM, Y::$TV, p::$TP, X::$TV) - return embed!(M, Y.$vfield, p.$pfield, X.$vfield) + embed!(M, Y.$vfield, p.$pfield, X.$vfield) + return Y end function ManifoldsBase.exp!(M::$TM, q::$TP, p::$TP, X::$TV) @@ -116,20 +138,6 @@ macro default_manifold_fallbacks(TM, TP, TV, pfield::Symbol, vfield::Symbol) return isapprox(M, p.$pfield, X.$vfield, Y.$vfield; kwargs...) end - function ManifoldsBase.allocate_result(::$TM, ::typeof(log), p::$TP, ::$TP) - a = allocate(p.$vfield) - return $TV(a) - end - function ManifoldsBase.allocate_result( - ::$TM, - ::typeof(inverse_retract), - p::$TP, - ::$TP, - ) - a = allocate(p.$vfield) - return $TV(a) - end - function ManifoldsBase.log!(M::$TM, X::$TV, p::$TP, q::$TP) log!(M, X.$vfield, p.$pfield, q.$pfield) return X @@ -164,66 +172,235 @@ macro default_manifold_fallbacks(TM, TP, TV, pfield::Symbol, vfield::Symbol) return X end end - - for BT in [ - ManifoldsBase.DISAMBIGUATION_BASIS_TYPES..., - ManifoldsBase.DISAMBIGUATION_COTANGENT_BASIS_TYPES..., - ] + for f_postfix in [:default, :orthogonal, :orthonormal, :vee, :cached, :diagonalizing] + ca = Symbol("get_coordinates_$(f_postfix)") + cm = Symbol("get_coordinates_$(f_postfix)!") + va = Symbol("get_vector_$(f_postfix)") + vm = Symbol("get_vector_$(f_postfix)!") + push!(block.args, quote + function ManifoldsBase.$ca(M::$TM, p::$TP, X::$TV, B) + return ManifoldsBase.$ca(M, p.$pfield, X.$vfield, B) + end + function ManifoldsBase.$cm(M::$TM, Y, p::$TP, X::$TV, B) + ManifoldsBase.$cm(M, Y, p.$pfield, X.$vfield, B) + return Y + end + function ManifoldsBase.$va(M::$TM, p::$TP, X, B) + return $TV(ManifoldsBase.$va(M, p.$pfield, X, B)) + end + function ManifoldsBase.$vm(M::$TM, Y::$TV, p::$TP, X, B) + ManifoldsBase.$vm(M, Y.$vfield, p.$pfield, X, B) + return Y + end + end) + end + # TODO forward retraction / inverse_retraction + for f_postfix in [:polar, :project, :qr, :softmax] + ra = Symbol("retract_$(f_postfix)") + rm = Symbol("retract_$(f_postfix)!") + push!(block.args, quote + function ManifoldsBase.$ra(M::$TM, p::$TP, X::$TV) + return $TP(ManifoldsBase.$ra(M, p.$pfield, X.$vfield)) + end + function ManifoldsBase.$rm(M::$TM, q, p::$TP, X::$TV) + ManifoldsBase.$rm(M, q.$pfield, p.$pfield, X.$vfield) + return q + end + end) + end + push!( + block.args, + quote + function ManifoldsBase.retract_exp_ode(M::$TM, p::$TP, X::$TV, m, B) + return $TP(ManifoldsBase.retract_exp_ode(M, p.$pfield, X.$vfield, m, B)) + end + function ManifoldsBase.retract_exp_ode!(M::$TM, q::$TP, p::$TP, X::$TV, m, B) + ManifoldsBase.retract_exp_ode!(M, q.$pfield, p.$pfield, X.$vfield, m, B) + return q + end + function ManifoldsBase.retract_pade(M::$TM, p::$TP, X::$TV, n) + return $TP(ManifoldsBase.retract_pade(M, p.$pfield, X.$vfield, n)) + end + function ManifoldsBase.retract_pade!(M::$TM, q::$TP, p::$TP, X::$TV, n) + ManifoldsBase.retract_pade!(M, q.$pfield, p.$pfield, X.$vfield, n) + return q + end + function ManifoldsBase.retract_embedded(M::$TM, p::$TP, X::$TV, m) + return $TP(ManifoldsBase.retract_embedded(M, p.$pfield, X.$vfield, m)) + end + function ManifoldsBase.retract_embedded!(M::$TM, q::$TP, p::$TP, X::$TV, m) + ManifoldsBase.retract_embedded!(M, q.$pfield, p.$pfield, X.$vfield, m) + return q + end + end, + ) + for f_postfix in [:polar, :project, :qr, :softmax] + ra = Symbol("inverse_retract_$(f_postfix)") + rm = Symbol("inverse_retract_$(f_postfix)!") + push!(block.args, quote + function ManifoldsBase.$ra(M::$TM, p::$TP, q::$TP) + return $TV((ManifoldsBase.$ra)(M, p.$pfield, q.$pfield)) + end + function ManifoldsBase.$rm(M::$TM, Y::$TV, p::$TP, q::$TP) + ManifoldsBase.$rm(M, Y.$vfield, p.$pfield, q.$pfield) + return Y + end + end) + end + push!( + block.args, + quote + function ManifoldsBase.inverse_retract_embedded(M::$TM, p::$TP, q::$TP, m) + return $TV( + ManifoldsBase.inverse_retract_embedded(M, p.$pfield, q.$pfield, m), + ) + end + function ManifoldsBase.inverse_retract_embedded!( + M::$TM, + X::$TV, + p::$TP, + q::$TP, + m, + ) + ManifoldsBase.inverse_retract_embedded!( + M, + X.$vfield, + p.$pfield, + q.$pfield, + m, + ) + return X + end + function ManifoldsBase.inverse_retract_nlsolve(M::$TM, p::$TP, q::$TP, m) + return $TV( + ManifoldsBase.inverse_retract_nlsolve(M, p.$pfield, q.$pfield, m), + ) + end + function ManifoldsBase.inverse_retract_nlsolve!( + M::$TM, + X::$TV, + p::$TP, + q::$TP, + m, + ) + ManifoldsBase.inverse_retract_nlsolve!( + M, + X.$vfield, + p.$pfield, + q.$pfield, + m, + ) + return X + end + end, + ) + # forward vector transports + + for sub in [:project, :diff] + # project & diff + vtaa = Symbol("vector_transport_along_$(sub)") + vtam = Symbol("vector_transport_along_$(sub)!") + vtta = Symbol("vector_transport_to_$(sub)") + vttm = Symbol("vector_transport_to_$(sub)!") push!( block.args, quote - function ManifoldsBase.get_coordinates!(M::$TM, Y, p::$TP, X::$TV, B::$BT) - return get_coordinates!(M, Y, p.$pfield, X.$vfield, B) + function ManifoldsBase.$vtaa(M::$TM, p::$TP, X::$TV, c) + return $TV(ManifoldsBase.$vtaa(M, p.$pfield, X.$vfield, c)) end - - function ManifoldsBase.get_vector(M::$TM, p::$TP, X, B::$BT) - return $TV(get_vector(M, p.$pfield, X, B)) + function ManifoldsBase.$vtam(M::$TM, Y::$TV, p::$TP, X::$TV, c) + ManifoldsBase.$vtam(M, Y.$vfield, p.$pfield, X.$vfield, c) + return Y end - - function ManifoldsBase.get_vector!(M::$TM, Y::$TV, p::$TP, X, B::$BT) - return get_vector!(M, Y.$vfield, p.$pfield, X, B) + function ManifoldsBase.$vtta(M::$TM, p::$TP, X::$TV, q::$TP) + return $TV(ManifoldsBase.$vtta(M, p.$pfield, X.$vfield, q.$pfield)) + end + function ManifoldsBase.$vttm(M::$TM, Y::$TV, p::$TP, X::$TV, q::$TP) + ManifoldsBase.$vttm(M, Y.$vfield, p.$pfield, X.$vfield, q.$pfield) + return Y end end, ) end - - for VTM in [ParallelTransport, VECTOR_TRANSPORT_DISAMBIGUATION...] - push!( - block.args, - quote - function ManifoldsBase.vector_transport_direction!( - M::$TM, - Y::$TV, - p::$TP, - X::$TV, - d::$TV, - m::$VTM, + # parallel transports + push!( + block.args, + quote + function ManifoldsBase.parallel_transport_along(M::$TM, p::$TP, X::$TV, c) + return $TV( + ManifoldsBase.parallel_transport_along(M, p.$pfield, X.$vfield, c), + ) + end + function ManifoldsBase.parallel_transport_along!( + M::$TM, + Y::$TV, + p::$TP, + X::$TV, + c, + ) + ManifoldsBase.parallel_transport_along!( + M, + Y.$vfield, + p.$pfield, + X.$vfield, + c, ) - vector_transport_direction!( + return Y + end + function ManifoldsBase.parallel_transport_direction( + M::$TM, + p::$TP, + X::$TV, + d::$TV, + ) + return $TV( + ManifoldsBase.parallel_transport_direction( M, - Y.$vfield, p.$pfield, X.$vfield, d.$vfield, - m, - ) - return Y - end - function ManifoldsBase.vector_transport_to!( - M::$TM, - Y::$TV, - p::$TP, - X::$TV, - q::$TP, - m::$VTM, + ), ) - vector_transport_to!(M, Y.$vfield, p.$pfield, X.$vfield, q.$pfield, m) - return Y - end - end, - ) - end - + end + function ManifoldsBase.parallel_transport_direction!( + M::$TM, + Y::$TV, + p::$TP, + X::$TV, + d::$TV, + ) + ManifoldsBase.parallel_transport_direction!( + M, + Y.$vfield, + p.$pfield, + X.$vfield, + d.$vfield, + ) + return Y + end + function ManifoldsBase.parallel_transport_to(M::$TM, p::$TP, X::$TV, q::$TP) + return $TV( + ManifoldsBase.parallel_transport_to(M, p.$pfield, X.$vfield, q.$pfield), + ) + end + function ManifoldsBase.parallel_transport_to!( + M::$TM, + Y::$TV, + p::$TP, + X::$TV, + q::$TP, + ) + ManifoldsBase.parallel_transport_to!( + M, + Y.$vfield, + p.$pfield, + X.$vfield, + q.$pfield, + ) + return Y + end + end, + ) return esc(block) end @@ -248,7 +425,7 @@ List of forwarded functions: """ macro manifold_vector_forwards(T, field::Symbol) return esc(quote - @manifold_vector_forwards ($T) _ ($field) + ManifoldsBase.@manifold_vector_forwards ($T) _ ($field) end) end macro manifold_vector_forwards(T, Twhere, field::Symbol) @@ -264,7 +441,7 @@ macro manifold_vector_forwards(T, Twhere, field::Symbol) Base.:+(X::$T) where {$Twhere} = $T(X.$field) Base.zero(X::$T) where {$Twhere} = $T(zero(X.$field)) - @eval @manifold_element_forwards $T $Twhere $field + @eval ManifoldsBase.@manifold_element_forwards $T $Twhere $field Base.axes(p::$T) where {$Twhere} = axes(p.$field) diff --git a/src/projections.jl b/src/projections.jl index b6331b1a..c407c4bf 100644 --- a/src/projections.jl +++ b/src/projections.jl @@ -27,9 +27,7 @@ accordingly. See also: [`EmbeddedManifold`](@ref), [`embed!`](@ref embed!(M::AbstractManifold, q, p)) """ -function project!(M::AbstractManifold, q, p) - return error(manifold_function_not_implemented_message(M, project!, q, p)) -end +project!(::AbstractManifold, q, p) """ project(M::AbstractManifold, p, X) @@ -69,6 +67,4 @@ Lie algebra is perfomed, too. See also: [`EmbeddedManifold`](@ref), [`embed!`](@ref embed!(M::AbstractManifold, Y, p, X)) """ -function project!(M::AbstractManifold, Y, p, X) - return error(manifold_function_not_implemented_message(M, project!, Y, p, X)) -end +project!(::AbstractManifold, Y, p, X) diff --git a/src/retractions.jl b/src/retractions.jl index a2d8185e..26be4f25 100644 --- a/src/retractions.jl +++ b/src/retractions.jl @@ -16,6 +16,7 @@ abstract type AbstractRetractionMethod end ApproximateInverseRetraction <: AbstractInverseRetractionMethod An abstract type for representing approximate inverse retraction methods. + """ abstract type ApproximateInverseRetraction <: AbstractInverseRetractionMethod end @@ -26,13 +27,85 @@ An abstract type for representing approximate retraction methods. """ abstract type ApproximateRetraction <: AbstractRetractionMethod end +@doc raw""" + EmbeddedRetraction{T<:AbstractRetractionMethod} <: AbstractRetractionMethod + +Compute a retraction by using the retraction of type `T` in the embedding and projecting the result. + +# Constructor + + EmbeddedRetraction(r::AbstractRetractionMethod) + +Generate the retraction with retraction `r` to use in the embedding. """ - ExponentialRetraction +struct EmbeddedRetraction{T<:AbstractRetractionMethod} <: AbstractRetractionMethod + retraction::T +end + +""" + ExponentialRetraction <: AbstractRetractionMethod Retraction using the exponential map. """ struct ExponentialRetraction <: AbstractRetractionMethod end +@doc raw""" + ODEExponentialRetraction{T<:AbstractRetractionMethod, B<:AbstractBasis} <: AbstractRetractionMethod + +Approximate the exponential map on the manifold by evaluating the ODE descripting the geodesic at 1, +assuming the default connection of the given manifold by solving the ordinary differential +equation + +```math +\frac{d^2}{dt^2} p^k + Γ^k_{ij} \frac{d}{dt} p_i \frac{d}{dt} p_j = 0, +``` + +where ``Γ^k_{ij}`` are the Christoffel symbols of the second kind, and +the Einstein summation convention is assumed. + +# Constructor + + ODEExponentialRetraction( + r::AbstractRetractionMethod, + b::AbstractBasis=DefaultOrthogonalBasis(), + ) + +Generate the retraction with a retraction to use internally (for some approaches) +and a basis for the tangent space(s). +""" +struct ODEExponentialRetraction{T<:AbstractRetractionMethod,B<:AbstractBasis} <: + AbstractRetractionMethod + retraction::T + basis::B +end +function ODEExponentialRetraction(r::T) where {T<:AbstractRetractionMethod} + return ODEExponentialRetraction(r, DefaultOrthonormalBasis()) +end +function ODEExponentialRetraction(::T, b::CachedBasis) where {T<:AbstractRetractionMethod} + return throw( + DomainError( + b, + "Cached Bases are currently not supported, since the basis has to be implemented in a surrounding of the start point as well.", + ), + ) +end +function ODEExponentialRetraction(r::ExponentialRetraction, ::AbstractBasis) + return throw( + DomainError( + r, + "You can not use the exponential map as an inner method to solve the ode for the exponential map.", + ), + ) +end +function ODEExponentialRetraction(r::ExponentialRetraction, ::CachedBasis) + return throw( + DomainError( + r, + "Neither the exponential map nor a Cached Basis can be used with this retraction type.", + ), + ) +end + """ PolarRetraction <: AbstractRetractionMethod @@ -57,7 +130,53 @@ matrix / matrices for point and tangent vector on a [`AbstractManifold`](@ref) struct QRRetraction <: AbstractRetractionMethod end """ - LogarithmicInverseRetraction + SoftmaxRetraction <: AbstractRetractionMethod + +Describes a retraction that is based on the softmax function. +""" +struct SoftmaxRetraction <: AbstractRetractionMethod end + + +@doc raw""" + PadeRetraction{m} <: AbstractRetractionMethod + +A retraction based on the Padé approximation of order $m$ +""" +struct PadeRetraction{m} <: AbstractRetractionMethod end + +function PadeRetraction(m::Int) + (m < 1) && error( + "The Padé based retraction is only available for positive orders, not for order $m.", + ) + return PadeRetraction{m}() +end +@doc raw""" + CayleyRetraction <: AbstractRetractionMethod + +A retraction based on the Cayley transform, which is realized by using the +[`PadeRetraction`](@ref)`{1}`. +""" +const CayleyRetraction = PadeRetraction{1} + +@doc raw""" + EmbeddedInverseRetraction{T<:AbstractInverseRetractionMethod} <: AbstractInverseRetractionMethod + +Compute an inverse retraction by using the inverse retraction of type `T` in the embedding and projecting the result + +# Constructor + + EmbeddedInverseRetraction(r::AbstractInverseRetractionMethod) + +Generate the inverse retraction with inverse retraction `r` to use in the embedding. +""" +struct EmbeddedInverseRetraction{T<:AbstractInverseRetractionMethod} <: + AbstractInverseRetractionMethod + inverse_retraction::T +end + + +""" + LogarithmicInverseRetraction <: AbstractInverseRetractionMethod Inverse retraction using the [`log`](@ref)arithmic map. """ @@ -87,14 +206,14 @@ matrix / matrices for point and tangent vector on a [`AbstractManifold`](@ref) struct QRInverseRetraction <: AbstractInverseRetractionMethod end """ - NLsolveInverseRetraction{T<:AbstractRetractionMethod,TV,TK} <: + NLSolveInverseRetraction{T<:AbstractRetractionMethod,TV,TK} <: ApproximateInverseRetraction An inverse retraction method for approximating the inverse of a retraction using `NLsolve`. # Constructor - NLsolveInverseRetraction( + NLSolveInverseRetraction( method::AbstractRetractionMethod[, X0]; project_tangent=false, project_point=false, @@ -107,14 +226,14 @@ vector is projected before the retraction using `project`. If `project_point` is then the resulting point is projected after the retraction. `nlsolve_kwargs` are keyword arguments passed to `NLsolve.nlsolve`. """ -struct NLsolveInverseRetraction{TR<:AbstractRetractionMethod,TV,TK} <: +struct NLSolveInverseRetraction{TR<:AbstractRetractionMethod,TV,TK} <: ApproximateInverseRetraction retraction::TR X0::TV project_tangent::Bool project_point::Bool nlsolve_kwargs::TK - function NLsolveInverseRetraction(m, X0, project_point, project_tangent, nlsolve_kwargs) + function NLSolveInverseRetraction(m, X0, project_point, project_tangent, nlsolve_kwargs) return new{typeof(m),typeof(X0),typeof(nlsolve_kwargs)}( m, X0, @@ -124,16 +243,23 @@ struct NLsolveInverseRetraction{TR<:AbstractRetractionMethod,TV,TK} <: ) end end -function NLsolveInverseRetraction( +function NLSolveInverseRetraction( m, X0 = nothing; project_tangent::Bool = false, project_point::Bool = false, nlsolve_kwargs..., ) - return NLsolveInverseRetraction(m, X0, project_point, project_tangent, nlsolve_kwargs) + return NLSolveInverseRetraction(m, X0, project_point, project_tangent, nlsolve_kwargs) end +""" + SoftmaxInverseRetraction <: AbstractInverseRetractionMethod + +Describes an inverse retraction that is based on the softmax function. +""" +struct SoftmaxInverseRetraction <: AbstractInverseRetractionMethod end + """ default_inverse_retraction_method(M::AbstractManifold) @@ -164,29 +290,90 @@ available methods. See also [`retract!`](@ref). """ -function inverse_retract!(M::AbstractManifold, X, p, q) - return inverse_retract!(M, X, p, q, default_inverse_retraction_method(M)) -end function inverse_retract!( M::AbstractManifold, X, p, q, - method::LogarithmicInverseRetraction, + m::AbstractInverseRetractionMethod = default_inverse_retraction_method(M), ) + return _inverse_retract!(M, X, p, q, m) +end +function _inverse_retract!(M::AbstractManifold, X, p, q, ::LogarithmicInverseRetraction) return log!(M, X, p, q) end -function inverse_retract!( - M::AbstractManifold, - X, - p, - q, - method::AbstractInverseRetractionMethod, -) - return error( - manifold_function_not_implemented_message(M, inverse_retract!, X, p, q, method), + +# +# dispatch to lower level +function _inverse_retract!(M::AbstractManifold, X, p, q, m::EmbeddedInverseRetraction) + return inverse_retract_embedded!(M, X, p, q, m.inverse_retraction) +end +function _inverse_retract!(M::AbstractManifold, X, p, q, ::PolarInverseRetraction) + return inverse_retract_polar!(M, X, p, q) +end +function _inverse_retract!(M::AbstractManifold, X, p, q, ::ProjectionInverseRetraction) + return inverse_retract_project!(M, X, p, q) +end +function _inverse_retract!(M::AbstractManifold, X, p, q, ::QRInverseRetraction) + return inverse_retract_qr!(M, X, p, q) +end +function _inverse_retract!(M::AbstractManifold, X, p, q, ::SoftmaxInverseRetraction) + return inverse_retract_softmax!(M, X, p, q) +end +function _inverse_retract!(M::AbstractManifold, X, p, q, m::NLSolveInverseRetraction) + return inverse_retract_nlsolve!(M, X, p, q, m) +end +""" + inverse_retract_embedded!(M::AbstractManifold, X, p, q, m) + +computes the mutating variant of the [`EmbeddedInverseRetraction`](@ref) using +the [`AbstractInverseRetractionMethod`](@ref) `m` in the embedding (see [`get_embedding`](@ref)) +and projecting the result back. +""" +function inverse_retract_embedded!(M::AbstractManifold, X, p, q, m) + return project!( + M, + X, + p, + inverse_retract( + get_embedding(M), + embed(get_embedding(M), p), + embed(get_embedding(M), q), + m, + ), ) end +""" + inverse_retract_qr!(M::AbstractManifold, X, p, q) + +computes the mutating variant of the [`QRInverseRetraction`](@ref). +""" +function inverse_retract_qr! end +""" + inverse_retract_project!(M::AbstractManifold, X, p, q) + +computes the mutating variant of the [`ProjectionInverseRetraction`](@ref). +""" +function inverse_retract_project! end +""" + inverse_retract_polar!(M::AbstractManifold, X, p, q) + +computes the mutating variant of the [`PolarInverseRetraction`](@ref). +""" +function inverse_retract_polar! end +""" + inverse_retract_nlsolve!(M::AbstractManifold, X, p, q, m) + +computes the mutating variant of the [`NLSolveInverseRetraction`](@ref) `m`. +""" +function inverse_retract_nlsolve! end +""" + inverse_retract_softmax!(M::AbstractManifold, X, p, q) + +computes the mutating variant of the [`SoftmaxInverseRetraction`](@ref). +""" +function inverse_retract_softmax! end + """ inverse_retract(M::AbstractManifold, p, q) @@ -202,22 +389,217 @@ corresponding manifold. See also [`retract`](@ref). """ -function inverse_retract(M::AbstractManifold, p, q) +function inverse_retract( + M::AbstractManifold, + p, + q, + m::AbstractInverseRetractionMethod = default_inverse_retraction_method(M), +) + return _inverse_retract(M, p, q, m) +end +function _inverse_retract(M::AbstractManifold, p, q, m::EmbeddedInverseRetraction) + return inverse_retract_embedded(M, p, q, m.inverse_retraction) +end +_inverse_retract(M::AbstractManifold, p, q, ::LogarithmicInverseRetraction) = log(M, p, q) +function _inverse_retract(M::AbstractManifold, p, q, m::NLSolveInverseRetraction) + return inverse_retract_nlsolve(M, p, q, m) +end +function _inverse_retract(M::AbstractManifold, p, q, ::PolarInverseRetraction) + return inverse_retract_polar(M, p, q) +end +function _inverse_retract(M::AbstractManifold, p, q, ::ProjectionInverseRetraction) + return inverse_retract_project(M, p, q) +end +function _inverse_retract(M::AbstractManifold, p, q, ::QRInverseRetraction) + return inverse_retract_qr(M, p, q) +end +function _inverse_retract(M::AbstractManifold, p, q, ::SoftmaxInverseRetraction) + return inverse_retract_softmax(M, p, q) +end +""" + inverse_retract_embedded(M::AbstractManifold, p, q, m) + +computes the allocating variant of the [`EmbeddedInverseRetraction`](@ref) using +the [`AbstractInverseRetractionMethod`](@ref) `m` in the embedding (see [`get_embedding`](@ref)) +and projecting the result back. +""" +function inverse_retract_embedded(M::AbstractManifold, p, q, m) + return project( + M, + p, + inverse_retract( + get_embedding(M), + embed(get_embedding(M), p), + embed(get_embedding(M), q), + m, + ), + ) +end +""" + inverse_retract_polar(M::AbstractManifold, p, q) + +computes the allocating variant of the [`PolarInverseRetraction`](@ref), +which by default allocates and calls [`inverse_retract_polar!`](@ref). +""" +function inverse_retract_polar(M::AbstractManifold, p, q) + X = allocate_result(M, inverse_retract, p, q) + return inverse_retract_polar!(M, X, p, q) +end +""" + inverse_retract_project(M::AbstractManifold, p, q) + +computes the allocating variant of the [`ProjectionInverseRetraction`](@ref), +which by default allocates and calls [`inverse_retract_project!`](@ref). +""" +function inverse_retract_project(M::AbstractManifold, p, q) + X = allocate_result(M, inverse_retract, p, q) + return inverse_retract_project!(M, X, p, q) +end +""" + inverse_retract_qr(M::AbstractManifold, p, q) + +computes the allocating variant of the [`QRInverseRetraction`](@ref), +which by default allocates and calls [`inverse_retract_qr!`](@ref). +""" +function inverse_retract_qr(M::AbstractManifold, p, q) X = allocate_result(M, inverse_retract, p, q) - inverse_retract!(M, X, p, q) - return X + return inverse_retract_qr!(M, X, p, q) end -function inverse_retract(M::AbstractManifold, p, q, method::AbstractInverseRetractionMethod) +""" + inverse_retract_nlsolve(M::AbstractManifold, p, q, m) + +computes the allocating variant of the [`NLSolveInverseRetraction`](@ref) `m`, +which by default allocates and calls [`inverse_retract_nlsolve!`](@ref). +""" +function inverse_retract_nlsolve(M::AbstractManifold, p, q, m) X = allocate_result(M, inverse_retract, p, q) - inverse_retract!(M, X, p, q, method) - return X + return inverse_retract_nlsolve!(M, X, p, q, m) +end +""" + inverse_retract_softmax(M::AbstractManifold, p, q) + +computes the allocating variant of the [`SoftmaxInverseRetraction`](@ref), +which by default allocates and calls [`inverse_retract_softmax!`](@ref). +""" +function inverse_retract_softmax(M::AbstractManifold, p, q) + X = allocate_result(M, inverse_retract, p, q) + return inverse_retract_softmax!(M, X, p, q) +end + + +""" + retract!(M::AbstractManifold, q, p, X) + retract!(M::AbstractManifold, q, p, X, t::Real=1) + retract!(M::AbstractManifold, q, p, X, method::AbstractRetractionMethod) + retract!(M::AbstractManifold, q, p, X, t::Real=1, method::AbstractRetractionMethod) + +Compute a retraction, a cheaper, approximate version of the [`exp`](@ref)onential map, +from `p` into direction `X`, scaled by `t`, on the [`AbstractManifold`](@ref) manifold `M`. +Result is saved to `q`. + +Retraction method can be specified by the last argument, defaulting to +[`default_retraction_method`](@ref)`(M)`. See the documentation of respective manifolds for available +methods. + +See [`retract`](@ref) for more details. +""" +function retract!( + M::AbstractManifold, + q, + p, + X, + m::AbstractRetractionMethod = default_retraction_method(M), +) + return _retract!(M, q, p, X, m) +end +function retract!( + M::AbstractManifold, + q, + p, + X, + t::Real, + method::AbstractRetractionMethod = default_retraction_method(M), +) + return _retract!(M, q, p, t * X, method) +end +# dispatch to lower level +function _retract!(M::AbstractManifold, q, p, X, m::EmbeddedRetraction) + return retract_embedded!(M, q, p, X, m.retraction) +end +_retract!(M::AbstractManifold, q, p, X, ::ExponentialRetraction) = exp!(M, q, p, X) +function _retract!(M::AbstractManifold, q, p, X, m::ODEExponentialRetraction) + return retract_exp_ode!(M, q, p, X, m.retraction, m.basis) +end +_retract!(M::AbstractManifold, q, p, X, ::PolarRetraction) = retract_polar!(M, q, p, X) +function _retract!(M::AbstractManifold, q, p, X, ::ProjectionRetraction) + return retract_project!(M, q, p, X) +end +_retract!(M::AbstractManifold, q, p, X, ::QRRetraction) = retract_qr!(M, q, p, X) +_retract!(M::AbstractManifold, q, p, X, ::SoftmaxRetraction) = retract_softmax!(M, q, p, X) +_retract!(M::AbstractManifold, q, p, X, ::CayleyRetraction) = retract_pade!(M, q, p, X, 1) +function _retract!(M::AbstractManifold, q, p, X, ::PadeRetraction{n}) where {n} + return retract_pade!(M, q, p, X, n) +end + +""" + retract_embedded!(M::AbstractManifold, X, p, q, m) + +computes the mutating variant of the [`EmbeddedRetraction`](@ref) using +the [`AbstractRetractionMethod`](@ref) `m` in the embedding (see [`get_embedding`](@ref)) +and projecting the result back. +""" +function retract_embedded!(M::AbstractManifold, q, p, X, m) + return project!( + M, + q, + retract( + get_embedding(M), + embed(get_embedding(M), p), + embed(get_embedding(M), p, X), + m, + ), + ) end +""" + retract_exp_ode!(M::AbstractManifold, q, p, X, m, B) + +computes the mutating variant of the [`ODEExponentialRetraction`](@ref)`(m, B)`. +""" +function retract_exp_ode! end +""" + retract_pade!(M::AbstractManifold, q, p, n) + +computes the mutating variant of the [`PadeRetraction`](@ref)`(n)`. +""" +function retract_pade! end +""" + retract_project!(M::AbstractManifold, q, p, X) + +computes the mutating variant of the [`ProjectionRetraction`](@ref). +""" +function retract_project! end +""" + retract_polar!(M::AbstractManifold, q, p, X) + +computes the mutating variant of the [`PolarRetraction`](@ref). +""" +function retract_polar! end +""" + retract_qr!(M::AbstractManifold, q, p, X) + +computes the mutating variant of the [`QRRetraction`](@ref). +""" +function retract_qr! end +""" + retract_softmax!(M::AbstractManifold, q, p, X) + +computes the mutating variant of the [`SoftmaxRetraction`](@ref). +""" +function retract_softmax! end @doc raw""" - retract(M::AbstractManifold, p, X) - retract(M::AbstractManifold, p, X, t::Real=1) - retract(M::AbstractManifold, p, X, method::AbstractRetractionMethod) - retract(M::AbstractManifold, p, X, t::Real=1, method::AbstractRetractionMethod) + retract(M::AbstractManifold, p, X, method::AbstractRetractionMethod=default_retraction_method(M)) + retract(M::AbstractManifold, p, X, t::Real=1, method::AbstractRetractionMethod=default_retraction_method(M)) Compute a retraction, a cheaper, approximate version of the [`exp`](@ref)onential map, from `p` into direction `X`, scaled by `t`, on the [`AbstractManifold`](@ref) `M`. @@ -237,43 +619,116 @@ Retraction method can be specified by the last argument, defaulting to Locally, the retraction is invertible. For the inverse operation, see [`inverse_retract`](@ref). """ -function retract(M::AbstractManifold, p, X) - q = allocate_result(M, retract, p, X) - retract!(M, q, p, X) - return q +function retract( + M::AbstractManifold, + p, + X, + m::AbstractRetractionMethod = default_retraction_method(M), +) + return _retract(M, p, X, m) end -retract(M::AbstractManifold, p, X, t::Real) = retract(M, p, t * X) -function retract(M::AbstractManifold, p, X, method::AbstractRetractionMethod) - q = allocate_result(M, retract, p, X) - retract!(M, q, p, X, method) - return q +function retract( + M::AbstractManifold, + p, + X, + t::Real, + m::AbstractRetractionMethod = default_retraction_method(M), +) + return _retract(M, p, t * X, m) +end +function _retract(M::AbstractManifold, p, X, m::EmbeddedRetraction) + return retract_embedded(M, p, X, m.retraction) end -function retract(M::AbstractManifold, p, X, t::Real, method::AbstractRetractionMethod) - return retract(M, p, t * X, method) +_retract(M::AbstractManifold, p, X, ::ExponentialRetraction) = exp(M, p, X) +function _retract(M::AbstractManifold, p, X, m::ODEExponentialRetraction) + return retract_exp_ode(M, p, X, m.retraction, m.basis) end +_retract(M::AbstractManifold, p, X, ::PolarRetraction) = retract_polar(M, p, X) +_retract(M::AbstractManifold, p, X, ::ProjectionRetraction) = retract_project(M, p, X) +_retract(M::AbstractManifold, p, X, ::QRRetraction) = retract_qr(M, p, X) +_retract(M::AbstractManifold, p, X, ::SoftmaxRetraction) = retract_softmax(M, p, X) +_retract(M::AbstractManifold, p, X, ::CayleyRetraction) = retract_pade(M, p, X, 1) +function _retract(M::AbstractManifold, p, X, ::PadeRetraction{n}) where {n} + return retract_pade(M, p, X, n) +end +""" + retract_embedded(M::AbstractManifold, p, X, m) +computes the allocating variant of the [`EmbeddedRetraction`](@ref) using +the [`AbstractRetractionMethod`](@ref) `m` in the embedding (see [`get_embedding`](@ref)) +and projecting the result back. """ - retract!(M::AbstractManifold, q, p, X) - retract!(M::AbstractManifold, q, p, X, t::Real=1) - retract!(M::AbstractManifold, q, p, X, method::AbstractRetractionMethod) - retract!(M::AbstractManifold, q, p, X, t::Real=1, method::AbstractRetractionMethod) +function retract_embedded(M::AbstractManifold, p, X, m) + return project( + M, + retract( + get_embedding(M), + embed(get_embedding(M), p), + embed(get_embedding(M), p, X), + m, + ), + ) +end +""" + retract_polar(M::AbstractManifold, p, q) -Compute a retraction, a cheaper, approximate version of the [`exp`](@ref)onential map, -from `p` into direction `X`, scaled by `t`, on the [`AbstractManifold`](@ref) manifold `M`. -Result is saved to `q`. +computes the allocating variant of the [`PolarRetraction`](@ref), +which by default allocates and calls [`retract_polar!`](@ref). +""" +function retract_polar(M::AbstractManifold, p, X) + q = allocate_result(M, retract, p, X) + return retract_polar!(M, q, p, X) +end +""" + retract_project(M::AbstractManifold, p, q) -Retraction method can be specified by the last argument, defaulting to -[`default_retraction_method`](@ref)`(M)`. See the documentation of respective manifolds for available -methods. +computes the allocating variant of the [`ProjectionRetraction`](@ref), +which by default allocates and calls [`retract_project!`](@ref). +""" +function retract_project(M::AbstractManifold, p, X) + q = allocate_result(M, retract, p, X) + return retract_project!(M, q, p, X) +end +""" + retract_qr(M::AbstractManifold, p, q) -See [`retract`](@ref) for more details. +computes the allocating variant of the [`QRRetraction`](@ref), +which by default allocates and calls [`retract_qr!`](@ref). +""" +function retract_qr(M::AbstractManifold, p, X) + q = allocate_result(M, retract, p, X) + return retract_qr!(M, q, p, X) +end +""" + retract_exp_ode(M::AbstractManifold, p, q, m, B) + +computes the allocating variant of the [`ODEExponentialRetraction`](@ref)`(m,B)`, +which by default allocates and calls [`retract_exp_ode!`](@ref). """ -retract!(M::AbstractManifold, q, p, X) = retract!(M, q, p, X, default_retraction_method(M)) -retract!(M::AbstractManifold, q, p, X, t::Real) = retract!(M, q, p, t * X) -retract!(M::AbstractManifold, q, p, X, ::ExponentialRetraction) = exp!(M, q, p, X) -function retract!(M::AbstractManifold, q, p, X, t::Real, method::AbstractRetractionMethod) - return retract!(M, q, p, t * X, method) +function retract_exp_ode(M::AbstractManifold, p, X, m, B) + q = allocate_result(M, retract, p, X) + return retract_exp_ode!(M, q, p, X, m, B) end -function retract!(M::AbstractManifold, q, p, X, method::AbstractRetractionMethod) - return error(manifold_function_not_implemented_message(M, retract!, q, p, method)) +""" + retract_softmax(M::AbstractManifold, p, q) + +computes the allocating variant of the [`SoftmaxRetraction`](@ref), +which by default allocates and calls [`retract_softmax!`](@ref). +""" +function retract_softmax(M::AbstractManifold, p, X) + q = allocate_result(M, retract, p, X) + return retract_softmax!(M, q, p, X) end +""" + retract_pade(M::AbstractManifold, p, q) + +computes the allocating variant of the [`PadeRetraction`](@ref)`(n)`, +which by default allocates and calls [`retract_pade!`](@ref). +""" +function retract_pade(M::AbstractManifold, p, X, n) + q = allocate_result(M, retract, p, X) + return retract_pade!(M, q, p, X, n) +end + +Base.show(io::IO, ::CayleyRetraction) = print(io, "CayleyRetraction()") +Base.show(io::IO, ::PadeRetraction{m}) where {m} = print(io, "PadeRetraction($(m))") diff --git a/src/vector_transport.jl b/src/vector_transport.jl index d4540007..a4e5c0f0 100644 --- a/src/vector_transport.jl +++ b/src/vector_transport.jl @@ -25,7 +25,7 @@ abstract type AbstractLinearVectorTransportMethod <: AbstractVectorTransportMeth A type to specify a vector transport that is given by differentiating a retraction. This can be introduced in two ways. Let ``\mathcal M`` be a Riemannian manifold, -``p\in\mathcal M`` a point, and ``X,Y\in T_p\mathcal M`` denote two tangent vectors at ``p``. +``p∈\mathcal M`` a point, and ``X,Y∈ T_p\mathcal M`` denote two tangent vectors at ``p``. Given a retraction (cf. [`AbstractRetractionMethod`](@ref)) ``\operatorname{retr}``, the vector transport of `X` in direction `Y` (cf. [`vector_transport_direction`](@ref)) @@ -51,45 +51,24 @@ compute ``Y = \operatorname{retr}_p^{-1}q``. # Constructor DifferentiatedRetractionVectorTransport(m::AbstractRetractionMethod) - -[^AbsilMahonySepulchre2008]: - > Absil, P.-A., Mahony, R. and Sepulchre R., - > _Optimization Algorithms on Matrix Manifolds_ - > Princeton University Press, 2008, - > doi: [10.1515/9781400830244](https://doi.org/10.1515/9781400830244) - > [open access](http://press.princeton.edu/chapters/absil/) """ struct DifferentiatedRetractionVectorTransport{R<:AbstractRetractionMethod} <: - AbstractLinearVectorTransportMethod end -function DifferentiatedRetractionVectorTransport(::R) where {R<:AbstractRetractionMethod} - return DifferentiatedRetractionVectorTransport{R}() + AbstractLinearVectorTransportMethod + retraction::R + function DifferentiatedRetractionVectorTransport( + r::R, + ) where {R<:AbstractRetractionMethod} + return new{R}(r) + end end @doc raw""" - ParallelTransport = DifferentiatedRetractionVectorTransport{ExponentialRetraction} + ParallelTransport <: AbstractVectorTransportMethod -Specify to use parallel transport vector transport method. - -To be precise let ``c(t)`` be a curve depending on the method - -* the (assumed to be unique) geodesic ``c(t) = γ_{p,q}(t)`` from ``γ_{p,q}(0)=p`` to ``γ_{p,q}(1)=q`` for [`vector_transport_to`](@ref) ``\mathcal P_{q\gets p}Y`` -* the unique geodesic ``c(t)=γ_{p,X}(t)`` from ``γ_{p,X}(0)=p`` into direction ``\dot γ_{p,X}(0)=X`` for [`vector_transport_direction`](@ref) ``\mathcal P_{p,X}Y`` -* a given curve ``c(0)=p`` for [`vector_transport_along`](@ref) ``\mathcal P^cY`` - -In these cases ``Y\in T_p\mathcal M`` is the vector that we would like to transport from -the tangent space at ``p=c(0)`` to the tangent space at ``c(1)``. - -Let ``Z\colon [0,1] \to T\mathcal M``, ``Z(t)\in T_{c(t)}\mathcal M`` be a smooth vector field -along the curve ``c`` with ``Z(0) = Y``, such that ``Z`` is _parallel_, i.e. -its covariant derivative ``\frac{\mathrm{D}}{\mathrm{d}t}Z`` is zero. Note that such a ``Z`` always exists and is unique. - -Then the parallel transport is given by ``Z(1)``. - -Note that since it is technically the [`DifferentiatedRetractionVectorTransport`](@ref) of -the [`exp`](@ref exp(M::AbstractManifold, p, X)) (cf. [`ExponentialRetraction`](@ref)), we define -`ParallelTransport` as an alias. +Compute the vector transport by parallel transport, see +[`parallel_transport_to`](@ref) """ -const ParallelTransport = DifferentiatedRetractionVectorTransport{ExponentialRetraction} +struct ParallelTransport <: AbstractLinearVectorTransportMethod end """ ProjectionTransport <: AbstractVectorTransportMethod @@ -98,7 +77,7 @@ Specify to use projection onto tangent space as vector transport method within [`vector_transport_to`](@ref), [`vector_transport_direction`](@ref), or [`vector_transport_along`](@ref). See [`project`](@ref) for details. """ -struct ProjectionTransport <: AbstractVectorTransportMethod end +struct ProjectionTransport <: AbstractLinearVectorTransportMethod end @doc raw""" @@ -108,7 +87,7 @@ Specify to use [`pole_ladder`](@ref) as vector transport method within [`vector_transport_to`](@ref), [`vector_transport_direction`](@ref), or [`vector_transport_along`](@ref), i.e. -Let $X\in T_p\mathcal M$ be a tangent vector at $p\in\mathcal M$ and $q\in\mathcal M$ the +Let $X∈ T_p\mathcal M$ be a tangent vector at $p∈\mathcal M$ and $q∈\mathcal M$ the point to transport to. Then $x = \exp_pX$ is used to call `y = `[`pole_ladder`](@ref)`(M, p, x, q)` and the resulting vector is obtained by computing $Y = -\log_qy$. @@ -168,7 +147,7 @@ end ScaledVectorTransport{T} <: AbstractVectorTransportMethod Introduce a scaled variant of any [`AbstractVectorTransportMethod`](@ref) `T`, -as introduced in [^SatoIwai2013] for some ``X\in T_p\mathcal M`` as +as introduced in [^SatoIwai2013] for some ``X∈ T_p\mathcal M`` as ```math \mathcal T^{\mathrm{S}}(X) = \frac{\lVert X\rVert_p}{\lVert \mathcal T(X)\rVert_q}\mathcal T(X). @@ -200,7 +179,7 @@ Specify to use [`schilds_ladder`](@ref) as vector transport method within [`vector_transport_to`](@ref), [`vector_transport_direction`](@ref), or [`vector_transport_along`](@ref), i.e. -Let $X\in T_p\mathcal M$ be a tangent vector at $p\in\mathcal M$ and $q\in\mathcal M$ the +Let $X∈ T_p\mathcal M$ be a tangent vector at $p∈\mathcal M$ and $q∈\mathcal M$ the point to transport to. Then ````math @@ -252,16 +231,70 @@ struct SchildsLadderTransport{ end end +""" + VectorTransportDirection{VM<:AbstractVectorTransportMethod,RM<:AbstractRetractionMethod} + <: AbstractVectorTransportMethod + + Specify a [`vector_transport_direction`](@ref) using a [`AbstractVectorTransportMethod`](@ref) + with explicitly using the [`AbstractRetractionMethod`](@ref) to determine the point in + the specified direction where to transsport to. + Note that you only need this for the non-default (non-implicit) second retraction method associated to a vector transport, + i.e. when a first implementation assumed an implicit associated retraction. +""" +struct VectorTransportDirection{ + VM<:AbstractVectorTransportMethod, + RM<:AbstractRetractionMethod, +} <: AbstractVectorTransportMethod + retraction::RM + vector_transport::VM + function VectorTransportDirection( + vector_transport = ParallelTransport(), + retraction = ExponentialRetraction(), + ) + return new{typeof(vector_transport),typeof(retraction)}( + retraction, + vector_transport, + ) + end +end + +""" + VectorTransportTo{VM<:AbstractVectorTransportMethod,RM<:AbstractRetractionMethod} + <: AbstractVectorTransportMethod + + Specify a [`vector_transport_to`](@ref) using a [`AbstractVectorTransportMethod`](@ref) + with explicitly using the [`AbstractInverseRetractionMethod`](@ref) to determine the direction + that transports from in `p`to `q`. + Note that you only need this for the non-default (non-implicit) second retraction method associated to a vector transport, + i.e. when a first implementation assumed an implicit associated retraction. +""" +struct VectorTransportTo{ + VM<:AbstractVectorTransportMethod, + IM<:AbstractInverseRetractionMethod, +} <: AbstractVectorTransportMethod + inverse_retraction::IM + vector_transport::VM + function VectorTransportTo( + vector_transport = ParallelTransport(), + inverse_retraction = LogarithmicInverseRetraction(), + ) + return new{typeof(vector_transport),typeof(inverse_retraction)}( + inverse_retraction, + vector_transport, + ) + end +end + """ default_vector_transport_method(M::AbstractManifold) The [`AbstractVectorTransportMethod`](@ref) that is used when calling [`vector_transport_along`](@ref), [`vector_transport_to`](@ref), or [`vector_transport_direction`](@ref) without specifying the vector transport method. -By default, this is [`DifferentiatedRetractionVectorTransport`](@ref)([`default_retraction_method`](@ref)`(M))`. +By default, this is [`ParallelTransport`](@ref). """ -function default_vector_transport_method(M::AbstractManifold) - return DifferentiatedRetractionVectorTransport(default_retraction_method(M)) +function default_vector_transport_method(::AbstractManifold) + return ParallelTransport() end @doc raw""" @@ -289,7 +322,7 @@ to different [`AbstractRetractionMethod`](@ref) and [`AbstractInverseRetractionM When you have $X=log_pd$ and $Y = -\log_q \operatorname{Pl}(p,d,q)$, you will obtain the [`PoleLadderTransport`](@ref). When performing multiple steps, this -method avoidsd the switching to the tangent space. Keep in mind that after $n$ successive +method avoids the switching to the tangent space. Keep in mind that after $n$ successive steps the tangent vector reads $Y_n = (-1)^n\log_q \operatorname{Pl}(p_{n-1},d_{n-1},p_n)$. It is cheaper to evaluate than [`schilds_ladder`](@ref), sinc if you want to form multiple @@ -428,59 +461,96 @@ end """ vector_transport_along(M::AbstractManifold, p, X, c) - vector_transport_along(M::AbstractManifold, p, X, c, method::AbstractVectorTransportMethod) + vector_transport_along(M::AbstractManifold, p, X, c, m::AbstractVectorTransportMethod) Transport a vector `X` from the tangent space at a point `p` on the [`AbstractManifold`](@ref) `M` along the curve represented by `c` using the `method`, which defaults to [`default_vector_transport_method`](@ref)`(M)`. """ -function vector_transport_along(M::AbstractManifold, p, X, c) - return vector_transport_along(M, p, X, c, default_vector_transport_method(M)) -end function vector_transport_along( M::AbstractManifold, p, X, c, - m::AbstractVectorTransportMethod, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), +) + return _vector_transport_along(M, p, X, c, m) +end +function _vector_transport_along(M::AbstractManifold, p, X, c, ::ParallelTransport) + return parallel_transport_along(M, p, X, c) +end +function _vector_transport_along( + M::AbstractManifold, + p, + X, + c, + m::DifferentiatedRetractionVectorTransport, +) + return vector_transport_along_diff(M, p, X, c, m.retraction) +end +function vector_transport_along_diff( + M::AbstractManifold, + p, + X, + c, + m::AbstractRetractionMethod, ) Y = allocate_result(M, vector_transport_along, X, p) - vector_transport_along!(M, Y, p, X, c, m) - return Y + return vector_transport_along_diff!(M, Y, p, X, c, m) +end +function _vector_transport_along(M::AbstractManifold, p, X, c, ::ProjectionTransport) + return vector_transport_along_project(M, p, X, c) +end +function vector_transport_along_project(M::AbstractManifold, p, X, c) + Y = allocate_result(M, vector_transport_along, X, p) + return vector_transport_along_project!(M, Y, p, X, c) +end +function _vector_transport_along(M::AbstractManifold, p, X, c, m::PoleLadderTransport) + Y = allocate_result(M, vector_transport_along, X, p) + return _vector_transport_along!(M, Y, p, X, c, m) +end +function _vector_transport_along(M::AbstractManifold, p, X, c, m::SchildsLadderTransport) + Y = allocate_result(M, vector_transport_along, X, p) + return _vector_transport_along!(M, Y, p, X, c, m) end """ vector_transport_along!(M::AbstractManifold, Y, p, X, c) - vector_transport_along!(M::AbstractManifold, Y, p, X, c, method::AbstractVectorTransportMethod) + vector_transport_along!(M::AbstractManifold, Y, p, X, c, m::AbstractVectorTransportMethod) Transport a vector `X` from the tangent space at a point `p` on the [`AbstractManifold`](@ref) `M` along the curve represented by `c` using the `method`, which defaults to [`default_vector_transport_method`](@ref)`(M)`. The result is saved to `Y`. """ -function vector_transport_along!(M::AbstractManifold, Y, p, X, c) - return vector_transport_along!(M, Y, p, X, c, default_vector_transport_method(M)) -end function vector_transport_along!( M::AbstractManifold, Y, p, X, c, - method::AbstractVectorTransportMethod, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), ) - return error( - manifold_function_not_implemented_message( - M, - vector_transport_along!, - M, - Y, - p, - X, - c, - method, - ), - ) + return _vector_transport_along!(M, Y, p, X, c, m) +end +function _vector_transport_along!(M::AbstractManifold, Y, p, X, c, ::ParallelTransport) + return parallel_transport_along!(M, Y, p, X, c) +end +function _vector_transport_along!( + M::AbstractManifold, + Y, + p, + X, + c, + m::DifferentiatedRetractionVectorTransport, +) + return vector_transport_along_diff!(M, Y, p, X, c, m.retraction) +end +function vector_transport_along_diff! end +function _vector_transport_along!(M::AbstractManifold, Y, p, X, c, ::ProjectionTransport) + return vector_transport_along_project!(M, Y, p, X, c) end +function vector_transport_along_project! end + @doc raw""" vector_transport_along!( M::AbstractManifold, @@ -488,7 +558,7 @@ end p, X, c::AbstractVector, - method::AbstractVectorTransportMethod + m::AbstractVectorTransportMethod ) where {T} Compute the vector transport along a discretized curve `c` using an @@ -500,23 +570,20 @@ function vector_transport_along!( p, X, c::AbstractVector, - method::AbstractVectorTransportMethod, + m::AbstractVectorTransportMethod, ) n = length(c) if n == 0 copyto!(Y, X) else - # we shouldn't assume that vector_transport_to! works when both input and output - # vectors are the same object - Y2 = allocate(X) - vector_transport_to!(M, Y2, p, X, c[1], method) + vector_transport_to!(M, Y, p, X, c[1], m) for i in 1:(length(c) - 1) - vector_transport_to!(M, Y, c[i], Y2, c[i + 1], method) - copyto!(Y2, Y) + vector_transport_to!(M, Y, c[i], Y, c[i + 1], m) end end return Y end + @doc raw""" function vector_transport_along!( M::AbstractManifold, @@ -524,7 +591,7 @@ end p, X, c::AbstractVector, - method::PoleLadderTransport + m::PoleLadderTransport ) Compute the vector transport along a discretized curve using @@ -532,36 +599,36 @@ Compute the vector transport along a discretized curve using This method is avoiding additional allocations as well as inner exp/log by performing all ladder steps on the manifold and only computing one tangent vector in the end. """ -function vector_transport_along!( +function _vector_transport_along!( M::AbstractManifold, Y, p, X, c::AbstractVector, - method::PoleLadderTransport, + m::PoleLadderTransport, ) clen = length(c) if clen == 0 copyto!(Y, X) else - d = retract(M, p, X, method.retraction) - m = mid_point(M, p, c[1]) + d = retract(M, p, X, m.retraction) + mp = mid_point(M, p, c[1]) pole_ladder!( M, d, p, d, c[1], - m, + mp, Y; - retraction = method.retraction, - inverse_retraction = method.inverse_retraction, + retraction = m.retraction, + inverse_retraction = m.inverse_retraction, ) for i in 1:(clen - 1) # precompute mid point inplace ci = c[i] cip1 = c[i + 1] - mid_point!(M, m, ci, cip1) + mid_point!(M, mp, ci, cip1) # compute new ladder point pole_ladder!( M, @@ -569,13 +636,13 @@ function vector_transport_along!( ci, d, cip1, - m, + mp, Y; - retraction = method.retraction, - inverse_retraction = method.inverse_retraction, + retraction = m.retraction, + inverse_retraction = m.inverse_retraction, ) end - inverse_retract!(M, Y, c[clen], d, method.inverse_retraction) + inverse_retract!(M, Y, c[clen], d, m.inverse_retraction) Y *= (-1)^clen end return Y @@ -587,7 +654,7 @@ end p, X, c::AbstractVector, - method::SchildsLadderTransport + m::SchildsLadderTransport ) Compute the vector transport along a discretized curve using @@ -595,36 +662,36 @@ Compute the vector transport along a discretized curve using This method is avoiding additional allocations as well as inner exp/log by performing all ladder steps on the manifold and only computing one tangent vector in the end. """ -function vector_transport_along!( +function _vector_transport_along!( M::AbstractManifold, Y, p, X, c::AbstractVector, - method::SchildsLadderTransport, + m::SchildsLadderTransport, ) clen = length(c) if clen == 0 copyto!(Y, X) else - d = retract(M, p, X, method.retraction) - m = mid_point(M, c[1], d) + d = retract(M, p, X, m.retraction) + mp = mid_point(M, c[1], d) schilds_ladder!( M, d, p, d, c[1], - m, + mp, Y; - retraction = method.retraction, - inverse_retraction = method.inverse_retraction, + retraction = m.retraction, + inverse_retraction = m.inverse_retraction, ) for i in 1:(clen - 1) ci = c[i] cip1 = c[i + 1] # precompute mid point inplace - mid_point!(M, m, cip1, d) + mid_point!(M, mp, cip1, d) # compute new ladder point schilds_ladder!( M, @@ -632,120 +699,281 @@ function vector_transport_along!( ci, d, cip1, - m, + mp, Y; - retraction = method.retraction, - inverse_retraction = method.inverse_retraction, + retraction = m.retraction, + inverse_retraction = m.inverse_retraction, ) end - inverse_retract!(M, Y, c[clen], d, method.inverse_retraction) + inverse_retract!(M, Y, c[clen], d, m.inverse_retraction) end return Y end - -""" +@doc raw""" vector_transport_direction(M::AbstractManifold, p, X, d) - vector_transport_direction(M::AbstractManifold, p, X, d, method::AbstractVectorTransportMethod) + vector_transport_direction(M::AbstractManifold, p, X, d, m::AbstractVectorTransportMethod) -Transport a vector `X` from the tangent space at a point `p` on the [`AbstractManifold`](@ref) `M` -in the direction indicated by the tangent vector `d` at `p`. By default, [`retract`](@ref) and -[`vector_transport_to!`](@ref) are used with the `method`, which defaults -to [`default_vector_transport_method`](@ref)`(M)`. +Given an [`AbstractManifold`](@ref) ``\mathcal M`` the vector transport is a generalization of the +[`parallel_transport_direction`](@ref) that identifies vectors from different tangent spaces. + +More precisely using [^AbsilMahonySepulchre2008], Def. 8.1.1, a vector transport +``T_{p,d}: T_p\mathcal M \to T_q\mathcal M``, ``p∈ \mathcal M``, ``Y∈ T_p\mathcal M`` is a smooth mapping +associated to a retraction ``\operatorname{retr}_p(Y) = q`` such that + +1. (associated retraction) ``\mathcal T_{p,d}X ∈ T_q\mathcal M`` if and only if ``q = \operatorname{retr}_p(d)``. +2. (consistency) ``\mathcal T_{p,0_p}X = X`` for all ``X∈T_p\mathcal M`` +3. (linearity) ``\mathcal T_{p,d}(αX+βY) = α\mathcal T_{p,d}X + β\mathcal T_{p,d}Y`` + +For the [`AbstractVectorTransportMethod`](@ref) we might even omit the third point. +The [`AbstractLinearVectorTransportMethod`](@ref)s are linear. + +# Input Parameters +* `M` a manifold +* `p` indicating the tangent space of +* `X` the tangent vector to be transported +* `d` indicating a transport direction (and distance through its length) +* `m` an [`AbstractVectorTransportMethod`](@ref), by default [`default_vector_transport_method`](@ref), so usually [`ParallelTransport`](@ref) + +Usually this method requires a [`AbstractRetractionMethod`](@ref) as well. +By default this is assumed to be the [`default_retraction_method`](@ref) or +implicitly given (and documented) for a vector transport. +To explicitly distinguish different retractions for a vector transport, +see [`VectorTransportDirection`](@ref). + +Instead of spcifying a start direction `d` one can equivalently also specify a target tanget space +``T_q\mathcal M``, see [`vector_transport_to`](@ref). +By default [`vector_transport_direction`](@ref) falls back to using [`vector_transport_to`](@ref), +using the [`default_retraction_method`](@ref) on `M`. + +[^AbsilMahonySepulchre2008]: + > Absil, P.-A., Mahony, R. and Sepulchre R., + > _Optimization Algorithms on Matrix Manifolds_ + > Princeton University Press, 2008, + > doi: [10.1515/9781400830244](https://doi.org/10.1515/9781400830244) + > [open access](http://press.princeton.edu/chapters/absil/) """ -function vector_transport_direction(M::AbstractManifold, p, X, d) - return vector_transport_direction(M, p, X, d, default_vector_transport_method(M)) -end function vector_transport_direction( M::AbstractManifold, p, X, d, - method::AbstractVectorTransportMethod, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), ) - Y = allocate_result(M, vector_transport_direction, X, p, d) - vector_transport_direction!(M, Y, p, X, d, method) - return Y + return _vector_transport_direction(M, p, X, d, m) +end +function _vector_transport_direction( + M::AbstractManifold, + p, + X, + d, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), +) + r = default_retraction_method(M) + return vector_transport_to(M, p, X, retract(M, p, d, r), m) end +function _vector_transport_direction( + M::AbstractManifold, + p, + X, + d, + m::VectorTransportDirection, +) + return vector_transport_to(M, p, X, retract(M, p, d, m.retraction), m.vector_transport) +end +function _vector_transport_direction( + M::AbstractManifold, + p, + X, + d, + m::DifferentiatedRetractionVectorTransport{R}, +) where {R<:AbstractRetractionMethod} + return vector_transport_direction_diff(M, p, X, d, m.retraction) +end +function vector_transport_direction_diff( + M::AbstractManifold, + p, + X, + d, + r::AbstractRetractionMethod, +) + Y = allocate_result(M, vector_transport_direction, p, X, d) + return vector_transport_direction_diff!(M, Y, p, X, d, r) +end +function _vector_transport_direction(M::AbstractManifold, p, X, d, ::ParallelTransport) + return parallel_transport_direction(M, p, X, d) +end + """ vector_transport_direction!(M::AbstractManifold, Y, p, X, d) - vector_transport_direction!(M::AbstractManifold, Y, p, X, d, method::AbstractVectorTransportMethod) + vector_transport_direction!(M::AbstractManifold, Y, p, X, d, m::AbstractVectorTransportMethod) Transport a vector `X` from the tangent space at a point `p` on the [`AbstractManifold`](@ref) `M` in the direction indicated by the tangent vector `d` at `p`. By default, [`retract`](@ref) and -[`vector_transport_to!`](@ref) are used with the `method`, which defaults -to [`default_vector_transport_method`](@ref)`(M)`. The result is saved to `Y`. +[`vector_transport_to!`](@ref) are used with the `m` and `r`, which default +to [`default_vector_transport_method`](@ref)`(M)` and [`default_retraction_method`](@ref)`(M)`, respectively. +The result is saved to `Y`. + +See [`vector_transport_direction`](@ref) for more details. """ -function vector_transport_direction!(M::AbstractManifold, Y, p, X, d) - return vector_transport_direction!(M, Y, p, X, d, default_vector_transport_method(M)) -end function vector_transport_direction!( M::AbstractManifold, Y, p, X, d, - method::AbstractVectorTransportMethod, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), ) - y = retract(M, p, d) - return vector_transport_to!(M, Y, p, X, y, method) + return _vector_transport_direction!(M, Y, p, X, d, m) +end +function _vector_transport_direction!( + M::AbstractManifold, + Y, + p, + X, + d, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), +) + r = default_retraction_method(M) + return vector_transport_to!(M, Y, p, X, retract(M, p, d, r), m) +end +function _vector_transport_direction!( + M::AbstractManifold, + Y, + p, + X, + d, + m::VectorTransportDirection, +) + return vector_transport_to!( + M, + Y, + p, + X, + retract(M, p, d, m.retraction), + m.vector_transport, + ) +end +function _vector_transport_direction!( + M::AbstractManifold, + Y, + p, + X, + d, + m::DifferentiatedRetractionVectorTransport, +) + return vector_transport_direction_diff!(M, Y, p, X, d, m.retraction) +end +function vector_transport_direction_diff! end +function _vector_transport_direction!(M::AbstractManifold, Y, p, X, d, ::ParallelTransport) + return parallel_transport_direction!(M, Y, p, X, d) end -""" +@doc raw""" vector_transport_to(M::AbstractManifold, p, X, q) - vector_transport_to(M::AbstractManifold, p, X, q, method::AbstractVectorTransportMethod) + vector_transport_to(M::AbstractManifold, p, X, q, m::AbstractVectorTransportMethod) + vector_transport_to(M::AbstractManifold, p, X, q, m::AbstractVectorTransportMethod) Transport a vector `X` from the tangent space at a point `p` on the [`AbstractManifold`](@ref) `M` -along the [`shortest_geodesic`](@ref) to the tangent space at another point `q`. -By default, the [`AbstractVectorTransportMethod`](@ref) `method` is -[`default_vector_transport_method`](@ref)`(M)`. +along a curve implicitly given by an [`AbstractRetractionMethod`](@ref) associated to `m`. +By default `m` is the [`default_vector_transport_method`](@ref)`(M)`. +To explicitly specify a (different) retraction to the implicitly assumeed retraction, see [`VectorTransportTo`](@ref). +Note that some vector transport methods might also carry their own retraction they are associated to, +like the [`DifferentiatedRetractionVectorTransport`](@ref) and some are even independent of the retraction, for example the [`ProjectionTransport`](@ref). + +This method is equivalent to using ``d = \operatorname{retr}^{-1}_p(q)`` in [`vector_transport_direction`](@ref)`(M, p, X, q, m, r)`, +where you can find the formal definition. This is the fallback for [`VectorTransportTo`](@ref). """ -function vector_transport_to(M::AbstractManifold, p, X, q) - return vector_transport_to(M, p, X, q, default_vector_transport_method(M)) -end function vector_transport_to( M::AbstractManifold, p, X, q, - method::AbstractVectorTransportMethod, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), ) - Y = allocate_result(M, vector_transport_to, X, p, q) - vector_transport_to!(M, Y, p, X, q, method) - return Y + return _vector_transport_to(M, p, X, q, m) +end +function _vector_transport_to(M::AbstractManifold, p, X, q, m::VectorTransportTo) + d = inverse_retract(M, p, q, m.inverse_retraction) + return vector_transport_direction(M, p, X, d, m.vector_transport) +end +function _vector_transport_to(M::AbstractManifold, p, X, q, ::ParallelTransport) + return parallel_transport_to(M, p, X, q) +end +function _vector_transport_to( + M::AbstractManifold, + p, + X, + q, + m::DifferentiatedRetractionVectorTransport, +) + return vector_transport_to_diff(M, p, X, q, m.retraction) +end +function vector_transport_to_diff(M::AbstractManifold, p, X, q, r) + Y = allocate_result(M, vector_transport_to, X, p) + return vector_transport_to_diff!(M, Y, p, X, q, r) +end +function _vector_transport_to(M::AbstractManifold, p, X, q, ::ProjectionTransport) + return vector_transport_to_project(M, p, X, q) +end +function vector_transport_to_project(M::AbstractManifold, p, X, q) + Y = allocate_result(M, vector_transport_to, X, p) + return vector_transport_to_project!(M, Y, p, X, q) +end +function vector_transport_to_project!(M::AbstractManifold, Y, p, X, q) + return project!(M, Y, q, X) end + """ vector_transport_to!(M::AbstractManifold, Y, p, X, q) - vector_transport_to!(M::AbstractManifold, Y, p, X, q, method::AbstractVectorTransportMethod) + vector_transport_to!(M::AbstractManifold, Y, p, X, q, m::AbstractVectorTransportMethod) Transport a vector `X` from the tangent space at a point `p` on the [`AbstractManifold`](@ref) `M` -along the [`shortest_geodesic`](@ref) to the tangent space at another point `q`. -By default, the [`AbstractVectorTransportMethod`](@ref) `method` is -[`default_vector_transport_method`](@ref)`(M)`. The result is saved to `Y`. -""" -function vector_transport_to!(M::AbstractManifold, Y, p, X, q) - return vector_transport_to!(M, Y, p, X, q, default_vector_transport_method(M)) -end -""" - vector_transport_to!(M::AbstractManifold, Y, p, X, q, method::ProjectionTransport) +to `q` using the [`AbstractVectorTransportMethod`](@ref) `m` and the [`AbstractRetractionMethod`](@ref) `r`. -Transport a vector `X` from the tangent space at `p` on the [`AbstractManifold`](@ref) `M` by -interpreting it as an element of the embedding and then projecting it onto the tangent space -at `q`. This function needs to be separately implemented for each manifold because -projection [`project`](@ref project(M::AbstractManifold, p, X)) may also change vector -representation (if it's different than in the embedding) and it is assumed that the vector -`X` already has the correct representation for `M`. +The result is computed in `Y`. +See [`vector_transport_to`](@ref) for more details. """ -vector_transport_to!(M::AbstractManifold, Y, p, X, q, m::ProjectionTransport) - -@doc raw""" - vector_transport_to!(M::AbstractManifold, Y, p, X, q, method::PoleLadderTransport) +function vector_transport_to!( + M::AbstractManifold, + Y, + p, + X, + q, + m::AbstractVectorTransportMethod = default_vector_transport_method(M), +) + return _vector_transport_to!(M, Y, p, X, q, m) +end +function _vector_transport_to!(M::AbstractManifold, Y, p, X, q, m::VectorTransportTo) + d = inverse_retract(M, p, q, m.inverse_retraction) + return vector_transport_direction!(M, Y, p, X, d, m.vector_transport) +end +function _vector_transport_to!(M::AbstractManifold, Y, p, X, q, ::ParallelTransport) + return parallel_transport_to!(M, Y, p, X, q) +end +function _vector_transport_to!( + M::AbstractManifold, + Y, + p, + X, + q, + m::DifferentiatedRetractionVectorTransport, +) + return vector_transport_to_diff!(M, Y, p, X, q, m.retraction) +end +function vector_transport_to_diff! end +function _vector_transport_to!(M::AbstractManifold, Y, p, X, q, ::ProjectionTransport) + return vector_transport_to_project!(M, Y, p, X, q) +end +function vector_transport_to_project! end -Perform a vector transport by using [`PoleLadderTransport`](@ref). -""" -function vector_transport_to!(M::AbstractManifold, Y, p, X, q, m::PoleLadderTransport) +function _vector_transport_to(M::AbstractManifold, p, X, q, m::PoleLadderTransport) + Y = allocate_result(M, vector_transport_to, X, p) + return _vector_transport_to!(M, Y, p, X, q, m) +end +function _vector_transport_to!(M::AbstractManifold, Y, p, X, q, m::PoleLadderTransport) inverse_retract!( M, Y, @@ -764,25 +992,20 @@ function vector_transport_to!(M::AbstractManifold, Y, p, X, q, m::PoleLadderTran return Y end -function vector_transport_to!( - M::AbstractManifold, - Y, - p, - X, - q, - m::ScaledVectorTransport{T}, -) where {T<:AbstractVectorTransportMethod} +function _vector_transport_to(M::AbstractManifold, p, X, c, m::ScaledVectorTransport) + Y = allocate_result(M, vector_transport_to, X, p) + return _vector_transport_to!(M, Y, p, X, c, m) +end +function _vector_transport_to!(M::AbstractManifold, Y, p, X, q, m::ScaledVectorTransport) vector_transport_to!(M, Y, p, X, q, m.method) Y .*= norm(M, p, X) / norm(M, q, Y) return Y end - -@doc raw""" - vector_transport_to!(M::AbstractManifold, Y, p, X, q, method::SchildsLadderTransport) - -Perform a vector transport by using [`SchildsLadderTransport`](@ref). -""" -function vector_transport_to!(M::AbstractManifold, Y, p, X, q, m::SchildsLadderTransport) +function _vector_transport_to(M::AbstractManifold, p, X, c, m::SchildsLadderTransport) + Y = allocate_result(M, vector_transport_to, X, p) + return _vector_transport_to!(M, Y, p, X, c, m) +end +function _vector_transport_to!(M::AbstractManifold, Y, p, X, q, m::SchildsLadderTransport) return inverse_retract!( M, Y, @@ -798,28 +1021,3 @@ function vector_transport_to!(M::AbstractManifold, Y, p, X, q, m::SchildsLadderT m.inverse_retraction, ) end - - -function vector_transport_to!( - M::AbstractManifold, - Y, - p, - X, - q, - method::AbstractVectorTransportMethod, -) - return error( - manifold_function_not_implemented_message( - M, - vector_transport_to!, - Y, - p, - X, - q, - method, - ), - ) -end - -const VECTOR_TRANSPORT_DISAMBIGUATION = - [PoleLadderTransport, ScaledVectorTransport, SchildsLadderTransport] diff --git a/test/bases.jl b/test/bases.jl index c477db70..a8a6a468 100644 --- a/test/bases.jl +++ b/test/bases.jl @@ -1,51 +1,25 @@ using LinearAlgebra using ManifoldsBase -using ManifoldsBase: DefaultManifold, ℝ, ℂ +using ManifoldsBase: DefaultManifold, ℝ, ℂ, RealNumbers, ComplexNumbers using ManifoldsBase: CotangentSpace, CotangentSpaceType, TangentSpace, TangentSpaceType using ManifoldsBase: FVector using Test import Base: +, -, *, copyto!, isapprox -import ManifoldsBase: allocate +import ManifoldsBase: + allocate, + get_vector_orthonormal!, + get_coordinates_orthonormal!, + get_basis_orthogonal, + get_basis_orthonormal struct ProjManifold <: AbstractManifold{ℝ} end -ManifoldsBase.inner(::ProjManifold, x, w, v) = dot(w, v) -ManifoldsBase.project!(::ProjManifold, w, x, v) = (w .= v .- dot(x, v) .* x) +ManifoldsBase.inner(::ProjManifold, p, X, Y) = dot(X, Y) +ManifoldsBase.project!(::ProjManifold, Y, p, X) = (Y .= X .- dot(p, X) .* p) ManifoldsBase.representation_size(::ProjManifold) = (2, 3) ManifoldsBase.manifold_dimension(::ProjManifold) = 5 -ManifoldsBase.get_vector(::ProjManifold, x, v, ::DefaultOrthonormalBasis) = reverse(v) - -@testset "Dispatch" begin - @test ManifoldsBase.decorator_transparent_dispatch( - get_basis, - DefaultManifold(3), - [0.0, 0.0, 0.0], - DefaultBasis(), - ) === Val(:parent) - @test ManifoldsBase.decorator_transparent_dispatch( - get_coordinates, - DefaultManifold(3), - [0.0, 0.0, 0.0], - ) === Val(:parent) - @test ManifoldsBase.decorator_transparent_dispatch( - get_coordinates!, - DefaultManifold(3), - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - ) === Val(:transparent) - @test ManifoldsBase.decorator_transparent_dispatch( - get_vector, - DefaultManifold(3), - [0.0, 0.0, 0.0], - ) === Val(:parent) - @test ManifoldsBase.decorator_transparent_dispatch( - get_vector!, - DefaultManifold(3), - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - ) === Val(:transparent) -end +ManifoldsBase.get_vector_orthonormal(::ProjManifold, p, X, N) = reverse(X) struct ProjectionTestManifold <: AbstractManifold{ℝ} end @@ -57,75 +31,6 @@ function ManifoldsBase.project!(::ProjectionTestManifold, Y, p, X) end ManifoldsBase.manifold_dimension(::ProjectionTestManifold) = 100 -@testset "Projected and arbitrary orthonormal basis" begin - M = ProjManifold() - x = [ - sqrt(2)/2 0.0 0.0 - 0.0 sqrt(2)/2 0.0 - ] - - for pB in (ProjectedOrthonormalBasis(:svd), ProjectedOrthonormalBasis(:gram_schmidt)) - if pB isa ProjectedOrthonormalBasis{:gram_schmidt,ℝ} - pb = get_basis( - M, - x, - pB; - warn_linearly_dependent = true, - skip_linearly_dependent = true, - ) # skip V4, wich is -V1 after proj. - @test_throws ErrorException get_basis(M, x, pB) # error - else - pb = get_basis(M, x, pB) # skips V4 automatically - end - @test number_system(pb) == ℝ - N = manifold_dimension(M) - @test isa(pb, CachedBasis) - @test CachedBasis(pb) === pb - @test length(get_vectors(M, x, pb)) == N - # test orthonormality - for i in 1:N - @test norm(M, x, get_vectors(M, x, pb)[i]) ≈ 1 - for j in (i + 1):N - @test inner(M, x, get_vectors(M, x, pb)[i], get_vectors(M, x, pb)[j]) ≈ 0 atol = - 1e-15 - end - end - end - aonb = get_basis(M, x, DefaultOrthonormalBasis()) - @test size(get_vectors(M, x, aonb)) == (5,) - @test get_vectors(M, x, aonb)[1] ≈ [0, 0, 0, 0, 1] - - @testset "Gram-Schmidt" begin - # for a basis - M = ManifoldsBase.DefaultManifold(3) - p = zeros(3) - V = [[2.0, 0.0, 0.0], [1.1, 2.2, 0.0], [0.0, 3.3, 4.4]] - b1 = ManifoldsBase.gram_schmidt(M, p, V) - b2 = ManifoldsBase.gram_schmidt(M, zeros(3), CachedBasis(DefaultBasis(), V)) - @test b1 == get_vectors(M, p, b2) - # projected gram schmidt - tm = ProjectionTestManifold() - bt = ProjectedOrthonormalBasis(:gram_schmidt) - p = [sqrt(2) / 2, 0.0, sqrt(2) / 2, 0.0, 0.0] - @test_throws ErrorException get_basis(tm, p, bt) - b = get_basis( - tm, - p, - bt; - return_incomplete_set = true, - skip_linearly_dependent = true, #skips 3 and 5 - ) - @test length(get_vectors(tm, p, b)) == 3 - @test_throws ErrorException ManifoldsBase.gram_schmidt(M, p, [V[1]]) - @test_throws ErrorException ManifoldsBase.gram_schmidt( - M, - p, - [V[1], V[1], V[1]]; - skip_linearly_dependent = true, - ) - end -end - struct NonManifold <: AbstractManifold{ℝ} end struct NonBasis <: ManifoldsBase.AbstractBasis{ℝ,TangentSpaceType} end @@ -172,10 +77,10 @@ function ManifoldsBase.exp!( return copyto!(y, x + v) end -function ManifoldsBase.get_basis( +function ManifoldsBase.get_basis_orthonormal( ::DefaultManifold, p::NonBroadcastBasisThing, - B::DefaultOrthonormalBasis{ℝ,TangentSpaceType}, + 𝔽, ) return CachedBasis( B, @@ -185,26 +90,22 @@ function ManifoldsBase.get_basis( ], ) end -function ManifoldsBase.get_basis( - ::DefaultManifold, - p::NonBroadcastBasisThing, - B::DefaultOrthogonalBasis{ℝ,TangentSpaceType}, -) +function ManifoldsBase.get_basis_orthogonal(::DefaultManifold, p::NonBroadcastBasisThing, 𝔽) return CachedBasis( - B, + DefaultOrthogonalBasis(𝔽), [ NonBroadcastBasisThing(ManifoldsBase._euclidean_basis_vector(p.v, i)) for i in eachindex(p.v) ], ) end -function ManifoldsBase.get_basis( +function ManifoldsBase.get_basis_default( ::DefaultManifold, p::NonBroadcastBasisThing, - B::DefaultBasis{ℝ,TangentSpaceType}, + N::ManifoldsBase.RealNumbers, ) return CachedBasis( - B, + DefaultBasis(N), [ NonBroadcastBasisThing(ManifoldsBase._euclidean_basis_vector(p.v, i)) for i in eachindex(p.v) @@ -212,23 +113,23 @@ function ManifoldsBase.get_basis( ) end -function ManifoldsBase.get_coordinates!( +function ManifoldsBase.get_coordinates_orthonormal!( M::DefaultManifold, Y, ::NonBroadcastBasisThing, X::NonBroadcastBasisThing, - ::DefaultOrthonormalBasis{ℝ,TangentSpaceType}, + ::RealNumbers, ) copyto!(Y, reshape(X.v, manifold_dimension(M))) return Y end -function ManifoldsBase.get_vector!( +function ManifoldsBase.get_vector_orthonormal!( M::DefaultManifold, Y::NonBroadcastBasisThing, ::NonBroadcastBasisThing, X, - ::DefaultOrthonormalBasis{ℝ,TangentSpaceType}, + ::RealNumbers, ) copyto!(Y.v, reshape(X, representation_size(M))) return Y @@ -247,296 +148,410 @@ ManifoldsBase._get_vector_cache_broadcast(::NonBroadcastBasisThing) = Val(false) DiagonalizingBasisProxy() = DiagonalizingOrthonormalBasis([1.0, 0.0, 0.0]) -@testset "ManifoldsBase.jl stuff" begin +@testset "Bases" begin + @testset "Projected and arbitrary orthonormal basis" begin + M = ProjManifold() + x = [ + sqrt(2)/2 0.0 0.0 + 0.0 sqrt(2)/2 0.0 + ] + + for pB in + (ProjectedOrthonormalBasis(:svd), ProjectedOrthonormalBasis(:gram_schmidt)) + if pB isa ProjectedOrthonormalBasis{:gram_schmidt,ℝ} + pb = get_basis( + M, + x, + pB; + warn_linearly_dependent = true, + skip_linearly_dependent = true, + ) # skip V4, wich is -V1 after proj. + @test_throws ErrorException get_basis(M, x, pB) # error + else + pb = get_basis(M, x, pB) # skips V4 automatically + end + @test number_system(pb) == ℝ + N = manifold_dimension(M) + @test isa(pb, CachedBasis) + @test CachedBasis(pb) === pb + @test length(get_vectors(M, x, pb)) == N + # test orthonormality + for i in 1:N + @test norm(M, x, get_vectors(M, x, pb)[i]) ≈ 1 + for j in (i + 1):N + @test inner(M, x, get_vectors(M, x, pb)[i], get_vectors(M, x, pb)[j]) ≈ + 0 atol = 1e-15 + end + end + end + aonb = get_basis(M, x, DefaultOrthonormalBasis()) + @test size(get_vectors(M, x, aonb)) == (5,) + @test get_vectors(M, x, aonb)[1] ≈ [0, 0, 0, 0, 1] + + @testset "Gram-Schmidt" begin + # for a basis + M = ManifoldsBase.DefaultManifold(3) + p = zeros(3) + V = [[2.0, 0.0, 0.0], [1.1, 2.2, 0.0], [0.0, 3.3, 4.4]] + b1 = ManifoldsBase.gram_schmidt(M, p, V) + b2 = ManifoldsBase.gram_schmidt(M, zeros(3), CachedBasis(DefaultBasis(), V)) + @test b1 == get_vectors(M, p, b2) + # projected gram schmidt + tm = ProjectionTestManifold() + bt = ProjectedOrthonormalBasis(:gram_schmidt) + p = [sqrt(2) / 2, 0.0, sqrt(2) / 2, 0.0, 0.0] + @test_throws ErrorException get_basis(tm, p, bt) + b = get_basis( + tm, + p, + bt; + return_incomplete_set = true, + skip_linearly_dependent = true, #skips 3 and 5 + ) + @test length(get_vectors(tm, p, b)) == 3 + @test_throws ErrorException ManifoldsBase.gram_schmidt(M, p, [V[1]]) + @test_throws ErrorException ManifoldsBase.gram_schmidt( + M, + p, + [V[1], V[1], V[1]]; + skip_linearly_dependent = true, + ) + end + end - @testset "Errors" begin - m = NonManifold() - onb = DefaultOrthonormalBasis() + @testset "ManifoldsBase.jl stuff" begin - @test_throws ErrorException get_basis(m, [0], onb) - @test_throws MethodError get_basis(m, [0], NonBasis()) - @test_throws ErrorException get_coordinates(m, [0], [0], onb) - @test_throws ErrorException get_coordinates!(m, [0], [0], [0], onb) - @test_throws ErrorException get_vector(m, [0], [0], onb) - @test_throws ErrorException get_vector!(m, [0], [0], [0], onb) - @test_throws ErrorException get_vectors(m, [0], NonBasis()) - end + @testset "Errors" begin + m = NonManifold() + onb = DefaultOrthonormalBasis() - M = DefaultManifold(3) + @test_throws MethodError get_basis(m, [0], onb) + @test_throws MethodError get_basis(m, [0], NonBasis()) + @test_throws MethodError get_coordinates(m, [0], [0], onb) + @test_throws MethodError get_coordinates!(m, [0], [0], [0], onb) + @test_throws MethodError get_vector(m, [0], [0], onb) + @test_throws MethodError get_vector!(m, [0], [0], [0], onb) + @test_throws MethodError get_vectors(m, [0], NonBasis()) + end - @testset "Constructors" begin - @test DefaultBasis{ℂ,TangentSpaceType}() === DefaultBasis(ℂ) - @test DefaultOrthogonalBasis{ℂ,TangentSpaceType}() === DefaultOrthogonalBasis(ℂ) - @test DefaultOrthonormalBasis{ℂ,TangentSpaceType}() === DefaultOrthonormalBasis(ℂ) + M = DefaultManifold(3) - @test DefaultBasis{ℂ}(CotangentSpace) === DefaultBasis(ℂ, CotangentSpace) - @test DefaultOrthogonalBasis{ℂ}(CotangentSpace) === - DefaultOrthogonalBasis(ℂ, CotangentSpace) - @test DefaultOrthonormalBasis{ℂ}(CotangentSpace) === - DefaultOrthonormalBasis(ℂ, CotangentSpace) - end + @testset "Constructors" begin + @test DefaultBasis{ℂ,TangentSpaceType}() === DefaultBasis(ℂ) + @test DefaultOrthogonalBasis{ℂ,TangentSpaceType}() === DefaultOrthogonalBasis(ℂ) + @test DefaultOrthonormalBasis{ℂ,TangentSpaceType}() === + DefaultOrthonormalBasis(ℂ) - _pts = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] - @testset "basis representation" for BT in ( - DefaultBasis, - DefaultOrthonormalBasis, - DefaultOrthogonalBasis, - DiagonalizingBasisProxy, - ), - pts in (_pts, map(NonBroadcastBasisThing, _pts)) - - if BT == DiagonalizingBasisProxy && pts !== _pts - continue + @test DefaultBasis{ℂ}(CotangentSpace) === DefaultBasis(ℂ, CotangentSpace) + @test DefaultOrthogonalBasis{ℂ}(CotangentSpace) === + DefaultOrthogonalBasis(ℂ, CotangentSpace) + @test DefaultOrthonormalBasis{ℂ}(CotangentSpace) === + DefaultOrthonormalBasis(ℂ, CotangentSpace) end - v1 = log(M, pts[1], pts[2]) - @test ManifoldsBase.number_of_coordinates(M, BT()) == 3 - if BT != DiagonalizingBasisProxy + _pts = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] + @testset "basis representation" for BT in ( + DefaultBasis, + DefaultOrthonormalBasis, + DefaultOrthogonalBasis, + DiagonalizingBasisProxy, + ), + pts in (_pts, map(NonBroadcastBasisThing, _pts)) + + if BT == DiagonalizingBasisProxy && pts !== _pts + continue + end + v1 = log(M, pts[1], pts[2]) + @test ManifoldsBase.number_of_coordinates(M, BT()) == 3 + vb = get_coordinates(M, pts[1], v1, BT()) @test isa(vb, AbstractVector) vbi = get_vector(M, pts[1], vb, BT()) @test isapprox(M, pts[1], v1, vbi) - end - b = get_basis(M, pts[1], BT()) - if BT != DiagonalizingBasisProxy - if pts[1] isa Array - @test isa(b, CachedBasis{ℝ,BT{ℝ,TangentSpaceType},Vector{Vector{Float64}}}) - else - @test isa( - b, - CachedBasis{ - ℝ, - BT{ℝ,TangentSpaceType}, - Vector{NonBroadcastBasisThing{Vector{Float64}}}, - }, - ) + b = get_basis(M, pts[1], BT()) + if BT != DiagonalizingBasisProxy + if pts[1] isa Array + @test isa( + b, + CachedBasis{ℝ,BT{ℝ,TangentSpaceType},Vector{Vector{Float64}}}, + ) + else + @test isa( + b, + CachedBasis{ + ℝ, + BT{ℝ,TangentSpaceType}, + Vector{NonBroadcastBasisThing{Vector{Float64}}}, + }, + ) + end end - end - @test get_basis(M, pts[1], b) === b - N = manifold_dimension(M) - @test length(get_vectors(M, pts[1], b)) == N - # check orthonormality - if BT isa DefaultOrthonormalBasis && pts[1] isa Vector - for i in 1:N - @test norm(M, pts[1], get_vectors(M, pts[1], b)[i]) ≈ 1 - for j in (i + 1):N - @test inner( - M, - pts[1], - get_vectors(M, pts[1], b)[i], - get_vectors(M, pts[1], b)[j], - ) ≈ 0 + @test get_basis(M, pts[1], b) === b + N = manifold_dimension(M) + @test length(get_vectors(M, pts[1], b)) == N + # check orthonormality + if BT isa DefaultOrthonormalBasis && pts[1] isa Vector + for i in 1:N + @test norm(M, pts[1], get_vectors(M, pts[1], b)[i]) ≈ 1 + for j in (i + 1):N + @test inner( + M, + pts[1], + get_vectors(M, pts[1], b)[i], + get_vectors(M, pts[1], b)[j], + ) ≈ 0 + end + end + # check that the coefficients correspond to the basis + for i in 1:N + @test inner(M, pts[1], v1, get_vectors(M, pts[1], b)[i]) ≈ vb[i] end end - # check that the coefficients correspond to the basis - for i in 1:N - @test inner(M, pts[1], v1, get_vectors(M, pts[1], b)[i]) ≈ vb[i] + + if BT != DiagonalizingBasisProxy + @test get_coordinates(M, pts[1], v1, b) ≈ + get_coordinates(M, pts[1], v1, BT()) + @test get_vector(M, pts[1], vb, b) ≈ get_vector(M, pts[1], vb, BT()) end - end - if BT != DiagonalizingBasisProxy - @test get_coordinates(M, pts[1], v1, b) ≈ get_coordinates(M, pts[1], v1, BT()) - @test get_vector(M, pts[1], vb, b) ≈ get_vector(M, pts[1], vb, BT()) - end + v1c = Vector{Float64}(undef, 3) + get_coordinates!(M, v1c, pts[1], v1, b) + @test v1c ≈ get_coordinates(M, pts[1], v1, b) - v1c = Vector{Float64}(undef, 3) - get_coordinates!(M, v1c, pts[1], v1, b) - @test v1c ≈ get_coordinates(M, pts[1], v1, b) + v1cv = allocate(v1) + get_vector!(M, v1cv, pts[1], v1c, b) + @test isapprox(M, pts[1], v1, v1cv) + end + @testset "() Manifolds" begin + M = ManifoldsBase.DefaultManifold() + ManifoldsBase.allocate_coordinates(M, 1, Float64, 0) == 0.0 + ManifoldsBase.allocate_coordinates(M, 1, Float64, 1) == zeros(Float64, 1) + ManifoldsBase.allocate_coordinates(M, 1, Float64, 2) == zeros(Float64, 2) + end + end - v1cv = allocate(v1) - get_vector!(M, v1cv, pts[1], v1c, b) - @test isapprox(M, pts[1], v1, v1cv) + @testset "Complex DefaultManifold with real and complex Cached Bases" begin + M = ManifoldsBase.DefaultManifold(3; field = ℂ) + p = [1.0, 2.0im, 3.0] + X = [1.2, 2.2im, 2.3im] + b = [Matrix{Float64}(I, 3, 3)[:, i] for i in 1:3] + Bℝ = CachedBasis(DefaultOrthonormalBasis{ℝ}(), b) + aℝ = get_coordinates(M, p, X, Bℝ) + Yℝ = get_vector(M, p, aℝ, Bℝ) + @test Yℝ ≈ X + @test ManifoldsBase.number_of_coordinates(M, Bℝ) == 3 + + bℂ = [b..., (b .* 1im)...] + Bℂ = CachedBasis(DefaultOrthonormalBasis{ℂ}(), bℂ) + aℂ = get_coordinates(M, p, X, Bℂ) + Yℂ = get_vector(M, p, aℂ, Bℂ) + @test Yℂ ≈ X + @test ManifoldsBase.number_of_coordinates(M, Bℂ) == 6 end -end -@testset "Complex DeaultManifold with real and complex Cached Bases" begin - M = ManifoldsBase.DefaultManifold(3; field = ℂ) - p = [1.0, 2.0im, 3.0] - X = [1.2, 2.2im, 2.3im] - b = [Matrix{Float64}(I, 3, 3)[:, i] for i in 1:3] - Bℝ = CachedBasis(DefaultOrthonormalBasis{ℝ}(), b) - aℝ = get_coordinates(M, p, X, Bℝ) - Yℝ = get_vector(M, p, aℝ, Bℝ) - @test Yℝ ≈ X - @test ManifoldsBase.number_of_coordinates(M, Bℝ) == 3 - - bℂ = [b..., (b .* 1im)...] - Bℂ = CachedBasis(DefaultOrthonormalBasis{ℂ}(), bℂ) - aℂ = get_coordinates(M, p, X, Bℂ) - Yℂ = get_vector(M, p, aℂ, Bℂ) - @test Yℂ ≈ X - @test ManifoldsBase.number_of_coordinates(M, Bℂ) == 6 -end + @testset "Basis show methods" begin + @test sprint(show, DefaultBasis()) == "DefaultBasis(ℝ)" + @test sprint(show, DefaultOrthogonalBasis()) == "DefaultOrthogonalBasis(ℝ)" + @test sprint(show, DefaultOrthonormalBasis()) == "DefaultOrthonormalBasis(ℝ)" + @test sprint(show, DefaultOrthonormalBasis(ℂ)) == "DefaultOrthonormalBasis(ℂ)" + @test sprint(show, GramSchmidtOrthonormalBasis(ℂ)) == + "GramSchmidtOrthonormalBasis(ℂ)" + @test sprint(show, ProjectedOrthonormalBasis(:svd)) == + "ProjectedOrthonormalBasis(:svd, ℝ)" + @test sprint(show, ProjectedOrthonormalBasis(:gram_schmidt, ℂ)) == + "ProjectedOrthonormalBasis(:gram_schmidt, ℂ)" + + diag_onb = DiagonalizingOrthonormalBasis(Float64[1, 2, 3]) + @test sprint(show, "text/plain", diag_onb) == """ + DiagonalizingOrthonormalBasis(ℝ) with eigenvalue 0 in direction: + 3-element $(sprint(show, Vector{Float64})): + 1.0 + 2.0 + 3.0""" + + M = DefaultManifold(2, 3) + x = collect(reshape(1.0:6.0, (2, 3))) + pb = get_basis(M, x, DefaultOrthonormalBasis()) + B2 = DefaultOrthonormalBasis(ManifoldsBase.ℝ, ManifoldsBase.CotangentSpace) + pb2 = get_basis(M, x, B2) + + test_basis_string = """ + DefaultOrthonormalBasis(ℝ) with 6 basis vectors: + E1 = + 2×3 $(sprint(show, Matrix{Float64})): + 1.0 0.0 0.0 + 0.0 0.0 0.0 + E2 = + 2×3 $(sprint(show, Matrix{Float64})): + 0.0 0.0 0.0 + 1.0 0.0 0.0 + ⋮ + E5 = + 2×3 $(sprint(show, Matrix{Float64})): + 0.0 0.0 1.0 + 0.0 0.0 0.0 + E6 = + 2×3 $(sprint(show, Matrix{Float64})): + 0.0 0.0 0.0 + 0.0 0.0 1.0""" + + @test sprint(show, "text/plain", pb) == test_basis_string + @test sprint(show, "text/plain", pb2) == test_basis_string + + b = DiagonalizingOrthonormalBasis(get_vectors(M, x, pb)[1]) + dpb = CachedBasis(b, Float64[1, 2, 3, 4, 5, 6], get_vectors(M, x, pb)) + @test sprint(show, "text/plain", dpb) == """ + DiagonalizingOrthonormalBasis(ℝ) with eigenvalue 0 in direction: + 2×3 $(sprint(show, Matrix{Float64})): + 1.0 0.0 0.0 + 0.0 0.0 0.0 + and 6 basis vectors. + Basis vectors: + E1 = + 2×3 $(sprint(show, Matrix{Float64})): + 1.0 0.0 0.0 + 0.0 0.0 0.0 + E2 = + 2×3 $(sprint(show, Matrix{Float64})): + 0.0 0.0 0.0 + 1.0 0.0 0.0 + ⋮ + E5 = + 2×3 $(sprint(show, Matrix{Float64})): + 0.0 0.0 1.0 + 0.0 0.0 0.0 + E6 = + 2×3 $(sprint(show, Matrix{Float64})): + 0.0 0.0 0.0 + 0.0 0.0 1.0 + Eigenvalues: + 6-element $(sprint(show, Vector{Float64})): + 1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0""" + + M = DefaultManifold(1, 1, 1) + x = reshape(Float64[1], (1, 1, 1)) + pb = get_basis(M, x, DefaultOrthonormalBasis()) + @test sprint(show, "text/plain", pb) == """ + DefaultOrthonormalBasis(ℝ) with 1 basis vector: + E1 = + 1×1×1 $(sprint(show, Array{Float64,3})): + [:, :, 1] = + 1.0""" + + dpb = CachedBasis( + DiagonalizingOrthonormalBasis(get_vectors(M, x, pb)), + Float64[1], + get_vectors(M, x, pb), + ) -@testset "Basis show methods" begin - @test sprint(show, DefaultBasis()) == "DefaultBasis(ℝ)" - @test sprint(show, DefaultOrthogonalBasis()) == "DefaultOrthogonalBasis(ℝ)" - @test sprint(show, DefaultOrthonormalBasis()) == "DefaultOrthonormalBasis(ℝ)" - @test sprint(show, DefaultOrthonormalBasis(ℂ)) == "DefaultOrthonormalBasis(ℂ)" - @test sprint(show, GramSchmidtOrthonormalBasis(ℂ)) == "GramSchmidtOrthonormalBasis(ℂ)" - @test sprint(show, ProjectedOrthonormalBasis(:svd)) == - "ProjectedOrthonormalBasis(:svd, ℝ)" - @test sprint(show, ProjectedOrthonormalBasis(:gram_schmidt, ℂ)) == - "ProjectedOrthonormalBasis(:gram_schmidt, ℂ)" - - diag_onb = DiagonalizingOrthonormalBasis(Float64[1, 2, 3]) - @test sprint(show, "text/plain", diag_onb) == """ - DiagonalizingOrthonormalBasis(ℝ) with eigenvalue 0 in direction: - 3-element $(sprint(show, Vector{Float64})): - 1.0 - 2.0 - 3.0""" - - M = DefaultManifold(2, 3) - x = collect(reshape(1.0:6.0, (2, 3))) - pb = get_basis(M, x, DefaultOrthonormalBasis()) - - @test sprint(show, "text/plain", pb) == """ - DefaultOrthonormalBasis(ℝ) with 6 basis vectors: - E1 = - 2×3 $(sprint(show, Matrix{Float64})): - 1.0 0.0 0.0 - 0.0 0.0 0.0 - E2 = - 2×3 $(sprint(show, Matrix{Float64})): - 0.0 0.0 0.0 - 1.0 0.0 0.0 - ⋮ - E5 = - 2×3 $(sprint(show, Matrix{Float64})): - 0.0 0.0 1.0 - 0.0 0.0 0.0 - E6 = - 2×3 $(sprint(show, Matrix{Float64})): - 0.0 0.0 0.0 - 0.0 0.0 1.0""" - - b = DiagonalizingOrthonormalBasis(get_vectors(M, x, pb)[1]) - dpb = CachedBasis(b, Float64[1, 2, 3, 4, 5, 6], get_vectors(M, x, pb)) - @test sprint(show, "text/plain", dpb) == """ - DiagonalizingOrthonormalBasis(ℝ) with eigenvalue 0 in direction: - 2×3 $(sprint(show, Matrix{Float64})): - 1.0 0.0 0.0 - 0.0 0.0 0.0 - and 6 basis vectors. - Basis vectors: - E1 = - 2×3 $(sprint(show, Matrix{Float64})): - 1.0 0.0 0.0 - 0.0 0.0 0.0 - E2 = - 2×3 $(sprint(show, Matrix{Float64})): - 0.0 0.0 0.0 - 1.0 0.0 0.0 - ⋮ - E5 = - 2×3 $(sprint(show, Matrix{Float64})): - 0.0 0.0 1.0 - 0.0 0.0 0.0 - E6 = - 2×3 $(sprint(show, Matrix{Float64})): - 0.0 0.0 0.0 - 0.0 0.0 1.0 - Eigenvalues: - 6-element $(sprint(show, Vector{Float64})): - 1.0 - 2.0 - 3.0 - 4.0 - 5.0 - 6.0""" - - M = DefaultManifold(1, 1, 1) - x = reshape(Float64[1], (1, 1, 1)) - pb = get_basis(M, x, DefaultOrthonormalBasis()) - @test sprint(show, "text/plain", pb) == """ - DefaultOrthonormalBasis(ℝ) with 1 basis vector: - E1 = - 1×1×1 $(sprint(show, Array{Float64,3})): - [:, :, 1] = - 1.0""" - - dpb = CachedBasis( - DiagonalizingOrthonormalBasis(get_vectors(M, x, pb)), - Float64[1], - get_vectors(M, x, pb), - ) + @test sprint(show, "text/plain", dpb) == """ + DiagonalizingOrthonormalBasis(ℝ) with eigenvalue 0 in direction: + 1-element $(sprint(show, Vector{Array{Float64,3}})): + $(sprint(show, dpb.data.frame_direction[1])) + and 1 basis vector. + Basis vectors: + E1 = + 1×1×1 $(sprint(show, Array{Float64,3})): + [:, :, 1] = + 1.0 + Eigenvalues: + 1-element $(sprint(show, Vector{Float64})): + 1.0""" + end - @test sprint(show, "text/plain", dpb) == """ - DiagonalizingOrthonormalBasis(ℝ) with eigenvalue 0 in direction: - 1-element $(sprint(show, Vector{Array{Float64,3}})): - $(sprint(show, dpb.data.frame_direction[1])) - and 1 basis vector. - Basis vectors: - E1 = - 1×1×1 $(sprint(show, Array{Float64,3})): - [:, :, 1] = - 1.0 - Eigenvalues: - 1-element $(sprint(show, Vector{Float64})): - 1.0""" -end + @testset "Bases of cotangent spaces" begin + b1 = DefaultOrthonormalBasis(ℝ, CotangentSpace) + @test b1.vector_space == CotangentSpace -@testset "Bases of cotangent spaces" begin - b1 = DefaultOrthonormalBasis(ℝ, CotangentSpace) - @test b1.vector_space == CotangentSpace + b2 = DefaultOrthogonalBasis(ℝ, CotangentSpace) + @test b2.vector_space == CotangentSpace - b2 = DefaultOrthogonalBasis(ℝ, CotangentSpace) - @test b2.vector_space == CotangentSpace + b3 = DefaultBasis(ℝ, CotangentSpace) + @test b3.vector_space == CotangentSpace - b3 = DefaultBasis(ℝ, CotangentSpace) - @test b3.vector_space == CotangentSpace + M = DefaultManifold(2; field = ℂ) + p = [1.0, 2.0im] + b1_d = ManifoldsBase.dual_basis(M, p, b1) + @test b1_d isa DefaultOrthonormalBasis + @test b1_d.vector_space == TangentSpace - M = DefaultManifold(2; field = ℂ) - p = [1.0, 2.0im] - b1_d = ManifoldsBase.dual_basis(M, p, b1) - @test b1_d isa DefaultOrthonormalBasis - @test b1_d.vector_space == TangentSpace + b1_d_d = ManifoldsBase.dual_basis(M, p, b1_d) + @test b1_d_d isa DefaultOrthonormalBasis + @test b1_d_d.vector_space == CotangentSpace + end - b1_d_d = ManifoldsBase.dual_basis(M, p, b1_d) - @test b1_d_d isa DefaultOrthonormalBasis - @test b1_d_d.vector_space == CotangentSpace -end + @testset "Complex Basis - Mutating cases" begin + Mc = ManifoldsBase.DefaultManifold(2, field = ManifoldsBase.ℂ) + p = [1.0, 1.0im] + X = [2.0, 1.0im] + Bc = DefaultOrthonormalBasis(ManifoldsBase.ℂ) + CBc = get_basis(Mc, p, Bc) + @test CBc.data == [[1.0, 0.0], [0.0, 1.0], [1.0im, 0.0], [0.0, 1.0im]] + B = DefaultOrthonormalBasis(ManifoldsBase.ℝ) + CB = get_basis(Mc, p, B) + @test CB.data == [[1.0, 0.0], [0.0, 1.0]] + @test get_coordinates(Mc, p, X, CBc) == [2.0, 0.0, 0.0, 1.0] + @test get_coordinates(Mc, p, X, CB) == [2.0, 1.0im] + # ONB + cc = zeros(4) + @test get_coordinates!(Mc, cc, p, X, Bc) == [2.0, 0.0, 0.0, 1.0] + @test cc == [2.0, 0.0, 0.0, 1.0] + c = zeros(ComplexF64, 2) + @test get_coordinates!(Mc, c, p, X, B) == [2.0, 1.0im] + @test c == [2.0, 1.0im] + # Cached + @test get_coordinates!(Mc, cc, p, X, CBc) == [2.0, 0.0, 0.0, 1.0] + @test cc == [2.0, 0.0, 0.0, 1.0] + @test get_coordinates!(Mc, c, p, X, CB) == [2.0, 1.0im] + @test c == [2.0, 1.0im] -@testset "FVector" begin - @test sprint(show, TangentSpace) == "TangentSpace" - @test sprint(show, CotangentSpace) == "CotangentSpace" - tvs = ([1.0, 0.0, 0.0], [0.0, 1.0, 0.0]) - fv_tvs = map(v -> TFVector(v, DefaultOrthonormalBasis()), tvs) - fv1 = fv_tvs[1] - tv1s = allocate(fv_tvs[1]) - @test isa(tv1s, FVector) - @test tv1s.type == TangentSpace - @test size(tv1s.data) == size(tvs[1]) - @test number_eltype(tv1s) == number_eltype(tvs[1]) - @test number_eltype(tv1s) == number_eltype(typeof(tv1s)) - @test isa(fv1 + fv1, FVector) - @test (fv1 + fv1).type == TangentSpace - @test isa(fv1 - fv1, FVector) - @test (fv1 - fv1).type == TangentSpace - @test isa(-fv1, FVector) - @test (-fv1).type == TangentSpace - @test isa(2 * fv1, FVector) - @test (2 * fv1).type == TangentSpace - tv1s_32 = allocate(fv_tvs[1], Float32) - @test isa(tv1s, FVector) - @test eltype(tv1s_32.data) === Float32 - copyto!(tv1s, fv_tvs[2]) - @test isapprox(tv1s.data, fv_tvs[2].data) - - @test sprint(show, fv1) == "TFVector([1.0, 0.0, 0.0], $(fv1.basis))" - - cofv1 = CoTFVector(tvs[1], DefaultOrthonormalBasis(ℝ, CotangentSpace)) - @test cofv1 isa CoTFVector - @test sprint(show, cofv1) == "CoTFVector([1.0, 0.0, 0.0], $(fv1.basis))" -end + end -@testset "vector_space_dimension" begin - M = ManifoldsBase.DefaultManifold(3) - MC = ManifoldsBase.DefaultManifold(3; field = ℂ) - @test ManifoldsBase.vector_space_dimension(M, TangentSpace) == 3 - @test ManifoldsBase.vector_space_dimension(M, CotangentSpace) == 3 - @test ManifoldsBase.vector_space_dimension(MC, TangentSpace) == 6 - @test ManifoldsBase.vector_space_dimension(MC, CotangentSpace) == 6 + @testset "FVector" begin + @test sprint(show, TangentSpace) == "TangentSpace" + @test sprint(show, CotangentSpace) == "CotangentSpace" + tvs = ([1.0, 0.0, 0.0], [0.0, 1.0, 0.0]) + fv_tvs = map(v -> TFVector(v, DefaultOrthonormalBasis()), tvs) + fv1 = fv_tvs[1] + tv1s = allocate(fv_tvs[1]) + @test isa(tv1s, FVector) + @test tv1s.type == TangentSpace + @test size(tv1s.data) == size(tvs[1]) + @test number_eltype(tv1s) == number_eltype(tvs[1]) + @test number_eltype(tv1s) == number_eltype(typeof(tv1s)) + @test isa(fv1 + fv1, FVector) + @test (fv1 + fv1).type == TangentSpace + @test isa(fv1 - fv1, FVector) + @test (fv1 - fv1).type == TangentSpace + @test isa(-fv1, FVector) + @test (-fv1).type == TangentSpace + @test isa(2 * fv1, FVector) + @test (2 * fv1).type == TangentSpace + tv1s_32 = allocate(fv_tvs[1], Float32) + @test isa(tv1s, FVector) + @test eltype(tv1s_32.data) === Float32 + copyto!(tv1s, fv_tvs[2]) + @test isapprox(tv1s.data, fv_tvs[2].data) + + @test sprint(show, fv1) == "TFVector([1.0, 0.0, 0.0], $(fv1.basis))" + + cofv1 = CoTFVector(tvs[1], DefaultOrthonormalBasis(ℝ, CotangentSpace)) + @test cofv1 isa CoTFVector + @test sprint(show, cofv1) == "CoTFVector([1.0, 0.0, 0.0], $(fv1.basis))" + end + + @testset "vector_space_dimension" begin + M = ManifoldsBase.DefaultManifold(3) + MC = ManifoldsBase.DefaultManifold(3; field = ℂ) + @test ManifoldsBase.vector_space_dimension(M, TangentSpace) == 3 + @test ManifoldsBase.vector_space_dimension(M, CotangentSpace) == 3 + @test ManifoldsBase.vector_space_dimension(MC, TangentSpace) == 6 + @test ManifoldsBase.vector_space_dimension(MC, CotangentSpace) == 6 + end end diff --git a/test/decorator_manifold.jl b/test/decorator_manifold.jl deleted file mode 100644 index 2f49445e..00000000 --- a/test/decorator_manifold.jl +++ /dev/null @@ -1,260 +0,0 @@ -using ManifoldsBase -using Test -using ManifoldsBase: - @decorator_transparent_function, - @decorator_transparent_fallback, - @decorator_transparent_signature, - is_decorator_transparent -import ManifoldsBase: decorator_transparent_dispatch - -struct TestDecorator{M<:AbstractManifold{ℝ}} <: - AbstractDecoratorManifold{ℝ,DefaultDecoratorType} - manifold::M -end - -abstract type AbstractTestDecorator <: AbstractDecoratorManifold{ℝ,DefaultDecoratorType} end - -struct TestDecorator2{M<:AbstractManifold{ℝ}} <: AbstractTestDecorator - manifold::M -end - -struct TestDecorator3{M<:AbstractManifold{ℝ}} <: AbstractTestDecorator - manifold::M -end - -abstract type AbstractParentDecorator <: AbstractDecoratorManifold{ℝ,DefaultDecoratorType} end - -struct ChildDecorator{M<:AbstractManifold{ℝ}} <: AbstractParentDecorator - manifold::M -end - -struct DefaultDecorator{M<:AbstractManifold{ℝ}} <: - AbstractDecoratorManifold{ℝ,DefaultDecoratorType} - manifold::M -end -ManifoldsBase.default_decorator_dispatch(::DefaultDecorator) = Val(true) - -test1(M::AbstractManifold, p; a = 0) = 101 + a -test2(M::AbstractManifold, p; a = 0) = 102 + a -test3(M::AbstractManifold, p; a = 0) = 103 + a -function test4(M::AbstractManifold, p; a = 0) - return error(ManifoldsBase.manifold_function_not_implemented_message(M, test4, p)) -end - -function test1(M::TestDecorator, p; a = 0) - return 1 + a -end - -function decorator_transparent_dispatch(::typeof(test1), M::TestDecorator, args...) - return Val(:intransparent) -end -function decorator_transparent_dispatch(::typeof(test2), M::TestDecorator, args...) - return Val(:transparent) -end -decorator_transparent_dispatch(::typeof(test3), M::TestDecorator, args...) = Val(:parent) -function decorator_transparent_dispatch(::typeof(test4), M::TestDecorator, args...) - return Val(:intransparent) -end - -@decorator_transparent_function :transparent function test5(M::AbstractDecoratorManifold, p) - return 5 -end - -@decorator_transparent_function @inline function test6(M::TestDecorator, p) - return 6 -end - -@decorator_transparent_function :parent function test7(M::TestDecorator, p) - return 7 -end - -@decorator_transparent_fallback :parent @inline function test7(M::TestDecorator, p) - return 17 -end - -test8(M::AbstractManifold, p; a = 0) = 8 + a - -@decorator_transparent_function :parent function test9( - M::AbstractDecoratorManifold, - p; - a = 0, - kwargs..., -) - return 9 + a + (haskey(kwargs, :b) ? kwargs[:b] : 0) -end - -@decorator_transparent_fallback :parent @inline function test9( - M::AbstractTestDecorator, - p::TP; - a = 0, - kwargs..., -) where {TP} - return 19 + a + (haskey(kwargs, :b) ? kwargs[:b] : 0) -end - -function test9(M::TestDecorator3, p::TP; a = 0, kwargs...) where {TP} - return 109 + a + (haskey(kwargs, :b) ? kwargs[:b] : 0) -end - -test10(M::AbstractTestDecorator, p::TP; a = 0) where {TP} = 10 * a -@decorator_transparent_function function test10(M::TestDecorator3, p::TP; a = 0) where {TP} - return 5 * a -end -# the following then ignores the previous definition and passes again to the parent above -decorator_transparent_dispatch(::typeof(test10), M::TestDecorator3, args...) = Val(:parent) - -@decorator_transparent_function function test11( - M::TestDecorator3, - p::TP; - a::Int = 0, -) where {TP} - return 15 * a -end - -@decorator_transparent_function function test12(M::ManifoldsBase.DefaultManifold, p) - return 12 * p -end -ManifoldsBase._acts_transparently(test12, TestDecorator3, p) = Val(:foo) - -@decorator_transparent_function :none function test13(M::TestDecorator3, p) - return 13.5 * p -end -function decorator_transparent_dispatch(::typeof(test13), M::TestDecorator, args...) - return Val(:intransparent) -end -function decorator_transparent_dispatch(::typeof(test13), M::TestDecorator2, args...) - return Val(:transparent) -end -test13(::ManifoldsBase.DefaultManifold, a) = 13 * a - -function test14(M::AbstractDecoratorManifold, p) - return 14.5 * p -end -@decorator_transparent_signature test14(M::AbstractDecoratorManifold, p) -decorator_transparent_dispatch(::typeof(test14), M::TestDecorator3, args...) = Val(:none) -function decorator_transparent_dispatch(::typeof(test14), M::TestDecorator, args...) - return Val(:intransparent) -end -function decorator_transparent_dispatch(::typeof(test14), M::TestDecorator2, args...) - return Val(:transparent) -end -test14(::ManifoldsBase.DefaultManifold, a) = 14 * a - -test15(::ManifoldsBase.DefaultManifold, a) = 15.5 * a -@decorator_transparent_function function test15(M::AbstractDecoratorManifold, p) - return error("Not yet implemented") -end -test15(::AbstractParentDecorator, p) = 15 * p -decorator_transparent_dispatch(::typeof(test15), M::ChildDecorator, args...) = Val(:parent) - -function test16(::AbstractParentDecorator, p) - return 16 * p -end -test16(::ManifoldsBase.DefaultManifold, a) = 16.5 * a -@decorator_transparent_signature test16(M::AbstractDecoratorManifold, p) -decorator_transparent_dispatch(::typeof(test16), M::ChildDecorator, args...) = Val(:parent) - -function test17(M::ManifoldsBase.DefaultManifold, p) - return 17 * p -end -@decorator_transparent_signature test17(M::AbstractDecoratorManifold, p) -function decorator_transparent_dispatch( - ::typeof(test17), - M::AbstractDecoratorManifold, - args..., -) - return Val(:intransparent) -end -default_decorator_dispatch(::DefaultDecorator) = Val(true) - -@decorator_transparent_function function test18(M::AbstractDecoratorManifold, p) - return 18.25 * p -end -decorator_transparent_dispatch(::typeof(test18), M::ChildDecorator, args...) = Val(:parent) - -@testset "Testing decorator manifold functions" begin - M = ManifoldsBase.DefaultManifold(3) - A = ValidationManifold(M) - - @test (@inferred base_manifold(M)) == M - @test (@inferred base_manifold(A)) == M - @test (@inferred ManifoldsBase.decorated_manifold(A)) == M - @test ManifoldsBase._extract_val(Val(:transparent)) === :transparent - - @test number_system(M) == ℝ - @test number_system(ManifoldsBase.DefaultManifold(3; field = ℂ)) == ℂ - - @test (@inferred base_manifold(M, Val(1))) == M - @test (@inferred base_manifold(M, Val(0))) == M - @test (@inferred base_manifold(A, Val(1))) == M - @test (@inferred base_manifold(A, Val(0))) == A - - x = 0 - @test_throws LoadError eval(:(@decorator_transparent_fallback x = x + 1)) - @test_throws LoadError eval(:(@decorator_transparent_function x = x + 1)) - @test_throws LoadError eval(:(@decorator_transparent_signature x = x + 1)) - - @test representation_size(M) == (3,) - @test representation_size(A) == (3,) - - @test manifold_dimension(M) == 3 - @test manifold_dimension(A) == 3 - - p = [1.0, 0.0, 0.0] - X = [2.0, 1.0, 3.0] - @test inner(A, p, X, X) ≈ ManifoldsBase.inner__transparent(A, p, X, X) - @test_throws ErrorException ManifoldsBase.inner__intransparent(A, p, X, X) - - TD = TestDecorator(M) - - @test (@inferred ManifoldsBase.default_decorator_dispatch(M)) === Val(false) - @test ManifoldsBase.is_default_decorator(M) === false - - @test injectivity_radius(TD, ManifoldsBase.ExponentialRetraction()) == Inf - - @test test1(TD, p) == 1 - @test test1(TD, p; a = 1000) == 1001 - @test test2(TD, p) == 102 - @test test2(TD, p; a = 1000) == 1102 - @test test3(TD, p) == 103 - @test test3(TD, p; a = 1000) == 1103 - @test_throws ErrorException test4(TD, p) - @test_throws ErrorException test4(TD, p; a = 1000) - @test (@inferred decorator_transparent_dispatch(test5, TD, p)) === Val(:transparent) - @test is_decorator_transparent(test5, TD, p) - @test test5(TD, p) == 5 - @test (@inferred decorator_transparent_dispatch(test6, TD, p)) === Val(:intransparent) - @test_throws ErrorException test7(M, p) - @test test7(TD, p) == 17 - @test (@inferred decorator_transparent_dispatch(test8, M, p)) === Val(:transparent) - @test is_decorator_transparent(test8, M, p) - @test_throws ErrorException test9(M, p; a = 1000) - @test test9(TD, p; a = 1000) == 1009 - @test test9(TD, p; a = 1000, b = 10000) == 11009 - @test test9(TestDecorator2(TD), p; a = 1000) == 1019 - @test test9(TestDecorator2(TD), p; a = 1000, b = 10000) == 11019 - @test test9(TestDecorator3(TestDecorator2(TD)), p; a = 1000) == 1109 - @test test9(TestDecorator3(TestDecorator2(TD)), p; a = 1000, b = 10000) == 11109 - @test test9(TestDecorator3(TD), p; a = 1000) == 1109 - @test test9(TestDecorator3(TD), p; a = 1000, b = 10000) == 11109 - @test test10(TestDecorator3(TD), p; a = 11) == 110 - @test test11(TestDecorator3(TD), p; a = 12) == 180 - @test_throws ErrorException test12(TestDecorator3(TD), p) - - @test_throws ErrorException test13(TestDecorator3(M), 1) # :none nonexistent - @test_throws ErrorException test13(TestDecorator(M), 1) # not implemented - @test test13(TestDecorator2(M), 2) == 26 # from parent - - @test_throws ErrorException test14(TestDecorator3(M), 1) # :none nonexistent - @test_throws ErrorException test14(TestDecorator(M), 1) # not implemented - @test test14(TestDecorator2(M), 2) == 28 # from parent - - @test test15(ChildDecorator(M), 1) == 15 - @test test16(ChildDecorator(M), 1) == 16 - - @test_throws ErrorException test17(TestDecorator(ManifoldsBase.DefaultManifold(3)), 1) - @test test17(DefaultDecorator(ManifoldsBase.DefaultManifold(3)), 1) == 17 - - # states that child has to implement at least a parent case - @test_throws ErrorException test18(ChildDecorator(ManifoldsBase.DefaultManifold(3)), 1) -end diff --git a/test/decorator_traits.jl b/test/decorator_traits.jl new file mode 100644 index 00000000..2efdf8ca --- /dev/null +++ b/test/decorator_traits.jl @@ -0,0 +1,141 @@ +using Test +using ManifoldsBase + +using ManifoldsBase: AbstractTrait, TraitList, EmptyTrait, trait, merge_traits +using ManifoldsBase: expand_trait, next_trait +import ManifoldsBase: active_traits, parent_trait + +struct IsCool <: AbstractTrait end +struct IsNice <: AbstractTrait end + +abstract type AbstractA end +abstract type DecoA <: AbstractA end +# A few concrete types +struct A <: AbstractA end +struct A1 <: DecoA end +struct A2 <: DecoA end +struct A3 <: DecoA end +struct A4 <: DecoA end +struct A5 <: DecoA end + +# just some function +f(::AbstractA, x) = x + 2 +g(::DecoA, x, y) = x + y + +struct IsGreat <: AbstractTrait end # a special case of IsNice +parent_trait(::IsGreat) = IsNice() + +active_traits(f, ::A1, ::Any) = merge_traits(IsNice()) +active_traits(f, ::A2, ::Any) = merge_traits(IsCool()) +active_traits(f, ::A3, ::Any) = merge_traits(IsCool(), IsNice()) +active_traits(f, ::A5, ::Any) = merge_traits(IsGreat()) + +f(a::DecoA, b) = f(trait(f, a, b), a, b) + +f(::TraitList{IsNice}, a, b) = g(a, b, 3) +f(::TraitList{IsCool}, a, b) = g(a, b, 5) + +# generic forward to the next trait to be looked at +f(t::TraitList, a, b) = f(next_trait(t), a, b) +# generic fallback when no traits are defined +f(::EmptyTrait, a, b) = invoke(f, Tuple{AbstractA,typeof(b)}, a, b) + +@testset "Decorator trait tests" begin + t = ManifoldsBase.EmptyTrait() + t2 = ManifoldsBase.TraitList(t, t) + @test merge_traits() == t + @test merge_traits(t) == t + @test merge_traits(t, t) == t + @test merge_traits(t2) == t2 + @test merge_traits( + merge_traits(IsGreat(), IsNice()), + merge_traits(IsGreat(), IsNice()), + ) === merge_traits(IsGreat(), IsNice(), IsGreat(), IsNice()) + @test expand_trait(merge_traits(IsGreat(), IsCool())) === + merge_traits(IsGreat(), IsNice(), IsCool()) + @test expand_trait(merge_traits(IsCool(), IsGreat())) === + merge_traits(IsCool(), IsGreat(), IsNice()) + + @test string(merge_traits(IsGreat(), IsNice())) == + "TraitList(IsGreat(), TraitList(IsNice(), EmptyTrait()))" + + global f + @test f(A(), 0) == 2 + @test f(A2(), 0) == 5 + @test f(A3(), 0) == 5 + @test f(A4(), 0) == 2 + @test f(A5(), 890) == 893 + f(::TraitList{IsGreat}, a, b) = g(a, b, 54) + @test f(A5(), 890) == 944 +end + +# +# A Manifold decorator test - check that EmptyTrait cases call Abstract and those fail with +# MethodError due to ambiguities (between Abstract and Decorator) +struct NonDecoratorManifold <: AbstractDecoratorManifold{ManifoldsBase.ℝ} end +ManifoldsBase.representation_size(::NonDecoratorManifold) = (2,) + +@testset "Testing a NonDecoratorManifold - emptytrait fallbacks" begin + M = NonDecoratorManifold() + p = [2.0, 1.0] + q = similar(p) + X = [1.0, 0.0] + Y = similar(X) + @test ManifoldsBase.check_size(M, p) === nothing + @test ManifoldsBase.check_size(M, p, X) === nothing + # Ambiguous since not implemented + @test_throws MethodError embed(M, p) + @test_throws MethodError embed!(M, q, p) + @test_throws MethodError embed(M, p, X) + @test_throws MethodError embed!(M, Y, p, X) + # the following is implemented but passes to the second and hence fails + @test_throws MethodError exp(M, p, X) + @test_throws MethodError exp!(M, q, p, X) + @test_throws MethodError retract(M, p, X) + @test_throws MethodError retract!(M, q, p, X) + @test_throws MethodError log(M, p, q) + @test_throws MethodError log!(M, Y, p, q) + @test_throws MethodError inverse_retract(M, p, q) + @test_throws MethodError inverse_retract!(M, Y, p, q) + @test_throws MethodError parallel_transport_along(M, p, X, :curve) + @test_throws MethodError parallel_transport_along!(M, Y, p, X, :curve) + @test_throws MethodError parallel_transport_direction(M, p, X, X) + @test_throws MethodError parallel_transport_direction!(M, Y, p, X, X) + @test_throws MethodError parallel_transport_to(M, p, X, q) + @test_throws MethodError parallel_transport_to!(M, Y, p, X, q) + @test_throws MethodError vector_transport_along(M, p, X, :curve) + @test_throws MethodError vector_transport_along!(M, Y, p, X, :curve) +end + +# With even less, check that representation size stack overflows +struct NonDecoratorNonManifold <: AbstractDecoratorManifold{ManifoldsBase.ℝ} end +@testset "Testing a NonDecoratorNonManifold - emptytrait fallback Errors" begin + N = NonDecoratorNonManifold() + @test_throws StackOverflowError representation_size(N) +end + + +h(::AbstractA, x::Float64, y; a = 1) = x + y - a +h(::DecoA, x, y) = x + y +ManifoldsBase.@invoke_maker 1 AbstractA h(A::DecoA, x::Float64, y) + +@testset "@invoke_maker" begin + @test h(A1(), 1.0, 2) == 2 + sig = ManifoldsBase._split_signature(:(fname(x::T; k::Float64 = 1) where {T})) + @test sig.fname == :fname + @test sig.where_exprs == Any[:T] + @test sig.callargs == Any[:(x::T)] + @test sig.kwargs_list == Any[:($(Expr(:kw, :(k::Float64), 1)))] + @test sig.argnames == [:x] + @test sig.argtypes == [:T] + @test sig.kwargs_call == Expr[:(k = k)] + @test_throws ErrorException ManifoldsBase._split_signature(:(a = b)) + + sig2 = ManifoldsBase._split_signature(:(fname(x; kwargs...))) + @test sig2.kwargs_call == Expr[:(kwargs...)] + + sig3 = ManifoldsBase._split_signature(:(fname(x::T, y::Int = 10; k1 = 1) where {T})) + @test sig3.kwargs_call == Expr[:(k1 = k1)] + @test sig3.callargs[2] == :($(Expr(:kw, :(y::Int), 10))) + @test sig3.argnames == [:x, :y] +end diff --git a/test/default_manifold.jl b/test/default_manifold.jl index f05d47c4..75d0ac78 100644 --- a/test/default_manifold.jl +++ b/test/default_manifold.jl @@ -1,6 +1,7 @@ using ManifoldsBase using ManifoldsBase: @manifold_element_forwards, @manifold_vector_forwards, @default_manifold_fallbacks +using ManifoldsBase: DefaultManifold import ManifoldsBase: number_eltype, check_point, @@ -10,6 +11,7 @@ import ManifoldsBase: inner, isapprox, log!, + parallel_transport_to!, retract!, inverse_retract! import Base: angle, convert @@ -24,59 +26,68 @@ struct CustomDefinedRetraction <: ManifoldsBase.AbstractRetractionMethod end struct CustomUndefinedRetraction <: ManifoldsBase.AbstractRetractionMethod end struct CustomDefinedInverseRetraction <: ManifoldsBase.AbstractInverseRetractionMethod end -function ManifoldsBase.injectivity_radius( - ::ManifoldsBase.DefaultManifold, - ::CustomDefinedRetraction, -) - return 10.0 -end -function ManifoldsBase.retract!( - ::ManifoldsBase.DefaultManifold, - q, - p, - X, - ::CustomDefinedRetraction, -) - return (q .= p .+ X) -end -function ManifoldsBase.inverse_retract!( - ::ManifoldsBase.DefaultManifold, - X, - p, - q, - ::CustomDefinedInverseRetraction, -) - return (X .= q .- p) -end - - -struct MatrixVectorTransport{T} <: AbstractVector{T} - m::Matrix{T} -end - -Base.getindex(x::MatrixVectorTransport, i) = x.m[:, i] - -Base.size(x::MatrixVectorTransport) = (size(x.m, 2),) - struct DefaultPoint{T} <: AbstractManifoldPoint value::T end DefaultPoint(v::T) where {T} = DefaultPoint{T}(v) convert(::Type{DefaultPoint{T}}, v::T) where {T} = DefaultPoint(v) - +Base.size(p::DefaultPoint) = size(p.value) Base.eltype(v::DefaultPoint) = eltype(v.value) struct DefaultTVector{T} <: TVector value::T end DefaultTVector(v::T) where {T} = DefaultTVector{T}(v) - -Base.eltype(v::DefaultTVector) = eltype(v.value) +Base.size(X::DefaultTVector) = size(X.value) +Base.eltype(X::DefaultTVector) = eltype(X.value) ManifoldsBase.@manifold_element_forwards DefaultPoint value ManifoldsBase.@manifold_vector_forwards DefaultTVector value ManifoldsBase.@default_manifold_fallbacks ManifoldsBase.DefaultManifold DefaultPoint DefaultTVector value value +function ManifoldsBase.injectivity_radius(::DefaultManifold, ::CustomDefinedRetraction) + return 10.0 +end +function ManifoldsBase._retract(M::DefaultManifold, p, X, ::CustomDefinedRetraction) + return retract_custom(M, p, X) +end +function retract_custom(::DefaultManifold, p::DefaultPoint, X::DefaultTVector) + return DefaultPoint(p.value + X.value) +end +function ManifoldsBase._inverse_retract( + M::DefaultManifold, + p, + q, + ::CustomDefinedInverseRetraction, +) + return inverse_retract_custom(M, p, q) +end +function inverse_retract_custom(::DefaultManifold, p::DefaultPoint, q::DefaultPoint) + return DefaultTVector(q.value - p.value) +end +struct MatrixVectorTransport{T} <: AbstractVector{T} + m::Matrix{T} +end +# dummy retractions, inverse retracions for fallback tests - mutating should be enough +ManifoldsBase.retract_polar!(::DefaultManifold, q, p, X) = (q .= p .+ X) +ManifoldsBase.retract_project!(::DefaultManifold, q, p, X) = (q .= p .+ X) +ManifoldsBase.retract_qr!(::DefaultManifold, q, p, X) = (q .= p .+ X) +ManifoldsBase.retract_exp_ode!(::DefaultManifold, q, p, X, m, B) = (q .= p .+ X) +ManifoldsBase.retract_pade!(::DefaultManifold, q, p, X, i) = (q .= p .+ X) +ManifoldsBase.retract_softmax!(::DefaultManifold, q, p, X) = (q .= p .+ X) +ManifoldsBase.get_embedding(M::DefaultManifold) = M # dummy embedding +ManifoldsBase.inverse_retract_polar!(::DefaultManifold, Y, p, q) = (Y .= q .- p) +ManifoldsBase.inverse_retract_project!(::DefaultManifold, Y, p, q) = (Y .= q .- p) +ManifoldsBase.inverse_retract_qr!(::DefaultManifold, Y, p, q) = (Y .= q .- p) +ManifoldsBase.inverse_retract_softmax!(::DefaultManifold, Y, p, q) = (Y .= q .- p) +ManifoldsBase.inverse_retract_nlsolve!(::DefaultManifold, Y, p, q, m) = (Y .= q .- p) +ManifoldsBase.vector_transport_along_project!(::DefaultManifold, Y, p, X, c) = (Y .= X) + + +Base.getindex(x::MatrixVectorTransport, i) = x.m[:, i] + +Base.size(x::MatrixVectorTransport) = (size(x.m, 2),) + @testset "Testing Default (Euclidean)" begin M = ManifoldsBase.DefaultManifold(3) types = [ @@ -119,8 +130,6 @@ ManifoldsBase.@default_manifold_fallbacks ManifoldsBase.DefaultManifold DefaultP @test injectivity_radius(M, rm) == Inf @test injectivity_radius(M, rm2) == 10 @test injectivity_radius(M, pts[1], rm2) == 10 - @test_throws ErrorException injectivity_radius(M, rm3) - @test_throws ErrorException injectivity_radius(M, pts[1], rm3) tv1 = log(M, pts[1], pts[2]) @@ -261,18 +270,16 @@ ManifoldsBase.@default_manifold_fallbacks ManifoldsBase.DefaultManifold DefaultP end @testset "vector transport" begin - # test constructor and alias + # test constructor @test default_vector_transport_method(M) == ParallelTransport() - @test DifferentiatedRetractionVectorTransport(ExponentialRetraction()) == - ParallelTransport() v1 = log(M, pts[1], pts[2]) v2 = log(M, pts[1], pts[3]) v1t1 = vector_transport_to(M, pts[1], v1, pts[3]) v1t2 = zero(v1t1) vector_transport_to!(M, v1t2, pts[1], v1, v2, ProjectionTransport()) v1t3 = vector_transport_direction(M, pts[1], v1, v2) - @test is_vector(M, pts[3], v1t1) - @test is_vector(M, pts[3], v1t3) + @test ManifoldsBase.is_vector(M, pts[3], v1t1) + @test ManifoldsBase.is_vector(M, pts[3], v1t3) @test isapprox(M, pts[3], v1t1, v1t3) # along a `Vector` of points c = [pts[1]] @@ -281,6 +288,11 @@ ManifoldsBase.@default_manifold_fallbacks ManifoldsBase.DefaultManifold DefaultP v1t5 = allocate(v1) vector_transport_along!(M, v1t5, pts[1], v1, c) @test isapprox(M, pts[1], v1, v1t5) + # transport along more than one interims point + @test vector_transport_along(M, pts[1], v1, pts[2:3]) == v1 + v1t6 = allocate(v1) + vector_transport_along!(M, v1t6, pts[1], v1, pts[2:3]) + @test isapprox(M, pts[1], v1, v1t6) # along a custom type of points if T <: DefaultPoint S = eltype(pts[1].value) @@ -373,8 +385,9 @@ ManifoldsBase.@default_manifold_fallbacks ManifoldsBase.DefaultManifold DefaultP end @testset "Retraction" begin - a = NLsolveInverseRetraction(ExponentialRetraction()) + a = NLSolveInverseRetraction(ExponentialRetraction()) @test a.retraction isa ExponentialRetraction + end @testset "copy of points and vectors" begin @@ -407,17 +420,110 @@ ManifoldsBase.@default_manifold_fallbacks ManifoldsBase.DefaultManifold DefaultP Y = DefaultTVector([1.0, 0.0, 0.0]) @test angle(M, p, X, Y) ≈ π / 2 @test inverse_retract(M, p, q, LogarithmicInverseRetraction()) == -Y + @test retract(M, q, Y, CustomDefinedRetraction()) == p @test retract(M, q, Y, ExponentialRetraction()) == p + # rest not implemented - so they also fall back even onto mutating + Z = similar(Y) + r = similar(p) + # test passthrough using the dummy implementations + for retr in [PolarRetraction, ProjectionRetraction, QRRetraction, SoftmaxRetraction] + @test retract(M, q, Y, retr()) == DefaultPoint(q.value + Y.value) + @test retract!(M, r, q, Y, retr()) == DefaultPoint(q.value + Y.value) + end + @test retract( + M, + q, + Y, + ODEExponentialRetraction(PolarRetraction(), DefaultBasis()), + ) == DefaultPoint(q.value + Y.value) + @test retract!( + M, + r, + q, + Y, + ODEExponentialRetraction(PolarRetraction(), DefaultBasis()), + ) == DefaultPoint(q.value + Y.value) + @test retract(M, q, Y, PadeRetraction(2)) == DefaultPoint(q.value + Y.value) + @test retract!(M, r, q, Y, PadeRetraction(2)) == DefaultPoint(q.value + Y.value) + @test retract!(M, r, q, Y, EmbeddedRetraction(ExponentialRetraction())) == + DefaultPoint(q.value + Y.value) + @test retract(M, q, Y, EmbeddedRetraction(ExponentialRetraction())) == + DefaultPoint(q.value + Y.value) p2 = allocate(p, eltype(p.value), size(p.value)) @test size(p2.value) == size(p.value) X2 = allocate(X, eltype(X.value), size(X.value)) @test size(X2.value) == size(X.value) - # Dispatch on custom - dispatch not working, check for new scheme later. - @test_broken inverse_retract(M, p, q, CustomDefinedInverseRetraction()) == -Y - @test_broken retract(M, q, Y, CustomDefinedRetraction()) == p + X3 = ManifoldsBase.allocate_result(M, log, p, q) + @test log!(M, X3, p, q) == log(M, p, q) + @test X3 == log(M, p, q) + @test log!(M, X3, p, q) == log(M, p, q) + @test X3 == log(M, p, q) + @test inverse_retract(M, p, q, CustomDefinedInverseRetraction()) == -Y + X4 = ManifoldsBase.allocate_result(M, inverse_retract, p, q) + @test inverse_retract!(M, X4, p, q) == inverse_retract(M, p, q) + @test X4 == inverse_retract(M, p, q) + # rest not implemented but check passthrough + for r in [ + PolarInverseRetraction, + ProjectionInverseRetraction, + QRInverseRetraction, + SoftmaxInverseRetraction, + ] + @test inverse_retract(M, q, p, r()) == DefaultTVector(p.value - q.value) + @test inverse_retract!(M, Z, q, p, r()) == DefaultTVector(p.value - q.value) + end + @test inverse_retract( + M, + q, + p, + EmbeddedInverseRetraction(LogarithmicInverseRetraction()), + ) == DefaultTVector(p.value - q.value) + @test inverse_retract(M, q, p, NLSolveInverseRetraction(ExponentialRetraction())) == + DefaultTVector(p.value - q.value) + @test inverse_retract!( + M, + Z, + q, + p, + EmbeddedInverseRetraction(LogarithmicInverseRetraction()), + ) == DefaultTVector(p.value - q.value) + @test inverse_retract!( + M, + Z, + q, + p, + NLSolveInverseRetraction(ExponentialRetraction()), + ) == DefaultTVector(p.value - q.value) + c = ManifoldsBase.allocate_coordinates(M, p, Float64, manifold_dimension(M)) + @test c isa Vector + @test length(c) == 3 @test 2.0 \ X == DefaultTVector(2.0 \ X.value) @test X + Y == DefaultTVector(X.value + Y.value) @test +X == X @test (Y .= X) === Y + # vector transport pass through + @test vector_transport_to(M, p, X, q, ProjectionTransport()) == X + @test vector_transport_direction(M, p, X, X, ProjectionTransport()) == X + @test vector_transport_to!(M, Y, p, X, q, ProjectionTransport()) == X + @test vector_transport_direction!(M, Y, p, X, X, ProjectionTransport()) == X + @test vector_transport_along(M, p, X, X, ProjectionTransport()) == X + @test vector_transport_along!(M, Z, p, X, X, ProjectionTransport()) == X + @test vector_transport_to(M, p, X, :q, ProjectionTransport()) == X + @test parallel_transport_to(M, p, X, q) == X + @test parallel_transport_direction(M, p, X, X) == X + @test parallel_transport_along(M, p, X, :c) == X + @test parallel_transport_to!(M, Y, p, X, q) == X + @test parallel_transport_direction!(M, Y, p, X, X) == X + @test parallel_transport_along!(M, Y, p, X, :c) == X + end + @testset "DefaultManifold and ONB" begin + M = ManifoldsBase.DefaultManifold(3) + p = [1.0, 0.0, 0.0] + CB = get_basis(M, p, DefaultOrthonormalBasis()) + @test CB.data == [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] + end + @testset "Show methods" begin + @test repr(CayleyRetraction()) == "CayleyRetraction()" + @test repr(PadeRetraction(2)) == "PadeRetraction(2)" end end diff --git a/test/domain_errors.jl b/test/domain_errors.jl index ae6602b7..d5a840ee 100644 --- a/test/domain_errors.jl +++ b/test/domain_errors.jl @@ -1,8 +1,17 @@ using ManifoldsBase +using ManifoldsBase: ℝ using Test struct ErrorTestManifold <: AbstractManifold{ℝ} end +function ManifoldsBase.check_size(::ErrorTestManifold, p) + size(p) != (2,) && return DomainError(size(p), " size $p not (2,)") + return nothing +end +function ManifoldsBase.check_size(::ErrorTestManifold, p, X) + size(X) != (2,) && return DomainError(size(X), " size $X not (2,)") + return nothing +end function ManifoldsBase.check_point(::ErrorTestManifold, x) if any(u -> u < 0, x) return DomainError(x, "<0") @@ -10,7 +19,7 @@ function ManifoldsBase.check_point(::ErrorTestManifold, x) return nothing end function ManifoldsBase.check_vector(M::ErrorTestManifold, x, v) - mpe = check_point(M, x) + mpe = ManifoldsBase.check_point(M, x) mpe === nothing || return mpe if any(u -> u < 0, v) return DomainError(v, "<0") @@ -20,15 +29,23 @@ end @testset "Domain errors" begin M = ErrorTestManifold() - @test isa(check_point(M, [-1, 1]), DomainError) - @test check_point(M, [1, 1]) === nothing + @test isa(ManifoldsBase.check_point(M, [-1, 1]), DomainError) + @test isa(ManifoldsBase.check_size(M, [-1, 1, 1]), DomainError) + @test isa(ManifoldsBase.check_size(M, [-1, 1], [1, 1, 1]), DomainError) + @test ManifoldsBase.check_point(M, [1, 1]) === nothing @test !is_point(M, [-1, 1]) + @test !is_point(M, [1, 1, 1]) # checksize fails + @test_throws DomainError is_point(M, [-1, 1, 1], true) # checksize errors @test is_point(M, [1, 1]) @test_throws DomainError is_point(M, [-1, 1], true) - @test isa(check_vector(M, [1, 1], [-1, 1]), DomainError) - @test check_vector(M, [1, 1], [1, 1]) === nothing + @test isa(ManifoldsBase.check_vector(M, [1, 1], [-1, 1]), DomainError) + @test ManifoldsBase.check_vector(M, [1, 1], [1, 1]) === nothing @test !is_vector(M, [1, 1], [-1, 1]) + @test !is_vector(M, [1, 1], [1, 1, 1]) + @test_throws DomainError is_vector(M, [1, 1], [-1, 1, 1], true) + @test !is_vector(M, [1, 1, 1], [1, 1, 1], false, true) + @test_throws DomainError is_vector(M, [1, 1, 1], [1, 1], true, true) @test is_vector(M, [1, 1], [1, 1]) @test_throws DomainError is_vector(M, [1, 1], [-1, 1], true) end diff --git a/test/embedded_manifold.jl b/test/embedded_manifold.jl index 963842bc..5830dad1 100644 --- a/test/embedded_manifold.jl +++ b/test/embedded_manifold.jl @@ -1,36 +1,86 @@ +using ManifoldsBase, Test + using ManifoldsBase: DefaultManifold, ℝ +# +# A first artificial (not real) manifold that is modelled as a submanifold +# half plane with euclidean metric is not a manifold but should test all things correctly here +# +struct HalfPlanemanifold <: AbstractDecoratorManifold{ℝ} end -struct PlaneManifold <: AbstractEmbeddedManifold{ℝ,TransparentIsometricEmbedding} end +ManifoldsBase.get_embedding(::HalfPlanemanifold) = ManifoldsBase.DefaultManifold(1, 3) +ManifoldsBase.decorated_manifold(::HalfPlanemanifold) = ManifoldsBase.DefaultManifold(2) +ManifoldsBase.representation_size(::HalfPlanemanifold) = (2,) -ManifoldsBase.decorated_manifold(::PlaneManifold) = ManifoldsBase.DefaultManifold(1, 3) -ManifoldsBase.base_manifold(::PlaneManifold) = ManifoldsBase.DefaultManifold(2) +function ManifoldsBase.check_point(::HalfPlanemanifold, p) + return p[1] > 0 ? nothing : DomainError(p[1], "p[1] ≤ 0") +end +function ManifoldsBase.check_vector(::HalfPlanemanifold, p, X) + return X[1] > 0 ? nothing : DomainError(X[1], "X[1] ≤ 0") +end +ManifoldsBase.embed(::HalfPlanemanifold, p) = reshape(p, 1, :) +ManifoldsBase.embed(::HalfPlanemanifold, p, X) = reshape(X, 1, :) -ManifoldsBase.project!(::PlaneManifold, q, p) = (q .= [p[1] p[2] 0.0]) -ManifoldsBase.project!(::PlaneManifold, Y, p, X) = (Y .= [X[1] X[2] 0.0]) +ManifoldsBase.project!(::HalfPlanemanifold, q, p) = (q .= [p[1] p[2] 0.0]) +ManifoldsBase.project!(::HalfPlanemanifold, Y, p, X) = (Y .= [X[1] X[2] 0.0]) -struct AnotherPlaneManifold <: AbstractEmbeddedManifold{ℝ,DefaultIsometricEmbeddingType} end +function ManifoldsBase.get_coordinates_orthonormal!( + ::HalfPlanemanifold, + Y, + p, + X, + ::ManifoldsBase.RealNumbers, +) + return (Y .= [X[1], X[2]]) +end +function ManifoldsBase.get_vector_orthonormal!( + ::HalfPlanemanifold, + Y, + p, + c, + ::ManifoldsBase.RealNumbers, +) + return (Y .= [c[1] c[2] 0.0]) +end + +function ManifoldsBase.active_traits(f, ::HalfPlanemanifold, args...) + return ManifoldsBase.merge_traits(ManifoldsBase.IsEmbeddedSubmanifold()) +end + +# +# A second manifold that is modelled as just isometrically embedded but not a submanifold +# +struct AnotherHalfPlanemanifold <: AbstractDecoratorManifold{ℝ} end + +ManifoldsBase.get_embedding(::AnotherHalfPlanemanifold) = ManifoldsBase.DefaultManifold(3) +function ManifoldsBase.decorated_manifold(::AnotherHalfPlanemanifold) + return ManifoldsBase.DefaultManifold(2) +end +ManifoldsBase.representation_size(::AnotherHalfPlanemanifold) = (2,) -ManifoldsBase.decorated_manifold(::AnotherPlaneManifold) = ManifoldsBase.DefaultManifold(3) -ManifoldsBase.base_manifold(::AnotherPlaneManifold) = ManifoldsBase.DefaultManifold(2) +function ManifoldsBase.active_traits(f, ::AnotherHalfPlanemanifold, args...) + return ManifoldsBase.merge_traits(ManifoldsBase.IsIsometricEmbeddedManifold()) +end -function ManifoldsBase.embed!(::AnotherPlaneManifold, q, p) +function ManifoldsBase.embed!(::AnotherHalfPlanemanifold, q, p) q[1:2] .= p q[3] = 0 return q end -function ManifoldsBase.embed!(::AnotherPlaneManifold, Y, p, X) +function ManifoldsBase.embed!(::AnotherHalfPlanemanifold, Y, p, X) Y[1:2] .= X Y[3] = 0 return Y end -function ManifoldsBase.project!(::AnotherPlaneManifold, q, p) +function ManifoldsBase.project!(::AnotherHalfPlanemanifold, q, p) return q .= [p[1], p[2]] end -function ManifoldsBase.project!(::AnotherPlaneManifold, Y, p, X) +function ManifoldsBase.project!(::AnotherHalfPlanemanifold, Y, p, X) return Y .= [X[1], X[2]] end - +# +# Third example - explicitly mention an embedding. +# function ManifoldsBase.embed!( ::EmbeddedManifold{𝔽,DefaultManifold{nL,𝔽},DefaultManifold{mL,𝔽2}}, q, @@ -79,19 +129,44 @@ function ManifoldsBase.project!( return q end -struct NotImplementedEmbeddedManifold <: - AbstractEmbeddedManifold{ℝ,TransparentIsometricEmbedding} end -function ManifoldsBase.decorated_manifold(::NotImplementedEmbeddedManifold) - return ManifoldsBase.DefaultManifold(2) +# +# A manifold that is a submanifold but otherwise has not implementations +# +struct NotImplementedEmbeddedSubManifold <: AbstractDecoratorManifold{ℝ} end +function ManifoldsBase.get_embedding(::NotImplementedEmbeddedSubManifold) + return ManifoldsBase.DefaultManifold(3) end -function ManifoldsBase.base_manifold(::NotImplementedEmbeddedManifold) +function ManifoldsBase.decorated_manifold(::NotImplementedEmbeddedSubManifold) return ManifoldsBase.DefaultManifold(2) end +function ManifoldsBase.active_traits(f, ::NotImplementedEmbeddedSubManifold, args...) + return ManifoldsBase.merge_traits(ManifoldsBase.IsEmbeddedSubmanifold()) +end -struct NotImplementedEmbeddedManifold2 <: - AbstractEmbeddedManifold{ℝ,DefaultIsometricEmbeddingType} end +# +# A manifold that is isometrically embedded but has no implementations +# +struct NotImplementedIsometricEmbeddedManifold <: AbstractDecoratorManifold{ℝ} end +function ManifoldsBase.active_traits(f, ::NotImplementedIsometricEmbeddedManifold, args...) + return ManifoldsBase.merge_traits(ManifoldsBase.IsIsometricEmbeddedManifold()) +end -struct NotImplementedEmbeddedManifold3 <: AbstractEmbeddedManifold{ℝ,DefaultEmbeddingType} end +# +# A manifold that is an embedded manifold but not isometric and has no other implementation +# +struct NotImplementedEmbeddedManifold <: AbstractDecoratorManifold{ℝ} end +function ManifoldsBase.active_traits(f, ::NotImplementedEmbeddedManifold, args...) + return ManifoldsBase.merge_traits(ManifoldsBase.IsEmbeddedManifold()) +end + +# +# A Manifold with a fallback +# +struct FallbackManifold <: AbstractDecoratorManifold{ℝ} end +function ManifoldsBase.active_traits(f, ::FallbackManifold, args...) + return ManifoldsBase.merge_traits(ManifoldsBase.IsExplicitDecorator()) +end +ManifoldsBase.decorated_manifold(::FallbackManifold) = DefaultManifold(3) @testset "Embedded Manifolds" begin @testset "EmbeddedManifold basic tests" begin @@ -102,33 +177,45 @@ struct NotImplementedEmbeddedManifold3 <: AbstractEmbeddedManifold{ℝ,DefaultEm @test repr(M) == "EmbeddedManifold($(sprint(show, M.manifold)), $(sprint(show, M.embedding)))" @test base_manifold(M) == ManifoldsBase.DefaultManifold(2) + @test base_manifold(M, Val(0)) == M + @test base_manifold(M, Val(1)) == ManifoldsBase.DefaultManifold(2) + @test base_manifold(M, Val(2)) == ManifoldsBase.DefaultManifold(2) @test get_embedding(M) == ManifoldsBase.DefaultManifold(3) - @test ManifoldsBase.decorated_manifold(M) == ManifoldsBase.DefaultManifold(3) - @test ManifoldsBase.default_decorator_dispatch(M) === Val(true) + @test get_embedding(M, [1, 2, 3]) == ManifoldsBase.DefaultManifold(3) end - @testset "PlaneManifold" begin - M = PlaneManifold() - @test repr(M) == "PlaneManifold()" - @test ManifoldsBase.default_decorator_dispatch(M) === Val(false) - @test ManifoldsBase.default_embedding_dispatch(M) === Val(false) + @testset "HalfPlanemanifold" begin + M = HalfPlanemanifold() + @test repr(M) == "HalfPlanemanifold()" @test get_embedding(M) == ManifoldsBase.DefaultManifold(1, 3) - # Check fallbacks to check embed->check_manifoldpoint Defaults - @test_throws DomainError is_point(M, [1, 0, 0], true) - @test_throws DomainError is_point(M, [1 0], true) + @test representation_size(M) == (2,) + # Check point checks using embedding + @test is_point(M, [1 0.1 0.1], true) + @test_throws DomainError is_point(M, [-1, 0, 0], true) #wrong dim (3,1) + @test !is_point(M, [-1, 0, 0]) + @test_throws DomainError is_point(M, [1, 0.1], true) # size @test is_point(M, [1 0 0], true) - @test_throws DomainError is_vector(M, [1 0 0], [1], true) - @test_throws DomainError is_vector(M, [1 0 0], [0 0 0 0], true) + @test !is_point(M, [-1 0 0]) # right size but <0 1st + @test_throws DomainError is_point(M, [-1 0 0], true) # right size but <0 1st + @test_throws DomainError is_vector(M, [1, 0, 0], [1 0 0], true) + @test_throws DomainError is_vector(M, [1 0 0], [1], true) # right point, wrong size vector + @test !is_vector(M, [1 0 0], [1]) + @test_throws DomainError is_vector(M, [1 0 0], [-1 0 0], true) # right point, vec 1st <0 + @test !is_vector(M, [1 0 0], [-1 0 0]) @test is_vector(M, [1 0 0], [1 0 1], true) + @test_throws DomainError is_vector(M, [-1, 0, 0], [0, 0, 0], true) + @test_throws DomainError is_vector(M, [1, 0, 0], [-1, 0, 0], true) + @test !is_vector(M, [-1, 0, 0], [0, 0, 0]) + @test !is_vector(M, [1, 0, 0], [-1, 0, 0]) p = [1.0 1.0 0.0] q = [1.0 0.0 0.0] X = q - p - @test check_size(M, p) === nothing - @test check_size(M, p, X) === nothing - @test check_size(M, [1, 2]) isa DomainError - @test check_size(M, [1 2 3 4]) isa DomainError - @test check_size(M, p, [1, 2]) isa DomainError - @test check_size(M, p, [1 2 3 4]) isa DomainError + @test ManifoldsBase.check_size(M, p) === nothing + @test ManifoldsBase.check_size(M, p, X) === nothing + @test ManifoldsBase.check_size(M, [1, 2]) isa DomainError + @test ManifoldsBase.check_size(M, [1 2 3 4]) isa DomainError + @test ManifoldsBase.check_size(M, p, [1, 2]) isa DomainError + @test ManifoldsBase.check_size(M, p, [1 2 3 4]) isa DomainError @test embed(M, p) == p pE = similar(p) embed!(M, pE, p) @@ -137,6 +224,10 @@ struct NotImplementedEmbeddedManifold3 <: AbstractEmbeddedManifold{ℝ,DefaultEm Q = similar(P) @test project!(M, Q, P) == project!(M, Q, P) @test project!(M, Q, P) == [1.0 1.0 0.0] + @test isapprox(M, p, zero_vector(M, p), [0 0 0]) + XZ = similar(X) + zero_vector!(M, XZ, p) + @test isapprox(M, p, XZ, [0 0 0]) XE = similar(X) embed!(M, XE, p, X) @@ -153,10 +244,36 @@ struct NotImplementedEmbeddedManifold3 <: AbstractEmbeddedManifold{ℝ,DefaultEm exp!(M, r, p, X) @test r == q @test distance(M, p, r) == norm(r - p) + + @test retract(M, p, X) == q + q2 = similar(q) + @test retract!(M, q2, p, X) == q + @test q2 == q + @test inverse_retract(M, p, q) == X + Y = similar(X) + @test inverse_retract!(M, Y, p, q) == X + @test Y == X + + @test vector_transport_along(M, p, X, []) == X + @test vector_transport_along!(M, Y, p, X, []) == X + @test parallel_transport_along(M, p, X, []) == X + @test parallel_transport_along!(M, Y, p, X, []) == X + @test parallel_transport_direction(M, p, X, X) == X + @test parallel_transport_direction!(M, Y, p, X, X) == X + @test parallel_transport_to(M, p, X, q) == X + @test parallel_transport_to!(M, Y, p, X, q) == X + + @test get_basis(M, p, DefaultOrthonormalBasis()) isa CachedBasis + Xc = [X[1], X[2]] + Yc = similar(Xc) + @test get_coordinates(M, p, X, DefaultOrthonormalBasis()) == Xc + @test get_coordinates!(M, Yc, p, X, DefaultOrthonormalBasis()) == Xc + @test get_vector(M, p, Xc, DefaultOrthonormalBasis()) == X + @test get_vector!(M, Y, p, Xc, DefaultOrthonormalBasis()) == X end - @testset "AnotherPlaneManifold" begin - M = AnotherPlaneManifold() + @testset "AnotherHalfPlanemanifold" begin + M = AnotherHalfPlanemanifold() p = [1.0, 2.0] pe = embed(M, p) @test pe == [1.0, 2.0, 0.0] @@ -168,21 +285,19 @@ struct NotImplementedEmbeddedManifold3 <: AbstractEmbeddedManifold{ℝ,DefaultEm end @testset "Test nonimplemented fallbacks" begin - @testset "Default Isometric Embedding Fallback Error Tests" begin - M = NotImplementedEmbeddedManifold() + @testset "Submanifold Embedding Fallbacks & Error Tests" begin + M = NotImplementedEmbeddedSubManifold() A = zeros(2) - # without any extra tests just the embedding is asked - @test check_point(M, [1, 2]) === nothing - @test check_vector(M, [1, 2], [3, 4]) === nothing + # for a submanifold quite a lot of functions are passed on + @test ManifoldsBase.check_point(M, [1, 2]) === nothing + @test ManifoldsBase.check_vector(M, [1, 2], [3, 4]) === nothing @test norm(M, [1, 2], [2, 3]) ≈ sqrt(13) @test distance(M, [1, 2], [3, 4]) ≈ sqrt(8) - @test inner(M, [1, 2], [2, 3], [2, 3]) ≈ 13 - @test_throws ErrorException manifold_dimension(M) - # without any implementation the projections are the identity - @test project(M, [1, 2]) == [1, 2] - @test project(M, [1, 2], [2, 3]) == [2, 3] - project!(M, A, [1, 2], [2, 3]) - @test A == [2, 3] + @test inner(M, [1, 2], [2, 3], [2, 3]) == 13 + @test manifold_dimension(M) == 2 # since base is defined is defined + @test_throws MethodError project(M, [1, 2]) + @test_throws MethodError project(M, [1, 2], [2, 3]) == [2, 3] + @test_throws MethodError project!(M, A, [1, 2], [2, 3]) @test vector_transport_direction(M, [1, 2], [2, 3], [3, 4]) == [2, 3] vector_transport_direction!(M, A, [1, 2], [2, 3], [3, 4]) @test A == [2, 3] @@ -190,143 +305,56 @@ struct NotImplementedEmbeddedManifold3 <: AbstractEmbeddedManifold{ℝ,DefaultEm vector_transport_to!(M, A, [1, 2], [2, 3], [3, 4]) @test A == [2, 3] end - @testset "General Isometric Embedding Fallback Error Tests" begin - M2 = NotImplementedEmbeddedManifold2() + @testset "Isometric Embedding Fallbacks & Error Tests" begin + M2 = NotImplementedIsometricEmbeddedManifold() @test base_manifold(M2) == M2 A = zeros(2) - @test_throws ErrorException exp(M2, [1, 2], [2, 3]) - @test_throws ErrorException exp!(M2, A, [1, 2], [2, 3]) - @test_throws ErrorException log(M2, [1, 2], [2, 3]) - @test_throws ErrorException log!(M2, A, [1, 2], [2, 3]) - @test_throws ErrorException distance(M2, [1, 2], [2, 3]) - @test_throws ErrorException manifold_dimension(M2) - @test_throws ErrorException project(M2, [1, 2]) - @test_throws ErrorException project!(M2, A, [1, 2]) - @test_throws ErrorException project(M2, [1, 2], [2, 3]) - @test_throws ErrorException project!(M2, A, [1, 2], [2, 3]) - @test_throws ErrorException vector_transport_along(M2, [1, 2], [2, 3], [[1, 2]]) - @test_throws ErrorException vector_transport_along( + # Check that all of these report not to be implemented, i.e. + @test_throws MethodError exp(M2, [1, 2], [2, 3]) + @test_throws MethodError exp!(M2, A, [1, 2], [2, 3]) + @test_throws MethodError retract(M2, [1, 2], [2, 3]) + @test_throws MethodError retract!(M2, A, [1, 2], [2, 3]) + @test_throws MethodError log(M2, [1, 2], [2, 3]) + @test_throws MethodError log!(M2, A, [1, 2], [2, 3]) + @test_throws MethodError inverse_retract(M2, [1, 2], [2, 3]) + @test_throws MethodError inverse_retract!(M2, A, [1, 2], [2, 3]) + @test_throws MethodError distance(M2, [1, 2], [2, 3]) + @test_throws StackOverflowError manifold_dimension(M2) + @test_throws MethodError project(M2, [1, 2]) + @test_throws MethodError project!(M2, A, [1, 2]) + @test_throws MethodError project(M2, [1, 2], [2, 3]) + @test_throws MethodError project!(M2, A, [1, 2], [2, 3]) + @test_throws MethodError vector_transport_along(M2, [1, 2], [2, 3], [[1, 2]]) + @test_throws MethodError vector_transport_along( M2, [1, 2], [2, 3], [[1, 2]], ParallelTransport(), ) - @test_throws ErrorException vector_transport_along!(M2, A, [1, 2], [2, 3], []) - @test_throws ErrorException vector_transport_direction( - M2, - [1, 2], - [2, 3], - [3, 4], - ) - @test_throws ErrorException vector_transport_direction!( + @test vector_transport_along!(M2, A, [1, 2], [2, 3], []) == [2, 3] + @test A == [2, 3] + @test_throws MethodError vector_transport_direction(M2, [1, 2], [2, 3], [3, 4]) + @test_throws MethodError vector_transport_direction!( M2, A, [1, 2], [2, 3], [3, 4], ) - @test_throws ErrorException vector_transport_to(M2, [1, 2], [2, 3], [3, 4]) - @test_throws ErrorException vector_transport_to!(M2, A, [1, 2], [2, 3], [3, 4]) + @test_throws MethodError vector_transport_to(M2, [1, 2], [2, 3], [3, 4]) + @test_throws MethodError vector_transport_to!(M2, A, [1, 2], [2, 3], [3, 4]) end @testset "Nonisometric Embedding Fallback Error Rests" begin - M3 = NotImplementedEmbeddedManifold3() - @test_throws ErrorException inner(M3, [1, 2], [2, 3], [2, 3]) - @test_throws ErrorException manifold_dimension(M3) - @test_throws ErrorException distance(M3, [1, 2], [2, 3]) - @test_throws ErrorException norm(M3, [1, 2], [2, 3]) - @test_throws ErrorException embed(M3, [1, 2], [2, 3]) - @test_throws ErrorException embed(M3, [1, 2]) + M3 = NotImplementedEmbeddedManifold() + @test_throws MethodError inner(M3, [1, 2], [2, 3], [2, 3]) + @test_throws StackOverflowError manifold_dimension(M3) + @test_throws MethodError distance(M3, [1, 2], [2, 3]) + @test_throws MethodError norm(M3, [1, 2], [2, 3]) + @test_throws MethodError embed(M3, [1, 2], [2, 3]) + @test_throws MethodError embed(M3, [1, 2]) end end - - @testset "EmbeddedManifold decorator dispatch" begin - TM = NotImplementedEmbeddedManifold() # transparently iso - IM = NotImplementedEmbeddedManifold2() # iso - AM = NotImplementedEmbeddedManifold3() # general - for f in [ - embed, - exp, - get_basis, - get_coordinates, - get_vector, - inverse_retract, - log, - norm, - distance, - ] - @test ManifoldsBase.decorator_transparent_dispatch(f, AM) === Val(:parent) - end - for f in - [project, retract, inverse_retract!, retract!, get_coordinates!, get_vector!] - @test ManifoldsBase.decorator_transparent_dispatch(f, AM) === Val(:parent) - end - for f in [vector_transport_along, vector_transport_direction, vector_transport_to] - @test ManifoldsBase.decorator_transparent_dispatch(f, AM) === Val(:parent) - end - for f in [mid_point, mid_point!] - @test ManifoldsBase.decorator_transparent_dispatch(f, AM) === Val(:parent) - end - for f in [check_point, check_vector, exp!, inner, embed!] - @test ManifoldsBase.decorator_transparent_dispatch(f, AM) === - Val(:intransparent) - end - for f in [log!, manifold_dimension, project!] - @test ManifoldsBase.decorator_transparent_dispatch(f, AM) === - Val(:intransparent) - end - @test ManifoldsBase.decorator_transparent_dispatch(vector_transport_along!, AM) === - Val(:transparent) - @test ManifoldsBase.decorator_transparent_dispatch( - vector_transport_to!, - AM, - 1, - 1, - 1, - 1, - 1, - ) === Val(:intransparent) - @test ManifoldsBase.decorator_transparent_dispatch( - vector_transport_direction!, - AM, - ) === Val(:parent) - - for f in [inner, norm] - @test ManifoldsBase.decorator_transparent_dispatch(f, IM) === Val(:transparent) - end - for f in [inverse_retract!, retract!, mid_point!, distance] - @test ManifoldsBase.decorator_transparent_dispatch(f, IM) === Val(:parent) - end - - for f in [exp, inverse_retract, log, project, retract, mid_point, distance] - @test ManifoldsBase.decorator_transparent_dispatch(f, TM) === Val(:transparent) - end - for f in [exp!, inverse_retract!, log!, project!, retract!, mid_point!] - @test ManifoldsBase.decorator_transparent_dispatch(f, TM) === Val(:transparent) - end - for f in [vector_transport_along, vector_transport_direction, vector_transport_to] - @test ManifoldsBase.decorator_transparent_dispatch(f, TM) === Val(:transparent) - end - for f in - [vector_transport_along!, vector_transport_direction!, vector_transport_to!] - @test ManifoldsBase.decorator_transparent_dispatch(f, TM) === Val(:transparent) - end - for t in [PoleLadderTransport(), SchildsLadderTransport()], M in [AM, TM, IM] - @test ManifoldsBase.decorator_transparent_dispatch( - vector_transport_to!, - M, - 1, - 1, - 1, - 1, - PoleLadderTransport(), - ) === Val(:parent) - end - @test ManifoldsBase.decorator_transparent_dispatch(embed, TM) === Val{:parent}() - @test ManifoldsBase.decorator_transparent_dispatch(embed!, TM) === - Val(:intransparent) - end - @testset "Explicit Embeddings using EmbeddedManifold" begin M = DefaultManifold(3, 3) N = DefaultManifold(4, 4) @@ -355,8 +383,11 @@ struct NotImplementedEmbeddedManifold3 <: AbstractEmbeddedManifold{ℝ,DefaultEm @test_throws DomainError embed!(O, zeros(3, 3), zeros(4, 4)) @test_throws DomainError project!(O, zeros(3, 3, 5), zeros(3, 3)) @test_throws DomainError project!(O, zeros(4, 4), zeros(3, 3)) - for f in [embed, project] - @test ManifoldsBase.decorator_transparent_dispatch(f, O) === Val(:intransparent) - end + end + @testset "Explicit Fallback" begin + M = FallbackManifold() + # test the explicit fallback to DefaultManifold(3) + @test inner(M, [1, 0, 0], [1, 2, 3], [0, 1, 0]) == 2 + @test is_point(M, [1, 0, 0]) end end diff --git a/test/empty_manifold.jl b/test/empty_manifold.jl index 02855a06..a20b7a2e 100644 --- a/test/empty_manifold.jl +++ b/test/empty_manifold.jl @@ -17,7 +17,7 @@ struct NotImplementedInverseRetraction <: AbstractInverseRetractionMethod end @test number_system(M) === ℝ @test representation_size(M) === nothing - @test_throws ErrorException manifold_dimension(M) + @test_throws MethodError manifold_dimension(M) # by default isapprox compares given points or vectors @test isapprox(M, [0], [0]) @@ -31,121 +31,109 @@ struct NotImplementedInverseRetraction <: AbstractInverseRetractionMethod end exp_retr = ManifoldsBase.ExponentialRetraction() - @test_throws ErrorException retract!(M, p, p, v) - @test_throws ErrorException retract!(M, p, p, v, exp_retr) - @test_throws ErrorException retract!(M, p, p, [0.0], 0.0) - @test_throws ErrorException retract!(M, p, p, [0.0], 0.0, exp_retr) - @test_throws ErrorException retract!(M, [0], [0], [0]) - @test_throws ErrorException retract!(M, [0], [0], [0], exp_retr) - @test_throws ErrorException retract!(M, [0], [0], [0], 0.0) - @test_throws ErrorException retract!(M, [0], [0], [0], 0.0, exp_retr) - @test_throws ErrorException retract(M, [0], [0]) - @test_throws ErrorException retract(M, [0], [0], exp_retr) - @test_throws ErrorException retract(M, [0], [0], 0.0) - @test_throws ErrorException retract(M, [0], [0], 0.0, exp_retr) - @test_throws ErrorException retract(M, [0.0], [0.0]) - @test_throws ErrorException retract(M, [0.0], [0.0], exp_retr) - @test_throws ErrorException retract(M, [0.0], [0.0], 0.0) - @test_throws ErrorException retract(M, [0.0], [0.0], 0.0, exp_retr) - @test_throws ErrorException retract(M, [0.0], [0.0], NotImplementedRetraction()) + @test_throws MethodError retract!(M, p, p, v) + @test_throws MethodError retract!(M, p, p, v, exp_retr) + @test_throws MethodError retract!(M, p, p, [0.0], 0.0) + @test_throws MethodError retract!(M, p, p, [0.0], 0.0, exp_retr) + @test_throws MethodError retract!(M, [0], [0], [0]) + @test_throws MethodError retract!(M, [0], [0], [0], exp_retr) + @test_throws MethodError retract!(M, [0], [0], [0], 0.0) + @test_throws MethodError retract!(M, [0], [0], [0], 0.0, exp_retr) + @test_throws MethodError retract(M, [0], [0]) + @test_throws MethodError retract(M, [0], [0], exp_retr) + @test_throws MethodError retract(M, [0], [0], 0.0) + @test_throws MethodError retract(M, [0], [0], 0.0, exp_retr) + @test_throws MethodError retract(M, [0.0], [0.0]) + @test_throws MethodError retract(M, [0.0], [0.0], exp_retr) + @test_throws MethodError retract(M, [0.0], [0.0], 0.0) + @test_throws MethodError retract(M, [0.0], [0.0], 0.0, exp_retr) + @test_throws MethodError retract(M, [0.0], [0.0], NotImplementedRetraction()) log_invretr = ManifoldsBase.LogarithmicInverseRetraction() - @test_throws ErrorException inverse_retract!(M, p, p, p) - @test_throws ErrorException inverse_retract!(M, p, p, p, log_invretr) - @test_throws ErrorException inverse_retract!(M, [0], [0], [0]) - @test_throws ErrorException inverse_retract!(M, [0], [0], [0], log_invretr) - @test_throws ErrorException inverse_retract(M, [0], [0]) - @test_throws ErrorException inverse_retract(M, [0], [0], log_invretr) - @test_throws ErrorException inverse_retract(M, [0.0], [0.0]) - @test_throws ErrorException inverse_retract(M, [0.0], [0.0], log_invretr) - @test_throws ErrorException inverse_retract( + @test_throws MethodError inverse_retract!(M, p, p, p) + @test_throws MethodError inverse_retract!(M, p, p, p, log_invretr) + @test_throws MethodError inverse_retract!(M, [0], [0], [0]) + @test_throws MethodError inverse_retract!(M, [0], [0], [0], log_invretr) + @test_throws MethodError inverse_retract(M, [0], [0]) + @test_throws MethodError inverse_retract(M, [0], [0], log_invretr) + @test_throws MethodError inverse_retract(M, [0.0], [0.0]) + @test_throws MethodError inverse_retract(M, [0.0], [0.0], log_invretr) + @test_throws MethodError inverse_retract( M, [0.0], [0.0], NotImplementedInverseRetraction(), ) - @test_throws ErrorException project!(M, p, [0]) - @test_throws ErrorException project!(M, [0], [0]) - @test_throws ErrorException project(M, [0]) - - @test_throws ErrorException project!(M, v, p, [0.0]) - @test_throws ErrorException project!(M, [0], [0], [0]) - @test_throws ErrorException project(M, [0], [0]) - @test_throws ErrorException project(M, [0.0], [0.0]) - - @test_throws ErrorException inner(M, p, v, v) - @test_throws ErrorException inner(M, [0], [0], [0]) - @test_throws ErrorException norm(M, p, v) - @test_throws ErrorException norm(M, [0], [0]) - @test_throws ErrorException angle(M, p, v, v) - @test_throws ErrorException angle(M, [0], [0], [0]) - - @test_throws ErrorException distance(M, [0.0], [0.0]) - - @test_throws ErrorException exp!(M, p, p, v) - @test_throws ErrorException exp!(M, p, p, v, 0.0) - @test_throws ErrorException exp!(M, [0], [0], [0]) - @test_throws ErrorException exp!(M, [0], [0], [0], 0.0) - @test_throws ErrorException exp(M, [0], [0]) - @test_throws ErrorException exp(M, [0], [0], 0.0) - @test_throws ErrorException exp(M, [0.0], [0.0]) - @test_throws ErrorException exp(M, [0.0], [0.0], 0.0) - - @test_throws ErrorException embed!(M, p, [0]) - @test_throws ErrorException embed!(M, [0], [0]) - @test_throws ErrorException embed(M, [0]) - - @test_throws ErrorException embed!(M, v, p, [0.0]) - @test_throws ErrorException embed!(M, [0], [0], [0]) - @test_throws ErrorException embed(M, [0], [0]) - @test_throws ErrorException embed(M, [0.0], [0.0]) - - - @test_throws ErrorException log!(M, v, p, p) - @test_throws ErrorException log!(M, [0], [0], [0]) - @test_throws ErrorException log(M, [0.0], [0.0]) - - @test_throws ErrorException vector_transport_to!(M, [0], [0], [0], [0]) - @test_throws ErrorException vector_transport_to(M, [0], [0], [0]) - @test_throws ErrorException vector_transport_to!( - M, - [0], - [0], - [0], - ProjectionTransport(), - ) + @test_throws MethodError project!(M, p, [0]) + @test_throws MethodError project!(M, [0], [0]) + @test_throws MethodError project(M, [0]) - @test_throws ErrorException vector_transport_direction!(M, [0], [0], [0], [0]) - @test_throws ErrorException vector_transport_direction(M, [0], [0], [0]) + @test_throws MethodError project!(M, v, p, [0.0]) + @test_throws MethodError project!(M, [0], [0], [0]) + @test_throws MethodError project(M, [0], [0]) + @test_throws MethodError project(M, [0.0], [0.0]) - @test_throws ErrorException ManifoldsBase.vector_transport_along!( - M, - [0], - [0], - [0], - x -> x, - ) - @test_throws ErrorException ManifoldsBase.vector_transport_along(M, [0], [0], x -> x) + @test_throws MethodError inner(M, p, v, v) + @test_throws MethodError inner(M, [0], [0], [0]) + @test_throws MethodError norm(M, p, v) + @test_throws MethodError norm(M, [0], [0]) + @test_throws MethodError angle(M, p, v, v) + @test_throws MethodError angle(M, [0], [0], [0]) + + @test_throws MethodError distance(M, [0.0], [0.0]) + + @test_throws MethodError exp!(M, p, p, v) + @test_throws MethodError exp!(M, p, p, v, 0.0) + @test_throws MethodError exp!(M, [0], [0], [0]) + @test_throws MethodError exp!(M, [0], [0], [0], 0.0) + @test_throws MethodError exp(M, [0], [0]) + @test_throws MethodError exp(M, [0], [0], 0.0) + @test_throws MethodError exp(M, [0.0], [0.0]) + @test_throws MethodError exp(M, [0.0], [0.0], 0.0) + + @test_throws MethodError embed!(M, p, [0]) + @test_throws MethodError embed!(M, [0], [0]) + @test_throws MethodError embed(M, [0]) + + @test_throws MethodError embed!(M, v, p, [0.0]) + @test_throws MethodError embed!(M, [0], [0], [0]) + @test_throws MethodError embed(M, [0], [0]) + @test_throws MethodError embed(M, [0.0], [0.0]) + + + @test_throws MethodError log!(M, v, p, p) + @test_throws MethodError log!(M, [0], [0], [0]) + @test_throws MethodError log(M, [0.0], [0.0]) + + @test_throws MethodError vector_transport_to!(M, [0], [0], [0], [0]) + @test_throws MethodError vector_transport_to(M, [0], [0], [0]) + @test_throws MethodError vector_transport_to!(M, [0], [0], [0], ProjectionTransport()) + + @test_throws MethodError vector_transport_direction!(M, [0], [0], [0], [0]) + @test_throws MethodError vector_transport_direction(M, [0], [0], [0]) + + @test_throws MethodError ManifoldsBase.vector_transport_along!(M, [0], [0], [0], x -> x) + @test_throws MethodError vector_transport_along(M, [0], [0], x -> x) - @test_throws ErrorException injectivity_radius(M) - @test_throws ErrorException injectivity_radius(M, [0]) - @test_throws ErrorException injectivity_radius(M, [0], exp_retr) + @test_throws MethodError injectivity_radius(M) + @test_throws MethodError injectivity_radius(M, [0]) + @test_throws MethodError injectivity_radius(M, [0], exp_retr) - @test_throws ErrorException zero_vector!(M, [0], [0]) - @test_throws ErrorException zero_vector(M, [0]) + @test_throws MethodError zero_vector!(M, [0], [0]) + @test_throws MethodError zero_vector(M, [0]) - @test check_point(M, [0]) === nothing - @test check_point(M, p) === nothing + @test ManifoldsBase.check_point(M, [0]) === nothing + @test ManifoldsBase.check_point(M, p) === nothing @test is_point(M, [0]) - @test check_point(M, [0]) === nothing + @test ManifoldsBase.check_point(M, [0]) === nothing - @test check_vector(M, [0], [0]) === nothing - @test check_vector(M, p, v) === nothing + @test ManifoldsBase.check_vector(M, [0], [0]) === nothing + @test ManifoldsBase.check_vector(M, p, v) === nothing @test is_vector(M, [0], [0]) - @test check_vector(M, [0], [0]) === nothing + @test ManifoldsBase.check_vector(M, [0], [0]) === nothing - @test_throws ErrorException hat!(M, [0], [0], [0]) - @test_throws ErrorException vee!(M, [0], [0], [0]) + @test_throws MethodError hat!(M, [0], [0], [0]) + @test_throws MethodError vee!(M, [0], [0], [0]) end diff --git a/test/manifold_fallbacks.jl b/test/manifold_fallbacks.jl new file mode 100644 index 00000000..e5c878d5 --- /dev/null +++ b/test/manifold_fallbacks.jl @@ -0,0 +1,89 @@ +using Test +using ManifoldsBase + +struct NonManifold <: AbstractManifold{ManifoldsBase.ℝ} end + +@testset "NotImplemented Errors" begin + M = NonManifold() + p = [1.0] + q = similar(p) + X = [2.0] + Y = similar(X) + for B in [ + VeeOrthogonalBasis(), + DefaultBasis(), + DefaultOrthogonalBasis(), + DefaultOrthonormalBasis(), + DiagonalizingOrthonormalBasis(X), + CachedBasis(DefaultBasis(), X), + ] + if !(B isa CachedBasis) + @test_throws MethodError get_basis(M, p, B) + @test_throws MethodError get_vector(M, p, X, B) + @test_throws MethodError get_vector!(M, Y, p, X, B) + else + @test get_basis(M, p, B) == B + @test get_vector(M, p, X, B) == 4 # since we have 1 vector + @test get_vector!(M, Y, p, X, B) == [4] # since Y is a vector + end + @test_throws MethodError get_coordinates(M, p, X, B) + @test_throws MethodError get_coordinates!(M, Y, p, X, B) + end + @test_throws MethodError inverse_retract(M, p, q) + @test_throws MethodError inverse_retract!(M, Y, p, q) + for IR in [ + LogarithmicInverseRetraction(), + EmbeddedInverseRetraction(ProjectionInverseRetraction()), + PolarInverseRetraction(), + ProjectionInverseRetraction(), + QRInverseRetraction(), + NLSolveInverseRetraction(ExponentialRetraction()), + SoftmaxInverseRetraction(), + ] + @test_throws MethodError inverse_retract(M, p, q, IR) + @test_throws MethodError inverse_retract!(M, Y, p, q, IR) + end + @test_throws MethodError retract(M, p, X) + @test_throws MethodError retract!(M, q, p, X) + for R in [ + EmbeddedRetraction(ProjectionRetraction()), + ExponentialRetraction(), + ODEExponentialRetraction(ProjectionRetraction(), DefaultBasis()), + ODEExponentialRetraction(ProjectionRetraction()), + PolarRetraction(), + ProjectionRetraction(), + QRRetraction(), + SoftmaxRetraction(), + CayleyRetraction(), + PadeRetraction(2), + ] + @test_throws MethodError retract(M, p, X, R) + @test_throws MethodError retract!(M, q, p, X, R) + end + for VT in [ + DifferentiatedRetractionVectorTransport(ExponentialRetraction()), + ProjectionTransport(), + ParallelTransport(), + ] + @test_throws MethodError vector_transport_along(M, p, X, :curve, VT) + @test_throws MethodError vector_transport_along!(M, Y, p, X, :curve, VT) + @test_throws MethodError vector_transport_direction(M, p, X, X, VT) + @test_throws MethodError vector_transport_direction!(M, Y, p, X, X, VT) + @test_throws MethodError vector_transport_to(M, p, X, q, VT) + @test_throws MethodError vector_transport_to!(M, Y, p, X, q, VT) + end +end + +@testset "Default Fallbacks and Error Messages" begin + M = ManifoldsBase.DefaultManifold(3) + p = [1.0, 0.0, 0.0] + @test number_of_coordinates(M, ManifoldsBase.ℝ) == 3 + B = get_basis(M, p, DefaultBasis()) + @test_throws DomainError ODEExponentialRetraction(ProjectionRetraction(), B) + @test_throws DomainError ODEExponentialRetraction( + ExponentialRetraction(), + DefaultBasis(), + ) + @test_throws DomainError ODEExponentialRetraction(ExponentialRetraction(), B) + @test_throws ErrorException PadeRetraction(0) +end diff --git a/test/passthrough_decorator.jl b/test/passthrough_decorator.jl new file mode 100644 index 00000000..fe8ef127 --- /dev/null +++ b/test/passthrough_decorator.jl @@ -0,0 +1,46 @@ + +using ManifoldsBase +using Test + +using ManifoldsBase: TraitList, merge_traits + +struct PassthoughTrait <: AbstractTrait end + +struct PassthroughDecorator{𝔽,MT<:AbstractManifold{𝔽}} <: AbstractDecoratorManifold{𝔽} + manifold::MT +end + +function ManifoldsBase.active_traits(f, ::PassthroughDecorator, ::Any...) + return merge_traits(PassthoughTrait()) +end + +function ManifoldsBase.log!( + ::TraitList{PassthoughTrait}, + M::AbstractDecoratorManifold, + X, + p, + q, +) + return log!(M.manifold, X, p, q) +end +function ManifoldsBase.exp!( + ::TraitList{PassthoughTrait}, + M::AbstractDecoratorManifold, + q, + p, + X, +) + return log!(M.manifold, q, p, X) +end + +@testset "PassthroughDecorator" begin + M = PassthroughDecorator(ManifoldsBase.DefaultManifold(2)) + + q = [0.0, 0.0] + p = [0.0, 0.0] + X = [1.0, 2.0] + Y = [0.0, 0.0] + @test inverse_retract!(M, q, p, X) == [1.0, 2.0] + @test retract(M, p, q) == [1.0, 2.0] + @test retract!(M, Y, p, q) == [1.0, 2.0] +end diff --git a/test/power.jl b/test/power.jl index da3cb248..d6d2f110 100644 --- a/test/power.jl +++ b/test/power.jl @@ -3,157 +3,187 @@ using ManifoldsBase using ManifoldsBase: AbstractNumbers, ℝ, ℂ, NestedReplacingPowerRepresentation using StaticArrays -struct DummyPowerRepresentation <: AbstractPowerRepresentation end -struct DummyDecorator{TM<:AbstractManifold{ManifoldsBase.ℝ}} <: - AbstractDecoratorManifold{ManifoldsBase.ℝ,DefaultDecoratorType} - manifold::TM -end - power_array_wrapper(::Type{NestedPowerRepresentation}, ::Int) = identity power_array_wrapper(::Type{NestedReplacingPowerRepresentation}, i::Int) = SVector{i} -@testset "Power AbstractManifold" begin - M = ManifoldsBase.DefaultManifold(3) +struct TestArrayRepresentation <: AbstractPowerRepresentation end + +@testset "Power Manifold" begin + + @testset "Power Manifold with a test representation" begin + M = ManifoldsBase.DefaultManifold(3) + N = PowerManifold(M, TestArrayRepresentation(), 2) + O = PowerManifold(N, TestArrayRepresentation(), 3) # joins instead of nesting. + @test repr(O) == + "PowerManifold(DefaultManifold(3; field = ℝ), TestArrayRepresentation(), 2, 3)" + end + + @testset "PowerManifold and allocation with empty representation size" begin + M = ManifoldsBase.DefaultManifold() + N = PowerManifold(M, NestedPowerRepresentation(), 2) + p = [1, 1] + X = [2, 2] + # check - though only because this function exists for avoiding ambiguities. + cm = ManifoldsBase.allocate_result(N, get_coordinates, p, X, DefaultBasis()) + @test size(X) == size(cm) + end + for PowerRepr in [NestedPowerRepresentation, NestedReplacingPowerRepresentation] - N = PowerManifold(M, PowerRepr(), 2) - wrapper_2 = power_array_wrapper(PowerRepr, 2) - wrapper_3 = power_array_wrapper(PowerRepr, 3) - p = wrapper_3.([zeros(3), ones(3)]) - q = wrapper_3.([ones(3), zeros(3)]) + @testset "PowerManifold with $(PowerRepr)" begin + M = ManifoldsBase.DefaultManifold(3) + N = PowerManifold(M, PowerRepr(), 2) + wrapper_2 = power_array_wrapper(PowerRepr, 2) + wrapper_3 = power_array_wrapper(PowerRepr, 3) + p = wrapper_3.([zeros(3), ones(3)]) + q = wrapper_3.([ones(3), zeros(3)]) - @testset "Constructors" begin - @test repr(N) == - "PowerManifold(DefaultManifold(3; field = ℝ), $(PowerRepr)(), 2)" - # add to type - @test repr(N^3) == - "PowerManifold(DefaultManifold(3; field = ℝ), $(PowerRepr)(), 2, 3)" - # add to type - @test repr(PowerManifold(N, 3)) == - "PowerManifold(DefaultManifold(3; field = ℝ), $(PowerRepr)(), 2, 3)" - # switch type - @test repr(PowerManifold(N, DummyPowerRepresentation(), 3)) == - "PowerManifold(DefaultManifold(3; field = ℝ), DummyPowerRepresentation(), 2, 3)" - # nest - @test repr(PowerManifold(N, PowerRepr(), 3)) == - "PowerManifold(PowerManifold(DefaultManifold(3; field = ℝ), $(PowerRepr)(), 2), $(PowerRepr)(), 3)" - end + @testset "Constructors" begin + @test repr(N) == + "PowerManifold(DefaultManifold(3; field = ℝ), $(PowerRepr)(), 2)" + # add to type + @test repr(N^3) == + "PowerManifold(DefaultManifold(3; field = ℝ), $(PowerRepr)(), 2, 3)" + # add to type + @test repr(PowerManifold(N, 3)) == + "PowerManifold(DefaultManifold(3; field = ℝ), $(PowerRepr)(), 2, 3)" + # nest + @test repr(PowerManifold(N, PowerRepr(), 3)) == + "PowerManifold(PowerManifold(DefaultManifold(3; field = ℝ), $(PowerRepr)(), 2), $(PowerRepr)(), 3)" + end - @test N^3 == PowerManifold(M, PowerRepr(), 2, 3) - @test ManifoldsBase.get_iterator(N^3) == Base.product(Base.OneTo(2), Base.OneTo(3)) + @test N^3 == PowerManifold(M, PowerRepr(), 2, 3) + @test ManifoldsBase.get_iterator(N^3) == + Base.product(Base.OneTo(2), Base.OneTo(3)) - @testset "point/tangent checks" begin - pE1 = [wrapper_3(zeros(3)), wrapper_2(ones(2))] # one component wrong - pE2 = wrapper_2.([zeros(2), ones(2)]) # both wrong - @test is_point(N, p) - @test is_point(N, p, true) - @test !is_point(N, pE1) - @test !is_point(N, pE2) - @test_throws ComponentManifoldError is_point(N, pE1, true) - @test_throws CompositeManifoldError is_point(N, pE2, true) - # tangent - test base - @test is_vector(N, p, p) - @test is_vector(N, p, p, true) - @test !is_vector(N, pE1, p) - @test !is_vector(N, pE2, p) - # tangents - with proper base - @test is_vector(N, p, p, true) - @test !is_vector(N, p, pE1) - @test !is_vector(N, p, pE2) - @test_throws ComponentManifoldError is_vector(N, p, pE1, true) - @test_throws CompositeManifoldError is_vector(N, p, pE2, true) - end - @testset "specific functions" begin - @test distance(N, p, q) == sqrt(sum(distance.(Ref(M), p, q) .^ 2)) - @test exp(N, p, q) == p .+ q - @test retract(N, p, q) == p .+ q - @test ManifoldsBase.get_iterator(N) == Base.OneTo(2) - @test injectivity_radius(N) == injectivity_radius(M) - @test injectivity_radius(N, p) == injectivity_radius(M, p) - p2 = allocate(p) - copyto!(N, p2, p) - @test !(p === p2) - @test p == p2 - q2 = allocate(q) - copyto!(N, q2, p, q) - @test !(q === q2) - @test q == q2 - @test inner(N, p, q, q) == sum(inner.(Ref(M), p, q, q)) - @test isapprox(N, p, q) == (all(isapprox.(Ref(M), p, q))) - @test isapprox(N, p, p) == (all(isapprox.(Ref(M), p, p))) - @test isapprox(N, p, q, p) == (all(isapprox.(Ref(M), p, q, p))) - @test isapprox(N, p, p, p) == (all(isapprox.(Ref(M), p, p, p))) - @test log(N, p, q) == q .- p - @test inverse_retract(N, p, q) == q .- p - @test manifold_dimension(N) == 2 * manifold_dimension(M) - @test mid_point(N, p, q) == mid_point.(Ref(M), p, q) - @test sqrt(inner(N, p, q, q)) == norm(N, p, q) - @test project(N, p) == p - @test project(N, p, q) == q - @test power_dimensions(N) == (2,) - @test power_dimensions(N^3) == (2, 3) - m = ParallelTransport() - @test vector_transport_to(N, p, p, q) == p - @test vector_transport_to(N, p, p, q, m) == p - @test vector_transport_to(N, p, p, q, PowerVectorTransport(m)) == p - q2 = [zeros(3), zeros(3)] - vector_transport_to!(N, q2, p, q, p) - @test q2 == q - @test vector_transport_direction(N, p, p, q) == p - @test vector_transport_direction(N, p, p, q, m) == p - @test vector_transport_direction(N, p, p, q, PowerVectorTransport(m)) == p - q2 = [zeros(3), zeros(3)] - vector_transport_direction!(N, q2, p, q, p) - @test q2 == q - @test p[N, 1] == p[1] - p[N, 1] = 2 .* ones(3) - @test p[N, 1] == 2 .* ones(3) - @test p[N, [2, 1]] == [p[2], p[1]] - if PowerRepr == NestedPowerRepresentation - @test view(p, N, 2) isa SubArray - @test view(p, N, 2) == p[2] + @testset "point/tangent checks" begin + pE1 = [wrapper_3(zeros(3)), wrapper_2(ones(2))] # one component wrong + pE2 = wrapper_2.([zeros(2), ones(2)]) # both wrong + @test is_point(N, p) + @test is_point(N, p, true) + @test !is_point(N, pE1) + @test !is_point(N, pE2) + @test_throws ComponentManifoldError is_point(N, pE1, true) + @test_throws CompositeManifoldError is_point(N, pE2, true) + # tangent - test base + @test is_vector(N, p, p) + @test is_vector(N, p, p, true) + @test !is_vector(N, pE1, p) + @test !is_vector(N, pE2, p) + # tangents - with proper base + @test is_vector(N, p, p, true) + @test !is_vector(N, p, pE1) + @test !is_vector(N, p, pE2) + @test_throws ComponentManifoldError is_vector(N, p, pE1, true) + @test_throws CompositeManifoldError is_vector(N, p, pE2, true) end - end - @testset "Basis, coordinates & vector" begin - v = get_coordinates(N, p, q, DefaultBasis()) - @test v == [q[1]..., q[2]...] - @test get_vector(N, p, v, DefaultBasis()) == q - B = get_basis(N, p, DefaultBasis()) - @test get_coordinates(N, p, q, B) == v - # the method tested below should not be used but it prevents ambiguities from occurring - # and the test is here to make coverage happy - @test ManifoldsBase.allocate_result(N, get_coordinates, p, q, B) isa Vector - v2 = zeros(size(v)) - get_coordinates!(N, v2, p, q, B) - @test v2 == v - @test get_coordinates(N, p, q, DefaultOrthonormalBasis()) == v - @test B.data.bases[1].data == get_basis(M, p[1], DefaultBasis()).data - @test B.data.bases[2].data == get_basis(M, p[2], DefaultBasis()).data - @test get_vector(N, p, v, B) == q - @test zero_vector(N, p) == [zeros(3), zeros(3)] - B2 = DiagonalizingOrthonormalBasis([ones(3), ones(3)]) - B3 = get_basis(N, p, B2) - @test sprint(show, "text/plain", B) == """$(DefaultBasis()) for a power manifold - Basis for component (1,): - $(sprint(show, "text/plain", B.data.bases[1])) - Basis for component (2,): - $(sprint(show, "text/plain", B.data.bases[2])) - """ - end + @testset "specific functions" begin + @test distance(N, p, q) == sqrt(sum(distance.(Ref(M), p, q) .^ 2)) + @test exp(N, p, q) == p .+ q - @testset "Zero index manifold" begin - Mzero = ManifoldsBase.DefaultManifold() - N = PowerManifold(Mzero, PowerRepr(), 3) - p = [1.0, 2.0, 3.0] + @test retract(N, p, q) == p .+ q + @test retract(N, p, q, ExponentialRetraction()) == p .+ q + r = allocate(p) + @test retract!(N, r, p, q, ExponentialRetraction()) == p .+ q + @test r == p .+ q + @test inverse_retract(N, p, r) == q + @test inverse_retract(N, p, r, LogarithmicInverseRetraction()) == q + X = allocate(p) + @test inverse_retract!(N, X, p, r, LogarithmicInverseRetraction()) == q + @test X == q - @test p[N, 1] == 1.0 - @test zero_vector(N, p) == zero(p) - end - @testset "Decorator passthrough for getindex" begin - Mzero = ManifoldsBase.DefaultManifold() - N = PowerManifold(Mzero, PowerRepr(), 3) - p = [1.0, 2.0, 3.0] - DN = DummyDecorator(N) - @test p[DN, 1] == p[N, 1] + @test ManifoldsBase.get_iterator(N) == Base.OneTo(2) + @test injectivity_radius(N) == injectivity_radius(M) + @test injectivity_radius(N, ExponentialRetraction()) == + injectivity_radius(M) + @test injectivity_radius(N, p) == injectivity_radius(M, p) + p2 = allocate(p) + copyto!(N, p2, p) + @test !(p === p2) + @test p == p2 + q2 = allocate(q) + copyto!(N, q2, p, q) + @test !(q === q2) + @test q == q2 + @test inner(N, p, q, q) == sum(inner.(Ref(M), p, q, q)) + @test isapprox(N, p, q) == (all(isapprox.(Ref(M), p, q))) + @test isapprox(N, p, p) == (all(isapprox.(Ref(M), p, p))) + @test isapprox(N, p, q, p) == (all(isapprox.(Ref(M), p, q, p))) + @test isapprox(N, p, p, p) == (all(isapprox.(Ref(M), p, p, p))) + @test log(N, p, q) == q .- p + @test inverse_retract(N, p, q) == q .- p + @test manifold_dimension(N) == 2 * manifold_dimension(M) + @test mid_point(N, p, q) == mid_point.(Ref(M), p, q) + @test sqrt(inner(N, p, q, q)) == norm(N, p, q) + @test project(N, p) == p + @test project(N, p, q) == q + @test power_dimensions(N) == (2,) + @test power_dimensions(N^3) == (2, 3) + m = ParallelTransport() + @test vector_transport_to(N, p, p, q) == p + @test vector_transport_to(N, p, p, q, m) == p + q2 = [zeros(3), zeros(3)] + vector_transport_to!(N, q2, p, q, p) + @test q2 == q + @test vector_transport_direction(N, p, p, q) == p + @test vector_transport_direction(N, p, p, q, m) == p + q2 = [zeros(3), zeros(3)] + vector_transport_direction!(N, q2, p, q, p) + @test q2 == q + @test p[N, 1] == p[1] + p[N, 1] = 2 .* ones(3) + @test p[N, 1] == 2 .* ones(3) + @test p[N, [2, 1]] == [p[2], p[1]] + if PowerRepr == NestedPowerRepresentation + @test view(p, N, 2) isa SubArray + @test view(p, N, 2) == p[2] + end + end + @testset "Basis, coordinates & vector" begin + v = get_coordinates(N, p, q, DefaultBasis()) + @test v == [q[1]..., q[2]...] + @test get_vector(N, p, v, DefaultBasis()) == q + B = get_basis(N, p, DefaultBasis()) + @test get_coordinates(N, p, q, B) == v + # the method tested below should not be used but it prevents ambiguities from occurring + # and the test is here to make coverage happy + @test ManifoldsBase.allocate_result(N, get_coordinates, p, q, B) isa Vector + v2 = similar(v) + get_coordinates!(N, v2, p, q, B) + @test v2 == v + @test get_coordinates(N, p, q, DefaultOrthonormalBasis()) == v + v3 = similar(v2) + @test get_coordinates!(N, v3, p, q, DefaultOrthonormalBasis()) == v + @test v3 == v + @test B.data.bases[1].data == get_basis(M, p[1], DefaultBasis()).data + @test B.data.bases[2].data == get_basis(M, p[2], DefaultBasis()).data + @test get_vector(N, p, v, B) == q + q2 = similar.(q) + @test get_vector!(N, q2, p, v, B) == q + @test q2 == q + q3 = similar.(q) + @test get_vector(N, p, v, DefaultOrthonormalBasis()) == q + @test get_vector!(N, q3, p, v, DefaultOrthonormalBasis()) == q + @test q3 == q + @test zero_vector(N, p) == [zeros(3), zeros(3)] + B2 = DiagonalizingOrthonormalBasis([ones(3), ones(3)]) + B3 = get_basis(N, p, B2) + @test sprint(show, "text/plain", B) == + """$(DefaultBasis()) for a power manifold +Basis for component (1,): +$(sprint(show, "text/plain", B.data.bases[1])) +Basis for component (2,): +$(sprint(show, "text/plain", B.data.bases[2])) +""" + end + + @testset "Zero index manifold" begin + Mzero = ManifoldsBase.DefaultManifold() + N = PowerManifold(Mzero, PowerRepr(), 3) + p = [1.0, 2.0, 3.0] + + @test p[N, 1] == 1.0 + @test zero_vector(N, p) == zero(p) + end end end end diff --git a/test/runtests.jl b/test/runtests.jl index 767ca0a8..94ed88a8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,25 +1,28 @@ using Test using ManifoldsBase @testset "ManifoldsBase" begin - + bound = 0 if VERSION >= v"1.1" num_ambiguities = length(Test.detect_ambiguities(ManifoldsBase)) #num_ambiguities > 0 && @warn "The number of ambiguities in ManifoldsBase is $(num_ambiguities)." - if VERSION >= v"1.7-DEV" - @test num_ambiguities <= 4 + if VERSION >= v"1.8-DEV" + @test num_ambiguities <= bound + 4 elseif VERSION >= v"1.6-DEV" # At the time of writing there seem to be two ambiguities regarding `getindex`, # one with a method from SparseArrays and one from VSCode's JSON processing # that's automatically loaded when running code in VSCode. - @test num_ambiguities <= 2 + @test num_ambiguities <= bound + 1 else - @test num_ambiguities == 0 + @test num_ambiguities == bound end end + + include("decorator_traits.jl") + include("passthrough_decorator.jl") include("allocation.jl") include("numbers.jl") include("bases.jl") - include("decorator_manifold.jl") + include("manifold_fallbacks.jl") include("empty_manifold.jl") include("errors.jl") include("default_manifold.jl") @@ -28,5 +31,5 @@ using ManifoldsBase include("embedded_manifold.jl") include("power.jl") include("domain_errors.jl") - include("vector_transport_along.jl") + include("vector_transport.jl") end diff --git a/test/validation_manifold.jl b/test/validation_manifold.jl index c081a13e..b4387e05 100644 --- a/test/validation_manifold.jl +++ b/test/validation_manifold.jl @@ -104,28 +104,25 @@ end @test isapprox(A, zero_vector(A, x), zero_vector(M, x)) vector_transport_to!(A, v2s, x2, v2, y2) @test isapprox(A, x2, v2, v2s) - vector_transport_to!(A, v2s, x2, v2, y2, ManifoldsBase.SchildsLadderTransport()) + @test isapprox(A, x2, v2, vector_transport_to(A, x2, v2, y2)) + slt = ManifoldsBase.SchildsLadderTransport() + vector_transport_to!(A, v2s, x2, v2, y2, slt) @test isapprox(A, x2, v2, v2s) - vector_transport_to!(A, v2s, x2, v2, y, ManifoldsBase.PoleLadderTransport()) + @test isapprox(A, x2, v2, vector_transport_to(A, x2, v2, y2, slt)) + plt = ManifoldsBase.PoleLadderTransport() + vector_transport_to!(A, v2s, x2, v2, y, plt) + @test isapprox(A, x2, v2, vector_transport_to(A, x2, v2, y2, plt)) @test isapprox(A, x2, v2, v2s) - vector_transport_to!(A, v2s, x2, v2, y2, ManifoldsBase.ProjectionTransport()) + pt = ManifoldsBase.ProjectionTransport() + vector_transport_to!(A, v2s, x2, v2, y2, pt) @test isapprox(A, x2, v2, v2s) + @test isapprox(A, x2, v2, vector_transport_to(A, x2, v2, y2, pt)) zero_vector!(A, v2s, x) @test isapprox(A, x, v2s, zero_vector(M, x)) - c = [x2] + c2 = [x2] v3 = similar(v2) - @test isapprox( - A, - x2, - v2, - vector_transport_along!(A, v3, x2, v2, c, ParallelTransport()), - ) - @test isapprox( - A, - x2, - v2, - vector_transport_along(A, x2, v2, c, ManifoldsBase.ProjectionTransport()), - ) + @test isapprox(A, x2, v2, vector_transport_along!(A, v3, x2, v2, c2)) + @test isapprox(A, x2, v2, vector_transport_along(A, x2, v2, c2)) @test injectivity_radius(A) == Inf @test injectivity_radius(A, x) == Inf @test injectivity_radius(A, ManifoldsBase.ExponentialRetraction()) == Inf diff --git a/test/vector_transport_along.jl b/test/vector_transport.jl similarity index 70% rename from test/vector_transport_along.jl rename to test/vector_transport.jl index 70e5e33e..2d610c4f 100644 --- a/test/vector_transport_along.jl +++ b/test/vector_transport.jl @@ -3,20 +3,16 @@ # also the Schild and pole special cases # using ManifoldsBase, Test - +import ManifoldsBase: parallel_transport_to!, parallel_transport_along! struct NonDefaultEuclidean <: AbstractManifold{ManifoldsBase.ℝ} end ManifoldsBase.log!(::NonDefaultEuclidean, v, x, y) = (v .= y .- x) ManifoldsBase.exp!(::NonDefaultEuclidean, y, x, v) = (y .= x .+ v) -function ManifoldsBase.vector_transport_to!( - ::NonDefaultEuclidean, - vto, - x, - v, - y, - ::ParallelTransport, -) - return copyto!(vto, v) +function ManifoldsBase.parallel_transport_to!(::NonDefaultEuclidean, Y, p, X, q) + return copyto!(Y, X) +end +function ManifoldsBase.parallel_transport_along!(::NonDefaultEuclidean, Y, p, X, q) + return copyto!(Y, X) end @testset "vector_transport_along" begin @@ -43,3 +39,17 @@ end end end end + +@testset "vector-transport fallback types" begin + VT = VectorTransportDirection() + M = NonDefaultEuclidean() + p = [1.0, 0.0, 0.0] + q = [0.0, 1.0, 0.0] + X = [0.0, 0.0, 1.0] + Y = similar(X) + @test vector_transport_direction(M, p, X, p - q, VT) == X + @test vector_transport_direction!(M, Y, p, X, p - q, VT) == X + VT2 = VectorTransportTo() + @test vector_transport_to(M, p, X, q, VT2) == X + @test vector_transport_to!(M, Y, p, X, q, VT2) == X +end