diff --git a/cmd/gate/root.go b/cmd/gate/root.go index c3bbc635..e5d3e646 100644 --- a/cmd/gate/root.go +++ b/cmd/gate/root.go @@ -1,10 +1,13 @@ package gate import ( + "bytes" + "encoding/json" "errors" "fmt" "math" "os" + "path" "strings" "github.com/go-logr/logr" @@ -12,8 +15,10 @@ import ( "github.com/spf13/viper" "github.com/urfave/cli/v2" "go.minekube.com/gate/pkg/gate" + "go.minekube.com/gate/pkg/gate/config" "go.uber.org/zap" "go.uber.org/zap/zapcore" + "gopkg.in/yaml.v3" ) // Execute runs App() and calls os.Exit when finished. @@ -115,7 +120,8 @@ func initViper(c *cli.Context, configFile string) (*viper.Viper, error) { v.AutomaticEnv() // read in environment variables that match v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) // Read in config. - if err := v.ReadInConfig(); err != nil { + cfgCopy := func() config.Config { return config.DefaultConfig }() + if err := FixedReadInConfig(v, configFile, &cfgCopy); err != nil { // A config file is only required to exist when explicit config flag was specified. if !(errors.As(err, &viper.ConfigFileNotFoundError{}) || os.IsNotExist(err)) || c.IsSet("config") { return nil, fmt.Errorf("error reading config file %q: %w", v.ConfigFileUsed(), err) @@ -124,6 +130,40 @@ func initViper(c *cli.Context, configFile string) (*viper.Viper, error) { return v, nil } +// FixedReadInConfig is a workaround for https://github.com/minekube/gate/issues/218#issuecomment-1632800775 +func FixedReadInConfig[T any](v *viper.Viper, configFile string, defaultConfig *T) error { + if configFile == "" || defaultConfig == nil { + return v.ReadInConfig() + } + + var ( + unmarshal func([]byte, any) error + marshal func(any) ([]byte, error) + ) + switch path.Ext(configFile) { + case ".yaml", ".yml": + unmarshal = yaml.Unmarshal + marshal = yaml.Marshal + case ".json": + unmarshal = json.Unmarshal + marshal = json.Marshal + default: + return fmt.Errorf("unsupported config file format %q", configFile) + } + b, err := os.ReadFile(configFile) + if err != nil { + return fmt.Errorf("error reading config file %q: %w", configFile, err) + } + if err = unmarshal(b, defaultConfig); err != nil { + return fmt.Errorf("error unmarshaling config file %q to %T: %w", configFile, defaultConfig, err) + } + if b, err = marshal(defaultConfig); err != nil { + return fmt.Errorf("error marshaling config file %q: %w", configFile, err) + } + + return v.ReadConfig(bytes.NewReader(b)) +} + // newLogger returns a new zap logger with a modified production // or development default config to ensure human readability. func newLogger(debug bool, v int) (l logr.Logger, err error) { diff --git a/pkg/util/favicon/favicon.go b/pkg/util/favicon/favicon.go index b4235897..c70a95e3 100644 --- a/pkg/util/favicon/favicon.go +++ b/pkg/util/favicon/favicon.go @@ -9,6 +9,7 @@ import ( _ "image/jpeg" "image/png" "os" + "path" "strings" "github.com/nfnt/resize" @@ -95,6 +96,9 @@ func Parse(s string) (Favicon, error) { } return f, nil } + if path.Ext(s) != "" { + return "", fmt.Errorf("favicon: file %s not found", s) + } return "", fmt.Errorf("favicon: invalid format: %s", s) }