Skip to content

Commit

Permalink
feat: mysql and pg server support tls
Browse files Browse the repository at this point in the history
  • Loading branch information
SSebo committed Nov 26, 2022
1 parent 30940e6 commit d06eec9
Show file tree
Hide file tree
Showing 15 changed files with 134 additions and 15 deletions.
11 changes: 8 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions src/common/base/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ common-error = { path = "../error" }
paste = "1.0"
serde = { version = "1.0", features = ["derive"] }
snafu = { version = "0.7", features = ["backtraces"] }
rustls = "0.20"
rustls-pemfile = "1.0"
1 change: 1 addition & 0 deletions src/common/base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
pub mod bit_vec;
pub mod buffer;
pub mod bytes;
pub mod tls;

pub use bit_vec::BitVec;
45 changes: 45 additions & 0 deletions src/common/base/src/tls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::fs::File;
use std::io::{BufReader, ErrorKind};

use rustls::{Certificate, PrivateKey, ServerConfig};
use rustls_pemfile::{certs, pkcs8_private_keys};
use serde::{Deserialize, Serialize};

#[derive(Debug, Default, Serialize, Deserialize, Clone)]
pub struct TlsOption {
pub cert_path: String,
pub key_path: String,
}

impl TlsOption {
pub fn setup(&self) -> Result<ServerConfig, std::io::Error> {
let cert = certs(&mut BufReader::new(File::open(&self.key_path)?))
.map_err(|_| std::io::Error::new(ErrorKind::InvalidInput, "invalid cert"))
.map(|mut certs| certs.drain(..).map(Certificate).collect())?;
let key = pkcs8_private_keys(&mut BufReader::new(File::open(&self.key_path)?))
.map_err(|_| std::io::Error::new(ErrorKind::InvalidInput, "invalid key"))
.map(|mut keys| keys.drain(..).map(PrivateKey).next().unwrap())?;

let config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(cert, key)
.map_err(|err| std::io::Error::new(ErrorKind::InvalidInput, err))?;

Ok(config)
}
}
3 changes: 3 additions & 0 deletions src/datanode/src/datanode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

use std::sync::Arc;

use common_base::tls::TlsOption;
use common_telemetry::info;
use meta_client::MetaClientOpts;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -49,6 +50,7 @@ pub struct DatanodeOptions {
pub storage: ObjectStoreConfig,
pub enable_memory_catalog: bool,
pub mode: Mode,
pub tls: Option<Arc<TlsOption>>,
}

impl Default for DatanodeOptions {
Expand All @@ -64,6 +66,7 @@ impl Default for DatanodeOptions {
storage: ObjectStoreConfig::default(),
enable_memory_catalog: false,
mode: Mode::Standalone,
tls: None,
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/datanode/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ impl Services {
Some(MysqlServer::create_server(
instance.clone(),
mysql_io_runtime,
None,
))
}
};
Expand Down
1 change: 1 addition & 0 deletions src/frontend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ sql = { path = "../sql" }
store-api = { path = "../store-api" }
table = { path = "../table" }
tokio = { version = "1.18", features = ["full"] }
rustls = "0.20"

[dev-dependencies]
datanode = { path = "../datanode" }
Expand Down
5 changes: 5 additions & 0 deletions src/frontend/src/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;

use common_base::tls::TlsOption;
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct MysqlOptions {
pub addr: String,
pub runtime_size: usize,
pub tls: Option<Arc<TlsOption>>,
}

impl Default for MysqlOptions {
fn default() -> Self {
Self {
addr: "127.0.0.1:4002".to_string(),
runtime_size: 2,
tls: None,
}
}
}
5 changes: 5 additions & 0 deletions src/frontend/src/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;

use common_base::tls::TlsOption;
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PostgresOptions {
pub addr: String,
pub runtime_size: usize,
pub check_pwd: bool,
pub tls: Option<Arc<TlsOption>>,
}

impl Default for PostgresOptions {
Expand All @@ -27,6 +31,7 @@ impl Default for PostgresOptions {
addr: "127.0.0.1:4003".to_string(),
runtime_size: 2,
check_pwd: false,
tls: None,
}
}
}
4 changes: 3 additions & 1 deletion src/frontend/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ impl Services {
.context(error::RuntimeResourceSnafu)?,
);

let mysql_server = MysqlServer::create_server(instance.clone(), mysql_io_runtime);
let mysql_server =
MysqlServer::create_server(instance.clone(), mysql_io_runtime, opts.tls.clone());

