Skip to content

Commit

Permalink
Use OnceCell instead of RefCell
Browse files Browse the repository at this point in the history
`OnceCell` has less runtime checking than `RefCell` (only whether it is
initialised or not, which is an `Option` check), and better represents
the dynamic extensions to the borrow checker that we actually need for
the caching in this method.

All methods that can invalidate the cache all necessarily take `&mut
ParameterTable` already, since they will modify Rust-space data.  A
`OnceCell` can be deinitialised through a mutable reference, so this is
fine.  The only reason a `&ParameterTable` method would need to mutate
the cache is to create it, which is the allowed set of `OnceCell`
operations.
  • Loading branch information
jakelishman committed Aug 9, 2024
1 parent 98f0856 commit 54d261c
Showing 1 changed file with 38 additions and 69 deletions.
107 changes: 38 additions & 69 deletions crates/circuit/src/parameter_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

use std::cell::RefCell;
use std::cell::OnceCell;

use hashbrown::hash_map::Entry;
use hashbrown::{HashMap, HashSet};
Expand Down Expand Up @@ -116,30 +116,6 @@ impl<'py> FromPyObject<'py> for VectorUuid {
}
}

#[derive(Clone, Default, Debug)]
struct ParameterTableOrder {
/// The Rust-space sort order.
uuids: Vec<ParameterUuid>,
/// Cache of a Python-space list of the parameter objects, in order. We only generate this
/// specifically when asked.
py_parameters: Option<Py<PyList>>,
}

impl ParameterTableOrder {
fn uuids(&self) -> Option<&[ParameterUuid]> {
(!self.uuids.is_empty()).then_some(self.uuids.as_slice())
}

fn py_parameters(&self) -> Option<&Py<PyList>> {
self.py_parameters.as_ref()
}

fn invalidate(&mut self) {
self.uuids.clear();
self.py_parameters = None;
}
}

#[derive(Clone, Default, Debug)]
pub struct ParameterTable {
/// Mapping of the parameter key (its UUID) to the information on it tracked by this table.
Expand All @@ -149,14 +125,17 @@ pub struct ParameterTable {
by_name: HashMap<PyBackedStr, ParameterUuid>,
/// Additional information on any `ParameterVector` instances that have elements in the circuit.
vectors: HashMap<VectorUuid, VectorInfo>,
/// Cache related to the sort order of the parameters. This is lexicographical for most
/// parameters, except elements of a `ParameterVector` are sorted within the vector by numerical
/// index. We calculate this on demand and cache it; an empty `order` implies it is not
/// currently calculated. We don't use `Option<Vec>` so we can re-use the allocation for
/// partial parameter bindings.
/// Cache of the sort order of the parameters. This is lexicographical for most parameters,
/// except elements of a `ParameterVector` are sorted within the vector by numerical index. We
/// calculate this on demand and cache it.
///
/// Any method that adds or removes a parameter needs to invalidate this.
order_cache: RefCell<ParameterTableOrder>,
order_cache: OnceCell<Vec<ParameterUuid>>,
/// Cache of a Python-space list of the parameter objects, in order. We only generate this
/// specifically when asked.
///
/// Any method that adds or removes a parameter needs to invalidate this.
py_parameters_cache: OnceCell<Py<PyList>>,
}

impl ParameterTable {
Expand Down Expand Up @@ -216,7 +195,6 @@ impl ParameterTable {
None
};
self.by_name.insert(name.clone(), uuid);
self.order_cache.get_mut().invalidate();
let mut uses = HashSet::new();
if let Some(usage) = usage {
uses.insert_unique_unchecked(usage);
Expand All @@ -227,6 +205,7 @@ impl ParameterTable {
element,
object: param_ob.clone().unbind(),
});
self.invalidate_cache();
}
}
Ok(uuid)
Expand All @@ -248,32 +227,19 @@ impl ParameterTable {

/// Get the (maybe cached) Python list of the sorted `Parameter` objects.
pub fn py_parameters<'py>(&self, py: Python<'py>) -> Bound<'py, PyList> {
if let Some(py_parameters) = self.order_cache.borrow().py_parameters() {
return py_parameters.clone_ref(py).into_bound(py);
}
let make_parameters = |order: &[ParameterUuid]| {
PyList::new_bound(
py,
order
.iter()
.map(|uuid| self.by_uuid[uuid].object.bind(py).clone()),
)
};
let out = match self.order_cache.borrow().uuids() {
Some(uuids) => make_parameters(uuids),
None => {
let uuids = self.sorted_order();
let out = make_parameters(&uuids);
if let Ok(mut cache) = self.order_cache.try_borrow_mut() {
cache.uuids = uuids;
}
out
}
};
if let Ok(mut cache) = self.order_cache.try_borrow_mut() {
cache.py_parameters = Some(out.clone().unbind());
}
out
self.py_parameters_cache
.get_or_init(|| {
PyList::new_bound(
py,
self.order_cache
.get_or_init(|| self.sorted_order())
.iter()
.map(|uuid| self.by_uuid[uuid].object.bind(py).clone()),
)
.unbind()
})
.bind(py)
.clone()
}

/// Get a Python set of all tracked `Parameter` objects.
Expand Down Expand Up @@ -335,8 +301,8 @@ impl ParameterTable {
vec_entry.remove_entry();
}
}
self.order_cache.get_mut().invalidate();
entry.remove_entry();
self.invalidate_cache();
}
Ok(())
}
Expand All @@ -361,7 +327,7 @@ impl ParameterTable {
(vector_info.refcount > 0).then_some(vector_info)
});
}
self.order_cache.get_mut().invalidate();
self.invalidate_cache();
Ok(info.uses)
}

Expand All @@ -371,16 +337,14 @@ impl ParameterTable {
pub fn drain_ordered(
&mut self,
) -> impl ExactSizeIterator<Item = (Py<PyAny>, HashSet<ParameterUse>)> {
let cache = self.order_cache.get_mut();
cache.py_parameters = None;
let order = if cache.uuids.is_empty() {
self.sorted_order()
} else {
::std::mem::take(&mut cache.uuids)
};
let order = self
.order_cache
.take()
.unwrap_or_else(|| self.sorted_order());
let by_uuid = ::std::mem::take(&mut self.by_uuid);
self.by_name.clear();
self.vectors.clear();
self.py_parameters_cache.take();
ParameterTableDrain {
order: order.into_iter(),
by_uuid,
Expand All @@ -393,7 +357,12 @@ impl ParameterTable {
self.by_uuid.clear();
self.by_name.clear();
self.vectors.clear();
self.order_cache.get_mut().invalidate();
self.invalidate_cache();
}

fn invalidate_cache(&mut self) {
self.order_cache.take();
self.py_parameters_cache.take();
}

/// Expose the tracked data for a given parameter as directly as possible to Python space.
Expand Down Expand Up @@ -428,7 +397,7 @@ impl ParameterTable {
visit.call(&info.object)?
}
// We don't need to / can't visit the `PyBackedStr` stores.
if let Some(list) = self.order_cache.borrow().py_parameters() {
if let Some(list) = self.py_parameters_cache.get() {
visit.call(list)?
}
Ok(())
Expand Down

0 comments on commit 54d261c

Please sign in to comment.