-
Notifications
You must be signed in to change notification settings - Fork 84
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use retrable http client in Azure authz provider
Signed-off-by: Bin Xia <[email protected]>
- Loading branch information
1 parent
286709a
commit da82048
Showing
5 changed files
with
134 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,9 +18,13 @@ package azure | |
|
||
import ( | ||
"context" | ||
"fmt" | ||
"io/fs" | ||
"net" | ||
"net/http" | ||
"net/http/httptest" | ||
"os" | ||
"strconv" | ||
"testing" | ||
"time" | ||
|
||
|
@@ -32,12 +36,14 @@ import ( | |
errutils "go.kubeguard.dev/guard/util/error" | ||
|
||
"github.com/go-chi/chi/v5" | ||
"github.com/google/uuid" | ||
"github.com/stretchr/testify/assert" | ||
authzv1 "k8s.io/api/authorization/v1" | ||
) | ||
|
||
const ( | ||
loginResp = `{ "token_type": "Bearer", "expires_in": 8459, "access_token": "%v"}` | ||
loginResp = `{ "token_type": "Bearer", "expires_in": 8459, "access_token": "%v"}` | ||
httpClientRetryCount = 2 | ||
) | ||
|
||
func clientSetup(serverUrl, mode string) (*Authorizer, error) { | ||
|
@@ -52,9 +58,10 @@ func clientSetup(serverUrl, mode string) (*Authorizer, error) { | |
} | ||
|
||
authOpts := auth.Options{ | ||
ClientID: "client_id", | ||
ClientSecret: "client_secret", | ||
TenantID: "tenant_id", | ||
ClientID: "client_id", | ||
ClientSecret: "client_secret", | ||
TenantID: "tenant_id", | ||
HttpClientRetryCount: httpClientRetryCount, | ||
} | ||
|
||
authzInfo := rbac.AuthzInfo{ | ||
|
@@ -70,7 +77,7 @@ func clientSetup(serverUrl, mode string) (*Authorizer, error) { | |
return c, nil | ||
} | ||
|
||
func serverSetup(loginResp, checkaccessResp string, loginStatus, checkaccessStatus int, sleepFor time.Duration) (*httptest.Server, error) { | ||
func serverSetup(loginResp, checkaccessResp string, loginStatus, checkaccessStatus int, sleepFor time.Duration, calledTimesFile string) (*httptest.Server, error) { | ||
listener, err := net.Listen("tcp", "127.0.0.1:") | ||
if err != nil { | ||
return nil, err | ||
|
@@ -85,6 +92,9 @@ func serverSetup(loginResp, checkaccessResp string, loginStatus, checkaccessStat | |
|
||
m.Post("/arm/*", func(w http.ResponseWriter, r *http.Request) { | ||
time.Sleep(sleepFor) | ||
if calledTimesFile != "" { | ||
_ = incCalledTimes(calledTimesFile) | ||
} | ||
w.WriteHeader(checkaccessStatus) | ||
_, _ = w.Write([]byte(checkaccessResp)) | ||
}) | ||
|
@@ -98,8 +108,8 @@ func serverSetup(loginResp, checkaccessResp string, loginStatus, checkaccessStat | |
return srv, nil | ||
} | ||
|
||
func getServerAndClient(t *testing.T, loginResp, checkaccessResp string, checkaccessStatus int, sleepFor time.Duration) (*httptest.Server, *Authorizer, authz.Store) { | ||
srv, err := serverSetup(loginResp, checkaccessResp, http.StatusOK, checkaccessStatus, sleepFor) | ||
func getServerAndClient(t *testing.T, loginResp, checkaccessResp string, checkaccessStatus int, sleepFor time.Duration, calledTimesFile string) (*httptest.Server, *Authorizer, authz.Store) { // nolint: unparam | ||
srv, err := serverSetup(loginResp, checkaccessResp, http.StatusOK, checkaccessStatus, sleepFor, calledTimesFile) | ||
if err != nil { | ||
t.Fatalf("Error when creating server, reason: %v", err) | ||
} | ||
|
@@ -123,13 +133,32 @@ func getServerAndClient(t *testing.T, loginResp, checkaccessResp string, checkac | |
return srv, client, dataStore | ||
} | ||
|
||
func createCalledTimesFile() (string, error) { | ||
calledTimesFile := uuid.New().String() | ||
err := os.WriteFile(calledTimesFile, []byte(strconv.Itoa(0)), fs.ModeTemporary) | ||
if err != nil { | ||
return "", err | ||
} | ||
return calledTimesFile, nil | ||
} | ||
|
||
func incCalledTimes(calledTimesFile string) error { | ||
content, _ := os.ReadFile(calledTimesFile) | ||
calledTimes, _ := strconv.Atoi(string(content)) | ||
return os.WriteFile(calledTimesFile, []byte(strconv.Itoa(calledTimes+1)), fs.ModeTemporary) | ||
} | ||
|
||
func deleteCalledTimesFile(calledTimesFile string) error { | ||
return os.Remove(calledTimesFile) | ||
} | ||
|
||
func TestCheck(t *testing.T) { | ||
t.Run("successful request", func(t *testing.T) { | ||
validBody := `[{"accessDecision":"Allowed", | ||
"actionId":"Microsoft.Kubernetes/connectedClusters/pods/delete", | ||
"isDataAction":true,"roleAssignment":null,"denyAssignment":null,"timeToLiveInMs":300000}]` | ||
|
||
srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusOK, 1*time.Second) | ||
srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusOK, 1*time.Second, "") | ||
defer srv.Close() | ||
defer store.Close() | ||
|
||
|
@@ -154,7 +183,7 @@ func TestCheck(t *testing.T) { | |
|
||
t.Run("unsuccessful request", func(t *testing.T) { | ||
validBody := `""` | ||
srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusInternalServerError, 1*time.Second) | ||
srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusInternalServerError, 1*time.Second, "") | ||
defer srv.Close() | ||
defer store.Close() | ||
|
||
|
@@ -170,15 +199,49 @@ func TestCheck(t *testing.T) { | |
resp, err := client.Check(ctx, request, store) | ||
assert.Nilf(t, resp, "response should be nil") | ||
assert.NotNilf(t, err, "should get error") | ||
assert.Contains(t, err.Error(), "Error occured during authorization check") | ||
assert.Contains(t, err.Error(), "Error occured during authorization checkdfdf") | ||
if v, ok := err.(errutils.HttpStatusCode); ok { | ||
assert.Equal(t, v.Code(), http.StatusInternalServerError) | ||
} | ||
}) | ||
|
||
t.Run("unsuccessful request - check retry count", func(t *testing.T) { | ||
calledTimesFile, err := createCalledTimesFile() | ||
assert.Nilf(t, err, "Should not have got error") | ||
|
||
validBody := `""` | ||
srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusInternalServerError, 1*time.Second, calledTimesFile) | ||
defer srv.Close() | ||
defer store.Close() | ||
|
||
request := &authzv1.SubjectAccessReviewSpec{ | ||
User: "[email protected]", | ||
ResourceAttributes: &authzv1.ResourceAttributes{ | ||
Namespace: "dev", Group: "", Resource: "pods", | ||
Subresource: "status", Version: "v1", Name: "test", Verb: "delete", | ||
}, Extra: map[string]authzv1.ExtraValue{"oid": {"00000000-0000-0000-0000-000000000000"}}, | ||
} | ||
|
||
ctx := context.Background() | ||
resp, err := client.Check(ctx, request, store) | ||
assert.Nilf(t, resp, "response should be nil") | ||
assert.NotNilf(t, err, "should get error") | ||
assert.Contains(t, err.Error(), "Error occured during authorization checkdfdf") | ||
if v, ok := err.(errutils.HttpStatusCode); ok { | ||
assert.Equal(t, v.Code(), http.StatusInternalServerError) | ||
} | ||
|
||
content, _ := os.ReadFile(calledTimesFile) | ||
calledTimes, _ := strconv.Atoi(string(content)) | ||
assert.Equal(t, httpClientRetryCount+1, calledTimes, fmt.Sprintf("The server should be called %d times", httpClientRetryCount+1)) | ||
|
||
err = deleteCalledTimesFile(calledTimesFile) | ||
assert.Nilf(t, err, "Should not have got error") | ||
}) | ||
|
||
t.Run("context timeout request", func(t *testing.T) { | ||
validBody := `""` | ||
srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusInternalServerError, 25*time.Second) | ||
srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusInternalServerError, 25*time.Second, "") | ||
defer srv.Close() | ||
defer store.Close() | ||
|
||
|
@@ -194,7 +257,7 @@ func TestCheck(t *testing.T) { | |
resp, err := client.Check(ctx, request, store) | ||
assert.Nilf(t, resp, "response should be nil") | ||
assert.NotNilf(t, err, "should get error") | ||
assert.Contains(t, err.Error(), "Checkaccess requests have timed out") | ||
assert.Contains(t, err.Error(), "context deadline exceeded") | ||
if v, ok := err.(errutils.HttpStatusCode); ok { | ||
assert.Equal(t, v.Code(), http.StatusInternalServerError) | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.