diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e954c6b..094f0594 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ ## Unreleased +IMPROVEMENTS: +* Add login field validation for subscription id, resource group name, vmss name, and vm name + ## v0.18.0 FEATURES: diff --git a/azure.go b/azure.go index 8d10fa94..f25f2ac6 100644 --- a/azure.go +++ b/azure.go @@ -11,6 +11,7 @@ import ( "io" "net/http" "os" + "regexp" "strings" "time" @@ -447,3 +448,14 @@ func graphURIFromName(name string) (string, error) { return c, nil } + +// guidRx from https://learn.microsoft.com/en-us/rest/api/defenderforcloud/tasks/get-subscription-level-task +var guidRx = regexp.MustCompile(`^[0-9A-Fa-f]{8}-([0-9A-Fa-f]{4}-){3}[0-9A-Fa-f]{12}$`) +var nameRx = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9\-]*$`) +var rgRx = regexp.MustCompile(`^[\-_.\pL\pN]*[\-_\pL\pN]$`) + +// verify the field provided matches Azure's requirements +// (see: https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/resource-name-rules). +func validateAzureField(regex *regexp.Regexp, value string) bool { + return regex.MatchString(value) +} diff --git a/azure_test.go b/azure_test.go index cada0c66..89f170cc 100644 --- a/azure_test.go +++ b/azure_test.go @@ -8,7 +8,9 @@ import ( "encoding/base64" "errors" "fmt" + "regexp" "strings" + "testing" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" @@ -188,3 +190,58 @@ func (p *mockProvider) ProvidersClient(subscriptionID string) (client.ProvidersC providersClientFunc: p.providersClientFunc, }, nil } + +func TestValidationRegex(t *testing.T) { + cases := []struct { + name string + in string + regex *regexp.Regexp + isMatch bool + }{ + { + name: "normal subscriptionID", + in: "1234abcd-1234-1234-defa-5678fedc90ba", + regex: guidRx, + isMatch: true, + }, + { + name: "bad subscriptionID", + in: "xyzg..", + regex: guidRx, + isMatch: false, + }, + { + name: "tricky name", + in: "real/../../secret/top-secret", + regex: nameRx, + isMatch: false, + }, + { + name: "valid name", + in: "this-name-is-good-14", + regex: nameRx, + isMatch: true, + }, + { + name: "tricky resource group", + in: "real/../../secret/top-secret", + regex: rgRx, + isMatch: false, + }, + { + name: "non-ascii resource group", + in: "сыноо", + regex: rgRx, + isMatch: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + out := validateAzureField(tc.regex, tc.in) + if tc.isMatch != out { + t.Fail() + } + }) + } +} diff --git a/path_login.go b/path_login.go index 64053c51..0fc0c55c 100644 --- a/path_login.go +++ b/path_login.go @@ -136,6 +136,19 @@ func (b *azureAuthBackend) pathLogin(ctx context.Context, req *logical.Request, vmName := data.Get("vm_name").(string) resourceID := data.Get("resource_id").(string) + if subscriptionID != "" && !validateAzureField(guidRx, subscriptionID) { + return logical.ErrorResponse(fmt.Sprintf("invalid subscription id %q", subscriptionID)), nil + } + if resourceGroupName != "" && !validateAzureField(rgRx, resourceGroupName) { + return logical.ErrorResponse(fmt.Sprintf("invalid resource group name %q", resourceGroupName)), nil + } + if vmssName != "" && !validateAzureField(nameRx, vmssName) { + return logical.ErrorResponse(fmt.Sprintf("invalid vmss_name %q", vmssName)), nil + } + if vmName != "" && !validateAzureField(nameRx, vmName) { + return logical.ErrorResponse(fmt.Sprintf("invalid vm name %q", vmName)), nil + } + config, err := b.config(ctx, req.Storage) if err != nil { return nil, fmt.Errorf("unable to retrieve backend configuration: %w", err) diff --git a/path_login_test.go b/path_login_test.go index 6e955640..9f7bd3cd 100644 --- a/path_login_test.go +++ b/path_login_test.go @@ -487,7 +487,7 @@ func TestLogin_BoundSubscriptionID(t *testing.T) { b, s := getTestBackendWithComputeClient(t, c, v, m, nil, g) roleName := "testrole" - subID := "subID" + subID := "1234abcd-1234-abcd-1234-abcd1234ef90" roleData := map[string]interface{}{ "name": roleName, "policies": []string{"dev", "prod"}, @@ -551,7 +551,7 @@ func TestLogin_BoundResourceGroup(t *testing.T) { } testLoginFailure(t, b, s, loginData, claims, roleData) - loginData["subscription_id"] = "sub" + loginData["subscription_id"] = "1234abcd-1234-abcd-1234-abcd1234ef90" testLoginFailure(t, b, s, loginData, claims, roleData) loginData["resource_group_name"] = rg @@ -602,7 +602,7 @@ func TestLogin_BoundResourceGroupWithUserAssignedID(t *testing.T) { } testLoginFailure(t, b, s, loginData, claims, roleData) - loginData["subscription_id"] = "sub" + loginData["subscription_id"] = "1234abcd-1234-abcd-1234-abcd1234ef90" testLoginFailure(t, b, s, loginData, claims, roleData) loginData["resource_group_name"] = rg @@ -648,7 +648,7 @@ func TestLogin_BoundLocation(t *testing.T) { } testLoginFailure(t, b, s, loginData, claims, roleData) - loginData["subscription_id"] = "sub" + loginData["subscription_id"] = "1234abcd-1234-abcd-1234-abcd1234abcd" loginData["resource_group_name"] = "rg" loginData["vmss_name"] = "good" @@ -693,7 +693,7 @@ func TestLogin_BoundScaleSet(t *testing.T) { } testLoginFailure(t, b, s, loginData, claims, roleData) - loginData["subscription_id"] = "sub" + loginData["subscription_id"] = "1234abcd-1234-abcd-1234-abcd1234ef90" loginData["resource_group_name"] = "rg" loginData["vmss_name"] = "goodvmss" @@ -753,7 +753,7 @@ func TestLogin_AppID(t *testing.T) { testLoginFailure(t, b, s, loginData, claims, roleData) loginData["resource_group_name"] = resourceGroup - loginData["subscription_id"] = "sub" + loginData["subscription_id"] = "1234abcd-1234-abcd-1234-abcd1234ef90" loginData["vmss_name"] = "vmss" testLoginSuccess(t, b, s, loginData, claims, roleData) @@ -761,6 +761,47 @@ func TestLogin_AppID(t *testing.T) { testLoginFailure(t, b, s, loginData, claims, roleData) } +func TestLogin_InvalidCharacters(t *testing.T) { + b, s := getTestBackend(t) + + roleName := "testrole" + roleData := map[string]interface{}{ + "name": roleName, + "policies": []string{"dev", "prod"}, + "bound_service_principal_ids": []string{"*"}, + } + testRoleCreate(t, b, s, roleData) + + claims := map[string]interface{}{ + "exp": time.Now().Add(60 * time.Second).Unix(), + "nbf": time.Now().Add(-60 * time.Second).Unix(), + } + + loginData := map[string]interface{}{ + "role": roleName, + "subscription_id": "1234abcd-1234-1234-abcd-abcd1234abcd", + "vmss_name": "vmss", + "vm_name": "vm", + "resource_group": "rg", + } + testLoginSuccess(t, b, s, loginData, claims, roleData) + + loginData["subscription_id"] = ".." // illegal + testLoginFailure(t, b, s, loginData, claims, roleData) + + loginData["subscription_id"] = "1234abcd-1234-1234-abcd-1234-abcd1234abcd" + loginData["vmss_name"] = "a/../b" + testLoginFailure(t, b, s, loginData, claims, roleData) + + loginData["vmss_name"] = "vmss" + loginData["vm_name"] = "../a" + testLoginFailure(t, b, s, loginData, claims, roleData) + + loginData["vm_name"] = "vm" + loginData["resource_group_name"] = "a/../b" + testLoginFailure(t, b, s, loginData, claims, roleData) +} + func testLoginSuccess(t *testing.T, b *azureAuthBackend, s logical.Storage, loginData, claims, roleData map[string]interface{}) { t.Helper() if err := testLoginWithClaims(t, b, s, loginData, claims, roleData); err != nil {