Skip to content

Commit

Permalink
Do not expire invitations on GET requests
Browse files Browse the repository at this point in the history
At the moment, when the user visits:

```
/invitations/accept?code=some-code
```

the invitation code from their email is immediately expired and replaced
with a newly generated code which is put in a hidden input in the HTML
form. Each time the user submits the form, the code is expired and (if
necessary - e.g. if there's a validation issue) replaced with a new one.

This is fine so long as the user fills the form in immediately, but
there are a number of edge cases where this approach causes usability
problems:

1) If the user refreshes the page it will tell them their invitation has
   expired.
2) If the user closes the tab without submitting the form, and then
   follows the invitation link from their email later it will show as
   expired.
3) If the user's email client or web browser pre-fetches the link for
   any reason (e.g. virus scanning / spam detection / performance
   optimisation) then the link will not work when they follow it for
   real.

The third issue is the most serious.

We (GOV.UK PaaS) have had some very users working in places that
pre-fetch links in emails (for some reason or other), and this means
they're completely unable to accept invitations. Judging from the irate
support tickets we've had from these users the experience is pretty
frustrating.

This commit changes the GET request to /invitations/accept so that it
does not expire the token (unless the invitation is being auto-accepted).

The POST handler is unchanged, so if the user actually submits the form
then the token will change (as it did before), even if there's a
validation issue that prevents the invitation being accepted.

This change fixes the usability issues, and makes the behaviour more
consistent with HTTP's semantics (in the sense that GET requests should
be "safe" - should not modify the state of the server).
  • Loading branch information
