Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "jwt_authn: Add logic to refetch JWT on KID mismatch (#36458)" #37763

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions api/envoy/extensions/filters/http/jwt_authn/v3/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,6 @@ message JwtCacheConfig {
}

// This message specifies how to fetch JWKS from remote and how to cache it.
// [#next-free-field: 6]
message RemoteJwks {
option (udpa.annotations.versioning).previous_message_type =
"envoy.config.filter.http.jwt_authn.v2alpha.RemoteJwks";
Expand Down Expand Up @@ -453,24 +452,6 @@ message RemoteJwks {
//
//
config.core.v3.RetryPolicy retry_policy = 4;

// Refetch JWKS if extracted JWT has no KID or a KID that does not match any cached JWKS's KID.
//
//
// In envoy, if :ref:`async JWKS fetching <envoy_v3_api_field_extensions.filters.http.jwt_authn.v3.RemoteJwks.async_fetch>`
// is enabled along with this field, then KID mismatch will trigger a new async fetch after appropriate backoff delay.
//
//
// If async fetching is disabled, new JWKS is fetched on demand and the cache is isolated to the fetched worker thread.
//
// There is exponential backoff built into this retrieval system for two cases to avoid DoS on JWKS Server:
//
// * If there is a request containing a JWT with no KID, a new fetch will be made for this request. Upon retrieval,
// a backoff will be triggered.
// * If there is a fetch due to KID mismatch, which results in a failed fetch or verification, a backoff will be triggered.
//
// During a backoff, no further fetches will be made due to KID mismatch.
bool refetch_jwks_on_kid_mismatch = 5;
}

// Fetch Jwks asynchronously in the main thread when the filter config is parsed.
Expand Down
5 changes: 0 additions & 5 deletions changelogs/current.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -342,11 +342,6 @@ new_features:
change: |
Added :ref:`attribute <arch_overview_attributes>` ``upstream.cx_pool_ready_duration``
to get the duration from when the upstream request was created to when the upstream connection pool is ready.
- area: jwt_authn
change: |
Added :ref:`refetch_jwks_on_kid_mismatch
<envoy_v3_api_field_extensions.filters.http.jwt_authn.v3.RemoteJwks.refetch_jwks_on_kid_mismatch>`
to allow filter to refetch JWKS when extracted JWT's KID does not match cached JWKS's KID.
- area: health_check
change: |
Added new health check filter stats including total requests, successful/failed checks, cached responses, and
Expand Down
1 change: 0 additions & 1 deletion source/extensions/filters/http/jwt_authn/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ envoy_cc_library(
"jwks_async_fetcher_lib",
":jwt_cache_lib",
"//source/common/config:datasource_lib",
"//source/common/config:utility_lib",
"@com_github_google_jwt_verify//:jwt_verify_lib",
"@envoy_api//envoy/extensions/filters/http/jwt_authn/v3:pkg_cc_proto",
],
Expand Down
69 changes: 15 additions & 54 deletions source/extensions/filters/http/jwt_authn/authenticator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,11 @@ class AuthenticatorImpl : public Logger::Loggable<Logger::Id::jwt>,
const absl::optional<std::string>& provider, bool allow_failed,
bool allow_missing, JwksCache& jwks_cache,
Upstream::ClusterManager& cluster_manager,
CreateJwksFetcherCb create_jwks_fetcher_cb, TimeSource& time_source,
Event::Dispatcher& dispatcher)
CreateJwksFetcherCb create_jwks_fetcher_cb, TimeSource& time_source)
: jwks_cache_(jwks_cache), cm_(cluster_manager),
create_jwks_fetcher_cb_(create_jwks_fetcher_cb), check_audience_(check_audience),
provider_(provider), is_allow_failed_(allow_failed), is_allow_missing_(allow_missing),
time_source_(time_source), dispatcher_(dispatcher) {}
time_source_(time_source) {}
// Following functions are for JwksFetcher::JwksReceiver interface
void onJwksSuccess(google::jwt_verify::JwksPtr&& jwks) override;
void onJwksError(Failure reason) override;
Expand Down Expand Up @@ -131,9 +130,6 @@ class AuthenticatorImpl : public Logger::Loggable<Logger::Id::jwt>,
const bool is_allow_missing_;
TimeSource& time_source_;
::google::jwt_verify::Jwt* jwt_{};
Event::Dispatcher& dispatcher_;
// Set to true if the JWT in this request caused a KID mismatch.
bool kid_mismatch_request_{false};
};

