Skip to content
This repository has been archived by the owner on Nov 8, 2022. It is now read-only.

Commit

Permalink
Merge pull request #1509 from candysmurf/cors
Browse files Browse the repository at this point in the history
SDI-1659: support for CORS header
  • Loading branch information
candysmurf authored Feb 10, 2017
2 parents 74029be + 09ddbf9 commit 89bf6c0
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 15 deletions.
3 changes: 3 additions & 0 deletions docs/SNAPTELD_CONFIGURATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ restapi:

# port sets the port to start the REST API server on. Default is 8181
port: 8181

# allowed_origins sets the allowed origins in a comma separated list. It defaults to the same origin if the value is empty.
allowed_origins: http://127.0.0.1:8080, http://snap.example.io, http://example.com
```
### snapteld tribe configurations
Expand Down
3 changes: 2 additions & 1 deletion examples/configs/snap-config-sample.json
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@
"rest_certificate":"/etc/snap/cert.pem",
"rest_key":"/etc/snap/cert.key",
"port":8282,
"addr":"127.0.0.1:12345"
"addr":"127.0.0.1:12345",
"allowed_origins": "http://127.0.0.1:8888, https://snap-telemetry.io"
},
"tribe":{
"enable":true,
Expand Down
3 changes: 3 additions & 0 deletions examples/configs/snap-config-sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ restapi:
# REST API in address[:port] format
addr: 127.0.0.1:12345

# corsd sets the cors allowed domains in a comma separated list. It is the same origin if it's empty.
allowed_origins: http://127.0.0.1:88888, https://snap-telemetry.io

# tribe section contains all configuration items for the tribe module
tribe:
# enable controls enabling tribe for the snapteld instance. Default value is false.
Expand Down
6 changes: 5 additions & 1 deletion glide.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions mgmt/rest/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const (
defaultAuthPassword string = ""
defaultPortSetByConfig bool = false
defaultPprof bool = false
defaultCorsd string = ""
)

// holds the configuration passed in through the SNAP config file
Expand All @@ -29,6 +30,7 @@ type Config struct {
RestAuthPassword string `json:"rest_auth_password"yaml:"rest_auth_password"`
portSetByConfig bool ``
Pprof bool `json:"pprof"yaml:"pprof"`
Corsd string `json:"corsd"yaml:"allowed_origins"`
}

const (
Expand Down Expand Up @@ -64,6 +66,9 @@ const (
},
"pprof": {
"type": "boolean"
},
"allowed_origins" : {
"type": "string"
}
},
"additionalProperties": false
Expand All @@ -84,6 +89,7 @@ func GetDefaultConfig() *Config {
RestAuthPassword: defaultAuthPassword,
portSetByConfig: defaultPortSetByConfig,
Pprof: defaultPprof,
Corsd: defaultCorsd,
}
}

Expand Down
6 changes: 5 additions & 1 deletion mgmt/rest/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ var (
Name: "pprof",
Usage: "Enables profiling tools",
}
flCorsd = cli.StringFlag{
Name: "allowed_origins",
Usage: "Define Cors allowed origins",
}

// Flags consumed by snapteld
Flags = []cli.Flag{flAPIDisabled, flAPIAddr, flAPIPort, flRestHTTPS, flRestCert, flRestKey, flRestAuth, flPProf}
Flags = []cli.Flag{flAPIDisabled, flAPIAddr, flAPIPort, flRestHTTPS, flRestCert, flRestKey, flRestAuth, flPProf, flCorsd}
)
113 changes: 101 additions & 12 deletions mgmt/rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,29 @@ import (
"fmt"
"net"
"net/http"
"net/url"
"strconv"
"sync"
"time"

log "github.com/Sirupsen/logrus"
"github.com/julienschmidt/httprouter"
"github.com/rs/cors"
"github.com/urfave/negroni"

"strings"

"github.com/intelsdi-x/snap/mgmt/rest/api"
"github.com/intelsdi-x/snap/mgmt/rest/v1"
"github.com/intelsdi-x/snap/mgmt/rest/v2"
)

const (
allowedMethods = "GET, POST, DELETE, PUT, OPTIONS"
allowedHeaders = "Origin, X-Requested-With, Content-Type, Accept"
maxAge = 3600
)

var (
ErrBadCert = errors.New("Invalid certificate given")

Expand All @@ -45,18 +56,19 @@ var (
)

type Server struct {
apis []api.API
n *negroni.Negroni
r *httprouter.Router
snapTLS *snapTLS
auth bool
pprof bool
authpwd string
addrString string
addr net.Addr
wg sync.WaitGroup
killChan chan struct{}
err chan error
apis []api.API
n *negroni.Negroni
r *httprouter.Router
snapTLS *snapTLS
auth bool
pprof bool
authpwd string
addrString string
addr net.Addr
wg sync.WaitGroup
killChan chan struct{}
err chan error
allowedOrigins map[string]bool
// the following instance variables are used to cleanly shutdown the server
serverListener net.Listener
closingChan chan bool
Expand Down Expand Up @@ -92,6 +104,23 @@ func New(cfg *Config) (*Server, error) {
negroni.HandlerFunc(s.authMiddleware),
)
s.r = httprouter.New()

// CORS has to be turned on explictly in the global config.
// Otherwise, it defauts to the same origin.
origins, err := s.getAllowedOrigins(cfg.Corsd)
if err != nil {
return nil, err
}
if len(origins) > 0 {
c := cors.New(cors.Options{
AllowedOrigins: origins,
AllowedMethods: []string{allowedMethods},
AllowedHeaders: []string{allowedHeaders},
MaxAge: maxAge,
})
s.n.Use(c)
}

// Use negroni to handle routes
s.n.UseHandler(s.r)
return s, nil
Expand Down Expand Up @@ -133,6 +162,9 @@ func (s *Server) SetAPIAuthPwd(pwd string) {

// Auth Middleware for REST API
func (s *Server) authMiddleware(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
reqOrigin := r.Header.Get("Origin")
s.setAllowedOrigins(rw, reqOrigin)

defer r.Body.Close()
if s.auth {
_, password, ok := r.BasicAuth()
Expand All @@ -149,6 +181,23 @@ func (s *Server) authMiddleware(rw http.ResponseWriter, r *http.Request, next ht
}
}

// CORS origins have to be turned on explictly in the global config.
// Otherwise, it defaults to the same origin.
func (s *Server) setAllowedOrigins(rw http.ResponseWriter, ro string) {
if len(s.allowedOrigins) > 0 {
if _, ok := s.allowedOrigins[ro]; ok {
// localhost CORS is not supported by all browsers. It has to use "*".
if strings.Contains(ro, "127.0.0.1") || strings.Contains(ro, "localhost") {
ro = "*"
}
rw.Header().Set("Access-Control-Allow-Origin", ro)
rw.Header().Set("Access-Control-Allow-Methods", allowedMethods)
rw.Header().Set("Access-Control-Allow-Headers", allowedHeaders)
rw.Header().Set("Access-Control-Max-Age", strconv.Itoa(maxAge))
}
}
}

func (s *Server) SetAddress(addrString string) {
s.addrString = addrString
restLogger.Info(fmt.Sprintf("Address used for binding: [%v]", s.addrString))
Expand Down Expand Up @@ -259,6 +308,46 @@ func (s *Server) addRoutes() {
s.addPprofRoutes()
}

func (s *Server) getAllowedOrigins(corsd string) ([]string, error) {
// Avoids panics when validating URLs.
defer func() {
if r := recover(); r != nil {
var ok bool
err, ok := r.(error)
if !ok {
err = fmt.Errorf("pkg: %v", r)
fmt.Println(err)
}
}

}()

if corsd == "" {
return []string{}, nil
}

vo := []string{}
s.allowedOrigins = map[string]bool{}

os := strings.Split(corsd, ",")
for _, o := range os {
to := strings.TrimSpace(o)

// Validates origin formation
u, err := url.Parse(to)

// Checks if scheme or host exists when no error occured.
if err != nil || u.Scheme == "" || u.Host == "" {
restLogger.Errorf("Invalid origin found %s", to)
return []string{}, fmt.Errorf("Invalid origin found: %s.", to)
}

vo = append(vo, to)
s.allowedOrigins[to] = true
}
return vo, nil
}

// Monkey patch ListenAndServe and TCP alive code from https://golang.org/src/net/http/server.go
// The built in ListenAndServe and ListenAndServeTLS include TCP keepalive
// At this point the Go team is not wanting to provide separate listen and serve methods
Expand Down
101 changes: 101 additions & 0 deletions mgmt/rest/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ limitations under the License.
package rest

import (
"fmt"
"net/url"
"strings"
"testing"

"github.com/intelsdi-x/snap/pkg/cfgfile"
. "github.com/smartystreets/goconvey/convey"
"github.com/urfave/negroni"
)

const (
Expand Down Expand Up @@ -161,5 +165,102 @@ func TestRestAPIDefaultConfig(t *testing.T) {
Convey("RestKey should be empty", func() {
So(cfg.RestKey, ShouldEqual, "")
})
Convey("Corsd should be empty", func() {
So(cfg.Corsd, ShouldEqual, "")
})
})
}

type mockServer struct {
n *negroni.Negroni
allowedOrigins map[string]bool
}

func NewMockServer(cfg *Config) (*mockServer, []string, error) {
s := &mockServer{}
origins, err := s.getAllowedOrigins(cfg.Corsd)

return s, origins, err
}

func (s *mockServer) getAllowedOrigins(corsd string) ([]string, error) {
defer func() {
if r := recover(); r != nil {
var ok bool
err, ok := r.(error)
if !ok {
err = fmt.Errorf("pkg: %v", r)
fmt.Println(err)
}
}

}()

if corsd == "" {
return []string{}, nil
}

vo := []string{}
s.allowedOrigins = map[string]bool{}

os := strings.Split(corsd, ",")
for _, o := range os {
to := strings.TrimSpace(o)

// Validates origin formation
u, err := url.Parse(to)

// Checks if scheme or host exists when no error occured.
if err != nil || u.Scheme == "" || u.Host == "" {
restLogger.Errorf("Invalid origin found %s", to)
return []string{}, fmt.Errorf("Invalid origin found: %s.", to)
}

vo = append(vo, to)
s.allowedOrigins[to] = true
}
return vo, nil
}

func TestRestAPICorsd(t *testing.T) {
cfg := GetDefaultConfig()

Convey("Test cors origin list", t, func() {

Convey("Origins are valid", func() {
cfg.Corsd = "http://127.0.0.1:80, http://example.com"
s, o, err := NewMockServer(cfg)

So(len(s.allowedOrigins), ShouldEqual, 2)
So(len(o), ShouldEqual, 2)
So(err, ShouldBeNil)
})

Convey("Origins have a wrong separator", func() {
cfg.Corsd = "http://127.0.0.1:80; http://example.com"
s, o, err := NewMockServer(cfg)

So(err, ShouldNotBeNil)
So(len(s.allowedOrigins), ShouldEqual, 0)
So(len(o), ShouldEqual, 0)
})

Convey("Origin misses scheme", func() {
cfg.Corsd = "127.0.0.1:80, http://example.com"
s, o, err := NewMockServer(cfg)

So(err, ShouldNotBeNil)
So(len(s.allowedOrigins), ShouldEqual, 0)
So(len(o), ShouldEqual, 0)
})

Convey("Origin is malformed", func() {
cfg.Corsd = "http://127.0.0.1:80, http://snap.io, [email protected]"
s, o, err := NewMockServer(cfg)

So(err, ShouldNotBeNil)
So(len(s.allowedOrigins), ShouldEqual, 2)
So(len(o), ShouldEqual, 0)
})
})
}
2 changes: 2 additions & 0 deletions snapteld.go
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,8 @@ func applyCmdLineFlags(cfg *Config, ctx *cli.Context) {
cfg.RestAPI.RestAuth = setBoolVal(cfg.RestAPI.RestAuth, ctx, "rest-auth")
cfg.RestAPI.RestAuthPassword = setStringVal(cfg.RestAPI.RestAuthPassword, ctx, "rest-auth-pwd")
cfg.RestAPI.Pprof = setBoolVal(cfg.RestAPI.Pprof, ctx, "pprof")
cfg.RestAPI.Corsd = setStringVal(cfg.RestAPI.Corsd, ctx, "allowed_origins")

// next for the scheduler related flags
cfg.Scheduler.WorkManagerQueueSize = setUIntVal(cfg.Scheduler.WorkManagerQueueSize, ctx, "work-manager-queue-size")
cfg.Scheduler.WorkManagerPoolSize = setUIntVal(cfg.Scheduler.WorkManagerPoolSize, ctx, "work-manager-pool-size")
Expand Down

0 comments on commit 89bf6c0

Please sign in to comment.