Skip to content

Commit

Permalink
opt: rename ScalarID to ScalarRank
Browse files Browse the repository at this point in the history
Previously, every scalar expression (except lists and list items) had an
ID that was said to be unique within the context of a memo. These IDs
were originally added as a way to canonically order filters. Being named
"IDs", their use later expanded to check for equality of two scalar
expressions.

Maintaining this uniqueness invariant is difficult in practice and has
dangerous implications when it is violated, as seen in cockroachdb#71002. While two
different scalar expressions with the same ID could certainly cause
problems for sorting filters, using these IDs to check for scalar
expression equality can be catastrophic. For example, a filter
expression that shares an ID with another expression could be completely
removed from the filter.

Unfortunately, there's no obvious way to add test build assertions that
scalar IDs are in fact unique, as explained in cockroachdb#71035. In order to
lessen the blast radius of breaking this invariant, this commit renames
"scalar ID" to "scalar rank". The comment for this attribute does not
explicitly guarantee its uniqueness. This renaming should urge
contributors to only use this value for ordering scalar expressions
canonically, not for scalar expression equality. Instead, pointer
equality should be used to check if two scalar expressions are the same.

Release note: None
  • Loading branch information
mgartner committed Oct 6, 2021
1 parent ecf11fe commit eae5076
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 48 deletions.
4 changes: 2 additions & 2 deletions pkg/sql/opt/memo/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,10 @@ func (n FiltersExpr) OuterCols() opt.ColSet {
return colSet
}

// Sort sorts the FilterItems in n by the IDs of the expression.
// Sort sorts the FilterItems in n by the ranks of the expressions.
func (n *FiltersExpr) Sort() {
sort.Slice(*n, func(i, j int) bool {
return (*n)[i].Condition.(opt.ScalarExpr).ID() < (*n)[j].Condition.(opt.ScalarExpr).ID()
return (*n)[i].Condition.Rank() < (*n)[j].Condition.Rank()
})
}

