diff --git a/alts/src/main/java/io/grpc/alts/AltsContextUtil.java b/alts/src/main/java/io/grpc/alts/AltsContextUtil.java index 91b06756dc3..5f2ce353761 100644 --- a/alts/src/main/java/io/grpc/alts/AltsContextUtil.java +++ b/alts/src/main/java/io/grpc/alts/AltsContextUtil.java @@ -14,9 +14,9 @@ * limitations under the License. */ - package io.grpc.alts; +import io.grpc.Attributes; import io.grpc.ExperimentalApi; import io.grpc.ServerCall; import io.grpc.alts.internal.AltsInternalContext; @@ -35,7 +35,7 @@ private AltsContextUtil() {} * @return the created {@link AltsContext} * @throws IllegalArgumentException if the {@link ServerCall} has no ALTS information. */ - public static AltsContext createFrom(ServerCall call) { + public static AltsContext createFrom(ServerCall call) { Object authContext = call.getAttributes().get(AltsProtocolNegotiator.AUTH_CONTEXT_KEY); if (!(authContext instanceof AltsInternalContext)) { throw new IllegalArgumentException("No ALTS context information found"); @@ -49,8 +49,18 @@ public static AltsContext createFrom(ServerCall call) { * @param call the {@link ServerCall} to check * @return true, if the {@link ServerCall} contains ALTS information and false otherwise. */ - public static boolean check(ServerCall call) { - Object authContext = call.getAttributes().get(AltsProtocolNegotiator.AUTH_CONTEXT_KEY); + public static boolean check(ServerCall call) { + return check(call.getAttributes()); + } + + /** + * Checks if the {@link Attributes} contains ALTS information. + * + * @param attributes the {@link Attributes} to check + * @return true, if the {@link Attributes} contains ALTS information and false otherwise. + */ + public static boolean check(Attributes attributes) { + Object authContext = attributes.get(AltsProtocolNegotiator.AUTH_CONTEXT_KEY); return authContext instanceof AltsInternalContext; } } diff --git a/alts/src/main/java/io/grpc/alts/DualCallCredentials.java b/alts/src/main/java/io/grpc/alts/DualCallCredentials.java new file mode 100644 index 00000000000..08104712e65 --- /dev/null +++ b/alts/src/main/java/io/grpc/alts/DualCallCredentials.java @@ -0,0 +1,46 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.alts; + +import io.grpc.CallCredentials; +import java.util.concurrent.Executor; + +/** + * {@code CallCredentials} that will pick the right credentials based on whether the established + * connection is ALTS or TLS. + */ +final class DualCallCredentials extends CallCredentials { + private final CallCredentials tlsCallCredentials; + private final CallCredentials altsCallCredentials; + + public DualCallCredentials(CallCredentials tlsCallCreds, CallCredentials altsCallCreds) { + tlsCallCredentials = tlsCallCreds; + altsCallCredentials = altsCallCreds; + } + + @Override + public void applyRequestMetadata( + CallCredentials.RequestInfo requestInfo, + Executor appExecutor, + CallCredentials.MetadataApplier applier) { + if (AltsContextUtil.check(requestInfo.getTransportAttrs())) { + altsCallCredentials.applyRequestMetadata(requestInfo, appExecutor, applier); + } else { + tlsCallCredentials.applyRequestMetadata(requestInfo, appExecutor, applier); + } + } +} diff --git a/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelCredentials.java b/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelCredentials.java index d9c2ddaaed7..1b5880120a4 100644 --- a/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelCredentials.java +++ b/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelCredentials.java @@ -63,6 +63,7 @@ public static Builder newBuilder() { */ public static final class Builder { private CallCredentials callCredentials; + private CallCredentials altsCallCredentials; private Builder() {} @@ -72,23 +73,32 @@ public Builder callCredentials(CallCredentials callCreds) { return this; } + /** Constructs GoogleDefaultChannelCredentials with an ALTS-specific call credential. */ + public Builder altsCallCredentials(CallCredentials callCreds) { + altsCallCredentials = callCreds; + return this; + } + /** Builds a GoogleDefaultChannelCredentials instance. */ public ChannelCredentials build() { ChannelCredentials nettyCredentials = InternalNettyChannelCredentials.create(createClientFactory()); - if (callCredentials != null) { - return CompositeChannelCredentials.create(nettyCredentials, callCredentials); - } - CallCredentials callCreds; - try { - callCreds = MoreCallCredentials.from(GoogleCredentials.getApplicationDefault()); - } catch (IOException e) { - callCreds = - new FailingCallCredentials( - Status.UNAUTHENTICATED - .withDescription("Failed to get Google default credentials") - .withCause(e)); + CallCredentials tlsCallCreds = callCredentials; + if (tlsCallCreds == null) { + try { + tlsCallCreds = MoreCallCredentials.from(GoogleCredentials.getApplicationDefault()); + } catch (IOException e) { + tlsCallCreds = + new FailingCallCredentials( + Status.UNAUTHENTICATED + .withDescription("Failed to get Google default credentials") + .withCause(e)); + } } + CallCredentials callCreds = + altsCallCredentials == null + ? tlsCallCreds + : new DualCallCredentials(tlsCallCreds, altsCallCredentials); return CompositeChannelCredentials.create(nettyCredentials, callCreds); } diff --git a/alts/src/test/java/io/grpc/alts/DualCallCredentialsTest.java b/alts/src/test/java/io/grpc/alts/DualCallCredentialsTest.java new file mode 100644 index 00000000000..29646191be1 --- /dev/null +++ b/alts/src/test/java/io/grpc/alts/DualCallCredentialsTest.java @@ -0,0 +1,109 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.alts; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import io.grpc.Attributes; +import io.grpc.CallCredentials; +import io.grpc.CallCredentials.RequestInfo; +import io.grpc.MethodDescriptor; +import io.grpc.SecurityLevel; +import io.grpc.alts.internal.AltsInternalContext; +import io.grpc.alts.internal.AltsProtocolNegotiator; +import io.grpc.testing.TestMethodDescriptors; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** Unit tests for {@link DualCallCredentials}. */ +@RunWith(JUnit4.class) +public class DualCallCredentialsTest { + + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + + @Mock CallCredentials tlsCallCredentials; + + @Mock CallCredentials altsCallCredentials; + + private static final String AUTHORITY = "testauthority"; + private static final SecurityLevel SECURITY_LEVEL = SecurityLevel.PRIVACY_AND_INTEGRITY; + + @Test + public void invokeTlsCallCredentials() { + DualCallCredentials callCredentials = + new DualCallCredentials(tlsCallCredentials, altsCallCredentials); + RequestInfo requestInfo = new RequestInfoImpl(false); + callCredentials.applyRequestMetadata(requestInfo, null, null); + + verify(altsCallCredentials, never()).applyRequestMetadata(any(), any(), any()); + verify(tlsCallCredentials, times(1)).applyRequestMetadata(requestInfo, null, null); + } + + @Test + public void invokeAltsCallCredentials() { + DualCallCredentials callCredentials = + new DualCallCredentials(tlsCallCredentials, altsCallCredentials); + RequestInfo requestInfo = new RequestInfoImpl(true); + callCredentials.applyRequestMetadata(requestInfo, null, null); + + verify(altsCallCredentials, times(1)).applyRequestMetadata(requestInfo, null, null); + verify(tlsCallCredentials, never()).applyRequestMetadata(any(), any(), any()); + } + + private static final class RequestInfoImpl extends CallCredentials.RequestInfo { + private Attributes attrs; + + RequestInfoImpl(boolean hasAltsContext) { + attrs = + hasAltsContext + ? Attributes.newBuilder() + .set( + AltsProtocolNegotiator.AUTH_CONTEXT_KEY, + AltsInternalContext.getDefaultInstance()) + .build() + : Attributes.EMPTY; + } + + @Override + public MethodDescriptor getMethodDescriptor() { + return TestMethodDescriptors.voidMethod(); + } + + @Override + public SecurityLevel getSecurityLevel() { + return SECURITY_LEVEL; + } + + @Override + public String getAuthority() { + return AUTHORITY; + } + + @Override + public Attributes getTransportAttrs() { + return attrs; + } + } +}