Some((mysql_server, mysql_addr))
} else {
Expand All @@ -90,6 +91,7 @@ impl Services {
let pg_server = Box::new(PostgresServer::new(
instance.clone(),
opts.check_pwd,
opts.tls.clone(),
pg_io_runtime,
)) as Box<dyn Server>;

Expand Down
3 changes: 2 additions & 1 deletion src/servers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ metrics = "0.20"
num_cpus = "1.13"
once_cell = "1.16"
openmetrics-parser = "0.4"
opensrv-mysql = "0.2"
opensrv-mysql = { git = "https://github.com/SSebo/opensrv.git" }
pgwire = "0.5"
prost = "0.11"
regex = "1.6"
Expand All @@ -47,6 +47,7 @@ tonic = "0.8"
tonic-reflection = "0.5"
tower = { version = "0.4", features = ["full"] }
tower-http = { version = "0.3", features = ["full"] }
tokio-rustls = "0.23"

[dev-dependencies]
axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" }
Expand Down
45 changes: 38 additions & 7 deletions src/servers/src/mysql/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@ use std::net::SocketAddr;
use std::sync::Arc;

use async_trait::async_trait;
use common_base::tls::TlsOption;
use common_runtime::Runtime;
use common_telemetry::logging::{error, info};
use futures::StreamExt;
use opensrv_mysql::AsyncMysqlIntermediary;
use opensrv_mysql::{
plain_run_with_options, secure_run_with_options, AsyncMysqlIntermediary, IntermediaryOptions,
};
use tokio;
use tokio::io::BufWriter;
use tokio::net::TcpStream;
use tokio_rustls::rustls::ServerConfig;

use crate::error::Result;
use crate::mysql::handler::MysqlInstanceShim;
Expand All @@ -36,33 +40,40 @@ const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024;
pub struct MysqlServer {
base_server: BaseTcpServer,
query_handler: SqlQueryHandlerRef,
tls: Option<Arc<TlsOption>>,
}

impl MysqlServer {
pub fn create_server(
query_handler: SqlQueryHandlerRef,
io_runtime: Arc<Runtime>,
tls: Option<Arc<TlsOption>>,
) -> Box<dyn Server> {
Box::new(MysqlServer {
base_server: BaseTcpServer::create_server("MySQL", io_runtime),
query_handler,
tls,
})
}

fn accept(
&self,
io_runtime: Arc<Runtime>,
stream: AbortableStream,
tls_conf: Option<Arc<ServerConfig>>,
) -> impl Future<Output = ()> {
let query_handler = self.query_handler.clone();

stream.for_each(move |tcp_stream| {
let io_runtime = io_runtime.clone();
let query_handler = query_handler.clone();
let tls_conf = tls_conf.clone();
async move {
match tcp_stream {
Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt.
Ok(io_stream) => {
if let Err(error) = Self::handle(io_stream, io_runtime, query_handler).await
if let Err(error) =
Self::handle(io_stream, io_runtime, query_handler, tls_conf).await
{
error!(error; "Unexpected error when handling TcpStream");
};
Expand All @@ -76,15 +87,28 @@ impl MysqlServer {
stream: TcpStream,
io_runtime: Arc<Runtime>,
query_handler: SqlQueryHandlerRef,
tls_conf: Option<Arc<ServerConfig>>,
) -> Result<()> {
info!("MySQL connection coming from: {}", stream.peer_addr()?);
let shim = MysqlInstanceShim::create(query_handler, stream.peer_addr()?.to_string());
let mut shim = MysqlInstanceShim::create(query_handler, stream.peer_addr()?.to_string());

let (r, w) = stream.into_split();
let w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
let (mut r, w) = stream.into_split();
let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
// TODO(LFC): Use `output_stream` to write large MySQL ResultSet to client.
let spawn_result = io_runtime
.spawn(AsyncMysqlIntermediary::run_on(shim, r, w))
.spawn(async move {
let ops = IntermediaryOptions::default();
let (is_ssl, init_params) =
AsyncMysqlIntermediary::init_before_ssl(&mut shim, &mut r, &mut w, &tls_conf)
.await?;

match tls_conf {
Some(tls_conf) if is_ssl => {
secure_run_with_options(shim, w, ops, tls_conf, init_params).await
}
_ => plain_run_with_options(shim, w, ops, init_params).await,
}
})
.await;
match spawn_result {
Ok(run_result) => {
Expand All @@ -110,7 +134,14 @@ impl Server for MysqlServer {
let (stream, addr) = self.base_server.bind(listening).await?;

let io_runtime = self.base_server.io_runtime();
let join_handle = tokio::spawn(self.accept(io_runtime, stream));
let tls_conf = match &self.tls {
Some(tls) => {
let server_conf = tls.setup()?;
Some(Arc::new(server_conf))
}
None => None,
};
let join_handle = tokio::spawn(self.accept(io_runtime, stream, tls_conf));
self.base_server.start_with(join_handle).await?;
Ok(addr)
}
Expand Down
Loading

0 comments on commit d06eec9

Please sign in to comment.