diff --git a/rolling-shutter/medley/configuration/command/command.go b/rolling-shutter/medley/configuration/command/command.go index f06a22a1d..e6071c35c 100644 --- a/rolling-shutter/medley/configuration/command/command.go +++ b/rolling-shutter/medley/configuration/command/command.go @@ -117,11 +117,16 @@ func Build[T configuration.Config]( if err != nil { return err } - return WriteConfig(builder.filesystem, cfg, outPath) + overwrite, err := cmd.Flags().GetBool("force") + if err != nil { + return err + } + return WriteConfig(builder.filesystem, cfg, outPath, overwrite) }, } genConfigCmd.PersistentFlags().String("output", "", "output file") genConfigCmd.MarkPersistentFlagRequired("output") + genConfigCmd.PersistentFlags().BoolP("force", "f", false, "overwrite existing file") cb.cobraCommand.AddCommand(genConfigCmd) } if builder.dumpConfig { @@ -145,13 +150,18 @@ func Build[T configuration.Config]( log.Debug(). Interface("config", cfg). Msg("dumping config") - return WriteConfig(builder.filesystem, cfg, outPath) + overwrite, err := cmd.Flags().GetBool("force") + if err != nil { + return err + } + return WriteConfig(builder.filesystem, cfg, outPath, overwrite) }, } dumpConfigCmd.PersistentFlags().String("output", "", "output file") dumpConfigCmd.MarkPersistentFlagRequired("output") dumpConfigCmd.PersistentFlags().String("config", "", "config file") dumpConfigCmd.MarkPersistentFlagFilename("config") + dumpConfigCmd.PersistentFlags().BoolP("force", "f", false, "overwrite existing file") cb.cobraCommand.AddCommand(dumpConfigCmd) } return cb diff --git a/rolling-shutter/medley/configuration/command/parse.go b/rolling-shutter/medley/configuration/command/parse.go index a01bf663c..5090258b2 100644 --- a/rolling-shutter/medley/configuration/command/parse.go +++ b/rolling-shutter/medley/configuration/command/parse.go @@ -24,12 +24,12 @@ func CommandAddConfigFileFlag(cmd *cobra.Command) { cmd.MarkPersistentFlagFilename("config") } -func WriteConfig(fs afero.Fs, config configuration.Config, outPath string) error { +func WriteConfig(fs afero.Fs, config configuration.Config, outPath string, overwrite bool) error { buf := &bytes.Buffer{} if err := configuration.WriteTOML(buf, config); err != nil { return errors.Wrap(err, "failed to write config file") } - return medley.SecureSpit(fs, outPath, buf.Bytes()) + return medley.SecureSpit(fs, outPath, buf.Bytes(), overwrite) } // ParseCLI reads in the CLI argument context from the diff --git a/rolling-shutter/medley/configuration/test/config_test.go b/rolling-shutter/medley/configuration/test/config_test.go index 904829619..fb7ffa27f 100644 --- a/rolling-shutter/medley/configuration/test/config_test.go +++ b/rolling-shutter/medley/configuration/test/config_test.go @@ -45,7 +45,7 @@ func TestConfiguration(t *testing.T) { err = afs.MkdirAll(dirPath, os.ModeDir) assert.NilError(t, err) - err = command.WriteConfig(afs, config, configFile) + err = command.WriteConfig(afs, config, configFile, false) assert.NilError(t, err) file, err := afero.ReadFile(afs, configFile) diff --git a/rolling-shutter/medley/configuration/test/helper.go b/rolling-shutter/medley/configuration/test/helper.go index 31165ea27..0dbf51876 100644 --- a/rolling-shutter/medley/configuration/test/helper.go +++ b/rolling-shutter/medley/configuration/test/helper.go @@ -28,7 +28,7 @@ func RoundtripParseConfig[T configuration.Config]( err := afs.MkdirAll(dirPath, os.ModeDir) assert.NilError(t, err) - err = command.WriteConfig(afs, config, configFile) + err = command.WriteConfig(afs, config, configFile, false) assert.NilError(t, err) var parsedConfig T diff --git a/rolling-shutter/medley/spit.go b/rolling-shutter/medley/spit.go index a4b0d1009..931f27e26 100644 --- a/rolling-shutter/medley/spit.go +++ b/rolling-shutter/medley/spit.go @@ -7,10 +7,14 @@ import ( ) // SecureSpit creates a new file with the given path and writes the given content to it. The file -// is created with with mode 0600. SecureSpit will not overwrite an existing file. -func SecureSpit(fs afero.Fs, path string, content []byte) error { +// is created with with mode 0600. SecureSpit will not overwrite an existing file unless asked. +func SecureSpit(fs afero.Fs, path string, content []byte, overwrite bool) error { var err error - file, err := fs.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0o600) + flags := os.O_RDWR | os.O_CREATE + if !overwrite { + flags |= os.O_EXCL + } + file, err := fs.OpenFile(path, flags, 0o600) if err != nil { return err }