Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements and bug fixes #47

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions pkg/dao/dinosaur.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,28 @@ import (

"github.com/openshift-online/rh-trex/pkg/api"
"github.com/openshift-online/rh-trex/pkg/db"
"github.com/openshift-online/rh-trex/pkg/util"
)

var (
dinosaurTableName = util.ToSnakeCase(api.DinosaurTypeName) + "s"
dinosaurColumns = []string{
Copy link
Contributor

@tiwillia tiwillia Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means to support searching a new column, it must be added to this list?

I see clusters-service implementing this requirement in this way, but AMS/OSDFM seem to handle this differently.

I'm not sure which is better tbh, the AMS implementation is explicit while this implementation is simpler. @markturansky can you provide your input here?

Copy link
Contributor Author

@gdbranco gdbranco Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'd like to eventually be able to move away from this to have something that uses reflect or similar to have it be more dynamic

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd prefer to not explicitly set casing, as that negates the casing config allowed by gorm (that also does things like table name prefix, etc). can we do this through gorm somehow,?

"id",
"created_at",
"updated_at",
"species",
}
)

func DinosaurApiToModel() TableMappingRelation {
result := map[string]string{}
applyBaseMapping(result, dinosaurColumns, dinosaurTableName)
return TableMappingRelation{
Mapping: result,
relationTableName: dinosaurTableName,
}
}

