Skip to content

Commit

Permalink
fix: fix IdTokenVerifier so it does not cache empty entries (#892)
Browse files Browse the repository at this point in the history
Couple fixes to the IdTokenVerifier:
- Cache will not save empty entry in case of public key fetch failure
- Payload verification moved to a separate protected method so child classes can call it directly and avoid duplicate signature verification

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
Co-authored-by: Tomo Suzuki <[email protected]>
  • Loading branch information
3 people authored Jun 2, 2022
1 parent c1b1468 commit 773b388
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,9 @@ public final Collection<String> getAudience() {
* @return {@code true} if verified successfully or {@code false} if failed
*/
public boolean verify(IdToken idToken) {
boolean tokenFieldsValid =
(issuers == null || idToken.verifyIssuer(issuers))
&& (audience == null || idToken.verifyAudience(audience))
&& idToken.verifyTime(clock.currentTimeMillis(), acceptableTimeSkewSeconds);
boolean payloadValid = verifyPayload(idToken);

if (!tokenFieldsValid) {
if (!payloadValid) {
return false;
}

Expand All @@ -254,6 +251,35 @@ public boolean verify(IdToken idToken) {
}
}

/**
* Verifies the payload of the given ID token
*
* <p>It verifies:
*
* <ul>
* <li>The issuer is one of {@link #getIssuers()} by calling {@link
* IdToken#verifyIssuer(String)}.
* <li>The audience is one of {@link #getAudience()} by calling {@link
* IdToken#verifyAudience(Collection)}.
* <li>The current time against the issued at and expiration time, using the {@link #getClock()}
* and allowing for a time skew specified in {@link #getAcceptableTimeSkewSeconds()} , by
* calling {@link IdToken#verifyTime(long, long)}.
* </ul>
*
* <p>Overriding is allowed, but it must call the super implementation.
*
* @param idToken ID token
* @return {@code true} if verified successfully or {@code false} if failed
*/
protected boolean verifyPayload(IdToken idToken) {
boolean tokenPayloadValid =
(issuers == null || idToken.verifyIssuer(issuers))
&& (audience == null || idToken.verifyAudience(audience))
&& idToken.verifyTime(clock.currentTimeMillis(), acceptableTimeSkewSeconds);

return tokenPayloadValid;
}

@VisibleForTesting
boolean verifySignature(IdToken idToken) throws VerificationException {
if (Boolean.parseBoolean(environment.getVariable(SKIP_SIGNATURE_ENV_VAR))) {
Expand All @@ -272,12 +298,12 @@ boolean verifySignature(IdToken idToken) throws VerificationException {
publicKeyToUse = publicKeyCache.get(certificateLocation).get(idToken.getHeader().getKeyId());
} catch (ExecutionException | UncheckedExecutionException e) {
throw new VerificationException(
"Error fetching PublicKey from certificate location " + certificatesLocation, e);
"Error fetching public key from certificate location " + certificatesLocation, e);
}

if (publicKeyToUse == null) {
throw new VerificationException(
"Could not find PublicKey for provided keyId: " + idToken.getHeader().getKeyId());
"Could not find public key for provided keyId: " + idToken.getHeader().getKeyId());
}

try {
Expand Down Expand Up @@ -380,7 +406,7 @@ public Builder setIssuer(String issuer) {
}

/**
* Override the location URL that contains published public keys. Defaults to well-known Google
* Overrides the location URL that contains published public keys. Defaults to well-known Google
* locations.
*
* @param certificatesLocation URL to published public keys
Expand Down Expand Up @@ -534,7 +560,7 @@ public Map<String, PublicKey> load(String certificateUrl) throws Exception {
Level.WARNING,
"Failed to get a certificate from certificate location " + certificateUrl,
io);
return ImmutableMap.of();
throw io;
}

ImmutableMap.Builder<String, PublicKey> keyCacheBuilder = new ImmutableMap.Builder<>();
Expand All @@ -556,6 +582,13 @@ public Map<String, PublicKey> load(String certificateUrl) throws Exception {
}
}

ImmutableMap<String, PublicKey> keyCache = keyCacheBuilder.build();

if (keyCache.isEmpty()) {
throw new VerificationException(
"No valid public key returned by the keystore: " + certificateUrl);
}

return keyCacheBuilder.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import junit.framework.TestCase;
import org.junit.Assert;

/**
* Tests {@link IdTokenVerifier}.
Expand Down Expand Up @@ -101,7 +104,7 @@ public void testBuilder() throws Exception {
assertEquals(TRUSTED_CLIENT_IDS, Lists.newArrayList(verifier.getAudience()));
}

public void testVerify() throws Exception {
public void testVerifyPayload() throws Exception {
MockClock clock = new MockClock();
MockEnvironment testEnvironment = new MockEnvironment();
testEnvironment.setVariable(IdTokenVerifier.SKIP_SIGNATURE_ENV_VAR, "true");
Expand All @@ -121,21 +124,31 @@ public void testVerify() throws Exception {
clock.timeMillis = 1500000L;
IdToken idToken = newIdToken(ISSUER, CLIENT_ID);
assertTrue(verifier.verify(idToken));
assertTrue(verifier.verifyPayload(idToken));
assertTrue(verifierFlexible.verify(newIdToken(ISSUER2, CLIENT_ID)));
assertTrue(verifierFlexible.verifyPayload(newIdToken(ISSUER2, CLIENT_ID)));
assertFalse(verifier.verify(newIdToken(ISSUER2, CLIENT_ID)));
assertFalse(verifier.verifyPayload(newIdToken(ISSUER2, CLIENT_ID)));
assertTrue(verifier.verify(newIdToken(ISSUER3, CLIENT_ID)));
assertTrue(verifier.verifyPayload(newIdToken(ISSUER3, CLIENT_ID)));
// audience
assertTrue(verifierFlexible.verify(newIdToken(ISSUER, CLIENT_ID2)));
assertTrue(verifierFlexible.verifyPayload(newIdToken(ISSUER, CLIENT_ID2)));
assertFalse(verifier.verify(newIdToken(ISSUER, CLIENT_ID2)));
assertFalse(verifier.verifyPayload(newIdToken(ISSUER, CLIENT_ID2)));
// time
clock.timeMillis = 700000L;
assertTrue(verifier.verify(idToken));
assertTrue(verifier.verifyPayload(idToken));
clock.timeMillis = 2300000L;
assertTrue(verifier.verify(idToken));
assertTrue(verifier.verifyPayload(idToken));
clock.timeMillis = 699999L;
assertFalse(verifier.verify(idToken));
assertFalse(verifier.verifyPayload(idToken));
clock.timeMillis = 2300001L;
assertFalse(verifier.verify(idToken));
assertFalse(verifier.verifyPayload(idToken));
}

public void testEmptyIssuersFails() throws Exception {
Expand Down Expand Up @@ -187,28 +200,52 @@ public void testMissingAudience() throws VerificationException {

public void testVerifyEs256TokenPublicKeyMismatch() throws Exception {
// Mock HTTP requests
HttpTransportFactory httpTransportFactory =
new HttpTransportFactory() {
MockLowLevelHttpRequest failedRequest =
new MockLowLevelHttpRequest() {
@Override
public HttpTransport create() {
return new MockHttpTransport() {
@Override
public LowLevelHttpRequest buildRequest(String method, String url)
throws IOException {
return new MockLowLevelHttpRequest() {
@Override
public LowLevelHttpResponse execute() throws IOException {
MockLowLevelHttpResponse response = new MockLowLevelHttpResponse();
response.setStatusCode(200);
response.setContentType("application/json");
response.setContent("");
return response;
}
};
}
};
public LowLevelHttpResponse execute() throws IOException {
throw new IOException("test io exception");
}
};

MockLowLevelHttpRequest badRequest =
new MockLowLevelHttpRequest() {
@Override
public LowLevelHttpResponse execute() throws IOException {
MockLowLevelHttpResponse response = new MockLowLevelHttpResponse();
response.setStatusCode(404);
response.setContentType("application/json");
response.setContent("");
return response;
}
};

MockLowLevelHttpRequest emptyRequest =
new MockLowLevelHttpRequest() {
@Override
public LowLevelHttpResponse execute() throws IOException {
MockLowLevelHttpResponse response = new MockLowLevelHttpResponse();
response.setStatusCode(200);
response.setContentType("application/json");
response.setContent("{\"keys\":[]}");
return response;
}
};

MockLowLevelHttpRequest goodRequest =
new MockLowLevelHttpRequest() {
@Override
public LowLevelHttpResponse execute() throws IOException {
MockLowLevelHttpResponse response = new MockLowLevelHttpResponse();
response.setStatusCode(200);
response.setContentType("application/json");
response.setContent(readResourceAsString("iap_keys.json"));
return response;
}
};

HttpTransportFactory httpTransportFactory =
mockTransport(failedRequest, badRequest, emptyRequest, goodRequest);
IdTokenVerifier tokenVerifier =
new IdTokenVerifier.Builder()
.setClock(FIXED_CLOCK)
Expand All @@ -219,8 +256,24 @@ public LowLevelHttpResponse execute() throws IOException {
tokenVerifier.verifySignature(IdToken.parse(JSON_FACTORY, ES256_TOKEN));
fail("Should have failed verification");
} catch (VerificationException ex) {
assertTrue(ex.getMessage().contains("Error fetching PublicKey"));
assertTrue(ex.getMessage().contains("Error fetching public key"));
}

try {
tokenVerifier.verifySignature(IdToken.parse(JSON_FACTORY, ES256_TOKEN));
fail("Should have failed verification");
} catch (VerificationException ex) {
assertTrue(ex.getMessage().contains("Error fetching public key"));
}

try {
tokenVerifier.verifySignature(IdToken.parse(JSON_FACTORY, ES256_TOKEN));
fail("Should have failed verification");
} catch (VerificationException ex) {
assertTrue(ex.getCause().getMessage().contains("No valid public key returned"));
}

Assert.assertTrue(tokenVerifier.verifySignature(IdToken.parse(JSON_FACTORY, ES256_TOKEN)));
}

public void testVerifyEs256Token() throws VerificationException, IOException {
Expand Down Expand Up @@ -284,6 +337,25 @@ static String readResourceAsString(String resourceName) throws IOException {
}
}

static HttpTransportFactory mockTransport(LowLevelHttpRequest... requests) {
final LowLevelHttpRequest firstRequest = requests[0];
final Queue<LowLevelHttpRequest> requestQueue = new ArrayDeque<>();
for (LowLevelHttpRequest request : requests) {
requestQueue.add(request);
}
return new HttpTransportFactory() {
@Override
public HttpTransport create() {
return new MockHttpTransport() {
@Override
public LowLevelHttpRequest buildRequest(String method, String url) throws IOException {
return requestQueue.poll();
}
};
}
};
}

static HttpTransportFactory mockTransport(String url, String certificates) {
final String certificatesContent = certificates;
final String certificatesUrl = url;
Expand Down

0 comments on commit 773b388

Please sign in to comment.