Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add channel priority to solve task and expose to python solve #598

Merged
merged 13 commits into from
Apr 25, 2024
3 changes: 2 additions & 1 deletion crates/rattler-bin/src/commands/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use rattler_repodata_gateway::fetch::{
use rattler_repodata_gateway::sparse::SparseRepoData;
use rattler_solve::{
libsolv_c::{self},
resolvo, SolverImpl, SolverTask,
resolvo, ChannelPriority, SolverImpl, SolverTask,
};
use reqwest::Client;
use std::sync::Arc;
Expand Down Expand Up @@ -235,6 +235,7 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> {
specs,
pinned_packages: Vec::new(),
timeout: opt.timeout.map(Duration::from_millis),
channel_priority: ChannelPriority::Strict,
};

// Next, use a solver to solve this specific problem. This provides us with all the operations
Expand Down
4 changes: 3 additions & 1 deletion crates/rattler_solve/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingM
use rattler_conda_types::ParseStrictness::Strict;
use rattler_conda_types::{Channel, ChannelConfig, MatchSpec};
use rattler_repodata_gateway::sparse::SparseRepoData;
use rattler_solve::{SolverImpl, SolverTask};
use rattler_solve::{ChannelPriority, SolverImpl, SolverTask};

fn conda_json_path() -> String {
format!(
Expand Down Expand Up @@ -69,6 +69,7 @@ fn bench_solve_environment(c: &mut Criterion, specs: Vec<&str>) {
virtual_packages: vec![],
specs: specs.clone(),
timeout: None,
channel_priority: ChannelPriority::Strict,
}))
.unwrap()
});
Expand All @@ -85,6 +86,7 @@ fn bench_solve_environment(c: &mut Criterion, specs: Vec<&str>) {
virtual_packages: vec![],
specs: specs.clone(),
timeout: None,
channel_priority: ChannelPriority::Strict,
}))
.unwrap()
});
Expand Down
19 changes: 19 additions & 0 deletions crates/rattler_solve/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,21 @@ impl fmt::Display for SolveError {
}
}

/// Represents the channel priority option to use during solves.
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum ChannelPriority {
/// The channel that the package is first found in will be used as the only channel
/// for that package.
Strict,

// Conda also has "Flexible" as an option, where packages present in multiple channels
// are only taken from lower-priority channels when this prevents unsatisfiable environment
// errors, but this would need implementation in the solvers.
// Flexible,
/// Packages can be retrieved from any channel as package version takes precedence.
Disabled,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shoudl we call this AllEqual or NoPriority/NoPriorities? Just thinking of something that makes it really clear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking it was nice to be consistent with conda, but also happy to change to NoPriority if you prefer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, totally. I am also OK with keeping the naming. Just sometimes I think we have the opportunity to come up with new & improved terminology :)

}

