Skip to content

Commit

Permalink
remove ureq (#2495)
Browse files Browse the repository at this point in the history
ureq-proto released a breaking change in 2.0.6 which broke our build

ureq used because of reddit:
https://old.reddit.com/r/rust/comments/f39ueb/how_to_use_reqwest_without_async/fhiyw6n

instead use futures::unfold to properly implement async streams with an async http client,
as a bonus we replace our dependency on ureq with our existing dependency on reqwest
  • Loading branch information
serprex authored Jan 28, 2025
1 parent ab93b08 commit 9cc6d99
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 132 deletions.
114 changes: 9 additions & 105 deletions nexus/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion nexus/peer-snowflake/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,4 @@ sha2 = "0.10"
sqlparser.workspace = true
tokio.workspace = true
tracing.workspace = true
ureq = { version = "3", features = ["json", "charset"] }
value = { path = "../value" }
66 changes: 40 additions & 26 deletions nexus/peer-snowflake/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub(crate) enum SnowflakeDataType {
Variant,
}

#[derive(Clone)]
pub struct SnowflakeSchema {
schema: Schema,
}
Expand Down Expand Up @@ -82,13 +83,17 @@ impl SnowflakeSchema {
}

pub struct SnowflakeRecordStream {
schema: SnowflakeSchema,
stream: Pin<Box<dyn Stream<Item = PgWireResult<Record>> + Send + Sync>>,
}

pub struct SnowflakeRecordStreamInner {
result_set: ResultSet,
partition_index: usize,
partition_number: usize,
schema: SnowflakeSchema,
auth: SnowflakeAuth,

endpoint_url: String,
auth: SnowflakeAuth,
schema: SnowflakeSchema,
}

impl SnowflakeRecordStream {
Expand All @@ -98,19 +103,26 @@ impl SnowflakeRecordStream {
partition_number: usize,
endpoint_url: String,
auth: SnowflakeAuth,
) -> Self {
let sf_schema = SnowflakeSchema::from_result_set(&result_set);
) -> SnowflakeRecordStream {
let schema = SnowflakeSchema::from_result_set(&result_set);

Self {
let inner = SnowflakeRecordStreamInner {
result_set,
schema: sf_schema,
partition_index,
partition_number,
endpoint_url,
auth,
}
schema: schema.clone(),
};
let stream = futures::stream::unfold(inner, |mut inner| async {
inner.advance().await.map(|val| (val, inner))
});

Self { schema, stream: Box::pin(stream) }
}
}

impl SnowflakeRecordStreamInner {
pub fn convert_result_set_item(&mut self) -> anyhow::Result<Record> {
let mut row_values = Vec::new();

Expand Down Expand Up @@ -202,7 +214,7 @@ impl SnowflakeRecordStream {
})
}

fn advance_partition(&mut self) -> anyhow::Result<bool> {
async fn advance_partition(&mut self) -> anyhow::Result<bool> {
if (self.partition_number + 1) == self.result_set.resultSetMetaData.partitionInfo.len() {
return Ok(false);
}
Expand All @@ -213,39 +225,41 @@ impl SnowflakeRecordStream {
let statement_handle = self.result_set.statementHandle.clone();
let url = self.endpoint_url.clone();
println!("Secret: {:#?}", secret);
let response: PartitionResult = ureq::get(format!("{}/{}", url, statement_handle))
.query("partition", partition_number.to_string())
let response = reqwest::Client::new().get(format!("{}/{}", url, statement_handle))
.query(&[("partition", partition_number.to_string())])
.header("Authorization", format!("Bearer {}", secret))
.header("X-Snowflake-Authorization-Token-Type", "KEYPAIR_JWT")
.header("user-agent", "ureq")
.call()?
.body_mut()
.read_json()
.send().await?
.json::<PartitionResult>().await
.map_err(|_| anyhow::anyhow!("get_partition failed"))?;
println!("Response: {:#?}", response.data);

self.result_set.data = response.data;
Ok(true)
}

fn advance(&mut self) -> anyhow::Result<bool> {
Ok((self.partition_index < self.result_set.data.len()) || self.advance_partition()?)
async fn advance(&mut self) -> Option<PgWireResult<Record>> {
let next = self.partition_index < self.result_set.data.len() || {
match self.advance_partition().await {
Ok(val) => val,
Err(err) => return Some(Err(PgWireError::ApiError(err.into()))),
}
};
if next {
let record = self.convert_result_set_item();
Some(record.map_err(|e| PgWireError::ApiError(e.into())))
} else {
None
}
}
}

impl Stream for SnowflakeRecordStream {
type Item = PgWireResult<Record>;

fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.advance() {
Ok(true) => {
let record = self.convert_result_set_item();
let result = record.map_err(|e| PgWireError::ApiError(e.into()));
Poll::Ready(Some(result))
}
Ok(false) => Poll::Ready(None),
Err(err) => Poll::Ready(Some(Err(PgWireError::ApiError(err.into())))),
}
fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Stream::poll_next(self.stream.as_mut(), ctx)
}
}

Expand Down

0 comments on commit 9cc6d99

Please sign in to comment.