Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow immutable borrow to access QuantumCircuit.parameters (backport #12918) #12958

Merged
merged 1 commit into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/circuit/src/circuit_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ impl CircuitData {
/// Get a (cached) sorted list of the Python-space `Parameter` instances tracked by this circuit
/// data's parameter table.
#[getter]
pub fn get_parameters<'py>(&mut self, py: Python<'py>) -> Bound<'py, PyList> {
pub fn get_parameters<'py>(&self, py: Python<'py>) -> Bound<'py, PyList> {
self.param_table.py_parameters(py)
}

Expand Down
129 changes: 77 additions & 52 deletions crates/circuit/src/parameter_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

use std::cell::OnceCell;

use hashbrown::hash_map::Entry;
use hashbrown::{HashMap, HashSet};
use thiserror::Error;
Expand Down Expand Up @@ -123,18 +125,17 @@ pub struct ParameterTable {
by_name: HashMap<PyBackedStr, ParameterUuid>,
/// Additional information on any `ParameterVector` instances that have elements in the circuit.
vectors: HashMap<VectorUuid, VectorInfo>,
/// Sort order of the parameters. This is lexicographical for most parameters, except elements
/// of a `ParameterVector` are sorted within the vector by numerical index. We calculate this
/// on demand and cache it; an empty `order` implies it is not currently calculated. We don't
/// use `Option<Vec>` so we can re-use the allocation for partial parameter bindings.
/// Cache of the sort order of the parameters. This is lexicographical for most parameters,
/// except elements of a `ParameterVector` are sorted within the vector by numerical index. We
/// calculate this on demand and cache it.
///
/// Any method that adds or a removes a parameter is responsible for invalidating this cache.
order: Vec<ParameterUuid>,
/// Any method that adds or removes a parameter needs to invalidate this.
order_cache: OnceCell<Vec<ParameterUuid>>,
/// Cache of a Python-space list of the parameter objects, in order. We only generate this
/// specifically when asked.
///
/// Any method that adds or a removes a parameter is responsible for invalidating this cache.
py_parameters: Option<Py<PyList>>,
/// Any method that adds or removes a parameter needs to invalidate this.
py_parameters_cache: OnceCell<Py<PyList>>,
}

impl ParameterTable {
Expand Down Expand Up @@ -194,8 +195,6 @@ impl ParameterTable {
None
};
self.by_name.insert(name.clone(), uuid);
self.order.clear();
self.py_parameters = None;
let mut uses = HashSet::new();
if let Some(usage) = usage {
uses.insert_unique_unchecked(usage);
Expand All @@ -206,6 +205,7 @@ impl ParameterTable {
element,
object: param_ob.clone().unbind(),
});
self.invalidate_cache();
}
}
Ok(uuid)
Expand All @@ -226,43 +226,39 @@ impl ParameterTable {
}

/// Get the (maybe cached) Python list of the sorted `Parameter` objects.
pub fn py_parameters<'py>(&mut self, py: Python<'py>) -> Bound<'py, PyList> {
if let Some(py_parameters) = self.py_parameters.as_ref() {
return py_parameters.clone_ref(py).into_bound(py);
}
self.ensure_sorted();
let out = PyList::new_bound(
py,
self.order
.iter()
.map(|uuid| self.by_uuid[uuid].object.clone_ref(py).into_bound(py)),
);
self.py_parameters = Some(out.clone().unbind());
out
pub fn py_parameters<'py>(&self, py: Python<'py>) -> Bound<'py, PyList> {
self.py_parameters_cache
.get_or_init(|| {
PyList::new_bound(
py,
self.order_cache
.get_or_init(|| self.sorted_order())
.iter()
.map(|uuid| self.by_uuid[uuid].object.bind(py).clone()),
)
.unbind()
})
.bind(py)
.clone()
}

/// Get a Python set of all tracked `Parameter` objects.
pub fn py_parameters_unsorted<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PySet>> {
PySet::new_bound(py, self.by_uuid.values().map(|info| &info.object))
}

