diff --git a/Sources/SwiftLanguageService/CodeActions/MoveMembersToExtension.swift b/Sources/SwiftLanguageService/CodeActions/MoveMembersToExtension.swift new file mode 100644 index 000000000..1abc49334 --- /dev/null +++ b/Sources/SwiftLanguageService/CodeActions/MoveMembersToExtension.swift @@ -0,0 +1,169 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2026 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +@_spi(SourceKitLSP) import LanguageServerProtocol +import SwiftRefactor +import SwiftSyntax + +private enum ValidationResult: CustomStringConvertible { + case accessor + case deinitializer + case enumCase + case storedProperty + + var description: String { + switch self { + case .accessor: return "accessor" + case .deinitializer: return "deinitializer" + case .enumCase: return "enum case" + case .storedProperty: return "stored property" + } + } + + /// Validates that `member` can be moved to an extension. If it can, return `nil`, otherwise return the reason why + /// `member` cannot be moved to an extension. + init?(_ member: MemberBlockItemSyntax) { + switch member.decl.kind { + case .accessorDecl: + self = .accessor + case .deinitializerDecl: + self = .deinitializer + case .enumCaseDecl: + self = .enumCase + default: + if let varDecl = member.decl.as(VariableDeclSyntax.self), + varDecl.bindings.contains(where: { $0.accessorBlock == nil || $0.initializer != nil }) + { + self = .storedProperty + } else { + return nil + } + } + } +} + +struct MoveMembersToExtension: SyntaxRefactoringProvider { + struct Context { + let range: Range + + init(range: Range) { + self.range = range + } + } + + static func refactor(syntax: SourceFileSyntax, in context: Context) throws -> SourceFileSyntax { + guard + let statement = syntax.statements.first(where: { $0.item.range.contains(context.range) }), + let decl = statement.item.asProtocol((any NamedDeclSyntax).self), + let declGroup = statement.item.asProtocol((any DeclGroupSyntax).self), + let statementIndex = syntax.statements.index(of: statement) + else { + throw RefactoringNotApplicableError("Type declaration not found") + } + + let selectedMembers = Array(declGroup.memberBlock.members).filter { context.range.overlaps($0.trimmedRange) } + .map { (member: $0, validationResult: ValidationResult($0)) } + + var membersToMove = selectedMembers.filter({ $0.validationResult == nil }).map(\.member) + + guard !membersToMove.isEmpty else { + let notMovedMembers = Set(selectedMembers.compactMap(\.validationResult)) + .map(\.description) + .sorted().joined(separator: ", ") + throw RefactoringNotApplicableError( + "Cannot move \(notMovedMembers) to extension" + ) + } + + var updatedDeclGroup = declGroup + var remainingMembers = Array(declGroup.memberBlock.members).filter { !membersToMove.contains($0) } + membersToMove[0].decl.leadingTrivia = membersToMove[0].decl.leadingTrivia.trimmingPrefix(while: \.isSpaceOrTab) + + if remainingMembers.isEmpty { + updatedDeclGroup.memberBlock.rightBrace.leadingTrivia = Trivia() + } else { + remainingMembers[0].leadingTrivia = .newline.merging( + remainingMembers[0].leadingTrivia.trimmingPrefix(while: \.isNewline) + ) + remainingMembers[remainingMembers.count - 1].trailingTrivia = remainingMembers[remainingMembers.count - 1] + .trailingTrivia.trimmingSuffix(while: \.isNewline) + } + + updatedDeclGroup.memberBlock.members = MemberBlockItemListSyntax(remainingMembers) + membersToMove[0].leadingTrivia = .newline.merging(membersToMove[0].leadingTrivia.trimmingPrefix(while: \.isNewline)) + let extensionMemberBlockSyntax = declGroup.memberBlock.with(\.members, MemberBlockItemListSyntax(membersToMove)) + + var declName = decl.name + declName.trailingTrivia = declName.trailingTrivia.merging(.space) + + let extensionDecl = ExtensionDeclSyntax( + leadingTrivia: .newlines(2), + extendedType: IdentifierTypeSyntax( + leadingTrivia: .space, + name: declName + ), + memberBlock: extensionMemberBlockSyntax + ) + + var syntax = syntax + let updatedStatement = statement.with(\.item, .decl(DeclSyntax(updatedDeclGroup))) + syntax.statements[statementIndex] = updatedStatement + syntax.statements.insert( + CodeBlockItemSyntax(item: .decl(DeclSyntax(extensionDecl))), + at: syntax.statements.index(after: statementIndex) + ) + return syntax + } +} + +extension MoveMembersToExtension: SyntaxRefactoringCodeActionProvider { + static var title: String { "Move to extension" } + + static func refactoringContext(for scope: SyntaxCodeActionScope) -> Context { + Context(range: scope.range) + } + + static func nodeToRefactor(in scope: SyntaxCodeActionScope) -> SourceFileSyntax? { + guard scope.request.range.lowerBound != scope.request.range.upperBound else { + return nil + } + + return scope.file + } + + static func textRefactor(syntax: SourceFileSyntax, in context: Context) throws -> [SourceEdit] { + let updatedSyntax = try self.refactor(syntax: syntax, in: context) + + return [ + .replace(syntax, with: updatedSyntax.description) + ] + } +} + +fileprivate extension Trivia { + func trimmingPrefix( + while predicate: (TriviaPiece) -> Bool + ) -> Trivia { + Trivia(pieces: self.drop(while: predicate)) + } + + func trimmingSuffix( + while predicate: (TriviaPiece) -> Bool + ) -> Trivia { + Trivia( + pieces: self[...] + .reversed() + .drop(while: predicate) + .reversed() + ) + } +} diff --git a/Sources/SwiftLanguageService/CodeActions/SyntaxCodeActions.swift b/Sources/SwiftLanguageService/CodeActions/SyntaxCodeActions.swift index a481a5950..3a5a856ef 100644 --- a/Sources/SwiftLanguageService/CodeActions/SyntaxCodeActions.swift +++ b/Sources/SwiftLanguageService/CodeActions/SyntaxCodeActions.swift @@ -28,6 +28,7 @@ let allSyntaxCodeActions: [any SyntaxCodeActionProvider.Type] = { ConvertZeroParameterFunctionToComputedProperty.self, FormatRawStringLiteral.self, MigrateToNewIfLetSyntax.self, + MoveMembersToExtension.self, OpaqueParameterToGeneric.self, RemoveRedundantParentheses.self, RemoveSeparatorsFromIntegerLiteral.self, @@ -41,5 +42,6 @@ let allSyntaxCodeActions: [any SyntaxCodeActionProvider.Type] = { let supersededSourcekitdRefactoringActions: Set = [ "source.refactoring.kind.convert.to.computed.property", // Superseded by ConvertStoredPropertyToComputed + "source.refactoring.kind.move.members.to.extension", // Superseded by MoveMembersToExtension "source.refactoring.kind.simplify.long.number.literal", // Superseded by AddSeparatorsToIntegerLiteral ] diff --git a/Tests/SourceKitLSPTests/MoveMembersToExtensionTests.swift b/Tests/SourceKitLSPTests/MoveMembersToExtensionTests.swift new file mode 100644 index 000000000..07141d972 --- /dev/null +++ b/Tests/SourceKitLSPTests/MoveMembersToExtensionTests.swift @@ -0,0 +1,340 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2026 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +@_spi(SourceKitLSP) import LanguageServerProtocol +import SKLogging +import SKTestSupport +import SourceKitLSP +import SwiftExtensions +@_spi(Testing) import SwiftLanguageService +import SwiftParser +import SwiftRefactor +import SwiftSyntax +import SwiftSyntaxBuilder +import XCTest + +private typealias CodeActionCapabilities = TextDocumentClientCapabilities.CodeAction +private typealias CodeActionLiteralSupport = CodeActionCapabilities.CodeActionLiteralSupport +private typealias CodeActionKindCapabilities = CodeActionLiteralSupport.CodeActionKindValueSet + +private let clientCapabilitiesWithCodeActionSupport: ClientCapabilities = { + var documentCapabilities = TextDocumentClientCapabilities() + var codeActionCapabilities = CodeActionCapabilities() + let codeActionKinds = CodeActionKindCapabilities(valueSet: [.refactor, .quickFix]) + let codeActionLiteralSupport = CodeActionLiteralSupport(codeActionKind: codeActionKinds) + codeActionCapabilities.codeActionLiteralSupport = codeActionLiteralSupport + documentCapabilities.codeAction = codeActionCapabilities + documentCapabilities.completion = .init(completionItem: .init(snippetSupport: true)) + return ClientCapabilities(workspace: nil, textDocument: documentCapabilities) +}() + +final class MoveMembersToExtensionTests: SourceKitLSPTestCase { + func testMoveMembersToExtension() async throws { + try await assertMoveMembersToExtensionCodeAction( + """ + 1️⃣class Foo { + 2️⃣func foo() { + print("Hello world!") + }3️⃣ + + func bar() { + print("Hello world!") + } + }4️⃣ + """, + expected: + """ + class Foo { + func bar() { + print("Hello world!") + } + } + + extension Foo { + func foo() { + print("Hello world!") + } + } + """ + ) + } + + func testMoveParticiallySelectedFunctionFromClass() async throws { + try await assertMoveMembersToExtensionCodeAction( + """ + 1️⃣class Foo { + func foo() { + print("Hello world!") + } + + func bar() { + 2️⃣print("Hello world!") + }3️⃣ + } + + struct Bar { + func foo() {} + }4️⃣ + """, + expected: + """ + class Foo { + func foo() { + print("Hello world!") + } + } + + extension Foo { + func bar() { + print("Hello world!") + } + } + + struct Bar { + func foo() {} + } + """ + ) + } + + func testMoveSelectedFromClass() async throws { + try await assertMoveMembersToExtensionCodeAction( + """ + 1️⃣class Foo {2️⃣ + func foo() { + print("Hello world!") + } + + deinit() {} + + func bar() { + print("Hello world!") + }3️⃣ + } + + struct Bar { + func foo() {} + }4️⃣ + """, + expected: + """ + class Foo { + deinit() {} + } + + extension Foo { + func foo() { + print("Hello world!") + } + + func bar() { + print("Hello world!") + } + } + + struct Bar { + func foo() {} + } + """ + ) + } + + func testMoveNestedFromStruct() async throws { + try await assertMoveMembersToExtensionCodeAction( + """ + 1️⃣struct Outer {2️⃣ + struct Inner { + func moveThis() {} + }3️⃣ + }4️⃣ + """, + expected: + """ + struct Outer {} + + extension Outer { + struct Inner { + func moveThis() {} + } + } + """ + ) + } + + func testMoveNestedFromStruct2() async throws { + try await assertMoveMembersToExtensionCodeAction( + """ + 1️⃣struct Outer {2️⃣ + struct Inner { + func moveThis() {} + }3️⃣ + }4️⃣ + """, + expected: + """ + struct Outer {} + + extension Outer { + struct Inner { + func moveThis() {} + } + } + """ + ) + } + + func testMoveSelectedFunctionName() async throws { + try await assertMoveMembersToExtensionCodeAction( + """ + 1️⃣struct Outer { + struct Inner { + func 2️⃣moveThis()3️⃣ {} + } + }4️⃣ + """, + expected: + """ + struct Outer {} + + extension Outer { + struct Inner { + func moveThis() {} + } + } + """ + ) + } + + func testSelectedDeinitializerMember() async throws { + let source = """ + 1️⃣class Foo { + func foo() { + print("Hello world!") + } + + 2️⃣deinit() {}3️⃣ + + func bar() { + print("Hello world!") + } + } + + struct Bar { + func foo() {} + }4️⃣ + """ + + let testClient = try await TestSourceKitLSPClient(capabilities: clientCapabilitiesWithCodeActionSupport) + let uri = DocumentURI(for: .swift) + + let positions = testClient.openDocument(source, uri: uri) + + let request = CodeActionRequest( + range: positions["2️⃣"]..