Skip to content

Commit

Permalink
Retreive subscription by id from server if subscription id is provided
Browse files Browse the repository at this point in the history
  • Loading branch information
msJinLei committed Oct 10, 2022
1 parent 3f3e61b commit ccd58da
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 24 deletions.
30 changes: 17 additions & 13 deletions src/Accounts/Accounts.Test/AzureRMProfileTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -445,9 +445,11 @@ public void MultipleTenantsAndSubscriptionsSucceed()
Assert.Equal(2, tenantResults.Count());
tenantResults = client.ListTenants(DefaultTenant.ToString());
Assert.Single(tenantResults);
IAzureSubscription subValue;
Assert.True(client.TryGetSubscriptionById(DefaultTenant.ToString(), DefaultSubscription.ToString(), out subValue));
Assert.Equal(DefaultSubscription.ToString(), subValue.Id.ToString());
var subValues = client.TryGetSubscriptionById(DefaultTenant.ToString(), DefaultSubscription.ToString());
Assert.True(subValues != null && subValues.Count() > 0);
Assert.Equal(DefaultSubscription.ToString(), subValues.FirstOrDefault().Id.ToString());

IAzureSubscription subValue = null;
Assert.True(client.TryGetSubscriptionByName(DefaultTenant.ToString(),
MockSubscriptionClientFactory.GetSubscriptionNameFromId(DefaultSubscription.ToString()),
out subValue));
Expand All @@ -474,14 +476,14 @@ public void MultipleTenantsSubscriptionListSucceed()
Assert.Equal(2, tenantResults.Count());
tenantResults = client.ListTenants(DefaultTenant.ToString());
Assert.Single(tenantResults);
IAzureSubscription subValue;
IEnumerable<IAzureSubscription> subValueList;
Assert.True(client.TryGetSubscriptionById(DefaultTenant.ToString(), DefaultSubscription.ToString(), out subValue));
Assert.Equal(DefaultSubscription.ToString(), subValue.Id.ToString());
subValueList = client.TryGetSubscriptionById(DefaultTenant.ToString(), DefaultSubscription.ToString());
Assert.True(subValueList != null && subValueList.Count() > 0);
Assert.Equal(DefaultSubscription.ToString(), subValueList.FirstOrDefault().Id.ToString());
Assert.True(client.TryGetSubscriptionListByName(DefaultTenant.ToString(),
MockSubscriptionClientFactory.GetSubscriptionNameFromId(DefaultSubscription.ToString()),
out subValueList));
Assert.Equal(DefaultSubscription.ToString(), subValue.Id.ToString());
Assert.Equal(DefaultSubscription.ToString(), subValueList.FirstOrDefault().Id.ToString());
}

[Fact]
Expand All @@ -500,9 +502,10 @@ public void SingleTenantAndSubscriptionSucceeds()
Assert.Single(tenantResults);
tenantResults = client.ListTenants(DefaultTenant.ToString());
Assert.Single(tenantResults);
var subValues = client.TryGetSubscriptionById(DefaultTenant.ToString(), DefaultSubscription.ToString());
Assert.True(subValues != null && subValues.Count() > 0);
Assert.Equal(DefaultSubscription.ToString(), subValues.FirstOrDefault().Id.ToString());
IAzureSubscription subValue;
Assert.True(client.TryGetSubscriptionById(DefaultTenant.ToString(), DefaultSubscription.ToString(), out subValue));
Assert.Equal(DefaultSubscription.ToString(), subValue.Id.ToString());
Assert.True(client.TryGetSubscriptionByName(DefaultTenant.ToString(),
MockSubscriptionClientFactory.GetSubscriptionNameFromId(DefaultSubscription.ToString()),
out subValue));
Expand Down Expand Up @@ -585,9 +588,9 @@ public void SubscriptionNotFoundDoesNotThrow()
thirdList, fourthList);
var subResults = new List<IAzureSubscription>(client.ListSubscriptions());
Assert.Equal(2, subResults.Count);
var subValues = client.TryGetSubscriptionById(DefaultTenant.ToString(), DefaultSubscription.ToString());
Assert.True(subValues == null || subValues.Count() == 0);
IAzureSubscription subValue;

Assert.False(client.TryGetSubscriptionById(DefaultTenant.ToString(), DefaultSubscription.ToString(), out subValue));
Assert.False(client.TryGetSubscriptionByName("random-tenant", "random-subscription", out subValue));
}

