diff --git a/cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationEntryPoint.java b/cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationEntryPoint.java index 18ecb236235..9bed4a0788f 100644 --- a/cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationEntryPoint.java +++ b/cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationEntryPoint.java @@ -21,8 +21,6 @@ import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; -import org.jasig.cas.client.util.CommonUtils; - import org.springframework.beans.factory.InitializingBean; import org.springframework.security.cas.ServiceProperties; import org.springframework.security.core.AuthenticationException; @@ -96,7 +94,7 @@ protected String createServiceUrl(HttpServletRequest request, HttpServletRespons */ protected String createRedirectUrl(String serviceUrl) { return CommonUtils.constructRedirectUrl(this.loginUrl, this.serviceProperties.getServiceParameter(), serviceUrl, - this.serviceProperties.isSendRenew(), false); + this.serviceProperties.isSendRenew(), false, null); } /** diff --git a/cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationFilter.java b/cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationFilter.java index 8e8b84700f3..1943834686c 100644 --- a/cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationFilter.java +++ b/cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationFilter.java @@ -24,7 +24,6 @@ import jakarta.servlet.http.HttpServletResponse; import org.jasig.cas.client.proxy.ProxyGrantingTicketStorage; -import org.jasig.cas.client.util.CommonUtils; import org.jasig.cas.client.validation.TicketValidator; import org.springframework.core.log.LogMessage; diff --git a/cas/src/main/java/org/springframework/security/cas/web/CommonUtils.java b/cas/src/main/java/org/springframework/security/cas/web/CommonUtils.java new file mode 100644 index 00000000000..8376b3f6174 --- /dev/null +++ b/cas/src/main/java/org/springframework/security/cas/web/CommonUtils.java @@ -0,0 +1,177 @@ +/* + * Licensed to Apereo under one or more contributor license + * agreements. See the NOTICE file distributed with this work + * for additional information regarding copyright ownership. + * Apereo licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file + * except in compliance with the License. You may obtain a + * copy of the License at the following location: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.springframework.security.cas.web; + +import java.io.IOException; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import org.jasig.cas.client.Protocol; +import org.jasig.cas.client.proxy.ProxyGrantingTicketStorage; +import org.jasig.cas.client.util.URIBuilder; + +import org.springframework.util.StringUtils; + +final class CommonUtils { + + private static final String PARAM_PROXY_GRANTING_TICKET_IOU = "pgtIou"; + + /** + * Constant representing the ProxyGrantingTicket Request Parameter. + */ + private static final String PARAM_PROXY_GRANTING_TICKET = "pgtId"; + + private static final String SERVICE_PARAMETER_NAMES; + + private CommonUtils() { + + } + + static { + final Set serviceParameterSet = new HashSet(4); + for (final Protocol protocol : Protocol.values()) { + serviceParameterSet.add(protocol.getServiceParameterName()); + } + SERVICE_PARAMETER_NAMES = serviceParameterSet.toString().replaceAll("\\[|\\]", "").replaceAll("\\s", ""); + } + + static String constructServiceUrl(final HttpServletRequest request, final HttpServletResponse response, + final String service, final String serverNames, final String artifactParameterName, final boolean encode) { + if (StringUtils.hasText(service)) { + return encode ? response.encodeURL(service) : service; + } + + final String serverName = findMatchingServerName(request, serverNames); + final URIBuilder originalRequestUrl = new URIBuilder(request.getRequestURL().toString(), encode); + originalRequestUrl.setParameters(request.getQueryString()); + + final URIBuilder builder; + if (!serverName.startsWith("https://") && !serverName.startsWith("http://")) { + final String scheme = request.isSecure() ? "https://" : "http://"; + builder = new URIBuilder(scheme + serverName, encode); + } + else { + builder = new URIBuilder(serverName, encode); + } + + if (builder.getPort() == -1 && !requestIsOnStandardPort(request)) { + builder.setPort(request.getServerPort()); + } + + builder.setEncodedPath(builder.getEncodedPath() + request.getRequestURI()); + + final List serviceParameterNames = Arrays.asList(SERVICE_PARAMETER_NAMES.split(",")); + if (!serviceParameterNames.isEmpty() && !originalRequestUrl.getQueryParams().isEmpty()) { + for (final URIBuilder.BasicNameValuePair pair : originalRequestUrl.getQueryParams()) { + final String name = pair.getName(); + if (!name.equals(artifactParameterName) && !serviceParameterNames.contains(name)) { + if (name.contains("&") || name.contains("=")) { + final URIBuilder encodedParamBuilder = new URIBuilder(); + encodedParamBuilder.setParameters(name); + for (final URIBuilder.BasicNameValuePair pair2 : encodedParamBuilder.getQueryParams()) { + final String name2 = pair2.getName(); + if (!name2.equals(artifactParameterName) && !serviceParameterNames.contains(name2)) { + builder.addParameter(name2, pair2.getValue()); + } + } + } + else { + builder.addParameter(name, pair.getValue()); + } + } + } + } + + final String result = builder.toString(); + final String returnValue = encode ? response.encodeURL(result) : result; + return returnValue; + } + + static String constructRedirectUrl(final String casServerLoginUrl, final String serviceParameterName, + final String serviceUrl, final boolean renew, final boolean gateway, final String method) { + return casServerLoginUrl + (casServerLoginUrl.contains("?") ? "&" : "?") + serviceParameterName + "=" + + urlEncode(serviceUrl) + (renew ? "&renew=true" : "") + (gateway ? "&gateway=true" : "") + + ((method != null) ? "&method=" + method : ""); + } + + static String urlEncode(final String value) { + return URLEncoder.encode(value, StandardCharsets.UTF_8); + } + + static void readAndRespondToProxyReceptorRequest(final HttpServletRequest request, + final HttpServletResponse response, final ProxyGrantingTicketStorage proxyGrantingTicketStorage) + throws IOException { + final String proxyGrantingTicketIou = request.getParameter(PARAM_PROXY_GRANTING_TICKET_IOU); + + final String proxyGrantingTicket = request.getParameter(PARAM_PROXY_GRANTING_TICKET); + + if (org.jasig.cas.client.util.CommonUtils.isBlank(proxyGrantingTicket) + || org.jasig.cas.client.util.CommonUtils.isBlank(proxyGrantingTicketIou)) { + response.getWriter().write(""); + return; + } + + proxyGrantingTicketStorage.save(proxyGrantingTicketIou, proxyGrantingTicket); + + response.getWriter().write(""); + response.getWriter().write(""); + } + + private static String findMatchingServerName(final HttpServletRequest request, final String serverName) { + final String[] serverNames = serverName.split(" "); + + if (serverNames.length == 0 || serverNames.length == 1) { + return serverName; + } + + final String host = request.getHeader("Host"); + final String xHost = request.getHeader("X-Forwarded-Host"); + + final String comparisonHost; + comparisonHost = (xHost != null) ? xHost : host; + + if (comparisonHost == null) { + return serverName; + } + + for (final String server : serverNames) { + final String lowerCaseServer = server.toLowerCase(); + + if (lowerCaseServer.contains(comparisonHost)) { + return server; + } + } + + return serverNames[0]; + } + + private static boolean requestIsOnStandardPort(final HttpServletRequest request) { + final int serverPort = request.getServerPort(); + return serverPort == 80 || serverPort == 443; + } + +} diff --git a/etc/checkstyle/checkstyle-suppressions.xml b/etc/checkstyle/checkstyle-suppressions.xml index e42d8124ea5..b7f5427cb63 100644 --- a/etc/checkstyle/checkstyle-suppressions.xml +++ b/etc/checkstyle/checkstyle-suppressions.xml @@ -13,6 +13,7 @@ +