Skip to content

Commit

Permalink
Restructure some modules/code (#51)
Browse files Browse the repository at this point in the history
Some refactoring to reduce the size of main.go in preparation of adding
OTEL Tracing instrumentation
  • Loading branch information
aacoba authored Jan 30, 2024
1 parent 0075aec commit 2ed28ba
Show file tree
Hide file tree
Showing 11 changed files with 193 additions and 134 deletions.
147 changes: 17 additions & 130 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,34 @@ package main

import (
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"github.com/ardanlabs/conf/v3"
"github.com/ldebruijn/graphql-protect/internal/app/config"
"github.com/ldebruijn/graphql-protect/internal/business/aliases"
"github.com/ldebruijn/graphql-protect/internal/business/batch"
"github.com/ldebruijn/graphql-protect/internal/business/block_field_suggestions"
"github.com/ldebruijn/graphql-protect/internal/business/enforce_post"
"github.com/ldebruijn/graphql-protect/internal/business/gql"
"github.com/ldebruijn/graphql-protect/internal/business/max_depth"
middleware2 "github.com/ldebruijn/graphql-protect/internal/business/middleware"
"github.com/ldebruijn/graphql-protect/internal/business/persisted_operations"
"github.com/ldebruijn/graphql-protect/internal/business/proxy"
"github.com/ldebruijn/graphql-protect/internal/business/readiness"
"github.com/ldebruijn/graphql-protect/internal/business/protect"
"github.com/ldebruijn/graphql-protect/internal/business/schema"
"github.com/ldebruijn/graphql-protect/internal/business/tokens"
"github.com/ldebruijn/graphql-protect/internal/http/middleware"
"github.com/ldebruijn/graphql-protect/internal/http/proxy"
"github.com/ldebruijn/graphql-protect/internal/http/readiness"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/vektah/gqlparser/v2/ast"
"github.com/vektah/gqlparser/v2/gqlerror"
"github.com/vektah/gqlparser/v2/parser"
"github.com/vektah/gqlparser/v2/validator"
log2 "log"
"log/slog"
"net/http"
"net/http/httputil"
"os"
"os/signal"
"runtime"
"syscall"
"time"
)

var (
shortHash = "develop"
build = "develop"
configPath = ""

httpDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "graphql_protect",
Subsystem: "http",
Name: "duration",
Help: "HTTP duration",
},
[]string{"route"},
)

appInfo = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: "graphql_protect",
Subsystem: "app",
Expand All @@ -59,11 +38,10 @@ var (
},
[]string{"version", "go_version", "short_hash"},
)
errRedacted = errors.New("error(s) redacted")
)

func init() {
prometheus.MustRegister(httpDuration, appInfo)
prometheus.MustRegister(appInfo)
}

func main() {
Expand Down Expand Up @@ -128,13 +106,19 @@ func run(log *slog.Logger, cfg *config.Config, shutdown chan os.Signal) error {
return nil
}

protectHandler, err := protect.NewGraphQLProtect(log, cfg, po, schemaProvider, pxy)
if err != nil {
log.Error("Error initializing GraphQL Protect", "err", err)
return err
}

mux := http.NewServeMux()

mid := middleware(log, cfg, po, schemaProvider)
mid := ProtectMiddlewareChain(log)

mux.Handle("/metrics", promhttp.Handler())
mux.Handle("/internal/healthz/readiness", readiness.NewReadinessHandler())
mux.Handle(cfg.Web.Path, mid(Handler(pxy)))
mux.Handle(cfg.Web.Path, mid(protectHandler))

api := http.Server{
Addr: cfg.Web.Host,
Expand Down Expand Up @@ -174,110 +158,13 @@ func run(log *slog.Logger, cfg *config.Config, shutdown chan os.Signal) error {
return nil
}

func middleware(log *slog.Logger, cfg *config.Config, po *persisted_operations.PersistedOperationsHandler, schema *schema.Provider) func(next http.Handler) http.Handler {
rec := middleware2.Recover(log)
httpInstrumentation := HTTPInstrumentation()

aliases.NewMaxAliasesRule(cfg.MaxAliases)
max_depth.NewMaxDepthRule(cfg.MaxDepth)
tks := tokens.MaxTokens(cfg.MaxTokens)
maxBatch, err := batch.NewMaxBatch(cfg.MaxBatch)
if err != nil {
log.Warn("Error initializing maximum batch protection", err)
}

vr := ValidationRules(schema, tks, maxBatch, cfg.ObfuscateValidationErrors)
disableMethod := enforce_post.EnforcePostMethod(cfg.EnforcePost)
func ProtectMiddlewareChain(log *slog.Logger) func(next http.Handler) http.Handler {
rec := middleware.Recover(log)
httpInstrumentation := middleware.RequestMetricMiddleware()

fn := func(next http.Handler) http.Handler {
return rec(httpInstrumentation(disableMethod(po.Execute(vr(next)))))
return rec(httpInstrumentation(next))
}

return fn
}

func HTTPInstrumentation() func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
start := time.Now()

next.ServeHTTP(w, r)

httpDuration.WithLabelValues(r.URL.Path).Observe(time.Since(start).Seconds())
}
return http.HandlerFunc(fn)
}
}