type DinosaurDao interface {
Get(ctx context.Context, id string) (*api.Dinosaur, error)
Create(ctx context.Context, dinosaur *api.Dinosaur) (*api.Dinosaur, error)
Expand Down
36 changes: 36 additions & 0 deletions pkg/dao/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dao

import (
"context"
"fmt"
"strings"

"github.com/jinzhu/inflection"
Expand All @@ -10,6 +11,41 @@ import (
"github.com/openshift-online/rh-trex/pkg/db"
)

type TableMappingRelation struct {
Mapping map[string]string
relationTableName string
}

type relationMapping func() TableMappingRelation

func applyBaseMapping(result map[string]string, columns []string, tableName string) {
for _, c := range columns {
mappingKey := c
mappingValue := fmt.Sprintf("%s.%s", tableName, c)
columnParts := strings.Split(c, ".")
if len(columnParts) == 1 {
mappingKey = mappingValue
}
if len(columnParts) == 2 {
mappingValue = strings.Split(mappingKey, ".")[1]
}
result[mappingKey] = mappingValue
}
}

func applyRelationMapping(result map[string]string, relations []relationMapping) {
for _, relation := range relations {
tableMappingRelation := relation()
for k, v := range tableMappingRelation.Mapping {
if _, ok := result[k]; ok {
result[tableMappingRelation.relationTableName+"."+k] = v
} else {
result[k] = v
}
}
}
}

type Where struct {
sql string
values []any
Expand Down
63 changes: 63 additions & 0 deletions pkg/dao/generic_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package dao

import (
"fmt"
"strings"

. "github.com/onsi/ginkgo/v2/dsl/core"
. "github.com/onsi/gomega"
)

var _ = Describe("applyBaseMapping", func() {
It("generates base mapping", func() {
result := map[string]string{}
applyBaseMapping(result, []string{"id", "created_at", "column1", "nested.field"}, "test_table")
for k, v := range result {
if strings.HasPrefix(k, "test_table") {
Expect(k).To(Equal(v))
continue
}
// nested fields from table
i := strings.Index(k, ".")
Expect(k[i+1:]).To(Equal(v))
}
})
})

var _ = Describe("applyRelationMapping", func() {
It("generates relation mapping", func() {
result := map[string]string{}
applyBaseMapping(result, []string{"id", "created_at", "column1", "nested.field"}, "base_table")
applyRelationMapping(result, []relationMapping{
func() TableMappingRelation {
result := map[string]string{}
applyBaseMapping(result, []string{"id", "created_at", "column1", "nested.field"}, "relation_table")
return TableMappingRelation{
relationTableName: "relation_table",
Mapping: result,
}
},
})
for k, v := range result {
if strings.HasPrefix(k, "base_table") {
Expect(k).To(Equal(v))
continue
}
if strings.HasPrefix(k, "relation_table") {
if c := strings.Count(k, "."); c > 1 {
i := strings.Index(k, ".")
i = strings.Index(k[i+1:], ".") + i
Expect(k[i+2:]).To(Equal(v))
continue
}
Expect(k).To(Equal(v))
continue
}

// nested fields from base table
i := strings.Index(k, ".")
Expect(k[i+1:]).To(Equal(v))
fmt.Println(k, v)
}
})
})
71 changes: 41 additions & 30 deletions pkg/db/sql_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package db
import (
"fmt"
"reflect"
"slices"
"strings"

"github.com/jinzhu/inflection"
Expand All @@ -11,6 +12,11 @@ import (
"gorm.io/gorm"
)

const (
invalidFieldNameMsg = "%s is not a valid field name"
disallowedFieldNameMsg = "%s is a disallowed field name"
)

// Check if a field name starts with properties.
func startsWithProperties(s string) bool {
return strings.HasPrefix(s, "properties.")
Expand All @@ -33,34 +39,33 @@ func hasProperty(n tsl.Node) bool {
}

// getField gets the sql field associated with a name.
func getField(name string, disallowedFields map[string]string) (field string, err *errors.ServiceError) {
func getField(
name string,
disallowedFields []string,
apiToModel map[string]string,
) (field string, err *errors.ServiceError) {
// We want to accept names with trailing and leading spaces
trimmedName := strings.Trim(name, " ")

// Check for properties ->> '<some field name>'
if strings.HasPrefix(trimmedName, "properties ->>") {
field = trimmedName
return
mappedField, ok := apiToModel[trimmedName]
if !ok {
return "", errors.BadRequest(invalidFieldNameMsg, name)
}

// Check for nested field, e.g., subscription_labels.key
checkName := trimmedName
fieldParts := strings.Split(trimmedName, ".")
checkName := mappedField
fieldParts := strings.Split(checkName, ".")
if len(fieldParts) > 2 {
err = errors.BadRequest("%s is not a valid field name", name)
err = errors.BadRequest(invalidFieldNameMsg, name)
return
}
if len(fieldParts) > 1 {
checkName = fieldParts[1]
}

// Check for allowed fields
_, ok := disallowedFields[checkName]
if ok {
err = errors.BadRequest("%s is not a valid field name", name)
if slices.Contains(disallowedFields, checkName) {
err = errors.BadRequest(disallowedFieldNameMsg, name)
return
}
field = trimmedName
field = checkName
return
}

Expand Down Expand Up @@ -102,7 +107,8 @@ func propertiesNodeConverter(n tsl.Node) tsl.Node {
// b. replace the field name with the SQL column name.
func FieldNameWalk(
n tsl.Node,
disallowedFields map[string]string) (newNode tsl.Node, err *errors.ServiceError) {
disallowedFields []string,
apiToModel map[string]string) (newNode tsl.Node, err *errors.ServiceError) {

var field string
var l, r tsl.Node
Expand All @@ -124,7 +130,7 @@ func FieldNameWalk(
}

// Check field name in the disallowedFields field names.
field, err = getField(userFieldName, disallowedFields)
field, err = getField(userFieldName, disallowedFields, apiToModel)
if err != nil {
return
}
Expand All @@ -137,7 +143,7 @@ func FieldNameWalk(
default:
// o/w continue walking the tree.
if n.Left != nil {
l, err = FieldNameWalk(n.Left.(tsl.Node), disallowedFields)
l, err = FieldNameWalk(n.Left.(tsl.Node), disallowedFields, apiToModel)
if err != nil {
return
}
Expand All @@ -148,7 +154,7 @@ func FieldNameWalk(
switch v := n.Right.(type) {
case tsl.Node:
// It's a regular node, just add it.
r, err = FieldNameWalk(v, disallowedFields)
r, err = FieldNameWalk(v, disallowedFields, apiToModel)
if err != nil {
return
}
Expand All @@ -162,7 +168,7 @@ func FieldNameWalk(

// Add all nodes in the right side array.
for _, e := range v {
r, err = FieldNameWalk(e, disallowedFields)
r, err = FieldNameWalk(e, disallowedFields, apiToModel)
if err != nil {
return
}
Expand All @@ -189,23 +195,26 @@ func FieldNameWalk(
}

// cleanOrderBy takes the orderBy arg and cleans it.
func cleanOrderBy(userArg string, disallowedFields map[string]string) (orderBy string, err *errors.ServiceError) {
func cleanOrderBy(userArg string,
disallowedFields []string,
apiToModel map[string]string,
tableName string) (orderBy string, err *errors.ServiceError) {
var orderField string

// We want to accept user params with trailing and leading spaces
trimedName := strings.Trim(userArg, " ")

// Each OrderBy can be a "<field-name>" or a "<field-name> asc|desc"
order := strings.Split(trimedName, " ")
direction := "none valid"

if len(order) == 1 {
orderField, err = getField(order[0], disallowedFields)
direction = "asc"
} else if len(order) == 2 {
orderField, err = getField(order[0], disallowedFields)
direction := "asc"
if len(order) == 2 {
direction = order[1]
}
field := order[0]
if orderParts := strings.Split(order[0], "."); len(orderParts) == 1 {
field = fmt.Sprintf("%s.%s", tableName, field)
}
orderField, err = getField(field, disallowedFields, apiToModel)
if err != nil || (direction != "asc" && direction != "desc") {
err = errors.BadRequest("bad order value '%s'", userArg)
return
Expand All @@ -218,13 +227,15 @@ func cleanOrderBy(userArg string, disallowedFields map[string]string) (orderBy s
// ArgsToOrderBy returns cleaned orderBy list.
func ArgsToOrderBy(
orderByArgs []string,
disallowedFields map[string]string) (orderBy []string, err *errors.ServiceError) {
disallowedFields []string,
apiToModel map[string]string,
tableName string) (orderBy []string, err *errors.ServiceError) {

var order string
if len(orderByArgs) != 0 {
orderBy = []string{}
for _, o := range orderByArgs {
order, err = cleanOrderBy(o, disallowedFields)
order, err = cleanOrderBy(o, disallowedFields, apiToModel, tableName)
if err != nil {
return
}
Expand Down
Loading