Expand Down Expand Up @@ -629,8 +632,9 @@ public void NoSubscriptionsInListDoesNotThrow()
subscriptions, subscriptions,
subscriptions, subscriptions);
Assert.Empty(client.ListSubscriptions());
IAzureSubscription subValue;
Assert.False(client.TryGetSubscriptionById(DefaultTenant.ToString(), DefaultSubscription.ToString(), out subValue));
var subValues = client.TryGetSubscriptionById(DefaultTenant.ToString(), DefaultSubscription.ToString());
Assert.True(subValues == null || subValues.Count() == 0);
IAzureSubscription subValue = null;
Assert.False(client.TryGetSubscriptionByName(DefaultTenant.ToString(), "random-name", out subValue));
}

Expand Down
41 changes: 34 additions & 7 deletions src/Accounts/Accounts/Models/RMProfileClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ public IAzureContext SetCurrentContext(string subscriptionNameOrId, string tenan
{
if (Guid.TryParse(subscriptionNameOrId, out subscriptionId))
{
TryGetSubscriptionById(tenantId, subscriptionNameOrId, out subscription);
subscription = TryGetSubscriptionById(tenantId, subscriptionNameOrId)?.FirstOrDefault();
}
else
{
Expand Down Expand Up @@ -401,17 +401,44 @@ public List<AzureTenant> ListTenants(string tenant = "")
.ToList();
}

public bool TryGetSubscriptionById(string tenantId, string subscriptionId, out IAzureSubscription subscription)
public IEnumerable<IAzureSubscription> TryGetSubscriptionById(string tenantId, string subscriptionId)
{
Guid subscriptionIdGuid;
subscription = null;
var subscriptions = new List<IAzureSubscription>();
if (Guid.TryParse(subscriptionId, out subscriptionIdGuid))
{
var subscriptionList = ListSubscriptions(tenantId).Where(s => s.GetId() == subscriptionIdGuid);
subscription = subscriptionList.FirstOrDefault(s => s.GetTenant() == s.GetHomeTenant()) ??
subscriptionList.FirstOrDefault();
var tenants = string.IsNullOrEmpty(tenantId) ? ListTenants() : new List<AzureTenant>() { CreateTenant(tenantId) };

IAzureAccount account = _profile.DefaultContext.Account;
IAzureEnvironment environment = _profile.DefaultContext.Environment;
string promptBehavior = ShowDialog.Never;

foreach (var tenant in tenants)
{
try
{
SecureString password = null;
IAccessToken accessToken = null;
try
{
accessToken = AcquireAccessToken(account, environment, tenant.Id, password, promptBehavior, null);
}
catch (Exception e)
{
WriteWarningMessage(string.Format(ProfileMessages.UnableToAqcuireToken, tenantId, e.Message));
WriteDebugMessage(string.Format(ProfileMessages.UnableToAqcuireToken, tenantId, e.ToString()));
continue;
}
subscriptions.Add(SubscriptionAndTenantClient?.GetSubscriptionById(subscriptionId, accessToken, account, environment));
break;
}
catch (CloudException e)
{
WriteDebugMessage(e.ToString());
}
}
}
return subscription != null;
return subscriptions;
}

public bool TryGetSubscriptionByName(string tenantId, string subscriptionName, out IAzureSubscription subscription)
Expand Down
8 changes: 4 additions & 4 deletions src/Accounts/Accounts/Subscription/GetAzureRMSubscription.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ public override void ExecuteCmdlet()
}
else if (!string.IsNullOrWhiteSpace(this.SubscriptionId))
{
IAzureSubscription result;
IEnumerable<IAzureSubscription> result = null;
try
{
if (!this._client.TryGetSubscriptionById(TenantId, this.SubscriptionId, out result))
result = this._client.TryGetSubscriptionById(TenantId, this.SubscriptionId);
if (result == null || result.Count() == 0)
{
ThrowSubscriptionNotFoundError(this.TenantId, this.SubscriptionId);
}

WriteObject( new PSAzureSubscription(result));
WriteObject(new PSAzureSubscription(result.FirstOrDefault()));
}
catch (AadAuthenticationException exception)
{
Expand Down

0 comments on commit ccd58da

Please sign in to comment.