diff --git a/plugin/key_column_qual_map.go b/plugin/key_column_qual_map.go index 05ca4fa1..22f54d10 100644 --- a/plugin/key_column_qual_map.go +++ b/plugin/key_column_qual_map.go @@ -39,22 +39,22 @@ func (m KeyColumnQualMap) String() string { return strings.Join(strs, "\n") } -func (m KeyColumnQualMap) SatisfiesKeyColumns(columns KeyColumnSlice) (bool, KeyColumnSlice) { - log.Printf("[TRACE] SatisfiesKeyColumns %v", columns) +func (m KeyColumnQualMap) GetUnsatisfiedKeyColumns(columns KeyColumnSlice) KeyColumnSlice { + log.Printf("[TRACE] GetUnsatisfiedKeyColumns %v", columns) if columns == nil { - return true, nil + return nil } var unsatisfiedKeyColumns KeyColumnSlice - satisfiedCount := map[string]int{ - Required: 0, - AnyOf: 0, - Optional: 0, + satisfiedMap := map[string]KeyColumnSlice{ + Required: {}, + AnyOf: {}, + Optional: {}, } - unsatisfiedCount := map[string]int{ - Required: 0, - AnyOf: 0, - Optional: 0, + unsatisfiedMap := map[string]KeyColumnSlice{ + Required: {}, + AnyOf: {}, + Optional: {}, } for _, keyColumn := range columns { @@ -62,25 +62,33 @@ func (m KeyColumnQualMap) SatisfiesKeyColumns(columns KeyColumnSlice) (bool, Key k := m[keyColumn.Name] satisfied := k != nil && k.SatisfiesKeyColumn(keyColumn) if satisfied { - satisfiedCount[keyColumn.Require]++ - + satisfiedMap[keyColumn.Require] = append(satisfiedMap[keyColumn.Require], keyColumn) log.Printf("[TRACE] key column satisfied %v", keyColumn) } else { - unsatisfiedCount[keyColumn.Require]++ - unsatisfiedKeyColumns = append(unsatisfiedKeyColumns, keyColumn) + unsatisfiedMap[keyColumn.Require] = append(unsatisfiedMap[keyColumn.Require], keyColumn) log.Printf("[TRACE] key column NOT satisfied %v", keyColumn) - // if this was NOT an optional key column, we are not satisfied } } // we are satisfied if: // all Required key columns are satisfied // either there is at least 1 satisfied AnyOf key columns, or there are no AnyOf columns - res := unsatisfiedCount[Required] == 0 && (satisfiedCount[AnyOf] > 0 || unsatisfiedCount[AnyOf] == 0) + anyOfSatisfied := len(satisfiedMap[AnyOf]) > 0 || len(unsatisfiedMap[AnyOf]) == 0 + if !anyOfSatisfied { + unsatisfiedKeyColumns = unsatisfiedMap[AnyOf] + } + // if any 'required' are unsatisfied, we are unsatisfied + requiredSatisfied := len(unsatisfiedMap[Required]) == 0 + if !requiredSatisfied { + unsatisfiedKeyColumns = append(unsatisfiedKeyColumns, unsatisfiedMap[Required]...) + } + + log.Printf("[TRACE] satisfied: %v", satisfiedMap) + log.Printf("[TRACE] unsatisfied: %v", unsatisfiedMap) + log.Printf("[TRACE] unsatisfied required KeyColumns %v", unsatisfiedKeyColumns) - log.Printf("[TRACE] SatisfiesKeyColumns result: %v\nsatisfiedCount %v\nunsatisfiedCount %v\nunsatisfiedKeyColumns %v", res, satisfiedCount, unsatisfiedCount, unsatisfiedKeyColumns) - return res, unsatisfiedKeyColumns + return unsatisfiedKeyColumns } // ToQualMap converts the map into a simpler map of column to []Quals diff --git a/plugin/query_data.go b/plugin/query_data.go index 81699feb..6bb1a10b 100644 --- a/plugin/query_data.go +++ b/plugin/query_data.go @@ -195,7 +195,7 @@ func (d *QueryData) setFetchType(table *Table) { // build a qual map from Get key columns qualMap := NewKeyColumnQualValueMap(d.QueryContext.UnsafeQuals, table.Get.KeyColumns) // now see whether the qual map has everything required for the get call - if satisfied, _ := qualMap.SatisfiesKeyColumns(table.Get.KeyColumns); satisfied { + if unsatisfiedColumns := qualMap.GetUnsatisfiedKeyColumns(table.Get.KeyColumns); len(unsatisfiedColumns) == 0 { log.Printf("[TRACE] Set fetchType to fetchTypeGet") d.KeyColumnQuals = qualMap.ToEqualsQualValueMap() d.Quals = qualMap diff --git a/plugin/table_fetch.go b/plugin/table_fetch.go index cfd4a99f..e7bf6829 100644 --- a/plugin/table_fetch.go +++ b/plugin/table_fetch.go @@ -300,9 +300,9 @@ func (t *Table) executeListCall(ctx context.Context, queryData *QueryData) { }() // verify we have the necessary quals - isSatisfied, unsatisfiedColumns := queryData.Quals.SatisfiesKeyColumns(t.List.KeyColumns) - if !isSatisfied { - err := status.Error(codes.Internal, fmt.Sprintf("'List' call is missing required quals: \n%s", unsatisfiedColumns.String())) + unsatisfiedColumns := queryData.Quals.GetUnsatisfiedKeyColumns(t.List.KeyColumns) + if len(unsatisfiedColumns) > 0 { + err := status.Error(codes.Internal, fmt.Sprintf("'List' call is missing required quals: %s", unsatisfiedColumns.String())) queryData.streamError(err) return }