From a360f7af8469fae1074fa6feb0adb6cd70f77540 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Uzarski?= Date: Fri, 4 Oct 2024 13:29:53 +0200 Subject: [PATCH] use_keyspace: don't use wildcard '_' in QueryError match Since last time, during error refactor I introduced a silent bug to the code (https://github.com/scylladb/scylla-rust-driver/pull/1075), I'd like to prevent that from happening in the future. This is why we replace a `_` match with explicit error variants when deciding if error received after `USE KEYSPACE` should be ignored. We also enable the `wildcard_enum_match_arm` clippy lint to disallow using `_` matches. --- scylla/src/transport/cluster.rs | 9 +++++---- scylla/src/transport/errors.rs | 27 +++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/scylla/src/transport/cluster.rs b/scylla/src/transport/cluster.rs index 2313402bc..9bf9bc52f 100644 --- a/scylla/src/transport/cluster.rs +++ b/scylla/src/transport/cluster.rs @@ -784,12 +784,13 @@ pub(crate) fn use_keyspace_result( for result in use_keyspace_results { match result { Ok(()) => was_ok = true, - Err(err) => match err { - QueryError::BrokenConnection(_) | QueryError::ConnectionPoolError(_) => { + Err(err) => { + if err.is_connection_broken() { broken_conn_error = Some(err) + } else { + return Err(err); } - _ => return Err(err), - }, + } } } diff --git a/scylla/src/transport/errors.rs b/scylla/src/transport/errors.rs index 8ef4a9943..d8d736578 100644 --- a/scylla/src/transport/errors.rs +++ b/scylla/src/transport/errors.rs @@ -100,6 +100,33 @@ pub enum QueryError { RequestTimeout(String), } +impl QueryError { + pub(crate) fn is_connection_broken(&self) -> bool { + // Do not remove this lint! + // It's there for a reason - we don't want new variants + // automatically fall under `_` pattern when they are introduced. + #[deny(clippy::wildcard_enum_match_arm)] + match self { + // Error variants that imply that some connection error appeared before/during execution. + QueryError::BrokenConnection(_) | QueryError::ConnectionPoolError(_) => true, + + // Other errors. + QueryError::DbError(_, _) + | QueryError::BadQuery(_) + | QueryError::CqlRequestSerialization(_) + | QueryError::BodyExtensionsParseError(_) + | QueryError::EmptyPlan + | QueryError::CqlResultParseError(_) + | QueryError::CqlErrorParseError(_) + | QueryError::MetadataError(_) + | QueryError::ProtocolError(_) + | QueryError::TimeoutError + | QueryError::UnableToAllocStreamId + | QueryError::RequestTimeout(_) => false, + } + } +} + impl From for QueryError { fn from(serialized_err: SerializeValuesError) -> QueryError { QueryError::BadQuery(BadQuery::SerializeValuesError(serialized_err))