Skip to content

Commit

Permalink
Option for not failing exhaustiveness check if all implementations of…
Browse files Browse the repository at this point in the history
… the sumtype are covered by interfaces (#15)
  • Loading branch information
jvmakine authored Dec 12, 2024
1 parent dfe7244 commit 468e7aa
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 3 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ passing checks, set the `-default-signifies-exhasutive=false` flag.
As a special case, if the type switch statement contains a `default` clause
that always panics, then exhaustiveness checks are still performed.

By default, `go-check-sumtype` will not include shared interfaces in the exhaustiviness check.
This can be changed by setting the `-include-shared-interfaces=true` flag.
When this flag is set, `go-check-sumtype` will not require that all concrete structs
are listed in the switch statement, as long as the switch statement is exhaustive
with respect to interfaces the structs implement.

## Details and motivation

Sum types are otherwise known as discriminated unions. That is, a sum type is
Expand Down
6 changes: 5 additions & 1 deletion check.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ func missingVariantsInSwitch(
) (*sumTypeDef, []types.Object) {
asserted := findTypeAssertExpr(swtch)
ty := pkg.TypesInfo.TypeOf(asserted)
if ty == nil {
panic(fmt.Sprintf("no type found for asserted expression: %v", asserted))
}

def := findDef(defs, ty)
if def == nil {
// We couldn't find a corresponding sum type, so there's
Expand All @@ -107,7 +111,7 @@ func missingVariantsInSwitch(
for _, expr := range variantExprs {
variantTypes = append(variantTypes, pkg.TypesInfo.TypeOf(expr))
}
return def, def.missing(variantTypes)
return def, def.missing(variantTypes, config.IncludeSharedInterfaces)
}

// switchVariants returns all case expressions found in a type switch. This
Expand Down
78 changes: 78 additions & 0 deletions check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,84 @@ func main() {}
assert.Equal(t, "T", errs[0].(notInterfaceError).Decl.TypeName)
}

// TestSubTypeInSwitch tests that if a shared interface is declared in the switch
// statement, we don't report an error if structs that implement the interface are not explicitly
// declared in the switch statement.
func TestSubTypeInSwitch(t *testing.T) {
code := `
package gochecksumtype
//sumtype:decl
type T1 interface { sealed1() }
type T2 interface {
T1
sealed2()
}
type A struct {}
func (a *A) sealed1() {}
type B struct {}
func (b *B) sealed1() {}
func (b *B) sealed2() {}
type C struct {}
func (c *C) sealed1() {}
func (c *C) sealed2() {}
func main() {
switch T1(nil).(type) {
case *A:
case T2:
}
}
`
pkgs := setupPackages(t, code)

errs := Run(pkgs, Config{IncludeSharedInterfaces: true})
assert.Equal(t, 0, len(errs))
}

// TestAllLeavesInSwitch tests that we do not report an error if a switch statement
// covers all leaves of the sum type, even if any SubTypes are not explicitly covered
func TestAllLeavesInSwitch(t *testing.T) {
code := `
package gochecksumtype
//sumtype:decl
type T1 interface { sealed1() }
type T2 interface {
T1
sealed2()
}
type A struct {}
func (a *A) sealed1() {}
type B struct {}
func (b *B) sealed1() {}
func (b *B) sealed2() {}
type C struct {}
func (c *C) sealed1() {}
func (c *C) sealed2() {}
func main() {
switch T1(nil).(type) {
case *A:
case *B:
case *C:
}
}
`
pkgs := setupPackages(t, code)

errs := Run(pkgs, Config{})
assert.Equal(t, 0, len(errs))
}

func missingNames(t *testing.T, err error) []string {
t.Helper()
ierr, ok := err.(inexhaustiveError)
Expand Down
7 changes: 7 additions & 0 deletions cmd/go-check-sumtype/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ func main() {
"Presence of \"default\" case in switch statements satisfies exhaustiveness, if all members are not listed.",
)

includeSharedInterfaces := flag.Bool(
"include-shared-interfaces",
false,
"Include shared interfaces in the exhaustiviness check.",
)

flag.Parse()
if flag.NArg() < 1 {
log.Fatalf("Usage: sumtype <packages>\n")
Expand All @@ -27,6 +33,7 @@ func main() {

config := gochecksumtype.Config{
DefaultSignifiesExhaustive: *defaultSignifiesExhaustive,
IncludeSharedInterfaces: *includeSharedInterfaces,
}

conf := &packages.Config{
Expand Down
3 changes: 3 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@ package gochecksumtype

type Config struct {
DefaultSignifiesExhaustive bool
// IncludeSharedInterfaces in the exhaustiviness check. If true, we do not need to list all concrete structs, as long
// as the switch statement is exhaustive with respect to interfaces the structs implement.
IncludeSharedInterfaces bool
}
26 changes: 24 additions & 2 deletions def.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (def *sumTypeDef) String() string {

// missing returns a list of variants in this sum type that are not in the
// given list of types.
func (def *sumTypeDef) missing(tys []types.Type) []types.Object {
func (def *sumTypeDef) missing(tys []types.Type, includeSharedInterfaces bool) []types.Object {
// TODO(ag): This is O(n^2). Fix that. /shrug
var missing []types.Object
for _, v := range def.Variants {
Expand All @@ -155,19 +155,41 @@ func (def *sumTypeDef) missing(tys []types.Type) []types.Object {
ty = indirect(ty)
if types.Identical(varty, ty) {
found = true
break
}
if includeSharedInterfaces && implements(varty, ty) {
found = true
break
}
}
if !found {
if !found && !isInterface(varty) {
// we do not include interfaces extending the sumtype, as the
// all implementations of those interfaces are already covered
// by the sumtype.
missing = append(missing, v)
}
}
return missing
}

func isInterface(ty types.Type) bool {
underlying := indirect(ty).Underlying()
_, ok := underlying.(*types.Interface)
return ok
}

// indirect dereferences through an arbitrary number of pointer types.
func indirect(ty types.Type) types.Type {
if ty, ok := ty.(*types.Pointer); ok {
return indirect(ty.Elem())
}
return ty
}

func implements(varty, interfaceType types.Type) bool {
underlying := interfaceType.Underlying()
if interf, ok := underlying.(*types.Interface); ok {
return types.Implements(varty, interf) || types.Implements(types.NewPointer(varty), interf)
}
return false
}

0 comments on commit 468e7aa

Please sign in to comment.