diff --git a/gopls/internal/golang/extract.go b/gopls/internal/golang/extract.go index 0e89f74dd11..328b5b13ab3 100644 --- a/gopls/internal/golang/extract.go +++ b/gopls/internal/golang/extract.go @@ -630,6 +630,8 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to allReturnsFinalErr = true // all ReturnStmts have final 'err' expression hasReturn = false // selection contains a ReturnStmt filter = []ast.Node{(*ast.ReturnStmt)(nil), (*ast.FuncLit)(nil)} + + origRetStmts []*ast.ReturnStmt // return stmts in source order, for type lookups ) curEnclosing.Inspect(filter, func(cur inspector.Cursor) (descend bool) { if funcLit, ok := cur.Node().(*ast.FuncLit); ok { @@ -643,6 +645,8 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to } hasReturn = true + origRetStmts = append(origRetStmts, ret) + if cur.Parent() == curStart.Parent() { hasNonNestedReturn = true } @@ -1091,7 +1095,7 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to declarations = initializeVars(uninitialized, retVars, seenUninitialized, seenVars) } - var declBuf, replaceBuf, newFuncBuf, ifBuf, commentBuf bytes.Buffer + var declBuf, replaceBuf, newFuncBuf, ifBuf bytes.Buffer if err := format.Node(&declBuf, fset, declarations); err != nil { return nil, nil, err } @@ -1112,6 +1116,15 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to } } + newFuncResults := &ast.FieldList{List: append(returnTypes, getDecls(retVars)...)} + + // Expand multi-value function calls in return statements. + // If a return contains a single CallExpr that is being augmented with new + // return values, the call return values must be expanded to maintain valid syntax. + if err := expandMultiValueCallReturns(extractedBlock, info, newFuncResults, file, start, origRetStmts); err != nil { + return nil, nil, err + } + // Build the extracted function. We format the function declaration and body // separately, so that comments are printed relative to the extracted // BlockStmt. @@ -1125,7 +1138,7 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to Name: ast.NewIdent(funName), Type: &ast.FuncType{ Params: &ast.FieldList{List: paramTypes}, - Results: &ast.FieldList{List: append(returnTypes, getDecls(retVars)...)}, + Results: newFuncResults, }, // Body handled separately -- see above. } @@ -1172,10 +1185,6 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to var fullReplacement strings.Builder fullReplacement.Write(before) - if commentBuf.Len() > 0 { - comments := strings.ReplaceAll(commentBuf.String(), "\n", newLineIndent) - fullReplacement.WriteString(comments) - } if declBuf.Len() > 0 { // add any initializations, if needed initializations := strings.ReplaceAll(declBuf.String(), "\n", newLineIndent) + newLineIndent @@ -1216,6 +1225,126 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to }, nil } +// expandMultiValueCallReturns expands multi-value function calls in return +// statements within the extracted block. +func expandMultiValueCallReturns(extractedBlock *ast.BlockStmt, info *types.Info, newFuncResults *ast.FieldList, file *ast.File, start token.Pos, origRetStmts []*ast.ReturnStmt) error { + // The re-parsed AST has no type information, so we pair its return stmts + // with the original (type-checked) ones to look up types for naming. + // + // The pairing is done as a separate pass because the second pass doesn't + // exactly visit the ReturnStmt in the same way as how the origRetStmts + // is collected (via ast.Inspect). + origRetMap := map[*ast.ReturnStmt]*ast.ReturnStmt{} + origIdx := 0 + ast.Inspect(extractedBlock, func(n ast.Node) bool { + switch n := n.(type) { + case *ast.ReturnStmt: + if origIdx < len(origRetStmts) { + origRetMap[n] = origRetStmts[origIdx] + origIdx++ + } else { + // The re-parsed AST may have injected returns appended + // at the end with no original counterpart but it is ok since + // we can guarantee it will not have CallExpr in it. + return false + } + case *ast.FuncLit: + return false // don't descend into closures. + } + + return true + }) + + // Traverse the extracted block again and do the actual expansion. + var expandErr error + ast.Inspect(extractedBlock, func(n ast.Node) bool { + if expandErr != nil { + return false + } + switch n := n.(type) { + case *ast.BlockStmt: + n.List, expandErr = expandFunctionCallReturnValues(n.List, info, newFuncResults, file, start, origRetMap) + case *ast.CaseClause: + n.Body, expandErr = expandFunctionCallReturnValues(n.Body, info, newFuncResults, file, start, origRetMap) + case *ast.FuncLit: + return false // don't descend into closures. + } + + return true + }) + return expandErr +} + +// expandFunctionCallReturnValues expands the return value of function calls +// in the given statement list when necessary. +func expandFunctionCallReturnValues(stmts []ast.Stmt, info *types.Info, newFuncResults *ast.FieldList, file *ast.File, start token.Pos, origRetMap map[*ast.ReturnStmt]*ast.ReturnStmt) ([]ast.Stmt, error) { + result := make([]ast.Stmt, 0, len(stmts)) + for _, stmt := range stmts { + result = append(result, stmt) + + // When we have multiple return statement results, we can't have a CallExpr in it. + // In that case, we need to splat the values of that CallExpr into variable(s) + // and return them. + retStmt, ok := stmt.(*ast.ReturnStmt) + if !ok || len(retStmt.Results) <= 1 { + continue + } + + // We can only have CallExpr in the first return statement result with the assumption that + // the original code is valid. + callExpr, ok := retStmt.Results[0].(*ast.CallExpr) + if !ok { + continue + } + + // Infer the number of function's return value using the enclosing function + // signature and the original return statement results because we don't have + // type information here. This should be correct assuming the original code + // is valid to begin with. + expandedVars := make([]ast.Expr, len(newFuncResults.List)-len(retStmt.Results)+1) // plus one to replace the CallExpr + + // Use type information from the original return statement to + // generate type-aware names and detect scope collisions. + origRet := origRetMap[retStmt] + if origRet == nil { + return nil, bug.Errorf("no original return statement for re-parsed return") + } + + scopePos := origRet.Pos() + origCallExpr := origRet.Results[0].(*ast.CallExpr) + sig := info.TypeOf(origCallExpr.Fun).Underlying().(*types.Signature) + tup := sig.Results() + + // Generate type-aware names for each expanded return values. + prevIdxByPrefix := map[string]int{} + for i := range expandedVars { + prefix := "v" + if name, ok := varNameForType(tup.At(i).Type()); ok { + prefix = name + } + + prev := prevIdxByPrefix[prefix] + name, next := freshName(info, file, scopePos, prefix, prev) + prevIdxByPrefix[prefix] = next + + expandedVars[i] = ast.NewIdent(name) + } + + result[len(result)-1] = ast.Stmt(&ast.AssignStmt{ + Lhs: expandedVars, + Tok: token.DEFINE, + Rhs: []ast.Expr{callExpr}, + TokPos: 0, + }) + result = append(result, &ast.ReturnStmt{ + Return: retStmt.Return, + Results: slices.Concat(expandedVars, retStmt.Results[1:]), + }) + } + + return result, nil +} + // isSelector reports if e is the selector expr , . It works for pointer and non-pointer selector expressions. func isSelector(e ast.Expr, x, sel string) bool { unary, ok := e.(*ast.UnaryExpr) diff --git a/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue77240.txt b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue77240.txt new file mode 100644 index 00000000000..123e4ed4248 --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue77240.txt @@ -0,0 +1,119 @@ +This test verifies the fix for golang/go#77240: type-aware variable naming +when expanding multi-value function call returns during extraction. + +-- go.mod -- +module mod.test/extract + +go 1.18 + +-- p1/p.go -- +package extract + +func Fun(v2 int) (int, int, error) { + switch v2 { //@codeaction("switch", "refactor.extract.function", end=end, result=ext) + case 1: // also a comment! + return doOne() // a comment! + case 2: + return doTwo() + } //@loc(end, "}") + + return 1, 3, nil +} + +func doOne() (int, int, error) { + return 0, 1, nil +} + +func doTwo() (int, int, error) { + return 0, 2, nil +} + +-- @ext/p1/p.go -- +package extract + +func Fun(v2 int) (int, int, error) { + i, i1, err, shouldReturn := newFunction(v2) + if shouldReturn { + return i, i1, err + } //@loc(end, "}") + + return 1, 3, nil +} + +func newFunction(v2 int) (int, int, error, bool) { + switch v2 { //@codeaction("switch", "refactor.extract.function", end=end, result=ext) + case 1: + i, // also a comment! + i1, err := doOne() + return i, i1, err, true // a comment! + case 2: + i, i1, err := doTwo() + return i, i1, err, true + } + return 0, 0, nil, false +} + +func doOne() (int, int, error) { + return 0, 1, nil +} + +func doTwo() (int, int, error) { + return 0, 2, nil +} + +-- p2/p.go -- +package extract + +import "fmt" + +func Fun(v2 int) (int, int, error) { + switch v2 { //@codeaction("switch", "refactor.extract.function", end=end2, result=ext2) + case 1: + i := v2 + 1 + i1 := v2 + 2 + err := fmt.Errorf("foo") + fmt.Println(i, i1, err) + return doOne() + case 2: + return doTwo() + } //@loc(end2, "}") + + return 1, 3, nil +} + +func doOne() (int, int, error) { return 0, 1, nil } +func doTwo() (int, int, error) { return 0, 2, nil } + +-- @ext2/p2/p.go -- +package extract + +import "fmt" + +func Fun(v2 int) (int, int, error) { + i, i1, err, shouldReturn := newFunction(v2) + if shouldReturn { + return i, i1, err + } //@loc(end2, "}") + + return 1, 3, nil +} + +func newFunction(v2 int) (int, int, error, bool) { + switch v2 { //@codeaction("switch", "refactor.extract.function", end=end2, result=ext2) + case 1: + i := v2 + 1 + i1 := v2 + 2 + err := fmt.Errorf("foo") + fmt.Println(i, i1, err) + i2, i3, err1 := doOne() + return i2, i3, err1, true + case 2: + i, i1, err := doTwo() + return i, i1, err, true + } + return 0, 0, nil, false +} + +func doOne() (int, int, error) { return 0, 1, nil } +func doTwo() (int, int, error) { return 0, 2, nil } +