/// Represents a dependency resolution task, to be solved by one of the backends (currently only
/// libsolv is supported)
pub struct SolverTask<TAvailablePackagesIterator> {
Expand Down Expand Up @@ -102,6 +117,10 @@ pub struct SolverTask<TAvailablePackagesIterator> {

/// The timeout after which the solver should stop
pub timeout: Option<std::time::Duration>,

/// The channel priority to solve with, either [`ChannelPriority::Strict`] or
/// [`ChannelPriority::Disabled`]
pub channel_priority: ChannelPriority,
}

/// A representation of a collection of [`RepoDataRecord`] usable by a [`SolverImpl`]
Expand Down
8 changes: 6 additions & 2 deletions crates/rattler_solve/src/libsolv_c/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,14 @@ pub fn cache_repodata(_url: String, _data: &[RepoDataRecord]) -> LibcByteSlice {
/// Note: this function relies on primitives that are only available on unix-like operating systems,
/// and will panic if called from another platform (e.g. Windows)
#[cfg(target_family = "unix")]
pub fn cache_repodata(url: String, data: &[RepoDataRecord]) -> LibcByteSlice {
pub fn cache_repodata(
url: String,
data: &[RepoDataRecord],
channel_priority: Option<i32>,
) -> LibcByteSlice {
// Add repodata to a new pool + repo
let pool = Pool::default();
let repo = Repo::new(&pool, url);
let repo = Repo::new(&pool, url, channel_priority.unwrap_or(0));
add_repodata_records(&pool, &repo, data);

// Export repo to .solv in memory
Expand Down
65 changes: 56 additions & 9 deletions crates/rattler_solve/src/libsolv_c/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
//! Provides an solver implementation based on the [`rattler_libsolv_c`] crate.

use crate::{IntoRepoData, SolverRepoData};
use crate::{ChannelPriority, IntoRepoData, SolverRepoData};
use crate::{SolveError, SolverTask};
pub use input::cache_repodata;
use input::{add_repodata_records, add_solv_file, add_virtual_packages};
pub use libc_byte_slice::LibcByteSlice;
use output::get_required_packages;
use rattler_conda_types::RepoDataRecord;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::ffi::CString;
use std::mem::ManuallyDrop;
use wrapper::{
Expand Down Expand Up @@ -106,8 +106,47 @@ impl super::SolverImpl for Solver {
});
pool.set_debug_level(Verbosity::Low);

let repodatas: Vec<Self::RepoData<'_>> = task
.available_packages
.into_iter()
.map(IntoRepoData::into)
.collect();

// Determine the channel priority for each channel in the repodata in the order in which
// the repodatas are passed, where the first channel will have the highest priority value
// and each successive channel will descend in priority value. If not strict, the highest
// priority value will be 0 and the channel priority map will not be populated as it will
// not be used.
let mut highest_priority: i32 = 0;
let channel_priority: HashMap<String, i32> =
if task.channel_priority == ChannelPriority::Strict {
let mut seen_channels = HashSet::new();
let mut channel_order: Vec<String> = Vec::new();
for channel in repodatas
.iter()
.filter(|&r| !r.records.is_empty())
.map(|r| r.records[0].channel.clone())
{
if !seen_channels.contains(&channel) {
channel_order.push(channel.clone());
seen_channels.insert(channel);
}
}
let mut channel_priority = HashMap::new();
for (index, channel) in channel_order.iter().enumerate() {
let reverse_index = channel_order.len() - index;
if index == 0 {
highest_priority = reverse_index as i32;
}
channel_priority.insert(channel.clone(), reverse_index as i32);
}
channel_priority
} else {
HashMap::new()
};

// Add virtual packages
let repo = Repo::new(&pool, "virtual_packages");
let repo = Repo::new(&pool, "virtual_packages", highest_priority);
add_virtual_packages(&pool, &repo, &task.virtual_packages);

// Mark the virtual packages as installed.
Expand All @@ -116,15 +155,19 @@ impl super::SolverImpl for Solver {
// Create repos for all channel + platform combinations
let mut repo_mapping = HashMap::new();
let mut all_repodata_records = Vec::new();
for repodata in task.available_packages.into_iter().map(IntoRepoData::into) {
for repodata in repodatas.iter() {
if repodata.records.is_empty() {
continue;
}

let channel_name = &repodata.records[0].channel;

// We dont want to drop the Repo, its stored in the pool anyway.
let repo = ManuallyDrop::new(Repo::new(&pool, channel_name));
let priority: i32 = if task.channel_priority == ChannelPriority::Strict {
*channel_priority.get(channel_name).unwrap()
} else {
0
};
let repo = ManuallyDrop::new(Repo::new(&pool, channel_name, priority));

if let Some(solv_file) = repodata.solv_file {
add_solv_file(&pool, &repo, solv_file);
Expand All @@ -134,19 +177,19 @@ impl super::SolverImpl for Solver {

// Keep our own info about repodata_records
repo_mapping.insert(repo.id(), repo_mapping.len());
all_repodata_records.push(repodata.records);
all_repodata_records.push(repodata.records.clone());
}

// Create a special pool for records that are already installed or locked.
let repo = Repo::new(&pool, "locked");
let repo = Repo::new(&pool, "locked", highest_priority);
let installed_solvables = add_repodata_records(&pool, &repo, &task.locked_packages);

// Also add the installed records to the repodata
repo_mapping.insert(repo.id(), repo_mapping.len());
all_repodata_records.push(task.locked_packages.iter().collect());

// Create a special pool for records that are pinned and cannot be changed.
let repo = Repo::new(&pool, "pinned");
let repo = Repo::new(&pool, "pinned", highest_priority);
let pinned_solvables = add_repodata_records(&pool, &repo, &task.pinned_packages);

// Also add the installed records to the repodata
Expand Down Expand Up @@ -179,6 +222,10 @@ impl super::SolverImpl for Solver {
let mut solver = pool.create_solver();
solver.set_flag(SolverFlag::allow_uninstall(), true);
solver.set_flag(SolverFlag::allow_downgrade(), true);
solver.set_flag(
SolverFlag::strict_channel_priority(),
task.channel_priority == ChannelPriority::Strict,
);

let transaction = solver.solve(&mut goal).map_err(SolveError::Unsolvable)?;

Expand Down
8 changes: 7 additions & 1 deletion crates/rattler_solve/src/libsolv_c/wrapper/flags.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use super::ffi::{SOLVER_FLAG_ALLOW_DOWNGRADE, SOLVER_FLAG_ALLOW_UNINSTALL};
use super::ffi::{
SOLVER_FLAG_ALLOW_DOWNGRADE, SOLVER_FLAG_ALLOW_UNINSTALL, SOLVER_FLAG_STRICT_REPO_PRIORITY,
};

#[repr(transparent)]
pub struct SolverFlag(u32);
Expand All @@ -12,6 +14,10 @@ impl SolverFlag {
SolverFlag(SOLVER_FLAG_ALLOW_DOWNGRADE)
}

pub fn strict_channel_priority() -> SolverFlag {
SolverFlag(SOLVER_FLAG_STRICT_REPO_PRIORITY)
}

pub fn inner(self) -> i32 {
self.0 as i32
}
Expand Down
11 changes: 6 additions & 5 deletions crates/rattler_solve/src/libsolv_c/wrapper/repo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ impl<'pool> Drop for Repo<'pool> {

impl<'pool> Repo<'pool> {
/// Constructs a repo in the provided pool, associated to the given url
pub fn new(pool: &Pool, url: impl AsRef<str>) -> Repo<'_> {
pub fn new(pool: &Pool, url: impl AsRef<str>, priority: i32) -> Repo<'_> {
let c_url = c_string(url);

unsafe {
let repo_ptr = ffi::repo_create(pool.raw_ptr(), c_url.as_ptr());
let non_null_ptr = NonNull::new(repo_ptr).expect("repo ptr was null");
let mut non_null_ptr = NonNull::new(repo_ptr).expect("repo ptr was null");
non_null_ptr.as_mut().priority = priority;
Repo(non_null_ptr, PhantomData)
}
}
Expand Down Expand Up @@ -120,7 +121,7 @@ mod tests {
#[test]
fn test_repo_creation() {
let pool = Pool::default();
let mut _repo = Repo::new(&pool, "conda-forge");
let mut _repo = Repo::new(&pool, "conda-forge", 0);
}

#[test]
Expand All @@ -132,7 +133,7 @@ mod tests {
{
// Create a pool and a repo
let pool = Pool::default();
let repo = Repo::new(&pool, "conda-forge");
let repo = Repo::new(&pool, "conda-forge", 0);

// Add a solvable with a particular name
let solvable_id = repo.add_solvable();
Expand All @@ -148,7 +149,7 @@ mod tests {

// Create a clean pool and repo
let pool = Pool::default();
let repo = Repo::new(&pool, "conda-forge");
let repo = Repo::new(&pool, "conda-forge", 0);

// Open and read the .solv file
let mode = c_string("rb");
Expand Down
14 changes: 9 additions & 5 deletions crates/rattler_solve/src/resolvo/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Provides an solver implementation based on the [`resolvo`] crate.

use crate::{IntoRepoData, SolveError, SolverRepoData, SolverTask};
use crate::{ChannelPriority, IntoRepoData, SolveError, SolverRepoData, SolverTask};
use rattler_conda_types::package::ArchiveType;
use rattler_conda_types::{
GenericVirtualPackage, MatchSpec, NamelessMatchSpec, PackageRecord, ParseMatchSpecError,
Expand Down Expand Up @@ -177,6 +177,7 @@ impl<'a> CondaDependencyProvider<'a> {
virtual_packages: &'a [GenericVirtualPackage],
match_specs: &[MatchSpec],
stop_time: Option<std::time::SystemTime>,
channel_priority: ChannelPriority,
) -> Self {
let pool = Rc::new(Pool::default());
let mut records: HashMap<NameId, Candidates> = HashMap::default();
Expand Down Expand Up @@ -283,10 +284,12 @@ impl<'a> CondaDependencyProvider<'a> {
}

// Enforce channel priority
// This functions makes the assumtion that the records are given in order of the channels.
if let Some(first_channel) = package_name_found_in_channel
.get(&record.package_record.name.as_normalized().to_string())
{
// This function makes the assumption that the records are given in order of the channels.
if let (Some(first_channel), ChannelPriority::Strict) = (
package_name_found_in_channel
.get(&record.package_record.name.as_normalized().to_string()),
channel_priority,
) {
// Add the record to the excluded list when it is from a different channel.
if first_channel != &&record.channel {
tracing::debug!(
Expand Down Expand Up @@ -445,6 +448,7 @@ impl super::SolverImpl for Solver {
&task.virtual_packages,
task.specs.clone().as_ref(),
stop_time,
task.channel_priority,
);
let pool = provider.pool.clone();

Expand Down
Loading
Loading