diff --git a/README.md b/README.md index 2ccec4e..287aa68 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,12 @@ passing checks, set the `-default-signifies-exhasutive=false` flag. As a special case, if the type switch statement contains a `default` clause that always panics, then exhaustiveness checks are still performed. +By default, `go-check-sumtype` will not include shared interfaces in the exhaustiviness check. +This can be changed by setting the `-include-shared-interfaces=true` flag. +When this flag is set, `go-check-sumtype` will not require that all concrete structs +are listed in the switch statement, as long as the switch statement is exhaustive +with respect to interfaces the structs implement. + ## Details and motivation Sum types are otherwise known as discriminated unions. That is, a sum type is diff --git a/check.go b/check.go index e4976ff..ff7fec7 100644 --- a/check.go +++ b/check.go @@ -92,6 +92,10 @@ func missingVariantsInSwitch( ) (*sumTypeDef, []types.Object) { asserted := findTypeAssertExpr(swtch) ty := pkg.TypesInfo.TypeOf(asserted) + if ty == nil { + panic(fmt.Sprintf("no type found for asserted expression: %v", asserted)) + } + def := findDef(defs, ty) if def == nil { // We couldn't find a corresponding sum type, so there's @@ -107,7 +111,7 @@ func missingVariantsInSwitch( for _, expr := range variantExprs { variantTypes = append(variantTypes, pkg.TypesInfo.TypeOf(expr)) } - return def, def.missing(variantTypes) + return def, def.missing(variantTypes, config.IncludeSharedInterfaces) } // switchVariants returns all case expressions found in a type switch. This diff --git a/check_test.go b/check_test.go index 70aece8..5710a21 100644 --- a/check_test.go +++ b/check_test.go @@ -217,6 +217,84 @@ func main() {} assert.Equal(t, "T", errs[0].(notInterfaceError).Decl.TypeName) } +// TestSubTypeInSwitch tests that if a shared interface is declared in the switch +// statement, we don't report an error if structs that implement the interface are not explicitly +// declared in the switch statement. +func TestSubTypeInSwitch(t *testing.T) { + code := ` +package gochecksumtype + +//sumtype:decl +type T1 interface { sealed1() } +type T2 interface { + T1 + sealed2() +} + + +type A struct {} +func (a *A) sealed1() {} + +type B struct {} +func (b *B) sealed1() {} +func (b *B) sealed2() {} + +type C struct {} +func (c *C) sealed1() {} +func (c *C) sealed2() {} + +func main() { + switch T1(nil).(type) { + case *A: + case T2: + } +} +` + pkgs := setupPackages(t, code) + + errs := Run(pkgs, Config{IncludeSharedInterfaces: true}) + assert.Equal(t, 0, len(errs)) +} + +// TestAllLeavesInSwitch tests that we do not report an error if a switch statement +// covers all leaves of the sum type, even if any SubTypes are not explicitly covered +func TestAllLeavesInSwitch(t *testing.T) { + code := ` +package gochecksumtype + +//sumtype:decl +type T1 interface { sealed1() } +type T2 interface { + T1 + sealed2() +} + + +type A struct {} +func (a *A) sealed1() {} + +type B struct {} +func (b *B) sealed1() {} +func (b *B) sealed2() {} + +type C struct {} +func (c *C) sealed1() {} +func (c *C) sealed2() {} + +func main() { + switch T1(nil).(type) { + case *A: + case *B: + case *C: + } +} +` + pkgs := setupPackages(t, code) + + errs := Run(pkgs, Config{}) + assert.Equal(t, 0, len(errs)) +} + func missingNames(t *testing.T, err error) []string { t.Helper() ierr, ok := err.(inexhaustiveError) diff --git a/cmd/go-check-sumtype/main.go b/cmd/go-check-sumtype/main.go index 4d4bec1..f482b32 100644 --- a/cmd/go-check-sumtype/main.go +++ b/cmd/go-check-sumtype/main.go @@ -19,6 +19,12 @@ func main() { "Presence of \"default\" case in switch statements satisfies exhaustiveness, if all members are not listed.", ) + includeSharedInterfaces := flag.Bool( + "include-shared-interfaces", + false, + "Include shared interfaces in the exhaustiviness check.", + ) + flag.Parse() if flag.NArg() < 1 { log.Fatalf("Usage: sumtype \n") @@ -27,6 +33,7 @@ func main() { config := gochecksumtype.Config{ DefaultSignifiesExhaustive: *defaultSignifiesExhaustive, + IncludeSharedInterfaces: *includeSharedInterfaces, } conf := &packages.Config{ diff --git a/config.go b/config.go index 759176e..5c722b7 100644 --- a/config.go +++ b/config.go @@ -2,4 +2,7 @@ package gochecksumtype type Config struct { DefaultSignifiesExhaustive bool + // IncludeSharedInterfaces in the exhaustiviness check. If true, we do not need to list all concrete structs, as long + // as the switch statement is exhaustive with respect to interfaces the structs implement. + IncludeSharedInterfaces bool } diff --git a/def.go b/def.go index df1aa4a..71bdf2f 100644 --- a/def.go +++ b/def.go @@ -145,7 +145,7 @@ func (def *sumTypeDef) String() string { // missing returns a list of variants in this sum type that are not in the // given list of types. -func (def *sumTypeDef) missing(tys []types.Type) []types.Object { +func (def *sumTypeDef) missing(tys []types.Type, includeSharedInterfaces bool) []types.Object { // TODO(ag): This is O(n^2). Fix that. /shrug var missing []types.Object for _, v := range def.Variants { @@ -155,15 +155,29 @@ func (def *sumTypeDef) missing(tys []types.Type) []types.Object { ty = indirect(ty) if types.Identical(varty, ty) { found = true + break + } + if includeSharedInterfaces && implements(varty, ty) { + found = true + break } } - if !found { + if !found && !isInterface(varty) { + // we do not include interfaces extending the sumtype, as the + // all implementations of those interfaces are already covered + // by the sumtype. missing = append(missing, v) } } return missing } +func isInterface(ty types.Type) bool { + underlying := indirect(ty).Underlying() + _, ok := underlying.(*types.Interface) + return ok +} + // indirect dereferences through an arbitrary number of pointer types. func indirect(ty types.Type) types.Type { if ty, ok := ty.(*types.Pointer); ok { @@ -171,3 +185,11 @@ func indirect(ty types.Type) types.Type { } return ty } + +func implements(varty, interfaceType types.Type) bool { + underlying := interfaceType.Underlying() + if interf, ok := underlying.(*types.Interface); ok { + return types.Implements(varty, interf) || types.Implements(types.NewPointer(varty), interf) + } + return false +}