func ValidationRules(schema *schema.Provider, tks *tokens.MaxTokensRule, batch *batch.MaxBatchRule, obfuscateErrors bool) func(next http.Handler) http.Handler { // nolint:funlen,cyclop
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
payload, err := gql.ParseRequestPayload(r)
if err != nil {
next.ServeHTTP(w, r)
return
}

var errs gqlerror.List

err = batch.Validate(payload)
if err != nil {
errs = append(errs, gqlerror.Wrap(err))
}

// only process the rest if no error yet
if err == nil {
for _, data := range payload {
operationSource := &ast.Source{
Input: data.Query,
}

err = tks.Validate(operationSource)
if err != nil {
errs = append(errs, gqlerror.Wrap(err))
continue // we could consider break-ing here. That would short-circuit on error, with the downside of not returning all potential errors
}

var query, err = parser.ParseQuery(operationSource)
if err != nil {
errs = append(errs, gqlerror.Wrap(err))
continue
}

errList := validator.Validate(schema.Get(), query)
if len(errList) > 0 {
errs = append(errs, errList...)
continue
}
}
}

if len(errs) > 0 {
if obfuscateErrors {
errs = gqlerror.List{gqlerror.Wrap(errRedacted)}
}

response := map[string]interface{}{
"data": nil,
"errors": errs,
}

err = json.NewEncoder(w).Encode(response)
if err != nil {
log2.Println(err)
}
return
}

next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
}

