Skip to content

Commit

Permalink
Fix Enum column definition parse logic to match ClickHouse spec (#1380)
Browse files Browse the repository at this point in the history
* Fix Enum column definition parse logic to match ClickHouse spec

Enum column value is allowed to contain a comma or escaped single quote: https://clickhouse.com/docs/en/sql-reference/data-types/enum

* Add single-quote string escape support
  • Loading branch information
jkaflik authored Aug 22, 2024
1 parent a192162 commit 9b995de
Show file tree
Hide file tree
Showing 4 changed files with 318 additions and 63 deletions.
188 changes: 128 additions & 60 deletions lib/column/enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,89 +18,157 @@
package column

import (
"bytes"
"errors"
"github.com/ClickHouse/ch-go/proto"
"math"
"strconv"
"strings"

"github.com/ClickHouse/ch-go/proto"
)

func Enum(chType Type, name string) (Interface, error) {
var (
payload string
columnType = string(chType)
)
if len(columnType) < 8 {
return nil, &Error{
ColumnType: string(chType),
Err: errors.New("invalid Enum"),
}
}
switch {
case strings.HasPrefix(columnType, "Enum8"):
payload = columnType[6:]
case strings.HasPrefix(columnType, "Enum16"):
payload = columnType[7:]
default:
enumType, values, indexes, valid := extractEnumNamedValues(chType)
if !valid {
return nil, &Error{
ColumnType: string(chType),
Err: errors.New("invalid Enum"),
}
}
var (
idents []string
indexes []int64
)
for _, block := range strings.Split(payload[:len(payload)-1], ",") {
parts := strings.Split(block, "=")
if len(parts) != 2 {
return nil, &Error{
ColumnType: string(chType),
Err: errors.New("invalid Enum"),
}
}
var (
ident = strings.TrimSpace(parts[0])
index, err = strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 16)
)
if err != nil || len(ident) < 2 {
return nil, &Error{
ColumnType: string(chType),
Err: errors.New("invalid Enum"),
}
}
ident = ident[1 : len(ident)-1]
idents, indexes = append(idents, ident), append(indexes, index)
}
if strings.HasPrefix(columnType, "Enum8") {

if enumType == enum8Type {
enum := Enum8{
iv: make(map[string]proto.Enum8, len(idents)),
vi: make(map[proto.Enum8]string, len(idents)),
iv: make(map[string]proto.Enum8, len(values)),
vi: make(map[proto.Enum8]string, len(values)),
chType: chType,
name: name,
}
for i := range idents {
if indexes[i] > math.MaxUint8 {
return nil, &Error{
ColumnType: string(chType),
Err: errors.New("invalid Enum"),
}
}
for i := range values {
v := int8(indexes[i])
enum.iv[idents[i]] = proto.Enum8(v)
enum.vi[proto.Enum8(v)] = idents[i]
enum.iv[values[i]] = proto.Enum8(v)
enum.vi[proto.Enum8(v)] = values[i]
}
return &enum, nil
}
enum := Enum16{
iv: make(map[string]proto.Enum16, len(idents)),
vi: make(map[proto.Enum16]string, len(idents)),
iv: make(map[string]proto.Enum16, len(values)),
vi: make(map[proto.Enum16]string, len(values)),
chType: chType,
name: name,
}
for i := range idents {
enum.iv[idents[i]] = proto.Enum16(indexes[i])
enum.vi[proto.Enum16(indexes[i])] = idents[i]

for i := range values {
enum.iv[values[i]] = proto.Enum16(indexes[i])
enum.vi[proto.Enum16(indexes[i])] = values[i]
}
return &enum, nil
}

const (
enum8Type = "Enum8"
E
enum16Type = "Enum16"
)

func extractEnumNamedValues(chType Type) (typ string, values []string, indexes []int, valid bool) {
src := []byte(chType)

var bracketOpen, stringOpen bool

var foundValueOffset int
var foundValueLen int
var skippedValueTokens []int
var indexFound bool
var valueIndex = 0

for c := 0; c < len(src); c++ {
token := src[c]

switch {
// open bracket found, capture the type
case token == '(' && !stringOpen:
typ = string(src[:c])

// Ignore everything captured as non-enum type
if typ != enum8Type && typ != enum16Type {
return
}

bracketOpen = true
break
// when inside a bracket, we can start capture value inside single quotes
case bracketOpen && token == '\'' && !stringOpen:
foundValueOffset = c + 1
stringOpen = true
break
// close the string and capture the value
case token == '\'' && stringOpen:
stringOpen = false
foundValueLen = c - foundValueOffset
break
// escape character, skip the next character
case token == '\\' && stringOpen:
skippedValueTokens = append(skippedValueTokens, c-foundValueOffset)
c++
break
// capture optional index. `=` token is followed with an integer index
case token == '=' && !stringOpen:
if foundValueLen == 0 {
return
}

indexStart := c + 1
// find the end of the index, it's either a comma or a closing bracket
for _, token := range src[indexStart:] {
if token == ',' || token == ')' {
break
}
c++
}

idx, err := strconv.Atoi(string(bytes.TrimSpace(src[indexStart : c+1])))
if err != nil {
return
}
valueIndex = idx
indexFound = true
break
// capture the value and index when a comma or closing bracket is found
case (token == ',' || token == ')') && !stringOpen:
if foundValueLen == 0 {
return
}

// if no index was found for current value, increment the value index
// e.g. Enum8('a','b') is equivalent to Enum8('a'=1,'b'=2)
// or Enum8('a'=3,'b') is equivalent to Enum8('a'=3,'b'=4)
// so if no index is provided, we increment the value index
if !indexFound {
valueIndex++
}

// if the index is out of range, return
if (typ == enum8Type && valueIndex > math.MaxUint8) ||
(typ == enum16Type && valueIndex > math.MaxUint16) {
return
}

foundName := src[foundValueOffset : foundValueOffset+foundValueLen]
for _, skipped := range skippedValueTokens {
foundName = append(foundName[:skipped], foundName[skipped+1:]...)
}

indexes = append(indexes, valueIndex)
values = append(values, string(foundName))
indexFound = false
break
}
}

// Enum type must have at least one value
if valueIndex == 0 {
return
}

valid = true
return
}
148 changes: 148 additions & 0 deletions lib/column/enum_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package column

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestExtractEnumNamedValues(t *testing.T) {
tests := []struct {
name string
chType Type
expectedType string
expectedValues map[int]string
isNotValid bool
}{
{
name: "Enum8",
chType: "Enum8('a'=1,'b'=2)",
expectedType: "Enum8",
expectedValues: map[int]string{
1: "a",
2: "b",
},
},
{
name: "Enum16",
chType: "Enum16('a'=1,'b'=2)",
expectedType: "Enum16",
expectedValues: map[int]string{
1: "a",
2: "b",
},
},
{
name: "Enum8 with comma in value",
chType: "Enum8('a'=1,'b'=2,'c,d'=3)",
expectedType: "Enum8",
expectedValues: map[int]string{
1: "a",
2: "b",
3: "c,d",
},
},
{
name: "Enum8 with spaces",
chType: "Enum8('a' = 1, 'b' = 2)",
expectedType: "Enum8",
expectedValues: map[int]string{
1: "a",
2: "b",
},
},
{
name: "Enum8 without indexes",
chType: "Enum8('a','b')",
expectedType: "Enum8",
expectedValues: map[int]string{
1: "a",
2: "b",
},
},
{
name: "Enum8 with a first index only",
chType: "Enum8('a'=1,'b')",
expectedType: "Enum8",
expectedValues: map[int]string{
1: "a",
2: "b",
},
},
{
name: "Enum8 with a last index only",
chType: "Enum8('a','b'=5)",
expectedType: "Enum8",
expectedValues: map[int]string{
1: "a",
5: "b",
},
},
{
name: "Enum8 with a first index only higher than 1",
chType: "Enum8('a'=5,'b')",
expectedType: "Enum8",
expectedValues: map[int]string{
5: "a",
6: "b",
},
},
{
name: "Enum8 with index with spaces",
chType: "Enum8( 'a' , 'b' = 5 )",
expectedType: "Enum8",
expectedValues: map[int]string{
1: "a",
5: "b",
},
},
{
name: "Enum8 with escaped quotes",
chType: `Enum8('a\'b'=1)`,
expectedType: "Enum8",
expectedValues: map[int]string{
1: "a'b",
},
},
{
name: "Enum8 with invalid index",
chType: "Enum8('a'=1,'b'=256)",
isNotValid: true,
},
{
name: "Enum8 with invalid non-integer index",
chType: "Enum8('a'=1,'b'='c')",
isNotValid: true,
},
{
name: "Empty Enum8",
chType: "Enum8()",
isNotValid: true,
},
{
name: "Empty Enum8 without brackets",
chType: "Enum8",
isNotValid: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actualType, actualValues, actualIndexes, valid := extractEnumNamedValues(tt.chType)

if tt.isNotValid {
assert.False(t, valid, "%s is valid enum", tt.chType)
return
}

actualValuesMap := make(map[int]string)
for i, v := range actualValues {
actualValuesMap[actualIndexes[i]] = v
}

assert.Equal(t, tt.expectedType, actualType)
assert.Equal(t, tt.expectedValues, actualValuesMap)

assert.True(t, valid, "%s is not valid enum", tt.chType)
})
}
}
7 changes: 4 additions & 3 deletions tests/enum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ import (
"context"
"database/sql/driver"
"fmt"
"github.com/stretchr/testify/require"
"testing"

"github.com/stretchr/testify/require"

"github.com/ClickHouse/clickhouse-go/v2"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -110,7 +111,7 @@ func TestEnum(t *testing.T) {
require.NoError(t, err)
const ddl = `
CREATE TABLE test_enum (
Col1 Enum ('hello' = 1, 'world' = 2)
Col1 Enum ('hello,hello' = 1, 'world' = 2)
, Col2 Enum8 ('click' = 5, 'house' = 25)
, Col3 Enum16('house' = 10, 'value' = 50)
, Col4 Array(Enum8 ('click' = 1, 'house' = 2))
Expand All @@ -126,7 +127,7 @@ func TestEnum(t *testing.T) {
batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_enum")
require.NoError(t, err)
var (
col1Data = "hello"
col1Data = "hello,hello"
col2Data = "click"
col3Data = "house"
col4Data = []string{"click", "house"}
Expand Down
Loading

0 comments on commit 9b995de

Please sign in to comment.