diff --git a/sdk/src/Core/Plugins/EC2Plugin.cs b/sdk/src/Core/Plugins/EC2Plugin.cs index 9f409fc4..dd96fd42 100644 --- a/sdk/src/Core/Plugins/EC2Plugin.cs +++ b/sdk/src/Core/Plugins/EC2Plugin.cs @@ -15,9 +15,13 @@ // //----------------------------------------------------------------------------- +using System; using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Threading.Tasks; using Amazon.Runtime.Internal.Util; -using Amazon.Util; +using ThirdParty.LitJson; namespace Amazon.XRay.Recorder.Core.Plugins { @@ -27,6 +31,8 @@ namespace Amazon.XRay.Recorder.Core.Plugins public class EC2Plugin : IPlugin { private static readonly Logger _logger = Logger.GetLogger(typeof(EC2Plugin)); + private readonly HttpClient _client = new HttpClient(); + const string metadata_base_url = "http://169.254.169.254/latest/"; /// /// Gets the name of the origin associated with this plugin. @@ -55,41 +61,99 @@ public string ServiceName /// /// Gets the context of the runtime that this plugin is associated with. /// - /// When the method returns, contains the runtime context of the plugin, or null if the runtime context is not available. + /// When the method returns, contains the runtime context of the plugin. /// true if the runtime context is available; Otherwise, false. public bool TryGetRuntimeContext(out IDictionary context) { - context = null; + // get the token + string token = GetToken(); - var dict = new Dictionary(); - if (EC2InstanceMetadata.InstanceId != null) + // get the metadata + context = GetMetadata(token); + + if (context.Count == 0) { - dict.Add("instance_id", EC2InstanceMetadata.InstanceId); + _logger.DebugFormat("Could not get instance metadata"); + return false; } - if (EC2InstanceMetadata.AvailabilityZone != null) + return true; + } + + private string GetToken() + { + string token = null; + try { - dict.Add("availability_zone", EC2InstanceMetadata.AvailabilityZone); + Dictionary header = new Dictionary(1); + header.Add("X-aws-ec2-metadata-token-ttl-seconds", "60"); + token = DoRequest(metadata_base_url + "api/token", HttpMethod.Put, header).Result; } + catch (Exception) + { + _logger.DebugFormat("Failed to get token for IMDSv2"); + } + + return token; + } - if (EC2InstanceMetadata.InstanceType != null) + + private IDictionary GetMetadata(string token) + { + try + { + Dictionary headers = null; + if (token != null) + { + headers = new Dictionary(1); + headers.Add("X-aws-ec2-metadata-token", token); + } + string identity_doc_url = metadata_base_url + "dynamic/instance-identity/document"; + string doc_string = DoRequest(identity_doc_url, HttpMethod.Get, headers).Result; + return ParseMetadata(doc_string); + } + catch (Exception) { - dict.Add("instance_size", EC2InstanceMetadata.InstanceType); + _logger.DebugFormat("Error occurred while getting EC2 metadata"); + return new Dictionary(); } + } - if (EC2InstanceMetadata.AmiId != null) + + protected virtual async Task DoRequest(string url, HttpMethod method, Dictionary headers = null) + { + HttpRequestMessage request = new HttpRequestMessage(method, url); + if (headers != null) { - dict.Add("ami_id", EC2InstanceMetadata.AmiId); + foreach (var item in headers) + { + request.Headers.Add(item.Key, item.Value); + } } - if (dict.Count == 0) + HttpResponseMessage response = await _client.SendAsync(request); + if (response.IsSuccessStatusCode) { - _logger.DebugFormat("Unable to contact EC2 metadata service, failed to get runtime context."); - return false; + return await response.Content.ReadAsStringAsync(); + } + else + { + throw new Exception("Unable to complete the request successfully"); } + } - context = dict; - return true; + + private IDictionary ParseMetadata(string jsonString) + { + JsonData data = JsonMapper.ToObject(jsonString); + Dictionary ec2_meta_dict = new Dictionary(); + + ec2_meta_dict.Add("instance_id", data["instanceId"]); + ec2_meta_dict.Add("availability_zone", data["availabilityZone"]); + ec2_meta_dict.Add("instance_size", data["instanceType"]); + ec2_meta_dict.Add("ami_id", data["imageId"]); + + return ec2_meta_dict; } } } diff --git a/sdk/test/UnitTests/TestEC2Plugin.cs b/sdk/test/UnitTests/TestEC2Plugin.cs new file mode 100644 index 00000000..e8757ffe --- /dev/null +++ b/sdk/test/UnitTests/TestEC2Plugin.cs @@ -0,0 +1,152 @@ +//----------------------------------------------------------------------------- +// +// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). +// You may not use this file except in compliance with the License. +// A copy of the License is located at +// +// http://aws.amazon.com/apache2.0 +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. +// +//----------------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Threading.Tasks; +using Amazon.XRay.Recorder.Core.Plugins; +using Microsoft.VisualStudio.TestTools.UnitTesting; + + +namespace Amazon.XRay.Recorder.UnitTests +{ + [TestClass] + public class TestEC2Plugin + { + + [TestMethod] + public void TestV2Success() + { + // Arrange + EC2Plugin ec2_plugin = new MockEC2Plugin(failV1: false, failV2: false); + IDictionary context = new Dictionary(); + + // Act + bool ret = ec2_plugin.TryGetRuntimeContext(out context); + + // Assert + Assert.IsTrue(ret); + Assert.AreEqual(4, context.Count); + + object instance_id = ""; + context.TryGetValue("instance_id", out instance_id); + Assert.AreEqual("i-07a181803de94c666", instance_id.ToString()); + + object availability_zone = ""; + context.TryGetValue("availability_zone", out availability_zone); + Assert.AreEqual("us-east-2a", availability_zone.ToString()); + + object instance_size = ""; + context.TryGetValue("instance_size", out instance_size); + Assert.AreEqual("t3.xlarge", instance_size.ToString()); + + object ami_id = ""; + context.TryGetValue("ami_id", out ami_id); + Assert.AreEqual("ami-03cca83dd001d4666", ami_id.ToString()); + } + + [TestMethod] + public void TestV2Fail_V1Success() + { + // Arrange + EC2Plugin ec2_plugin = new MockEC2Plugin(failV1: false, failV2: true); + IDictionary context = new Dictionary(); + + // Act + bool ret = ec2_plugin.TryGetRuntimeContext(out context); + + // Assert + Assert.IsTrue(ret); + Assert.AreEqual(4, context.Count); + object instance_id = ""; + context.TryGetValue("instance_id", out instance_id); + Assert.AreEqual("i-07a181803de94c477", instance_id.ToString()); + + object availability_zone = ""; + context.TryGetValue("availability_zone", out availability_zone); + Assert.AreEqual("us-west-2a", availability_zone.ToString()); + + object instance_size = ""; + context.TryGetValue("instance_size", out instance_size); + Assert.AreEqual("t2.xlarge", instance_size.ToString()); + + object ami_id = ""; + context.TryGetValue("ami_id", out ami_id); + Assert.AreEqual("ami-03cca83dd001d4d11", ami_id.ToString()); + } + + [TestMethod] + public void TestV2Fail_V1Fail() + { + // Arrange + EC2Plugin ec2_plugin = new MockEC2Plugin(failV1: true, failV2: true); + IDictionary context = new Dictionary(); + + // Act + bool ret = ec2_plugin.TryGetRuntimeContext(out context); + + // Assert + Assert.IsFalse(ret); + Assert.AreEqual(0, context.Count); + } + } + + // This is a mock class created for the purpose of unit testing. The overridden DoRequest method returns valid values or Exception + // based on the conditions for the tests. + public class MockEC2Plugin : EC2Plugin + { + private readonly bool _failV2; + private readonly bool _failV1; + + public MockEC2Plugin(bool failV1, bool failV2) + { + _failV1 = failV1; + _failV2 = failV2; + } + + protected override Task DoRequest(string url, HttpMethod method, Dictionary headers = null) + { + if (_failV2 && url == "http://169.254.169.254/latest/api/token") + { + throw new Exception("Unable to complete the v2 request successfully"); + } + else if (!_failV2 && url == "http://169.254.169.254/latest/api/token") + { + return Task.FromResult("dummyTokenfromferg"); + } + else if (_failV1) + { + throw new Exception("Unable to complete the v1 request successfully"); + } + + string meta_string = ""; + if (headers == null) // for v1 endpoint request + { + meta_string = "{\"availabilityZone\" : \"us-west-2a\", \"imageId\" : \"ami-03cca83dd001d4d11\", \"instanceId\" : \"i-07a181803de94c477\", \"instanceType\" : \"t2.xlarge\"}"; + return Task.FromResult(meta_string); + } + else + { // for v2 endpoint + meta_string = "{\"availabilityZone\" : \"us-east-2a\", \"imageId\" : \"ami-03cca83dd001d4666\", \"instanceId\" : \"i-07a181803de94c666\", \"instanceType\" : \"t3.xlarge\"}"; + return Task.FromResult(meta_string); + } + } + + } + +}