Skip to content

Commit

Permalink
Revert "jwt_authn: Add logic to refetch JWT on KID mismatch (#36458)" (
Browse files Browse the repository at this point in the history
  • Loading branch information
wbpcode authored Dec 19, 2024
1 parent 857107b commit b9c4ff2
Show file tree
Hide file tree
Showing 18 changed files with 65 additions and 704 deletions.
19 changes: 0 additions & 19 deletions api/envoy/extensions/filters/http/jwt_authn/v3/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,6 @@ message JwtCacheConfig {
}

// This message specifies how to fetch JWKS from remote and how to cache it.
// [#next-free-field: 6]
message RemoteJwks {
option (udpa.annotations.versioning).previous_message_type =
"envoy.config.filter.http.jwt_authn.v2alpha.RemoteJwks";
Expand Down Expand Up @@ -453,24 +452,6 @@ message RemoteJwks {
//
//
config.core.v3.RetryPolicy retry_policy = 4;

// Refetch JWKS if extracted JWT has no KID or a KID that does not match any cached JWKS's KID.
//
//
// In envoy, if :ref:`async JWKS fetching <envoy_v3_api_field_extensions.filters.http.jwt_authn.v3.RemoteJwks.async_fetch>`
// is enabled along with this field, then KID mismatch will trigger a new async fetch after appropriate backoff delay.
//
//
// If async fetching is disabled, new JWKS is fetched on demand and the cache is isolated to the fetched worker thread.
//
// There is exponential backoff built into this retrieval system for two cases to avoid DoS on JWKS Server:
//
// * If there is a request containing a JWT with no KID, a new fetch will be made for this request. Upon retrieval,
// a backoff will be triggered.
// * If there is a fetch due to KID mismatch, which results in a failed fetch or verification, a backoff will be triggered.
//
// During a backoff, no further fetches will be made due to KID mismatch.
bool refetch_jwks_on_kid_mismatch = 5;
}

// Fetch Jwks asynchronously in the main thread when the filter config is parsed.
Expand Down
5 changes: 0 additions & 5 deletions changelogs/current.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -342,11 +342,6 @@ new_features:
change: |
Added :ref:`attribute <arch_overview_attributes>` ``upstream.cx_pool_ready_duration``
to get the duration from when the upstream request was created to when the upstream connection pool is ready.
- area: jwt_authn
change: |
Added :ref:`refetch_jwks_on_kid_mismatch
<envoy_v3_api_field_extensions.filters.http.jwt_authn.v3.RemoteJwks.refetch_jwks_on_kid_mismatch>`
to allow filter to refetch JWKS when extracted JWT's KID does not match cached JWKS's KID.
- area: health_check
change: |
Added new health check filter stats including total requests, successful/failed checks, cached responses, and
Expand Down
1 change: 0 additions & 1 deletion source/extensions/filters/http/jwt_authn/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ envoy_cc_library(
"jwks_async_fetcher_lib",
":jwt_cache_lib",
"//source/common/config:datasource_lib",
"//source/common/config:utility_lib",
"@com_github_google_jwt_verify//:jwt_verify_lib",
"@envoy_api//envoy/extensions/filters/http/jwt_authn/v3:pkg_cc_proto",
],
Expand Down
69 changes: 15 additions & 54 deletions source/extensions/filters/http/jwt_authn/authenticator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,11 @@ class AuthenticatorImpl : public Logger::Loggable<Logger::Id::jwt>,
const absl::optional<std::string>& provider, bool allow_failed,
bool allow_missing, JwksCache& jwks_cache,
Upstream::ClusterManager& cluster_manager,
CreateJwksFetcherCb create_jwks_fetcher_cb, TimeSource& time_source,
Event::Dispatcher& dispatcher)
CreateJwksFetcherCb create_jwks_fetcher_cb, TimeSource& time_source)
: jwks_cache_(jwks_cache), cm_(cluster_manager),
create_jwks_fetcher_cb_(create_jwks_fetcher_cb), check_audience_(check_audience),
provider_(provider), is_allow_failed_(allow_failed), is_allow_missing_(allow_missing),
time_source_(time_source), dispatcher_(dispatcher) {}
time_source_(time_source) {}
// Following functions are for JwksFetcher::JwksReceiver interface
void onJwksSuccess(google::jwt_verify::JwksPtr&& jwks) override;
void onJwksError(Failure reason) override;
Expand Down Expand Up @@ -131,9 +130,6 @@ class AuthenticatorImpl : public Logger::Loggable<Logger::Id::jwt>,
const bool is_allow_missing_;
TimeSource& time_source_;
::google::jwt_verify::Jwt* jwt_{};
Event::Dispatcher& dispatcher_;
// Set to true if the JWT in this request caused a KID mismatch.
bool kid_mismatch_request_{false};
};

