diff --git a/lib/column/enum.go b/lib/column/enum.go index 935dd2d4eb..026f82f376 100644 --- a/lib/column/enum.go +++ b/lib/column/enum.go @@ -18,89 +18,157 @@ package column import ( + "bytes" "errors" - "github.com/ClickHouse/ch-go/proto" "math" "strconv" - "strings" + + "github.com/ClickHouse/ch-go/proto" ) func Enum(chType Type, name string) (Interface, error) { - var ( - payload string - columnType = string(chType) - ) - if len(columnType) < 8 { - return nil, &Error{ - ColumnType: string(chType), - Err: errors.New("invalid Enum"), - } - } - switch { - case strings.HasPrefix(columnType, "Enum8"): - payload = columnType[6:] - case strings.HasPrefix(columnType, "Enum16"): - payload = columnType[7:] - default: + enumType, values, indexes, valid := extractEnumNamedValues(chType) + if !valid { return nil, &Error{ ColumnType: string(chType), Err: errors.New("invalid Enum"), } } - var ( - idents []string - indexes []int64 - ) - for _, block := range strings.Split(payload[:len(payload)-1], ",") { - parts := strings.Split(block, "=") - if len(parts) != 2 { - return nil, &Error{ - ColumnType: string(chType), - Err: errors.New("invalid Enum"), - } - } - var ( - ident = strings.TrimSpace(parts[0]) - index, err = strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 16) - ) - if err != nil || len(ident) < 2 { - return nil, &Error{ - ColumnType: string(chType), - Err: errors.New("invalid Enum"), - } - } - ident = ident[1 : len(ident)-1] - idents, indexes = append(idents, ident), append(indexes, index) - } - if strings.HasPrefix(columnType, "Enum8") { + + if enumType == enum8Type { enum := Enum8{ - iv: make(map[string]proto.Enum8, len(idents)), - vi: make(map[proto.Enum8]string, len(idents)), + iv: make(map[string]proto.Enum8, len(values)), + vi: make(map[proto.Enum8]string, len(values)), chType: chType, name: name, } - for i := range idents { - if indexes[i] > math.MaxUint8 { - return nil, &Error{ - ColumnType: string(chType), - Err: errors.New("invalid Enum"), - } - } + for i := range values { v := int8(indexes[i]) - enum.iv[idents[i]] = proto.Enum8(v) - enum.vi[proto.Enum8(v)] = idents[i] + enum.iv[values[i]] = proto.Enum8(v) + enum.vi[proto.Enum8(v)] = values[i] } return &enum, nil } enum := Enum16{ - iv: make(map[string]proto.Enum16, len(idents)), - vi: make(map[proto.Enum16]string, len(idents)), + iv: make(map[string]proto.Enum16, len(values)), + vi: make(map[proto.Enum16]string, len(values)), chType: chType, name: name, } - for i := range idents { - enum.iv[idents[i]] = proto.Enum16(indexes[i]) - enum.vi[proto.Enum16(indexes[i])] = idents[i] + + for i := range values { + enum.iv[values[i]] = proto.Enum16(indexes[i]) + enum.vi[proto.Enum16(indexes[i])] = values[i] } return &enum, nil } + +const ( + enum8Type = "Enum8" + E + enum16Type = "Enum16" +) + +func extractEnumNamedValues(chType Type) (typ string, values []string, indexes []int, valid bool) { + src := []byte(chType) + + var bracketOpen, stringOpen bool + + var foundValueOffset int + var foundValueLen int + var skippedValueTokens []int + var indexFound bool + var valueIndex = 0 + + for c := 0; c < len(src); c++ { + token := src[c] + + switch { + // open bracket found, capture the type + case token == '(' && !stringOpen: + typ = string(src[:c]) + + // Ignore everything captured as non-enum type + if typ != enum8Type && typ != enum16Type { + return + } + + bracketOpen = true + break + // when inside a bracket, we can start capture value inside single quotes + case bracketOpen && token == '\'' && !stringOpen: + foundValueOffset = c + 1 + stringOpen = true + break + // close the string and capture the value + case token == '\'' && stringOpen: + stringOpen = false + foundValueLen = c - foundValueOffset + break + // escape character, skip the next character + case token == '\\' && stringOpen: + skippedValueTokens = append(skippedValueTokens, c-foundValueOffset) + c++ + break + // capture optional index. `=` token is followed with an integer index + case token == '=' && !stringOpen: + if foundValueLen == 0 { + return + } + + indexStart := c + 1 + // find the end of the index, it's either a comma or a closing bracket + for _, token := range src[indexStart:] { + if token == ',' || token == ')' { + break + } + c++ + } + + idx, err := strconv.Atoi(string(bytes.TrimSpace(src[indexStart : c+1]))) + if err != nil { + return + } + valueIndex = idx + indexFound = true + break + // capture the value and index when a comma or closing bracket is found + case (token == ',' || token == ')') && !stringOpen: + if foundValueLen == 0 { + return + } + + // if no index was found for current value, increment the value index + // e.g. Enum8('a','b') is equivalent to Enum8('a'=1,'b'=2) + // or Enum8('a'=3,'b') is equivalent to Enum8('a'=3,'b'=4) + // so if no index is provided, we increment the value index + if !indexFound { + valueIndex++ + } + + // if the index is out of range, return + if (typ == enum8Type && valueIndex > math.MaxUint8) || + (typ == enum16Type && valueIndex > math.MaxUint16) { + return + } + + foundName := src[foundValueOffset : foundValueOffset+foundValueLen] + for _, skipped := range skippedValueTokens { + foundName = append(foundName[:skipped], foundName[skipped+1:]...) + } + + indexes = append(indexes, valueIndex) + values = append(values, string(foundName)) + indexFound = false + break + } + } + + // Enum type must have at least one value + if valueIndex == 0 { + return + } + + valid = true + return +} diff --git a/lib/column/enum_test.go b/lib/column/enum_test.go new file mode 100644 index 0000000000..467687f596 --- /dev/null +++ b/lib/column/enum_test.go @@ -0,0 +1,148 @@ +package column + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExtractEnumNamedValues(t *testing.T) { + tests := []struct { + name string + chType Type + expectedType string + expectedValues map[int]string + isNotValid bool + }{ + { + name: "Enum8", + chType: "Enum8('a'=1,'b'=2)", + expectedType: "Enum8", + expectedValues: map[int]string{ + 1: "a", + 2: "b", + }, + }, + { + name: "Enum16", + chType: "Enum16('a'=1,'b'=2)", + expectedType: "Enum16", + expectedValues: map[int]string{ + 1: "a", + 2: "b", + }, + }, + { + name: "Enum8 with comma in value", + chType: "Enum8('a'=1,'b'=2,'c,d'=3)", + expectedType: "Enum8", + expectedValues: map[int]string{ + 1: "a", + 2: "b", + 3: "c,d", + }, + }, + { + name: "Enum8 with spaces", + chType: "Enum8('a' = 1, 'b' = 2)", + expectedType: "Enum8", + expectedValues: map[int]string{ + 1: "a", + 2: "b", + }, + }, + { + name: "Enum8 without indexes", + chType: "Enum8('a','b')", + expectedType: "Enum8", + expectedValues: map[int]string{ + 1: "a", + 2: "b", + }, + }, + { + name: "Enum8 with a first index only", + chType: "Enum8('a'=1,'b')", + expectedType: "Enum8", + expectedValues: map[int]string{ + 1: "a", + 2: "b", + }, + }, + { + name: "Enum8 with a last index only", + chType: "Enum8('a','b'=5)", + expectedType: "Enum8", + expectedValues: map[int]string{ + 1: "a", + 5: "b", + }, + }, + { + name: "Enum8 with a first index only higher than 1", + chType: "Enum8('a'=5,'b')", + expectedType: "Enum8", + expectedValues: map[int]string{ + 5: "a", + 6: "b", + }, + }, + { + name: "Enum8 with index with spaces", + chType: "Enum8( 'a' , 'b' = 5 )", + expectedType: "Enum8", + expectedValues: map[int]string{ + 1: "a", + 5: "b", + }, + }, + { + name: "Enum8 with escaped quotes", + chType: `Enum8('a\'b'=1)`, + expectedType: "Enum8", + expectedValues: map[int]string{ + 1: "a'b", + }, + }, + { + name: "Enum8 with invalid index", + chType: "Enum8('a'=1,'b'=256)", + isNotValid: true, + }, + { + name: "Enum8 with invalid non-integer index", + chType: "Enum8('a'=1,'b'='c')", + isNotValid: true, + }, + { + name: "Empty Enum8", + chType: "Enum8()", + isNotValid: true, + }, + { + name: "Empty Enum8 without brackets", + chType: "Enum8", + isNotValid: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actualType, actualValues, actualIndexes, valid := extractEnumNamedValues(tt.chType) + + if tt.isNotValid { + assert.False(t, valid, "%s is valid enum", tt.chType) + return + } + + actualValuesMap := make(map[int]string) + for i, v := range actualValues { + actualValuesMap[actualIndexes[i]] = v + } + + assert.Equal(t, tt.expectedType, actualType) + assert.Equal(t, tt.expectedValues, actualValuesMap) + + assert.True(t, valid, "%s is not valid enum", tt.chType) + }) + } +} diff --git a/tests/enum_test.go b/tests/enum_test.go index 9569addd58..994b7087ff 100644 --- a/tests/enum_test.go +++ b/tests/enum_test.go @@ -21,9 +21,10 @@ import ( "context" "database/sql/driver" "fmt" - "github.com/stretchr/testify/require" "testing" + "github.com/stretchr/testify/require" + "github.com/ClickHouse/clickhouse-go/v2" "github.com/stretchr/testify/assert" ) @@ -110,7 +111,7 @@ func TestEnum(t *testing.T) { require.NoError(t, err) const ddl = ` CREATE TABLE test_enum ( - Col1 Enum ('hello' = 1, 'world' = 2) + Col1 Enum ('hello,hello' = 1, 'world' = 2) , Col2 Enum8 ('click' = 5, 'house' = 25) , Col3 Enum16('house' = 10, 'value' = 50) , Col4 Array(Enum8 ('click' = 1, 'house' = 2)) @@ -126,7 +127,7 @@ func TestEnum(t *testing.T) { batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_enum") require.NoError(t, err) var ( - col1Data = "hello" + col1Data = "hello,hello" col2Data = "click" col3Data = "house" col4Data = []string{"click", "house"} diff --git a/tests/issues/1299_test.go b/tests/issues/1299_test.go new file mode 100644 index 0000000000..fb16e1ebc0 --- /dev/null +++ b/tests/issues/1299_test.go @@ -0,0 +1,38 @@ +package issues + +import ( + "context" + "testing" + + "github.com/ClickHouse/clickhouse-go/v2/tests" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIssue1299(t *testing.T) { + ctx := context.Background() + conn, err := tests.GetConnection("issues", nil, nil, nil) + require.NoError(t, err) + defer conn.Close() + + expectedEnumValue := "raw:48h',1h:63d,1d:5y" + + const ddl = ` + CREATE TABLE test_1299 ( + Col1 Enum ('raw:48h\',1h:63d,1d:5y' = 1, 'raw:8h,1m:48h,1h:63d,1d:5y' = 2) + ) Engine MergeTree() ORDER BY tuple() + ` + err = conn.Exec(ctx, ddl) + require.NoError(t, err) + defer conn.Exec(ctx, "DROP TABLE test_1299") + + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_1299") + require.NoError(t, err) + require.NoError(t, batch.Append(expectedEnumValue)) + require.NoError(t, batch.Send()) + + var actualEnumValue string + require.NoError(t, conn.QueryRow(ctx, "SELECT * FROM test_1299").Scan(&actualEnumValue)) + + assert.Equal(t, expectedEnumValue, actualEnumValue) +}