Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option for not failing exhaustiveness check if all implementations of the sumtype are covered by interfaces #15

Merged
merged 3 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ passing checks, set the `-default-signifies-exhasutive=false` flag.
As a special case, if the type switch statement contains a `default` clause
that always panics, then exhaustiveness checks are still performed.

By default, `go-check-sumtype` will not include shared interfaces in the exhaustiviness check.
This can be changed by setting the `-include-shared-interfaces=true` flag.
When this flag is set, `go-check-sumtype` will not require that all concrete structs
are listed in the switch statement, as long as the switch statement is exhaustive
with respect to interfaces the structs implement.

## Details and motivation

Sum types are otherwise known as discriminated unions. That is, a sum type is
Expand Down
6 changes: 5 additions & 1 deletion check.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ func missingVariantsInSwitch(
) (*sumTypeDef, []types.Object) {
asserted := findTypeAssertExpr(swtch)
ty := pkg.TypesInfo.TypeOf(asserted)
if ty == nil {
panic(fmt.Sprintf("no type found for asserted expression: %v", asserted))
}

def := findDef(defs, ty)
if def == nil {
// We couldn't find a corresponding sum type, so there's
Expand All @@ -107,7 +111,7 @@ func missingVariantsInSwitch(
for _, expr := range variantExprs {
variantTypes = append(variantTypes, pkg.TypesInfo.TypeOf(expr))
}
return def, def.missing(variantTypes)
return def, def.missing(variantTypes, config.IncludeSharedInterfaces)
}

// switchVariants returns all case expressions found in a type switch. This
Expand Down
78 changes: 78 additions & 0 deletions check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,84 @@ func main() {}
assert.Equal(t, "T", errs[0].(notInterfaceError).Decl.TypeName)
}

// TestSubTypeInSwitch tests that if a shared interface is declared in the switch
// statement, we don't report an error if structs that implement the interface are not explicitly
// declared in the switch statement.
func TestSubTypeInSwitch(t *testing.T) {
code := `
package gochecksumtype

//sumtype:decl
type T1 interface { sealed1() }
type T2 interface {
T1
sealed2()
}


type A struct {}
func (a *A) sealed1() {}

type B struct {}
func (b *B) sealed1() {}
func (b *B) sealed2() {}

type C struct {}
func (c *C) sealed1() {}
func (c *C) sealed2() {}

func main() {
switch T1(nil).(type) {
case *A:
case T2:
}
}
`
pkgs := setupPackages(t, code)

errs := Run(pkgs, Config{IncludeSharedInterfaces: true})
assert.Equal(t, 0, len(errs))
}

// TestAllLeavesInSwitch tests that we do not report an error if a switch statement
// covers all leaves of the sum type, even if any SubTypes are not explicitly covered
func TestAllLeavesInSwitch(t *testing.T) {
code := `
package gochecksumtype

//sumtype:decl
type T1 interface { sealed1() }
type T2 interface {
T1
sealed2()
}


type A struct {}
func (a *A) sealed1() {}

type B struct {}
func (b *B) sealed1() {}
func (b *B) sealed2() {}

type C struct {}
func (c *C) sealed1() {}
func (c *C) sealed2() {}

func main() {
switch T1(nil).(type) {
case *A:
case *B:
case *C:
}
}
`
pkgs := setupPackages(t, code)

errs := Run(pkgs, Config{})
assert.Equal(t, 0, len(errs))
}

func missingNames(t *testing.T, err error) []string {
t.Helper()
ierr, ok := err.(inexhaustiveError)
Expand Down
7 changes: 7 additions & 0 deletions cmd/go-check-sumtype/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ func main() {
"Presence of \"default\" case in switch statements satisfies exhaustiveness, if all members are not listed.",
)

includeSharedInterfaces := flag.Bool(
"include-shared-interfaces",
false,
"Include shared interfaces in the exhaustiviness check.",
)

flag.Parse()
if flag.NArg() < 1 {
log.Fatalf("Usage: sumtype <packages>\n")
Expand All @@ -27,6 +33,7 @@ func main() {

config := gochecksumtype.Config{
DefaultSignifiesExhaustive: *defaultSignifiesExhaustive,
IncludeSharedInterfaces: *includeSharedInterfaces,
}

conf := &packages.Config{
Expand Down
3 changes: 3 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@ package gochecksumtype

type Config struct {
DefaultSignifiesExhaustive bool
// IncludeSharedInterfaces in the exhaustiviness check. If true, we do not need to list all concrete structs, as long
// as the switch statement is exhaustive with respect to interfaces the structs implement.
IncludeSharedInterfaces bool
}
26 changes: 24 additions & 2 deletions def.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (def *sumTypeDef) String() string {

// missing returns a list of variants in this sum type that are not in the
// given list of types.
func (def *sumTypeDef) missing(tys []types.Type) []types.Object {
func (def *sumTypeDef) missing(tys []types.Type, includeSharedInterfaces bool) []types.Object {
// TODO(ag): This is O(n^2). Fix that. /shrug
var missing []types.Object
for _, v := range def.Variants {
Expand All @@ -155,19 +155,41 @@ func (def *sumTypeDef) missing(tys []types.Type) []types.Object {
ty = indirect(ty)
if types.Identical(varty, ty) {
found = true
break
}
if includeSharedInterfaces && implements(varty, ty) {
found = true
break
}
}
if !found {
if !found && !isInterface(varty) {
// we do not include interfaces extending the sumtype, as the
// all implementations of those interfaces are already covered
// by the sumtype.
missing = append(missing, v)
}
}
return missing
}

func isInterface(ty types.Type) bool {
underlying := indirect(ty).Underlying()
_, ok := underlying.(*types.Interface)
return ok
}

// indirect dereferences through an arbitrary number of pointer types.
func indirect(ty types.Type) types.Type {
if ty, ok := ty.(*types.Pointer); ok {
return indirect(ty.Elem())
}
return ty
}

func implements(varty, interfaceType types.Type) bool {
underlying := interfaceType.Underlying()
if interf, ok := underlying.(*types.Interface); ok {
return types.Implements(varty, interf) || types.Implements(types.NewPointer(varty), interf)
}
return false
}
Loading