Skip to content

Commit

Permalink
refactor(prompt): 分离 w 和 Prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
caixw committed May 27, 2024
1 parent b0b756c commit c79513c
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 70 deletions.
32 changes: 14 additions & 18 deletions colors/colors.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ import (
// Color 定义了控制台能接受的所有颜色值
//
// 颜色定义分为以下几种:
// - 默认色: math.MaxInt32 定义为 Default
// - 基本色: 0-7 定义从 Black 至 White
// - 增强色: 8-15 定义从 BrightBlack 至 BrightWhite
// - 默认色: [math.MaxInt32] 定义为 [Default]
// - 基本色: 0-7 定义从 [Black][White]
// - 增强色: 8-15 定义从 [BrightBlack][BrightWhite]
// - 256 色: 0-256,数值,其中 0-15 的数据会被转换成以上的色彩;
// - 真彩色: 负数,可由 [RGB] 函数生成;
//
Expand Down Expand Up @@ -75,15 +75,15 @@ type Type int
//
// NOTE: 并不是所有的终端都支持这些所有特性。
const (
Bold Type = iota + 1
Faint // 弱化
Italic // 斜体
Underline // 下划线
Blink // 闪烁
RapidBlink // 快速闪烁
ReverseVideo // 反显
Conceal // 隐藏
Delete // 删除线
Bold Type = iota + 1 // 粗体
Faint // 弱化
Italic // 斜体
Underline // 下划线
Blink // 闪烁
RapidBlink // 快速闪烁
ReverseVideo // 反显
Conceal // 隐藏
Delete // 删除线
maxType

Normal Type = -1 // 正常显示
Expand Down Expand Up @@ -133,9 +133,7 @@ func (c Color) String() string {
}

// RGB 根据 RGB 生成真色彩
func RGB(r, g, b uint8) Color {
return Color(-(int32(r)<<16 + int32(g)<<8 + int32(b)))
}
func RGB(r, g, b uint8) Color { return Color(-(int32(r)<<16 + int32(g)<<8 + int32(b))) }

// HEX 以 16 进制的形式转换成颜色
//
Expand Down Expand Up @@ -214,9 +212,7 @@ func (c Color) bColorCode() []int {
}
}

func isValidType(t Type) bool {
return t == Normal || (t >= Bold && t < maxType)
}
func isValidType(t Type) bool { return t == Normal || (t >= Bold && t < maxType) }

func sgr(t Type, foreground, background Color) ansi.ESC {
codes := make([]int, 0, 10)
Expand Down
2 changes: 1 addition & 1 deletion colors/colors_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import "golang.org/x/sys/windows"

// EnableVirtualTerminalProcessing 是否启用 ENABLE_VIRTUAL_TERMINAL_PROCESSING 模式
//
// enable 表示设置之前值,之后可调用 RestoreVirtualTerminalProcessing 恢复:
// enable 表示设置之前值,之后可调用 [RestoreVirtualTerminalProcessing] 恢复:
//
// enable, err := EnableVirtualTerminalProcessing(windows.Stdout)
// RestoreVirtualTerminalProcessing(enable) // 恢复
Expand Down
46 changes: 23 additions & 23 deletions prompt/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type Prompt struct {
defaultColor colors.Color
}

// New 声明 Prompt 变量
// New 声明 [Prompt] 变量
//
// delim 从 input 读取内容时的分隔符,如果为空,则采用 \n;
// defaultColor 默认值的颜色,如果该值无效,则会 panic。
Expand All @@ -48,14 +48,14 @@ func New(delim byte, input io.Reader, output io.Writer, defaultColor colors.Colo
// q 显示的问题内容;
// def 表示默认值。
func (p *Prompt) String(q, def string) (string, error) {
w := &w{}
w.print(p.output, colors.Default, q)
w := &w{output: p.output}
w.print(colors.Default, q)
if def != "" {
w.print(p.output, p.defaultColor, "(", def, ")")
w.print(p.defaultColor, "(", def, ")")
}
w.print(p.output, colors.Default, ":")
w.print(colors.Default, ":")

v := w.read(p)
v := w.read(p.reader, p.delim)
if w.err != nil {
return "", w.err
}
Expand All @@ -68,16 +68,16 @@ func (p *Prompt) String(q, def string) (string, error) {

// Bool 输出 bool 问题并获取用户的回答内容
func (p *Prompt) Bool(q string, def bool) (bool, error) {
w := &w{}
w.print(p.output, colors.Default, q)
w := &w{output: p.output}
w.print(colors.Default, q)
str := "Y"
if !def {
str = "N"
}
w.print(p.output, p.defaultColor, "(", str, ")")
w.print(p.output, colors.Default, ":")
w.print(p.defaultColor, "(", str, ")")
w.print(colors.Default, ":")

val := w.read(p)
val := w.read(p.reader, p.delim)
if w.err != nil {
return false, w.err
}
Expand All @@ -98,19 +98,19 @@ func (p *Prompt) Bool(q string, def bool) (bool, error) {
// slice 表示可选的问题列表;
// def 表示默认项的索引,必须在 slice 之内。
func (p *Prompt) Slice(q string, slice []string, def ...int) (selected []int, err error) {
w := &w{}
w.println(p.output, colors.Default, q)
w := &w{output: p.output}
w.println(colors.Default, q)
for i, v := range slice {
c := colors.Default
if inIntSlice(i, def) {
c = p.defaultColor
}
w.printf(p.output, c, "(%d)", i)
w.printf(p.output, colors.Default, "%s\n", v)
w.printf(c, "(%d)", i)
w.printf(colors.Default, "%s\n", v)
}
w.print(p.output, colors.Default, "请输入你的选择项,多项请用半角逗号(,)分隔:")
w.print(colors.Default, "请输入你的选择项,多项请用半角逗号(,)分隔:")

val := w.read(p)
val := w.read(p.reader, p.delim)
if w.err != nil {
return nil, w.err
}
Expand All @@ -135,19 +135,19 @@ func (p *Prompt) Slice(q string, slice []string, def ...int) (selected []int, er
// maps 表示可选的问题列表;
// def 表示默认项的索引,必须在 maps 之内。
func (p *Prompt) Map(q string, maps map[string]string, def ...string) (selected []string, err error) {
w := &w{}
w.println(p.output, colors.Default, q)
w := &w{output: p.output}
w.println(colors.Default, q)
for k, v := range maps {
c := colors.Default
if sliceutil.Count(def, func(i string, _ int) bool { return i == k }) > 0 {
c = p.defaultColor
}
w.printf(p.output, c, "(%s)", k)
w.printf(p.output, colors.Default, "%s\n", v)
w.printf(c, "(%s)", k)
w.printf(colors.Default, "%s\n", v)
}
w.print(p.output, colors.Default, "请输入你的选择项,多项请用半角逗号(,)分隔:")
w.print(colors.Default, "请输入你的选择项,多项请用半角逗号(,)分隔:")

val := w.read(p)
val := w.read(p.reader, p.delim)
if w.err != nil {
return nil, w.err
}
Expand Down
20 changes: 11 additions & 9 deletions prompt/w.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,42 @@
package prompt

import (
"bufio"
"io"

"github.com/issue9/term/v3/colors"
)

type w struct {
err error
output io.Writer
err error
}

func (w *w) println(output io.Writer, c colors.Color, v ...any) {
func (w *w) println(c colors.Color, v ...any) {
if w.err == nil {
_, w.err = colors.Fprintln(output, colors.Normal, c, colors.Default, v...)
_, w.err = colors.Fprintln(w.output, colors.Normal, c, colors.Default, v...)
}
}

func (w *w) print(output io.Writer, c colors.Color, v ...any) {
func (w *w) print(c colors.Color, v ...any) {
if w.err == nil {
_, w.err = colors.Fprint(output, colors.Normal, c, colors.Default, v...)
_, w.err = colors.Fprint(w.output, colors.Normal, c, colors.Default, v...)
}
}

func (w *w) printf(output io.Writer, c colors.Color, format string, v ...any) {
func (w *w) printf(c colors.Color, format string, v ...any) {
if w.err == nil {
_, w.err = colors.Fprintf(output, colors.Normal, c, colors.Default, format, v...)
_, w.err = colors.Fprintf(w.output, colors.Normal, c, colors.Default, format, v...)
}
}

// 从输入端读取一行内容
func (w *w) read(p *Prompt) (v string) {
func (w *w) read(reader *bufio.Reader, delim byte) (v string) {
if w.err != nil {
return ""
}

v, w.err = p.reader.ReadString(p.delim)
v, w.err = reader.ReadString(delim)
if w.err != nil {
return ""
}
Expand Down
35 changes: 16 additions & 19 deletions prompt/w_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package prompt

import (
"bufio"
"bytes"
"io"
"testing"
Expand All @@ -19,55 +20,51 @@ func TestW_print(t *testing.T) {
a := assert.New(t, false)

r := new(bytes.Buffer)
w := &w{}
w := &w{output: r}
p := New(0, r, io.Discard, colors.Red)
a.NotNil(p)

w.print(r, colors.Default, "print")
w.print(colors.Default, "print")
a.Contains(r.String(), "print")

r.Reset()
w.println(r, colors.Default, "println")
w.println(colors.Default, "println")
a.Contains(r.String(), "println")

r.Reset()
w.printf(r, colors.Default, "printf %s", "printf")
w.printf(colors.Default, "printf %s", "printf")
a.Contains(r.String(), "printf printf")
}

func TestW_read(t *testing.T) {
a := assert.New(t, false)

r := new(bytes.Buffer)
rr := bufio.NewReader(r)
w := &w{}
p := New(0, r, io.Discard, colors.Red)
a.NotNil(p)

r.WriteString("hello\nworld\n\n")
a.Equal(w.read(p), "hello")
a.Equal(w.read(p), "world")
a.Equal(w.read(p), "")
a.Equal(w.read(p), "")
a.Equal(w.read(rr, '\n'), "hello")
a.Equal(w.read(rr, '\n'), "world")
a.Equal(w.read(rr, '\n'), "")
a.Equal(w.read(rr, '\n'), "")
a.NotNil(w.err)

// 没有读到指定分隔符,则读取所有
r.Reset()
rr.Reset(r)
w.err = nil
p = New('x', r, io.Discard, colors.Red)
a.NotNil(p)
r.WriteString("hello\nworld\n\n")
a.Equal(w.read(p), "").
a.Equal(w.read(rr, 'x'), "").
NotNil(w.err)

// 返回错误信息
r.Reset()
rr.Reset(r)
w.err = nil
p = New(0, iotest.TimeoutReader(r), io.Discard, colors.Red)
a.NotNil(p)
rr = bufio.NewReader(iotest.TimeoutReader(r))
r.WriteString("hello")
a.Equal(w.read(p), "").
a.Equal(w.read(rr, '\n'), "").
NotNil(w.err)
r.WriteString("world\n\n")
a.Equal(w.read(p), "").
a.Equal(w.read(rr, '\n'), "").
NotNil(w.err)
}

0 comments on commit c79513c

Please sign in to comment.