Skip to content

Commit

Permalink
Imdsv2 support (#134)
Browse files Browse the repository at this point in the history
* Refactored the EC2Plugin to use IMDSv2 endpoint via HTTP and fallback to v1 if v2 fails

* Reafctored code to non-static members. Added unit tests for EC2Plugin

* Refactored the logic to fetch metadata in a single flow irrespective of v1 and v2

* Refactored to return empty dict instead of null to drop reduntant null check

* Removed some unnecessary variables and assignments
  • Loading branch information
srprash authored Jun 1, 2020
1 parent b1ade64 commit 1b90091
Show file tree
Hide file tree
Showing 2 changed files with 233 additions and 17 deletions.
98 changes: 81 additions & 17 deletions sdk/src/Core/Plugins/EC2Plugin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
// </copyright>
//-----------------------------------------------------------------------------

using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Threading.Tasks;
using Amazon.Runtime.Internal.Util;
using Amazon.Util;
using ThirdParty.LitJson;

namespace Amazon.XRay.Recorder.Core.Plugins
{
Expand All @@ -27,6 +31,8 @@ namespace Amazon.XRay.Recorder.Core.Plugins
public class EC2Plugin : IPlugin
{
private static readonly Logger _logger = Logger.GetLogger(typeof(EC2Plugin));
private readonly HttpClient _client = new HttpClient();
const string metadata_base_url = "http://169.254.169.254/latest/";

/// <summary>
/// Gets the name of the origin associated with this plugin.
Expand Down Expand Up @@ -55,41 +61,99 @@ public string ServiceName
/// <summary>
/// Gets the context of the runtime that this plugin is associated with.
/// </summary>
/// <param name="context">When the method returns, contains the runtime context of the plugin, or null if the runtime context is not available.</param>
/// <param name="context">When the method returns, contains the runtime context of the plugin.</param>
/// <returns>true if the runtime context is available; Otherwise, false.</returns>
public bool TryGetRuntimeContext(out IDictionary<string, object> context)
{
context = null;
// get the token
string token = GetToken();

var dict = new Dictionary<string, object>();
if (EC2InstanceMetadata.InstanceId != null)
// get the metadata
context = GetMetadata(token);

if (context.Count == 0)
{
dict.Add("instance_id", EC2InstanceMetadata.InstanceId);
_logger.DebugFormat("Could not get instance metadata");
return false;
}

if (EC2InstanceMetadata.AvailabilityZone != null)
return true;
}

private string GetToken()
{
string token = null;
try
{
dict.Add("availability_zone", EC2InstanceMetadata.AvailabilityZone);
Dictionary<string, string> header = new Dictionary<string, string>(1);
header.Add("X-aws-ec2-metadata-token-ttl-seconds", "60");
token = DoRequest(metadata_base_url + "api/token", HttpMethod.Put, header).Result;
}
catch (Exception)
{
_logger.DebugFormat("Failed to get token for IMDSv2");
}

return token;
}

if (EC2InstanceMetadata.InstanceType != null)

private IDictionary<string, object> GetMetadata(string token)
{
try
{
Dictionary<string, string> headers = null;
if (token != null)
{
headers = new Dictionary<string, string>(1);
headers.Add("X-aws-ec2-metadata-token", token);
}
string identity_doc_url = metadata_base_url + "dynamic/instance-identity/document";
string doc_string = DoRequest(identity_doc_url, HttpMethod.Get, headers).Result;
return ParseMetadata(doc_string);
}
catch (Exception)
{
dict.Add("instance_size", EC2InstanceMetadata.InstanceType);
_logger.DebugFormat("Error occurred while getting EC2 metadata");
return new Dictionary<string, object>();
}
}

if (EC2InstanceMetadata.AmiId != null)

protected virtual async Task<string> DoRequest(string url, HttpMethod method, Dictionary<string, string> headers = null)
{
HttpRequestMessage request = new HttpRequestMessage(method, url);
if (headers != null)
{
dict.Add("ami_id", EC2InstanceMetadata.AmiId);
foreach (var item in headers)
{
request.Headers.Add(item.Key, item.Value);
}
}

if (dict.Count == 0)
HttpResponseMessage response = await _client.SendAsync(request);
if (response.IsSuccessStatusCode)
{
_logger.DebugFormat("Unable to contact EC2 metadata service, failed to get runtime context.");
return false;
return await response.Content.ReadAsStringAsync();
}
else
{
throw new Exception("Unable to complete the request successfully");
}
}

context = dict;
return true;

private IDictionary<string, object> ParseMetadata(string jsonString)
{
JsonData data = JsonMapper.ToObject(jsonString);
Dictionary<string, object> ec2_meta_dict = new Dictionary<string, object>();

ec2_meta_dict.Add("instance_id", data["instanceId"]);
ec2_meta_dict.Add("availability_zone", data["availabilityZone"]);
ec2_meta_dict.Add("instance_size", data["instanceType"]);
ec2_meta_dict.Add("ami_id", data["imageId"]);

return ec2_meta_dict;
}
}
}
152 changes: 152 additions & 0 deletions sdk/test/UnitTests/TestEC2Plugin.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
//-----------------------------------------------------------------------------
// <copyright file="TestPlugins.cs" company="Amazon.com">
// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License").
// You may not use this file except in compliance with the License.
// A copy of the License is located at
//
// http://aws.amazon.com/apache2.0
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.
// </copyright>
//-----------------------------------------------------------------------------

using System;
using System.Collections.Generic;
using System.Net.Http;
using System.Threading.Tasks;
using Amazon.XRay.Recorder.Core.Plugins;
using Microsoft.VisualStudio.TestTools.UnitTesting;


namespace Amazon.XRay.Recorder.UnitTests
{
[TestClass]
public class TestEC2Plugin
{

[TestMethod]
public void TestV2Success()
{
// Arrange
EC2Plugin ec2_plugin = new MockEC2Plugin(failV1: false, failV2: false);
IDictionary<string, object> context = new Dictionary<string, object>();

// Act
bool ret = ec2_plugin.TryGetRuntimeContext(out context);

// Assert
Assert.IsTrue(ret);
Assert.AreEqual(4, context.Count);

object instance_id = "";
context.TryGetValue("instance_id", out instance_id);
Assert.AreEqual("i-07a181803de94c666", instance_id.ToString());

object availability_zone = "";
context.TryGetValue("availability_zone", out availability_zone);
Assert.AreEqual("us-east-2a", availability_zone.ToString());

object instance_size = "";
context.TryGetValue("instance_size", out instance_size);
Assert.AreEqual("t3.xlarge", instance_size.ToString());

object ami_id = "";
context.TryGetValue("ami_id", out ami_id);
Assert.AreEqual("ami-03cca83dd001d4666", ami_id.ToString());
}

[TestMethod]
public void TestV2Fail_V1Success()
{
// Arrange
EC2Plugin ec2_plugin = new MockEC2Plugin(failV1: false, failV2: true);
IDictionary<string, object> context = new Dictionary<string, object>();

// Act
bool ret = ec2_plugin.TryGetRuntimeContext(out context);

// Assert
Assert.IsTrue(ret);
Assert.AreEqual(4, context.Count);
object instance_id = "";
context.TryGetValue("instance_id", out instance_id);
Assert.AreEqual("i-07a181803de94c477", instance_id.ToString());

object availability_zone = "";
context.TryGetValue("availability_zone", out availability_zone);
Assert.AreEqual("us-west-2a", availability_zone.ToString());

object instance_size = "";
context.TryGetValue("instance_size", out instance_size);
Assert.AreEqual("t2.xlarge", instance_size.ToString());

object ami_id = "";
context.TryGetValue("ami_id", out ami_id);
Assert.AreEqual("ami-03cca83dd001d4d11", ami_id.ToString());
}

[TestMethod]
public void TestV2Fail_V1Fail()
{
// Arrange
EC2Plugin ec2_plugin = new MockEC2Plugin(failV1: true, failV2: true);
IDictionary<string, object> context = new Dictionary<string, object>();

// Act
bool ret = ec2_plugin.TryGetRuntimeContext(out context);

// Assert
Assert.IsFalse(ret);
Assert.AreEqual(0, context.Count);
}
}

// This is a mock class created for the purpose of unit testing. The overridden DoRequest method returns valid values or Exception
// based on the conditions for the tests.
public class MockEC2Plugin : EC2Plugin
{
private readonly bool _failV2;
private readonly bool _failV1;

public MockEC2Plugin(bool failV1, bool failV2)
{
_failV1 = failV1;
_failV2 = failV2;
}

protected override Task<string> DoRequest(string url, HttpMethod method, Dictionary<string, string> headers = null)
{
if (_failV2 && url == "http://169.254.169.254/latest/api/token")
{
throw new Exception("Unable to complete the v2 request successfully");
}
else if (!_failV2 && url == "http://169.254.169.254/latest/api/token")
{
return Task.FromResult("dummyTokenfromferg");
}
else if (_failV1)
{
throw new Exception("Unable to complete the v1 request successfully");
}

string meta_string = "";
if (headers == null) // for v1 endpoint request
{
meta_string = "{\"availabilityZone\" : \"us-west-2a\", \"imageId\" : \"ami-03cca83dd001d4d11\", \"instanceId\" : \"i-07a181803de94c477\", \"instanceType\" : \"t2.xlarge\"}";
return Task.FromResult(meta_string);
}
else
{ // for v2 endpoint
meta_string = "{\"availabilityZone\" : \"us-east-2a\", \"imageId\" : \"ami-03cca83dd001d4666\", \"instanceId\" : \"i-07a181803de94c666\", \"instanceType\" : \"t3.xlarge\"}";
return Task.FromResult(meta_string);
}
}

}

}

0 comments on commit 1b90091

Please sign in to comment.