generated from TBD54566975/tbd-project-template
-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add support for ingress path params
- Loading branch information
1 parent
338989b
commit 64d57e1
Showing
11 changed files
with
450 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
package controller | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"fmt" | ||
"math/rand" | ||
"net/http" | ||
"reflect" | ||
"strconv" | ||
"strings" | ||
|
||
"github.com/TBD54566975/ftl/backend/common/slices" | ||
"github.com/TBD54566975/ftl/backend/controller/dal" | ||
"github.com/TBD54566975/ftl/backend/schema" | ||
"github.com/alecthomas/errors" | ||
) | ||
|
||
func (s *Service) getIngressRoute(ctx context.Context, method string, path string) (*dal.IngressRoute, error) { | ||
routes, err := s.dal.GetIngressRoutes(ctx, method) | ||
if err != nil { | ||
return nil, err | ||
} | ||
var matchedRoutes = slices.Filter(routes, func(route dal.IngressRoute) bool { | ||
return matchAndExtractAllSegments(route.Path, path, func(segment, value string) {}) | ||
}) | ||
|
||
if len(matchedRoutes) == 0 { | ||
return nil, dal.ErrNotFound | ||
} | ||
|
||
// TODO: add load balancing at some point | ||
route := matchedRoutes[rand.Intn(len(matchedRoutes))] //nolint:gosec | ||
return &route, nil | ||
} | ||
|
||
func matchAndExtractAllSegments(pattern, urlPath string, onMatch func(segment, value string)) bool { | ||
patternSegments := strings.Split(strings.Trim(pattern, "/"), "/") | ||
urlSegments := strings.Split(strings.Trim(urlPath, "/"), "/") | ||
|
||
if len(patternSegments) != len(urlSegments) { | ||
return false | ||
} | ||
|
||
for i, segment := range patternSegments { | ||
if segment == "" && urlSegments[i] == "" { | ||
continue // Skip empty segments | ||
} | ||
|
||
if strings.HasPrefix(segment, "{") && strings.HasSuffix(segment, "}") { | ||
key := strings.Trim(segment, "{}") // Dynamic segment | ||
onMatch(key, urlSegments[i]) | ||
} else if segment != urlSegments[i] { | ||
return false | ||
} | ||
} | ||
return true | ||
} | ||
|
||
func (s *Service) validateAndExtractBody(route *dal.IngressRoute, r *http.Request, sch *schema.Schema) ([]byte, error) { | ||
requestMap, err := buildRequestMap(route, r) | ||
if err != nil { | ||
return nil, errors.WithStack(err) | ||
} | ||
|
||
verb := sch.ResolveVerbRef(&schema.VerbRef{Name: route.Verb, Module: route.Module}) | ||
if verb == nil { | ||
return nil, errors.Errorf("unknown verb %s", route.Verb) | ||
} | ||
|
||
dataRef := verb.Request | ||
if dataRef.Module == "" { | ||
dataRef.Module = route.Module | ||
} | ||
|
||
err = validateRequestMap(dataRef, "", requestMap, sch) | ||
if err != nil { | ||
return nil, errors.WithStack(err) | ||
} | ||
|
||
body, err := json.Marshal(requestMap) | ||
if err != nil { | ||
return nil, errors.WithStack(err) | ||
} | ||
|
||
return body, nil | ||
} | ||
|
||
func buildRequestMap(route *dal.IngressRoute, r *http.Request) (map[string]any, error) { | ||
requestMap := map[string]any{} | ||
matchAndExtractAllSegments(route.Path, r.URL.Path, func(segment, value string) { | ||
requestMap[segment] = value | ||
}) | ||
|
||
switch r.Method { | ||
case http.MethodPost, http.MethodPut: | ||
var bodyMap map[string]any | ||
err := json.NewDecoder(r.Body).Decode(&bodyMap) | ||
if err != nil { | ||
return nil, errors.WithStack(err) | ||
} | ||
|
||
// Merge bodyMap into params | ||
for k, v := range bodyMap { | ||
requestMap[k] = v | ||
} | ||
default: | ||
// TODO: Support query params correctly for map and array | ||
for key, value := range r.URL.Query() { | ||
requestMap[key] = value[len(value)-1] | ||
} | ||
} | ||
|
||
return requestMap, nil | ||
} | ||
|
||
func validateRequestMap(dataRef *schema.DataRef, parentName string, request map[string]any, sch *schema.Schema) error { | ||
data := sch.ResolveDataRef(dataRef) | ||
if data == nil { | ||
return errors.Errorf("unknown data %v", dataRef) | ||
} | ||
|
||
for _, field := range data.Fields { | ||
err := validateField(field.Name, parentName, field.Type, request, sch) | ||
if err != nil { | ||
return errors.WithStack(err) | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func validateField(fieldName string, parentFieldName string, fieldType schema.Type, request map[string]any, sch *schema.Schema) error { | ||
value, ok := request[fieldName] | ||
if !ok { | ||
if parentFieldName == "" { | ||
return errors.Errorf("missing field %s", fieldName) | ||
} | ||
return errors.Errorf("missing field %s.%s", parentFieldName, fieldName) | ||
} | ||
|
||
var typeMatches bool | ||
switch fieldType := fieldType.(type) { | ||
case *schema.Int: | ||
switch value := value.(type) { | ||
case float64: | ||
typeMatches = true | ||
case string: | ||
if _, err := strconv.ParseInt(value, 10, 64); err == nil { | ||
typeMatches = true | ||
} | ||
} | ||
case *schema.Float: | ||
switch value := value.(type) { | ||
case float64: | ||
typeMatches = true | ||
case string: | ||
if _, err := strconv.ParseFloat(value, 64); err == nil { | ||
typeMatches = true | ||
} | ||
} | ||
case *schema.String: | ||
_, typeMatches = value.(string) | ||
case *schema.Bool: | ||
switch value := value.(type) { | ||
case bool: | ||
typeMatches = true | ||
case string: | ||
if _, err := strconv.ParseBool(value); err == nil { | ||
typeMatches = true | ||
} | ||
} | ||
case *schema.Array: | ||
rv := reflect.ValueOf(value) | ||
if rv.Kind() != reflect.Slice { | ||
return errors.Errorf("field %s is not a slice", fieldName) | ||
} | ||
elementType := fieldType.Element | ||
for i := 0; i < rv.Len(); i++ { | ||
elem := rv.Index(i).Interface() | ||
elemFieldName := fmt.Sprintf("%s[%d]", fieldName, i) | ||
if err := validateField(elemFieldName, fieldName, elementType, map[string]any{elemFieldName: elem}, sch); err != nil { | ||
return err | ||
} | ||
} | ||
typeMatches = true | ||
case *schema.Map: | ||
rv := reflect.ValueOf(value) | ||
if rv.Kind() != reflect.Map { | ||
return errors.Errorf("field %s is not a map", fieldName) | ||
} | ||
keyType := fieldType.Key | ||
valueType := fieldType.Value | ||
for _, key := range rv.MapKeys() { | ||
elem := rv.MapIndex(key).Interface() | ||
elemFieldName := fmt.Sprintf("%s[%v]", fieldName, key) | ||
if err := validateField(elemFieldName, fieldName, keyType, map[string]any{elemFieldName: key.Interface()}, sch); err != nil { | ||
return err | ||
} | ||
if err := validateField(elemFieldName, fieldName, valueType, map[string]any{elemFieldName: elem}, sch); err != nil { | ||
return err | ||
} | ||
} | ||
typeMatches = true | ||
case *schema.DataRef: | ||
if valueMap, ok := value.(map[string]any); ok { | ||
// HACK to get around schema extraction issues. | ||
if fieldType.Module == "" { | ||
if fieldType.Name == "Address" { | ||
fieldType.Module = "shipping" | ||
} | ||
if fieldType.Name == "CreditCardInfo" { | ||
fieldType.Module = "payment" | ||
} | ||
} | ||
if err := validateRequestMap(fieldType, fieldName, valueMap, sch); err != nil { | ||
return err | ||
} | ||
typeMatches = true | ||
} | ||
|
||
default: | ||
return errors.Errorf("field %s has unsupported type %T", fieldName, value) | ||
} | ||
|
||
if !typeMatches { | ||
return errors.Errorf("field %s has wrong type. expected %s found %T", fieldName, fieldType, value) | ||
} | ||
return nil | ||
} |
Oops, something went wrong.