Skip to content

Commit 5d7e6bf

Browse files
committed
Implement handling of #[serde(default)] on variant in internally tagged enums
If tag will not be found in the data, the default tag will be assumed
1 parent 9d46abf commit 5d7e6bf

3 files changed

Lines changed: 112 additions & 19 deletions

File tree

serde/src/private/de.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -816,17 +816,18 @@ mod content {
816816
pub struct TaggedContentVisitor<T> {
817817
tag_name: &'static str,
818818
expecting: &'static str,
819-
value: PhantomData<T>,
819+
/// If set, this tag will be used if tag will not be found in data
820+
default: Option<T>,
820821
}
821822

822823
impl<T> TaggedContentVisitor<T> {
823824
/// Visitor for the content of an internally tagged enum with the given
824825
/// tag name.
825-
pub fn new(name: &'static str, expecting: &'static str) -> Self {
826+
pub fn new(name: &'static str, expecting: &'static str, default: Option<T>) -> Self {
826827
TaggedContentVisitor {
827828
tag_name: name,
828829
expecting,
829-
value: PhantomData,
830+
default,
830831
}
831832
}
832833
}
@@ -846,6 +847,8 @@ mod content {
846847
where
847848
S: SeqAccess<'de>,
848849
{
850+
// We do not support sequence representation without tags, because that may
851+
// create ambiguity during deserialization
849852
let tag = match tri!(seq.next_element()) {
850853
Some(tag) => tag,
851854
None => {
@@ -879,7 +882,7 @@ mod content {
879882
}
880883
}
881884
}
882-
match tag {
885+
match tag.or(self.default) {
883886
None => Err(de::Error::missing_field(self.tag_name)),
884887
Some(tag) => Ok((tag, Content::Map(vec))),
885888
}

serde_derive/src/de.rs

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,36 +1387,45 @@ fn deserialize_internally_tagged_enum(
13871387
let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants);
13881388

13891389
// Match arms to extract a variant from a string
1390-
let variant_arms = variants
1390+
let mut variants = variants
13911391
.iter()
13921392
.enumerate()
1393-
.filter(|&(_, variant)| !variant.attrs.skip_deserializing())
1394-
.map(|(i, variant)| {
1395-
let variant_name = field_i(i);
1393+
.filter(|&(_, variant)| !variant.attrs.skip_deserializing());
1394+
let variant_arms = variants.clone().map(|(i, variant)| {
1395+
let variant_name = field_i(i);
13961396

1397-
let block = Match(deserialize_internally_tagged_variant(
1398-
params,
1399-
variant,
1400-
cattrs,
1401-
quote!(__deserializer),
1402-
));
1397+
let block = Match(deserialize_internally_tagged_variant(
1398+
params,
1399+
variant,
1400+
cattrs,
1401+
quote!(__deserializer),
1402+
));
14031403

1404-
quote! {
1405-
__Field::#variant_name => #block
1406-
}
1407-
});
1404+
quote! {
1405+
__Field::#variant_name => #block
1406+
}
1407+
});
14081408

14091409
let expecting = format!("internally tagged enum {}", params.type_name());
14101410
let expecting = cattrs.expecting().unwrap_or(&expecting);
14111411

1412+
// We check that only one variant is marked with #[serde(default)]
1413+
let default = match variants.find(|(_, variant)| variant.attrs.default()) {
1414+
Some((i, _)) => {
1415+
let default = field_i(i);
1416+
quote! { _serde::#private::Some(__Field::#default) }
1417+
}
1418+
None => quote! { _serde::#private::None },
1419+
};
1420+
14121421
quote_block! {
14131422
#variant_visitor
14141423

14151424
#variants_stmt
14161425

14171426
let (__tag, __content) = _serde::Deserializer::deserialize_any(
14181427
__deserializer,
1419-
_serde::#private::de::TaggedContentVisitor::<__Field>::new(#tag, #expecting))?;
1428+
_serde::#private::de::TaggedContentVisitor::<__Field>::new(#tag, #expecting, #default))?;
14201429
let __deserializer = _serde::#private::de::ContentDeserializer::<__D::Error>::new(__content);
14211430

14221431
match __tag {

test_suite/tests/test_enum_internally_tagged.rs

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,87 @@ mod struct_enum {
10371037
}
10381038
}
10391039

1040+
#[test]
1041+
fn default_variant() {
1042+
#[derive(Debug, PartialEq, Serialize, Deserialize)]
1043+
#[serde(tag = "tag")]
1044+
enum InternallyTaggedWithDefault {
1045+
Unit,
1046+
NewtypeUnit(()),
1047+
NewtypeUnitStruct(Unit),
1048+
NewtypeNewtype(Newtype),
1049+
NewtypeMap(BTreeMap<String, String>),
1050+
NewtypeStruct(Struct),
1051+
NewtypeEnum(Enum),
1052+
#[serde(default)]
1053+
Struct {
1054+
a: u8,
1055+
},
1056+
StructEnum {
1057+
enum_: Enum,
1058+
},
1059+
}
1060+
1061+
let value = InternallyTaggedWithDefault::Struct { a: 1 };
1062+
1063+
// Special case: no tag field, use enum tokens
1064+
assert_de_tokens(
1065+
&value,
1066+
&[
1067+
Token::Struct {
1068+
name: "InternallyTagged",
1069+
len: 1,
1070+
},
1071+
Token::Str("a"),
1072+
Token::U8(1),
1073+
Token::StructEnd,
1074+
],
1075+
);
1076+
assert_de_tokens(
1077+
&value,
1078+
&[
1079+
Token::Struct {
1080+
name: "InternallyTagged",
1081+
len: 1,
1082+
},
1083+
Token::BorrowedStr("a"),
1084+
Token::U8(1),
1085+
Token::StructEnd,
1086+
],
1087+
);
1088+
1089+
// Special case: no tag field, Map representation
1090+
assert_de_tokens(
1091+
&value,
1092+
&[
1093+
Token::Map { len: Some(1) },
1094+
Token::Str("a"),
1095+
Token::U8(1),
1096+
Token::MapEnd,
1097+
],
1098+
);
1099+
assert_de_tokens(
1100+
&value,
1101+
&[
1102+
Token::Map { len: Some(1) },
1103+
Token::BorrowedStr("a"),
1104+
Token::U8(1),
1105+
Token::MapEnd,
1106+
],
1107+
);
1108+
1109+
// Special case: Seq representation cannot be used without a tag due to ambiguity
1110+
assert_de_tokens_error::<InternallyTaggedWithDefault>(
1111+
&[
1112+
Token::Seq { len: Some(1) },
1113+
Token::U8(1), // tag (== NewtypeUnit)
1114+
Token::SeqEnd,
1115+
],
1116+
// The error is not very clear, because actually we got end of sequence instead of a Unit
1117+
"invalid type: sequence, expected unit",
1118+
);
1119+
}
1120+
10401121
#[test]
10411122
fn wrong_tag() {
10421123
assert_de_tokens_error::<InternallyTagged>(

0 commit comments

Comments
 (0)