Skip to content

Commit

Permalink
Makes worker database value conversion to Go types strictly enforce e…
Browse files Browse the repository at this point in the history
…rror handling (#3170)
  • Loading branch information
alishakawaguchi authored Feb 6, 2025
1 parent f2a4457 commit ad4b856
Show file tree
Hide file tree
Showing 22 changed files with 293 additions and 198 deletions.
43 changes: 26 additions & 17 deletions internal/database-record-mapper/dynamodb/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,67 +26,76 @@ func (m *DynamoDBMapper) MapRecordWithKeyType(item map[string]types.AttributeVal
standardJSON := make(map[string]any)
ktm := make(map[string]neosync_types.KeyType)
for k, v := range item {
val := parseAttributeValue(k, v, ktm)
val, err := parseAttributeValue(k, v, ktm)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse attribute value for key %q: %w", k, err)
}
standardJSON[k] = val
}
return standardJSON, ktm, nil
}

// ParseAttributeValue converts a DynamoDB AttributeValue to a standard value
func parseAttributeValue(key string, v types.AttributeValue, keyTypeMap map[string]neosync_types.KeyType) any {
func parseAttributeValue(key string, v types.AttributeValue, keyTypeMap map[string]neosync_types.KeyType) (any, error) {
switch t := v.(type) {
case *types.AttributeValueMemberB:
return t.Value
return t.Value, nil
case *types.AttributeValueMemberBOOL:
return t.Value
return t.Value, nil
case *types.AttributeValueMemberBS:
return t.Value
return t.Value, nil
case *types.AttributeValueMemberL:
lAny := make([]any, len(t.Value))
for i, v := range t.Value {
val := parseAttributeValue(fmt.Sprintf("%s[%d]", key, i), v, keyTypeMap)
val, err := parseAttributeValue(fmt.Sprintf("%s[%d]", key, i), v, keyTypeMap)
if err != nil {
return nil, fmt.Errorf("failed to parse list value at index %d for key %q: %w", i, key, err)
}
lAny[i] = val
}
return lAny
return lAny, nil
case *types.AttributeValueMemberM:
mAny := make(map[string]any, len(t.Value))
for k, v := range t.Value {
path := k
if key != "" {
path = fmt.Sprintf("%s.%s", key, k)
}
val := parseAttributeValue(path, v, keyTypeMap)
val, err := parseAttributeValue(path, v, keyTypeMap)
if err != nil {
return nil, fmt.Errorf("failed to parse map value for key %q: %w", path, err)
}
mAny[k] = val
}
return mAny
return mAny, nil
case *types.AttributeValueMemberN:
n, err := gotypeutil.ParseStringAsNumber(t.Value)
if err != nil {
return t.Value
return nil, fmt.Errorf("failed to parse number value for key %q: %w", key, err)
}
return n
return n, nil
case *types.AttributeValueMemberNS:
keyTypeMap[key] = neosync_types.NumberSet
lAny := make([]any, len(t.Value))
for i, v := range t.Value {
n, err := gotypeutil.ParseStringAsNumber(v)
if err != nil {
return v
return nil, fmt.Errorf("failed to parse number set value at index %d for key %q: %w", i, key, err)
}
lAny[i] = n
}
return lAny
return lAny, nil
case *types.AttributeValueMemberNULL:
return nil
return nil, nil
case *types.AttributeValueMemberS:
return t.Value
return t.Value, nil
case *types.AttributeValueMemberSS:
keyTypeMap[key] = neosync_types.StringSet
lAny := make([]any, len(t.Value))
for i, v := range t.Value {
lAny[i] = v
}
return lAny
return lAny, nil
}
return nil
return nil, fmt.Errorf("unsupported DynamoDB attribute type for key %q", key)
}
3 changes: 2 additions & 1 deletion internal/database-record-mapper/dynamodb/mapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ func Test_ParseAttributeValue(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ktm := map[string]neosync_types.KeyType{}
actual := parseAttributeValue(tt.name, tt.input, ktm)
actual, err := parseAttributeValue(tt.name, tt.input, ktm)
require.NoError(t, err)
require.True(t, reflect.DeepEqual(actual, tt.expected), fmt.Sprintf("expected %v %v, got %v %v", tt.expected, reflect.TypeOf(tt.expected), actual, reflect.TypeOf(actual)))
})
}
Expand Down
39 changes: 26 additions & 13 deletions internal/database-record-mapper/mongodb/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,46 +27,59 @@ func (m *MongoDBMapper) MapRecordWithKeyType(item map[string]any) (valuemap map[
result := make(map[string]any)
ktm := make(map[string]neosync_types.KeyType)
for k, v := range item {
result[k] = parsePrimitives(k, v, ktm)
val, err := parsePrimitives(k, v, ktm)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse primitive value for key %q: %w", k, err)
}
result[k] = val
}
return result, ktm, nil
}