richardTowers authored and Toby Lorne committed Jun 18, 2020
1 parent 11f5ad8 commit 82385c5
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ public interface ExpiringCodeStore {
*/
ExpiringCode generateCode(String data, Timestamp expiresAt, String intent, String zoneId);

/**
* Retrieve a code BUT DO NOT DELETE IT.
*
* WARNING - if you intend to expire the code as soon as you read it,
* use {@link #retrieveCode(String, String)} instead.
*
* @param code the one-time code to look for
* @param zoneId
* @return code or null if the code is not found
* @throws java.lang.NullPointerException if the code is null
*/
ExpiringCode peekCode(String code, String zoneId);

/**
* Retrieve a code and delete it if it exists.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,25 @@ public ExpiringCode generateCode(String data, Timestamp expiresAt, String intent
return null;
}

@Override
public ExpiringCode peekCode(String code, String zoneId) {
cleanExpiredEntries();

if (code == null) {
throw new NullPointerException();
}

try {
ExpiringCode expiringCode = jdbcTemplate.queryForObject(selectAllFields, rowMapper, code, zoneId);
if (expiringCode.getExpiresAt().getTime() < timeService.getCurrentTimeMillis()) {
expiringCode = null;
}
return expiringCode;
} catch (EmptyResultDataAccessException x) {
return null;
}
}

@Override
public ExpiringCode retrieveCode(String code, String zoneId) {
cleanExpiredEntries();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ public void return404(HttpServletResponse response) {
@RequestMapping(value = "/accept", method = GET, params = {"code"})
public String acceptInvitePage(@RequestParam String code, Model model, HttpServletRequest request, HttpServletResponse response) {

ExpiringCode expiringCode = expiringCodeStore.retrieveCode(code, IdentityZoneHolder.get().getId());
ExpiringCode expiringCode = expiringCodeStore.peekCode(code, IdentityZoneHolder.get().getId());
if ((null == expiringCode) || (null != expiringCode.getIntent() && !INVITATION.name().equals(expiringCode.getIntent()))) {
return handleUnprocessableEntity(model, response, "error_message_code", "code_expired", "invitations/accept_invite");
}
Expand All @@ -128,28 +128,27 @@ public String acceptInvitePage(@RequestParam String code, Model model, HttpServl
String origin = codeData.get(ORIGIN);
try {
IdentityProvider provider = identityProviderProvisioning.retrieveByOrigin(origin, IdentityZoneHolder.get().getId());
final String newCode = expiringCodeStore.generateCode(expiringCode.getData(), new Timestamp(System.currentTimeMillis() + (10 * 60 * 1000)), expiringCode.getIntent(), IdentityZoneHolder.get().getId()).getCode();

UaaUser user = userDatabase.retrieveUserById(codeData.get("user_id"));
boolean isUaaUserAndVerified =
UAA.equals(provider.getType()) && user.isVerified();
boolean isExternalUserAndAcceptedInvite =
!UAA.equals(provider.getType()) && UaaHttpRequestUtils.isAcceptedInvitationAuthentication();
if (isUaaUserAndVerified || isExternalUserAndAcceptedInvite) {
AcceptedInvitation accepted = invitationsService.acceptInvitation(newCode, "");
AcceptedInvitation accepted = invitationsService.acceptInvitation(code, "");
String redirect = "redirect:" + accepted.getRedirectUri();
logger.debug(String.format("Redirecting accepted invitation for email:%s, id:%s to URL:%s", codeData.get("email"), codeData.get("user_id"), redirect));
return redirect;
} else if (SAML.equals(provider.getType())) {
setRequestAttributes(request, newCode, user);
setRequestAttributes(request, code, user);

SamlIdentityProviderDefinition definition = ObjectUtils.castInstance(provider.getConfig(), SamlIdentityProviderDefinition.class);

String redirect = "redirect:/" + SamlRedirectUtils.getIdpRedirectUrl(definition, spEntityID, IdentityZoneHolder.get());
logger.debug(String.format("Redirecting invitation for email:%s, id:%s single SAML IDP URL:%s", codeData.get("email"), codeData.get("user_id"), redirect));
return redirect;
} else if (OIDC10.equals(provider.getType()) || OAUTH20.equals(provider.getType())) {
setRequestAttributes(request, newCode, user);
setRequestAttributes(request, code, user);

AbstractExternalOAuthIdentityProviderDefinition definition = ObjectUtils.castInstance(provider.getConfig(), AbstractExternalOAuthIdentityProviderDefinition.class);

Expand All @@ -162,7 +161,7 @@ public String acceptInvitePage(@RequestParam String code, Model model, HttpServl
Collections.singletonList(UaaAuthority.UAA_INVITED));
SecurityContextHolder.getContext().setAuthentication(token);
model.addAttribute("provider", provider.getType());
model.addAttribute("code", newCode);
model.addAttribute("code", code);
model.addAttribute("email", codeData.get("email"));
logger.debug(String.format("Sending user to accept invitation page email:%s, id:%s", codeData.get("email"), codeData.get("user_id")));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,19 @@ void generateCodeWithDuplicateCode() {
() -> expiringCodeStore.generateCode(data, expiresAt, null, IdentityZone.getUaaZoneId()));
}

@Test
void peekCode() {
String data = "{}";
Timestamp expiresAt = new Timestamp(System.currentTimeMillis() + 60000);
String zoneId = IdentityZoneHolder.get().getId();

ExpiringCode generatedCode = expiringCodeStore.generateCode(data, expiresAt, null, zoneId);

Assert.assertEquals(generatedCode, expiringCodeStore.peekCode(generatedCode.getCode(), zoneId));
Assert.assertEquals(generatedCode, expiringCodeStore.peekCode(generatedCode.getCode(), zoneId));
Assert.assertEquals(generatedCode, expiringCodeStore.peekCode(generatedCode.getCode(), zoneId));
}

@Test
void retrieveCode() {
String data = "{}";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,21 @@ public ExpiringCode generateCode(String data, Timestamp expiresAt, String intent
return expiringCode;
}

@Override
public ExpiringCode peekCode(String code, String zoneId) {
if (code == null) {
throw new NullPointerException();
}

ExpiringCode expiringCode = store.get(code + zoneId);

if (expiringCode == null || isExpired(expiringCode)) {
expiringCode = null;
}

return expiringCode;
}

@Override
public ExpiringCode retrieveCode(String code, String zoneId) {
if (code == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ public void testAcceptInvitationsPage() throws Exception {
codeData.put("email", "[email protected]");
codeData.put("client_id", "client-id");
codeData.put("redirect_uri", "blah.test.com");
when(expiringCodeStore.retrieveCode("code", IdentityZoneHolder.get().getId())).thenReturn(createCode(codeData), null);
when(expiringCodeStore.generateCode(anyString(), any(), eq(INVITATION.name()), eq(IdentityZoneHolder.get().getId()))).thenReturn(createCode(codeData));
when(expiringCodeStore.peekCode("code", IdentityZoneHolder.get().getId())).thenReturn(createCode(codeData), null);
IdentityProvider provider = new IdentityProvider();
provider.setType(OriginKeys.UAA);
when(providerProvisioning.retrieveByOrigin(any(), any())).thenReturn(provider);
Expand Down Expand Up @@ -193,8 +192,7 @@ public void incorrectCodeIntent() throws Exception {
@Test
public void acceptInvitePage_for_unverifiedSamlUser() throws Exception {
Map<String,String> codeData = getInvitationsCode("test-saml");
when(expiringCodeStore.retrieveCode("the_secret_code", IdentityZoneHolder.get().getId())).thenReturn(createCode(codeData));
when(expiringCodeStore.generateCode(anyString(), any(), eq(INVITATION.name()), eq(IdentityZoneHolder.get().getId()))).thenReturn(createCode(codeData));
when(expiringCodeStore.peekCode("the_secret_code", IdentityZoneHolder.get().getId())).thenReturn(createCode(codeData));
IdentityProvider provider = new IdentityProvider();
SamlIdentityProviderDefinition definition = new SamlIdentityProviderDefinition()
.setMetaDataLocation("http://test.saml.com")
Expand All @@ -220,8 +218,7 @@ public void acceptInvitePage_for_unverifiedSamlUser() throws Exception {
@Test
public void acceptInvitePage_for_unverifiedOIDCUser() throws Exception {
Map<String,String> codeData = getInvitationsCode("test-oidc");
when(expiringCodeStore.retrieveCode("the_secret_code", IdentityZoneHolder.get().getId())).thenReturn(createCode(codeData));
when(expiringCodeStore.generateCode(anyString(), any(), eq(INVITATION.name()), eq(IdentityZoneHolder.get().getId()))).thenReturn(createCode(codeData));
when(expiringCodeStore.peekCode("the_secret_code", IdentityZoneHolder.get().getId())).thenReturn(createCode(codeData));

OIDCIdentityProviderDefinition definition = new OIDCIdentityProviderDefinition();
definition.setAuthUrl(new URL("https://oidc10.auth.url"));
Expand All @@ -247,8 +244,7 @@ public void acceptInvitePage_for_unverifiedOIDCUser() throws Exception {
@Test
public void acceptInvitePage_for_unverifiedLdapUser() throws Exception {
Map<String, String> codeData = getInvitationsCode(LDAP);
when(expiringCodeStore.retrieveCode("the_secret_code", IdentityZoneHolder.get().getId())).thenReturn(createCode(codeData));
when(expiringCodeStore.generateCode(anyString(), any(), eq(INVITATION.name()), eq(IdentityZoneHolder.get().getId()))).thenReturn(createCode(codeData));
when(expiringCodeStore.peekCode("the_secret_code", IdentityZoneHolder.get().getId())).thenReturn(createCode(codeData));

IdentityProvider provider = new IdentityProvider();
provider.setType(LDAP);
Expand All @@ -263,7 +259,7 @@ public void acceptInvitePage_for_unverifiedLdapUser() throws Exception {
.andExpect(content().string(containsString("Email: " + "[email protected]")))
.andExpect(content().string(containsString("Sign in with enterprise credentials:")))
.andExpect(content().string(containsString("username")))
.andExpect(model().attribute("code", "code"))
.andExpect(model().attribute("code", "the_secret_code"))
.andReturn();
}

Expand Down Expand Up @@ -397,8 +393,7 @@ public void acceptInvitePage_for_verifiedUser() throws Exception {
codeData.put("email", "[email protected]");
codeData.put("origin", "some-origin");

when(expiringCodeStore.retrieveCode("the_secret_code", IdentityZoneHolder.get().getId())).thenReturn(createCode(codeData), null);
when(expiringCodeStore.generateCode(anyString(), any(), eq(INVITATION.name()), eq(IdentityZoneHolder.get().getId()))).thenReturn(createCode(codeData));
when(expiringCodeStore.peekCode("the_secret_code", IdentityZoneHolder.get().getId())).thenReturn(createCode(codeData), null);
when(invitationsService.acceptInvitation(anyString(), eq(""))).thenReturn(new AcceptedInvitation("blah.test.com", new ScimUser()));
IdentityProvider provider = new IdentityProvider();
provider.setType(OriginKeys.UAA);
Expand Down Expand Up @@ -668,10 +663,8 @@ public void testAcceptInvite_displaysConsentText() throws Exception {
Map<String,String> codeData = getInvitationsCode(OriginKeys.UAA);
String codeDataString = JsonUtils.writeValueAsString(codeData);
ExpiringCode expiringCode = new ExpiringCode("thecode", new Timestamp(1), codeDataString, INVITATION.name());
when(expiringCodeStore.retrieveCode("thecode", IdentityZoneHolder.get().getId()))
when(expiringCodeStore.peekCode("thecode", IdentityZoneHolder.get().getId()))
.thenReturn(expiringCode, null);
when(expiringCodeStore.generateCode(anyString(), any(), eq(INVITATION.name()), eq(IdentityZoneHolder.get().getId())))
.thenReturn(expiringCode);

mockMvc.perform(get("/invitations/accept")
.param("code", "thecode"))
Expand Down Expand Up @@ -714,6 +707,8 @@ public void testAcceptInvite_displaysErrorMessageIfConsentNotChecked() throws Ex
Map<String,String> codeData = getInvitationsCode(OriginKeys.UAA);
String codeDataString = JsonUtils.writeValueAsString(codeData);
ExpiringCode expiringCode = new ExpiringCode("thecode", new Timestamp(1), codeDataString, INVITATION.name());
when(expiringCodeStore.peekCode(anyString(), eq(IdentityZoneHolder.get().getId())))
.thenReturn(expiringCode);
when(expiringCodeStore.retrieveCode(anyString(), eq(IdentityZoneHolder.get().getId())))
.thenReturn(expiringCode);
when(expiringCodeStore.generateCode(anyString(), any(), eq(INVITATION.name()), eq(IdentityZoneHolder.get().getId())))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ void acceptInvitationForVerifiedUserSendsRedirect() throws Exception {
}

@Test
void acceptInvitationForUaaUserShouldExpireInvitelink() throws Exception {
void acceptInvitationForUaaUserShouldNotExpireInvitelink() throws Exception {
String email = new RandomValueStringGenerator().generate().toLowerCase() + "@test.org";
URL inviteLink = inviteUser(webApplicationContext, mockMvc, email, userInviteToken, null, clientId, OriginKeys.UAA);
assertEquals(OriginKeys.UAA, queryUserForField(jdbcTemplate, email, OriginKeys.ORIGIN, String.class));
Expand All @@ -218,9 +218,10 @@ void acceptInvitationForUaaUserShouldExpireInvitelink() throws Exception {
.accept(MediaType.TEXT_HTML);
mockMvc.perform(get)
.andExpect(status().isOk());

mockMvc.perform(get)
.andExpect(status().isUnprocessableEntity());
.andExpect(status().isOk());
mockMvc.perform(get)
.andExpect(status().isOk());
}

@Test
Expand Down

0 comments on commit 82385c5

Please sign in to comment.