func Handler(p *httputil.ReverseProxy) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
5 changes: 3 additions & 2 deletions cmd/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"github.com/ldebruijn/graphql-protect/internal/app/config"
"github.com/ldebruijn/graphql-protect/internal/business/protect"
"github.com/stretchr/testify/assert"
"io"
log2 "log"
Expand Down Expand Up @@ -331,7 +332,7 @@ input ImageInput {
expected := map[string]interface{}{
"errors": []map[string]interface{}{
{
"message": errRedacted.Error(),
"message": protect.ErrRedacted.Error(),
},
},
}
Expand All @@ -340,7 +341,7 @@ input ImageInput {
assert.NoError(t, err)
// perform string comparisons as map[string]interface seems incomparable
fmt.Println(string(actual))
assert.True(t, errorsContainsMessage(errRedacted.Error(), actual))
assert.True(t, errorsContainsMessage(protect.ErrRedacted.Error(), actual))
},
},
{
Expand Down
2 changes: 1 addition & 1 deletion internal/app/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import (
"github.com/ldebruijn/graphql-protect/internal/business/enforce_post"
"github.com/ldebruijn/graphql-protect/internal/business/max_depth"
"github.com/ldebruijn/graphql-protect/internal/business/persisted_operations"
"github.com/ldebruijn/graphql-protect/internal/business/proxy"
"github.com/ldebruijn/graphql-protect/internal/business/schema"
"github.com/ldebruijn/graphql-protect/internal/business/tokens"
"github.com/ldebruijn/graphql-protect/internal/http/proxy"
"os"
"time"
)
Expand Down
2 changes: 1 addition & 1 deletion internal/app/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import (
"github.com/ldebruijn/graphql-protect/internal/business/enforce_post"
"github.com/ldebruijn/graphql-protect/internal/business/max_depth"
"github.com/ldebruijn/graphql-protect/internal/business/persisted_operations"
"github.com/ldebruijn/graphql-protect/internal/business/proxy"
"github.com/ldebruijn/graphql-protect/internal/business/schema"
"github.com/ldebruijn/graphql-protect/internal/business/tokens"
"github.com/ldebruijn/graphql-protect/internal/http/proxy"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v3"
"os"
Expand Down
136 changes: 136 additions & 0 deletions internal/business/protect/protect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package protect

import (
"encoding/json"
"errors"
"github.com/ldebruijn/graphql-protect/internal/app/config"
"github.com/ldebruijn/graphql-protect/internal/business/aliases"
"github.com/ldebruijn/graphql-protect/internal/business/batch"
"github.com/ldebruijn/graphql-protect/internal/business/enforce_post"
"github.com/ldebruijn/graphql-protect/internal/business/gql"
"github.com/ldebruijn/graphql-protect/internal/business/max_depth"
"github.com/ldebruijn/graphql-protect/internal/business/persisted_operations"
"github.com/ldebruijn/graphql-protect/internal/business/schema"
"github.com/ldebruijn/graphql-protect/internal/business/tokens"
"github.com/vektah/gqlparser/v2/ast"
"github.com/vektah/gqlparser/v2/gqlerror"
"github.com/vektah/gqlparser/v2/parser"
"github.com/vektah/gqlparser/v2/validator"
"log/slog"
"net/http"
)

var (
ErrRedacted = errors.New("error(s) redacted")
)

type GraphQLProtect struct {
log *slog.Logger
cfg *config.Config
po *persisted_operations.PersistedOperationsHandler
schema *schema.Provider
tokens *tokens.MaxTokensRule
maxBatch *batch.MaxBatchRule
next http.Handler
preFilterChain func(handler http.Handler) http.Handler
}

func NewGraphQLProtect(log *slog.Logger, cfg *config.Config, po *persisted_operations.PersistedOperationsHandler, schema *schema.Provider, upstreamHandler http.Handler) (*GraphQLProtect, error) {
aliases.NewMaxAliasesRule(cfg.MaxAliases)
max_depth.NewMaxDepthRule(cfg.MaxDepth)
maxBatch, err := batch.NewMaxBatch(cfg.MaxBatch)
if err != nil {
log.Warn("Error initializing maximum batch protection", err)
}

disableMethod := enforce_post.EnforcePostMethod(cfg.EnforcePost)

return &GraphQLProtect{
log: log,
cfg: cfg,
po: po,
schema: schema,
tokens: tokens.MaxTokens(cfg.MaxTokens),
maxBatch: maxBatch,
preFilterChain: func(next http.Handler) http.Handler {
return disableMethod(po.Execute(next))
},
next: upstreamHandler,
}, nil
}

func (p *GraphQLProtect) ServeHTTP(w http.ResponseWriter, r *http.Request) {
p.preFilterChain(http.HandlerFunc(p.handle)).ServeHTTP(w, r)
}

func (p *GraphQLProtect) handle(w http.ResponseWriter, r *http.Request) {
errs := p.validateRequest(r)

if len(errs) > 0 {
if p.cfg.ObfuscateValidationErrors {
errs = gqlerror.List{gqlerror.Wrap(ErrRedacted)}
}

response := map[string]interface{}{
"data": nil,
"errors": errs,
}

w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(response)
if err != nil {
p.log.Error("could not encode error", "err", err)
}
return
}

p.next.ServeHTTP(w, r)
}

func (p *GraphQLProtect) validateRequest(r *http.Request) gqlerror.List {
payload, err := gql.ParseRequestPayload(r)
if err != nil {
return gqlerror.List{gqlerror.Wrap(err)}
}

var errs gqlerror.List

err = p.maxBatch.Validate(payload)
if err != nil {
errs = append(errs, gqlerror.Wrap(err))
}

if err != nil {
return errs
}

// only process the rest if no error yet
if err == nil {
for _, data := range payload {
validationErrors := p.validateQuery(data)
if len(validationErrors) > 0 {
errs = append(errs, validationErrors...)
}
}
}

return errs
}

func (p *GraphQLProtect) validateQuery(data gql.RequestData) gqlerror.List {
operationSource := &ast.Source{
Input: data.Query,
}

err := p.tokens.Validate(operationSource)
if err != nil {
return gqlerror.List{gqlerror.Wrap(err)}
}

query, err := parser.ParseQuery(operationSource)
if err != nil {
return gqlerror.List{gqlerror.Wrap(err)}
}

return validator.Validate(p.schema.Get(), query)
}
Loading

0 comments on commit 2ed28ba

Please sign in to comment.