Skip to content

Commit

Permalink
Postgres TLS Support (#403)
Browse files Browse the repository at this point in the history
* Implement TlsStream for Socket

* starttls trait

* Implement TlsConnection in workers-rs

* Update example

* fmt

* docs and fix example
  • Loading branch information
kflansburg authored Nov 8, 2023
1 parent 1d30714 commit 92a1127
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 10 deletions.
28 changes: 22 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions examples/tokio-postgres/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
name = "tokio-postgres-on-workers"
version = "0.1.0"
edition = "2021"
resolver = "2"

# https://github.com/rustwasm/wasm-pack/issues/1247
[package.metadata.wasm-pack.profile.release]
Expand All @@ -11,5 +12,5 @@ wasm-opt = false
crate-type = ["cdylib"]

[dependencies]
worker = { workspace=true }
tokio-postgres = { git="https://github.com/sfackler/rust-postgres", branch="master", features=['js'], default-features=false }
worker = { workspace=true, features=["tokio-postgres"] }
tokio-postgres = { version="0.7", features=['js'], default-features=false }
7 changes: 5 additions & 2 deletions examples/tokio-postgres/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use worker::postgres_tls::PassthroughTls;
use worker::*;

#[event(fetch)]
Expand All @@ -7,9 +8,11 @@ async fn main(_req: Request, _env: Env, _ctx: Context) -> Result<Response> {
config.user("postgres");

// Connect using Worker Socket
let socket = Socket::builder().connect("database_url", 5432)?;
let socket = Socket::builder()
.secure_transport(SecureTransport::StartTls)
.connect("database_url", 5432)?;
let (_client, connection) = config
.connect_raw(socket, tokio_postgres::tls::NoTls)
.connect_raw(socket, PassthroughTls)
.await
.map_err(|e| worker::Error::RustError(format!("tokio-postgres: {:?}", e)))?;

Expand Down
6 changes: 6 additions & 0 deletions worker/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ features = [
"WritableStreamDefaultWriter"
]

[dependencies.tokio-postgres]
version = "0.7"
default-features=false
features = ["js"]
optional = true

[features]
queue = ["worker-macros/queue", "worker-sys/queue"]
d1 = ["worker-sys/d1"]
54 changes: 54 additions & 0 deletions worker/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,60 @@ fn handle_data(buf: &mut ReadBuf<'_>, mut data: Vec<u8>) -> (Reading, Poll<IoRes
}
}

#[cfg(feature = "tokio-postgres")]
/// Implements [`TlsConnect`](tokio_postgres::TlsConnect) for
/// [`Socket`](crate::Socket) to enable `tokio_postgres` connections
/// to databases using TLS.
pub mod postgres_tls {
use super::Socket;
use futures_util::future::{ready, Ready};
use std::error::Error;
use std::fmt::{self, Display, Formatter};
use tokio_postgres::tls::{ChannelBinding, TlsConnect, TlsStream};

/// Supply this to `connect_raw` in place of `NoTls` to specify TLS
/// when using Workers.
///
/// ```rust
/// let config = tokio_postgres::config::Config::new();
/// let socket = Socket::builder()
/// .secure_transport(SecureTransport::StartTls)
/// .connect("database_url", 5432)?;
/// let _ = config.connect_raw(socket, PassthroughTls).await?;
/// ```
pub struct PassthroughTls;

#[derive(Debug)]
/// Error type for PassthroughTls.
/// Should never be returned.
pub struct PassthroughTlsError;

impl Error for PassthroughTlsError {}

impl Display for PassthroughTlsError {
fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result {
fmt.write_str("PassthroughTlsError")
}
}

impl TlsConnect<Socket> for PassthroughTls {
type Stream = Socket;
type Error = PassthroughTlsError;
type Future = Ready<Result<Socket, PassthroughTlsError>>;

fn connect(self, s: Self::Stream) -> Self::Future {
let tls = s.start_tls();
ready(Ok(tls))
}
}

impl TlsStream for Socket {
fn channel_binding(&self) -> ChannelBinding {
ChannelBinding::none()
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit 92a1127

Please sign in to comment.