Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

types: make WithoutTypeModifier less error prone #75839

Merged
merged 1 commit into from
Feb 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 19 additions & 41 deletions pkg/sql/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,15 @@ func MakeChar(width int32) *T {
Family: StringFamily, Oid: oid.T_bpchar, Width: width, Locale: &emptyLocale}}
}

// oidCanBeCollatedString returns true if the given oid is can be a CollatedString.
func oidCanBeCollatedString(o oid.Oid) bool {
switch o {
case oid.T_text, oid.T_varchar, oid.T_bpchar, oid.T_char, oid.T_name:
return true
}
return false
}

// MakeCollatedString constructs a new instance of a CollatedStringFamily type
// that is collated according to the given locale. The new type is based upon
// the given string type, having the same oid and width values. For example:
Expand All @@ -851,8 +860,7 @@ func MakeChar(width int32) *T {
// VARCHAR(20) => VARCHAR(20) COLLATE EN
//
func MakeCollatedString(strType *T, locale string) *T {
switch strType.Oid() {
case oid.T_text, oid.T_varchar, oid.T_bpchar, oid.T_char, oid.T_name:
if oidCanBeCollatedString(strType.Oid()) {
return &T{InternalType: InternalType{
Family: CollatedStringFamily, Oid: strType.Oid(), Width: strType.Width(), Locale: &locale}}
}
Expand Down Expand Up @@ -1278,50 +1286,20 @@ func (t *T) WithoutTypeModifiers() *T {
return t
}

switch t.Oid() {
case oid.T_bit:
return MakeBit(0)
case oid.T_bpchar, oid.T_char, oid.T_text, oid.T_varchar:
// For string-like types, we copy the type and set the width to 0 rather
// than returning typeBpChar, typeQChar, VarChar, or String so that
// we retain the locale value if the type is collated.
// For types that can be a collated string, we copy the type and set the width
// to 0 rather than returning the default OidToType type so that we retain the
// locale value if the type is collated.
if oidCanBeCollatedString(t.Oid()) {
newT := *t
newT.InternalType.Width = 0
return &newT
case oid.T_interval:
return Interval
case oid.T_numeric:
return Decimal
case oid.T_time:
return Time
case oid.T_timestamp:
return Timestamp
case oid.T_timestamptz:
return TimestampTZ
case oid.T_timetz:
return TimeTZ
case oid.T_varbit:
return VarBit
case oid.T_anyelement,
oid.T_bool,
oid.T_bytea,
oid.T_date,
oidext.T_box2d,
oid.T_float4, oid.T_float8,
oidext.T_geography, oidext.T_geometry,
oid.T_inet,
oid.T_int2, oid.T_int4, oid.T_int8,
oid.T_jsonb,
oid.T_name,
oid.T_oid,
oid.T_regclass, oid.T_regnamespace, oid.T_regproc, oid.T_regprocedure, oid.T_regrole, oid.T_regtype,
oid.T_unknown,
oid.T_uuid,
oid.T_void:
return t
default:
}

t, ok := OidToType[t.Oid()]
if !ok {
panic(errors.AssertionFailedf("unexpected OID: %d", t.Oid()))
}
return t
}

// Scale is an alias method for Width, used for clarity for types in
Expand Down
10 changes: 7 additions & 3 deletions pkg/sql/types/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,8 @@ func TestWithoutTypeModifiers(t *testing.T) {
{MakeArray(MakeDecimal(5, 1)), DecimalArray},
{MakeTuple([]*T{MakeString(2), Time, MakeDecimal(5, 1)}),
MakeTuple([]*T{String, Time, Decimal})},
{MakeGeography(geopb.ShapeType_Point, 3857), Geography},
{MakeGeometry(geopb.ShapeType_PointZ, 4326), Geometry},

// Types without modifiers.
{Bool, Bool},
Expand All @@ -1026,8 +1028,10 @@ func TestWithoutTypeModifiers(t *testing.T) {
}

for _, tc := range testCases {
if actual := tc.t.WithoutTypeModifiers(); !actual.Identical(tc.expected) {
t.Errorf("expected <%v>, got <%v>", tc.expected.DebugString(), actual.DebugString())
}
t.Run(tc.t.SQLString(), func(t *testing.T) {
if actual := tc.t.WithoutTypeModifiers(); !actual.Identical(tc.expected) {
t.Errorf("expected <%v>, got <%v>", tc.expected.DebugString(), actual.DebugString())
}
})
}
}