Skip to content

Commit

Permalink
chore: refactor parameter parsing in ListIdentities and disallow comb…
Browse files Browse the repository at this point in the history
…ining filters
  • Loading branch information
alnr committed Dec 12, 2024
1 parent 8cbb5bd commit 9b7bfad
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 67 deletions.
127 changes: 69 additions & 58 deletions identity/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,7 @@ func (h *Handler) RegisterAdminRoutes(admin *x.RouterAdmin) {
// Paginated Identity List Response
//
// swagger:response listIdentities
//
//nolint:deadcode,unused
//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type listIdentitiesResponse struct {
type _ struct {
migrationpagination.ResponseHeaderAnnotation

// List of identities
Expand All @@ -133,11 +130,10 @@ type listIdentitiesResponse struct {

// Paginated List Identity Parameters
//
// swagger:parameters listIdentities
// Note: Filters cannot be combined.
//
//nolint:deadcode,unused
//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type listIdentitiesParameters struct {
// swagger:parameters listIdentities
type _ struct {
migrationpagination.RequestParameters

// List of ids used to filter identities.
Expand Down Expand Up @@ -183,11 +179,73 @@ type listIdentitiesParameters struct {
crdbx.ConsistencyRequestParameters
}

func parseListIdentitiesParameters(r *http.Request) (params ListIdentityParameters, err error) {
query := r.URL.Query()
var requestedFilters int

params.Expand = ExpandDefault

if ids := query["ids"]; len(ids) > 0 {
requestedFilters++
for _, v := range ids {
id, err := uuid.FromString(v)
if err != nil {
return params, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Invalid UUID value `%s` for parameter `ids`.", v))
}
params.IdsFilter = append(params.IdsFilter, id)
}
}
if len(params.IdsFilter) > 500 {
return params, errors.WithStack(herodot.ErrBadRequest.WithReason("The number of ids to filter must not exceed 500."))
}

if orgID := query.Get("organization_id"); orgID != "" {
requestedFilters++
params.OrganizationID, err = uuid.FromString(orgID)
if err != nil {
return params, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Invalid UUID value `%s` for parameter `organization_id`.", orgID))
}
}

if identifier := query.Get("credentials_identifier"); identifier != "" {
requestedFilters++
params.Expand = ExpandEverything
params.CredentialsIdentifier = identifier
}

if identifier := query.Get("credentials_identifier_similar"); identifier != "" {
requestedFilters++
params.Expand = ExpandEverything
params.CredentialsIdentifierSimilar = identifier
}

for _, v := range query["include_credential"] {
params.Expand = ExpandEverything
tc, ok := ParseCredentialsType(v)
if !ok {
return params, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Invalid value `%s` for parameter `include_credential`.", v))
}
params.DeclassifyCredentials = append(params.DeclassifyCredentials, tc)
}

if requestedFilters > 1 {
return params, errors.WithStack(herodot.ErrBadRequest.WithReason("You cannot combine multiple filters in this API"))
}

params.KeySetPagination, params.PagePagination, err = x.ParseKeysetOrPagePagination(r)
if err != nil {
return params, err

Check warning on line 237 in identity/handler.go

View check run for this annotation

Codecov / codecov/patch

identity/handler.go#L237

Added line #L237 was not covered by tests
}
params.ConsistencyLevel = crdbx.ConsistencyLevelFromRequest(r)

return params, nil
}

// swagger:route GET /admin/identities identity listIdentities
//
// # List Identities
//
// Lists all [identities](https://www.ory.sh/docs/kratos/concepts/identity-user-model) in the system.
// Lists all [identities](https://www.ory.sh/docs/kratos/concepts/identity-user-model) in the system. Note: filters cannot be combined.
//
// Produces:
// - application/json
Expand All @@ -201,54 +259,7 @@ type listIdentitiesParameters struct {
// 200: listIdentities
// default: errorGeneric
func (h *Handler) list(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
includeCredentials := r.URL.Query()["include_credential"]
var err error
var declassify []CredentialsType
for _, v := range includeCredentials {
tc, ok := ParseCredentialsType(v)
if ok {
declassify = append(declassify, tc)
} else {
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Invalid value `%s` for parameter `include_credential`.", declassify)))
return
}
}

var orgId uuid.UUID
if orgIdStr := r.URL.Query().Get("organization_id"); orgIdStr != "" {
orgId, err = uuid.FromString(r.URL.Query().Get("organization_id"))
if err != nil {
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Invalid UUID value `%s` for parameter `organization_id`.", r.URL.Query().Get("organization_id"))))
return
}
}
var idsFilter []uuid.UUID
for _, v := range r.URL.Query()["ids"] {
id, err := uuid.FromString(v)
if err != nil {
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Invalid UUID value `%s` for parameter `ids`.", v)))
return
}
idsFilter = append(idsFilter, id)
}

params := ListIdentityParameters{
Expand: ExpandDefault,
IdsFilter: idsFilter,
CredentialsIdentifier: r.URL.Query().Get("credentials_identifier"),
CredentialsIdentifierSimilar: r.URL.Query().Get("preview_credentials_identifier_similar"),
OrganizationID: orgId,
ConsistencyLevel: crdbx.ConsistencyLevelFromRequest(r),
DeclassifyCredentials: declassify,
}
if params.CredentialsIdentifier != "" && params.CredentialsIdentifierSimilar != "" {
h.r.Writer().WriteError(w, r, herodot.ErrBadRequest.WithReason("Cannot pass both credentials_identifier and preview_credentials_identifier_similar."))
return
}
if params.CredentialsIdentifier != "" || params.CredentialsIdentifierSimilar != "" || len(params.DeclassifyCredentials) > 0 {
params.Expand = ExpandEverything
}
params.KeySetPagination, params.PagePagination, err = x.ParseKeysetOrPagePagination(r)
params, err := parseListIdentitiesParameters(r)
if err != nil {
h.r.Writer().WriteError(w, r, err)
return
Expand All @@ -271,7 +282,7 @@ func (h *Handler) list(w http.ResponseWriter, r *http.Request, _ httprouter.Para
}
u := *r.URL
pagepagination.PaginationHeader(w, &u, total, params.PagePagination.Page, params.PagePagination.ItemsPerPage)
} else {
} else if nextPage != nil {
u := *r.URL
keysetpagination.Header(w, &u, nextPage)
}
Expand Down
35 changes: 32 additions & 3 deletions identity/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,21 +369,50 @@ func TestHandler(t *testing.T) {
id := x.ParseUUID(res.Get("id").String())
ids = append(ids, id)
}
require.Equal(t, len(ids), identitiesAmount)
require.Len(t, ids, identitiesAmount)
})

t.Run("case=list few identities", func(t *testing.T) {
url := "/identities?ids=" + ids[0].String()
url := "/identities?ids=" + ids[0].String() + "&ids=" + ids[0].String() // duplicate ID is deduplicated in result
for i := 1; i < listAmount; i++ {
url += "&ids=" + ids[i].String()
}
res := get(t, adminTS, url, http.StatusOK)

identities := res.Array()
require.Equal(t, len(identities), listAmount)
require.Len(t, identities, listAmount)
})
})

t.Run("case=list identities by ID is capped at 500", func(t *testing.T) {
url := "/identities?ids=" + x.NewUUID().String()
for i := 0; i < 501; i++ {
url += "&ids=" + x.NewUUID().String()
}
res := get(t, adminTS, url, http.StatusBadRequest)
assert.Contains(t, res.Get("error.reason").String(), "must not exceed 500")
})

t.Run("case=list identities cannot combine filters", func(t *testing.T) {
filters := []string{
"ids=" + x.NewUUID().String(),
"[email protected]",
"credentials_identifier_similar=bar.com",
"organization_id=" + x.NewUUID().String(),
}
for i := range filters {
for j := range filters {
if i == j {
continue // OK to use the same filter multiple times. Behavior varies by filter, though.
}

url := "/identities?" + filters[i] + "&" + filters[j]
res := get(t, adminTS, url, http.StatusBadRequest)
assert.Contains(t, res.Get("error.reason").String(), "cannot combine multiple filters")
}
}
})

t.Run("case=malformed ids should return an error", func(t *testing.T) {
res := get(t, adminTS, "/identities?ids=not-a-uuid", http.StatusBadRequest)
assert.Contains(t, res.Get("error.reason").String(), "Invalid UUID value `not-a-uuid` for parameter `ids`.", "%s", res.Raw)
Expand Down
4 changes: 2 additions & 2 deletions internal/client-go/api_identity.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions internal/httpclient/api_identity.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion spec/api.json
Original file line number Diff line number Diff line change
Expand Up @@ -3930,7 +3930,7 @@
},
"/admin/identities": {
"get": {
"description": "Lists all [identities](https://www.ory.sh/docs/kratos/concepts/identity-user-model) in the system.",
"description": "Lists all [identities](https://www.ory.sh/docs/kratos/concepts/identity-user-model) in the system. Note: filters cannot be combined.",
"operationId": "listIdentities",
"parameters": [
{
Expand Down
2 changes: 1 addition & 1 deletion spec/swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@
"oryAccessToken": []
}
],
"description": "Lists all [identities](https://www.ory.sh/docs/kratos/concepts/identity-user-model) in the system.",
"description": "Lists all [identities](https://www.ory.sh/docs/kratos/concepts/identity-user-model) in the system. Note: filters cannot be combined.",
"produces": [
"application/json"
],
Expand Down

0 comments on commit 9b7bfad

Please sign in to comment.