diff --git a/Cargo.lock b/Cargo.lock index c4965d9e046b..9d76509f74fc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1153,6 +1153,8 @@ dependencies = [ "bytes", "common-error", "paste", + "rustls", + "rustls-pemfile 1.0.1", "serde", "snafu", ] @@ -2207,6 +2209,7 @@ dependencies = [ "openmetrics-parser", "prost 0.11.0", "query", + "rustls", "serde", "serde_json", "servers", @@ -3707,16 +3710,17 @@ dependencies = [ [[package]] name = "opensrv-mysql" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4c24c12fd688cb5aa5b1a54c6ccb2e30fb9b5132debb0e89fcb432b3f73db8f" +version = "0.3.0" +source = "git+https://github.com/SSebo/opensrv.git#aa9beefdeef075c6b2da85766ad4d20bb585e694" dependencies = [ "async-trait", "byteorder", "chrono", "mysql_common", "nom", + "pin-project-lite", "tokio", + "tokio-rustls", ] [[package]] @@ -5453,6 +5457,7 @@ dependencies = [ "table", "tokio", "tokio-postgres", + "tokio-rustls", "tokio-stream", "tokio-test", "tonic", diff --git a/src/common/base/Cargo.toml b/src/common/base/Cargo.toml index c1718a2ae0d2..ae0399397b3a 100644 --- a/src/common/base/Cargo.toml +++ b/src/common/base/Cargo.toml @@ -11,3 +11,5 @@ common-error = { path = "../error" } paste = "1.0" serde = { version = "1.0", features = ["derive"] } snafu = { version = "0.7", features = ["backtraces"] } +rustls = "0.20" +rustls-pemfile = "1.0" diff --git a/src/common/base/src/lib.rs b/src/common/base/src/lib.rs index c3d8a047295b..c87731af4b57 100644 --- a/src/common/base/src/lib.rs +++ b/src/common/base/src/lib.rs @@ -15,5 +15,6 @@ pub mod bit_vec; pub mod buffer; pub mod bytes; +pub mod tls; pub use bit_vec::BitVec; diff --git a/src/common/base/src/tls.rs b/src/common/base/src/tls.rs new file mode 100644 index 000000000000..50af6a8fcfa3 --- /dev/null +++ b/src/common/base/src/tls.rs @@ -0,0 +1,45 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fs::File; +use std::io::{BufReader, ErrorKind}; + +use rustls::{Certificate, PrivateKey, ServerConfig}; +use rustls_pemfile::{certs, pkcs8_private_keys}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct TlsOption { + pub cert_path: String, + pub key_path: String, +} + +impl TlsOption { + pub fn setup(&self) -> Result { + let cert = certs(&mut BufReader::new(File::open(&self.key_path)?)) + .map_err(|_| std::io::Error::new(ErrorKind::InvalidInput, "invalid cert")) + .map(|mut certs| certs.drain(..).map(Certificate).collect())?; + let key = pkcs8_private_keys(&mut BufReader::new(File::open(&self.key_path)?)) + .map_err(|_| std::io::Error::new(ErrorKind::InvalidInput, "invalid key")) + .map(|mut keys| keys.drain(..).map(PrivateKey).next().unwrap())?; + + let config = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(cert, key) + .map_err(|err| std::io::Error::new(ErrorKind::InvalidInput, err))?; + + Ok(config) + } +} diff --git a/src/datanode/src/datanode.rs b/src/datanode/src/datanode.rs index af620146970b..6751afda0317 100644 --- a/src/datanode/src/datanode.rs +++ b/src/datanode/src/datanode.rs @@ -14,6 +14,7 @@ use std::sync::Arc; +use common_base::tls::TlsOption; use common_telemetry::info; use meta_client::MetaClientOpts; use serde::{Deserialize, Serialize}; @@ -49,6 +50,7 @@ pub struct DatanodeOptions { pub storage: ObjectStoreConfig, pub enable_memory_catalog: bool, pub mode: Mode, + pub tls: Option>, } impl Default for DatanodeOptions { @@ -64,6 +66,7 @@ impl Default for DatanodeOptions { storage: ObjectStoreConfig::default(), enable_memory_catalog: false, mode: Mode::Standalone, + tls: None, } } } diff --git a/src/datanode/src/server.rs b/src/datanode/src/server.rs index f2cef8ca744a..05922eb07144 100644 --- a/src/datanode/src/server.rs +++ b/src/datanode/src/server.rs @@ -62,6 +62,7 @@ impl Services { Some(MysqlServer::create_server( instance.clone(), mysql_io_runtime, + None, )) } }; diff --git a/src/frontend/Cargo.toml b/src/frontend/Cargo.toml index a56affcf17b4..d89124c11da6 100644 --- a/src/frontend/Cargo.toml +++ b/src/frontend/Cargo.toml @@ -45,6 +45,7 @@ sql = { path = "../sql" } store-api = { path = "../store-api" } table = { path = "../table" } tokio = { version = "1.18", features = ["full"] } +rustls = "0.20" [dev-dependencies] datanode = { path = "../datanode" } diff --git a/src/frontend/src/mysql.rs b/src/frontend/src/mysql.rs index 71bb600753a5..ddc4e0a39187 100644 --- a/src/frontend/src/mysql.rs +++ b/src/frontend/src/mysql.rs @@ -12,12 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + +use common_base::tls::TlsOption; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct MysqlOptions { pub addr: String, pub runtime_size: usize, + pub tls: Option>, } impl Default for MysqlOptions { @@ -25,6 +29,7 @@ impl Default for MysqlOptions { Self { addr: "127.0.0.1:4002".to_string(), runtime_size: 2, + tls: None, } } } diff --git a/src/frontend/src/postgres.rs b/src/frontend/src/postgres.rs index 41a11233bc33..8dd2fdc43fc8 100644 --- a/src/frontend/src/postgres.rs +++ b/src/frontend/src/postgres.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + +use common_base::tls::TlsOption; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Serialize, Deserialize)] @@ -19,6 +22,7 @@ pub struct PostgresOptions { pub addr: String, pub runtime_size: usize, pub check_pwd: bool, + pub tls: Option>, } impl Default for PostgresOptions { @@ -27,6 +31,7 @@ impl Default for PostgresOptions { addr: "127.0.0.1:4003".to_string(), runtime_size: 2, check_pwd: false, + tls: None, } } } diff --git a/src/frontend/src/server.rs b/src/frontend/src/server.rs index ede8adc8b55b..8eee23da0ce1 100644 --- a/src/frontend/src/server.rs +++ b/src/frontend/src/server.rs @@ -69,7 +69,8 @@ impl Services { .context(error::RuntimeResourceSnafu)?, ); - let mysql_server = MysqlServer::create_server(instance.clone(), mysql_io_runtime); + let mysql_server = + MysqlServer::create_server(instance.clone(), mysql_io_runtime, opts.tls.clone()); Some((mysql_server, mysql_addr)) } else { @@ -90,6 +91,7 @@ impl Services { let pg_server = Box::new(PostgresServer::new( instance.clone(), opts.check_pwd, + opts.tls.clone(), pg_io_runtime, )) as Box; diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index 8a4c3f4345d5..8fc97d7c5096 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -30,7 +30,7 @@ metrics = "0.20" num_cpus = "1.13" once_cell = "1.16" openmetrics-parser = "0.4" -opensrv-mysql = "0.2" +opensrv-mysql = { git = "https://github.com/SSebo/opensrv.git" } pgwire = "0.5" prost = "0.11" regex = "1.6" @@ -47,6 +47,7 @@ tonic = "0.8" tonic-reflection = "0.5" tower = { version = "0.4", features = ["full"] } tower-http = { version = "0.3", features = ["full"] } +tokio-rustls = "0.23" [dev-dependencies] axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" } diff --git a/src/servers/src/mysql/server.rs b/src/servers/src/mysql/server.rs index f66669303c32..cd9bdc3005c6 100644 --- a/src/servers/src/mysql/server.rs +++ b/src/servers/src/mysql/server.rs @@ -17,13 +17,17 @@ use std::net::SocketAddr; use std::sync::Arc; use async_trait::async_trait; +use common_base::tls::TlsOption; use common_runtime::Runtime; use common_telemetry::logging::{error, info}; use futures::StreamExt; -use opensrv_mysql::AsyncMysqlIntermediary; +use opensrv_mysql::{ + plain_run_with_options, secure_run_with_options, AsyncMysqlIntermediary, IntermediaryOptions, +}; use tokio; use tokio::io::BufWriter; use tokio::net::TcpStream; +use tokio_rustls::rustls::ServerConfig; use crate::error::Result; use crate::mysql::handler::MysqlInstanceShim; @@ -36,16 +40,19 @@ const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024; pub struct MysqlServer { base_server: BaseTcpServer, query_handler: SqlQueryHandlerRef, + tls: Option>, } impl MysqlServer { pub fn create_server( query_handler: SqlQueryHandlerRef, io_runtime: Arc, + tls: Option>, ) -> Box { Box::new(MysqlServer { base_server: BaseTcpServer::create_server("MySQL", io_runtime), query_handler, + tls, }) } @@ -53,16 +60,20 @@ impl MysqlServer { &self, io_runtime: Arc, stream: AbortableStream, + tls_conf: Option>, ) -> impl Future { let query_handler = self.query_handler.clone(); + stream.for_each(move |tcp_stream| { let io_runtime = io_runtime.clone(); let query_handler = query_handler.clone(); + let tls_conf = tls_conf.clone(); async move { match tcp_stream { Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt. Ok(io_stream) => { - if let Err(error) = Self::handle(io_stream, io_runtime, query_handler).await + if let Err(error) = + Self::handle(io_stream, io_runtime, query_handler, tls_conf).await { error!(error; "Unexpected error when handling TcpStream"); }; @@ -76,15 +87,28 @@ impl MysqlServer { stream: TcpStream, io_runtime: Arc, query_handler: SqlQueryHandlerRef, + tls_conf: Option>, ) -> Result<()> { info!("MySQL connection coming from: {}", stream.peer_addr()?); - let shim = MysqlInstanceShim::create(query_handler, stream.peer_addr()?.to_string()); + let mut shim = MysqlInstanceShim::create(query_handler, stream.peer_addr()?.to_string()); - let (r, w) = stream.into_split(); - let w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w); + let (mut r, w) = stream.into_split(); + let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w); // TODO(LFC): Use `output_stream` to write large MySQL ResultSet to client. let spawn_result = io_runtime - .spawn(AsyncMysqlIntermediary::run_on(shim, r, w)) + .spawn(async move { + let ops = IntermediaryOptions::default(); + let (is_ssl, init_params) = + AsyncMysqlIntermediary::init_before_ssl(&mut shim, &mut r, &mut w, &tls_conf) + .await?; + + match tls_conf { + Some(tls_conf) if is_ssl => { + secure_run_with_options(shim, w, ops, tls_conf, init_params).await + } + _ => plain_run_with_options(shim, w, ops, init_params).await, + } + }) .await; match spawn_result { Ok(run_result) => { @@ -110,7 +134,14 @@ impl Server for MysqlServer { let (stream, addr) = self.base_server.bind(listening).await?; let io_runtime = self.base_server.io_runtime(); - let join_handle = tokio::spawn(self.accept(io_runtime, stream)); + let tls_conf = match &self.tls { + Some(tls) => { + let server_conf = tls.setup()?; + Some(Arc::new(server_conf)) + } + None => None, + }; + let join_handle = tokio::spawn(self.accept(io_runtime, stream, tls_conf)); self.base_server.start_with(join_handle).await?; Ok(addr) } diff --git a/src/servers/src/postgres/server.rs b/src/servers/src/postgres/server.rs index 61adc65d6375..592e477ba2dd 100644 --- a/src/servers/src/postgres/server.rs +++ b/src/servers/src/postgres/server.rs @@ -17,11 +17,13 @@ use std::net::SocketAddr; use std::sync::Arc; use async_trait::async_trait; +use common_base::tls::TlsOption; use common_runtime::Runtime; use common_telemetry::logging::error; use futures::StreamExt; use pgwire::tokio::process_socket; use tokio; +use tokio_rustls::TlsAcceptor; use crate::error::Result; use crate::postgres::auth_handler::PgAuthStartupHandler; @@ -33,6 +35,7 @@ pub struct PostgresServer { base_server: BaseTcpServer, auth_handler: Arc, query_handler: Arc, + tls: Option>, } impl PostgresServer { @@ -40,6 +43,7 @@ impl PostgresServer { pub fn new( query_handler: SqlQueryHandlerRef, check_pwd: bool, + tls: Option>, io_runtime: Arc, ) -> PostgresServer { let postgres_handler = Arc::new(PostgresServerHandler::new(query_handler)); @@ -48,6 +52,7 @@ impl PostgresServer { base_server: BaseTcpServer::create_server("Postgres", io_runtime), auth_handler: startup_handler, query_handler: postgres_handler, + tls, } } @@ -55,6 +60,7 @@ impl PostgresServer { &self, io_runtime: Arc, accepting_stream: AbortableStream, + tls_acceptor: Option>, ) -> impl Future { let auth_handler = self.auth_handler.clone(); let query_handler = self.query_handler.clone(); @@ -63,6 +69,7 @@ impl PostgresServer { let io_runtime = io_runtime.clone(); let auth_handler = auth_handler.clone(); let query_handler = query_handler.clone(); + let tls_acceptor = tls_acceptor.clone(); async move { match tcp_stream { @@ -70,7 +77,7 @@ impl PostgresServer { Ok(io_stream) => { io_runtime.spawn(process_socket( io_stream, - None, + tls_acceptor.clone(), auth_handler.clone(), query_handler.clone(), query_handler.clone(), @@ -91,8 +98,17 @@ impl Server for PostgresServer { async fn start(&self, listening: SocketAddr) -> Result { let (stream, addr) = self.base_server.bind(listening).await?; + let tls_acceptor = match &self.tls { + Some(tls) => { + let server_conf = tls.setup()?; + Some(Arc::new(TlsAcceptor::from(Arc::new(server_conf)))) + } + None => None, + }; + let io_runtime = self.base_server.io_runtime(); - let join_handle = tokio::spawn(self.accept(io_runtime, stream)); + let join_handle = tokio::spawn(self.accept(io_runtime, stream, tls_acceptor)); + self.base_server.start_with(join_handle).await?; Ok(addr) } diff --git a/src/servers/tests/mysql/mysql_server_test.rs b/src/servers/tests/mysql/mysql_server_test.rs index ba82e8d68f8a..5a3c47294e69 100644 --- a/src/servers/tests/mysql/mysql_server_test.rs +++ b/src/servers/tests/mysql/mysql_server_test.rs @@ -39,7 +39,7 @@ fn create_mysql_server(table: MemTable) -> Result> { .build() .unwrap(), ); - Ok(MysqlServer::create_server(query_handler, io_runtime)) + Ok(MysqlServer::create_server(query_handler, io_runtime, None)) } #[tokio::test] diff --git a/src/servers/tests/postgres/mod.rs b/src/servers/tests/postgres/mod.rs index bbbe88ee7848..dfd8a280663e 100644 --- a/src/servers/tests/postgres/mod.rs +++ b/src/servers/tests/postgres/mod.rs @@ -39,6 +39,7 @@ fn create_postgres_server(table: MemTable, check_pwd: bool) -> Result