diff --git a/scylla/src/transport/cluster.rs b/scylla/src/transport/cluster.rs index 2313402bc..9bf9bc52f 100644 --- a/scylla/src/transport/cluster.rs +++ b/scylla/src/transport/cluster.rs @@ -784,12 +784,13 @@ pub(crate) fn use_keyspace_result( for result in use_keyspace_results { match result { Ok(()) => was_ok = true, - Err(err) => match err { - QueryError::BrokenConnection(_) | QueryError::ConnectionPoolError(_) => { + Err(err) => { + if err.is_connection_broken() { broken_conn_error = Some(err) + } else { + return Err(err); } - _ => return Err(err), - }, + } } } diff --git a/scylla/src/transport/errors.rs b/scylla/src/transport/errors.rs index f6b31b6c9..a1f2988e1 100644 --- a/scylla/src/transport/errors.rs +++ b/scylla/src/transport/errors.rs @@ -100,6 +100,33 @@ pub enum QueryError { RequestTimeout(String), } +impl QueryError { + pub(crate) fn is_connection_broken(&self) -> bool { + // Do not remove this lint! + // It's there for a reason - we don't want new variants + // automatically fall under `_` pattern when they are introduced. + #[deny(clippy::wildcard_enum_match_arm)] + match self { + // Error variants that imply that some connection error appeared before/during execution. + QueryError::BrokenConnection(_) | QueryError::ConnectionPoolError(_) => true, + + // Other errors. + QueryError::DbError(_, _) + | QueryError::BadQuery(_) + | QueryError::CqlRequestSerialization(_) + | QueryError::BodyExtensionsParseError(_) + | QueryError::EmptyPlan + | QueryError::CqlResultParseError(_) + | QueryError::CqlErrorParseError(_) + | QueryError::MetadataError(_) + | QueryError::ProtocolError(_) + | QueryError::TimeoutError + | QueryError::UnableToAllocStreamId + | QueryError::RequestTimeout(_) => false, + } + } +} + impl From for QueryError { fn from(serialized_err: SerializeValuesError) -> QueryError { QueryError::BadQuery(BadQuery::SerializeValuesError(serialized_err))