From cf62106fd65f501c1a608449d96fc9ad93f06e59 Mon Sep 17 00:00:00 2001 From: Peter Huene Date: Wed, 7 Feb 2024 12:27:00 -0800 Subject: [PATCH] Fix transaction implementation in `LogState::validate`. (#243) This PR removes the incorrect transaction implementation for `LogState::validate`. Now `validate` takes ownership of `self` and returns `Result`. This means that callers that expect to keep the log state following an invalid log entry must clone the state prior to validation. As the in-memory data store is the only store that persists the log state in memory, it now clones the state before validation and updates the log state upon successful validation. For the postgres data store, the log state was loaded from the database and is discarded on error, so no clone is necessary. Fixes #242. --- crates/client/src/lib.rs | 15 +-- crates/protocol/src/lib.rs | 4 +- crates/protocol/src/operator/state.rs | 109 +++++------------- crates/protocol/src/package/state.rs | 119 ++++++-------------- crates/protocol/tests/operator.rs | 8 +- crates/protocol/tests/package.rs | 8 +- crates/server/src/api/debug/mod.rs | 5 +- crates/server/src/datastore/memory.rs | 34 +++--- crates/server/src/datastore/postgres/mod.rs | 6 +- 9 files changed, 101 insertions(+), 207 deletions(-) diff --git a/crates/client/src/lib.rs b/crates/client/src/lib.rs index 628ddd48..e3af19ec 100644 --- a/crates/client/src/lib.rs +++ b/crates/client/src/lib.rs @@ -432,7 +432,7 @@ impl Client { if operator.head_registry_index.is_none() || proto_envelope.registry_index > operator.head_registry_index.unwrap() { - operator + operator.state = operator .state .validate(&proto_envelope.envelope) .map_err(|inner| ClientError::OperatorValidationFailed { inner })?; @@ -454,12 +454,13 @@ impl Client { if package.head_registry_index.is_none() || proto_envelope.registry_index > package.head_registry_index.unwrap() { - package - .state - .validate(&proto_envelope.envelope) - .map_err(|inner| ClientError::PackageValidationFailed { - name: package.name.clone(), - inner, + let state = std::mem::take(&mut package.state); + package.state = + state.validate(&proto_envelope.envelope).map_err(|inner| { + ClientError::PackageValidationFailed { + name: package.name.clone(), + inner, + } })?; package.head_registry_index = Some(proto_envelope.registry_index); package.head_fetch_token = Some(record.fetch_token); diff --git a/crates/protocol/src/lib.rs b/crates/protocol/src/lib.rs index e3e7ffba..5ab671dd 100644 --- a/crates/protocol/src/lib.rs +++ b/crates/protocol/src/lib.rs @@ -22,7 +22,7 @@ pub trait Record: Clone + Decode + Send + Sync { fn contents(&self) -> HashSet<&AnyHash>; } -/// Trait implemented by the validator types. +/// Trait implemented by the log state types. pub trait Validator: std::fmt::Debug + Serialize + DeserializeOwned + Default + Send + Sync { @@ -33,7 +33,7 @@ pub trait Validator: type Error: Send; /// Validates the given record. - fn validate(&mut self, record: &ProtoEnvelope) -> Result<(), Self::Error>; + fn validate(self, record: &ProtoEnvelope) -> Result; } /// Helpers for converting to and from protobuf diff --git a/crates/protocol/src/operator/state.rs b/crates/protocol/src/operator/state.rs index 2cdf3671..11545d26 100644 --- a/crates/protocol/src/operator/state.rs +++ b/crates/protocol/src/operator/state.rs @@ -114,16 +114,16 @@ pub struct LogState { /// This is `None` until the first (i.e. init) record is validated. #[serde(skip_serializing_if = "Option::is_none")] algorithm: Option, - /// The current head of the validator. + /// The current head of the state. #[serde(skip_serializing_if = "Option::is_none")] head: Option, /// The permissions of each key. #[serde(skip_serializing_if = "IndexMap::is_empty")] permissions: IndexMap>, - /// The keys known to the validator. + /// The keys known to the state. #[serde(skip_serializing_if = "IndexMap::is_empty")] keys: IndexMap, - /// The namespaces known to the validator. The key is the lowercased namespace. + /// The namespaces known to the state. The key is the lowercased namespace. #[serde(skip_serializing_if = "IndexMap::is_empty")] namespaces: IndexMap, } @@ -146,20 +146,14 @@ impl LogState { /// It is expected that `validate` is called in order of the /// records in the log. /// - /// This operation is transactional: if any entry in the record - /// fails to validate, the validator state will remain unchanged. + /// Note that on failure, the log state is consumed to prevent + /// invalid state from being used in future validations. pub fn validate( - &mut self, + mut self, record: &ProtoEnvelope, - ) -> Result<(), ValidationError> { - let snapshot = self.snapshot(); - - let result = self.validate_record(record); - if result.is_err() { - self.rollback(snapshot); - } - - result + ) -> Result { + self.validate_record(record)?; + Ok(self) } /// Gets the public key of the given key id. @@ -230,7 +224,7 @@ impl LogState { // Validate the envelope signature model::OperatorRecord::verify(key, envelope.content_bytes(), envelope.signature())?; - // Update the validator head + // Update the state head self.head = Some(Head { digest: RecordId::operator_record::(envelope), timestamp: record.timestamp, @@ -466,59 +460,17 @@ impl LogState { } Ok(()) } - - fn snapshot(&self) -> Snapshot { - let Self { - algorithm, - head, - permissions, - keys, - namespaces, - } = self; - - Snapshot { - algorithm: *algorithm, - head: head.clone(), - permissions: permissions.len(), - keys: keys.len(), - namespaces: namespaces.len(), - } - } - - fn rollback(&mut self, snapshot: Snapshot) { - let Snapshot { - algorithm, - head, - permissions, - keys, - namespaces, - } = snapshot; - - self.algorithm = algorithm; - self.head = head; - self.permissions.truncate(permissions); - self.keys.truncate(keys); - self.namespaces.truncate(namespaces); - } } impl crate::Validator for LogState { type Record = model::OperatorRecord; type Error = ValidationError; - fn validate(&mut self, record: &ProtoEnvelope) -> Result<(), Self::Error> { + fn validate(self, record: &ProtoEnvelope) -> Result { self.validate(record) } } -struct Snapshot { - algorithm: Option, - head: Option, - permissions: usize, - keys: usize, - namespaces: usize, -} - #[cfg(test)] mod tests { use pretty_assertions::assert_eq; @@ -547,11 +499,11 @@ mod tests { let envelope = ProtoEnvelope::signed_contents(&alice_priv, record).expect("failed to sign envelope"); - let mut validator = LogState::default(); - validator.validate(&envelope).unwrap(); + let state = LogState::default(); + let state = state.validate(&envelope).unwrap(); assert_eq!( - validator, + state, LogState { head: Some(Head { digest: RecordId::operator_record::(&envelope), @@ -591,8 +543,8 @@ mod tests { let envelope = ProtoEnvelope::signed_contents(&alice_priv, record).expect("failed to sign envelope"); - let mut validator = LogState::default(); - validator.validate(&envelope).unwrap(); + let state = LogState::default(); + let state = state.validate(&envelope).unwrap(); let expected = LogState { head: Some(Head { @@ -612,7 +564,7 @@ mod tests { namespaces: IndexMap::new(), }; - assert_eq!(validator, expected); + assert_eq!(state, expected); let record = model::OperatorRecord { prev: Some(RecordId::operator_record::(&envelope)), @@ -639,14 +591,11 @@ mod tests { let envelope = ProtoEnvelope::signed_contents(&alice_priv, record).expect("failed to sign envelope"); - // This validation should fail and the validator state should remain unchanged - match validator.validate(&envelope).unwrap_err() { + // This validation should fail + match state.validate(&envelope).unwrap_err() { ValidationError::PermissionNotFoundToRevoke { .. } => {} _ => panic!("expected a different error"), } - - // The validator should not have changed - assert_eq!(validator, expected); } #[test] @@ -676,8 +625,8 @@ mod tests { let envelope = ProtoEnvelope::signed_contents(&alice_priv, record).expect("failed to sign envelope"); - let mut validator = LogState::default(); - validator.validate(&envelope).unwrap(); + let state = LogState::default(); + let state = state.validate(&envelope).unwrap(); let expected = LogState { head: Some(Head { @@ -714,7 +663,7 @@ mod tests { ]), }; - assert_eq!(validator, expected); + assert_eq!(state, expected); { let record = model::OperatorRecord { @@ -737,14 +686,11 @@ mod tests { let envelope = ProtoEnvelope::signed_contents(&alice_priv, record) .expect("failed to sign envelope"); - // This validation should fail and the validator state should remain unchanged - match validator.validate(&envelope).unwrap_err() { + // This validation should fail + match state.clone().validate(&envelope).unwrap_err() { ValidationError::NamespaceAlreadyDefined { .. } => {} _ => panic!("expected a different error"), } - - // The validator should not have changed - assert_eq!(validator, expected); } { @@ -768,14 +714,11 @@ mod tests { let envelope = ProtoEnvelope::signed_contents(&alice_priv, record) .expect("failed to sign envelope"); - // This validation should fail and the validator state should remain unchanged - match validator.validate(&envelope).unwrap_err() { + // This validation should fail + match state.validate(&envelope).unwrap_err() { ValidationError::NamespaceConflict { .. } => {} _ => panic!("expected a different error"), } - - // The validator should not have changed - assert_eq!(validator, expected); } } } diff --git a/crates/protocol/src/package/state.rs b/crates/protocol/src/package/state.rs index 857417e7..7180d4d1 100644 --- a/crates/protocol/src/package/state.rs +++ b/crates/protocol/src/package/state.rs @@ -143,7 +143,7 @@ pub struct LogState { /// This is `None` until the first (i.e. init) record is validated. #[serde(skip_serializing_if = "Option::is_none")] algorithm: Option, - /// The current head of the validator. + /// The current head of the state. #[serde(skip_serializing_if = "Option::is_none")] head: Option, /// The permissions of each key. @@ -152,18 +152,18 @@ pub struct LogState { /// The releases in the package log. #[serde(skip_serializing_if = "IndexMap::is_empty")] releases: IndexMap, - /// The keys known to the validator. + /// The keys known to the state. #[serde(skip_serializing_if = "IndexMap::is_empty")] keys: IndexMap, } impl LogState { - /// Create a new package log validator. + /// Create a new package log state. pub fn new() -> Self { Self::default() } - /// Gets the current head of the validator. + /// Gets the current head of the state. /// /// Returns `None` if no records have been validated yet. pub fn head(&self) -> &Option { @@ -175,23 +175,17 @@ impl LogState { /// It is expected that `validate` is called in order of the /// records in the log. /// - /// This operation is transactional: if any entry in the record - /// fails to validate, the validator state will remain unchanged. + /// Note that on failure, the log state is consumed to prevent + /// invalid state from being used in future validations. pub fn validate( - &mut self, + mut self, record: &ProtoEnvelope, - ) -> Result<(), ValidationError> { - let snapshot = self.snapshot(); - - let result = self.validate_record(record); - if result.is_err() { - self.rollback(snapshot); - } - - result + ) -> Result { + self.validate_record(record)?; + Ok(self) } - /// Gets the releases known to the validator. + /// Gets the releases known to the state. /// /// The releases are returned in package log order. /// @@ -268,7 +262,7 @@ impl LogState { // Validate the envelope signature model::PackageRecord::verify(key, envelope.content_bytes(), envelope.signature())?; - // Update the validator head + // Update the state head self.head = Some(Head { digest: record_id, timestamp: record.timestamp, @@ -527,59 +521,17 @@ impl LogState { } Ok(()) } - - fn snapshot(&self) -> Snapshot { - let Self { - algorithm, - head, - releases, - permissions, - keys, - } = self; - - Snapshot { - algorithm: *algorithm, - head: head.clone(), - releases: releases.len(), - permissions: permissions.len(), - keys: keys.len(), - } - } - - fn rollback(&mut self, snapshot: Snapshot) { - let Snapshot { - algorithm, - head, - releases, - permissions, - keys, - } = snapshot; - - self.algorithm = algorithm; - self.head = head; - self.releases.truncate(releases); - self.permissions.truncate(permissions); - self.keys.truncate(keys); - } } impl crate::Validator for LogState { type Record = model::PackageRecord; type Error = ValidationError; - fn validate(&mut self, record: &ProtoEnvelope) -> Result<(), Self::Error> { + fn validate(self, record: &ProtoEnvelope) -> Result { self.validate(record) } } -struct Snapshot { - algorithm: Option, - head: Option, - releases: usize, - permissions: usize, - keys: usize, -} - #[cfg(test)] mod tests { use super::*; @@ -605,11 +557,11 @@ mod tests { }; let envelope = ProtoEnvelope::signed_contents(&alice_priv, record).unwrap(); - let mut validator = LogState::default(); - validator.validate(&envelope).unwrap(); + let state = LogState::default(); + let state = state.validate(&envelope).unwrap(); assert_eq!( - validator, + state, LogState { head: Some(Head { digest: RecordId::package_record::(&envelope), @@ -634,7 +586,7 @@ mod tests { let bob_id = bob_pub.fingerprint(); let hash_algo = HashAlgorithm::Sha256; - let mut validator = LogState::default(); + let state = LogState::default(); // In envelope 0: alice inits and grants bob release let timestamp0 = SystemTime::now(); @@ -654,7 +606,7 @@ mod tests { ], }; let envelope0 = ProtoEnvelope::signed_contents(&alice_priv, record0).unwrap(); - validator.validate(&envelope0).unwrap(); + let state = state.validate(&envelope0).unwrap(); // In envelope 1: bob releases 1.1.0 let timestamp1 = timestamp0 + Duration::from_secs(1); @@ -671,11 +623,11 @@ mod tests { let envelope1 = ProtoEnvelope::signed_contents(&bob_priv, record1).unwrap(); let record_id1 = RecordId::package_record::(&envelope1); - validator.validate(&envelope1).unwrap(); + let state = state.validate(&envelope1).unwrap(); - // At this point, the validator should consider 1.1.0 released + // At this point, the state should consider 1.1.0 released assert_eq!( - validator.find_latest_release(&"~1".parse().unwrap()), + state.find_latest_release(&"~1".parse().unwrap()), Some(&Release { record_id: record_id1.clone(), version: Version::new(1, 1, 0), @@ -686,11 +638,11 @@ mod tests { } }) ); - assert!(validator + assert!(state .find_latest_release(&"~1.2".parse().unwrap()) .is_none()); assert_eq!( - validator.releases().collect::>(), + state.releases().collect::>(), vec![&Release { record_id: record_id1.clone(), version: Version::new(1, 1, 0), @@ -717,14 +669,12 @@ mod tests { ], }; let envelope2 = ProtoEnvelope::signed_contents(&alice_priv, record2).unwrap(); - validator.validate(&envelope2).unwrap(); + let state = state.validate(&envelope2).unwrap(); - // At this point, the validator should consider 1.1.0 yanked - assert!(validator - .find_latest_release(&"~1".parse().unwrap()) - .is_none()); + // At this point, the state should consider 1.1.0 yanked + assert!(state.find_latest_release(&"~1".parse().unwrap()).is_none()); assert_eq!( - validator.releases().collect::>(), + state.releases().collect::>(), vec![&Release { record_id: record_id1.clone(), version: Version::new(1, 1, 0), @@ -738,7 +688,7 @@ mod tests { ); assert_eq!( - validator, + state, LogState { algorithm: Some(HashAlgorithm::Sha256), head: Some(Head { @@ -789,8 +739,8 @@ mod tests { let envelope = ProtoEnvelope::signed_contents(&alice_priv, record).expect("failed to sign envelope"); - let mut validator = LogState::default(); - validator.validate(&envelope).unwrap(); + let state = LogState::default(); + let state = state.validate(&envelope).unwrap(); let expected = LogState { head: Some(Head { @@ -806,7 +756,7 @@ mod tests { keys: IndexMap::from([(alice_id, alice_pub)]), }; - assert_eq!(validator, expected); + assert_eq!(state, expected); let record = model::PackageRecord { prev: Some(RecordId::package_record::(&envelope)), @@ -829,13 +779,10 @@ mod tests { let envelope = ProtoEnvelope::signed_contents(&alice_priv, record).expect("failed to sign envelope"); - // This validation should fail and the validator state should remain unchanged - match validator.validate(&envelope).unwrap_err() { + // This validation should fail + match state.validate(&envelope).unwrap_err() { ValidationError::PermissionNotFoundToRevoke { .. } => {} _ => panic!("expected a different error"), } - - // The validator should not have changed - assert_eq!(validator, expected); } } diff --git a/crates/protocol/tests/operator.rs b/crates/protocol/tests/operator.rs index 32cd2edd..11497c33 100644 --- a/crates/protocol/tests/operator.rs +++ b/crates/protocol/tests/operator.rs @@ -47,9 +47,9 @@ fn validate_input(input: Vec) -> Result { Some(envelope) }) - .try_fold(LogState::new(), |mut validator, record| { - validator.validate(&record)?; - Ok(validator) + .try_fold(LogState::new(), |state, record| { + let state = state.validate(&record)?; + Ok(state) }) } @@ -81,7 +81,7 @@ fn execute_test(input_path: &Path) { .unwrap(); let output = match validate_input(input) { - Ok(validator) => Output::Valid(validator), + Ok(state) => Output::Valid(state), Err(e) => Output::Error(e.to_string()), }; diff --git a/crates/protocol/tests/package.rs b/crates/protocol/tests/package.rs index 2a451689..e6e79bf8 100644 --- a/crates/protocol/tests/package.rs +++ b/crates/protocol/tests/package.rs @@ -46,9 +46,9 @@ fn validate_input(input: Vec) -> Result { Some(envelope) }) - .try_fold(LogState::new(), |mut validator, record| { - validator.validate(&record)?; - Ok(validator) + .try_fold(LogState::new(), |state, record| { + let state = state.validate(&record)?; + Ok(state) }) } @@ -77,7 +77,7 @@ fn execute_test(input_path: &Path) { .unwrap(); let output = match validate_input(input) { - Ok(validator) => Output::Valid(validator), + Ok(state) => Output::Valid(state), Err(e) => Output::Error(e.to_string()), }; diff --git a/crates/server/src/api/debug/mod.rs b/crates/server/src/api/debug/mod.rs index 3e12694c..23438c8e 100644 --- a/crates/server/src/api/debug/mod.rs +++ b/crates/server/src/api/debug/mod.rs @@ -109,9 +109,8 @@ async fn get_package_info( let records = records .into_iter() .map(|record| { - package_state - .validate(&record.envelope) - .context("validate")?; + let state = std::mem::take(&mut package_state); + package_state = state.validate(&record.envelope).context("validate")?; let record_id = RecordId::package_record::(&record.envelope); let timestamp = record .envelope diff --git a/crates/server/src/datastore/memory.rs b/crates/server/src/datastore/memory.rs index 938bd857..02cf0f32 100644 --- a/crates/server/src/datastore/memory.rs +++ b/crates/server/src/datastore/memory.rs @@ -22,18 +22,18 @@ struct Entry { record_content: ProtoEnvelope, } -struct Log { - validator: V, +struct Log { + state: S, entries: Vec>, } -impl Default for Log +impl Default for Log where - V: Default, + S: Default, { fn default() -> Self { Self { - validator: V::default(), + state: S::default(), entries: Vec::new(), } } @@ -253,11 +253,13 @@ impl DataStore for MemoryDataStore { let record = record.take().unwrap(); let log = operators.entry(log_id.clone()).or_default(); match log - .validator + .state + .clone() .validate(&record) .map_err(DataStoreError::from) { - Ok(_) => { + Ok(s) => { + log.state = s; let index = log.entries.len(); log.entries.push(Entry { registry_index, @@ -378,11 +380,13 @@ impl DataStore for MemoryDataStore { let record = record.take().unwrap(); let log = packages.entry(log_id.clone()).or_default(); match log - .validator + .state + .clone() .validate(&record) .map_err(DataStoreError::from) { - Ok(_) => { + Ok(state) => { + log.state = state; let index = log.entries.len(); log.entries.push(Entry { registry_index, @@ -705,7 +709,7 @@ impl DataStore for MemoryDataStore { let key = match state .packages .get(log_id) - .and_then(|log| log.validator.public_key(record.key_id())) + .and_then(|log| log.state.public_key(record.key_id())) { Some(key) => Some(key), None => match record.as_ref().entries.first() { @@ -731,7 +735,7 @@ impl DataStore for MemoryDataStore { .operators .get(operator_log_id) .ok_or_else(|| DataStoreError::LogNotFound(operator_log_id.clone()))? - .validator + .state .namespace_state(package_name.namespace()) { Ok(Some(state)) => match state { @@ -777,14 +781,14 @@ impl DataStore for MemoryDataStore { ) -> Result<(), DataStoreError> { let state = self.0.read().await; - let validator = &state + let state = &state .operators .get(operator_log_id) .ok_or_else(|| DataStoreError::LogNotFound(operator_log_id.clone()))? - .validator; + .state; TimestampedCheckpoint::verify( - validator + state .public_key(ts_checkpoint.key_id()) .ok_or(DataStoreError::UnknownKey(ts_checkpoint.key_id().clone()))?, &ts_checkpoint.as_ref().encode(), @@ -794,7 +798,7 @@ impl DataStore for MemoryDataStore { ts_checkpoint.signature().clone(), )))?; - if !validator.key_has_permission_to_sign_checkpoints(ts_checkpoint.key_id()) { + if !state.key_has_permission_to_sign_checkpoints(ts_checkpoint.key_id()) { return Err(DataStoreError::KeyUnauthorized( ts_checkpoint.key_id().clone(), )); diff --git a/crates/server/src/datastore/postgres/mod.rs b/crates/server/src/datastore/postgres/mod.rs index 13595f85..fcb41adb 100644 --- a/crates/server/src/datastore/postgres/mod.rs +++ b/crates/server/src/datastore/postgres/mod.rs @@ -224,7 +224,7 @@ where conn.transaction::<_, DataStoreError, _>(|conn| { async move { // Get the record content and validator - let (id, content, mut validator) = schema::records::table + let (id, content, validator) = schema::records::table .inner_join(schema::logs::table) .select(( schema::records::id, @@ -251,12 +251,12 @@ where })?; // Validate the record - validator.validate(&record).map_err(Into::into)?; + let validator = validator.0.validate(&record).map_err(Into::into)?; // Store the updated validation state diesel::update(schema::logs::table) .filter(schema::logs::id.eq(log_id)) - .set(schema::logs::validator.eq(validator)) + .set(schema::logs::validator.eq(Json(validator))) .execute(conn) .await?;