func parsePrimitives(key string, value any, keyTypeMap map[string]neosync_types.KeyType) any {
func parsePrimitives(key string, value any, keyTypeMap map[string]neosync_types.KeyType) (any, error) {
switch v := value.(type) {
case primitive.Decimal128:
keyTypeMap[key] = neosync_types.Decimal128
floatVal, _, err := big.ParseFloat(v.String(), 10, 128, big.ToNearestEven)
if err == nil {
return floatVal
if err != nil {
return nil, fmt.Errorf("failed to parse decimal128 value for key %q: %w", key, err)
}
return v
return floatVal, nil
case primitive.Binary:
keyTypeMap[key] = neosync_types.Binary
return v
return v, nil
case primitive.ObjectID:
keyTypeMap[key] = neosync_types.ObjectID
return v
return v, nil
case primitive.Timestamp:
keyTypeMap[key] = neosync_types.Timestamp
return v
return v, nil
case bson.D:
m := make(map[string]any)
for _, elem := range v {
path := elem.Key
if key != "" {
path = fmt.Sprintf("%s.%s", key, elem.Key)
}
m[elem.Key] = parsePrimitives(path, elem.Value, keyTypeMap)
val, err := parsePrimitives(path, elem.Value, keyTypeMap)
if err != nil {
return nil, fmt.Errorf("failed to parse bson.D value for key %q: %w", path, err)
}
m[elem.Key] = val
}
return m
return m, nil
case bson.A:
result := make([]any, len(v))
for i, item := range v {
result[i] = parsePrimitives(fmt.Sprintf("%s[%d]", key, i), item, keyTypeMap)
path := fmt.Sprintf("%s[%d]", key, i)
val, err := parsePrimitives(path, item, keyTypeMap)
if err != nil {
return nil, fmt.Errorf("failed to parse bson.A value at index %d for key %q: %w", i, key, err)
}
result[i] = val
}
return result
return result, nil
default:
return v
return v, nil
}
}
3 changes: 2 additions & 1 deletion internal/database-record-mapper/mongodb/mapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ func Test_ParsePrimitives(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ktm := make(map[string]neosync_types.KeyType)
result := parsePrimitives(tc.key, tc.value, ktm)
result, err := parsePrimitives(tc.key, tc.value, ktm)
require.NoError(t, err)
require.Equal(t, tc.expectedKTM, ktm)
require.Equal(t, tc.expected, result)
})
Expand Down
16 changes: 7 additions & 9 deletions internal/database-record-mapper/mssql/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mssql
import (
"database/sql"
"errors"
"fmt"
"strings"
"time"

Expand Down Expand Up @@ -55,10 +56,10 @@ func (m *MSSQLMapper) MapRecord(rows *sql.Rows) (map[string]any, error) {
return nil, err
}

return parseRowValues(values, columnNames, columnDbTypes), nil
return parseRowValues(values, columnNames, columnDbTypes)
}

func parseRowValues(values []any, columnNames, columnDbTypes []string) map[string]any {
func parseRowValues(values []any, columnNames, columnDbTypes []string) (map[string]any, error) {
jObj := map[string]any{}
for i, v := range values {
col := columnNames[i]
Expand All @@ -67,8 +68,7 @@ func parseRowValues(values []any, columnNames, columnDbTypes []string) map[strin
case time.Time:
dt, err := neosynctypes.NewDateTimeFromMssql(t)
if err != nil {
jObj[col] = t
continue
return nil, fmt.Errorf("failed to convert time.Time to DateTime for column %s: %w", col, err)
}
jObj[col] = dt
case *mssql.UniqueIdentifier:
Expand All @@ -78,15 +78,13 @@ func parseRowValues(values []any, columnNames, columnDbTypes []string) map[strin
case strings.EqualFold(colType, "binary"):
binary, err := neosynctypes.NewBinaryFromMssql(t)
if err != nil {
jObj[col] = t
continue
return nil, fmt.Errorf("failed to convert binary data for column %s: %w", col, err)
}
jObj[col] = binary
case strings.EqualFold(colType, "varbinary"):
bits, err := neosynctypes.NewBitsFromMssql(t)
if err != nil {
jObj[col] = t
continue
return nil, fmt.Errorf("failed to convert varbinary data for column %s: %w", col, err)
}
jObj[col] = bits
default:
Expand All @@ -96,5 +94,5 @@ func parseRowValues(values []any, columnNames, columnDbTypes []string) map[strin
jObj[col] = t
}
}
return jObj
return jObj, nil
}
4 changes: 3 additions & 1 deletion internal/database-record-mapper/mssql/mapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
mssql "github.com/microsoft/go-mssqldb"
neosynctypes "github.com/nucleuscloud/neosync/internal/neosync-types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_parseRowValues(t *testing.T) {
Expand Down Expand Up @@ -45,7 +46,8 @@ func Test_parseRowValues(t *testing.T) {
"INT",
}

result := parseRowValues(values, columnNames, columnDbTypes)
result, err := parseRowValues(values, columnNames, columnDbTypes)
require.NoError(t, err)

// Test datetime handling
dt, err := neosynctypes.NewDateTimeFromMssql(testTime)
Expand Down
25 changes: 13 additions & 12 deletions internal/database-record-mapper/mysql/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"time"

Expand Down Expand Up @@ -46,12 +47,15 @@ func (m *MySQLMapper) MapRecord(rows *sql.Rows) (map[string]any, error) {
if err := rows.Scan(valuesWrapped...); err != nil {
return nil, err
}
jObj := parseMysqlRowValues(values, columnNames, columnDbTypes)
jObj, err := parseMysqlRowValues(values, columnNames, columnDbTypes)
if err != nil {
return nil, err
}

return jObj, nil
}

func parseMysqlRowValues(values []any, columnNames, columnDbTypes []string) map[string]any {
func parseMysqlRowValues(values []any, columnNames, columnDbTypes []string) (map[string]any, error) {
jObj := map[string]any{}
for i, v := range values {
col := columnNames[i]
Expand All @@ -60,29 +64,26 @@ func parseMysqlRowValues(values []any, columnNames, columnDbTypes []string) map[
case time.Time:
dt, err := neosynctypes.NewDateTimeFromMysql(t)
if err != nil {
jObj[col] = t
continue
return nil, fmt.Errorf("failed to parse datetime value: %w", err)
}
jObj[col] = dt
case []byte:
if strings.EqualFold(colDataType, "json") {
var js any
if err := json.Unmarshal(t, &js); err == nil {
jObj[col] = js
continue
if err := json.Unmarshal(t, &js); err != nil {
return nil, err
}
jObj[col] = js
} else if strings.EqualFold(colDataType, "binary") {
binary, err := neosynctypes.NewBinaryFromMysql(t)
if err != nil {
jObj[col] = t
continue
return nil, fmt.Errorf("failed to parse binary value: %w", err)
}
jObj[col] = binary
} else if strings.EqualFold(colDataType, "bit") || strings.EqualFold(colDataType, "varbit") {
bits, err := neosynctypes.NewBitsFromMysql(t)
if err != nil {
jObj[col] = t
continue
return nil, fmt.Errorf("failed to parse bit/varbit value: %w", err)
}
jObj[col] = bits
} else {
Expand All @@ -92,5 +93,5 @@ func parseMysqlRowValues(values []any, columnNames, columnDbTypes []string) map[
jObj[col] = t
}
}
return jObj
return jObj, nil
}
6 changes: 4 additions & 2 deletions internal/database-record-mapper/mysql/mapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ func Test_parseMysqlRowValues(t *testing.T) {
"json",
"binary",
}
result := parseMysqlRowValues(values, columnNames, cTypes)
result, err := parseMysqlRowValues(values, columnNames, cTypes)
require.NoError(t, err)
expected := map[string]any{
"text_col": "Hello",
"int_col": int64(42),
Expand All @@ -56,7 +57,8 @@ func Test_parseMysqlRowValues(t *testing.T) {
columnNames := []string{"text_col", "bool_col", "null_col", "int_col", "json_col", "array_col"}
cTypes := []string{"json", "json", "json", "json", "json", "json"}

result := parseMysqlRowValues(values, columnNames, cTypes)
result, err := parseMysqlRowValues(values, columnNames, cTypes)
require.NoError(t, err)

expected := map[string]any{
"text_col": "Hello",
Expand Down
Loading

0 comments on commit ad4b856

Please sign in to comment.