Skip to content

Commit

Permalink
feat: add support for ingress path params
Browse files Browse the repository at this point in the history
  • Loading branch information
wesbillman committed Dec 1, 2023
1 parent 338989b commit 64d57e1
Show file tree
Hide file tree
Showing 11 changed files with 450 additions and 50 deletions.
47 changes: 22 additions & 25 deletions backend/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package controller

import (
"context"
"encoding/json"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -164,7 +163,7 @@ func New(ctx context.Context, db *dal.DAL, config Config, runnerScaling scaling.
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
logger := log.FromContext(r.Context())
logger.Infof("%s %s", r.Method, r.URL.Path)
routes, err := s.dal.GetIngressRoutes(r.Context(), r.Method, r.URL.Path)
route, err := s.getIngressRoute(r.Context(), r.Method, r.URL.Path)
if err != nil {
if errors.Is(err, dal.ErrNotFound) {
http.NotFound(w, r)
Expand All @@ -173,37 +172,35 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
route := routes[rand.Intn(len(routes))] //nolint:gosec
var body []byte
switch r.Method {
case http.MethodPost, http.MethodPut:
body, err = io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

default:
// TODO: Transcode query parameters into JSON.
payload := map[string]string{}
for key, value := range r.URL.Query() {
payload[key] = value[len(value)-1]
}
body, err = json.Marshal(payload)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
deployments, err := s.dal.GetActiveDeployments(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
requestName, err := s.dal.CreateIngressRequest(r.Context(), fmt.Sprintf("%s %s", r.Method, r.URL.Path), r.RemoteAddr)
sch := &schema.Schema{
Modules: slices.Map(deployments, func(d dal.Deployment) *schema.Module {
return d.Schema
}),
}

body, err := s.validateAndExtractBody(route, r, sch)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

creq := connect.NewRequest(&ftlv1.CallRequest{
Verb: &schemapb.VerbRef{Module: route.Module, Name: route.Verb},
Body: body,
Metadata: &ftlv1.Metadata{},
Verb: &schemapb.VerbRef{Module: route.Module, Name: route.Verb},
Body: body,
})

requestName, err := s.dal.CreateIngressRequest(r.Context(), fmt.Sprintf("%s %s", r.Method, r.URL.Path), r.RemoteAddr)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
headers.SetRequestName(creq.Header(), requestName)
resp, err := s.Call(r.Context(), creq)
if err != nil {
Expand Down
24 changes: 14 additions & 10 deletions backend/controller/dal/dal.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@ var (
)

type IngressRoute struct {
Runner model.RunnerKey
Endpoint string
Module string
Verb string
Runner model.RunnerKey
Deployment model.DeploymentName
Endpoint string
Path string
Module string
Verb string
}

type IngressRouteEntry struct {
Expand Down Expand Up @@ -936,8 +938,8 @@ func (d *DAL) CreateIngressRequest(ctx context.Context, route, addr string) (mod
return name, errors.WithStack(err)
}

func (d *DAL) GetIngressRoutes(ctx context.Context, method string, path string) ([]IngressRoute, error) {
routes, err := d.db.GetIngressRoutes(ctx, method, path)
func (d *DAL) GetIngressRoutes(ctx context.Context, method string) ([]IngressRoute, error) {
routes, err := d.db.GetIngressRoutes(ctx, method)
if err != nil {
return nil, errors.WithStack(translatePGError(err))
}
Expand All @@ -946,10 +948,12 @@ func (d *DAL) GetIngressRoutes(ctx context.Context, method string, path string)
}
return slices.Map(routes, func(row sql.GetIngressRoutesRow) IngressRoute {
return IngressRoute{
Runner: model.RunnerKey(row.RunnerKey),
Endpoint: row.Endpoint,
Module: row.Module,
Verb: row.Verb,
Runner: model.RunnerKey(row.RunnerKey),
Deployment: row.DeploymentName,
Endpoint: row.Endpoint,
Path: row.Path,
Module: row.Module,
Verb: row.Verb,
}
}), nil
}
Expand Down
230 changes: 230 additions & 0 deletions backend/controller/ingress.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
package controller

import (
"context"
"encoding/json"
"fmt"
"math/rand"
"net/http"
"reflect"
"strconv"
"strings"

"github.com/TBD54566975/ftl/backend/common/slices"
"github.com/TBD54566975/ftl/backend/controller/dal"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/alecthomas/errors"
)

func (s *Service) getIngressRoute(ctx context.Context, method string, path string) (*dal.IngressRoute, error) {
routes, err := s.dal.GetIngressRoutes(ctx, method)
if err != nil {
return nil, err
}
var matchedRoutes = slices.Filter(routes, func(route dal.IngressRoute) bool {
return matchAndExtractAllSegments(route.Path, path, func(segment, value string) {})
})

if len(matchedRoutes) == 0 {
return nil, dal.ErrNotFound
}

// TODO: add load balancing at some point
route := matchedRoutes[rand.Intn(len(matchedRoutes))] //nolint:gosec
return &route, nil
}

func matchAndExtractAllSegments(pattern, urlPath string, onMatch func(segment, value string)) bool {
patternSegments := strings.Split(strings.Trim(pattern, "/"), "/")
urlSegments := strings.Split(strings.Trim(urlPath, "/"), "/")

if len(patternSegments) != len(urlSegments) {
return false
}

for i, segment := range patternSegments {
if segment == "" && urlSegments[i] == "" {
continue // Skip empty segments
}

if strings.HasPrefix(segment, "{") && strings.HasSuffix(segment, "}") {
key := strings.Trim(segment, "{}") // Dynamic segment
onMatch(key, urlSegments[i])
} else if segment != urlSegments[i] {
return false
}
}
return true
}

func (s *Service) validateAndExtractBody(route *dal.IngressRoute, r *http.Request, sch *schema.Schema) ([]byte, error) {
requestMap, err := buildRequestMap(route, r)
if err != nil {
return nil, errors.WithStack(err)
}

verb := sch.ResolveVerbRef(&schema.VerbRef{Name: route.Verb, Module: route.Module})
if verb == nil {
return nil, errors.Errorf("unknown verb %s", route.Verb)
}

dataRef := verb.Request
if dataRef.Module == "" {
dataRef.Module = route.Module
}

err = validateRequestMap(dataRef, "", requestMap, sch)
if err != nil {
return nil, errors.WithStack(err)
}

body, err := json.Marshal(requestMap)
if err != nil {
return nil, errors.WithStack(err)
}

return body, nil
}

func buildRequestMap(route *dal.IngressRoute, r *http.Request) (map[string]any, error) {
requestMap := map[string]any{}
matchAndExtractAllSegments(route.Path, r.URL.Path, func(segment, value string) {
requestMap[segment] = value
})

switch r.Method {
case http.MethodPost, http.MethodPut:
var bodyMap map[string]any
err := json.NewDecoder(r.Body).Decode(&bodyMap)
if err != nil {
return nil, errors.WithStack(err)
}

// Merge bodyMap into params
for k, v := range bodyMap {
requestMap[k] = v
}
default:
// TODO: Support query params correctly for map and array
for key, value := range r.URL.Query() {
requestMap[key] = value[len(value)-1]
}
}

return requestMap, nil
}

func validateRequestMap(dataRef *schema.DataRef, parentName string, request map[string]any, sch *schema.Schema) error {
data := sch.ResolveDataRef(dataRef)
if data == nil {
return errors.Errorf("unknown data %v", dataRef)
}

for _, field := range data.Fields {
err := validateField(field.Name, parentName, field.Type, request, sch)
if err != nil {
return errors.WithStack(err)
}
}

return nil
}

func validateField(fieldName string, parentFieldName string, fieldType schema.Type, request map[string]any, sch *schema.Schema) error {
value, ok := request[fieldName]
if !ok {
if parentFieldName == "" {
return errors.Errorf("missing field %s", fieldName)
}
return errors.Errorf("missing field %s.%s", parentFieldName, fieldName)
}

var typeMatches bool
switch fieldType := fieldType.(type) {
case *schema.Int:
switch value := value.(type) {
case float64:
typeMatches = true
case string:
if _, err := strconv.ParseInt(value, 10, 64); err == nil {
typeMatches = true
}
}
case *schema.Float:
switch value := value.(type) {
case float64:
typeMatches = true
case string:
if _, err := strconv.ParseFloat(value, 64); err == nil {
typeMatches = true
}
}
case *schema.String:
_, typeMatches = value.(string)
case *schema.Bool:
switch value := value.(type) {
case bool:
typeMatches = true
case string:
if _, err := strconv.ParseBool(value); err == nil {
typeMatches = true
}
}
case *schema.Array:
rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Slice {
return errors.Errorf("field %s is not a slice", fieldName)
}
elementType := fieldType.Element
for i := 0; i < rv.Len(); i++ {
elem := rv.Index(i).Interface()
elemFieldName := fmt.Sprintf("%s[%d]", fieldName, i)
if err := validateField(elemFieldName, fieldName, elementType, map[string]any{elemFieldName: elem}, sch); err != nil {
return err
}
}
typeMatches = true
case *schema.Map:
rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Map {
return errors.Errorf("field %s is not a map", fieldName)
}
keyType := fieldType.Key
valueType := fieldType.Value
for _, key := range rv.MapKeys() {
elem := rv.MapIndex(key).Interface()
elemFieldName := fmt.Sprintf("%s[%v]", fieldName, key)
if err := validateField(elemFieldName, fieldName, keyType, map[string]any{elemFieldName: key.Interface()}, sch); err != nil {
return err
}
if err := validateField(elemFieldName, fieldName, valueType, map[string]any{elemFieldName: elem}, sch); err != nil {
return err
}
}
typeMatches = true
case *schema.DataRef:
if valueMap, ok := value.(map[string]any); ok {
// HACK to get around schema extraction issues.
if fieldType.Module == "" {
if fieldType.Name == "Address" {
fieldType.Module = "shipping"
}
if fieldType.Name == "CreditCardInfo" {
fieldType.Module = "payment"
}
}
if err := validateRequestMap(fieldType, fieldName, valueMap, sch); err != nil {
return err
}
typeMatches = true
}

default:
return errors.Errorf("field %s has unsupported type %T", fieldName, value)
}

if !typeMatches {
return errors.Errorf("field %s has wrong type. expected %s found %T", fieldName, fieldType, value)
}
return nil
}
Loading

0 comments on commit 64d57e1

Please sign in to comment.