diff --git a/testscript/testscript.go b/testscript/testscript.go index ea0dee0e..759e3cfe 100644 --- a/testscript/testscript.go +++ b/testscript/testscript.go @@ -69,6 +69,27 @@ func (e *Env) Defer(f func()) { e.ts.Defer(f) } +// Getenv retrieves the value of the environment variable named by the key. It +// returns the value, which will be empty if the variable is not present. +func (e *Env) Getenv(key string) string { + key = envvarname(key) + for i := len(e.Vars) - 1; i >= 0; i-- { + if pair := strings.SplitN(e.Vars[i], "=", 2); len(pair) == 2 && envvarname(pair[0]) == key { + return pair[1] + } + } + return "" +} + +// Setenv sets the value of the environment variable named by the key. It +// panics if key is invalid. +func (e *Env) Setenv(key, value string) { + if key == "" || strings.IndexByte(key, '=') != -1 { + panic("Setenv: invalid argument") + } + e.Vars = append(e.Vars, key+"="+value) +} + // T returns the t argument passed to the current test by the T.Run method. // Note that if the tests were started by calling Run, // the returned value will implement testing.TB. diff --git a/testscript/testscript_test.go b/testscript/testscript_test.go index a5db1699..cb66cf30 100644 --- a/testscript/testscript_test.go +++ b/testscript/testscript_test.go @@ -83,6 +83,52 @@ func TestCRLFInput(t *testing.T) { }) } +func TestEnv(t *testing.T) { + e := &Env{ + Vars: []string{ + "HOME=/no-home", + "PATH=/usr/bin", + "PATH=/usr/bin:/usr/local/bin", + "INVALID", + }, + } + + if got, want := e.Getenv("HOME"), "/no-home"; got != want { + t.Errorf("e.Getenv(\"HOME\") == %q, want %q", got, want) + } + + e.Setenv("HOME", "/home/user") + if got, want := e.Getenv("HOME"), "/home/user"; got != want { + t.Errorf("e.Getenv(\"HOME\") == %q, want %q", got, want) + } + + if got, want := e.Getenv("PATH"), "/usr/bin:/usr/local/bin"; got != want { + t.Errorf("e.Getenv(\"PATH\") == %q, want %q", got, want) + } + + if got, want := e.Getenv("INVALID"), ""; got != want { + t.Errorf("e.Getenv(\"INVALID\") == %q, want %q", got, want) + } + + for _, key := range []string{ + "", + "=", + "key=invalid", + } { + value := "" + var panicValue interface{} + func() { + defer func() { + panicValue = recover() + }() + e.Setenv(key, value) + }() + if panicValue == nil { + t.Errorf("e.Setenv(%q, %q) did not panic, want panic", key, value) + } + } +} + func TestScripts(t *testing.T) { // TODO set temp directory. testDeferCount := 0