Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RUST-1337 Use tokio's AsyncRead and AsyncWrite traits (backport) #669

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ chrono = "0.4.7"
derivative = "2.1.1"
flate2 = { version = "1.0", optional = true }
futures-core = "0.3.14"
futures-io = "0.3.14"
futures-util = { version = "0.3.14", features = ["io"] }
futures-executor = "0.3.14"
hex = "0.4.0"
Expand Down
12 changes: 8 additions & 4 deletions src/bson_util/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use std::{convert::TryFrom, io::Read, time::Duration};
use std::{
convert::TryFrom,
io::{Read, Write},
time::Duration,
};

use bson::RawBsonRef;
use serde::{de::Error as SerdeDeError, ser, Deserialize, Deserializer, Serialize, Serializer};

use crate::{
bson::{doc, Bson, Document},
error::{ErrorKind, Result},
runtime::{SyncLittleEndianRead, SyncLittleEndianWrite},
runtime::SyncLittleEndianRead,
};

/// Coerce numeric types into an `i64` if it would be lossless to do so. If this Bson is not numeric
Expand Down Expand Up @@ -203,10 +207,10 @@ fn num_decimal_digits(mut n: usize) -> u64 {

/// Read a document's raw BSON bytes from the provided reader.
pub(crate) fn read_document_bytes<R: Read>(mut reader: R) -> Result<Vec<u8>> {
let length = reader.read_i32()?;
let length = reader.read_i32_sync()?;

let mut bytes = Vec::with_capacity(length as usize);
bytes.write_i32(length)?;
bytes.write_all(&length.to_le_bytes())?;

reader.take(length as u64 - 4).read_to_end(&mut bytes)?;

Expand Down
20 changes: 9 additions & 11 deletions src/cmap/conn/wire/header.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
use futures_io::{AsyncRead, AsyncWrite};
use futures_util::AsyncWriteExt;
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};

use crate::{
error::{ErrorKind, Result},
runtime::AsyncLittleEndianRead,
};
use crate::error::{ErrorKind, Result};

/// The wire protocol op codes.
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
Expand Down Expand Up @@ -54,11 +50,13 @@ impl Header {
}

/// Reads bytes from `r` and deserializes them into a header.
pub(crate) async fn read_from<R: AsyncRead + Unpin + Send>(reader: &mut R) -> Result<Self> {
let length = reader.read_i32().await?;
let request_id = reader.read_i32().await?;
let response_to = reader.read_i32().await?;
let op_code = OpCode::from_i32(reader.read_i32().await?)?;
pub(crate) async fn read_from<R: tokio::io::AsyncRead + Unpin + Send>(
reader: &mut R,
) -> Result<Self> {
let length = reader.read_i32_le().await?;
let request_id = reader.read_i32_le().await?;
let response_to = reader.read_i32_le().await?;
let op_code = OpCode::from_i32(reader.read_i32_le().await?)?;
Ok(Self {
length,
request_id,
Expand Down
33 changes: 14 additions & 19 deletions src/cmap/conn/wire/message.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
use std::io::Read;

use bitflags::bitflags;
use futures_io::AsyncWrite;
use futures_util::{
io::{BufReader, BufWriter},
AsyncReadExt,
AsyncWriteExt,
};
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};

use super::header::{Header, OpCode};
use crate::{
Expand All @@ -16,7 +11,7 @@ use crate::{
Command,
},
error::{Error, ErrorKind, Result},
runtime::{AsyncLittleEndianWrite, AsyncStream, SyncLittleEndianRead},
runtime::{AsyncStream, SyncLittleEndianRead},
};

use crate::compression::{Compressor, Decoder};
Expand Down Expand Up @@ -129,7 +124,7 @@ impl Message {
let mut reader = buf.as_slice();

// Read original opcode (should be OP_MSG)
let original_opcode = reader.read_i32()?;
let original_opcode = reader.read_i32_sync()?;
if original_opcode != OpCode::Message as i32 {
return Err(ErrorKind::InvalidResponse {
message: format!(
Expand All @@ -142,10 +137,10 @@ impl Message {
}

// Read uncompressed size
let uncompressed_size = reader.read_i32()?;
let uncompressed_size = reader.read_i32_sync()?;

// Read compressor id
let compressor_id: u8 = reader.read_u8()?;
let compressor_id: u8 = reader.read_u8_sync()?;

// Get decoder
let decoder = Decoder::from_u8(compressor_id)?;
Expand Down Expand Up @@ -178,7 +173,7 @@ impl Message {
mut length_remaining: i32,
header: &Header,
) -> Result<Self> {
let flags = MessageFlags::from_bits_truncate(reader.read_u32()?);
let flags = MessageFlags::from_bits_truncate(reader.read_u32_sync()?);
length_remaining -= std::mem::size_of::<u32>() as i32;

let mut count_reader = SyncCountReader::new(&mut reader);
Expand All @@ -193,7 +188,7 @@ impl Message {
let mut checksum = None;

if length_remaining == 4 && flags.contains(MessageFlags::CHECKSUM_PRESENT) {
checksum = Some(reader.read_u32()?);
checksum = Some(reader.read_u32_sync()?);
} else if length_remaining != 0 {
return Err(ErrorKind::InvalidResponse {
message: format!(
Expand Down Expand Up @@ -241,11 +236,11 @@ impl Message {
};

header.write_to(&mut writer).await?;
writer.write_u32(self.flags.bits()).await?;
writer.write_u32_le(self.flags.bits()).await?;
writer.write_all(&sections_bytes).await?;

if let Some(checksum) = self.checksum {
writer.write_u32(checksum).await?;
writer.write_u32_le(checksum).await?;
}

writer.flush().await?;
Expand Down Expand Up @@ -292,9 +287,9 @@ impl Message {
// Write header
header.write_to(&mut writer).await?;
// Write original (pre-compressed) opcode (always OP_MSG)
writer.write_i32(OpCode::Message as i32).await?;
writer.write_i32_le(OpCode::Message as i32).await?;
// Write uncompressed size
writer.write_i32(uncompressed_len as i32).await?;
writer.write_i32_le(uncompressed_len as i32).await?;
// Write compressor id
writer.write_u8(compressor_id).await?;
// Write compressed message
Expand Down Expand Up @@ -329,15 +324,15 @@ pub(crate) enum MessageSection {
impl MessageSection {
/// Reads bytes from `reader` and deserializes them into a MessageSection.
fn read<R: Read>(reader: &mut R) -> Result<Self> {
let payload_type = reader.read_u8()?;
let payload_type = reader.read_u8_sync()?;

if payload_type == 0 {
return Ok(MessageSection::Document(bson_util::read_document_bytes(
reader,
)?));
}

let size = reader.read_i32()?;
let size = reader.read_i32_sync()?;
let mut length_remaining = size - std::mem::size_of::<i32>() as i32;

let mut identifier = String::new();
Expand Down Expand Up @@ -385,7 +380,7 @@ impl MessageSection {
// Write payload type.
writer.write_u8(1).await?;

writer.write_i32(*size).await?;
writer.write_i32_le(*size).await?;
super::util::write_cstring(writer, identifier).await?;

for doc in documents {
Expand Down
3 changes: 1 addition & 2 deletions src/cmap/conn/wire/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@ use std::{
sync::atomic::{AtomicI32, Ordering},
};

use futures_io::{self, AsyncWrite};
use futures_util::AsyncWriteExt;
use lazy_static::lazy_static;
use tokio::io::{AsyncWrite, AsyncWriteExt};

use crate::error::Result;

Expand Down
56 changes: 0 additions & 56 deletions src/runtime/async_read_ext.rs

This file was deleted.

51 changes: 0 additions & 51 deletions src/runtime/async_write_ext.rs

This file was deleted.

6 changes: 2 additions & 4 deletions src/runtime/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
mod acknowledged_message;
mod async_read_ext;
mod async_write_ext;
mod http;
#[cfg(feature = "async-std-runtime")]
mod interval;
mod join_handle;
mod resolver;
mod stream;
mod sync_read_ext;
#[cfg(feature = "openssl-tls")]
mod tls_openssl;
#[cfg_attr(feature = "openssl-tls", allow(unused))]
Expand All @@ -16,11 +15,10 @@ use std::{future::Future, net::SocketAddr, time::Duration};

pub(crate) use self::{
acknowledged_message::AcknowledgedMessage,
async_read_ext::{AsyncLittleEndianRead, SyncLittleEndianRead},
async_write_ext::{AsyncLittleEndianWrite, SyncLittleEndianWrite},
join_handle::AsyncJoinHandle,
resolver::AsyncResolver,
stream::AsyncStream,
sync_read_ext::SyncLittleEndianRead,
};
use crate::{error::Result, options::ServerAddress};
pub(crate) use http::HttpClient;
Expand Down
Loading