Skip to content

Commit

Permalink
Support manually in-place authentication token update (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
EAimTY authored May 29, 2024
1 parent d2fa07a commit 1e1251a
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 41 deletions.
12 changes: 5 additions & 7 deletions src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
//! Authentication service.

use http::{header::AUTHORIZATION, HeaderValue, Request};
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use std::task::{Context, Poll};
use tower_service::Service;

#[derive(Debug, Clone)]
pub struct AuthService<S> {
inner: S,
token: Option<Arc<HeaderValue>>,
token: Arc<RwLock<Option<HeaderValue>>>,
}

impl<S> AuthService<S> {
#[inline]
pub fn new(inner: S, token: Option<Arc<HeaderValue>>) -> Self {
pub fn new(inner: S, token: Arc<RwLock<Option<HeaderValue>>>) -> Self {
Self { inner, token }
}
}
Expand All @@ -33,10 +33,8 @@ where

#[inline]
fn call(&mut self, mut request: Request<Body>) -> Self::Future {
if let Some(token) = &self.token {
request
.headers_mut()
.insert(AUTHORIZATION, token.as_ref().clone());
if let Some(token) = self.token.read().unwrap().as_ref() {
request.headers_mut().insert(AUTHORIZATION, token.clone());
}

self.inner.call(request)
Expand Down
33 changes: 24 additions & 9 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ use crate::OpenSslResult;
#[cfg(feature = "tls")]
use crate::TlsOptions;
use http::uri::Uri;
use http::HeaderValue;

use std::str::FromStr;
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tokio::sync::mpsc::Sender;

Expand Down Expand Up @@ -104,7 +105,10 @@ impl Client {
}

let mut options = options;
let auth_token = Self::auth(channel.clone(), &mut options).await?;

let auth_token = Arc::new(RwLock::new(None));
Self::auth(channel.clone(), &mut options, &auth_token).await?;

Ok(Self::build_client(channel, tx, auth_token, options))
}

Expand Down Expand Up @@ -210,28 +214,29 @@ impl Client {
async fn auth(
channel: Channel,
options: &mut Option<ConnectOptions>,
) -> Result<Option<Arc<http::HeaderValue>>> {
auth_token: &Arc<RwLock<Option<HeaderValue>>>,
) -> Result<()> {
let user = match options {
None => return Ok(None),
None => return Ok(()),
Some(opt) => {
// Take away the user, the password should not be stored in client.
opt.user.take()
}
};

if let Some((name, password)) = user {
let mut tmp_auth = AuthClient::new(channel, None);
let mut tmp_auth = AuthClient::new(channel, auth_token.clone());
let resp = tmp_auth.authenticate(name, password).await?;
Ok(Some(Arc::new(resp.token().parse()?)))
} else {
Ok(None)
auth_token.write().unwrap().replace(resp.token().parse()?);
}

Ok(())
}

fn build_client(
channel: Channel,
tx: Sender<Change<Uri, Endpoint>>,
auth_token: Option<Arc<http::HeaderValue>>,
auth_token: Arc<RwLock<Option<HeaderValue>>>,
options: Option<ConnectOptions>,
) -> Self {
let kv = KvClient::new(channel.clone(), auth_token.clone());
Expand Down Expand Up @@ -730,6 +735,16 @@ impl Client {
pub async fn resign(&mut self, option: Option<ResignOptions>) -> Result<ResignResponse> {
self.election.resign(option).await
}

/// Sets client-side authentication.
pub async fn set_client_auth(&mut self, name: String, password: String) -> Result<()> {
self.auth.set_client_auth(name, password).await
}

/// Removes client-side authentication.
pub fn remove_client_auth(&mut self) {
self.auth.remove_client_auth();
}
}

/// Options for `Connect` operation.
Expand Down
9 changes: 3 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,9 @@
//!
//! # Feature Flags
//!
//! - `tls`: Enables the `rustls`-based TLS connection. Not
//! enabled by default.
//! - `tls-roots`: Adds system trust roots to `rustls`-based TLS connection using the
//! `rustls-native-certs` crate. Not enabled by default.
//! - `pub-response-field`: Exposes structs used to create regular `etcd-client` responses
//! including internal protobuf representations. Useful for mocking. Not enabled by default.
//! - `tls`: Enables the `rustls`-based TLS connection. Not enabled by default.
//! - `tls-roots`: Adds system trust roots to `rustls`-based TLS connection using the `rustls-native-certs` crate. Not enabled by default.
//! - `pub-response-field`: Exposes structs used to create regular `etcd-client` responses including internal protobuf representations. Useful for mocking. Not enabled by default.

#![cfg_attr(docsrs, feature(doc_cfg))]

Expand Down
32 changes: 25 additions & 7 deletions src/rpc/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,41 @@ use crate::rpc::pb::etcdserverpb::{
use crate::rpc::ResponseHeader;
use crate::rpc::{get_prefix, KeyRange};
use http::HeaderValue;
use std::sync::RwLock;
use std::{string::String, sync::Arc};
use tonic::{IntoRequest, Request};

/// Client for Auth operations.
#[repr(transparent)]
#[derive(Clone)]
pub struct AuthClient {
inner: PbAuthClient<AuthService<Channel>>,
auth_token: Arc<RwLock<Option<HeaderValue>>>,
}

impl AuthClient {
/// Creates an auth client.
#[inline]
pub(crate) fn new(channel: Channel, auth_token: Option<Arc<HeaderValue>>) -> Self {
let inner = PbAuthClient::new(AuthService::new(channel, auth_token));
Self { inner }
pub(crate) fn new(channel: Channel, auth_token: Arc<RwLock<Option<HeaderValue>>>) -> Self {
let inner = PbAuthClient::new(AuthService::new(channel, auth_token.clone()));
Self { inner, auth_token }
}

/// Sets client-side authentication.
pub async fn set_client_auth(&mut self, name: String, password: String) -> Result<()> {
let resp = self.authenticate(name, password).await?;
self.auth_token
.write()
.unwrap()
.replace(resp.token().parse()?);
Ok(())
}

/// Removes client-side authentication.
pub fn remove_client_auth(&mut self) {
self.auth_token.write().unwrap().take();
}

/// Enables authentication.
/// Enables authentication for the etcd cluster.
#[inline]
pub async fn auth_enable(&mut self) -> Result<AuthEnableResponse> {
let resp = self
Expand All @@ -64,7 +80,7 @@ impl AuthClient {
Ok(AuthEnableResponse::new(resp))
}

/// Disables authentication.
/// Disables authentication for the etcd cluster.
#[inline]
pub async fn auth_disable(&mut self) -> Result<AuthDisableResponse> {
let resp = self
Expand All @@ -75,7 +91,9 @@ impl AuthClient {
Ok(AuthDisableResponse::new(resp))
}

/// Processes an authenticate request.
/// Sends an authenticate request.
/// Note that this does not set or update client-side authentication settings.
/// Call [`set_client_auth`] to set or update client-side authentication.
#[inline]
pub async fn authenticate(
&mut self,
Expand Down
3 changes: 2 additions & 1 deletion src/rpc/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::rpc::pb::etcdserverpb::{
};
use crate::rpc::ResponseHeader;
use http::HeaderValue;
use std::sync::RwLock;
use std::{string::String, sync::Arc};
use tonic::{IntoRequest, Request};

Expand All @@ -27,7 +28,7 @@ pub struct ClusterClient {
impl ClusterClient {
/// Creates an Cluster client.
#[inline]
pub(crate) fn new(channel: Channel, auth_token: Option<Arc<HeaderValue>>) -> Self {
pub(crate) fn new(channel: Channel, auth_token: Arc<RwLock<Option<HeaderValue>>>) -> Self {
let inner = PbClusterClient::new(AuthService::new(channel, auth_token));
Self { inner }
}
Expand Down
3 changes: 2 additions & 1 deletion src/rpc/election.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::rpc::pb::v3electionpb::{
};
use crate::rpc::{KeyValue, ResponseHeader};
use http::HeaderValue;
use std::sync::RwLock;
use std::task::{Context, Poll};
use std::{pin::Pin, sync::Arc};
use tokio_stream::Stream;
Expand Down Expand Up @@ -486,7 +487,7 @@ impl From<&PbLeaderKey> for &LeaderKey {
impl ElectionClient {
/// Creates a election
#[inline]
pub(crate) fn new(channel: Channel, auth_token: Option<Arc<HeaderValue>>) -> Self {
pub(crate) fn new(channel: Channel, auth_token: Arc<RwLock<Option<HeaderValue>>>) -> Self {
let inner = PbElectionClient::new(AuthService::new(channel, auth_token));
Self { inner }
}
Expand Down
4 changes: 2 additions & 2 deletions src/rpc/kv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::rpc::{get_prefix, KeyRange, KeyValue, ResponseHeader};
use crate::vec::VecExt;
use http::HeaderValue;
use std::mem::ManuallyDrop;
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use tonic::{IntoRequest, Request};

/// Client for KV operations.
Expand All @@ -35,7 +35,7 @@ pub struct KvClient {
impl KvClient {
/// Creates a kv client.
#[inline]
pub(crate) fn new(channel: Channel, auth_token: Option<Arc<HeaderValue>>) -> Self {
pub(crate) fn new(channel: Channel, auth_token: Arc<RwLock<Option<HeaderValue>>>) -> Self {
let inner = PbKvClient::new(AuthService::new(channel, auth_token));
Self { inner }
}
Expand Down
4 changes: 2 additions & 2 deletions src/rpc/lease.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::vec::VecExt;
use crate::Error;
use http::HeaderValue;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use std::task::{Context, Poll};
use tokio::sync::mpsc::{channel, Sender};
use tokio_stream::wrappers::ReceiverStream;
Expand All @@ -35,7 +35,7 @@ pub struct LeaseClient {
impl LeaseClient {
/// Creates a `LeaseClient`.
#[inline]
pub(crate) fn new(channel: Channel, auth_token: Option<Arc<HeaderValue>>) -> Self {
pub(crate) fn new(channel: Channel, auth_token: Arc<RwLock<Option<HeaderValue>>>) -> Self {
let inner = PbLeaseClient::new(AuthService::new(channel, auth_token));
Self { inner }
}
Expand Down
4 changes: 2 additions & 2 deletions src/rpc/lock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::channel::Channel;
use crate::error::Result;
use crate::rpc::ResponseHeader;
use http::HeaderValue;
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use tonic::{IntoRequest, Request};
use v3lockpb::lock_client::LockClient as PbLockClient;
use v3lockpb::{
Expand All @@ -24,7 +24,7 @@ pub struct LockClient {
impl LockClient {
/// Creates a lock client.
#[inline]
pub(crate) fn new(channel: Channel, auth_token: Option<Arc<HeaderValue>>) -> Self {
pub(crate) fn new(channel: Channel, auth_token: Arc<RwLock<Option<HeaderValue>>>) -> Self {
let inner = PbLockClient::new(AuthService::new(channel, auth_token));
Self { inner }
}
Expand Down
4 changes: 2 additions & 2 deletions src/rpc/maintenance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::rpc::ResponseHeader;
use etcdserverpb::maintenance_client::MaintenanceClient as PbMaintenanceClient;
use etcdserverpb::AlarmMember as PbAlarmMember;
use http::HeaderValue;
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use tonic::codec::Streaming as PbStreaming;
use tonic::{IntoRequest, Request};

Expand Down Expand Up @@ -556,7 +556,7 @@ impl MoveLeaderResponse {
impl MaintenanceClient {
/// Creates a maintenance client.
#[inline]
pub(crate) fn new(channel: Channel, auth_token: Option<Arc<HeaderValue>>) -> Self {
pub(crate) fn new(channel: Channel, auth_token: Arc<RwLock<Option<HeaderValue>>>) -> Self {
let inner = PbMaintenanceClient::new(AuthService::new(channel, auth_token));
Self { inner }
}
Expand Down
4 changes: 2 additions & 2 deletions src/rpc/watch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::rpc::pb::mvccpb::Event as PbEvent;
use crate::rpc::{KeyRange, KeyValue, ResponseHeader};
use http::HeaderValue;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use std::task::{Context, Poll};
use tokio::sync::mpsc::{channel, Sender};
use tokio_stream::{wrappers::ReceiverStream, Stream};
Expand All @@ -31,7 +31,7 @@ pub struct WatchClient {
impl WatchClient {
/// Creates a watch client.
#[inline]
pub(crate) fn new(channel: Channel, auth_token: Option<Arc<HeaderValue>>) -> Self {
pub(crate) fn new(channel: Channel, auth_token: Arc<RwLock<Option<HeaderValue>>>) -> Self {
let inner = PbWatchClient::new(AuthService::new(channel, auth_token));
Self { inner }
}
Expand Down

0 comments on commit 1e1251a

Please sign in to comment.