diff --git a/pkg/config/config.go b/pkg/config/config.go index f744fe2a0..fbb4bb833 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -192,9 +192,9 @@ func (c *Config) Environ() []string { fmt.Sprintf("SOFT_SERVE_HTTP_TLS_KEY_PATH=%s", c.HTTP.TLSKeyPath), fmt.Sprintf("SOFT_SERVE_HTTP_TLS_CERT_PATH=%s", c.HTTP.TLSCertPath), fmt.Sprintf("SOFT_SERVE_HTTP_PUBLIC_URL=%s", c.HTTP.PublicURL), - fmt.Sprintf("SOFT_SERVE_HTTP_ALLOWED_HEADERS=%s", strings.Join(c.HTTP.CORS.AllowedHeaders, "\n")), - fmt.Sprintf("SOFT_SERVE_HTTP_ALLOWED_ORIGINS=%s", strings.Join(c.HTTP.CORS.AllowedOrigins, "\n")), - fmt.Sprintf("SOFT_SERVE_HTTP_ALLOWED_METHODS=%s", strings.Join(c.HTTP.CORS.AllowedMethods, "\n")), + fmt.Sprintf("SOFT_SERVE_HTTP_CORS_ALLOWED_HEADERS=%s", strings.Join(c.HTTP.CORS.AllowedHeaders, "\n")), + fmt.Sprintf("SOFT_SERVE_HTTP_CORS_ALLOWED_ORIGINS=%s", strings.Join(c.HTTP.CORS.AllowedOrigins, "\n")), + fmt.Sprintf("SOFT_SERVE_HTTP_CORS_ALLOWED_METHODS=%s", strings.Join(c.HTTP.CORS.AllowedMethods, "\n")), fmt.Sprintf("SOFT_SERVE_STATS_LISTEN_ADDR=%s", c.Stats.ListenAddr), fmt.Sprintf("SOFT_SERVE_LOG_FORMAT=%s", c.Log.Format), fmt.Sprintf("SOFT_SERVE_LOG_TIME_FORMAT=%s", c.Log.TimeFormat), @@ -260,24 +260,6 @@ func parseEnv(cfg *Config) error { cfg.InitialAdminKeys = append(cfg.InitialAdminKeys, initialAdminKeys...) } - // split allowed headers and append to cfg - if allowedHeadersEnv := os.Getenv("SOFT_SERVE_HTTP_CORS_ALLOWED_HEADERS"); allowedHeadersEnv != "" { - allowedHeaders := strings.Split(allowedHeadersEnv, " ") - cfg.HTTP.CORS.AllowedHeaders = append(cfg.HTTP.CORS.AllowedHeaders, allowedHeaders...) - } - - // split allowed origins and append to cfg - if allowedOriginsEnv := os.Getenv("SOFT_SERVE_HTTP_CORS_ALLOWED_ORIGINS"); allowedOriginsEnv != "" { - allowedOrigins := strings.Split(allowedOriginsEnv, " ") - cfg.HTTP.CORS.AllowedOrigins = append(cfg.HTTP.CORS.AllowedOrigins, allowedOrigins...) - } - - // split allowed methods and append to cfg - if allowedMethodsEnv := os.Getenv("SOFT_SERVE_HTTP_CORS_ALLOWED_METHODS"); allowedMethodsEnv != "" { - allowedMethods := strings.Split(allowedMethodsEnv, " ") - cfg.HTTP.CORS.AllowedMethods = append(cfg.HTTP.CORS.AllowedMethods, allowedMethods...) - } - return cfg.Validate() } diff --git a/pkg/web/server.go b/pkg/web/server.go index 6685feb3c..ab336e89a 100644 --- a/pkg/web/server.go +++ b/pkg/web/server.go @@ -5,9 +5,9 @@ import ( "net/http" "github.com/charmbracelet/log" + "github.com/charmbracelet/soft-serve/pkg/config" "github.com/gorilla/handlers" "github.com/gorilla/mux" - "github.com/charmbracelet/soft-serve/pkg/config" ) // NewRouter returns a new HTTP router. @@ -29,25 +29,10 @@ func NewRouter(ctx context.Context) http.Handler { cfg := config.FromContext(ctx) - CORSHeaders := []string{"Accept", "Accept-Language", "Content-Language", "Origin"} - - if len(cfg.HTTP.CORS.AllowedHeaders) != 0 { - CORSHeaders = cfg.HTTP.CORS.AllowedHeaders - } - - CORSOrigins := []string{} - - if len(cfg.HTTP.CORS.AllowedOrigins) != 0 { - CORSOrigins = cfg.HTTP.CORS.AllowedOrigins - } - - CORSMethods := []string{http.MethodGet, http.MethodHead, http.MethodPost} - - if len(cfg.HTTP.CORS.AllowedMethods) != 0 { - CORSMethods = cfg.HTTP.CORS.AllowedMethods - } - - h = handlers.CORS(handlers.AllowedHeaders(CORSHeaders),handlers.AllowedOrigins(CORSOrigins),handlers.AllowedMethods(CORSMethods))(h) + h = handlers.CORS(handlers.AllowedHeaders(cfg.HTTP.CORS.AllowedHeaders), + handlers.AllowedOrigins(cfg.HTTP.CORS.AllowedOrigins), + handlers.AllowedMethods(cfg.HTTP.CORS.AllowedMethods), + )(h) return h }