Skip to content

Commit

Permalink
add basic field validation on login path (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
kpcraig authored Jun 7, 2024
1 parent a68fae7 commit da893d8
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 6 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
## Unreleased

IMPROVEMENTS:
* Add login field validation for subscription id, resource group name, vmss name, and vm name

## v0.18.0

FEATURES:
Expand Down
12 changes: 12 additions & 0 deletions azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"io"
"net/http"
"os"
"regexp"
"strings"
"time"

Expand Down Expand Up @@ -447,3 +448,14 @@ func graphURIFromName(name string) (string, error) {

return c, nil
}

// guidRx from https://learn.microsoft.com/en-us/rest/api/defenderforcloud/tasks/get-subscription-level-task
var guidRx = regexp.MustCompile(`^[0-9A-Fa-f]{8}-([0-9A-Fa-f]{4}-){3}[0-9A-Fa-f]{12}$`)
var nameRx = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9\-]*$`)
var rgRx = regexp.MustCompile(`^[\-_.\pL\pN]*[\-_\pL\pN]$`)

// verify the field provided matches Azure's requirements
// (see: https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/resource-name-rules).
func validateAzureField(regex *regexp.Regexp, value string) bool {
return regex.MatchString(value)
}
57 changes: 57 additions & 0 deletions azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
"encoding/base64"
"errors"
"fmt"
"regexp"
"strings"
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"

Expand Down Expand Up @@ -188,3 +190,58 @@ func (p *mockProvider) ProvidersClient(subscriptionID string) (client.ProvidersC
providersClientFunc: p.providersClientFunc,
}, nil
}

func TestValidationRegex(t *testing.T) {
cases := []struct {
name string
in string
regex *regexp.Regexp
isMatch bool
}{
{
name: "normal subscriptionID",
in: "1234abcd-1234-1234-defa-5678fedc90ba",
regex: guidRx,
isMatch: true,
},
{
name: "bad subscriptionID",
in: "xyzg..",
regex: guidRx,
isMatch: false,
},
{
name: "tricky name",
in: "real/../../secret/top-secret",
regex: nameRx,
isMatch: false,
},
{
name: "valid name",
in: "this-name-is-good-14",
regex: nameRx,
isMatch: true,
},
{
name: "tricky resource group",
in: "real/../../secret/top-secret",
regex: rgRx,
isMatch: false,
},
{
name: "non-ascii resource group",
in: "сыноо",
regex: rgRx,
isMatch: true,
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
out := validateAzureField(tc.regex, tc.in)
if tc.isMatch != out {
t.Fail()
}
})
}
}
13 changes: 13 additions & 0 deletions path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,19 @@ func (b *azureAuthBackend) pathLogin(ctx context.Context, req *logical.Request,
vmName := data.Get("vm_name").(string)
resourceID := data.Get("resource_id").(string)

if subscriptionID != "" && !validateAzureField(guidRx, subscriptionID) {
return logical.ErrorResponse(fmt.Sprintf("invalid subscription id %q", subscriptionID)), nil
}
if resourceGroupName != "" && !validateAzureField(rgRx, resourceGroupName) {
return logical.ErrorResponse(fmt.Sprintf("invalid resource group name %q", resourceGroupName)), nil
}
if vmssName != "" && !validateAzureField(nameRx, vmssName) {
return logical.ErrorResponse(fmt.Sprintf("invalid vmss_name %q", vmssName)), nil
}
if vmName != "" && !validateAzureField(nameRx, vmName) {
return logical.ErrorResponse(fmt.Sprintf("invalid vm name %q", vmName)), nil
}

config, err := b.config(ctx, req.Storage)
if err != nil {
return nil, fmt.Errorf("unable to retrieve backend configuration: %w", err)
Expand Down
53 changes: 47 additions & 6 deletions path_login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ func TestLogin_BoundSubscriptionID(t *testing.T) {
b, s := getTestBackendWithComputeClient(t, c, v, m, nil, g)

roleName := "testrole"
subID := "subID"
subID := "1234abcd-1234-abcd-1234-abcd1234ef90"
roleData := map[string]interface{}{
"name": roleName,
"policies": []string{"dev", "prod"},
Expand Down Expand Up @@ -551,7 +551,7 @@ func TestLogin_BoundResourceGroup(t *testing.T) {
}
testLoginFailure(t, b, s, loginData, claims, roleData)

loginData["subscription_id"] = "sub"
loginData["subscription_id"] = "1234abcd-1234-abcd-1234-abcd1234ef90"
testLoginFailure(t, b, s, loginData, claims, roleData)

loginData["resource_group_name"] = rg
Expand Down Expand Up @@ -602,7 +602,7 @@ func TestLogin_BoundResourceGroupWithUserAssignedID(t *testing.T) {
}
testLoginFailure(t, b, s, loginData, claims, roleData)

loginData["subscription_id"] = "sub"
loginData["subscription_id"] = "1234abcd-1234-abcd-1234-abcd1234ef90"
testLoginFailure(t, b, s, loginData, claims, roleData)

loginData["resource_group_name"] = rg
Expand Down Expand Up @@ -648,7 +648,7 @@ func TestLogin_BoundLocation(t *testing.T) {
}
testLoginFailure(t, b, s, loginData, claims, roleData)

loginData["subscription_id"] = "sub"
loginData["subscription_id"] = "1234abcd-1234-abcd-1234-abcd1234abcd"
loginData["resource_group_name"] = "rg"

loginData["vmss_name"] = "good"
Expand Down Expand Up @@ -693,7 +693,7 @@ func TestLogin_BoundScaleSet(t *testing.T) {
}
testLoginFailure(t, b, s, loginData, claims, roleData)

loginData["subscription_id"] = "sub"
loginData["subscription_id"] = "1234abcd-1234-abcd-1234-abcd1234ef90"
loginData["resource_group_name"] = "rg"

loginData["vmss_name"] = "goodvmss"
Expand Down Expand Up @@ -753,14 +753,55 @@ func TestLogin_AppID(t *testing.T) {
testLoginFailure(t, b, s, loginData, claims, roleData)

loginData["resource_group_name"] = resourceGroup
loginData["subscription_id"] = "sub"
loginData["subscription_id"] = "1234abcd-1234-abcd-1234-abcd1234ef90"
loginData["vmss_name"] = "vmss"
testLoginSuccess(t, b, s, loginData, claims, roleData)

claims["appid"] = badID
testLoginFailure(t, b, s, loginData, claims, roleData)
}

func TestLogin_InvalidCharacters(t *testing.T) {
b, s := getTestBackend(t)

roleName := "testrole"
roleData := map[string]interface{}{
"name": roleName,
"policies": []string{"dev", "prod"},
"bound_service_principal_ids": []string{"*"},
}
testRoleCreate(t, b, s, roleData)

claims := map[string]interface{}{
"exp": time.Now().Add(60 * time.Second).Unix(),
"nbf": time.Now().Add(-60 * time.Second).Unix(),
}

loginData := map[string]interface{}{
"role": roleName,
"subscription_id": "1234abcd-1234-1234-abcd-abcd1234abcd",
"vmss_name": "vmss",
"vm_name": "vm",
"resource_group": "rg",
}
testLoginSuccess(t, b, s, loginData, claims, roleData)

loginData["subscription_id"] = ".." // illegal
testLoginFailure(t, b, s, loginData, claims, roleData)

loginData["subscription_id"] = "1234abcd-1234-1234-abcd-1234-abcd1234abcd"
loginData["vmss_name"] = "a/../b"
testLoginFailure(t, b, s, loginData, claims, roleData)

loginData["vmss_name"] = "vmss"
loginData["vm_name"] = "../a"
testLoginFailure(t, b, s, loginData, claims, roleData)

loginData["vm_name"] = "vm"
loginData["resource_group_name"] = "a/../b"
testLoginFailure(t, b, s, loginData, claims, roleData)
}

func testLoginSuccess(t *testing.T, b *azureAuthBackend, s logical.Storage, loginData, claims, roleData map[string]interface{}) {
t.Helper()
if err := testLoginWithClaims(t, b, s, loginData, claims, roleData); err != nil {
Expand Down

0 comments on commit da893d8

Please sign in to comment.