/// Ensure that the `order` field is populated and sorted.
fn ensure_sorted(&mut self) {
// If `order` is already populated, it's sorted; it's the responsibility of the methods of
// this struct that mutate it to invalidate the cache.
if !self.order.is_empty() {
return;
}
self.order.reserve(self.by_uuid.len());
self.order.extend(self.by_uuid.keys());
self.order.sort_unstable_by_key(|uuid| {
/// Get the sorted order of the `ParameterTable`. This does not access the cache.
fn sorted_order(&self) -> Vec<ParameterUuid> {
let mut out = self.by_uuid.keys().copied().collect::<Vec<_>>();
out.sort_unstable_by_key(|uuid| {
let info = &self.by_uuid[uuid];
if let Some(vec) = info.element.as_ref() {
(&self.vectors[&vec.vector_uuid].name, vec.index)
} else {
(&info.name, 0)
}
})
});
out
}

/// Add a use of a parameter to the table.
Expand Down Expand Up @@ -305,9 +301,8 @@ impl ParameterTable {
vec_entry.remove_entry();
}
}
self.order.clear();
self.py_parameters = None;
entry.remove_entry();
self.invalidate_cache();
}
Ok(())
}
Expand All @@ -332,26 +327,28 @@ impl ParameterTable {
(vector_info.refcount > 0).then_some(vector_info)
});
}
self.order.clear();
self.py_parameters = None;
self.invalidate_cache();
Ok(info.uses)
}

/// Clear this table, yielding the Python parameter objects and their uses in sorted order.
///
/// The clearing effect is eager and not dependent on the iteration.
pub fn drain_ordered(
&'_ mut self,
) -> impl Iterator<Item = (Py<PyAny>, HashSet<ParameterUse>)> + '_ {
self.ensure_sorted();
&mut self,
) -> impl ExactSizeIterator<Item = (Py<PyAny>, HashSet<ParameterUse>)> {
let order = self
.order_cache
.take()
.unwrap_or_else(|| self.sorted_order());
let by_uuid = ::std::mem::take(&mut self.by_uuid);
self.by_name.clear();
self.vectors.clear();
self.py_parameters = None;
self.order.drain(..).map(|uuid| {
let info = self
.by_uuid
.remove(&uuid)
.expect("tracked UUIDs should be consistent");
(info.object, info.uses)
})
self.py_parameters_cache.take();
ParameterTableDrain {
order: order.into_iter(),
by_uuid,
}
}

/// Empty this `ParameterTable` of all its contents. This does not affect the capacities of the
Expand All @@ -360,8 +357,12 @@ impl ParameterTable {
self.by_uuid.clear();
self.by_name.clear();
self.vectors.clear();
self.order.clear();
self.py_parameters = None;
self.invalidate_cache();
}

fn invalidate_cache(&mut self) {
self.order_cache.take();
self.py_parameters_cache.take();
}

/// Expose the tracked data for a given parameter as directly as possible to Python space.
Expand Down Expand Up @@ -396,9 +397,33 @@ impl ParameterTable {
visit.call(&info.object)?
}
// We don't need to / can't visit the `PyBackedStr` stores.
if let Some(list) = self.py_parameters.as_ref() {
if let Some(list) = self.py_parameters_cache.get() {
visit.call(list)?
}
Ok(())
}
}

struct ParameterTableDrain {
order: ::std::vec::IntoIter<ParameterUuid>,
by_uuid: HashMap<ParameterUuid, ParameterInfo>,
}
impl Iterator for ParameterTableDrain {
type Item = (Py<PyAny>, HashSet<ParameterUse>);

fn next(&mut self) -> Option<Self::Item> {
self.order.next().map(|uuid| {
let info = self
.by_uuid
.remove(&uuid)
.expect("tracked UUIDs should be consistent");
(info.object, info.uses)
})
}

fn size_hint(&self) -> (usize, Option<usize>) {
self.order.size_hint()
}
}
impl ExactSizeIterator for ParameterTableDrain {}
impl ::std::iter::FusedIterator for ParameterTableDrain {}
Loading