Skip to content

Commit 0681c47

Browse files
committed
Fix recursive generic type alias expansion
1 parent 006b9af commit 0681c47

24 files changed

Lines changed: 728 additions & 203 deletions

crates/ty_ide/src/completion.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2618,9 +2618,9 @@ fn completion_kind_from_type<'db>(db: &'db dyn Db, ty: Type<'db>) -> Option<Comp
26182618
| Type::KnownInstance(_)
26192619
| Type::AlwaysTruthy
26202620
| Type::AlwaysFalsy => return None,
2621-
Type::TypeAlias(alias) => {
2622-
visitor.visit(ty, || imp(db, alias.value_type(db), visitor))?
2623-
}
2621+
Type::TypeAlias(alias) => visitor.visit(ty, || {
2622+
alias.visit_value(db, || None, |value_ty| imp(db, value_ty, visitor))
2623+
})?,
26242624
})
26252625
}
26262626
imp(db, ty, &CompletionKindVisitor::default())

crates/ty_python_semantic/resources/mdtest/pep695_type_aliases.md

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def _(x: Bar[int]):
443443

444444
```py
445445
from typing import Callable, Iterator, Literal, TypedDict, overload
446-
from ty_extensions import all_members, has_member
446+
from ty_extensions import Intersection, Not, all_members, has_member
447447

448448
class RecursiveItem: ...
449449

@@ -543,6 +543,81 @@ def decorated(x: int) -> int:
543543

544544
reveal_type(decorated) # revealed: ((int, /) -> int) | Unknown
545545

546+
type RecursiveCallableSpecialization[T] = Callable[[], RecursiveCallableSpecialization[list[T]]]
547+
548+
def reveal_recursive_callable_specialization(x: RecursiveCallableSpecialization[int]):
549+
reveal_type(x) # revealed: () -> RecursiveCallableSpecialization[list[int]]
550+
551+
type RecursiveTruthinessSpecialization[T] = int | RecursiveTruthinessSpecialization[list[T]]
552+
553+
def reveal_recursive_truthiness_specialization(x: RecursiveTruthinessSpecialization[int]):
554+
reveal_type(x) # revealed: int
555+
556+
def positive_intersection_recursive_truthiness_specialization(
557+
x: Intersection[RecursiveTruthinessSpecialization[int], object],
558+
):
559+
reveal_type(x) # revealed: int
560+
561+
def negative_recursive_truthiness_specialization(
562+
x: Not[RecursiveTruthinessSpecialization[int]],
563+
):
564+
reveal_type(x) # revealed: ~int & ~RecursiveTruthinessSpecialization[list[int]]
565+
566+
def negative_intersection_recursive_truthiness_specialization(
567+
x: Intersection[object, Not[RecursiveTruthinessSpecialization[int]]],
568+
):
569+
reveal_type(x) # revealed: ~int & ~RecursiveTruthinessSpecialization[list[list[int]]]
570+
571+
def bool_recursive_specialization(x: RecursiveTruthinessSpecialization[int]):
572+
if x:
573+
pass
574+
575+
def infer_recursive_truthiness_specialization[T](x: RecursiveTruthinessSpecialization[T]) -> T:
576+
raise NotImplementedError
577+
578+
def call_infer_recursive_truthiness_specialization(x: RecursiveTruthinessSpecialization[int]):
579+
infer_recursive_truthiness_specialization(x)
580+
581+
def accept_negative_recursive_truthiness_specialization[T](x: T) -> None:
582+
pass
583+
584+
def call_accept_negative_recursive_truthiness_specialization(
585+
x: Not[RecursiveTruthinessSpecialization[int]],
586+
):
587+
accept_negative_recursive_truthiness_specialization(x)
588+
589+
def compare_recursive_specialization(x: RecursiveTruthinessSpecialization[int]):
590+
reveal_type(x == 1) # revealed: bool
591+
592+
def binary_recursive_specialization(x: RecursiveTruthinessSpecialization[int]):
593+
reveal_type(x + 1) # revealed: int
594+
595+
type RecursiveSubscriptSpecialization[T] = list[RecursiveSubscriptSpecialization[list[T]]]
596+
597+
def head_recursive_subscript_specialization[T](x: RecursiveSubscriptSpecialization[T]) -> T:
598+
raise NotImplementedError
599+
600+
def call_head_recursive_subscript_specialization(x: RecursiveSubscriptSpecialization[int]):
601+
reveal_type(head_recursive_subscript_specialization(x)) # revealed: int
602+
603+
def subscript_recursive_specialization(x: RecursiveSubscriptSpecialization[int]):
604+
reveal_type(x[0]) # revealed: list[RecursiveSubscriptSpecialization[list[list[int]]]]
605+
606+
type RecursiveClassInfoSpecialization[T] = type[int] | RecursiveClassInfoSpecialization[list[T]]
607+
608+
def isinstance_recursive_specialization(obj: object, classinfo: RecursiveClassInfoSpecialization[int]):
609+
if isinstance(obj, classinfo):
610+
reveal_type(obj) # revealed: object
611+
612+
def issubclass_recursive_specialization(obj: type[object], classinfo: RecursiveClassInfoSpecialization[int]):
613+
if issubclass(obj, classinfo):
614+
reveal_type(obj) # revealed: type
615+
616+
def match_recursive_classinfo_specialization(obj: object, classinfo: RecursiveClassInfoSpecialization[int]):
617+
match obj:
618+
case classinfo():
619+
reveal_type(obj) # revealed: object
620+
546621
class BaseWithMethod:
547622
def method(self) -> None: ...
548623

@@ -832,6 +907,26 @@ reveal_type(CallableIs) # revealed: TypeAliasType
832907
reveal_type(CallableGuard) # revealed: TypeAliasType
833908
```
834909

910+
### Recursive generic aliases in special forms don't stack overflow
911+
912+
```py
913+
from typing import Literal, reveal_type
914+
from typing_extensions import TypeIs
915+
916+
type RecursiveLiteralSpecialization[T] = Literal[1] | RecursiveLiteralSpecialization[list[T]]
917+
type RecursiveIsSpecialization[T] = int | RecursiveIsSpecialization[list[T]]
918+
type RecursiveDuplicateSpecialization[T] = RecursiveIsSpecialization[T] | RecursiveIsSpecialization[T]
919+
920+
def is_recursive(x: object) -> TypeIs[RecursiveIsSpecialization[int]]:
921+
return True
922+
923+
reveal_type(RecursiveLiteralSpecialization) # revealed: TypeAliasType
924+
reveal_type(RecursiveDuplicateSpecialization) # revealed: TypeAliasType
925+
926+
# error: [invalid-type-form] "Type arguments for `Literal` must be `None`, a literal value (int, bool, str, or bytes), or an enum member"
927+
x: Literal[RecursiveLiteralSpecialization[int]]
928+
```
929+
835930
### Recursive alias in binary operators doesn't stack overflow
836931

837932
```py

crates/ty_python_semantic/src/semantic_model.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -476,9 +476,9 @@ impl<'db> SemanticModel<'db> {
476476
.iter()
477477
.flat_map(|element| collect(db, *element, visitor))
478478
.collect(),
479-
Type::TypeAlias(alias) => {
480-
visitor.visit(ty, || collect(db, alias.value_type(db), visitor))
481-
}
479+
Type::TypeAlias(alias) => visitor.visit(ty, || {
480+
alias.visit_value(db, Vec::new, |value_ty| collect(db, value_ty, visitor))
481+
}),
482482
_ => Vec::new(),
483483
}
484484
}

0 commit comments

Comments
 (0)