Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions crates/ty_python_semantic/resources/mdtest/typed_dict.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ python-version = "3.12"
Here, we define a `TypedDict` using the class-based syntax:

```py
from typing import TypedDict
from typing import Optional, TypedDict

class Person(TypedDict):
name: str
Expand Down Expand Up @@ -49,7 +49,7 @@ Functional `TypedDict`s with non-identifier keys should synthesize `__init__` wi
keys into invalid named parameters:

```py
from typing import TypedDict
from typing import Optional, TypedDict

Config = TypedDict("Config", {"in": int, "x-y": str, "ok": int})
# revealed: Overload[(self: Config, map: Config, /, *, ok: int = ..., **kwargs) -> None, (self: Config, /, *, ok: int, **kwargs) -> None]
Expand Down Expand Up @@ -993,6 +993,55 @@ def convert_positional(src: Source) -> Target:
return Target(src)
```

Unpacking a narrower `TypedDict` into a wider `TypedDict` literal should preserve the unpacked
required keys:

```py
from typing import Optional, TypedDict

class MyTypedDict1(TypedDict):
aaa: int
bbb: int

class MyTypedDict2(TypedDict):
aaa: int
bbb: int
ccc: int

d1: MyTypedDict1 = {
"aaa": 1,
"bbb": 2,
}

d2: MyTypedDict2 = {
**d1,
"ccc": 3,
}

d3 = MyTypedDict2({**d1, "ccc": 3})

class BadTypedDict1(TypedDict):
aaa: str
bbb: int

bad1: BadTypedDict1 = {
"aaa": "bad",
"bbb": 2,
}

ok1: MyTypedDict2 = {
**bad1,
"aaa": 1,
"ccc": 3,
}

ok2 = MyTypedDict2({**bad1, "aaa": 1, "ccc": 3})

