diff --git a/pkg/sql/types/types.go b/pkg/sql/types/types.go index 77731c5c9554..4730f087e731 100644 --- a/pkg/sql/types/types.go +++ b/pkg/sql/types/types.go @@ -843,6 +843,15 @@ func MakeChar(width int32) *T { Family: StringFamily, Oid: oid.T_bpchar, Width: width, Locale: &emptyLocale}} } +// oidCanBeCollatedString returns true if the given oid is can be a CollatedString. +func oidCanBeCollatedString(o oid.Oid) bool { + switch o { + case oid.T_text, oid.T_varchar, oid.T_bpchar, oid.T_char, oid.T_name: + return true + } + return false +} + // MakeCollatedString constructs a new instance of a CollatedStringFamily type // that is collated according to the given locale. The new type is based upon // the given string type, having the same oid and width values. For example: @@ -851,8 +860,7 @@ func MakeChar(width int32) *T { // VARCHAR(20) => VARCHAR(20) COLLATE EN // func MakeCollatedString(strType *T, locale string) *T { - switch strType.Oid() { - case oid.T_text, oid.T_varchar, oid.T_bpchar, oid.T_char, oid.T_name: + if oidCanBeCollatedString(strType.Oid()) { return &T{InternalType: InternalType{ Family: CollatedStringFamily, Oid: strType.Oid(), Width: strType.Width(), Locale: &locale}} } @@ -1278,50 +1286,20 @@ func (t *T) WithoutTypeModifiers() *T { return t } - switch t.Oid() { - case oid.T_bit: - return MakeBit(0) - case oid.T_bpchar, oid.T_char, oid.T_text, oid.T_varchar: - // For string-like types, we copy the type and set the width to 0 rather - // than returning typeBpChar, typeQChar, VarChar, or String so that - // we retain the locale value if the type is collated. + // For types that can be a collated string, we copy the type and set the width + // to 0 rather than returning the default OidToType type so that we retain the + // locale value if the type is collated. + if oidCanBeCollatedString(t.Oid()) { newT := *t newT.InternalType.Width = 0 return &newT - case oid.T_interval: - return Interval - case oid.T_numeric: - return Decimal - case oid.T_time: - return Time - case oid.T_timestamp: - return Timestamp - case oid.T_timestamptz: - return TimestampTZ - case oid.T_timetz: - return TimeTZ - case oid.T_varbit: - return VarBit - case oid.T_anyelement, - oid.T_bool, - oid.T_bytea, - oid.T_date, - oidext.T_box2d, - oid.T_float4, oid.T_float8, - oidext.T_geography, oidext.T_geometry, - oid.T_inet, - oid.T_int2, oid.T_int4, oid.T_int8, - oid.T_jsonb, - oid.T_name, - oid.T_oid, - oid.T_regclass, oid.T_regnamespace, oid.T_regproc, oid.T_regprocedure, oid.T_regrole, oid.T_regtype, - oid.T_unknown, - oid.T_uuid, - oid.T_void: - return t - default: + } + + t, ok := OidToType[t.Oid()] + if !ok { panic(errors.AssertionFailedf("unexpected OID: %d", t.Oid())) } + return t } // Scale is an alias method for Width, used for clarity for types in diff --git a/pkg/sql/types/types_test.go b/pkg/sql/types/types_test.go index 057497e27e62..3d5d38d8ba6a 100644 --- a/pkg/sql/types/types_test.go +++ b/pkg/sql/types/types_test.go @@ -1006,6 +1006,8 @@ func TestWithoutTypeModifiers(t *testing.T) { {MakeArray(MakeDecimal(5, 1)), DecimalArray}, {MakeTuple([]*T{MakeString(2), Time, MakeDecimal(5, 1)}), MakeTuple([]*T{String, Time, Decimal})}, + {MakeGeography(geopb.ShapeType_Point, 3857), Geography}, + {MakeGeometry(geopb.ShapeType_PointZ, 4326), Geometry}, // Types without modifiers. {Bool, Bool}, @@ -1026,8 +1028,10 @@ func TestWithoutTypeModifiers(t *testing.T) { } for _, tc := range testCases { - if actual := tc.t.WithoutTypeModifiers(); !actual.Identical(tc.expected) { - t.Errorf("expected <%v>, got <%v>", tc.expected.DebugString(), actual.DebugString()) - } + t.Run(tc.t.SQLString(), func(t *testing.T) { + if actual := tc.t.WithoutTypeModifiers(); !actual.Identical(tc.expected) { + t.Errorf("expected <%v>, got <%v>", tc.expected.DebugString(), actual.DebugString()) + } + }) } }