std::string AuthenticatorImpl::name() const {
Expand Down Expand Up @@ -272,24 +268,16 @@ void AuthenticatorImpl::startVerify() {

auto jwks_obj = jwks_data_->getJwksObj();
if (jwks_obj != nullptr && !jwks_data_->isExpired()) {
// If KID mismatch fetching is set and fetch is disallowed, trigger the
// verification process.
if (jwks_data_->getJwtProvider().remote_jwks().refetch_jwks_on_kid_mismatch() &&
jwks_data_->isRemoteJwksFetchAllowed()) {
if (!jwt_->kid_.empty()) {
for (const auto& jwk : jwks_obj->keys()) {
if (jwk->kid_ == jwt_->kid_) {
verifyKey();
return;
}
}
}
ENVOY_LOG(info, "Triggering refetch of JWKS due to KID mismatch");
kid_mismatch_request_ = true;
} else {
verifyKey();
return;
}
// TODO(qiwzhang): It would seem there's a window of error whereby if the JWT issuer
// has started signing with a new key that's not in our cache, then the
// verification will fail even though the JWT is valid. A simple fix
// would be to check the JWS kid header field; if present check we have
// the key cached, if we do proceed to verify else try a new JWKS retrieval.
// JWTs without a kid header field in the JWS we might be best to get each
// time? This all only matters for remote JWKS.

verifyKey();
return;
}

// TODO(potatop): potential optimization.
Expand All @@ -302,10 +290,6 @@ void AuthenticatorImpl::startVerify() {
fetcher_ = create_jwks_fetcher_cb_(cm_, jwks_data_->getJwtProvider().remote_jwks());
}
fetcher_->fetch(*parent_span_, *this);
// Disallow fetches across other worker threads due to in-flight fetch.
if (kid_mismatch_request_) {
dispatcher_.post([this]() { jwks_data_->allowRemoteJwksFetch(absl::nullopt, true); });
}
return;
}
// No valid keys for this issuer. This may happen as a result of incorrect local
Expand All @@ -315,14 +299,7 @@ void AuthenticatorImpl::startVerify() {

void AuthenticatorImpl::onJwksSuccess(google::jwt_verify::JwksPtr&& jwks) {
jwks_cache_.stats().jwks_fetch_success_.inc();

const Status status =
jwks_data_
->setRemoteJwks(std::move(jwks),
jwks_data_->getJwtProvider().remote_jwks().has_async_fetch() &&
kid_mismatch_request_)
->getStatus();

const Status status = jwks_data_->setRemoteJwks(std::move(jwks))->getStatus();
if (status != Status::Ok) {
doneWithStatus(status);
} else {
Expand Down Expand Up @@ -442,15 +419,6 @@ void AuthenticatorImpl::handleGoodJwt(bool cache_hit) {
// move the ownership of "owned_jwt_" into the function.
jwks_data_->getJwtCache().insert(curr_token_->token(), std::move(owned_jwt_));
}

// On successful retrieval & verification of JWKS due to KID mismatch,
// - If retrieved due to KID being empty, trigger backoff.
// - Else, shut down backoff.
if (kid_mismatch_request_ && !jwks_data_->isRemoteJwksFetchAllowed()) {
dispatcher_.post([this]() { jwks_data_->allowRemoteJwksFetch(!jwt_->kid_.empty(), false); });
kid_mismatch_request_ = false;
}

doneWithStatus(Status::Ok);
}

Expand Down Expand Up @@ -483,13 +451,6 @@ void AuthenticatorImpl::doneWithStatus(const Status& status) {
// Forward the failed status to dynamic metadata
ENVOY_LOG(debug, "status is: {}", ::google::jwt_verify::getStatusString(status));

// Trigger backoff to disallow further fetches when retrieval & verification is due to KID
// mismatch and fails.
if (kid_mismatch_request_ && !jwks_data_->isRemoteJwksFetchAllowed()) {
dispatcher_.post([this]() { jwks_data_->allowRemoteJwksFetch(false, false); });
kid_mismatch_request_ = false;
}

std::string failed_status_in_metadata;

if (jwks_data_) {
Expand Down Expand Up @@ -546,10 +507,10 @@ AuthenticatorPtr Authenticator::create(const CheckAudience* check_audience,
bool allow_failed, bool allow_missing, JwksCache& jwks_cache,
Upstream::ClusterManager& cluster_manager,
CreateJwksFetcherCb create_jwks_fetcher_cb,
TimeSource& time_source, Event::Dispatcher& dispatcher) {
TimeSource& time_source) {
return std::make_unique<AuthenticatorImpl>(check_audience, provider, allow_failed, allow_missing,
jwks_cache, cluster_manager, create_jwks_fetcher_cb,
time_source, dispatcher);
time_source);
}

} // namespace JwtAuthn
Expand Down
2 changes: 1 addition & 1 deletion source/extensions/filters/http/jwt_authn/authenticator.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Authenticator {
bool allow_missing, JwksCache& jwks_cache,
Upstream::ClusterManager& cluster_manager,
CreateJwksFetcherCb create_jwks_fetcher_cb,
TimeSource& time_source, Event::Dispatcher& dispatcher);
TimeSource& time_source);
};

/**
Expand Down
3 changes: 1 addition & 2 deletions source/extensions/filters/http/jwt_authn/filter_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ FilterConfigImpl::FilterConfigImpl(
const std::string& stats_prefix, Server::Configuration::FactoryContext& context)
: proto_config_(std::move(proto_config)), stats_(generateStats(stats_prefix, context.scope())),
cm_(context.serverFactoryContext().clusterManager()),
time_source_(context.serverFactoryContext().mainThreadDispatcher().timeSource()),
dispatcher_(context.serverFactoryContext().mainThreadDispatcher()) {
time_source_(context.serverFactoryContext().mainThreadDispatcher().timeSource()) {

ENVOY_LOG(debug, "Loaded JwtAuthConfig: {}", proto_config_.DebugString());

Expand Down
6 changes: 1 addition & 5 deletions source/extensions/filters/http/jwt_authn/filter_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ class FilterConfigImpl : public Logger::Loggable<Logger::Id::jwt>,
Upstream::ClusterManager& cm() const { return cm_; }
TimeSource& timeSource() const { return time_source_; }

Event::Dispatcher& dispatcher() const { return dispatcher_; }

// FilterConfig

JwtAuthnFilterStats& stats() override { return stats_; }
Expand Down Expand Up @@ -114,8 +112,7 @@ class FilterConfigImpl : public Logger::Loggable<Logger::Id::jwt>,
const absl::optional<std::string>& provider, bool allow_failed,
bool allow_missing) const override {
return Authenticator::create(check_audience, provider, allow_failed, allow_missing,
getJwksCache(), cm(), Common::JwksFetcher::create, timeSource(),
dispatcher());
getJwksCache(), cm(), Common::JwksFetcher::create, timeSource());
}

private:
Expand Down Expand Up @@ -150,7 +147,6 @@ class FilterConfigImpl : public Logger::Loggable<Logger::Id::jwt>,
// all requirement_names for debug
std::string all_requirement_names_;
TimeSource& time_source_;
Event::Dispatcher& dispatcher_;
};

} // namespace JwtAuthn
Expand Down
28 changes: 3 additions & 25 deletions source/extensions/filters/http/jwt_authn/jwks_async_fetcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,9 @@ std::chrono::milliseconds getFailedRefetchDuration(const JwksAsyncFetch& async_f
JwksAsyncFetcher::JwksAsyncFetcher(const RemoteJwks& remote_jwks,
Server::Configuration::FactoryContext& context,
CreateJwksFetcherCb create_fetcher_fn,
JwtAuthnFilterStats& stats, JwksDoneFetched done_fn,
isRemoteJwksFetchAllowedCb is_fetch_allowed_fn,
allowRemoteJwksFetchCb allow_fetch_fn)
JwtAuthnFilterStats& stats, JwksDoneFetched done_fn)
: remote_jwks_(remote_jwks), context_(context), create_fetcher_fn_(create_fetcher_fn),
stats_(stats), done_fn_(done_fn), is_fetch_allowed_fn_(is_fetch_allowed_fn),
allow_fetch_fn_(allow_fetch_fn),
stats_(stats), done_fn_(done_fn),
debug_name_(absl::StrCat("Jwks async fetching url=", remote_jwks_.http_uri().uri())) {
// if async_fetch is not enabled, do nothing.
if (!remote_jwks_.has_async_fetch()) {
Expand Down Expand Up @@ -81,25 +78,14 @@ std::chrono::seconds JwksAsyncFetcher::getCacheDuration(const RemoteJwks& remote
return DefaultCacheExpirationSec;
}

void JwksAsyncFetcher::resetFetchTimer() { refetch_timer_->enableTimer(good_refetch_duration_); }

void JwksAsyncFetcher::fetch() {
if (remote_jwks_.refetch_jwks_on_kid_mismatch() && !is_fetch_allowed_fn_()) {
resetFetchTimer();
return;
}

if (fetcher_) {
fetcher_->cancel();
}

ENVOY_LOG(debug, "{}: started", debug_name_);
fetcher_ = create_fetcher_fn_(context_.serverFactoryContext().clusterManager(), remote_jwks_);
fetcher_->fetch(Tracing::NullSpan::instance(), *this);

if (remote_jwks_.refetch_jwks_on_kid_mismatch()) {
allow_fetch_fn_(absl::nullopt, true);
}
}

void JwksAsyncFetcher::handleFetchDone() {
Expand All @@ -111,12 +97,8 @@ void JwksAsyncFetcher::handleFetchDone() {

void JwksAsyncFetcher::onJwksSuccess(google::jwt_verify::JwksPtr&& jwks) {
done_fn_(std::move(jwks));
if (remote_jwks_.refetch_jwks_on_kid_mismatch()) {
// Don't modify backoff for async fetch.
allow_fetch_fn_(absl::nullopt, false);
}
handleFetchDone();
resetFetchTimer();
refetch_timer_->enableTimer(good_refetch_duration_);
stats_.jwks_fetch_success_.inc();

// Note: not to free fetcher_ within onJwksSuccess or onJwksError function.
Expand All @@ -131,10 +113,6 @@ void JwksAsyncFetcher::onJwksSuccess(google::jwt_verify::JwksPtr&& jwks) {

void JwksAsyncFetcher::onJwksError(Failure) {
ENVOY_LOG(warn, "{}: failed", debug_name_);
if (remote_jwks_.refetch_jwks_on_kid_mismatch()) {
// Don't modify backoff for async fetch.
allow_fetch_fn_(absl::nullopt, false);
}
handleFetchDone();
refetch_timer_->enableTimer(failed_refetch_duration_);
stats_.jwks_fetch_failed_.inc();
Expand Down
15 changes: 1 addition & 14 deletions source/extensions/filters/http/jwt_authn/jwks_async_fetcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ using CreateJwksFetcherCb = std::function<Common::JwksFetcherPtr(
*/
using JwksDoneFetched = std::function<void(google::jwt_verify::JwksPtr&& jwks)>;

using isRemoteJwksFetchAllowedCb = std::function<bool(void)>;

using allowRemoteJwksFetchCb = std::function<void(absl::optional<bool>, bool)>;

// This class handles fetching Jwks asynchronously.
// It will be no-op if async_fetch is not enabled.
// At its constructor, it will start to fetch Jwks, register with init_manager if not fast_listener.
Expand All @@ -39,17 +35,12 @@ class JwksAsyncFetcher : public Logger::Loggable<Logger::Id::jwt>,
public:
JwksAsyncFetcher(const envoy::extensions::filters::http::jwt_authn::v3::RemoteJwks& remote_jwks,
Server::Configuration::FactoryContext& context, CreateJwksFetcherCb fetcher_fn,
JwtAuthnFilterStats& stats, JwksDoneFetched done_fn,
isRemoteJwksFetchAllowedCb is_fetch_allowed_fn,
allowRemoteJwksFetchCb allow_fetch_fn);
JwtAuthnFilterStats& stats, JwksDoneFetched done_fn);

// Get the remote Jwks cache duration.
static std::chrono::seconds
getCacheDuration(const envoy::extensions::filters::http::jwt_authn::v3::RemoteJwks& remote_jwks);

// Reset async fetch timer to default value.
void resetFetchTimer();

private:
// Fetch the Jwks
void fetch();
Expand Down Expand Up @@ -84,10 +75,6 @@ class JwksAsyncFetcher : public Logger::Loggable<Logger::Id::jwt>,
// The init target.
std::unique_ptr<Init::TargetImpl> init_target_;

const isRemoteJwksFetchAllowedCb is_fetch_allowed_fn_;

const allowRemoteJwksFetchCb allow_fetch_fn_;

// Used in logs.
const std::string debug_name_;
};
Expand Down
Loading

0 comments on commit b9c4ff2

Please sign in to comment.