Skip to content

Commit

Permalink
chore: TlsOption is required but supply default value
Browse files Browse the repository at this point in the history
  • Loading branch information
SSebo committed Nov 27, 2022
1 parent 2719eb5 commit 6bed566
Show file tree
Hide file tree
Showing 12 changed files with 120 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[target.aarch64-unknown-linux-gnu]
linker = "aarch64-linux-gnu-gcc"
# linker = "aarch64-linux-gnu-gcc"
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/common/base/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ serde = { version = "1.0", features = ["derive"] }
snafu = { version = "0.7", features = ["backtraces"] }
rustls = "0.20"
rustls-pemfile = "1.0"

[dev-dependencies]
serde_json = "1.0"
93 changes: 91 additions & 2 deletions src/common/base/src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,30 @@ use rustls_pemfile::{certs, pkcs8_private_keys};
use serde::{Deserialize, Serialize};

#[derive(Debug, Default, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum TlsMode {
#[default]
Disabled,
Prefered,
Required,
Verified,
}

#[derive(Debug, Default, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
pub struct TlsOption {
pub mode: TlsMode,
#[serde(default)]
pub cert_path: String,
#[serde(default)]
pub key_path: String,
}

impl TlsOption {
pub fn setup(&self) -> Result<ServerConfig, std::io::Error> {
pub fn setup(&self) -> Result<Option<ServerConfig>, std::io::Error> {
if let TlsMode::Disabled = self.mode {
return Ok(None);
}
let cert = certs(&mut BufReader::new(File::open(&self.key_path)?))
.map_err(|_| std::io::Error::new(ErrorKind::InvalidInput, "invalid cert"))
.map(|mut certs| certs.drain(..).map(Certificate).collect())?;
Expand All @@ -40,6 +57,78 @@ impl TlsOption {
.with_single_cert(cert, key)
.map_err(|err| std::io::Error::new(ErrorKind::InvalidInput, err))?;

Ok(config)
Ok(Some(config))
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_tls_option_disable() {
let s = r#"
{
"mode": "disabled"
}
"#;

let t: TlsOption = serde_json::from_str(s).unwrap();
assert!(matches!(t.mode, TlsMode::Disabled));
assert!(t.key_path.is_empty());
assert!(t.cert_path.is_empty());

let setup = t.setup();
assert!(setup.is_ok());
let setup = setup.unwrap();
assert!(setup.is_none());
}

#[test]
fn test_tls_option_prefered() {
let s = r#"
{
"mode": "prefered",
"cert_path": "/some_dir/some.crt",
"key_path": "/some_dir/some.key"
}
"#;

let t: TlsOption = serde_json::from_str(s).unwrap();
assert!(matches!(t.mode, TlsMode::Prefered));
assert!(!t.key_path.is_empty());
assert!(!t.cert_path.is_empty());
}

#[test]
fn test_tls_option_required() {
let s = r#"
{
"mode": "required",
"cert_path": "/some_dir/some.crt",
"key_path": "/some_dir/some.key"
}
"#;

let t: TlsOption = serde_json::from_str(s).unwrap();
assert!(matches!(t.mode, TlsMode::Required));
assert!(!t.key_path.is_empty());
assert!(!t.cert_path.is_empty());
}

#[test]
fn test_tls_option_verified() {
let s = r#"
{
"mode": "required",
"cert_path": "/some_dir/some.crt",
"key_path": "/some_dir/some.key"
}
"#;

let t: TlsOption = serde_json::from_str(s).unwrap();
assert!(matches!(t.mode, TlsMode::Required));
assert!(!t.key_path.is_empty());
assert!(!t.cert_path.is_empty());
}
}
3 changes: 0 additions & 3 deletions src/datanode/src/datanode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

use std::sync::Arc;

use common_base::tls::TlsOption;
use common_telemetry::info;
use meta_client::MetaClientOpts;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -50,7 +49,6 @@ pub struct DatanodeOptions {
pub storage: ObjectStoreConfig,
pub enable_memory_catalog: bool,
pub mode: Mode,
pub tls: Option<Arc<TlsOption>>,
}

impl Default for DatanodeOptions {
Expand All @@ -66,7 +64,6 @@ impl Default for DatanodeOptions {
storage: ObjectStoreConfig::default(),
enable_memory_catalog: false,
mode: Mode::Standalone,
tls: None,
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/datanode/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl Services {
Some(MysqlServer::create_server(
instance.clone(),
mysql_io_runtime,
None,
Default::default(),
))
}
};
Expand Down
5 changes: 3 additions & 2 deletions src/frontend/src/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ use serde::{Deserialize, Serialize};
pub struct MysqlOptions {
pub addr: String,
pub runtime_size: usize,
pub tls: Option<Arc<TlsOption>>,
#[serde(default = "Default::default")]
pub tls: Arc<TlsOption>,
}

impl Default for MysqlOptions {
fn default() -> Self {
Self {
addr: "127.0.0.1:4002".to_string(),
runtime_size: 2,
tls: None,
tls: Arc::new(TlsOption::default()),
}
}
}
5 changes: 3 additions & 2 deletions src/frontend/src/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ pub struct PostgresOptions {
pub addr: String,
pub runtime_size: usize,
pub check_pwd: bool,
pub tls: Option<Arc<TlsOption>>,
#[serde(default = "Default::default")]
pub tls: Arc<TlsOption>,
}

impl Default for PostgresOptions {
Expand All @@ -31,7 +32,7 @@ impl Default for PostgresOptions {
addr: "127.0.0.1:4003".to_string(),
runtime_size: 2,
check_pwd: false,
tls: None,
tls: Default::default(),
}
}
}
14 changes: 5 additions & 9 deletions src/servers/src/mysql/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024;
pub struct MysqlServer {
base_server: BaseTcpServer,
query_handler: SqlQueryHandlerRef,
tls: Option<Arc<TlsOption>>,
tls: Arc<TlsOption>,
}

impl MysqlServer {
pub fn create_server(
query_handler: SqlQueryHandlerRef,
io_runtime: Arc<Runtime>,
tls: Option<Arc<TlsOption>>,
tls: Arc<TlsOption>,
) -> Box<dyn Server> {
Box::new(MysqlServer {
base_server: BaseTcpServer::create_server("MySQL", io_runtime),
Expand Down Expand Up @@ -134,13 +134,9 @@ impl Server for MysqlServer {
let (stream, addr) = self.base_server.bind(listening).await?;

let io_runtime = self.base_server.io_runtime();
let tls_conf = match &self.tls {
Some(tls) => {
let server_conf = tls.setup()?;
Some(Arc::new(server_conf))
}
None => None,
};

let tls_conf = self.tls.setup()?.map(Arc::new);

let join_handle = tokio::spawn(self.accept(io_runtime, stream, tls_conf));
self.base_server.start_with(join_handle).await?;
Ok(addr)
Expand Down
15 changes: 6 additions & 9 deletions src/servers/src/postgres/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ pub struct PostgresServer {
base_server: BaseTcpServer,
auth_handler: Arc<PgAuthStartupHandler>,
query_handler: Arc<PostgresServerHandler>,
tls: Option<Arc<TlsOption>>,
tls: Arc<TlsOption>,
}

impl PostgresServer {
/// Creates a new Postgres server with provided query_handler and async runtime
pub fn new(
query_handler: SqlQueryHandlerRef,
check_pwd: bool,
tls: Option<Arc<TlsOption>>,
tls: Arc<TlsOption>,
io_runtime: Arc<Runtime>,
) -> PostgresServer {
let postgres_handler = Arc::new(PostgresServerHandler::new(query_handler));
Expand Down Expand Up @@ -98,13 +98,10 @@ impl Server for PostgresServer {
async fn start(&self, listening: SocketAddr) -> Result<SocketAddr> {
let (stream, addr) = self.base_server.bind(listening).await?;

let tls_acceptor = match &self.tls {
Some(tls) => {
let server_conf = tls.setup()?;
Some(Arc::new(TlsAcceptor::from(Arc::new(server_conf))))
}
None => None,
};
let tls_acceptor = self
.tls
.setup()?
.map(|server_conf| Arc::new(TlsAcceptor::from(Arc::new(server_conf))));

let io_runtime = self.base_server.io_runtime();
let join_handle = tokio::spawn(self.accept(io_runtime, stream, tls_acceptor));
Expand Down
6 changes: 5 additions & 1 deletion src/servers/tests/mysql/mysql_server_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ fn create_mysql_server(table: MemTable) -> Result<Box<dyn Server>> {
.build()
.unwrap(),
);
Ok(MysqlServer::create_server(query_handler, io_runtime, None))
Ok(MysqlServer::create_server(
query_handler,
io_runtime,
Default::default(),
))
}

#[tokio::test]
Expand Down
2 changes: 1 addition & 1 deletion src/servers/tests/postgres/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ fn create_postgres_server(table: MemTable, check_pwd: bool) -> Result<Box<dyn Se
Ok(Box::new(PostgresServer::new(
query_handler,
check_pwd,
None,
Default::default(),
io_runtime,
)))
}
Expand Down

0 comments on commit 6bed566

Please sign in to comment.