From 2dec4ae1751f09004739c6c68ceda7ab90223ef1 Mon Sep 17 00:00:00 2001 From: Antoine Prouvost Date: Thu, 25 May 2023 16:59:28 +0200 Subject: [PATCH] Add ObjSolver (#2504) * Add ObjSolver * Add ObjSolver problems * Add ObjSolver rule getters * Add ObjPool::add_conda_dependency * Add ObjSolver::solve test * Add extra header --- libmamba/CMakeLists.txt | 2 + libmamba/src/solv-cpp/ids.hpp | 3 + libmamba/src/solv-cpp/pool.cpp | 14 +++ libmamba/src/solv-cpp/pool.hpp | 13 +- libmamba/src/solv-cpp/repo.hpp | 5 +- libmamba/src/solv-cpp/solvable.hpp | 5 +- libmamba/src/solv-cpp/solver.cpp | 115 ++++++++++++++++++ libmamba/src/solv-cpp/solver.hpp | 101 ++++++++++++++++ libmamba/tests/CMakeLists.txt | 1 + libmamba/tests/src/solv-cpp/test_pool.cpp | 7 ++ libmamba/tests/src/solv-cpp/test_solver.cpp | 126 ++++++++++++++++++++ 11 files changed, 389 insertions(+), 3 deletions(-) create mode 100644 libmamba/src/solv-cpp/solver.cpp create mode 100644 libmamba/src/solv-cpp/solver.hpp create mode 100644 libmamba/tests/src/solv-cpp/test_solver.cpp diff --git a/libmamba/CMakeLists.txt b/libmamba/CMakeLists.txt index e752f866ec..20c4c54053 100644 --- a/libmamba/CMakeLists.txt +++ b/libmamba/CMakeLists.txt @@ -123,6 +123,7 @@ set(LIBMAMBA_SOURCES ${LIBMAMBA_SOURCE_DIR}/solv-cpp/pool.cpp ${LIBMAMBA_SOURCE_DIR}/solv-cpp/repo.cpp ${LIBMAMBA_SOURCE_DIR}/solv-cpp/solvable.cpp + ${LIBMAMBA_SOURCE_DIR}/solv-cpp/solver.cpp # C++ wrapping of libcurl ${LIBMAMBA_SOURCE_DIR}/core/curl.cpp # C++ wrapping of compression libs (zstd and bzlib) @@ -275,6 +276,7 @@ set(LIBMAMBA_PRIVATE_HEADERS ${LIBMAMBA_SOURCE_DIR}/solv-cpp/ids.hpp ${LIBMAMBA_SOURCE_DIR}/solv-cpp/pool.hpp ${LIBMAMBA_SOURCE_DIR}/solv-cpp/solvable.hpp + ${LIBMAMBA_SOURCE_DIR}/solv-cpp/solver.hpp ${LIBMAMBA_SOURCE_DIR}/solv-cpp/repo.hpp # C++ wrapping of compression libs (zstd and bzlib) ${LIBMAMBA_SOURCE_DIR}/core/compression.hpp diff --git a/libmamba/src/solv-cpp/ids.hpp b/libmamba/src/solv-cpp/ids.hpp index c3a3d3f7cb..97c18b7290 100644 --- a/libmamba/src/solv-cpp/ids.hpp +++ b/libmamba/src/solv-cpp/ids.hpp @@ -15,9 +15,12 @@ namespace mamba::solv using DependencyId = ::Id; using RepoId = ::Id; using SolvableId = ::Id; + using RuleId = ::Id; + using ProblemId = ::Id; using RelationFlag = int; using DistType = int; + using SolverFlag = int; } #endif diff --git a/libmamba/src/solv-cpp/pool.cpp b/libmamba/src/solv-cpp/pool.cpp index 8dcc52a4f4..c02081e3a5 100644 --- a/libmamba/src/solv-cpp/pool.cpp +++ b/libmamba/src/solv-cpp/pool.cpp @@ -12,6 +12,10 @@ #include #include #include +extern "C" // Incomplete header +{ +#include +} #include "solv-cpp/pool.hpp" @@ -112,6 +116,16 @@ namespace mamba::solv return id; } + auto ObjPool::add_conda_dependency(raw_str_view dep) -> DependencyId + { + return ::pool_conda_matchspec(raw(), dep); + } + + auto ObjPool::add_conda_dependency(const std::string& dep) -> DependencyId + { + return add_conda_dependency(dep.c_str()); + } + auto ObjPool::get_dependency_name(DependencyId id) const -> std::string_view { return ::pool_id2str(raw(), id); diff --git a/libmamba/src/solv-cpp/pool.hpp b/libmamba/src/solv-cpp/pool.hpp index e9cf31975e..8ed17f745a 100644 --- a/libmamba/src/solv-cpp/pool.hpp +++ b/libmamba/src/solv-cpp/pool.hpp @@ -18,7 +18,10 @@ #include "solv-cpp/repo.hpp" #include "solv-cpp/solvable.hpp" -using Pool = struct s_Pool; +extern "C" +{ + using Pool = struct s_Pool; +} namespace mamba::solv { @@ -33,6 +36,8 @@ namespace mamba::solv { public: + using raw_str_view = const char*; + ObjPool(); ~ObjPool(); @@ -95,6 +100,12 @@ namespace mamba::solv */ auto add_dependency(StringId name_id, RelationFlag flag, StringId version_id) -> DependencyId; + /** + * Parse a dependency from string and add it to the pool. + */ + auto add_conda_dependency(raw_str_view dep) -> DependencyId; + auto add_conda_dependency(const std::string& dep) -> DependencyId; + /** Get the registered name of a dependency. */ auto get_dependency_name(DependencyId id) const -> std::string_view; diff --git a/libmamba/src/solv-cpp/repo.hpp b/libmamba/src/solv-cpp/repo.hpp index 309182819a..85eea1459c 100644 --- a/libmamba/src/solv-cpp/repo.hpp +++ b/libmamba/src/solv-cpp/repo.hpp @@ -14,7 +14,10 @@ #include "solv-cpp/ids.hpp" #include "solv-cpp/solvable.hpp" -using Repo = struct s_Repo; +extern "C" +{ + using Repo = struct s_Repo; +} namespace fs { diff --git a/libmamba/src/solv-cpp/solvable.hpp b/libmamba/src/solv-cpp/solvable.hpp index 3b216f7993..df866712f9 100644 --- a/libmamba/src/solv-cpp/solvable.hpp +++ b/libmamba/src/solv-cpp/solvable.hpp @@ -16,7 +16,10 @@ #include "solv-cpp/ids.hpp" #include "solv-cpp/queue.hpp" -using Solvable = struct s_Solvable; +extern "C" +{ + using Solvable = struct s_Solvable; +} namespace mamba::solv { diff --git a/libmamba/src/solv-cpp/solver.cpp b/libmamba/src/solv-cpp/solver.cpp new file mode 100644 index 0000000000..901e73558c --- /dev/null +++ b/libmamba/src/solv-cpp/solver.cpp @@ -0,0 +1,115 @@ +// Copyright (c) 2023, QuantStack and Mamba Contributors +// +// Distributed under the terms of the BSD 3-Clause License. +// +// The full license is in the file LICENSE, distributed with this software. + +#include + +#include +#include +#include +// broken headers go last +#include +#include + +#include "solv-cpp/pool.hpp" +#include "solv-cpp/queue.hpp" +#include "solv-cpp/solver.hpp" + +namespace mamba::solv +{ + void ObjSolver::SolverDeleter::operator()(::Solver* ptr) + { + ::solver_free(ptr); + } + + ObjSolver::ObjSolver(const ObjPool& pool) + : m_solver(::solver_create(const_cast<::Pool*>(pool.raw()))) + { + } + + ObjSolver::~ObjSolver() = default; + + auto ObjSolver::raw() -> ::Solver* + { + return m_solver.get(); + } + + void ObjSolver::set_flag(SolverFlag flag, bool value) + { + ::solver_set_flag(raw(), flag, value); + } + + auto ObjSolver::get_flag(SolverFlag flag) const -> bool + { + const auto val = ::solver_get_flag(const_cast<::Solver*>(raw()), flag); + assert((val == 0) || (val == 1)); + return val != 0; + } + + auto ObjSolver::raw() const -> const ::Solver* + { + return m_solver.get(); + } + + auto ObjSolver::solve(const ObjPool& /* pool */, const ObjQueue& jobs) -> bool + { + // pool is captured inside solver so we take it as a parameter to be explicit. + const auto n_pbs = ::solver_solve(raw(), const_cast<::Queue*>(jobs.raw())); + return n_pbs == 0; + } + + auto ObjSolver::problem_count() const -> std::size_t + { + return ::solver_problem_count(const_cast<::Solver*>(raw())); + } + + auto ObjSolver::problem_to_string(const ObjPool& /* pool */, ProblemId id) const -> std::string + { + // pool is captured inside solver so we take it as a parameter to be explicit. + return ::solver_problem2str(const_cast<::Solver*>(raw()), id); + } + + auto ObjSolver::next_problem(ProblemId id) const -> ProblemId + { + return ::solver_next_problem(const_cast<::Solver*>(raw()), id); + } + + auto ObjSolver::problem_rules(ProblemId id) const -> ObjQueue + { + ObjQueue rules = {}; + ::solver_findallproblemrules(const_cast<::Solver*>(raw()), id, rules.raw()); + return rules; + } + + auto ObjSolver::get_rule_info(const ObjPool& /* pool */, RuleId id) const -> ObjRuleInfo + { + // pool is captured inside solver so we take it as a parameter to be explicit. + SolvableId from_id = 0; + SolvableId to_id = 0; + DependencyId dep_id = 0; + const auto type = ::solver_ruleinfo(const_cast<::Solver*>(raw()), id, &from_id, &to_id, &dep_id); + + return { + /* .from_id= */ (from_id != 0) ? std::optional{ from_id } : std::nullopt, + /* .to_id= */ (to_id != 0) ? std::optional{ to_id } : std::nullopt, + /* .dep_id= */ (dep_id != 0) ? std::optional{ dep_id } : std::nullopt, + /* .type= */ type, + /* .klass= */ ::solver_ruleclass(const_cast<::Solver*>(raw()), id), + }; + } + + auto ObjSolver::rule_info_to_string(const ObjPool& /* pool */, ObjRuleInfo ri) const -> std::string + { + // pool is captured inside solver so we take it as a parameter to be explicit. + return ::solver_ruleinfo2str( + const_cast<::Solver*>(raw()), + ri.type, + ri.from_id.value_or(0), + ri.to_id.value_or(0), + ri.dep_id.value_or(0) + ); + } + +} diff --git a/libmamba/src/solv-cpp/solver.hpp b/libmamba/src/solv-cpp/solver.hpp new file mode 100644 index 0000000000..aaa08430b2 --- /dev/null +++ b/libmamba/src/solv-cpp/solver.hpp @@ -0,0 +1,101 @@ +// Copyright (c) 2023, QuantStack and Mamba Contributors +// +// Distributed under the terms of the BSD 3-Clause License. +// +// The full license is in the file LICENSE, distributed with this software. + +#ifndef MAMBA_SOLV_SOLVER_HPP +#define MAMBA_SOLV_SOLVER_HPP + +#include +#include +#include + +// START Only required for broken header +#include +extern "C" +{ + typedef struct s_Solvable Solvable; + typedef struct s_Map Map; + typedef struct s_Queue Queue; +} +// END +#include + +#include "solv-cpp/ids.hpp" +#include "solv-cpp/queue.hpp" + +extern "C" +{ + using Solver = struct s_Solver; +} + +namespace mamba::solv +{ + class ObjPool; + class ObjQueue; + + struct ObjRuleInfo + { + std::optional from_id; + std::optional to_id; + std::optional dep_id; + ::SolverRuleinfo type; + ::SolverRuleinfo klass; + }; + + class ObjSolver + { + public: + + ObjSolver(const ObjPool& pool); + ~ObjSolver(); + + auto raw() -> ::Solver*; + auto raw() const -> const ::Solver*; + + void set_flag(SolverFlag flag, bool value); + [[nodiscard]] auto get_flag(SolverFlag flag) const -> bool; + + [[nodiscard]] auto solve(const ObjPool& pool, const ObjQueue& jobs) -> bool; + + [[nodiscard]] auto problem_count() const -> std::size_t; + [[nodiscard]] auto problem_to_string(const ObjPool& pool, ProblemId id) const -> std::string; + template + void for_each_problem_id(UnaryFunc&& func) const; + + /** + * Return an @ref ObjQueue of @ref RuleId with all rules involved in a current problem. + */ + [[nodiscard]] auto problem_rules(ProblemId id) const -> ObjQueue; + [[nodiscard]] auto get_rule_info(const ObjPool& pool, RuleId id) const -> ObjRuleInfo; + [[nodiscard]] auto rule_info_to_string(const ObjPool& pool, ObjRuleInfo id) const + -> std::string; + + private: + + struct SolverDeleter + { + void operator()(::Solver* ptr); + }; + + std::unique_ptr<::Solver, ObjSolver::SolverDeleter> m_solver = nullptr; + + auto next_problem(ProblemId id = 0) const -> ProblemId; + }; + + /********************************* + * Implementation of ObjSolver * + *********************************/ + + template + void ObjSolver::for_each_problem_id(UnaryFunc&& func) const + { + for (ProblemId id = next_problem(); id != 0; id = next_problem(id)) + { + func(id); + } + } + +} +#endif diff --git a/libmamba/tests/CMakeLists.txt b/libmamba/tests/CMakeLists.txt index 84595008e8..052fcba9b0 100644 --- a/libmamba/tests/CMakeLists.txt +++ b/libmamba/tests/CMakeLists.txt @@ -17,6 +17,7 @@ set(LIBMAMBA_TEST_SRCS src/solv-cpp/test_pool.cpp src/solv-cpp/test_repo.cpp src/solv-cpp/test_solvable.cpp + src/solv-cpp/test_solver.cpp # Utility library src/util/test_flat_set.cpp src/util/test_graph.cpp diff --git a/libmamba/tests/src/solv-cpp/test_pool.cpp b/libmamba/tests/src/solv-cpp/test_pool.cpp index 55e2f34379..c5345810ad 100644 --- a/libmamba/tests/src/solv-cpp/test_pool.cpp +++ b/libmamba/tests/src/solv-cpp/test_pool.cpp @@ -72,6 +72,13 @@ TEST_SUITE("ObjPool") CHECK_EQ(pool.get_dependency_relation(id_rel), " > "); CHECK_EQ(pool.get_dependency_version(id_rel), "1.0.0"); CHECK_EQ(pool.dependency_to_string(id_rel), "mamba > 1.0.0"); + + SUBCASE("Parse a conda dependency") + { + const auto id_conda = pool.add_conda_dependency("rattler < 0.1"); + CHECK_EQ(pool.get_dependency_name(id_conda), "rattler"); + CHECK_EQ(pool.get_dependency_version(id_conda), "<0.1"); + } } SUBCASE("Add repo") diff --git a/libmamba/tests/src/solv-cpp/test_solver.cpp b/libmamba/tests/src/solv-cpp/test_solver.cpp new file mode 100644 index 0000000000..32e2ac3580 --- /dev/null +++ b/libmamba/tests/src/solv-cpp/test_solver.cpp @@ -0,0 +1,126 @@ +// Copyright (c) 2023, QuantStack and Mamba Contributors +// +// Distributed under the terms of the BSD 3-Clause License. +// +// The full license is in the file LICENSE, distributed with this software. + +#include +#include + +#include +#include + +#include "solv-cpp/pool.hpp" +#include "solv-cpp/solver.hpp" + +using namespace mamba::solv; + +struct SimplePkg +{ + std::string name; + std::string version; + std::vector dependencies = {}; +}; + +auto +make_simple_packages() -> std::vector +{ + return { + { "menu", "1.5.0", { "dropdown=2.*" } }, + { "menu", "1.4.0", { "dropdown=2.*" } }, + { "menu", "1.3.0", { "dropdown=2.*" } }, + { "menu", "1.2.0", { "dropdown=2.*" } }, + { "menu", "1.1.0", { "dropdown=2.*" } }, + { "menu", "1.0.0", { "dropdown=1.*" } }, + { "dropdown", "2.3.0", { "icons=2.*" } }, + { "dropdown", "2.2.0", { "icons=2.*" } }, + { "dropdown", "2.1.0", { "icons=2.*" } }, + { "dropdown", "2.0.0", { "icons=2.*" } }, + { "dropdown", "1.8.0", { "icons=1.*", "intl=3.*" } }, + { "icons", "2.0.0" }, + { "icons", "1.0.0" }, + { "intl", "5.0.0" }, + { "intl", "4.0.0" }, + { "intl", "3.0.0" }, + }; +} + + +TEST_SUITE("ObjSolver") +{ + TEST_CASE("Create a solver") + { + auto pool = ObjPool(); + auto [repo_id, repo] = pool.add_repo("forge"); + + for (const auto& pkg : make_simple_packages()) + { + auto [solv_id, solv] = repo.add_solvable(); + solv.set_name(pkg.name); + solv.set_version(pkg.version); + for (const auto& dep : pkg.dependencies) + { + solv.add_dependency(pool.add_conda_dependency(dep)); + } + solv.add_self_provide(); + } + repo.internalize(); + + auto solver = ObjSolver(pool); + + CHECK_EQ(solver.problem_count(), 0); + + SUBCASE("Flag default value") + { + CHECK_FALSE(solver.get_flag(SOLVER_FLAG_ALLOW_DOWNGRADE)); + } + + SUBCASE("Set flag") + { + solver.set_flag(SOLVER_FLAG_ALLOW_DOWNGRADE, true); + CHECK(solver.get_flag(SOLVER_FLAG_ALLOW_DOWNGRADE)); + } + + SUBCASE("Add packages") + { + SUBCASE("Solve successfully") + { + // The job is matched with the ``provides`` field of the solvable + auto jobs = ObjQueue{ + SOLVER_INSTALL | SOLVER_SOLVABLE_PROVIDES, + pool.add_conda_dependency("menu"), + SOLVER_INSTALL | SOLVER_SOLVABLE_PROVIDES, + pool.add_conda_dependency("icons=2.*"), + }; + CHECK(solver.solve(pool, jobs)); + CHECK_EQ(solver.problem_count(), 0); + } + + SUBCASE("Solve unsuccessfully") + { + // The job is matched with the ``provides`` field of the solvable + auto jobs = ObjQueue{ + SOLVER_INSTALL | SOLVER_SOLVABLE_PROVIDES, + pool.add_conda_dependency("menu"), + SOLVER_INSTALL | SOLVER_SOLVABLE_PROVIDES, + pool.add_conda_dependency("icons=1.*"), + SOLVER_INSTALL | SOLVER_SOLVABLE_PROVIDES, + pool.add_conda_dependency("intl=5.*"), + }; + + CHECK_FALSE(solver.solve(pool, jobs)); + CHECK_NE(solver.problem_count(), 0); + + auto all_rules = ObjQueue{}; + solver.for_each_problem_id( + [&](auto pb) + { + auto pb_rules = solver.problem_rules(pb); + all_rules.insert(all_rules.end(), pb_rules.cbegin(), pb_rules.cend()); + } + ); + CHECK_FALSE(all_rules.empty()); + } + } + } +}