Expand Down
18 changes: 9 additions & 9 deletions pkg/sql/opt/memo/memo.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ type Memo struct {
dateStyle pgdate.DateStyle
intervalStyle duration.IntervalStyle

// curID is the highest currently in-use scalar expression ID.
curID opt.ScalarID
// curRank is the highest currently in-use scalar expression rank.
curRank opt.ScalarRank

// curWithID is the highest currently in-use WITH ID.
curWithID opt.WithID
Expand Down Expand Up @@ -368,15 +368,15 @@ func (m *Memo) IsOptimized() bool {
return ok && rel.RequiredPhysical() != nil
}

// NextID returns a new unique ScalarID to number expressions with.
func (m *Memo) NextID() opt.ScalarID {
m.curID++
return m.curID
// NextRank returns a new rank that can be assigned to a scalar expression.
func (m *Memo) NextRank() opt.ScalarRank {
m.curRank++
return m.curRank
}

// CopyNextIDFrom copies the next ScalarID from the other memo.
func (m *Memo) CopyNextIDFrom(other *Memo) {
m.curID = other.curID
// CopyNextRankFrom copies the next ScalarRank from the other memo.
func (m *Memo) CopyNextRankFrom(other *Memo) {
m.curRank = other.curRank
}

// RequestColStat calculates and returns the column statistic calculated on the
Expand Down
8 changes: 4 additions & 4 deletions pkg/sql/opt/norm/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,10 @@ func (f *Factory) CopyAndReplace(
panic(errors.AssertionFailedf("destination memo must be empty"))
}

// Copy the next scalar ID to the target memo so that new scalar expressions
// built with the new memo will not share scalar IDs with existing
// expressions.
f.mem.CopyNextIDFrom(from.Memo())
// Copy the next scalar rank to the target memo so that new scalar
// expressions built with the new memo will not share scalar ranks with
// existing expressions.
f.mem.CopyNextRankFrom(from.Memo())

// Copy all metadata to the target memo so that referenced tables and
// columns can keep the same ids they had in the "from" memo. Scalar
Expand Down
16 changes: 10 additions & 6 deletions pkg/sql/opt/norm/select_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,11 @@ func (c *CustomFuncs) mergeSortedAnds(left, right opt.ScalarExpr) opt.ScalarExpr
nextRight = and.Right
}

if nextLeft.ID() == nextRight.ID() {
if nextLeft == nextRight {
// Eliminate duplicates.
return c.mergeSortedAnds(left, remainingRight)
}
if nextLeft.ID() < nextRight.ID() {
if nextLeft.Rank() < nextRight.Rank() {
return c.f.ConstructAnd(c.mergeSortedAnds(left, remainingRight), nextRight)
}
return c.f.ConstructAnd(c.mergeSortedAnds(remainingLeft, right), nextLeft)
Expand All @@ -249,7 +249,7 @@ func (c *CustomFuncs) mergeSortedAnds(left, right opt.ScalarExpr) opt.ScalarExpr
func (c *CustomFuncs) HasDuplicateFilters(f memo.FiltersExpr) bool {
for i := 0; i < len(f); i++ {
for j := i + 1; j < len(f); j++ {
if f[i].Condition.ID() == f[j].Condition.ID() {
if f[i].Condition == f[j].Condition {
return true
}
}
Expand All @@ -259,10 +259,14 @@ func (c *CustomFuncs) HasDuplicateFilters(f memo.FiltersExpr) bool {

// DeduplicateFilters returns the input filters with duplicates removed.
func (c *CustomFuncs) DeduplicateFilters(f memo.FiltersExpr) memo.FiltersExpr {
// Here we sort the filters by their scalar rank, though we don't really
// care that they are fully sorted. To remove duplicates we only care that
// duplicate expressions are grouped together, which they will be since
// their scalar rank must be equal.
result := c.SortFilters(f)
j := 1
for i := 1; i < len(result); i++ {
if result[i].Condition.ID() != result[i-1].Condition.ID() {
if result[i].Condition != result[i-1].Condition {
result[j] = result[i]
j++
}
Expand All @@ -271,10 +275,10 @@ func (c *CustomFuncs) DeduplicateFilters(f memo.FiltersExpr) memo.FiltersExpr {
}

// AreFiltersSorted determines whether the expressions in a FiltersExpr are
// ordered by their expression IDs.
// ordered by their expression ranks.
func (c *CustomFuncs) AreFiltersSorted(f memo.FiltersExpr) bool {
for i := 1; i < len(f); i++ {
if f[i-1].Condition.ID() > f[i].Condition.ID() {
if f[i-1].Condition.Rank() > f[i].Condition.Rank() {
return false
}
}
Expand Down
11 changes: 6 additions & 5 deletions pkg/sql/opt/operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,19 @@ type Expr interface {
String() string
}

// ScalarID is the type of the memo-unique identifier given to every scalar
// ScalarRank is the type of the sort order given to every scalar
// expression.
type ScalarID int
type ScalarRank int

// ScalarExpr is a scalar expression, which is an expression that returns a
// primitive-typed value like boolean or string rather than rows and columns.
type ScalarExpr interface {
Expr

// ID is a unique (within the context of a memo) ID that can be
// used to define a total order over ScalarExprs.
ID() ScalarID
// Rank is a value that defines how the scalar expression should be ordered
// among a collection of scalar expressions. It defines a total order over
// ScalarExprs within the context of a memo.
Rank() ScalarRank

// DataType is the SQL type of the expression.
DataType() *types.T
Expand Down
14 changes: 7 additions & 7 deletions pkg/sql/opt/optgen/cmd/optgen/exprs_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ func (g *exprsGen) genExprStruct(define *lang.DefineExpr) {
fmt.Fprintf(g.w, " scalar props.Scalar\n")
}
} else {
fmt.Fprintf(g.w, " id opt.ScalarID\n")
fmt.Fprintf(g.w, " rank opt.ScalarRank\n")
}
} else if define.Tags.Contains("Enforcer") {
fmt.Fprintf(g.w, " Input RelExpr\n")
Expand All @@ -233,11 +233,11 @@ func (g *exprsGen) genExprFuncs(define *lang.DefineExpr) {
fmt.Fprintf(g.w, "var _ opt.ScalarExpr = &%s{}\n\n", opTyp.name)

// Generate the ID method.
fmt.Fprintf(g.w, "func (e *%s) ID() opt.ScalarID {\n", opTyp.name)
fmt.Fprintf(g.w, "func (e *%s) Rank() opt.ScalarRank {\n", opTyp.name)
if define.Tags.Contains("ListItem") {
fmt.Fprintf(g.w, " return 0\n")
fmt.Fprintf(g.w, " panic(errors.AssertionFailedf(\"list items have no rank\"))")
} else {
fmt.Fprintf(g.w, " return e.id\n")
fmt.Fprintf(g.w, " return e.rank\n")
}
fmt.Fprintf(g.w, "}\n\n")
} else {
Expand Down Expand Up @@ -519,8 +519,8 @@ func (g *exprsGen) genListExprFuncs(define *lang.DefineExpr) {
fmt.Fprintf(g.w, "var _ opt.ScalarExpr = &%s{}\n\n", opTyp.name)

// Generate the ID method.
fmt.Fprintf(g.w, "func (e *%s) ID() opt.ScalarID {\n", opTyp.name)
fmt.Fprintf(g.w, " panic(errors.AssertionFailedf(\"lists have no id\"))")
fmt.Fprintf(g.w, "func (e *%s) Rank() opt.ScalarRank {\n", opTyp.name)
fmt.Fprintf(g.w, " panic(errors.AssertionFailedf(\"lists have no rank\"))")
fmt.Fprintf(g.w, "}\n\n")

// Generate the Op method.
Expand Down Expand Up @@ -619,7 +619,7 @@ func (g *exprsGen) genMemoizeFuncs() {
}

if define.Tags.Contains("Scalar") {
fmt.Fprintf(g.w, " id: m.NextID(),\n")
fmt.Fprintf(g.w, " rank: m.NextRank(),\n")
fmt.Fprintf(g.w, " }\n")
if g.needsDataTypeField(define) {
fmt.Fprintf(g.w, " e.Typ = InferType(m, e)\n")
Expand Down
30 changes: 15 additions & 15 deletions pkg/sql/opt/optgen/cmd/optgen/testdata/exprs
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ var EmptyProjectionsExpr = ProjectionsExpr{}

var _ opt.ScalarExpr = &ProjectionsExpr{}

func (e *ProjectionsExpr) ID() opt.ScalarID {
panic(errors.AssertionFailedf("lists have no id"))
func (e *ProjectionsExpr) Rank() opt.ScalarRank {
panic(errors.AssertionFailedf("lists have no rank"))
}

func (e *ProjectionsExpr) Op() opt.Operator {
Expand Down Expand Up @@ -242,8 +242,8 @@ type ProjectionsItem struct {

var _ opt.ScalarExpr = &ProjectionsItem{}

func (e *ProjectionsItem) ID() opt.ScalarID {
return 0
func (e *ProjectionsItem) Rank() opt.ScalarRank {
panic(errors.AssertionFailedf("list items have no rank"))
}

func (e *ProjectionsItem) Op() opt.Operator {
Expand Down Expand Up @@ -539,14 +539,14 @@ import (
type VariableExpr struct {
Col opt.ColumnID

Typ *types.T
id opt.ScalarID
Typ *types.T
rank opt.ScalarRank
}

var _ opt.ScalarExpr = &VariableExpr{}

func (e *VariableExpr) ID() opt.ScalarID {
return e.id
func (e *VariableExpr) Rank() opt.ScalarRank {
return e.rank
}

func (e *VariableExpr) Op() opt.Operator {
Expand Down Expand Up @@ -582,14 +582,14 @@ func (e *VariableExpr) DataType() *types.T {
type MaxExpr struct {
Input *VariableExpr

Typ *types.T
id opt.ScalarID
Typ *types.T
rank opt.ScalarRank
}

var _ opt.ScalarExpr = &MaxExpr{}

func (e *MaxExpr) ID() opt.ScalarID {
return e.id
func (e *MaxExpr) Rank() opt.ScalarRank {
return e.rank
}

func (e *MaxExpr) Op() opt.Operator {
Expand Down Expand Up @@ -636,8 +636,8 @@ func (m *Memo) MemoizeVariable(
) *VariableExpr {
const size = int64(unsafe.Sizeof(VariableExpr{}))
e := &VariableExpr{
Col: col,
id: m.NextID(),
Col: col,
rank: m.NextRank(),
}
e.Typ = InferType(m, e)
interned := m.interner.InternVariable(e)
Expand All @@ -657,7 +657,7 @@ func (m *Memo) MemoizeMax(
const size = int64(unsafe.Sizeof(MaxExpr{}))
e := &MaxExpr{
Input: input,
id: m.NextID(),
rank: m.NextRank(),
}
e.Typ = InferType(m, e)
interned := m.interner.InternMax(e)
Expand Down

0 comments on commit eae5076

Please sign in to comment.