diff --git a/crates/ty_python_semantic/resources/mdtest/typed_dict.md b/crates/ty_python_semantic/resources/mdtest/typed_dict.md index 7addaa0f5b6f2..64ef071f8fccb 100644 --- a/crates/ty_python_semantic/resources/mdtest/typed_dict.md +++ b/crates/ty_python_semantic/resources/mdtest/typed_dict.md @@ -13,7 +13,7 @@ python-version = "3.12" Here, we define a `TypedDict` using the class-based syntax: ```py -from typing import TypedDict +from typing import Optional, TypedDict class Person(TypedDict): name: str @@ -49,7 +49,7 @@ Functional `TypedDict`s with non-identifier keys should synthesize `__init__` wi keys into invalid named parameters: ```py -from typing import TypedDict +from typing import Optional, TypedDict Config = TypedDict("Config", {"in": int, "x-y": str, "ok": int}) # revealed: Overload[(self: Config, map: Config, /, *, ok: int = ..., **kwargs) -> None, (self: Config, /, *, ok: int, **kwargs) -> None] @@ -993,6 +993,55 @@ def convert_positional(src: Source) -> Target: return Target(src) ``` +Unpacking a narrower `TypedDict` into a wider `TypedDict` literal should preserve the unpacked +required keys: + +```py +from typing import Optional, TypedDict + +class MyTypedDict1(TypedDict): + aaa: int + bbb: int + +class MyTypedDict2(TypedDict): + aaa: int + bbb: int + ccc: int + +d1: MyTypedDict1 = { + "aaa": 1, + "bbb": 2, +} + +d2: MyTypedDict2 = { + **d1, + "ccc": 3, +} + +d3 = MyTypedDict2({**d1, "ccc": 3}) + +class BadTypedDict1(TypedDict): + aaa: str + bbb: int + +bad1: BadTypedDict1 = { + "aaa": "bad", + "bbb": 2, +} + +ok1: MyTypedDict2 = { + **bad1, + "aaa": 1, + "ccc": 3, +} + +ok2 = MyTypedDict2({**bad1, "aaa": 1, "ccc": 3}) + +# error: [invalid-argument-type] "Invalid argument to key "aaa" with declared type `int` on TypedDict `MyTypedDict2`: value of type `str`" +still_union: Optional[MyTypedDict2] = {**bad1, "ccc": 3} +reveal_type(still_union) # revealed: MyTypedDict2 | None +``` + Unpacking `Never` or a dynamic type (`Any`, `Unknown`) passes unconditionally, since these types can have any keys: diff --git a/crates/ty_python_semantic/src/types/typed_dict.rs b/crates/ty_python_semantic/src/types/typed_dict.rs index e768400b46e44..151a33fd0a2c1 100644 --- a/crates/ty_python_semantic/src/types/typed_dict.rs +++ b/crates/ty_python_semantic/src/types/typed_dict.rs @@ -1110,8 +1110,9 @@ fn validate_extracted_typed_dict_keys<'db, 'ast>( nodes: TypedDictAssignmentNodes<'ast>, full_object_ty: Option>, ignored_keys: &OrderSet, -) -> OrderSet { +) -> (OrderSet, bool) { let mut provided_keys = OrderSet::new(); + let mut valid = true; for (key_name, unpacked_key) in unpacked_keys { if ignored_keys.contains(key_name) { @@ -1120,7 +1121,7 @@ fn validate_extracted_typed_dict_keys<'db, 'ast>( if unpacked_key.is_required { provided_keys.insert(key_name.clone()); } - TypedDictKeyAssignment { + valid &= TypedDictKeyAssignment { context, typed_dict, full_object_ty, @@ -1135,7 +1136,7 @@ fn validate_extracted_typed_dict_keys<'db, 'ast>( .validate(); } - provided_keys + (provided_keys, valid) } /// Validates a mixed-constructor positional argument when its type can be viewed as a `TypedDict`. @@ -1160,18 +1161,21 @@ fn validate_from_typed_dict_argument<'db, 'ast>( .filter(|(key_name, _)| typed_dict_items.contains_key(key_name)) .collect(); - Some(validate_extracted_typed_dict_keys( - context, - typed_dict, - &unpacked_keys, - TypedDictAssignmentNodes { - typed_dict: typed_dict_node, - key: arg.into(), - value: arg.into(), - }, - full_object_ty_annotation(arg_ty), - ignored_keys, - )) + Some( + validate_extracted_typed_dict_keys( + context, + typed_dict, + &unpacked_keys, + TypedDictAssignmentNodes { + typed_dict: typed_dict_node, + key: arg.into(), + value: arg.into(), + }, + full_object_ty_annotation(arg_ty), + ignored_keys, + ) + .0, + ) } fn report_duplicate_typed_dict_constructor_key<'db, 'ast>( @@ -1380,22 +1384,24 @@ fn validate_from_dict_literal<'db, 'ast>( ) -> OrderSet { let mut provided_keys = OrderSet::new(); let items = typed_dict.items(context.db()); + let mut shadowed_keys = ignored_keys.clone(); if let ast::Expr::Dict(dict_expr) = &arguments.args[0] { // Validate dict entries - for dict_item in &dict_expr.items { + for dict_item in dict_expr.items.iter().rev() { if let Some(ref key_expr) = dict_item.key && let Some(key_value) = expression_type_fn(key_expr, TypeContext::default()).as_string_literal() { - let key = key_value.value(context.db()); - if ignored_keys.contains(key) { + let key = Name::new(key_value.value(context.db())); + if shadowed_keys.contains(&key) { continue; } - provided_keys.insert(Name::new(key)); + shadowed_keys.insert(key.clone()); + provided_keys.insert(key.clone()); let value_tcx = items - .get(key) + .get(key.as_str()) .map(|field| TypeContext::new(Some(field.declared_ty))) .unwrap_or_default(); let value_ty = expression_type_fn(&dict_item.value, value_tcx); @@ -1403,7 +1409,7 @@ fn validate_from_dict_literal<'db, 'ast>( context, typed_dict, full_object_ty: None, - key, + key: key.as_str(), value_ty, typed_dict_node, key_node: key_expr.into(), @@ -1412,6 +1418,28 @@ fn validate_from_dict_literal<'db, 'ast>( emit_diagnostic: true, } .validate(); + } else if dict_item.key.is_none() { + let unpacked_ty = expression_type_fn(&dict_item.value, TypeContext::default()); + if let Some(unpacked_keys) = + extract_unpacked_typed_dict_keys(context.db(), unpacked_ty) + { + let (unpacked_provided_keys, _) = validate_extracted_typed_dict_keys( + context, + typed_dict, + &unpacked_keys, + TypedDictAssignmentNodes { + typed_dict: typed_dict_node, + key: (&dict_item.value).into(), + value: (&dict_item.value).into(), + }, + full_object_ty_annotation(unpacked_ty), + &shadowed_keys, + ); + provided_keys.extend(unpacked_provided_keys); + shadowed_keys.extend(unpacked_keys.into_iter().filter_map( + |(key_name, unpacked_key)| unpacked_key.is_required.then_some(key_name), + )); + } } } } @@ -1501,7 +1529,9 @@ fn validate_from_keywords<'db, 'ast>( }, full_object_ty_annotation(unpacked_type), &OrderSet::new(), - ) { + ) + .0 + { record_guaranteed_typed_dict_constructor_key( context, typed_dict, @@ -1528,14 +1558,19 @@ pub(super) fn validate_typed_dict_dict_literal<'db>( ) -> Result, OrderSet> { let mut valid = true; let mut provided_keys = OrderSet::new(); + let mut shadowed_keys = OrderSet::new(); // Validate each key-value pair in the dictionary literal - for item in &dict_expr.items { + for item in dict_expr.items.iter().rev() { if let Some(key_expr) = &item.key && let Some(key_str) = expression_type_fn(key_expr).as_string_literal() { - let key = key_str.value(context.db()); - provided_keys.insert(Name::new(key)); + let key = Name::new(key_str.value(context.db())); + if shadowed_keys.contains(&key) { + continue; + } + shadowed_keys.insert(key.clone()); + provided_keys.insert(key.clone()); let value_ty = expression_type_fn(&item.value); @@ -1543,7 +1578,7 @@ pub(super) fn validate_typed_dict_dict_literal<'db>( context, typed_dict, full_object_ty: None, - key, + key: key.as_str(), value_ty, typed_dict_node, key_node: key_expr.into(), @@ -1552,6 +1587,28 @@ pub(super) fn validate_typed_dict_dict_literal<'db>( emit_diagnostic: true, } .validate(); + } else if item.key.is_none() { + let unpacked_ty = expression_type_fn(&item.value); + if let Some(unpacked_keys) = extract_unpacked_typed_dict_keys(context.db(), unpacked_ty) + { + let (unpacked_provided_keys, unpacked_valid) = validate_extracted_typed_dict_keys( + context, + typed_dict, + &unpacked_keys, + TypedDictAssignmentNodes { + typed_dict: typed_dict_node, + key: (&item.value).into(), + value: (&item.value).into(), + }, + full_object_ty_annotation(unpacked_ty), + &shadowed_keys, + ); + valid &= unpacked_valid; + provided_keys.extend(unpacked_provided_keys); + shadowed_keys.extend(unpacked_keys.into_iter().filter_map( + |(key_name, unpacked_key)| unpacked_key.is_required.then_some(key_name), + )); + } } }