diff --git a/joern-cli/frontends/rust2cpg/src/main/scala/io/joern/rust2cpg/astcreation/AstCreator.scala b/joern-cli/frontends/rust2cpg/src/main/scala/io/joern/rust2cpg/astcreation/AstCreator.scala index a65abd85c572..28e4fd88e2b3 100644 --- a/joern-cli/frontends/rust2cpg/src/main/scala/io/joern/rust2cpg/astcreation/AstCreator.scala +++ b/joern-cli/frontends/rust2cpg/src/main/scala/io/joern/rust2cpg/astcreation/AstCreator.scala @@ -146,6 +146,27 @@ class AstCreator(val config: Config, val parseResult: ParseResult)(implicit with ) } + protected def typeDeclForImpl(impl: RustNodeSyntax.Impl): NewTypeDecl = { + val implType = typeFullNameForType(impl.typ.last) + val name = implType.split(RustFullNames.PathSep).lastOption.getOrElse(implType) + val parent = methodAstParentStack.head + typeDeclNode( + node = impl, + name = name, + fullName = implType, + filename = parseResult.filename, + code = code(impl), + astParentType = parent.label, + astParentFullName = parent.properties(PropertyNames.FullName).toString + ) + } + + protected def enclosingTypeDeclFullName: Option[String] = { + methodAstParentStack.collectFirst { case typeDecl: NewTypeDecl => + typeDecl.properties(PropertyNames.FullName).toString + } + } + protected def operatorNameFor(binExpr: RustNodeSyntax.BinExpr): Option[String] = binExpr.op match { case Some(_: RustNodeSyntax.Pipe2Token) => Some(Operators.logicalOr) case Some(_: RustNodeSyntax.Amp2Token) => Some(Operators.logicalAnd) diff --git a/joern-cli/frontends/rust2cpg/src/main/scala/io/joern/rust2cpg/astcreation/RustVisitor.scala b/joern-cli/frontends/rust2cpg/src/main/scala/io/joern/rust2cpg/astcreation/RustVisitor.scala index 1956c6ddbfaf..d5b4a5282b58 100644 --- a/joern-cli/frontends/rust2cpg/src/main/scala/io/joern/rust2cpg/astcreation/RustVisitor.scala +++ b/joern-cli/frontends/rust2cpg/src/main/scala/io/joern/rust2cpg/astcreation/RustVisitor.scala @@ -52,7 +52,7 @@ trait RustVisitor(implicit withValidationMode: ValidationMode) { this: AstCreato case x: ExternBlock => notHandledYet(x) :: Nil case x: ExternCrate => notHandledYet(x) :: Nil case fn: Fn => visitFn(fn) :: Nil - case x: Impl => notHandledYet(x) :: Nil + case impl: Impl => visitImpl(impl) case macroCall: MacroCall => visitMacroCall(macroCall) case macroRules: MacroRules => visitMacroRules(macroRules) case macroDef: MacroDef => visitMacroDef(macroDef) @@ -288,6 +288,24 @@ trait RustVisitor(implicit withValidationMode: ValidationMode) { this: AstCreato returnAst(ret, Seq(exprAst)) } + // Impl = + // Attr* Visibility? + // 'default'? 'unsafe'? + // 'impl' GenericParamList? ('const'? '!'? Type 'for')? Type WhereClause? + // AssocItemList + // TODO: support `impl X for Y` and remove the side-effects (storeInDiffGraph). + private def visitImpl(impl: Impl): Seq[Ast] = { + if (impl.forKwToken.isDefined) { + Nil + } else { + methodAstParentStack.push(typeDeclForImpl(impl)) + val methodAsts = impl.assocItemList.assocItem.collect { case fn: Fn => visitFn(fn) } + methodAstParentStack.pop() + methodAsts.foreach(Ast.storeInDiffGraph(_, diffGraph)) + Nil + } + } + // BlockExpr = // Attr* Label? (TryBlockModifier | 'unsafe' | ('async' 'move'?) | ('gen' 'move'?) | 'const') StmtList private def visitBlockExpr(blockExpr: BlockExpr): Ast = { @@ -364,7 +382,8 @@ trait RustVisitor(implicit withValidationMode: ValidationMode) { this: AstCreato // )')' // | '|' (Param (',' Param)* ','?)? '|' private def visitParamList(paramList: ParamList): Seq[Ast] = { - paramList.param.zipWithIndex.map { case (param, paramIdx) => + val selfParamAst = paramList.selfParam.map(visitSelfParam).toList + val paramAsts = paramList.param.zipWithIndex.map { case (param, paramIdx) => val paramName = param.pat.collect { case x: IdentPat => x } val paramTypeFullName = param.typ.map(typeFullNameForType) @@ -383,6 +402,33 @@ trait RustVisitor(implicit withValidationMode: ValidationMode) { this: AstCreato case _ => notHandledYet(param) } } + + selfParamAst ++ paramAsts + } + + // SelfParam = + // Attr* ( ('&' Lifetime?)? 'mut'? Name | 'mut'? Name ':' Type ) + private def visitSelfParam(selfParam: SelfParam): Ast = { + val enclosingType = enclosingTypeDeclFullName.getOrElse(Defines.Any) + val typeFullName = selfParam.typ match { + case Some(typ) => typeFullNameForType(typ) + case None if selfParam.ampToken.isDefined => + val mut = Option.when(selfParam.mutKwToken.isDefined)("mut ").getOrElse("") + s"&$mut$enclosingType" + case None => enclosingType + } + val evaluationStrategy = + if (selfParam.ampToken.isDefined) EvaluationStrategies.BY_SHARING else EvaluationStrategies.BY_VALUE + val paramNode = parameterInNode( + node = selfParam, + name = code(selfParam.name), + code = code(selfParam), + index = 0, + isVariadic = false, + evaluationStrategy = evaluationStrategy, + typeFullName = typeFullName + ) + Ast(paramNode) } // Param = diff --git a/joern-cli/frontends/rust2cpg/src/test/scala/io/joern/rust2cpg/passes/ast/ImplTests.scala b/joern-cli/frontends/rust2cpg/src/test/scala/io/joern/rust2cpg/passes/ast/ImplTests.scala new file mode 100644 index 000000000000..ab4d733375ae --- /dev/null +++ b/joern-cli/frontends/rust2cpg/src/test/scala/io/joern/rust2cpg/passes/ast/ImplTests.scala @@ -0,0 +1,232 @@ +package io.joern.rust2cpg.passes.ast + +import io.joern.rust2cpg.testfixtures.Rust2CpgSuite +import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EvaluationStrategies} +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* +import io.shiftleft.semanticcpg.utils.FileUtil.PathExt + +import java.nio.file.Paths + +class ImplTests extends Rust2CpgSuite(noSysRoot = true) { + + "an inherent method in an impl block" should { + val cpg = code(""" + |struct Foo; + |impl Foo { + | fn bar(&self) {} + |} + |""".stripMargin) + + "have correct fullName" in { + cpg.method.nameExact("bar").fullName.l shouldBe List("rust2cpgtest::Foo::bar") + } + + "have correct return typeFullName" in { + cpg.method.nameExact("bar").methodReturn.typeFullName.l shouldBe List("()") + } + + "have correct self properties" in { + inside(cpg.method.nameExact("bar").parameter.nameExact("self").l) { case self :: Nil => + self.index shouldBe 0 + self.order shouldBe 0 + self.evaluationStrategy shouldBe EvaluationStrategies.BY_SHARING + self.typeFullName shouldBe "&rust2cpgtest::Foo" + } + } + + "be an AST child of the corresponding TYPE_DECL" in { + cpg.typeDecl.nameExact("Foo").method.nameExact("bar").fullName.l shouldBe List("rust2cpgtest::Foo::bar") + } + + "not create a duplicate TYPE_DECL" in { + cpg.typeDecl.fullNameExact("rust2cpgtest::Foo").size shouldBe 1 + } + } + + "an inherent method with a `&mut self` and an explicit parameter" should { + val cpg = code(""" + |struct Foo; + |impl Foo { + | fn bar(&mut self, x: i32) {} + |} + |""".stripMargin) + + "have correct self properties" in { + inside(cpg.method.nameExact("bar").parameter.nameExact("self").l) { case self :: Nil => + self.index shouldBe 0 + self.order shouldBe 0 + self.evaluationStrategy shouldBe EvaluationStrategies.BY_SHARING + self.typeFullName shouldBe "&mut rust2cpgtest::Foo" + } + } + + "have correct explicit parameter properties" in { + inside(cpg.method.nameExact("bar").parameter.nameExact("x").l) { case param :: Nil => + param.index shouldBe 1 + param.order shouldBe 1 + param.typeFullName shouldBe "i32" + } + } + } + + "an inherent method with a by-value `self` receiver" should { + val cpg = code(""" + |struct Foo; + |impl Foo { + | fn bar(self) {} + |} + |""".stripMargin) + + "have correct self properties" in { + inside(cpg.method.nameExact("bar").parameter.nameExact("self").l) { case self :: Nil => + self.index shouldBe 0 + self.order shouldBe 0 + self.evaluationStrategy shouldBe EvaluationStrategies.BY_VALUE + self.typeFullName shouldBe "rust2cpgtest::Foo" + } + } + } + + "an associated function without a receiver" should { + val cpg = code(""" + |struct Foo; + |impl Foo { + | fn new() -> Foo { Foo } + |} + |""".stripMargin) + + "have no parameters" in { + cpg.method.nameExact("new").parameter shouldBe empty + } + + "have correct return typeFullName" in { + cpg.method.nameExact("new").methodReturn.typeFullName.l shouldBe List("rust2cpgtest::Foo") + } + } + + "multiple methods in a single impl block" should { + val cpg = code(""" + |struct Foo; + |impl Foo { + | fn a(&self) {} + | fn b(&self) {} + |} + |""".stripMargin) + + "lower each as an AST child of the corresponding TYPE_DECL" in { + cpg.typeDecl.nameExact("Foo").method.fullName.sorted.l shouldBe List( + "rust2cpgtest::Foo::a", + "rust2cpgtest::Foo::b" + ) + } + } + + "multiple impl blocks for the same type" should { + val cpg = code(""" + |struct Foo; + |impl Foo { + | fn a(&self) {} + |} + |impl Foo { + | fn b(&self) {} + |} + |""".stripMargin) + + "merge all methods under the same TYPE_DECL" in { + cpg.typeDecl.nameExact("Foo").method.fullName.sorted.l shouldBe List( + "rust2cpgtest::Foo::a", + "rust2cpgtest::Foo::b" + ) + } + + "not create a duplicate TYPE_DECL" in { + cpg.typeDecl.fullNameExact("rust2cpgtest::Foo").size shouldBe 1 + } + } + + "multiple impl blocks for the same type spread across files" should { + val cpg = code( + """ + |struct Foo; + |mod a; + |mod b; + |""".stripMargin, + fileName = (Paths.get("src") / "lib.rs").toString + ).moreCode( + """ + |impl crate::Foo { + | fn a(&self) {} + |} + |""".stripMargin, + fileName = (Paths.get("src") / "a.rs").toString + ).moreCode( + """ + |impl crate::Foo { + | fn b(&self) {} + |}""".stripMargin, + fileName = (Paths.get("src") / "b.rs").toString + ) + + "merge all methods under the same TYPE_DECL" in { + cpg.typeDecl.nameExact("Foo").method.fullName.sorted.l shouldBe List( + "rust2cpgtest::Foo::a", + "rust2cpgtest::Foo::b" + ) + } + + "not create a duplicate TYPE_DECL" in { + cpg.typeDecl.fullNameExact("rust2cpgtest::Foo").size shouldBe 1 + } + } + + "a call to an inherent method" should { + val cpg = code(""" + |struct Foo; + |impl Foo { fn bar(&self) {} } + |fn run(f: Foo) { f.bar(); } + |""".stripMargin) + + "have correct properties and callee" in { + inside(cpg.call.nameExact("bar").l) { case call :: Nil => + implicit val callResolver: NoResolve.type = NoResolve + call.methodFullName shouldBe "rust2cpgtest::Foo::bar" + call.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + call.callee.l shouldBe cpg.method.fullNameExact("rust2cpgtest::Foo::bar").l + } + } + + "have correct arguments" in { + inside(cpg.call.nameExact("bar").argument.l) { case (base: Identifier) :: Nil => + base.name shouldBe "f" + base.argumentIndex shouldBe 0 + base.typeFullName shouldBe "&rust2cpgtest::Foo" + } + } + + "have correct self typeFullName" in { + cpg.method.nameExact("bar").parameter.nameExact("self").typeFullName.l shouldBe List("&rust2cpgtest::Foo") + } + } + + "a call to an associated function" should { + val cpg = code(""" + |struct Foo; + |impl Foo { fn new() -> Foo { Foo } } + |fn run() { Foo::new(); } + |""".stripMargin) + + "have correct properties and callee" in { + inside(cpg.call.nameExact("new").l) { case call :: Nil => + implicit val callResolver: NoResolve.type = NoResolve + call.methodFullName shouldBe "rust2cpgtest::Foo::new" + call.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + call.callee.l shouldBe cpg.method.fullNameExact("rust2cpgtest::Foo::new").l + } + } + + "have no arguments" in { + cpg.call.nameExact("new").argument shouldBe empty + } + } +}