diff --git a/z/flags.go b/z/flags.go index ba853967..c5f48fb9 100644 --- a/z/flags.go +++ b/z/flags.go @@ -3,6 +3,9 @@ package z import ( "fmt" "log" + "os" + "os/user" + "path/filepath" "sort" "strconv" "strings" @@ -247,3 +250,34 @@ func (sf *SuperFlag) GetString(opt string) string { } return sf.m[opt] } + +func (sf *SuperFlag) GetPath(opt string) string { + p := sf.GetString(opt) + path, err := expandPath(p) + if err != nil { + log.Fatalf("Failed to get path: %+v", err) + } + return path +} + +// expandPath expands the paths containing ~ to /home/user. It also computes the absolute path +// from the relative paths. For example: ~/abc/../cef will be transformed to /home/user/cef. +func expandPath(path string) (string, error) { + if len(path) == 0 { + return "", nil + } + if path[0] == '~' && (len(path) == 1 || os.IsPathSeparator(path[1])) { + usr, err := user.Current() + if err != nil { + return "", errors.Wrap(err, "Failed to get the home directory of the user") + } + path = filepath.Join(usr.HomeDir, path[1:]) + } + + var err error + path, err = filepath.Abs(path) + if err != nil { + return "", errors.Wrap(err, "Failed to generate absolute path") + } + return path, nil +} diff --git a/z/flags_test.go b/z/flags_test.go index 6b8c3ac1..93f55705 100644 --- a/z/flags_test.go +++ b/z/flags_test.go @@ -1,6 +1,10 @@ package z import ( + "fmt" + "os" + "os/user" + "path/filepath" "testing" "time" @@ -39,3 +43,49 @@ func TestFlagDefault(t *testing.T) { require.Equal(t, true, f.GetBool("one")) require.Equal(t, int64(4), f.GetInt64("two")) } + +func TestGetPath(t *testing.T) { + + usr, err := user.Current() + require.NoError(t, err) + homeDir := usr.HomeDir + cwd, err := os.Getwd() + require.NoError(t, err) + + tests := []struct { + path string + expected string + }{ + { + "/home/user/file.txt", + "/home/user/file.txt", + }, + { + "~/file.txt", + filepath.Join(homeDir, "file.txt"), + }, + { + "~/abc/../file.txt", + filepath.Join(homeDir, "file.txt"), + }, + { + "~/", + homeDir, + }, + { + "~filename", + filepath.Join(cwd, "~filename"), + }, + } + + get := func(p string) string { + opt := fmt.Sprintf("file=%s", p) + sf := NewSuperFlag(opt) + return sf.GetPath("file") + } + + for _, tc := range tests { + actual := get(tc.path) + require.Equalf(t, tc.expected, actual, "Failed on testcase: %s", tc.path) + } +}