From 75be8a7ab6bfcf62f3d060b2a104c2dcdf77e4ce Mon Sep 17 00:00:00 2001
From: Francesco Romani <fromani@redhat.com>
Date: Wed, 29 Nov 2023 09:39:35 +0100
Subject: [PATCH] WIP: flagcodec: enable flag normalization

Signed-off-by: Francesco Romani <fromani@redhat.com>
---
 pkg/flagcodec/flagcodec.go      |  77 +++++++++++++-----
 pkg/flagcodec/flagcodec_test.go | 139 ++++++++++++++++++++++++++++----
 2 files changed, 184 insertions(+), 32 deletions(-)

diff --git a/pkg/flagcodec/flagcodec.go b/pkg/flagcodec/flagcodec.go
index b988b9c3..965c6dd8 100644
--- a/pkg/flagcodec/flagcodec.go
+++ b/pkg/flagcodec/flagcodec.go
@@ -38,9 +38,10 @@ type Val struct {
 }
 
 type Flags struct {
-	command string
-	args    map[string]Val
-	keys    []string
+	command         string
+	args            map[string]Val
+	keys            []string
+	processFlagName func(string) string
 }
 
 // ParseArgvKeyValue parses a clean (trimmed) argv whose components
@@ -51,9 +52,24 @@ type Flags struct {
 // "--opt=foo"
 // AND NOT
 // "--opt", "foo"
-// The value of argv[0], whatever it is, is taken at command.
-func ParseArgvKeyValue(args []string) *Flags {
-	return ParseArgvKeyValueWithCommand("", args)
+func ParseArgvKeyValue(args []string, opts ...Option) *Flags {
+	ret := &Flags{
+		command:         "",
+		args:            make(map[string]Val),
+		processFlagName: func(v string) string { return v },
+	}
+	for _, opt := range opts {
+		opt(ret)
+	}
+	for _, arg := range args {
+		fields := strings.SplitN(arg, "=", 2)
+		if len(fields) == 1 {
+			ret.SetToggle(fields[0])
+			continue
+		}
+		ret.SetOption(fields[0], fields[1])
+	}
+	return ret
 }
 
 // ParseArgvKeyValueWithCommand parses a clean (trimmed) argv whose components
@@ -64,21 +80,42 @@ func ParseArgvKeyValue(args []string) *Flags {
 // "--opt=foo"
 // AND NOT
 // "--opt", "foo"
-// The command is supplied explicitely as parameter.
+// The command is supplied explicitly as parameter.
 func ParseArgvKeyValueWithCommand(command string, args []string) *Flags {
-	ret := &Flags{
-		command: command,
-		args:    make(map[string]Val),
+	return ParseArgvKeyValue(args, WithCommand(command))
+}
+
+type Option func(*Flags) *Flags
+
+func normalizeFlagName(v string) string {
+	if len(v) == 3 && v[0] == '-' && v[1] == '-' {
+		// single char, double dash flag (ugly?), fix it
+		return v[1:]
 	}
-	for _, arg := range args {
-		fields := strings.SplitN(arg, "=", 2)
-		if len(fields) == 1 {
-			ret.SetToggle(fields[0])
-			continue
-		}
-		ret.SetOption(fields[0], fields[1])
+	// everything else pass through silently
+	return v
+}
+
+// WithFlagNormalization optionally enables flag normalization.
+// The canonical representation of flags in this package is:
+// * single-dash for one-char flags (-v, -h)
+// * double-dash for multi-char flags (--foo, --long-option)
+// pflag allows one-char to have one or two dashes. For flagcodec
+// these were different options. When normalization is enabled,
+// though, all flag names are processed to adhere to the canonical
+// representation, so flagcodec will treat `--v` and `-v` to
+// be the same flag. Since this is possibly breaking change,
+// this treatment is opt-in.
+func WithFlagNormalization(fl *Flags) *Flags {
+	fl.processFlagName = normalizeFlagName
+	return fl
+}
+
+func WithCommand(command string) Option {
+	return func(fl *Flags) *Flags {
+		fl.command = command
+		return fl
 	}
-	return ret
 }
 
 func (fl *Flags) recordFlag(name string) {
@@ -99,6 +136,7 @@ func (fl *Flags) forgetFlag(name string) {
 }
 
 func (fl *Flags) SetToggle(name string) {
+	name = fl.processFlagName(name)
 	fl.recordFlag(name)
 	fl.args[name] = Val{
 		Kind: FlagToggle,
@@ -106,6 +144,7 @@ func (fl *Flags) SetToggle(name string) {
 }
 
 func (fl *Flags) SetOption(name, data string) {
+	name = fl.processFlagName(name)
 	fl.recordFlag(name)
 	fl.args[name] = Val{
 		Kind: FlagOption,
@@ -114,6 +153,7 @@ func (fl *Flags) SetOption(name, data string) {
 }
 
 func (fl *Flags) Delete(name string) {
+	name = fl.processFlagName(name)
 	fl.forgetFlag(name)
 	delete(fl.args, name)
 }
@@ -139,6 +179,7 @@ func (fl *Flags) Argv() []string {
 }
 
 func (fl *Flags) GetFlag(name string) (Val, bool) {
+	name = fl.processFlagName(name)
 	if val, ok := fl.args[name]; ok {
 		return val, ok
 	}
diff --git a/pkg/flagcodec/flagcodec_test.go b/pkg/flagcodec/flagcodec_test.go
index 11e102af..a12e801d 100644
--- a/pkg/flagcodec/flagcodec_test.go
+++ b/pkg/flagcodec/flagcodec_test.go
@@ -64,14 +64,25 @@ func TestParseStringRoundTrip(t *testing.T) {
 		},
 	}
 
-	for _, tc := range testCases {
-		t.Run(tc.name, func(t *testing.T) {
-			fl := ParseArgvKeyValue(tc.argv)
-			got := fl.Argv()
-			if !reflect.DeepEqual(tc.expected, got) {
-				t.Errorf("expected %v got %v", tc.expected, got)
+	for _, normFlag := range []bool{false, true} {
+		for _, tc := range testCases {
+			name := tc.name
+			if normFlag {
+				name += "-norm-flag"
 			}
-		})
+			t.Run(name, func(t *testing.T) {
+				var fl *Flags
+				if normFlag {
+					fl = ParseArgvKeyValue(tc.argv, WithFlagNormalization)
+				} else {
+					fl = ParseArgvKeyValue(tc.argv)
+				}
+				got := fl.Argv()
+				if !reflect.DeepEqual(tc.expected, got) {
+					t.Errorf("expected %v got %v", tc.expected, got)
+				}
+			})
+		}
 	}
 }
 
@@ -117,9 +128,80 @@ func TestParseStringRoundTripWithCommand(t *testing.T) {
 		},
 	}
 
+	for _, normFlag := range []bool{false, true} {
+		for _, tc := range testCases {
+			name := tc.name
+			if normFlag {
+				name += "-norm-flag"
+			}
+			t.Run(name, func(t *testing.T) {
+				var fl *Flags
+				if normFlag {
+					fl = ParseArgvKeyValue(tc.args, WithCommand(tc.command), WithFlagNormalization)
+				} else {
+					fl = ParseArgvKeyValue(tc.args, WithCommand(tc.command))
+				}
+				got := fl.Argv()
+				if !reflect.DeepEqual(tc.expected, got) {
+					t.Errorf("expected %v got %v", tc.expected, got)
+				}
+			})
+		}
+	}
+}
+
+func TestAddFlags(t *testing.T) {
+	type testOpt struct {
+		name  string
+		value string
+	}
+
+	type testCase struct {
+		name     string
+		command  string
+		args     []string
+		options  []testOpt
+		expected []string
+	}
+
+	testCases := []testCase{
+		{
+			name:    "add-mixed",
+			command: "/bin/resource-topology-exporter",
+			args: []string{
+				"--sleep-interval=10s",
+				"--sysfs=/host-sys",
+				"--kubelet-state-dir=/host-var/lib/kubelet",
+				"--podresources-socket=unix:///host-var/lib/kubelet/pod-resources/kubelet.sock",
+			},
+			options: []testOpt{
+				{
+					name:  "--hostname",
+					value: "host.test.net",
+				},
+				{
+					name:  "--v",
+					value: "2",
+				},
+			},
+			expected: []string{
+				"/bin/resource-topology-exporter",
+				"--sleep-interval=10s",
+				"--sysfs=/host-sys",
+				"--kubelet-state-dir=/host-var/lib/kubelet",
+				"--podresources-socket=unix:///host-var/lib/kubelet/pod-resources/kubelet.sock",
+				"--hostname=host.test.net",
+				"--v=2",
+			},
+		},
+	}
+
 	for _, tc := range testCases {
 		t.Run(tc.name, func(t *testing.T) {
-			fl := ParseArgvKeyValueWithCommand(tc.command, tc.args)
+			fl := ParseArgvKeyValue(tc.args, WithCommand(tc.command))
+			for _, opt := range tc.options {
+				fl.SetOption(opt.name, opt.value)
+			}
 			got := fl.Argv()
 			if !reflect.DeepEqual(tc.expected, got) {
 				t.Errorf("expected %v got %v", tc.expected, got)
@@ -128,7 +210,7 @@ func TestParseStringRoundTripWithCommand(t *testing.T) {
 	}
 }
 
-func TestAddFlags(t *testing.T) {
+func TestAddFlagsNormalized(t *testing.T) {
 	type testOpt struct {
 		name  string
 		value string
@@ -144,7 +226,7 @@ func TestAddFlags(t *testing.T) {
 
 	testCases := []testCase{
 		{
-			name:    "add-mixed",
+			name:    "add-mixed-two-dashes-single-letter",
 			command: "/bin/resource-topology-exporter",
 			args: []string{
 				"--sleep-interval=10s",
@@ -169,14 +251,43 @@ func TestAddFlags(t *testing.T) {
 				"--kubelet-state-dir=/host-var/lib/kubelet",
 				"--podresources-socket=unix:///host-var/lib/kubelet/pod-resources/kubelet.sock",
 				"--hostname=host.test.net",
-				"--v=2",
+				"-v=2",
+			},
+		},
+		{
+			name:    "add-mixed-single-dashe-single-letter",
+			command: "/bin/resource-topology-exporter",
+			args: []string{
+				"--sleep-interval=10s",
+				"--sysfs=/host-sys",
+				"--kubelet-state-dir=/host-var/lib/kubelet",
+				"--podresources-socket=unix:///host-var/lib/kubelet/pod-resources/kubelet.sock",
+			},
+			options: []testOpt{
+				{
+					name:  "--hostname",
+					value: "host.test.net",
+				},
+				{
+					name:  "-v",
+					value: "2",
+				},
+			},
+			expected: []string{
+				"/bin/resource-topology-exporter",
+				"--sleep-interval=10s",
+				"--sysfs=/host-sys",
+				"--kubelet-state-dir=/host-var/lib/kubelet",
+				"--podresources-socket=unix:///host-var/lib/kubelet/pod-resources/kubelet.sock",
+				"--hostname=host.test.net",
+				"-v=2",
 			},
 		},
 	}
 
 	for _, tc := range testCases {
 		t.Run(tc.name, func(t *testing.T) {
-			fl := ParseArgvKeyValueWithCommand(tc.command, tc.args)
+			fl := ParseArgvKeyValue(tc.args, WithCommand(tc.command), WithFlagNormalization)
 			for _, opt := range tc.options {
 				fl.SetOption(opt.name, opt.value)
 			}
@@ -285,7 +396,7 @@ func TestDeleteFlags(t *testing.T) {
 
 	for _, tc := range testCases {
 		t.Run(tc.name, func(t *testing.T) {
-			fl := ParseArgvKeyValueWithCommand(tc.command, tc.args)
+			fl := ParseArgvKeyValue(tc.args, WithCommand(tc.command))
 			for _, opt := range tc.options {
 				fl.Delete(opt)
 			}
@@ -380,7 +491,7 @@ func TestGetFlags(t *testing.T) {
 
 	for _, tc := range testCases {
 		t.Run(tc.name, func(t *testing.T) {
-			fl := ParseArgvKeyValueWithCommand(tc.command, tc.args)
+			fl := ParseArgvKeyValue(tc.args, WithCommand(tc.command))
 			for idx := range tc.params {
 				param := tc.params[idx]
 				exp := tc.expected[idx]