diff --git a/internal/lsp/source/command.go b/internal/lsp/source/command.go index 66d2f1d70d3..2bc3c77bab5 100644 --- a/internal/lsp/source/command.go +++ b/internal/lsp/source/command.go @@ -131,7 +131,7 @@ var ( Title: "Extract to function", suggestedFixFn: extractFunction, appliesFn: func(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, _ *types.Package, info *types.Info) bool { - _, _, _, _, _, ok, _ := canExtractFunction(fset, rng, src, file, info) + _, ok, _ := canExtractFunction(fset, rng, src, file, info) return ok }, } diff --git a/internal/lsp/source/extract.go b/internal/lsp/source/extract.go index 84679dce5e2..411b46508f3 100644 --- a/internal/lsp/source/extract.go +++ b/internal/lsp/source/extract.go @@ -180,11 +180,12 @@ type returnVariable struct { // of the function and insert this call as well as the extracted function into // their proper locations. func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { - tok, path, rng, outer, start, ok, err := canExtractFunction(fset, rng, src, file, info) + p, ok, err := canExtractFunction(fset, rng, src, file, info) if !ok { return nil, fmt.Errorf("extractFunction: cannot extract %s: %v", fset.Position(rng.Start), err) } + tok, path, rng, outer, start := p.tok, p.path, p.rng, p.outer, p.start fileScope := info.Scopes[file] if fileScope == nil { return nil, fmt.Errorf("extractFunction: file scope is empty") @@ -229,8 +230,10 @@ func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast. // we must determine the signature of the extracted function. We will then replace // the block with an assignment statement that calls the extracted function with // the appropriate parameters and return values. - free, vars, assigned, defined := collectFreeVars( - info, file, fileScope, pkgScope, rng, path[0]) + variables, err := collectFreeVars(info, file, fileScope, pkgScope, rng, path[0]) + if err != nil { + return nil, err + } var ( params, returns []ast.Expr // used when calling the extracted function @@ -269,42 +272,38 @@ func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast. // variable in the extracted function. Determine the outcome(s) for each variable // based on whether it is free, altered within the selected block, and used outside // of the selected block. - for _, obj := range vars { - if _, ok := seenVars[obj]; ok { + for _, v := range variables { + if _, ok := seenVars[v.obj]; ok { continue } - typ := analysisinternal.TypeExpr(fset, file, pkg, obj.Type()) + typ := analysisinternal.TypeExpr(fset, file, pkg, v.obj.Type()) if typ == nil { - return nil, fmt.Errorf("nil AST expression for type: %v", obj.Name()) + return nil, fmt.Errorf("nil AST expression for type: %v", v.obj.Name()) } - seenVars[obj] = typ - identifier := ast.NewIdent(obj.Name()) + seenVars[v.obj] = typ + identifier := ast.NewIdent(v.obj.Name()) // An identifier must meet three conditions to become a return value of the // extracted function. (1) its value must be defined or reassigned within // the selection (isAssigned), (2) it must be used at least once after the // selection (isUsed), and (3) its first use after the selection // cannot be its own reassignment or redefinition (objOverriden). - if obj.Parent() == nil { + if v.obj.Parent() == nil { return nil, fmt.Errorf("parent nil") } - isUsed, firstUseAfter := - objUsed(info, span.NewRange(fset, rng.End, obj.Parent().End()), obj) - _, isAssigned := assigned[obj] - _, isFree := free[obj] - if isAssigned && isUsed && !varOverridden(info, firstUseAfter, obj, isFree, outer) { + isUsed, firstUseAfter := objUsed(info, span.NewRange(fset, rng.End, v.obj.Parent().End()), v.obj) + if v.assigned && isUsed && !varOverridden(info, firstUseAfter, v.obj, v.free, outer) { returnTypes = append(returnTypes, &ast.Field{Type: typ}) returns = append(returns, identifier) - if !isFree { - uninitialized = append(uninitialized, obj) - } else if obj.Parent().Pos() == startParent.Pos() { + if !v.free { + uninitialized = append(uninitialized, v.obj) + } else if v.obj.Parent().Pos() == startParent.Pos() { canRedefineCount++ } } - _, isDefined := defined[obj] // An identifier must meet two conditions to become a parameter of the // extracted function. (1) it must be free (isFree), and (2) its first // use within the selection cannot be its own definition (isDefined). - if isFree && !isDefined { + if v.free && !v.defined { params = append(params, identifier) paramTypes = append(paramTypes, &ast.Field{ Names: []*ast.Ident{identifier}, @@ -409,8 +408,7 @@ func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast. // statements in the selection. Update the type signature of the extracted // function and construct the if statement that will be inserted in the enclosing // function. - retVars, ifReturn, err = generateReturnInfo( - enclosing, pkg, path, file, info, fset, rng.Start) + retVars, ifReturn, err = generateReturnInfo(enclosing, pkg, path, file, info, fset, rng.Start) if err != nil { return nil, err } @@ -500,13 +498,11 @@ func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast. fullReplacement.Write(newFuncBuf.Bytes()) // insert the extracted function return &analysis.SuggestedFix{ - TextEdits: []analysis.TextEdit{ - { - Pos: outer.Pos(), - End: outer.End(), - NewText: []byte(fullReplacement.String()), - }, - }, + TextEdits: []analysis.TextEdit{{ + Pos: outer.Pos(), + End: outer.End(), + NewText: []byte(fullReplacement.String()), + }}, }, nil } @@ -561,15 +557,28 @@ func findParent(start ast.Node, target ast.Node) ast.Node { return parent } +// variable describes the status of a variable within a selection. +type variable struct { + obj types.Object + + // free reports whether the variable is a free variable, meaning it should + // be a parameter to the extracted function. + free bool + + // assigned reports whether the variable is assigned to in the selection. + assigned bool + + // defined reports whether the variable is defined in the selection. + defined bool +} + // collectFreeVars maps each identifier in the given range to whether it is "free." // Given a range, a variable in that range is defined as "free" if it is declared // outside of the range and neither at the file scope nor package scope. These free // variables will be used as arguments in the extracted function. It also returns a // list of identifiers that may need to be returned by the extracted function. // Some of the code in this function has been adapted from tools/cmd/guru/freevars.go. -func collectFreeVars(info *types.Info, file *ast.File, fileScope *types.Scope, - pkgScope *types.Scope, rng span.Range, node ast.Node) (map[types.Object]struct{}, - []types.Object, map[types.Object]struct{}, map[types.Object]struct{}) { +func collectFreeVars(info *types.Info, file *ast.File, fileScope, pkgScope *types.Scope, rng span.Range, node ast.Node) ([]*variable, error) { // id returns non-nil if n denotes an object that is referenced by the span // and defined either within the span or in the lexical environment. The bool // return value acts as an indicator for where it was defined. @@ -612,7 +621,7 @@ func collectFreeVars(info *types.Info, file *ast.File, fileScope *types.Scope, } return nil, false } - free := make(map[types.Object]struct{}) + seen := make(map[types.Object]*variable) firstUseIn := make(map[types.Object]token.Pos) var vars []types.Object ast.Inspect(node, func(n ast.Node) bool { @@ -630,15 +639,16 @@ func collectFreeVars(info *types.Info, file *ast.File, fileScope *types.Scope, prune = true } if obj != nil { - if isFree { - free[obj] = struct{}{} + seen[obj] = &variable{ + obj: obj, + free: isFree, } + vars = append(vars, obj) // Find the first time that the object is used in the selection. first, ok := firstUseIn[obj] if !ok || n.Pos() < first { firstUseIn[obj] = n.Pos() } - vars = append(vars, obj) if prune { return false } @@ -657,8 +667,6 @@ func collectFreeVars(info *types.Info, file *ast.File, fileScope *types.Scope, // 3: y := 3 // 4: z := x + a // - assigned := make(map[types.Object]struct{}) - defined := make(map[types.Object]struct{}) ast.Inspect(node, func(n ast.Node) bool { if n == nil { return false @@ -677,7 +685,10 @@ func collectFreeVars(info *types.Info, file *ast.File, fileScope *types.Scope, if obj == nil { continue } - assigned[obj] = struct{}{} + if _, ok := seen[obj]; !ok { + continue + } + seen[obj].assigned = true if n.Tok != token.DEFINE { continue } @@ -697,7 +708,10 @@ func collectFreeVars(info *types.Info, file *ast.File, fileScope *types.Scope, if referencesObj(info, expr, obj) { continue } - defined[obj] = struct{}{} + if _, ok := seen[obj]; !ok { + continue + } + seen[obj].defined = true break } } @@ -717,7 +731,10 @@ func collectFreeVars(info *types.Info, file *ast.File, fileScope *types.Scope, if obj == nil { continue } - assigned[obj] = struct{}{} + if _, ok := seen[obj]; !ok { + continue + } + seen[obj].assigned = true } } return false @@ -727,12 +744,23 @@ func collectFreeVars(info *types.Info, file *ast.File, fileScope *types.Scope, } else if obj, _ := id(ident); obj == nil { return false } else { - assigned[obj] = struct{}{} + if _, ok := seen[obj]; !ok { + return false + } + seen[obj].assigned = true } } return true }) - return free, vars, assigned, defined + var variables []*variable + for _, obj := range vars { + v, ok := seen[obj] + if !ok { + return nil, fmt.Errorf("no seen types.Object for %v", obj) + } + variables = append(variables, v) + } + return variables, nil } // referencesObj checks whether the given object appears in the given expression. @@ -756,29 +784,34 @@ func referencesObj(info *types.Info, expr ast.Expr, obj types.Object) bool { return hasObj } -// canExtractFunction reports whether the code in the given range can be extracted to a function. -func canExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, info *types.Info) (*token.File, []ast.Node, span.Range, *ast.FuncDecl, ast.Node, bool, error) { +type fnExtractParams struct { + tok *token.File + path []ast.Node + rng span.Range + outer *ast.FuncDecl + start ast.Node +} + +// canExtractFunction reports whether the code in the given range can be +// extracted to a function. +func canExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, info *types.Info) (*fnExtractParams, bool, error) { if rng.Start == rng.End { - return nil, nil, span.Range{}, nil, nil, false, - fmt.Errorf("start and end are equal") + return nil, false, fmt.Errorf("start and end are equal") } tok := fset.File(file.Pos()) if tok == nil { - return nil, nil, span.Range{}, nil, nil, false, - fmt.Errorf("no file for pos %v", fset.Position(file.Pos())) + return nil, false, fmt.Errorf("no file for pos %v", fset.Position(file.Pos())) } rng = adjustRangeForWhitespace(rng, tok, src) path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End) if len(path) == 0 { - return nil, nil, span.Range{}, nil, nil, false, - fmt.Errorf("no path enclosing interval") + return nil, false, fmt.Errorf("no path enclosing interval") } // Node that encloses the selection must be a statement. // TODO: Support function extraction for an expression. _, ok := path[0].(ast.Stmt) if !ok { - return nil, nil, span.Range{}, nil, nil, false, - fmt.Errorf("node is not a statement") + return nil, false, fmt.Errorf("node is not a statement") } // Find the function declaration that encloses the selection. @@ -790,7 +823,7 @@ func canExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *a } } if outer == nil { - return nil, nil, span.Range{}, nil, nil, false, fmt.Errorf("no enclosing function") + return nil, false, fmt.Errorf("no enclosing function") } // Find the nodes at the start and end of the selection. @@ -799,8 +832,8 @@ func canExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *a if n == nil { return false } - // Do not override 'start' with a node that begins at the same location but is - // nested further from 'outer'. + // Do not override 'start' with a node that begins at the same location + // but is nested further from 'outer'. if start == nil && n.Pos() == rng.Start && n.End() <= rng.End { start = n } @@ -810,10 +843,15 @@ func canExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *a return n.Pos() <= rng.End }) if start == nil || end == nil { - return nil, nil, span.Range{}, nil, nil, false, - fmt.Errorf("range does not map to AST nodes") + return nil, false, fmt.Errorf("range does not map to AST nodes") } - return tok, path, rng, outer, start, true, nil + return &fnExtractParams{ + tok: tok, + path: path, + rng: rng, + outer: outer, + start: start, + }, true, nil } // objUsed checks if the object is used within the range. It returns the first occurence of diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go index 0c380113dcd..63d24df0041 100644 --- a/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go +++ b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go @@ -5,6 +5,7 @@ func _() { a = 5 //@mark(exSt0, "a") a = a + 2 //@mark(exEn0, "2") //@extractfunc(exSt0, exEn0) - b := a * 2 - _ = 3 + 4 + b := a * 2 //@mark(exB, " b") + _ = 3 + 4 //@mark(exEnd, "4") + //@extractfunc(exB, exEnd) } diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go.golden b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go.golden index 04caef266bf..d31fcc1c87f 100644 --- a/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go.golden +++ b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go.golden @@ -5,8 +5,9 @@ func _() { a := 1 a = fn0(a) //@mark(exEn0, "2") //@extractfunc(exSt0, exEn0) - b := a * 2 - _ = 3 + 4 + b := a * 2 //@mark(exB, " b") + _ = 3 + 4 //@mark(exEnd, "4") + //@extractfunc(exB, exEnd) } func fn0(a int) int { @@ -15,3 +16,20 @@ func fn0(a int) int { return a } +-- functionextraction_extract_args_returns_8_1 -- +package extract + +func _() { + a := 1 + a = 5 //@mark(exSt0, "a") + a = a + 2 //@mark(exEn0, "2") + //@extractfunc(exSt0, exEn0) + fn0(a) //@mark(exEnd, "4") + //@extractfunc(exB, exEnd) +} + +func fn0(a int) { + b := a * 2 + _ = 3 + 4 +} + diff --git a/internal/lsp/testdata/lsp/summary.txt.golden b/internal/lsp/testdata/lsp/summary.txt.golden index f625017c64a..e6e82d10e50 100644 --- a/internal/lsp/testdata/lsp/summary.txt.golden +++ b/internal/lsp/testdata/lsp/summary.txt.golden @@ -13,7 +13,7 @@ FoldingRangesCount = 2 FormatCount = 6 ImportCount = 8 SuggestedFixCount = 38 -FunctionExtractionCount = 11 +FunctionExtractionCount = 12 DefinitionsCount = 63 TypeDefinitionsCount = 2 HighlightsCount = 69