{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "4b96010c-2937-4eb6-9052-2d44b61acc49", "metadata": {}, "outputs": [], "source": [ "using LinearAlgebra\n", "using Manifolds\n", "using ManifoldDiff\n", "using Manopt\n", "using Random\n", "using Zygote" ] }, { "cell_type": "code", "execution_count": null, "id": "6e924039-68d6-4249-8948-f13838c4afb3", "metadata": {}, "outputs": [], "source": [ "m, n = 2, 8\n", "H = randn(ComplexF64, m , n)\n", "HtH = H'*H\n", "P = randn(ComplexF64, n, m)" ] }, { "cell_type": "code", "execution_count": null, "id": "9b137cb9-9f0f-4a59-95ae-74d2fc6cf171", "metadata": {}, "outputs": [], "source": [ "f(M, P) = real(logdet(I(m) + P' * HtH * P))\n", "euclidean_∇f(M, P) = 2HtH*P / (I(m) + P' * HtH * P)" ] }, { "cell_type": "code", "execution_count": null, "id": "b858502b-2e27-4572-b756-b63dd5ec1c48", "metadata": {}, "outputs": [], "source": [ "manif = Manifolds.ArraySphere(n, m, field = ℂ)\n", "∇f(M, P) = ManifoldDiff.riemannian_gradient(M, P, euclidean_∇f(get_embedding(M), embed(M, P)))" ] }, { "cell_type": "code", "execution_count": null, "id": "b49750e9-c52b-4db9-8628-fc1fc0a86d8a", "metadata": {}, "outputs": [], "source": [ "# If I set the memory size to -1, this crashes\n", "res = @time Manopt.quasi_Newton(manif, f, ∇f, P\n", " , memory_size = m*n\n", " # , memory_size = -1\n", " , stopping_criterion = StopAfterIteration(2000) | StopWhenGradientNormLess(1e-6))" ] }, { "cell_type": "code", "execution_count": null, "id": "2b7213aa-2217-4def-9c27-f1b26194ad09", "metadata": {}, "outputs": [], "source": [ "# Check the constraint\n", "tr(res'*res)" ] }, { "cell_type": "code", "execution_count": null, "id": "4e40757f-4025-44b2-9fa6-309628796aa6", "metadata": {}, "outputs": [], "source": [ "# Check gradient evaluation\n", "g1 = @time Zygote.gradient(W -> f(manif, W), P)[1]" ] }, { "cell_type": "code", "execution_count": null, "id": "7699dd43-d5b0-4ac7-b9fa-648e134269a7", "metadata": {}, "outputs": [], "source": [ "g2 = @time euclidean_∇f(manif, P) " ] }, { "cell_type": "code", "execution_count": null, "id": "62ebbde2-fc2d-41cf-8697-252f0945104e", "metadata": {}, "outputs": [], "source": [ "norm(g1 - g2)" ] }, { "cell_type": "code", "execution_count": null, "id": "4cd47918-07a0-42dd-a03d-8893e42707a4", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Julia 1.9.4", "language": "julia", "name": "julia-1.9" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.9.4" } }, "nbformat": 4, "nbformat_minor": 5 }