Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,38 @@ class NodeBuilder(diffGraph: DiffGraphBuilder) {
case Some(hint) =>
val nameSequence = hint match {
case n: ast.Name => Option(n.id)
// TODO: Definitely a place for follow up handling of generics - currently only take the polymorphic type
// without type args. To see the type arguments, see ast.Subscript.slice
case attr: ast.Attribute =>
extractTypesFromHint(Some(attr.value)).map { x => x + "." + attr.attr }
case n: ast.Subscript if n.value.isInstanceOf[ast.Name] => Option(n.value.asInstanceOf[ast.Name].id)
case n: ast.Subscript if n.value.isInstanceOf[ast.Name] =>
val outerName = n.value.asInstanceOf[ast.Name].id
outerName match {
case "Optional" =>
extractTypesFromHint(Some(n.slice)).map(inner => s"$inner|${builtinPrefix}None")
case "List" | "Set" | "FrozenSet" | "Deque" =>
Some(s"${builtinPrefix}${outerName.toLowerCase}")
case "Dict" | "DefaultDict" | "OrderedDict" =>
Some(s"${builtinPrefix}dict")
case "Tuple" =>
Some(s"${builtinPrefix}tuple")
case "Union" =>
n.slice match {
case t: ast.Tuple =>
Some(t.elts.flatMap(e => extractTypesFromHint(Some(e))).mkString("|"))
case _ => Some(s"${typingPrefix}$outerName")
}
case _ => Option(outerName)
}
case n: ast.Constant if n.value.isInstanceOf[ast.StringConstant] =>
Option(n.value.asInstanceOf[ast.StringConstant].value)
case n: ast.BinOp if n.op == ast.BitOr =>
val left = extractTypesFromHint(Some(n.left))
val right = extractTypesFromHint(Some(n.right))
(left, right) match {
case (Some(l), Some(r)) => Some(s"$l|$r")
case (Some(l), _) => Some(l)
case (_, Some(r)) => Some(r)
case _ => None
}
case _ => None
}
nameSequence.map { typeName =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class FunctionDefCpgTests extends PySrc2CpgFixture with Matchers {
.parameter
.typeFullName
.dedup
.l shouldBe Seq("typing.Optional")
.l shouldBe Seq("__builtin.str|__builtin.None")
}

"test return hint of method definition using built-in types" in {
Expand All @@ -201,7 +201,7 @@ class FunctionDefCpgTests extends PySrc2CpgFixture with Matchers {
.methodReturn
.typeFullName
.dedup
.l shouldBe Seq("typing.List")
.l shouldBe Seq("__builtin.list")
}

"test parameter hint of the form abc.def" in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1661,4 +1661,34 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {
}
}
}

"type recovery from annotations and stubs" should {
lazy val cpg = code("""from typing import Optional
|
|def greet(name: str) -> str:
| return "Hello, " + name
|
|def maybe(x: Optional[int]) -> str:
| return str(x)
|
|result = greet("world")
|""".stripMargin).cpg

"resolve annotated parameter type for 'name'" in {
cpg.method.name("greet").parameter.name("name").typeFullName.head shouldBe "__builtin.str"
}

"resolve annotated return type for 'greet'" in {
cpg.method.name("greet").methodReturn.typeFullName.head shouldBe "__builtin.str"
}

"resolve Optional[int] parameter to union type" in {
cpg.method.name("maybe").parameter.name("x").typeFullName.head shouldBe "__builtin.int|__builtin.None"
}

"propagate return type to call site 'result'" in {
val resultTypes = cpg.identifier.name("result").typeFullName.toSet
resultTypes should contain("__builtin.str")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder
case "<operator>.tupleLiteral" => Set(s"${Constants.builtinPrefix}tuple")
case "<operator>.dictLiteral" => Set(s"${Constants.builtinPrefix}dict")
case "<operator>.setLiteral" => Set(s"${Constants.builtinPrefix}set")
case _ => super.getTypesFromCall(c)
case _ =>
PythonTypeStubs.returnTypeFor(c.methodFullName).orElse(PythonTypeStubs.returnTypeFor(c.name)) match {
case Some(rt) => Set(rt)
case None => super.getTypesFromCall(c)
}
}

override def getFieldParents(fa: FieldAccess): Set[String] = {
Expand Down Expand Up @@ -196,6 +200,25 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder
}

override def prepopulateSymbolTable(): Unit = {
// Seed symbol table from Python type annotations on parameters
cu.method.parameter
.filterNot(_.typeFullName == Constants.ANY)
.filterNot(_.name == "self")
.foreach { param =>
symbolTable.append(LocalVar(param.name), Set(param.typeFullName))
}
// Seed from local variable type annotations
cu.method.local
.filterNot(_.typeFullName == Constants.ANY)
.foreach { local =>
symbolTable.append(LocalVar(local.name), Set(local.typeFullName))
}
// Seed from method return type annotations
cu.method
.filterNot(_.methodReturn.typeFullName == Constants.ANY)
.foreach { m =>
symbolTable.append(CallAlias(m.name, None), Set(m.methodReturn.typeFullName))
}
cu.ast.isMethodRef.where(_.astSiblings.isIdentifier.nameExact("classmethod")).referencedMethod.foreach {
classMethod =>
classMethod.parameter
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package io.joern.x2cpg.frontendspecific.pysrc2cpg

object PythonTypeStubs {

private val b = Constants.builtinPrefix

private val stubReturnTypes: Map[String, String] = Map(
s"${b}len" -> s"${b}int", s"${b}int" -> s"${b}int", s"${b}str" -> s"${b}str",
s"${b}float" -> s"${b}float", s"${b}bool" -> s"${b}bool", s"${b}bytes" -> s"${b}bytes",
s"${b}list" -> s"${b}list", s"${b}dict" -> s"${b}dict", s"${b}set" -> s"${b}set",
s"${b}tuple" -> s"${b}tuple", s"${b}type" -> s"${b}type", s"${b}repr" -> s"${b}str",
s"${b}abs" -> s"${b}int", s"${b}hash" -> s"${b}int", s"${b}id" -> s"${b}int",
s"${b}ord" -> s"${b}int", s"${b}chr" -> s"${b}str", s"${b}hex" -> s"${b}str",
s"${b}oct" -> s"${b}str", s"${b}bin" -> s"${b}str", s"${b}range" -> s"${b}range",
s"${b}enumerate" -> s"${b}enumerate", s"${b}zip" -> s"${b}zip",
s"${b}map" -> s"${b}map", s"${b}filter" -> s"${b}filter",
s"${b}sorted" -> s"${b}list", s"${b}reversed" -> s"${b}reversed",
s"${b}isinstance" -> s"${b}bool", s"${b}issubclass" -> s"${b}bool",
s"${b}hasattr" -> s"${b}bool", s"${b}callable" -> s"${b}bool",
s"${b}sum" -> s"${b}int", s"${b}round" -> s"${b}int",
s"${b}input" -> s"${b}str", s"${b}open" -> "io.TextIOWrapper",
s"${b}format" -> s"${b}str",
"len" -> s"${b}int", "int" -> s"${b}int", "str" -> s"${b}str",
"float" -> s"${b}float", "bool" -> s"${b}bool", "bytes" -> s"${b}bytes",
"list" -> s"${b}list", "dict" -> s"${b}dict", "set" -> s"${b}set",
"tuple" -> s"${b}tuple", "type" -> s"${b}type", "repr" -> s"${b}str",
"abs" -> s"${b}int", "hash" -> s"${b}int", "id" -> s"${b}int",
"ord" -> s"${b}int", "chr" -> s"${b}str", "hex" -> s"${b}str",
"oct" -> s"${b}str", "bin" -> s"${b}str", "range" -> s"${b}range",
"enumerate" -> s"${b}enumerate", "zip" -> s"${b}zip",
"map" -> s"${b}map", "filter" -> s"${b}filter",
"sorted" -> s"${b}list", "reversed" -> s"${b}reversed",
"isinstance" -> s"${b}bool", "issubclass" -> s"${b}bool",
"hasattr" -> s"${b}bool", "callable" -> s"${b}bool",
"sum" -> s"${b}int", "round" -> s"${b}int",
"input" -> s"${b}str", "open" -> "io.TextIOWrapper", "format" -> s"${b}str",
"os.path.join" -> s"${b}str", "os.path.exists" -> s"${b}bool",
"os.path.isfile" -> s"${b}bool", "os.path.isdir" -> s"${b}bool",
"os.path.basename" -> s"${b}str", "os.path.dirname" -> s"${b}str",
"os.getcwd" -> s"${b}str", "os.listdir" -> s"${b}list",
"json.loads" -> s"${b}dict", "json.dumps" -> s"${b}str",
"json.load" -> s"${b}dict"
)

def returnTypeFor(name: String): Option[String] = stubReturnTypes.get(name)
}