Skip to content

Commit

Permalink
Lambda: auto decode account id (#2167)
Browse files Browse the repository at this point in the history
  • Loading branch information
meiao authored Dec 13, 2024
1 parent 925da72 commit 37ea748
Show file tree
Hide file tree
Showing 13 changed files with 299 additions and 175 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

package com.agent.instrumentation.awsjavasdk1.services.lambda;

import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProvider;

import java.lang.ref.WeakReference;
import java.util.Objects;

Expand All @@ -18,12 +21,14 @@ public class FunctionRawData {
private final String qualifier;
private final String region;
private final WeakReference<Object> sdkClient;
private final WeakReference<AWSCredentialsProvider> credentialsProvider;

public FunctionRawData(String functionRef, String qualifier, String region, Object sdkClient) {
public FunctionRawData(String functionRef, String qualifier, String region, Object sdkClient, AWSCredentialsProvider credentialsProvider) {
this.functionRef = functionRef;
this.qualifier = qualifier;
this.region = region;
this.sdkClient = new WeakReference<>(sdkClient);
this.credentialsProvider = new WeakReference<>(credentialsProvider);
}

public String getFunctionRef() {
Expand All @@ -42,6 +47,17 @@ public Object getSdkClient() {
return sdkClient.get();
}

public String getAccessKey() {
AWSCredentialsProvider credentialsProvider = this.credentialsProvider.get();
if (credentialsProvider != null) {
AWSCredentials credentials = credentialsProvider.getCredentials();
if (credentials != null) {
return credentials.getAWSAccessKeyId();
}
}
return null;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand All @@ -52,17 +68,23 @@ public boolean equals(Object o) {
}

FunctionRawData that = (FunctionRawData) o;
if (this.sdkClient.get() == null || that.sdkClient.get() == null) {
if (this.sdkClient.get() == null || that.sdkClient.get() == null ||
this.credentialsProvider.get() == null || that.credentialsProvider.get() == null) {
return false;
}
return Objects.equals(functionRef, that.functionRef) &&
Objects.equals(qualifier, that.qualifier) &&
Objects.equals(region, that.region) &&
Objects.equals(sdkClient.get(), that.sdkClient.get());
Objects.equals(sdkClient.get(), that.sdkClient.get()) &&
Objects.equals(credentialsProvider.get(), that.credentialsProvider.get());
}

@Override
public int hashCode() {
return Objects.hash(functionRef, qualifier, region, sdkClient.get());
return Objects.hash(functionRef,
qualifier,
region,
sdkClient.get(),
credentialsProvider.get());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,23 @@

package com.agent.instrumentation.awsjavasdk1.services.lambda;

import com.amazonaws.services.lambda.model.InvokeRequest;
import com.newrelic.agent.bridge.AgentBridge;
import com.newrelic.api.agent.CloudAccountInfo;
import com.newrelic.api.agent.CloudParameters;
import com.newrelic.api.agent.NewRelic;
import com.newrelic.api.agent.Token;

import java.util.Map;
import java.util.function.Function;
import java.util.logging.Level;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class LambdaUtil {

private static final Map<InvokeRequest, Token> TOKEN_MAP = AgentBridge.collectionFactory.createConcurrentWeakKeyedMap();

private static final String PLATFORM = "aws_lambda";
private static final String NULL_ARN = "";
private static final FunctionProcessedData NULL_DATA = new FunctionProcessedData(NULL_ARN, NULL_ARN);
Expand Down Expand Up @@ -90,7 +95,7 @@ static FunctionProcessedData processData(FunctionRawData data) {

if (accountId == null) {
// if account id is not provided, we will try to get it from the config
accountId = getAccountId(data.getSdkClient());
accountId = getAccountId(data.getSdkClient(), data.getAccessKey());
}

if (region != null && accountId != null) {
Expand All @@ -109,11 +114,29 @@ static FunctionProcessedData processData(FunctionRawData data) {
return new FunctionProcessedData(functionName, arn);
}

private static String getAccountId(Object sdkClient) {
return AgentBridge.cloud.getAccountInfo(sdkClient, CloudAccountInfo.AWS_ACCOUNT_ID);
private static String getAccountId(Object sdkClient, String accessKey) {
String accountId = AgentBridge.cloud.getAccountInfo(sdkClient, CloudAccountInfo.AWS_ACCOUNT_ID);
if (accountId == null && accessKey != null) {
accountId = AgentBridge.cloud.decodeAwsAccountId(accessKey);
}
return accountId;
}

public static String getSimpleFunctionName(FunctionRawData functionRawData) {
return CACHE.apply(functionRawData).getFunctionName();
}

/*
* The following are almost the same as a Token in a @NewField.
* These are here and not in a @NewField because the weaver was misbehaving
* trying to rewrite the @NewField methods.
*/
public static Token getToken(InvokeRequest request) {
return TOKEN_MAP.remove(request);
}

public static void setTokenForRequest(InvokeRequest request) {
Token token = NewRelic.getAgent().getTransaction().getToken();
TOKEN_MAP.put(request, token);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
*
* * Copyright 2024 New Relic Corporation. All rights reserved.
* * SPDX-License-Identifier: Apache-2.0
*
*/

package com.agent.instrumentation.awsjavasdk1.services.lambda;

import com.amazonaws.handlers.AsyncHandler;
import com.amazonaws.services.lambda.model.InvokeRequest;
import com.amazonaws.services.lambda.model.InvokeResult;
import com.newrelic.api.agent.NewRelic;
import com.newrelic.api.agent.Token;
import com.newrelic.api.agent.Trace;

public class TokenLinkingAsyncHandler implements AsyncHandler<InvokeRequest, InvokeResult> {

private final AsyncHandler<InvokeRequest, InvokeResult> delegate;

private Token token;

public TokenLinkingAsyncHandler(AsyncHandler<InvokeRequest, InvokeResult> delegate) {
this.delegate = delegate;
token = NewRelic.getAgent().getTransaction().getToken();
}

@Override
@Trace(async = true)
public void onError(Exception e) {
if (token != null) {
token.linkAndExpire();
token = null;
}
NewRelic.getAgent().getTracedMethod().setMetricName("Java", delegate.getClass().getName(), "onError");
delegate.onError(e);
}

@Override
@Trace(async = true)
public void onSuccess(InvokeRequest request, InvokeResult invokeResult) {
if (token != null) {
token.linkAndExpire();
token = null;
}
NewRelic.getAgent().getTracedMethod().setMetricName("Java", delegate.getClass().getName(), "onSuccess");
delegate.onSuccess(request, invokeResult);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@

package com.amazonaws.services.lambda;

import com.agent.instrumentation.awsjavasdk1.services.lambda.FunctionRawData;
import com.agent.instrumentation.awsjavasdk1.services.lambda.LambdaUtil;
import com.agent.instrumentation.awsjavasdk1.services.lambda.TokenLinkingAsyncHandler;
import com.amazonaws.handlers.AsyncHandler;
import com.amazonaws.services.lambda.model.InvokeRequest;
import com.amazonaws.services.lambda.model.InvokeResult;
import com.newrelic.api.agent.CloudParameters;
import com.newrelic.api.agent.NewRelic;
import com.newrelic.api.agent.Segment;
import com.newrelic.api.agent.Trace;
import com.newrelic.api.agent.weaver.MatchType;
import com.newrelic.api.agent.weaver.Weave;
import com.newrelic.api.agent.weaver.Weaver;
Expand All @@ -24,48 +22,12 @@
@Weave(type = MatchType.ExactClass, originalName = "com.amazonaws.services.lambda.AWSLambdaAsyncClient")
public abstract class AWSLambdaAsyncClient_Instrumentation {

protected abstract String getSigningRegion();

@Trace
public Future<InvokeResult> invokeAsync(final InvokeRequest request, AsyncHandler<InvokeRequest, InvokeResult> asyncHandler) {
FunctionRawData functionRawData = new FunctionRawData(request.getFunctionName(), request.getQualifier(), this.getSigningRegion(), this);
CloudParameters cloudParameters = LambdaUtil.getCloudParameters(functionRawData);
String functionName = LambdaUtil.getSimpleFunctionName(functionRawData);
Segment segment = NewRelic.getAgent().getTransaction().startSegment("Lambda", "invoke/" + functionName);

try {
segment.reportAsExternal(cloudParameters);
asyncHandler = new SegmentEndingAsyncHandler(asyncHandler, segment);
return Weaver.callOriginal();
} catch (Throwable t) {
segment.end();
throw t;
}
}

private static class SegmentEndingAsyncHandler implements AsyncHandler<InvokeRequest, InvokeResult> {
private final AsyncHandler<InvokeRequest, InvokeResult> originalHandler;
private final Segment segment;

public SegmentEndingAsyncHandler(
AsyncHandler<InvokeRequest, InvokeResult> asyncHandler, Segment segment) {
this.segment = segment;
this.originalHandler = asyncHandler;
}

@Override
public void onError(Exception exception) {
segment.end();
if (originalHandler != null) {
originalHandler.onError(exception);
}
}

@Override
public void onSuccess(InvokeRequest request, InvokeResult invokeResult) {
segment.end();
if (originalHandler != null) {
originalHandler.onSuccess(request, invokeResult);
}
LambdaUtil.setTokenForRequest(request);
if (asyncHandler != null) {
asyncHandler = new TokenLinkingAsyncHandler(asyncHandler);
}
return Weaver.callOriginal();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@

import com.agent.instrumentation.awsjavasdk1.services.lambda.FunctionRawData;
import com.agent.instrumentation.awsjavasdk1.services.lambda.LambdaUtil;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.services.lambda.model.InvokeRequest;
import com.amazonaws.services.lambda.model.InvokeResult;
import com.newrelic.api.agent.CloudParameters;
import com.newrelic.api.agent.NewRelic;
import com.newrelic.api.agent.Token;
import com.newrelic.api.agent.Trace;
import com.newrelic.api.agent.TracedMethod;
import com.newrelic.api.agent.weaver.MatchType;
Expand All @@ -24,9 +26,18 @@ public abstract class AWSLambdaClient_Instrumentation {

abstract protected String getSigningRegion();

@Trace(leaf = true)
public InvokeResult invoke(InvokeRequest invokeRequest) {
FunctionRawData functionRawData = new FunctionRawData(invokeRequest.getFunctionName(), invokeRequest.getQualifier(), this.getSigningRegion(), this);
private final AWSCredentialsProvider awsCredentialsProvider = Weaver.callOriginal();

// this method needs the async because it is invoked by the async client
// it is also in the path of the sync client execution
@Trace(async = true)
final InvokeResult executeInvoke(InvokeRequest invokeRequest) {
Token token = LambdaUtil.getToken(invokeRequest);
if (token != null) {
token.linkAndExpire();
token = null;
}
FunctionRawData functionRawData = new FunctionRawData(invokeRequest.getFunctionName(), invokeRequest.getQualifier(), this.getSigningRegion(), this, awsCredentialsProvider);
CloudParameters cloudParameters = LambdaUtil.getCloudParameters(functionRawData);
TracedMethod tracedMethod = NewRelic.getAgent().getTracedMethod();
tracedMethod.reportAsExternal(cloudParameters);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

package com.agent.instrumentation.awsjavasdk1.services.lambda;

import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.regions.Regions;
import com.newrelic.agent.bridge.AgentBridge;
import com.newrelic.agent.bridge.CloudApi;
import com.newrelic.agent.bridge.NoOpCloud;
import com.newrelic.api.agent.CloudAccountInfo;
import com.newrelic.api.agent.CloudParameters;
import org.junit.After;
import org.junit.Before;
Expand All @@ -23,13 +23,15 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class LambdaUtilTest {

private AWSCredentialsProvider credentialsProvider;

@Before
public void before() {
AgentBridge.cloud = mock(CloudApi.class);
credentialsProvider = mock(AWSCredentialsProvider.class);
}

@After
Expand Down Expand Up @@ -65,6 +67,6 @@ public void testGetCloudParamArnQualifier() {
}

private FunctionRawData data(String functionRef, String qualifier) {
return new FunctionRawData(functionRef, qualifier, Regions.US_EAST_1.getName(), this);
return new FunctionRawData(functionRef, qualifier, Regions.US_EAST_1.getName(), this, credentialsProvider);
}
}
Loading

0 comments on commit 37ea748

Please sign in to comment.