Skip to content

Commit

Permalink
Jwt: implement clear route cache (envoyproxy#30699)
Browse files Browse the repository at this point in the history
Commit Message: Follow-up to envoyproxy#30356 to implement clear_route_cache.
Risk Level: low, new field
Testing: done
Docs Changes: none
Release Notes: none
Issue: envoyproxy#29681
  • Loading branch information
kyessenov authored Nov 4, 2023
1 parent 3816b1a commit f0248a4
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ message JwtProvider {
// This header is only reserved for jwt claim; any other value will be overwritten.
repeated JwtClaimToHeader claim_to_headers = 15;

// [#not-implemented-hide:]
// Clears route cache in order to allow JWT token to correctly affect
// routing decisions. Filter clears all cached routes when:
//
Expand Down
35 changes: 27 additions & 8 deletions source/extensions/filters/http/jwt_authn/authenticator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ class AuthenticatorImpl : public Logger::Loggable<Logger::Id::jwt>,
// Following functions are for Authenticator interface.
void verify(Http::HeaderMap& headers, Tracing::Span& parent_span,
std::vector<JwtLocationConstPtr>&& tokens,
SetExtractedJwtDataCallback set_extracted_jwt_data_cb,
AuthenticatorCallback callback) override;
SetExtractedJwtDataCallback set_extracted_jwt_data_cb, AuthenticatorCallback callback,
ClearRouteCacheCallback clear_route_cb) override;
void onDestroy() override;

TimeSource& timeSource() { return time_source_; }
Expand All @@ -87,8 +87,8 @@ class AuthenticatorImpl : public Logger::Loggable<Logger::Id::jwt>,
// finds one to verify with key.
void startVerify();

// Copy the JWT Claim to HTTP Header
void addJWTClaimToHeader(const std::string& claim_name, const std::string& header_name);
// Copy the JWT Claim to HTTP Header. Returns true iff header is added.
bool addJWTClaimToHeader(const std::string& claim_name, const std::string& header_name);

// The jwks cache object.
JwksCache& jwks_cache_;
Expand Down Expand Up @@ -117,6 +117,10 @@ class AuthenticatorImpl : public Logger::Loggable<Logger::Id::jwt>,
SetExtractedJwtDataCallback set_extracted_jwt_data_cb_;
// The on_done function.
AuthenticatorCallback callback_;
// Clear route cache callback function.
ClearRouteCacheCallback clear_route_cb_;
// Set to true to clear the route cache.
bool clear_route_cache_{false};
// check audience object.
const CheckAudience* check_audience_;
// specific provider or not when it is allow missing or failed.
Expand All @@ -143,13 +147,16 @@ std::string AuthenticatorImpl::name() const {
void AuthenticatorImpl::verify(Http::HeaderMap& headers, Tracing::Span& parent_span,
std::vector<JwtLocationConstPtr>&& tokens,
SetExtractedJwtDataCallback set_extracted_jwt_data_cb,
AuthenticatorCallback callback) {
AuthenticatorCallback callback,
ClearRouteCacheCallback clear_route_cb) {
ASSERT(!callback_);
headers_ = &headers;
parent_span_ = &parent_span;
tokens_ = std::move(tokens);
set_extracted_jwt_data_cb_ = std::move(set_extracted_jwt_data_cb);
callback_ = std::move(callback);
clear_route_cb_ = std::move(clear_route_cb);
clear_route_cache_ = false;

ENVOY_LOG(debug, "{}: JWT authentication starts (allow_failed={}), tokens size={}", name(),
is_allow_failed_, tokens_.size());
Expand Down Expand Up @@ -303,7 +310,7 @@ void AuthenticatorImpl::verifyKey() {
handleGoodJwt(/*cache_hit=*/false);
}

void AuthenticatorImpl::addJWTClaimToHeader(const std::string& claim_name,
bool AuthenticatorImpl::addJWTClaimToHeader(const std::string& claim_name,
const std::string& header_name) {
StructUtils payload_getter(jwt_->payload_pb_);
const ProtobufWkt::Value* claim_value;
Expand Down Expand Up @@ -342,8 +349,10 @@ void AuthenticatorImpl::addJWTClaimToHeader(const std::string& claim_name,
headers_->addCopy(Http::LowerCaseString(header_name), str_claim_value);
ENVOY_LOG(debug, "[jwt_auth] claim : {} with value : {} is added to the header : {}",
claim_name, str_claim_value, header_name);
return true;
}
}
return false;
}

void AuthenticatorImpl::handleGoodJwt(bool cache_hit) {
Expand All @@ -363,8 +372,13 @@ void AuthenticatorImpl::handleGoodJwt(bool cache_hit) {
}

// Copy JWT claim to header
bool header_added = false;
for (const auto& header_and_claim : provider.claim_to_headers()) {
addJWTClaimToHeader(header_and_claim.claim_name(), header_and_claim.header_name());
header_added |=
addJWTClaimToHeader(header_and_claim.claim_name(), header_and_claim.header_name());
}
if (provider.clear_route_cache() && header_added) {
clear_route_cache_ = true;
}

if (!provider.forward()) {
Expand Down Expand Up @@ -453,8 +467,13 @@ void AuthenticatorImpl::doneWithStatus(const Status& status) {
} else {
callback_(status);
}

callback_ = nullptr;

if (clear_route_cache_ && clear_route_cb_) {
clear_route_cb_();
}
clear_route_cb_ = nullptr;

return;
}

Expand Down
4 changes: 3 additions & 1 deletion source/extensions/filters/http/jwt_authn/authenticator.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ using AuthenticatorCallback = std::function<void(const ::google::jwt_verify::Sta
using SetExtractedJwtDataCallback =
std::function<void(const std::string&, const ProtobufWkt::Struct&)>;

using ClearRouteCacheCallback = std::function<void()>;

/**
* Authenticator object to handle all JWT authentication flow.
*/
Expand All @@ -34,7 +36,7 @@ class Authenticator {
virtual void verify(Http::HeaderMap& headers, Tracing::Span& parent_span,
std::vector<JwtLocationConstPtr>&& tokens,
SetExtractedJwtDataCallback set_extracted_jwt_data_cb,
AuthenticatorCallback callback) PURE;
AuthenticatorCallback callback, ClearRouteCacheCallback clear_route_cb) PURE;

// Called when the object is about to be destroyed.
virtual void onDestroy() PURE;
Expand Down
2 changes: 2 additions & 0 deletions source/extensions/filters/http/jwt_authn/filter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ void Filter::setExtractedData(const ProtobufWkt::Struct& extracted_data) {
extracted_data);
}

void Filter::clearRouteCache() { decoder_callbacks_->downstreamCallbacks()->clearRouteCache(); }

void Filter::onComplete(const Status& status) {
ENVOY_LOG(debug, "Jwt authentication completed with: {}",
::google::jwt_verify::getStatusString(status));
Expand Down
1 change: 1 addition & 0 deletions source/extensions/filters/http/jwt_authn/filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Filter : public Http::StreamDecoderFilter,
// Following two functions are for Verifier::Callbacks interface.
// Pass the extracted data from a verified JWT as an opaque ProtobufWkt::Struct.
void setExtractedData(const ProtobufWkt::Struct& extracted_data) override;
void clearRouteCache() override;
// It will be called when its verify() call is completed.
void onComplete(const ::google::jwt_verify::Status& status) override;

Expand Down
9 changes: 6 additions & 3 deletions source/extensions/filters/http/jwt_authn/verifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ class ProviderVerifierImpl : public BaseVerifierImpl {
[&ctximpl](const std::string& name, const ProtobufWkt::Struct& extracted_data) {
ctximpl.addExtractedData(name, extracted_data);
},
[this, &ctximpl](const Status& status) { onComplete(status, ctximpl); });
[this, &ctximpl](const Status& status) { onComplete(status, ctximpl); },
[&ctximpl]() { ctximpl.callback()->clearRouteCache(); });
if (!ctximpl.getCompletionState(this).is_completed_) {
ctximpl.storeAuth(std::move(auth));
} else {
Expand Down Expand Up @@ -176,7 +177,8 @@ class AllowFailedVerifierImpl : public BaseVerifierImpl {
[&ctximpl](const std::string& name, const ProtobufWkt::Struct& extracted_data) {
ctximpl.addExtractedData(name, extracted_data);
},
[this, &ctximpl](const Status& status) { onComplete(status, ctximpl); });
[this, &ctximpl](const Status& status) { onComplete(status, ctximpl); },
[&ctximpl]() { ctximpl.callback()->clearRouteCache(); });
if (!ctximpl.getCompletionState(this).is_completed_) {
ctximpl.storeAuth(std::move(auth));
} else {
Expand Down Expand Up @@ -208,7 +210,8 @@ class AllowMissingVerifierImpl : public BaseVerifierImpl {
[&ctximpl](const std::string& name, const ProtobufWkt::Struct& extracted_data) {
ctximpl.addExtractedData(name, extracted_data);
},
[this, &ctximpl](const Status& status) { onComplete(status, ctximpl); });
[this, &ctximpl](const Status& status) { onComplete(status, ctximpl); },
[&ctximpl]() { ctximpl.callback()->clearRouteCache(); });
if (!ctximpl.getCompletionState(this).is_completed_) {
ctximpl.storeAuth(std::move(auth));
} else {
Expand Down
5 changes: 5 additions & 0 deletions source/extensions/filters/http/jwt_authn/verifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class Verifier {
*/
virtual void setExtractedData(const ProtobufWkt::Struct& payload) PURE;

/**
* JWT payloads added to headers may require clearing the cached route.
*/
virtual void clearRouteCache() PURE;

/**
* Called on completion of request.
*
Expand Down
35 changes: 31 additions & 4 deletions test/extensions/filters/http/jwt_authn/authenticator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ class AuthenticatorTest : public testing::Test {
EXPECT_TRUE(jwks_->getStatus() == Status::Ok);
}

void expectVerifyStatus(Status expected_status, Http::RequestHeaderMap& headers) {
void expectVerifyStatus(Status expected_status, Http::RequestHeaderMap& headers,
bool expect_clear_route = false) {
std::function<void(const Status&)> on_complete_cb = [&expected_status](const Status& status) {
ASSERT_EQ(status, expected_status);
};
Expand All @@ -68,8 +69,10 @@ class AuthenticatorTest : public testing::Test {
};
initTokenExtractor();
auto tokens = extractor_->extract(headers);
bool clear_route = false;
auth_->verify(headers, parent_span_, std::move(tokens), std::move(set_extracted_jwt_data_cb),
std::move(on_complete_cb));
std::move(on_complete_cb), [&clear_route] { clear_route = true; });
EXPECT_EQ(expect_clear_route, clear_route);
}

void initTokenExtractor() {
Expand Down Expand Up @@ -163,6 +166,29 @@ TEST_F(AuthenticatorTest, TestClaimToHeader) {
Envoy::Base64::encode(expected_json.data(), expected_json.size()));
}

// This test verifies whether the claim is successfully added to header or not
TEST_F(AuthenticatorTest, TestClaimToHeaderWithClearRouteCache) {
TestUtility::loadFromYaml(ClaimToHeadersConfig, proto_config_);
createAuthenticator();
EXPECT_CALL(*raw_fetcher_, fetch(_, _))
.WillOnce(Invoke([this](Tracing::Span&, JwksFetcher::JwksReceiver& receiver) {
receiver.onJwksSuccess(std::move(jwks_));
}));

{
Http::TestRequestHeaderMapImpl headers{
{"Authorization", "Bearer " + std::string(NestedGoodToken)}};
expectVerifyStatus(Status::Ok, headers, true);
EXPECT_EQ(headers.get_("x-jwt-claim-nested"), "value1");
}

{
Http::TestRequestHeaderMapImpl headers{{"Authorization", "Bearer " + std::string(GoodToken)}};
expectVerifyStatus(Status::Ok, headers, false);
EXPECT_FALSE(headers.has("x-jwt-claim-nested"));
}
}

// This test verifies when wrong claim is passed in claim_to_headers
TEST_F(AuthenticatorTest, TestClaimToHeaderWithHeaderReplace) {
createAuthenticator();
Expand Down Expand Up @@ -854,7 +880,8 @@ TEST_F(AuthenticatorTest, TestOnDestroy) {
auto tokens = extractor_->extract(headers);
// callback should not be called.
std::function<void(const Status&)> on_complete_cb = [](const Status&) { FAIL(); };
auth_->verify(headers, parent_span_, std::move(tokens), nullptr, std::move(on_complete_cb));
auth_->verify(headers, parent_span_, std::move(tokens), nullptr, std::move(on_complete_cb),
nullptr);

// Destroy the authenticating process.
auth_->onDestroy();
Expand Down Expand Up @@ -1013,7 +1040,7 @@ class AuthenticatorJwtCacheTest : public testing::Test {
};
auto tokens = extractor_->extract(headers);
auth_->verify(headers, parent_span_, std::move(tokens), set_extracted_jwt_data_cb,
on_complete_cb);
on_complete_cb, nullptr);
}

::google::jwt_verify::JwksPtr jwks_;
Expand Down
5 changes: 3 additions & 2 deletions test/extensions/filters/http/jwt_authn/mock.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class MockAuthenticator : public Authenticator {

void verify(Http::HeaderMap& headers, Tracing::Span& parent_span,
std::vector<JwtLocationConstPtr>&& tokens,
SetExtractedJwtDataCallback set_extracted_jwt_data_cb,
AuthenticatorCallback callback) override {
SetExtractedJwtDataCallback set_extracted_jwt_data_cb, AuthenticatorCallback callback,
ClearRouteCacheCallback) override {
doVerify(headers, parent_span, &tokens, std::move(set_extracted_jwt_data_cb),
std::move(callback));
}
Expand All @@ -47,6 +47,7 @@ class MockAuthenticator : public Authenticator {
class MockVerifierCallbacks : public Verifier::Callbacks {
public:
MOCK_METHOD(void, setExtractedData, (const ProtobufWkt::Struct& payload));
MOCK_METHOD(void, clearRouteCache, ());
MOCK_METHOD(void, onComplete, (const Status& status));
};

Expand Down
29 changes: 29 additions & 0 deletions test/extensions/filters/http/jwt_authn/test_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,35 @@ const char ExampleConfig[] = R"(
bypass_cors_preflight: true
)";

// Config with claim_to_headers and clear_route_cache.
const char ClaimToHeadersConfig[] = R"(
providers:
example_provider:
issuer: https://example.com
audiences:
- example_service
- http://example_service1
- https://example_service2/
remote_jwks:
http_uri:
uri: https://pubkey_server/pubkey_path
cluster: pubkey_cluster
timeout:
seconds: 5
cache_duration:
seconds: 600
claim_to_headers:
- header_name: "x-jwt-claim-nested"
claim_name: "nested.key-1"
clear_route_cache: true
rules:
- match:
path: "/"
requires:
provider_name: "example_provider"
bypass_cors_preflight: true
)";

const char ExampleConfigWithRegEx[] = R"(
providers:
example_provider:
Expand Down

0 comments on commit f0248a4

Please sign in to comment.