diff --git a/src/config.rs b/src/config.rs index 296d9f867..d4eeb6f60 100644 --- a/src/config.rs +++ b/src/config.rs @@ -23,7 +23,7 @@ //! [librdkafka-config]: https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md use std::collections::HashMap; -use std::ffi::{CStr, CString}; +use std::ffi::CString; use std::iter::FromIterator; use std::os::raw::c_char; use std::ptr; diff --git a/src/consumer/base_consumer.rs b/src/consumer/base_consumer.rs index c67c90cb2..a26350b94 100644 --- a/src/consumer/base_consumer.rs +++ b/src/consumer/base_consumer.rs @@ -345,7 +345,7 @@ where self.queue.ptr(), )) }; - if err.is_error() { + if let Some(err) = err { Err(KafkaError::ConsumerQueueClose(err.code())) } else { Ok(()) @@ -423,7 +423,7 @@ where assignment.ptr(), )) }; - if ret.is_error() { + if let Some(ret) = ret { let error = ret.name(); return Err(KafkaError::Subscription(error)); }; @@ -437,7 +437,7 @@ where assignment.ptr(), )) }; - if ret.is_error() { + if let Some(ret) = ret { let error = ret.name(); return Err(KafkaError::Subscription(error)); }; @@ -477,7 +477,7 @@ where timeout.into().as_millis(), )) }; - if ret.is_error() { + if let Some(ret) = ret { let error = ret.name(); return Err(KafkaError::Seek(error)); } diff --git a/src/consumer/stream_consumer.rs b/src/consumer/stream_consumer.rs index 5a7f60552..728ad2b1d 100644 --- a/src/consumer/stream_consumer.rs +++ b/src/consumer/stream_consumer.rs @@ -129,16 +129,12 @@ impl<'a, C: ConsumerContext> MessageStream<'a, C> { self.consumer.poll(Duration::ZERO) } } -} - -impl<'a, C: ConsumerContext> Stream for MessageStream<'a, C> { - type Item = KafkaResult>; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next_item(&self, cx: &mut Context<'_>) -> Poll>> { // If there is a message ready, yield it immediately to avoid the // taking the lock in `self.set_waker`. if let Some(message) = self.poll() { - return Poll::Ready(Some(message)); + return Poll::Ready(message); } // Otherwise, we need to wait for a message to become available. Store @@ -153,11 +149,19 @@ impl<'a, C: ConsumerContext> Stream for MessageStream<'a, C> { // installed the waker. match self.poll() { None => Poll::Pending, - Some(message) => Poll::Ready(Some(message)), + Some(message) => Poll::Ready(message), } } } +impl<'a, C: ConsumerContext> Stream for MessageStream<'a, C> { + type Item = KafkaResult>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_next_item(cx).map(Some) + } +} + impl<'a, C: ConsumerContext> Drop for MessageStream<'a, C> { fn drop(&mut self) { self.wakers.unregister(self.slot); @@ -297,10 +301,8 @@ where /// /// [cancellation safe]: https://docs.rs/tokio/latest/tokio/macro.select.html#cancellation-safety pub async fn recv(&self) -> Result, KafkaError> { - self.stream() - .next() - .await - .expect("kafka streams never terminate") + let stream = self.stream(); + futures_util::future::poll_fn(move |cx| stream.poll_next_item(cx)).await } /// Splits messages for the specified partition into their own stream. diff --git a/src/error.rs b/src/error.rs index 312a6bb65..7217b7488 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,7 +3,6 @@ use std::error::Error; use std::ffi::{self, CStr}; use std::fmt; -use std::ptr; use std::sync::Arc; use rdkafka_sys as rdsys; @@ -39,13 +38,13 @@ impl IsError for RDKafkaConfRes { impl IsError for RDKafkaError { fn is_error(&self) -> bool { - self.0.is_some() + true } } /// Native rdkafka error. #[derive(Clone)] -pub struct RDKafkaError(Option>>); +pub struct RDKafkaError(Arc>); unsafe impl KafkaDrop for rdsys::rd_kafka_error_t { const TYPE: &'static str = "error"; @@ -56,15 +55,12 @@ unsafe impl Send for RDKafkaError {} unsafe impl Sync for RDKafkaError {} impl RDKafkaError { - pub(crate) unsafe fn from_ptr(ptr: *mut rdsys::rd_kafka_error_t) -> RDKafkaError { - RDKafkaError(NativePtr::from_ptr(ptr).map(Arc::new)) + pub(crate) unsafe fn from_ptr(ptr: *mut rdsys::rd_kafka_error_t) -> Option { + NativePtr::from_ptr(ptr).map(|p| RDKafkaError(Arc::new(p))) } fn ptr(&self) -> *const rdsys::rd_kafka_error_t { - match &self.0 { - None => ptr::null(), - Some(p) => p.ptr(), - } + self.0.ptr() } /// Returns the error code or [`RDKafkaErrorCode::NoError`] if the error is diff --git a/src/producer/base_producer.rs b/src/producer/base_producer.rs index 1cc6e05ce..9cd6df168 100644 --- a/src/producer/base_producer.rs +++ b/src/producer/base_producer.rs @@ -537,7 +537,7 @@ where timeout.into().as_millis(), )) }; - if ret.is_error() { + if let Some(ret) = ret { Err(KafkaError::Transaction(ret)) } else { Ok(()) @@ -547,7 +547,7 @@ where fn begin_transaction(&self) -> KafkaResult<()> { let ret = unsafe { RDKafkaError::from_ptr(rdsys::rd_kafka_begin_transaction(self.native_ptr())) }; - if ret.is_error() { + if let Some(ret) = ret { Err(KafkaError::Transaction(ret)) } else { Ok(()) @@ -568,7 +568,7 @@ where timeout.into().as_millis(), )) }; - if ret.is_error() { + if let Some(ret) = ret { Err(KafkaError::Transaction(ret)) } else { Ok(()) @@ -589,7 +589,7 @@ where timeout.as_millis(), )) }; - if ret.is_error() { + if let Some(ret) = ret { Err(KafkaError::Transaction(ret)) } else { Ok(()) @@ -603,7 +603,7 @@ where timeout.into().as_millis(), )) }; - if ret.is_error() { + if let Some(ret) = ret { Err(KafkaError::Transaction(ret)) } else { Ok(())