diff --git a/src/cmd/common/setup.go b/src/cmd/common/setup.go index ae2052b882..72fd0d6722 100644 --- a/src/cmd/common/setup.go +++ b/src/cmd/common/setup.go @@ -5,8 +5,10 @@ package common import ( + "fmt" "io" "os" + "time" "github.com/defenseunicorns/zarf/src/config" "github.com/defenseunicorns/zarf/src/config/lang" @@ -51,14 +53,20 @@ func SetupCLI() { } if !config.SkipLogFile { - logFile, err := message.UseLogFile("") + ts := time.Now().Format("2006-01-02-15-04-05") + + f, err := os.CreateTemp("", fmt.Sprintf("zarf-%s-*.log", ts)) + if err != nil { + message.WarnErr(err, "Error creating a log file in a temporary directory") + return + } + logFile, err := message.UseLogFile(f) if err != nil { message.WarnErr(err, "Error saving a log file to a temporary directory") return } pterm.SetDefaultOutput(io.MultiWriter(os.Stderr, logFile)) - location := message.LogFileLocation() - message.Notef("Saving log file to %s", location) + message.Notef("Saving log file to %s", f.Name()) } } diff --git a/src/pkg/message/credentials.go b/src/pkg/message/credentials.go index 60ca971119..07b85525bc 100644 --- a/src/pkg/message/credentials.go +++ b/src/pkg/message/credentials.go @@ -32,8 +32,8 @@ func PrintCredentialTable(state *types.ZarfState, componentsToDeploy []types.Dep // Pause the logfile's output to avoid credentials being printed to the log file if logFile != nil { - logFile.pause() - defer logFile.resume() + logFile.Pause() + defer logFile.Resume() } loginData := [][]string{} @@ -95,8 +95,8 @@ func PrintComponentCredential(state *types.ZarfState, componentName string) { func PrintCredentialUpdates(oldState *types.ZarfState, newState *types.ZarfState, services []string) { // Pause the logfile's output to avoid credentials being printed to the log file if logFile != nil { - logFile.pause() - defer logFile.resume() + logFile.Pause() + defer logFile.Resume() } for _, service := range services { diff --git a/src/pkg/message/message.go b/src/pkg/message/message.go index 1801c7b9f5..713bb90cbe 100644 --- a/src/pkg/message/message.go +++ b/src/pkg/message/message.go @@ -7,7 +7,6 @@ package message import ( "encoding/json" "fmt" - "io" "net/http" "os" "runtime/debug" @@ -49,7 +48,7 @@ var RuleLine = strings.Repeat("━", TermWidth) var logLevel = InfoLevel // logFile acts as a buffer for logFile generation -var logFile *pausableLogFile +var logFile *PausableWriter // DebugWriter represents a writer interface that writes to message.Debug type DebugWriter struct{} @@ -77,32 +76,14 @@ func init() { pterm.SetDefaultOutput(os.Stderr) } -// UseLogFile writes output to stderr and a logFile. -func UseLogFile(dir string) (io.Writer, error) { - // Prepend the log filename with a timestamp. - ts := time.Now().Format("2006-01-02-15-04-05") - - f, err := os.CreateTemp(dir, fmt.Sprintf("zarf-%s-*.log", ts)) - if err != nil { - return nil, err - } - - logFile = &pausableLogFile{ - wr: f, - f: f, - } +// UseLogFile wraps a given file in a PausableWriter +// and sets it as the log file used by the message package. +func UseLogFile(f *os.File) (*PausableWriter, error) { + logFile = NewPausableWriter(f) return logFile, nil } -// LogFileLocation returns the location of the log file. -func LogFileLocation() string { - if logFile == nil { - return "" - } - return logFile.f.Name() -} - // SetLogLevel sets the log level. func SetLogLevel(lvl LogLevel) { logLevel = lvl diff --git a/src/pkg/message/pausable.go b/src/pkg/message/pausable.go index ee8d7b6a67..b9e8fae1c7 100644 --- a/src/pkg/message/pausable.go +++ b/src/pkg/message/pausable.go @@ -6,26 +6,29 @@ package message import ( "io" - "os" ) -// pausableLogFile is a pausable log file -type pausableLogFile struct { - wr io.Writer - f *os.File +// PausableWriter is a pausable writer +type PausableWriter struct { + out, wr io.Writer } -// pause the log file -func (l *pausableLogFile) pause() { - l.wr = io.Discard +// NewPausableWriter creates a new pausable writer +func NewPausableWriter(wr io.Writer) *PausableWriter { + return &PausableWriter{out: wr, wr: wr} } -// resume the log file -func (l *pausableLogFile) resume() { - l.wr = l.f +// Pause sets the output writer to io.Discard +func (pw *PausableWriter) Pause() { + pw.out = io.Discard } -// Write writes the data to the log file -func (l *pausableLogFile) Write(p []byte) (n int, err error) { - return l.wr.Write(p) +// Resume sets the output writer back to the original writer +func (pw *PausableWriter) Resume() { + pw.out = pw.wr +} + +// Write writes the data to the underlying output writer +func (pw *PausableWriter) Write(p []byte) (n int, err error) { + return pw.out.Write(p) } diff --git a/src/pkg/message/pausable_test.go b/src/pkg/message/pausable_test.go new file mode 100644 index 0000000000..2cfeb2c827 --- /dev/null +++ b/src/pkg/message/pausable_test.go @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2021-Present The Zarf Authors + +package message + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPausableWriter(t *testing.T) { + var buf bytes.Buffer + + pw := NewPausableWriter(&buf) + + n, err := pw.Write([]byte("foo")) + require.NoError(t, err) + require.Equal(t, 3, n) + + require.Equal(t, "foo", buf.String()) + + pw.Pause() + + n, err = pw.Write([]byte("bar")) + require.NoError(t, err) + require.Equal(t, 3, n) + + require.Equal(t, "foo", buf.String()) + + pw.Resume() + + n, err = pw.Write([]byte("baz")) + require.NoError(t, err) + require.Equal(t, 3, n) + + require.Equal(t, "foobaz", buf.String()) +}