std::string AuthenticatorImpl::name() const {
Expand Down Expand Up @@ -272,24 +268,16 @@ void AuthenticatorImpl::startVerify() {

auto jwks_obj = jwks_data_->getJwksObj();
if (jwks_obj != nullptr && !jwks_data_->isExpired()) {
// If KID mismatch fetching is set and fetch is disallowed, trigger the
// verification process.
if (jwks_data_->getJwtProvider().remote_jwks().refetch_jwks_on_kid_mismatch() &&
jwks_data_->isRemoteJwksFetchAllowed()) {
if (!jwt_->kid_.empty()) {
for (const auto& jwk : jwks_obj->keys()) {
if (jwk->kid_ == jwt_->kid_) {
verifyKey();
return;
}
}
}
ENVOY_LOG(info, "Triggering refetch of JWKS due to KID mismatch");
kid_mismatch_request_ = true;
} else {
verifyKey();
return;
}
// TODO(qiwzhang): It would seem there's a window of error whereby if the JWT issuer
// has started signing with a new key that's not in our cache, then the
// verification will fail even though the JWT is valid. A simple fix
// would be to check the JWS kid header field; if present check we have
// the key cached, if we do proceed to verify else try a new JWKS retrieval.
// JWTs without a kid header field in the JWS we might be best to get each
// time? This all only matters for remote JWKS.

verifyKey();
return;
}

// TODO(potatop): potential optimization.
Expand All @@ -302,10 +290,6 @@ void AuthenticatorImpl::startVerify() {
fetcher_ = create_jwks_fetcher_cb_(cm_, jwks_data_->getJwtProvider().remote_jwks());
}
fetcher_->fetch(*parent_span_, *this);
// Disallow fetches across other worker threads due to in-flight fetch.
if (kid_mismatch_request_) {
dispatcher_.post([this]() { jwks_data_->allowRemoteJwksFetch(absl::nullopt, true); });
}
return;
}
// No valid keys for this issuer. This may happen as a result of incorrect local
Expand All @@ -315,14 +299,7 @@ void AuthenticatorImpl::startVerify() {

void AuthenticatorImpl::onJwksSuccess(google::jwt_verify::JwksPtr&& jwks) {
jwks_cache_.stats().jwks_fetch_success_.inc();

const Status status =
jwks_data_
->setRemoteJwks(std::move(jwks),
jwks_data_->getJwtProvider().remote_jwks().has_async_fetch() &&
kid_mismatch_request_)
->getStatus();

const Status status = jwks_data_->setRemoteJwks(std::move(jwks))->getStatus();
if (status != Status::Ok) {
doneWithStatus(status);
} else {
Expand Down Expand Up @@ -442,15 +419,6 @@ void AuthenticatorImpl::handleGoodJwt(bool cache_hit) {
// move the ownership of "owned_jwt_" into the function.
jwks_data_->getJwtCache().insert(curr_token_->token(), std::move(owned_jwt_));
}

// On successful retrieval & verification of JWKS due to KID mismatch,
// - If retrieved due to KID being empty, trigger backoff.
// - Else, shut down backoff.
if (kid_mismatch_request_ && !jwks_data_->isRemoteJwksFetchAllowed()) {
dispatcher_.post([this]() { jwks_data_->allowRemoteJwksFetch(!jwt_->kid_.empty(), false); });
kid_mismatch_request_ = false;
}

doneWithStatus(Status::Ok);
}

Expand Down Expand Up @@ -483,13 +451,6 @@ void AuthenticatorImpl::doneWithStatus(const Status& status) {
// Forward the failed status to dynamic metadata
ENVOY_LOG(debug, "status is: {}", ::google::jwt_verify::getStatusString(status));

// Trigger backoff to disallow further fetches when retrieval & verification is due to KID
// mismatch and fails.
if (kid_mismatch_request_ && !jwks_data_->isRemoteJwksFetchAllowed()) {
dispatcher_.post([this]() { jwks_data_->allowRemoteJwksFetch(false, false); });
kid_mismatch_request_ = false;
}

std::string failed_status_in_metadata;

if (jwks_data_) {
Expand Down Expand Up @@ -546,10 +507,10 @@ AuthenticatorPtr Authenticator::create(const CheckAudience* check_audience,
bool allow_failed, bool allow_missing, JwksCache& jwks_cache,
Upstream::ClusterManager& cluster_manager,
CreateJwksFetcherCb create_jwks_fetcher_cb,
TimeSource& time_source, Event::Dispatcher& dispatcher) {
TimeSource& time_source) {
return std::make_unique<AuthenticatorImpl>(check_audience, provider, allow_failed, allow_missing,
jwks_cache, cluster_manager, create_jwks_fetcher_cb,
time_source, dispatcher);
time_source);
}

} // namespace JwtAuthn
Expand Down
2 changes: 1 addition & 1 deletion source/extensions/filters/http/jwt_authn/authenticator.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Authenticator {
bool allow_missing, JwksCache& jwks_cache,
Upstream::ClusterManager& cluster_manager,
CreateJwksFetcherCb create_jwks_fetcher_cb,
TimeSource& time_source, Event::Dispatcher& dispatcher);
TimeSource& time_source);
};

/**
Expand Down
3 changes: 1 addition & 2 deletions source/extensions/filters/http/jwt_authn/filter_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ FilterConfigImpl::FilterConfigImpl(
const std::string& stats_prefix, Server::Configuration::FactoryContext& context)
: proto_config_(std::move(proto_config)), stats_(generateStats(stats_prefix, context.scope())),
cm_(context.serverFactoryContext().clusterManager()),
time_source_(context.serverFactoryContext().mainThreadDispatcher().timeSource()),
dispatcher_(context.serverFactoryContext().mainThreadDispatcher()) {
time_source_(context.serverFactoryContext().mainThreadDispatcher().timeSource()) {

ENVOY_LOG(debug, "Loaded JwtAuthConfig: {}", proto_config_.DebugString());

Expand Down
6 changes: 1 addition & 5 deletions source/extensions/filters/http/jwt_authn/filter_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ class FilterConfigImpl : public Logger::Loggable<Logger::Id::jwt>,
Upstream::ClusterManager& cm() const { return cm_; }
TimeSource& timeSource() const { return time_source_; }

Event::Dispatcher& dispatcher() const { return dispatcher_; }

// FilterConfig

JwtAuthnFilterStats& stats() override { return stats_; }
Expand Down Expand Up @@ -114,8 +112,7 @@ class FilterConfigImpl : public Logger::Loggable<Logger::Id::jwt>,
const absl::optional<std::string>& provider, bool allow_failed,
bool allow_missing) const override {
return Authenticator::create(check_audience, provider, allow_failed, allow_missing,
getJwksCache(), cm(), Common::JwksFetcher::create, timeSource(),
dispatcher());
getJwksCache(), cm(), Common::JwksFetcher::create, timeSource());
}

private:
Expand Down Expand Up @@ -150,7 +147,6 @@ class FilterConfigImpl : public Logger::Loggable<Logger::Id::jwt>,
// all requirement_names for debug
std::string all_requirement_names_;
TimeSource& time_source_;
Event::Dispatcher& dispatcher_;
};

} // namespace JwtAuthn
Expand Down
28 changes: 3 additions & 25 deletions source/extensions/filters/http/jwt_authn/jwks_async_fetcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,9 @@ std::chrono::milliseconds getFailedRefetchDuration(const JwksAsyncFetch& async_f
JwksAsyncFetcher::JwksAsyncFetcher(const RemoteJwks& remote_jwks,
Server::Configuration::FactoryContext& context,
CreateJwksFetcherCb create_fetcher_fn,
JwtAuthnFilterStats& stats, JwksDoneFetched done_fn,
isRemoteJwksFetchAllowedCb is_fetch_allowed_fn,
allowRemoteJwksFetchCb allow_fetch_fn)
JwtAuthnFilterStats& stats, JwksDoneFetched done_fn)
: remote_jwks_(remote_jwks), context_(context), create_fetcher_fn_(create_fetcher_fn),
stats_(stats), done_fn_(done_fn), is_fetch_allowed_fn_(is_fetch_allowed_fn),
allow_fetch_fn_(allow_fetch_fn),
stats_(stats), done_fn_(done_fn),
debug_name_(absl::StrCat("Jwks async fetching url=", remote_jwks_.http_uri().uri())) {
// if async_fetch is not enabled, do nothing.
if (!remote_jwks_.has_async_fetch()) {
Expand Down Expand Up @@ -81,25 +78,14 @@ std::chrono::seconds JwksAsyncFetcher::getCacheDuration(const RemoteJwks& remote
return DefaultCacheExpirationSec;
}

void JwksAsyncFetcher::resetFetchTimer() { refetch_timer_->enableTimer(good_refetch_duration_); }

void JwksAsyncFetcher::fetch() {
if (remote_jwks_.refetch_jwks_on_kid_mismatch() && !is_fetch_allowed_fn_()) {
resetFetchTimer();
return;
}

if (fetcher_) {
fetcher_->cancel();
}

ENVOY_LOG(debug, "{}: started", debug_name_);
fetcher_ = create_fetcher_fn_(context_.serverFactoryContext().clusterManager(), remote_jwks_);
fetcher_->fetch(Tracing::NullSpan::instance(), *this);

if (remote_jwks_.refetch_jwks_on_kid_mismatch()) {
allow_fetch_fn_(absl::nullopt, true);
}
}

void JwksAsyncFetcher::handleFetchDone() {
Expand All @@ -111,12 +97,8 @@ void JwksAsyncFetcher::handleFetchDone() {

void JwksAsyncFetcher::onJwksSuccess(google::jwt_verify::JwksPtr&& jwks) {
done_fn_(std::move(jwks));
if (remote_jwks_.refetch_jwks_on_kid_mismatch()) {
// Don't modify backoff for async fetch.
allow_fetch_fn_(absl::nullopt, false);
}
handleFetchDone();
resetFetchTimer();
refetch_timer_->enableTimer(good_refetch_duration_);
stats_.jwks_fetch_success_.inc();

// Note: not to free fetcher_ within onJwksSuccess or onJwksError function.
Expand All @@ -131,10 +113,6 @@ void JwksAsyncFetcher::onJwksSuccess(google::jwt_verify::JwksPtr&& jwks) {

void JwksAsyncFetcher::onJwksError(Failure) {
ENVOY_LOG(warn, "{}: failed", debug_name_);
if (remote_jwks_.refetch_jwks_on_kid_mismatch()) {
// Don't modify backoff for async fetch.
allow_fetch_fn_(absl::nullopt, false);
}
handleFetchDone();
refetch_timer_->enableTimer(failed_refetch_duration_);
stats_.jwks_fetch_failed_.inc();
Expand Down
15 changes: 1 addition & 14 deletions source/extensions/filters/http/jwt_authn/jwks_async_fetcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ using CreateJwksFetcherCb = std::function<Common::JwksFetcherPtr(
*/
using JwksDoneFetched = std::function<void(google::jwt_verify::JwksPtr&& jwks)>;

using isRemoteJwksFetchAllowedCb = std::function<bool(void)>;

using allowRemoteJwksFetchCb = std::function<void(absl::optional<bool>, bool)>;

// This class handles fetching Jwks asynchronously.
// It will be no-op if async_fetch is not enabled.
// At its constructor, it will start to fetch Jwks, register with init_manager if not fast_listener.
Expand All @@ -39,17 +35,12 @@ class JwksAsyncFetcher : public Logger::Loggable<Logger::Id::jwt>,
public:
JwksAsyncFetcher(const envoy::extensions::filters::http::jwt_authn::v3::RemoteJwks& remote_jwks,
Server::Configuration::FactoryContext& context, CreateJwksFetcherCb fetcher_fn,
JwtAuthnFilterStats& stats, JwksDoneFetched done_fn,
isRemoteJwksFetchAllowedCb is_fetch_allowed_fn,
allowRemoteJwksFetchCb allow_fetch_fn);
JwtAuthnFilterStats& stats, JwksDoneFetched done_fn);

// Get the remote Jwks cache duration.
static std::chrono::seconds
getCacheDuration(const envoy::extensions::filters::http::jwt_authn::v3::RemoteJwks& remote_jwks);

// Reset async fetch timer to default value.
void resetFetchTimer();

private:
// Fetch the Jwks
void fetch();
Expand Down Expand Up @@ -84,10 +75,6 @@ class JwksAsyncFetcher : public Logger::Loggable<Logger::Id::jwt>,
// The init target.
std::unique_ptr<Init::TargetImpl> init_target_;

const isRemoteJwksFetchAllowedCb is_fetch_allowed_fn_;

const allowRemoteJwksFetchCb allow_fetch_fn_;

// Used in logs.
const std::string debug_name_;
};
Expand Down
Loading