-
Notifications
You must be signed in to change notification settings - Fork 873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix(flight_sql): PreparedStatement has no token for auth. #3948
Conversation
… headers. In particular, the token is required for auth in each request.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense to me -- thank you @youngsofun
Can you please add test coverage for this feature, perhaps in
arrow-rs/arrow-flight/examples/flight_sql_server.rs
Lines 504 to 688 in 9bd2bae
mod tests { | |
use super::*; | |
use futures::TryStreamExt; | |
use std::fs; | |
use std::time::Duration; | |
use tempfile::NamedTempFile; | |
use tokio::net::{UnixListener, UnixStream}; | |
use tokio::time::sleep; | |
use tokio_stream::wrappers::UnixListenerStream; | |
use tonic::transport::ClientTlsConfig; | |
use arrow_cast::pretty::pretty_format_batches; | |
use arrow_flight::sql::client::FlightSqlServiceClient; | |
use arrow_flight::utils::flight_data_to_batches; | |
use tonic::transport::{Certificate, Endpoint}; | |
use tower::service_fn; | |
async fn client_with_uds(path: String) -> FlightSqlServiceClient { | |
let connector = service_fn(move |_| UnixStream::connect(path.clone())); | |
let channel = Endpoint::try_from("http://example.com") | |
.unwrap() | |
.connect_with_connector(connector) | |
.await | |
.unwrap(); | |
FlightSqlServiceClient::new(channel) | |
} | |
async fn create_https_server() -> Result<(), tonic::transport::Error> { | |
let cert = std::fs::read_to_string("examples/data/server.pem").unwrap(); | |
let key = std::fs::read_to_string("examples/data/server.key").unwrap(); | |
let client_ca = std::fs::read_to_string("examples/data/client_ca.pem").unwrap(); | |
let tls_config = ServerTlsConfig::new() | |
.identity(Identity::from_pem(&cert, &key)) | |
.client_ca_root(Certificate::from_pem(&client_ca)); | |
let addr = "0.0.0.0:50051".parse().unwrap(); | |
let svc = FlightServiceServer::new(FlightSqlServiceImpl {}); | |
Server::builder() | |
.tls_config(tls_config) | |
.unwrap() | |
.add_service(svc) | |
.serve(addr) | |
.await | |
} | |
#[tokio::test] | |
async fn test_select_https() { | |
tokio::spawn(async { | |
create_https_server().await.unwrap(); | |
}); | |
sleep(Duration::from_millis(2000)).await; | |
let request_future = async { | |
let cert = std::fs::read_to_string("examples/data/client1.pem").unwrap(); | |
let key = std::fs::read_to_string("examples/data/client1.key").unwrap(); | |
let server_ca = std::fs::read_to_string("examples/data/ca.pem").unwrap(); | |
let tls_config = ClientTlsConfig::new() | |
.domain_name("localhost") | |
.ca_certificate(Certificate::from_pem(&server_ca)) | |
.identity(Identity::from_pem(cert, key)); | |
let endpoint = endpoint(String::from("https://127.0.0.1:50051")) | |
.unwrap() | |
.tls_config(tls_config) | |
.unwrap(); | |
let channel = endpoint.connect().await.unwrap(); | |
let mut client = FlightSqlServiceClient::new(channel); | |
let token = client.handshake("admin", "password").await.unwrap(); | |
println!("Auth succeeded with token: {:?}", token); | |
let mut stmt = client.prepare("select 1;".to_string()).await.unwrap(); | |
let flight_info = stmt.execute().await.unwrap(); | |
let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap().clone(); | |
let flight_data = client.do_get(ticket).await.unwrap(); | |
let flight_data: Vec<FlightData> = flight_data.try_collect().await.unwrap(); | |
let batches = flight_data_to_batches(&flight_data).unwrap(); | |
let res = pretty_format_batches(batches.as_slice()).unwrap(); | |
let expected = r#" | |
+-------------------+ | |
| salutation | | |
+-------------------+ | |
| Hello, FlightSQL! | | |
+-------------------+"# | |
.trim() | |
.to_string(); | |
assert_eq!(res.to_string(), expected); | |
}; | |
tokio::select! { | |
_ = request_future => println!("Client finished!"), | |
} | |
} | |
#[tokio::test] | |
async fn test_select_1() { | |
let file = NamedTempFile::new().unwrap(); | |
let path = file.into_temp_path().to_str().unwrap().to_string(); | |
let _ = fs::remove_file(path.clone()); | |
let uds = UnixListener::bind(path.clone()).unwrap(); | |
let stream = UnixListenerStream::new(uds); | |
// We would just listen on TCP, but it seems impossible to know when tonic is ready to serve | |
let service = FlightSqlServiceImpl {}; | |
let serve_future = Server::builder() | |
.add_service(FlightServiceServer::new(service)) | |
.serve_with_incoming(stream); | |
let request_future = async { | |
let mut client = client_with_uds(path).await; | |
let token = client.handshake("admin", "password").await.unwrap(); | |
println!("Auth succeeded with token: {:?}", token); | |
let mut stmt = client.prepare("select 1;".to_string()).await.unwrap(); | |
let flight_info = stmt.execute().await.unwrap(); | |
let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap().clone(); | |
let flight_data = client.do_get(ticket).await.unwrap(); | |
let flight_data: Vec<FlightData> = flight_data.try_collect().await.unwrap(); | |
let batches = flight_data_to_batches(&flight_data).unwrap(); | |
let res = pretty_format_batches(batches.as_slice()).unwrap(); | |
let expected = r#" | |
+-------------------+ | |
| salutation | | |
+-------------------+ | |
| Hello, FlightSQL! | | |
+-------------------+"# | |
.trim() | |
.to_string(); | |
assert_eq!(res.to_string(), expected); | |
}; | |
tokio::select! { | |
_ = serve_future => panic!("server returned first"), | |
_ = request_future => println!("Client finished!"), | |
} | |
} | |
#[tokio::test] | |
async fn test_execute_update() { | |
let file = NamedTempFile::new().unwrap(); | |
let path = file.into_temp_path().to_str().unwrap().to_string(); | |
let _ = fs::remove_file(path.clone()); | |
let uds = UnixListener::bind(path.clone()).unwrap(); | |
let stream = UnixListenerStream::new(uds); | |
// We would just listen on TCP, but it seems impossible to know when tonic is ready to serve | |
let service = FlightSqlServiceImpl {}; | |
let serve_future = Server::builder() | |
.add_service(FlightServiceServer::new(service)) | |
.serve_with_incoming(stream); | |
let request_future = async { | |
let mut client = client_with_uds(path).await; | |
let token = client.handshake("admin", "password").await.unwrap(); | |
println!("Auth succeeded with token: {:?}", token); | |
let res = client | |
.execute_update("creat table test(a int);".to_string()) | |
.await | |
.unwrap(); | |
assert_eq!(res, FlightSqlServiceImpl::fake_update_result()); | |
}; | |
tokio::select! { | |
_ = serve_future => panic!("server returned first"), | |
_ = request_future => println!("Client finished!"), | |
} | |
} | |
fn endpoint(addr: String) -> Result<Endpoint, ArrowError> { | |
let endpoint = Endpoint::new(addr) | |
.map_err(|_| ArrowError::IoError("Cannot create endpoint".to_string()))? | |
.connect_timeout(Duration::from_secs(20)) | |
.timeout(Duration::from_secs(20)) | |
.tcp_nodelay(true) // Disable Nagle's Algorithm since we don't want packets to wait | |
.tcp_keepalive(Option::Some(Duration::from_secs(3600))) | |
.http2_keep_alive_interval(Duration::from_secs(300)) | |
.keep_alive_timeout(Duration::from_secs(20)) | |
.keep_alive_while_idle(true); | |
Ok(endpoint) | |
} | |
} |
OK |
Thanks! CC @stormasm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me, just some minor nits
@alamb thank you for your review, I will examine them tomorrow, time to sleep now in Beijing :) |
@alamb polished according to you suggestions and refactored the tests by the way. please review again. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you
Which issue does this PR close?
Rationale for this change
set_request_headers
for each request. In particular, the token is required for auth in each request.FlightSqlServiceClient
(powered bytower::service::Buffer
inChannel
)(we are using this client for unit tests for our implementation of FlightSql Service, auth fail)
What changes are included in this PR?
to
do_get
, but alsodo_put
,do_action
etc. andset_request_headers
for each request.Are there any user-facing changes?