# error: [invalid-argument-type] "Invalid argument to key "aaa" with declared type `int` on TypedDict `MyTypedDict2`: value of type `str`"
still_union: Optional[MyTypedDict2] = {**bad1, "ccc": 3}
reveal_type(still_union) # revealed: MyTypedDict2 | None
```

Unpacking `Never` or a dynamic type (`Any`, `Unknown`) passes unconditionally, since these types can
have any keys:

Expand Down
109 changes: 83 additions & 26 deletions crates/ty_python_semantic/src/types/typed_dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1110,8 +1110,9 @@ fn validate_extracted_typed_dict_keys<'db, 'ast>(
nodes: TypedDictAssignmentNodes<'ast>,
full_object_ty: Option<Type<'db>>,
ignored_keys: &OrderSet<Name>,
) -> OrderSet<Name> {
) -> (OrderSet<Name>, bool) {
let mut provided_keys = OrderSet::new();
let mut valid = true;

for (key_name, unpacked_key) in unpacked_keys {
if ignored_keys.contains(key_name) {
Expand All @@ -1120,7 +1121,7 @@ fn validate_extracted_typed_dict_keys<'db, 'ast>(
if unpacked_key.is_required {
provided_keys.insert(key_name.clone());
}
TypedDictKeyAssignment {
valid &= TypedDictKeyAssignment {
context,
typed_dict,
full_object_ty,
Expand All @@ -1135,7 +1136,7 @@ fn validate_extracted_typed_dict_keys<'db, 'ast>(
.validate();
}

provided_keys
(provided_keys, valid)
}

/// Validates a mixed-constructor positional argument when its type can be viewed as a `TypedDict`.
Expand All @@ -1160,18 +1161,21 @@ fn validate_from_typed_dict_argument<'db, 'ast>(
.filter(|(key_name, _)| typed_dict_items.contains_key(key_name))
.collect();

Some(validate_extracted_typed_dict_keys(
context,
typed_dict,
&unpacked_keys,
TypedDictAssignmentNodes {
typed_dict: typed_dict_node,
key: arg.into(),
value: arg.into(),
},
full_object_ty_annotation(arg_ty),
ignored_keys,
))
Some(
validate_extracted_typed_dict_keys(
context,
typed_dict,
&unpacked_keys,
TypedDictAssignmentNodes {
typed_dict: typed_dict_node,
key: arg.into(),
value: arg.into(),
},
full_object_ty_annotation(arg_ty),
ignored_keys,
)
.0,
)
}

fn report_duplicate_typed_dict_constructor_key<'db, 'ast>(
Expand Down Expand Up @@ -1380,30 +1384,32 @@ fn validate_from_dict_literal<'db, 'ast>(
) -> OrderSet<Name> {
let mut provided_keys = OrderSet::new();
let items = typed_dict.items(context.db());
let mut shadowed_keys = ignored_keys.clone();

if let ast::Expr::Dict(dict_expr) = &arguments.args[0] {
// Validate dict entries
for dict_item in &dict_expr.items {
for dict_item in dict_expr.items.iter().rev() {
if let Some(ref key_expr) = dict_item.key
&& let Some(key_value) =
expression_type_fn(key_expr, TypeContext::default()).as_string_literal()
{
let key = key_value.value(context.db());
if ignored_keys.contains(key) {
let key = Name::new(key_value.value(context.db()));
if shadowed_keys.contains(&key) {
continue;
}
provided_keys.insert(Name::new(key));
shadowed_keys.insert(key.clone());
provided_keys.insert(key.clone());

let value_tcx = items
.get(key)
.get(key.as_str())
.map(|field| TypeContext::new(Some(field.declared_ty)))
.unwrap_or_default();
let value_ty = expression_type_fn(&dict_item.value, value_tcx);
TypedDictKeyAssignment {
context,
typed_dict,
full_object_ty: None,
key,
key: key.as_str(),
value_ty,
typed_dict_node,
key_node: key_expr.into(),
Expand All @@ -1412,6 +1418,28 @@ fn validate_from_dict_literal<'db, 'ast>(
emit_diagnostic: true,
}
.validate();
} else if dict_item.key.is_none() {
let unpacked_ty = expression_type_fn(&dict_item.value, TypeContext::default());
if let Some(unpacked_keys) =
extract_unpacked_typed_dict_keys(context.db(), unpacked_ty)
{
let (unpacked_provided_keys, _) = validate_extracted_typed_dict_keys(
context,
typed_dict,
&unpacked_keys,
TypedDictAssignmentNodes {
typed_dict: typed_dict_node,
key: (&dict_item.value).into(),
value: (&dict_item.value).into(),
},
full_object_ty_annotation(unpacked_ty),
&shadowed_keys,
);
provided_keys.extend(unpacked_provided_keys);
shadowed_keys.extend(unpacked_keys.into_iter().filter_map(
|(key_name, unpacked_key)| unpacked_key.is_required.then_some(key_name),
));
}
}
}
}
Expand Down Expand Up @@ -1501,7 +1529,9 @@ fn validate_from_keywords<'db, 'ast>(
},
full_object_ty_annotation(unpacked_type),
&OrderSet::new(),
) {
)
.0
{
record_guaranteed_typed_dict_constructor_key(
context,
typed_dict,
Expand All @@ -1528,22 +1558,27 @@ pub(super) fn validate_typed_dict_dict_literal<'db>(
) -> Result<OrderSet<Name>, OrderSet<Name>> {
let mut valid = true;
let mut provided_keys = OrderSet::new();
let mut shadowed_keys = OrderSet::new();

// Validate each key-value pair in the dictionary literal
for item in &dict_expr.items {
for item in dict_expr.items.iter().rev() {
if let Some(key_expr) = &item.key
&& let Some(key_str) = expression_type_fn(key_expr).as_string_literal()
{
let key = key_str.value(context.db());
provided_keys.insert(Name::new(key));
let key = Name::new(key_str.value(context.db()));
if shadowed_keys.contains(&key) {
continue;
}
shadowed_keys.insert(key.clone());
provided_keys.insert(key.clone());

let value_ty = expression_type_fn(&item.value);

valid &= TypedDictKeyAssignment {
context,
typed_dict,
full_object_ty: None,
key,
key: key.as_str(),
value_ty,
typed_dict_node,
key_node: key_expr.into(),
Expand All @@ -1552,6 +1587,28 @@ pub(super) fn validate_typed_dict_dict_literal<'db>(
emit_diagnostic: true,
}
.validate();
} else if item.key.is_none() {
let unpacked_ty = expression_type_fn(&item.value);
if let Some(unpacked_keys) = extract_unpacked_typed_dict_keys(context.db(), unpacked_ty)
{
let (unpacked_provided_keys, unpacked_valid) = validate_extracted_typed_dict_keys(
context,
typed_dict,
&unpacked_keys,
TypedDictAssignmentNodes {
typed_dict: typed_dict_node,
key: (&item.value).into(),
value: (&item.value).into(),
},
full_object_ty_annotation(unpacked_ty),
&shadowed_keys,
);
valid &= unpacked_valid;
provided_keys.extend(unpacked_provided_keys);
shadowed_keys.extend(unpacked_keys.into_iter().filter_map(
|(key_name, unpacked_key)| unpacked_key.is_required.then_some(key_name),
));
}
}
}

Expand Down
Loading