Skip to content

Commit

Permalink
add -cors.origins flag to "zed serve"
Browse files Browse the repository at this point in the history
The -cors.origins flag accepts a comma-separated list of CORS allowed
origins.

Closes #4297.
  • Loading branch information
nwt committed Jan 25, 2023
1 parent 4aa010c commit 990520e
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 10 deletions.
23 changes: 23 additions & 0 deletions cli/commastringsflag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package cli

import "strings"

// CommaStringsFlag is a [flag.Value] representing a comma-separated list of
// strings.
type CommaStringsFlag []string

func (c *CommaStringsFlag) String() string {
if *c == nil {
return ""
}
return strings.Join(*c, ",")
}

func (c *CommaStringsFlag) Set(value string) error {
if value == "" {
*c = nil
} else {
*c = strings.Split(value, ",")
}
return nil
}
2 changes: 2 additions & 0 deletions cmd/zed/serve/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ func New(parent charm.Command, f *flag.FlagSet) (charm.Command, error) {
c.conf.Version = cli.Version
c.logflags.SetFlags(f)
f.IntVar(&c.brimfd, "brimfd", -1, "pipe read fd passed by brim to signal brim closure")
c.conf.CORSAllowedOrigins = []string{"*.observableusercontent.com", "localhost"}
f.Var((*cli.CommaStringsFlag)(&c.conf.CORSAllowedOrigins), "cors.origins", "comma-separated list of CORS allowed origins")
f.StringVar(&c.listenAddr, "l", ":9867", "[addr]:port to listen on")
f.StringVar(&c.portFile, "portfile", "", "write listen port to file")
f.StringVar(&c.rootContentFile, "rootcontentfile", "", "file to serve for GET /")
Expand Down
15 changes: 8 additions & 7 deletions service/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ const indexPage = `
</html>`

type Config struct {
Auth AuthConfig
Root *storage.URI
RootContent io.ReadSeeker
Version string
Logger *zap.Logger
Auth AuthConfig
CORSAllowedOrigins []string
Root *storage.URI
RootContent io.ReadSeeker
Version string
Logger *zap.Logger
}

type Core struct {
Expand Down Expand Up @@ -105,7 +106,7 @@ func NewCore(ctx context.Context, conf Config) (*Core, error) {
}

routerAux := mux.NewRouter()
routerAux.Use(corsMiddleware())
routerAux.Use(corsMiddleware(conf.CORSAllowedOrigins))

routerAux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
http.ServeContent(w, r, "", time.Time{}, conf.RootContent)
Expand All @@ -131,7 +132,7 @@ func NewCore(ctx context.Context, conf Config) (*Core, error) {
routerAPI.Use(requestIDMiddleware())
routerAPI.Use(accessLogMiddleware(conf.Logger))
routerAPI.Use(panicCatchMiddleware(conf.Logger))
routerAPI.Use(corsMiddleware())
routerAPI.Use(corsMiddleware(conf.CORSAllowedOrigins))

c := &Core{
auth: authenticator,
Expand Down
4 changes: 1 addition & 3 deletions service/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ func requestIDMiddleware() mux.MiddlewareFunc {
}
}

var allowedOrigins = []string{"*.observableusercontent.com", "localhost"}

func corsMiddleware() mux.MiddlewareFunc {
func corsMiddleware(allowedOrigins []string) mux.MiddlewareFunc {
return cors.New(cors.Options{
AllowedOrigins: allowedOrigins,
AllowedMethods: []string{
Expand Down

0 comments on commit 990520e

Please sign in to comment.