Skip to content

Commit

Permalink
Merge branch '6.1.x'
Browse files Browse the repository at this point in the history
Closes gh-13701
  • Loading branch information
jzheaux committed Aug 18, 2023
2 parents 9c599fa + 321deb3 commit 3540dee
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
Expand Down Expand Up @@ -126,21 +124,19 @@ private Saml2MetadataResponse responseByIterable(HttpServletRequest request,
Iterable<RelyingPartyRegistration> registrations) {
Map<String, RelyingPartyRegistration> results = new LinkedHashMap<>();
for (RelyingPartyRegistration registration : registrations) {
results.put(registration.getEntityId(), registration);
}
Collection<RelyingPartyRegistration> resolved = new ArrayList<>();
for (RelyingPartyRegistration registration : results.values()) {
UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
String entityId = uriResolver.resolve(registration.getEntityId());
String ssoLocation = uriResolver.resolve(registration.getAssertionConsumerServiceLocation());
String sloLocation = uriResolver.resolve(registration.getSingleLogoutServiceLocation());
String sloResponseLocation = uriResolver.resolve(registration.getSingleLogoutServiceResponseLocation());
resolved.add(registration.mutate().entityId(entityId).assertionConsumerServiceLocation(ssoLocation)
.singleLogoutServiceLocation(sloLocation).singleLogoutServiceResponseLocation(sloResponseLocation)
.build());
results.computeIfAbsent(entityId, (e) -> {
String ssoLocation = uriResolver.resolve(registration.getAssertionConsumerServiceLocation());
String sloLocation = uriResolver.resolve(registration.getSingleLogoutServiceLocation());
String sloResponseLocation = uriResolver.resolve(registration.getSingleLogoutServiceResponseLocation());
return registration.mutate().entityId(entityId).assertionConsumerServiceLocation(ssoLocation)
.singleLogoutServiceLocation(sloLocation)
.singleLogoutServiceResponseLocation(sloResponseLocation).build();
});
}
String metadata = this.metadata.resolve(resolved);
String value = (resolved.size() == 1) ? resolved.iterator().next().getRegistrationId()
String metadata = this.metadata.resolve(results.values());
String value = (results.size() == 1) ? results.values().iterator().next().getRegistrationId()
: UUID.randomUUID().toString();
String fileName = this.filename.replace("{registrationId}", value);
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

Expand Down Expand Up @@ -101,15 +102,32 @@ void resolveWhenRequestDoesNotMatchThenNull() {
assertThat(resolver.resolve(new MockHttpServletRequest())).isNull();
}

// gh-13700
@Test
void resolveWhenNoRegistrationIdThenResolvesEntityIds() {
RelyingPartyRegistration one = withEntityId("one");
RelyingPartyRegistration two = withEntityId("two");
RelyingPartyRegistrationRepository registrations = new InMemoryRelyingPartyRegistrationRepository(one, two);
RequestMatcherMetadataResponseResolver resolver = new RequestMatcherMetadataResponseResolver(registrations,
this.metadataFactory);
given(this.metadataFactory.resolve(any(Collection.class))).willReturn("metadata");
resolver.resolve(get("/saml2/metadata"));
ArgumentCaptor<Collection<RelyingPartyRegistration>> captor = ArgumentCaptor.forClass(Collection.class);
verify(this.metadataFactory).resolve(captor.capture());
Collection<RelyingPartyRegistration> resolved = captor.getValue();
assertThat(resolved).hasSize(2);
assertThat(resolved.iterator().next().getEntityId()).isEqualTo("one");
}

private MockHttpServletRequest get(String uri) {
MockHttpServletRequest request = new MockHttpServletRequest("GET", uri);
request.setServletPath(uri);
return request;
}

private RelyingPartyRegistration withEntityId(String entityId) {
return TestRelyingPartyRegistrations.relyingPartyRegistration().registrationId(entityId).entityId(entityId)
.build();
return TestRelyingPartyRegistrations.relyingPartyRegistration().registrationId(entityId)
.entityId("{registrationId}").build();
}

}

0 comments on commit 3540dee

Please sign in to comment.