diff --git a/Cargo.lock b/Cargo.lock index 591f78d2c6c..07fb0fc9b49 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1537,7 +1537,6 @@ dependencies = [ "serde", "serde_json", "toml", - "utf8_iter", "yoke", "zerofrom", "zerovec", diff --git a/components/collections/Cargo.toml b/components/collections/Cargo.toml index c15f61500b4..953b7a3282e 100644 --- a/components/collections/Cargo.toml +++ b/components/collections/Cargo.toml @@ -29,7 +29,6 @@ yoke = { workspace = true, features = ["derive"] } zerofrom = { workspace = true, features = ["derive"] } zerovec = { workspace = true, features = ["derive", "yoke"] } potential_utf = { workspace = true, features = ["zerovec"] } -utf8_iter = { workspace = true } serde = { workspace = true, features = ["derive"], optional = true } databake = { workspace = true, features = ["derive"], optional = true } diff --git a/components/collections/src/codepointinvliststringlist/mod.rs b/components/collections/src/codepointinvliststringlist/mod.rs index 9fb15c64bfe..c06a9fdda8a 100644 --- a/components/collections/src/codepointinvliststringlist/mod.rs +++ b/components/collections/src/codepointinvliststringlist/mod.rs @@ -175,13 +175,22 @@ impl<'data> CodePointInversionListAndStringList<'data> { /// See [`Self::contains_str`] pub fn contains_utf8(&self, s: &[u8]) -> bool { - use utf8_iter::Utf8CharsEx; - let mut chars = s.chars(); - if let Some(first_char) = chars.next() { - if chars.next().is_none() { - return self.contains(first_char); + let single_char = s.get(0).and_then(|&first_byte| { + if first_byte.is_ascii() && s.len() == 1 { + return Some(first_byte as char); } + let utf8_len = first_byte.leading_ones() as usize; + if utf8_len != s.len() { + return None; + } + + str::from_utf8(&s[..utf8_len]).ok()?.chars().next() + }); + + if let Some(single_char) = single_char { + return self.contains(single_char); } + self.str_list .binary_search_by(|t| t.as_bytes().cmp(s)) .is_ok() diff --git a/components/collections/src/codepointtrie/cptrie.rs b/components/collections/src/codepointtrie/cptrie.rs index 086d936e2b5..206cd45d107 100644 --- a/components/collections/src/codepointtrie/cptrie.rs +++ b/components/collections/src/codepointtrie/cptrie.rs @@ -148,6 +148,8 @@ pub struct CodePointTrie<'trie, T: TrieValue> { pub(crate) index: ZeroVec<'trie, u16>, /// # Safety Invariant /// + /// `data.len()` must be 128 or greater regardless of trie type. Furthermore: + /// /// If `header.trie_type == TrieType::Fast`, `data.len()` must be greater /// than `FAST_TYPE_DATA_MASK` plus the largest value in /// `index[0..FAST_TYPE_FAST_INDEXING_MAX + 1]`. Otherwise, `data.len()` @@ -385,6 +387,11 @@ impl<'trie, T: TrieValue> CodePointTrie<'trie, T> { return Err(Error::DataTooShortForFastAccess); } + // Invariant upheld for `data`: Length must be at least 128. + if data.len() < 128 { + return Err(Error::DataTooShortForAsciiAccess); + } + // Invariant upheld for `data`: If we got this far, the length of `data` // satisfies `data`'s length invariant on the assumption that the contents // of `fast_index` subslice of `index` and `header.trie_type` will not @@ -395,7 +402,9 @@ impl<'trie, T: TrieValue> CodePointTrie<'trie, T> { /// Turns this trie into a version whose trie type is encoded in the Rust type. #[inline] - pub fn to_typed(self) -> Typed, SmallCodePointTrie<'trie, T>> { + pub const fn to_typed( + self, + ) -> Typed, SmallCodePointTrie<'trie, T>> { match self.header.trie_type { TrieType::Fast => Typed::Fast(FastCodePointTrie { inner: self }), TrieType::Small => Typed::Small(SmallCodePointTrie { inner: self }), @@ -577,6 +586,16 @@ impl<'trie, T: TrieValue> CodePointTrie<'trie, T> { ); let bit_prefix = (code_point as usize) >> FAST_TYPE_SHIFT; + let bit_suffix = (code_point & FAST_TYPE_DATA_MASK) as usize; + self.get_bit_prefix_suffix_assuming_fast_index(bit_prefix, bit_suffix) + } + + #[inline(always)] + unsafe fn get_bit_prefix_suffix_assuming_fast_index( + &self, + bit_prefix: usize, + bit_suffix: usize, + ) -> T { debug_assert!(bit_prefix < self.index.len()); // SAFETY: Relying on the length invariant of `self.index` having // been checked and on the unchangedness invariant of `self.index` @@ -584,7 +603,6 @@ impl<'trie, T: TrieValue> CodePointTrie<'trie, T> { let base_offset_to_data: usize = usize::from(u16::from_unaligned(*unsafe { self.index.as_ule_slice().get_unchecked(bit_prefix) })); - let bit_suffix = (code_point & FAST_TYPE_DATA_MASK) as usize; // SAFETY: Cannot overflow with supported (32-bit and 64-bit) `usize` // sizes, since `base_offset_to_data` was extended from `u16` and // `bit_suffix` is at most `FAST_TYPE_DATA_MASK`, which is well @@ -695,6 +713,99 @@ impl<'trie, T: TrieValue> CodePointTrie<'trie, T> { } } + /// Returns the value that is associated with `latin1` in this [`CodePointTrie`]. + #[inline(always)] + pub fn get8(&self, latin1: u8) -> T { + let code_point = u32::from(latin1); + debug_assert!(code_point <= SMALL_TYPE_FAST_INDEXING_MAX); + // SAFETY: `u8` is always below `SMALL_TYPE_FAST_INDEXING_MAX` and, + // therefore, belowe `FAST_TYPE_FAST_INDEXING_MAX`. + unsafe { self.get32_assuming_fast_index(code_point) } + } + + /// Returns the value that is associated with `ascii` in this [`CodePointTrie`]. + /// + /// # Safety + /// + /// `ascii` must be less than 128. + #[inline(always)] + pub unsafe fn get7(&self, ascii: u8) -> T { + debug_assert!(ascii < 128); + debug_assert!((ascii as usize) < self.data.len()); + // SAFETY: Allowed by the safety invariant of `self.data` guaranteering a length of at least + // 128. + T::from_unaligned(*unsafe { self.data.as_ule_slice().get_unchecked(ascii as usize) }) + } + + /// Returns the value that is associated with a two-byte UTF-8 sequence in this [`CodePointTrie`]. + /// + /// `high_five` is the low five bits of the lead byte of a two-byte UTF-8 sequence. + /// `low_six` is the low six bits of the trail byte of a two-byte UTF-8 sequence. + /// + /// # Safety + /// + /// `high_five` must not have bit positions other than the lowest 5 set to 1. + /// `low_six` must not have bit positions other than the lowest 6 set to 1. + /// + /// # Panics + /// + /// With debug assertions enabled, panics if the above safety invariants are + /// violated or `high_five` represents non-shortest form. + #[inline(always)] + pub unsafe fn get_utf8_two_byte(&self, high_five: u32, low_six: u32) -> T { + debug_assert!(low_six <= 0b111_111); // Safety invariant. + debug_assert!(high_five <= 0b11_111); // Safety invariant. + debug_assert!(high_five > 0b1); // Non-shortest form; not safety invariant. + + // SAFETY: The highest character representable as a two-byte + // UTF-8 sequence is U+07FF, eleven binary ones, which is below + // both `SMALL_TYPE_FAST_INDEXING_MAX` and `FAST_TYPE_FAST_INDEXING_MAX`. + self.get_bit_prefix_suffix_assuming_fast_index(high_five as usize, low_six as usize) + } + + /// Returns the value that is associated with a three-byte UTF-8 or WTF-8 sequence in this [`CodePointTrie`]. + /// + /// `high_ten` is the low four bits of the lead byte of three-byte UTF-8 or WTF-8 sequence shifted left by 6 followed by the low six bits of the first trail byte. + /// `low_six` is the low six bits of the last trail byte of a three-byte UTF-8 or WTF-8 sequence. + /// + /// Sequences representing surrogates (WTF-8) are allowed. + /// + /// # Safety + /// + /// `low_six` must not have bit positions other than the lowest 6 set to 1. + /// + /// # Intended Invariant + /// + /// `high_ten` must not have bit positions other than the lowest 10 set to 1. + /// + /// # Panics + /// + /// With debug assertions enabled, panics if the above safety invariant is + /// violated or `high_ten` is out of range for three-byte WTF-8 (or UTF-8) + /// sequence. + #[inline(always)] + #[allow(clippy::unusual_byte_groupings)] + pub unsafe fn get_utf8_three_byte(&self, high_ten: u32, low_six: u32) -> T { + debug_assert!(low_six <= 0b111_111); // Safety invariant. + debug_assert!(high_ten <= 0b1111_111_111); // Not actually a _safety_ invariant for this impl. + debug_assert!(high_ten > 0b11_111); // Non-shortest form; not safety invariant. + + let fast_max = match self.header.trie_type { + TrieType::Fast => FAST_TYPE_FAST_INDEXING_MAX, + TrieType::Small => SMALL_TYPE_FAST_INDEXING_MAX, + }; + // Keep only the prefix bits: + let max_bit_prefix = fast_max >> FAST_TYPE_SHIFT; + if high_ten <= max_bit_prefix { + // SAFETY: The caller is responsible for upholding the safety + // invariant for `low_six` and we just checked the safety + // invariant of `high_ten`. + self.get_bit_prefix_suffix_assuming_fast_index(high_ten as usize, low_six as usize) + } else { + self.get32_by_small_index_cold((high_ten << 6) | low_six) + } + } + /// Lookup trie value by non-Basic Multilingual Plane Scalar Value. /// /// The return value may be bogus (not necessarily `error_value`) is the argument is actually in @@ -1432,6 +1543,8 @@ impl Iterator for CodePointMapRangeIterator<'_, T> { /// All implementations of `TypedCodePointTrie` are reviewable in this module. trait Seal {} +impl<'trie, T: TrieValue> Seal for CodePointTrie<'trie, T> {} + /// Trait for writing trait bounds for monomorphizing over either /// `FastCodePointTrie` or `SmallCodePointTrie`. #[allow(private_bounds)] // Permit sealing @@ -1463,6 +1576,22 @@ pub trait TypedCodePointTrie<'trie, T: TrieValue>: Seal { } } + /// Lookup trie value by Latin1 Code Point without branching on trie type. + #[inline(always)] + fn get8(&self, latin1: u8) -> T { + self.as_untyped_ref().get8(latin1) + } + + /// Lookup trie value by ASCII Code Point without branching on trie type. + /// + /// # Safety + /// + /// `ascii` must be less than 128. + #[inline(always)] + unsafe fn get7(&self, ascii: u8) -> T { + self.as_untyped_ref().get7(ascii) + } + /// Lookup trie value by non-Basic Multilingual Plane Scalar Value without branching on trie type. #[inline(always)] fn get32_supplementary(&self, supplementary: u32) -> T { @@ -1524,6 +1653,69 @@ pub trait TypedCodePointTrie<'trie, T: TrieValue>: Seal { } } + /// Returns the value that is associated with a two-byte UTF-8 sequence. + /// + /// `high_five` is the low five bits of the lead byte of a two-byte UTF-8 sequence. + /// `low_six` is the low six bits of the trail byte of a two-byte UTF-8 sequence. + /// + /// # Safety + /// + /// `high_five` must not have bit positions other than the lowest 5 set to 1. + /// `low_six` must not have bit positions other than the lowest 6 set to 1. + /// + /// # Panics + /// + /// With debug assertions enabled, panics if the above safety invariants are + /// violated or `high_five` represents non-shortest form. + #[inline(always)] + unsafe fn get_utf8_two_byte(&self, high_five: u32, low_six: u32) -> T { + self.as_untyped_ref().get_utf8_two_byte(high_five, low_six) + } + + /// Returns the value that is associated with a three-byte UTF-8 or WTF-8 sequence. + /// + /// `high_ten` is the low four bits of the lead byte of three-byte UTF-8 or WTF-8 sequence shifted left by 6 followed by the low six bits of the first trail byte. + /// `low_six` is the low six bits of the last trail byte of a three-byte UTF-8 or WTF-8 sequence. + /// + /// Sequences representing surrogates (WTF-8) are allowed. + /// + /// # Safety + /// + /// `high_ten` must not have bit positions other than the lowest 10 set to 1. + /// `low_six` must not have bit positions other than the lowest 6 set to 1. + /// + /// # Panics + /// + /// With debug assertions enabled, panics if the above safety invariants are + /// violated or `high_ten` is out of range for three-byte WTF-8 (or UTF-8) + /// sequence. + #[inline(always)] + #[allow(clippy::unusual_byte_groupings)] + unsafe fn get_utf8_three_byte(&self, high_ten: u32, low_six: u32) -> T { + debug_assert!(low_six <= 0b111_111); // Safety invariant. + debug_assert!(high_ten <= 0b1111_111_111); // Not actually a _safety_ invariant for this impl. + debug_assert!(high_ten > 0b11_111); // Non-shortest form; not safety invariant. + + debug_assert_eq!(Self::TRIE_TYPE, self.as_untyped_ref().header.trie_type); + let fast_max = match Self::TRIE_TYPE { + TrieType::Fast => FAST_TYPE_FAST_INDEXING_MAX, + TrieType::Small => SMALL_TYPE_FAST_INDEXING_MAX, + }; + + // Keep only the prefix bits: + let max_bit_prefix = fast_max >> FAST_TYPE_SHIFT; + if high_ten <= max_bit_prefix { + // SAFETY: The caller is responsible for upholding the safety + // invariant for `low_six` and we just checked the safety + // invariant of `high_ten`. + self.as_untyped_ref() + .get_bit_prefix_suffix_assuming_fast_index(high_ten as usize, low_six as usize) + } else { + self.as_untyped_ref() + .get32_by_small_index_cold((high_ten << 6) | low_six) + } + } + /// Returns a reference to the wrapped `CodePointTrie`. fn as_untyped_ref(&self) -> &CodePointTrie<'trie, T>; @@ -1535,12 +1727,39 @@ pub trait TypedCodePointTrie<'trie, T: TrieValue>: Seal { /// the the getters don't branch on the trie type /// and for guarenteeing that `get16` is branchless /// in release builds. -#[derive(Debug)] +#[derive(Debug, Eq, PartialEq, Yokeable, ZeroFrom, Clone)] #[repr(transparent)] pub struct FastCodePointTrie<'trie, T: TrieValue> { inner: CodePointTrie<'trie, T>, } +impl<'trie, T: TrieValue> FastCodePointTrie<'trie, T> { + #[doc(hidden)] // databake internal + /// # Safety + /// + /// `header.trie_type`, `index`, and `data` must + /// satisfy the invariants for the fields of the + /// same names on `CodePointTrie`. + pub const unsafe fn from_parts_unstable_unchecked_v1( + header: CodePointTrieHeader, + index: ZeroVec<'trie, u16>, + data: ZeroVec<'trie, T>, + error_value: T, + ) -> Self { + // Field invariants upheld: The caller is responsible. + // In practice, this means that datagen in the databake + // mode upholds these invariants when constructing the + // `CodePointTrie` that is then baked. + let untyped = CodePointTrie::<'trie, T>::from_parts_unstable_unchecked_v1( + header, + index, + data, + error_value, + ); + Self { inner: untyped } + } +} + impl<'trie, T: TrieValue> TypedCodePointTrie<'trie, T> for FastCodePointTrie<'trie, T> { const TRIE_TYPE: TrieType = TrieType::Fast; @@ -1573,6 +1792,37 @@ impl<'trie, T: TrieValue> TypedCodePointTrie<'trie, T> for FastCodePointTrie<'tr // being correct and the exclusive ways of obtaining `Self`. unsafe { self.as_untyped_ref().get32_assuming_fast_index(code_point) } } + + /// Returns the value that is associated with a three-byte UTF-8 or WTF-8 sequence. + /// + /// `high_ten` is the low four bits of the lead byte of three-byte UTF-8 or WTF-8 sequence shifted left by 6 followed by the low six bits of the first trail byte. + /// `low_six` is the low six bits of the last trail byte of a three-byte UTF-8 or WTF-8 sequence. + /// + /// Sequences representing surrogates (WTF-8) are allowed. + /// + /// # Safety + /// + /// `high_ten` must not have bit positions other than the lowest 10 set to 1. + /// `low_six` must not have bit positions other than the lowest 6 set to 1. + /// + /// # Panics + /// + /// With debug assertions enabled, panics if the above safety invariants are + /// violated or `high_ten` is out of range for three-byte WTF-8 (or UTF-8) + /// sequence. + #[inline(always)] + #[allow(clippy::unusual_byte_groupings)] + unsafe fn get_utf8_three_byte(&self, high_ten: u32, low_six: u32) -> T { + debug_assert!(low_six <= 0b111_111); // Safety invariant. + debug_assert!(high_ten <= 0b1111_111_111); // Safety invariant. + debug_assert!(high_ten > 0b11_111); // Non-shortest form; not safety invariant. + debug_assert_eq!(Self::TRIE_TYPE, TrieType::Fast); + debug_assert_eq!(self.as_untyped_ref().header.trie_type, TrieType::Fast); + // SAFETY: The highest character representable as a three-byte + // UTF-8 sequence is U+FFFF, which is `FAST_TYPE_FAST_INDEXING_MAX`. + self.inner + .get_bit_prefix_suffix_assuming_fast_index(high_ten as usize, low_six as usize) + } } impl<'trie, T: TrieValue> Seal for FastCodePointTrie<'trie, T> {} @@ -1605,14 +1855,59 @@ impl<'trie, T: TrieValue> TryFrom> for FastCodePointTrie } } +#[cfg(feature = "databake")] +impl databake::Bake for FastCodePointTrie<'_, T> { + fn bake(&self, env: &databake::CrateEnv) -> databake::TokenStream { + let header = self.inner.header.bake(env); + let index = self.inner.index.bake(env); + let data = self.inner.data.bake(env); + let error_value = self.inner.error_value.bake(env); + databake::quote! { unsafe { icu_collections::codepointtrie::FastCodePointTrie::from_parts_unstable_unchecked_v1(#header, #index, #data, #error_value) } } + } +} + +#[cfg(feature = "databake")] +impl databake::BakeSize for FastCodePointTrie<'_, T> { + fn borrows_size(&self) -> usize { + self.inner.borrows_size() + } +} + /// Type-safe wrapper for a small trie guaranteeing /// the the getters don't branch on the trie type. -#[derive(Debug)] +#[derive(Debug, Eq, PartialEq, Yokeable, ZeroFrom, Clone)] #[repr(transparent)] pub struct SmallCodePointTrie<'trie, T: TrieValue> { inner: CodePointTrie<'trie, T>, } +impl<'trie, T: TrieValue> SmallCodePointTrie<'trie, T> { + #[doc(hidden)] // databake internal + /// # Safety + /// + /// `header.trie_type`, `index`, and `data` must + /// satisfy the invariants for the fields of the + /// same names on `CodePointTrie`. + pub const unsafe fn from_parts_unstable_unchecked_v1( + header: CodePointTrieHeader, + index: ZeroVec<'trie, u16>, + data: ZeroVec<'trie, T>, + error_value: T, + ) -> Self { + // Field invariants upheld: The caller is responsible. + // In practice, this means that datagen in the databake + // mode upholds these invariants when constructing the + // `CodePointTrie` that is then baked. + let untyped = CodePointTrie::<'trie, T>::from_parts_unstable_unchecked_v1( + header, + index, + data, + error_value, + ); + Self { inner: untyped } + } +} + impl<'trie, T: TrieValue> TypedCodePointTrie<'trie, T> for SmallCodePointTrie<'trie, T> { const TRIE_TYPE: TrieType = TrieType::Small; @@ -1659,6 +1954,24 @@ impl<'trie, T: TrieValue> TryFrom> for SmallCodePointTri } } +#[cfg(feature = "databake")] +impl databake::Bake for SmallCodePointTrie<'_, T> { + fn bake(&self, env: &databake::CrateEnv) -> databake::TokenStream { + let header = self.inner.header.bake(env); + let index = self.inner.index.bake(env); + let data = self.inner.data.bake(env); + let error_value = self.inner.error_value.bake(env); + databake::quote! { unsafe { icu_collections::codepointtrie::SmallCodePointTrie::from_parts_unstable_unchecked_v1(#header, #index, #data, #error_value) } } + } +} + +#[cfg(feature = "databake")] +impl databake::BakeSize for SmallCodePointTrie<'_, T> { + fn borrows_size(&self) -> usize { + self.inner.borrows_size() + } +} + /// Error indicating that the `TrieType` of an untyped trie /// does not match the requested typed trie type. #[derive(Debug)] @@ -1676,6 +1989,194 @@ pub enum Typed { Small(S), } +/// Trait for writing trait bounds for monomorphizing over either +/// `CodePointTrie`, `FastCodePointTrie`, or `SmallCodePointTrie`. +/// +/// Method naming intentionally differs from the method naming on +/// those types in order to disambiguate. +#[allow(private_bounds)] // Permit sealing +pub trait AbstractCodePointTrie<'trie, T: TrieValue>: Seal { + /// Look up trie value by an ASCII character. + /// + /// # Safety + /// + /// `ascii` must be less than 128. + unsafe fn ascii(&self, ascii: u8) -> T; + + /// Look up trie value by a two-byte UTF-8 sequence. + /// + /// `high_five` is the low five bits of the lead byte of a two-byte UTF-8 sequence. + /// `low_six` is the low six bits of the trail byte of a two-byte UTF-8 sequence. + /// + /// # Safety + /// + /// `high_five` must not have bit positions other than the lowest 5 set to 1. + /// `low_six` must not have bit positions other than the lowest 6 set to 1. + unsafe fn utf8_two_byte(&self, high_five: u32, low_six: u32) -> T; + + /// Look up trie value by a three-byte UTF-8 or WTF-8 sequence. + /// + /// `high_ten` is the low four bits of the lead byte of three-byte UTF-8 or WTF-8 sequence shifted left by 6 followed by the low six bits of the first trail byte. + /// `low_six` is the low six bits of the last trail byte of a three-byte UTF-8 or WTF-8 sequence. + /// + /// Sequences representing surrogates (WTF-8) are allowed. + /// + /// # Safety + /// + /// `high_ten` must not have bit positions other than the lowest 10 set to 1. + /// `low_six` must not have bit positions other than the lowest 6 set to 1. + unsafe fn utf8_three_byte(&self, high_ten: u32, low_six: u32) -> T; + + /// Look up trie value by a Latin1 character. + fn latin1(&self, latin1: u8) -> T; + + /// Look up trie value by a Basic Multilingual Plane character. + /// + /// Surrogate values are allowed. + fn bmp(&self, bmp: u16) -> T; + + /// Look up trie value by a non-Basic Multilingual Plane character. + /// + /// The behavior is memory-safe nonsense if the argument is not + /// actually a non-Basic Multilingual Plane character. + fn supplementary(&self, supplementary: u32) -> T; + + /// Look up trie value by a Unicode Scalar Value. + fn scalar(&self, scalar: char) -> T; + + /// Look up trie value by Unicode Code Point. + /// + /// Surrogate values are allowed. Out of range input + /// results in the error value. + fn code_point(&self, code_point: u32) -> T; +} + +impl<'trie, T: TrieValue> AbstractCodePointTrie<'trie, T> for FastCodePointTrie<'trie, T> { + #[inline(always)] + unsafe fn ascii(&self, ascii: u8) -> T { + self.get7(ascii) + } + + #[inline(always)] + unsafe fn utf8_two_byte(&self, high_five: u32, low_six: u32) -> T { + self.get_utf8_two_byte(high_five, low_six) + } + + #[inline(always)] + unsafe fn utf8_three_byte(&self, high_ten: u32, low_six: u32) -> T { + self.get_utf8_three_byte(high_ten, low_six) + } + + #[inline(always)] + fn latin1(&self, latin1: u8) -> T { + self.get8(latin1) + } + + #[inline(always)] + fn bmp(&self, bmp: u16) -> T { + self.get16(bmp) + } + + #[inline(always)] + fn supplementary(&self, supplementary: u32) -> T { + self.get32_supplementary(supplementary) + } + + #[inline(always)] + fn scalar(&self, scalar: char) -> T { + self.get(scalar) + } + + #[inline(always)] + fn code_point(&self, code_point: u32) -> T { + self.get32(code_point) + } +} + +impl<'trie, T: TrieValue> AbstractCodePointTrie<'trie, T> for SmallCodePointTrie<'trie, T> { + #[inline(always)] + unsafe fn ascii(&self, ascii: u8) -> T { + self.get7(ascii) + } + + #[inline(always)] + unsafe fn utf8_two_byte(&self, high_five: u32, low_six: u32) -> T { + self.get_utf8_two_byte(high_five, low_six) + } + + #[inline(always)] + unsafe fn utf8_three_byte(&self, high_ten: u32, low_six: u32) -> T { + self.get_utf8_three_byte(high_ten, low_six) + } + + #[inline(always)] + fn latin1(&self, latin1: u8) -> T { + self.get8(latin1) + } + + #[inline(always)] + fn bmp(&self, bmp: u16) -> T { + self.get16(bmp) + } + + #[inline(always)] + fn supplementary(&self, supplementary: u32) -> T { + self.get32_supplementary(supplementary) + } + + #[inline(always)] + fn scalar(&self, scalar: char) -> T { + self.get(scalar) + } + + #[inline(always)] + fn code_point(&self, code_point: u32) -> T { + self.get32(code_point) + } +} + +impl<'trie, T: TrieValue> AbstractCodePointTrie<'trie, T> for CodePointTrie<'trie, T> { + #[inline(always)] + unsafe fn ascii(&self, ascii: u8) -> T { + self.get7(ascii) + } + + #[inline(always)] + unsafe fn utf8_two_byte(&self, high_five: u32, low_six: u32) -> T { + self.get_utf8_two_byte(high_five, low_six) + } + + #[inline(always)] + unsafe fn utf8_three_byte(&self, high_ten: u32, low_six: u32) -> T { + self.get_utf8_three_byte(high_ten, low_six) + } + + #[inline(always)] + fn latin1(&self, latin1: u8) -> T { + self.get8(latin1) + } + + #[inline(always)] + fn bmp(&self, bmp: u16) -> T { + self.get16(bmp) + } + + #[inline(always)] + fn supplementary(&self, supplementary: u32) -> T { + self.get32_supplementary(supplementary) + } + + #[inline(always)] + fn scalar(&self, scalar: char) -> T { + self.get(scalar) + } + + #[inline(always)] + fn code_point(&self, code_point: u32) -> T { + self.get32(code_point) + } +} + #[cfg(test)] mod tests { use super::*; @@ -1828,6 +2329,144 @@ mod tests { assert_eq!(small.get('\u{10000}'), 1); } + #[test] + #[cfg(feature = "serde")] + fn test_serde_with_postcard_roundtrip_small() -> Result<(), postcard::Error> { + let untyped = planes::get_planes_trie(); + let trie = >::try_from(untyped.clone()).unwrap(); + + let trie_serialized: Vec = postcard::to_allocvec(&trie).unwrap(); + + // Assert an expected (golden data) version of the serialized trie. + const EXP_TRIE_SERIALIZED: &[u8] = &[ + 128, 128, 64, 128, 2, 2, 0, 0, 1, 160, 18, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 136, + 2, 144, 2, 144, 2, 144, 2, 176, 2, 176, 2, 176, 2, 176, 2, 208, 2, 208, 2, 208, 2, 208, + 2, 240, 2, 240, 2, 240, 2, 240, 2, 16, 3, 16, 3, 16, 3, 16, 3, 48, 3, 48, 3, 48, 3, 48, + 3, 80, 3, 80, 3, 80, 3, 80, 3, 112, 3, 112, 3, 112, 3, 112, 3, 144, 3, 144, 3, 144, 3, + 144, 3, 176, 3, 176, 3, 176, 3, 176, 3, 208, 3, 208, 3, 208, 3, 208, 3, 240, 3, 240, 3, + 240, 3, 240, 3, 16, 4, 16, 4, 16, 4, 16, 4, 48, 4, 48, 4, 48, 4, 48, 4, 80, 4, 80, 4, + 80, 4, 80, 4, 112, 4, 112, 4, 112, 4, 112, 4, 0, 0, 16, 0, 32, 0, 48, 0, 64, 0, 80, 0, + 96, 0, 112, 0, 0, 0, 16, 0, 32, 0, 48, 0, 0, 0, 16, 0, 32, 0, 48, 0, 0, 0, 16, 0, 32, + 0, 48, 0, 0, 0, 16, 0, 32, 0, 48, 0, 0, 0, 16, 0, 32, 0, 48, 0, 0, 0, 16, 0, 32, 0, 48, + 0, 0, 0, 16, 0, 32, 0, 48, 0, 0, 0, 16, 0, 32, 0, 48, 0, 128, 0, 128, 0, 128, 0, 128, + 0, 128, 0, 128, 0, 128, 0, 128, 0, 128, 0, 128, 0, 128, 0, 128, 0, 128, 0, 128, 0, 128, + 0, 128, 0, 128, 0, 128, 0, 128, 0, 128, 0, 128, 0, 128, 0, 128, 0, 128, 0, 128, 0, 128, + 0, 128, 0, 128, 0, 128, 0, 128, 0, 128, 0, 128, 0, 144, 0, 144, 0, 144, 0, 144, 0, 144, + 0, 144, 0, 144, 0, 144, 0, 144, 0, 144, 0, 144, 0, 144, 0, 144, 0, 144, 0, 144, 0, 144, + 0, 144, 0, 144, 0, 144, 0, 144, 0, 144, 0, 144, 0, 144, 0, 144, 0, 144, 0, 144, 0, 144, + 0, 144, 0, 144, 0, 144, 0, 144, 0, 144, 0, 160, 0, 160, 0, 160, 0, 160, 0, 160, 0, 160, + 0, 160, 0, 160, 0, 160, 0, 160, 0, 160, 0, 160, 0, 160, 0, 160, 0, 160, 0, 160, 0, 160, + 0, 160, 0, 160, 0, 160, 0, 160, 0, 160, 0, 160, 0, 160, 0, 160, 0, 160, 0, 160, 0, 160, + 0, 160, 0, 160, 0, 160, 0, 160, 0, 176, 0, 176, 0, 176, 0, 176, 0, 176, 0, 176, 0, 176, + 0, 176, 0, 176, 0, 176, 0, 176, 0, 176, 0, 176, 0, 176, 0, 176, 0, 176, 0, 176, 0, 176, + 0, 176, 0, 176, 0, 176, 0, 176, 0, 176, 0, 176, 0, 176, 0, 176, 0, 176, 0, 176, 0, 176, + 0, 176, 0, 176, 0, 176, 0, 192, 0, 192, 0, 192, 0, 192, 0, 192, 0, 192, 0, 192, 0, 192, + 0, 192, 0, 192, 0, 192, 0, 192, 0, 192, 0, 192, 0, 192, 0, 192, 0, 192, 0, 192, 0, 192, + 0, 192, 0, 192, 0, 192, 0, 192, 0, 192, 0, 192, 0, 192, 0, 192, 0, 192, 0, 192, 0, 192, + 0, 192, 0, 192, 0, 208, 0, 208, 0, 208, 0, 208, 0, 208, 0, 208, 0, 208, 0, 208, 0, 208, + 0, 208, 0, 208, 0, 208, 0, 208, 0, 208, 0, 208, 0, 208, 0, 208, 0, 208, 0, 208, 0, 208, + 0, 208, 0, 208, 0, 208, 0, 208, 0, 208, 0, 208, 0, 208, 0, 208, 0, 208, 0, 208, 0, 208, + 0, 208, 0, 224, 0, 224, 0, 224, 0, 224, 0, 224, 0, 224, 0, 224, 0, 224, 0, 224, 0, 224, + 0, 224, 0, 224, 0, 224, 0, 224, 0, 224, 0, 224, 0, 224, 0, 224, 0, 224, 0, 224, 0, 224, + 0, 224, 0, 224, 0, 224, 0, 224, 0, 224, 0, 224, 0, 224, 0, 224, 0, 224, 0, 224, 0, 224, + 0, 240, 0, 240, 0, 240, 0, 240, 0, 240, 0, 240, 0, 240, 0, 240, 0, 240, 0, 240, 0, 240, + 0, 240, 0, 240, 0, 240, 0, 240, 0, 240, 0, 240, 0, 240, 0, 240, 0, 240, 0, 240, 0, 240, + 0, 240, 0, 240, 0, 240, 0, 240, 0, 240, 0, 240, 0, 240, 0, 240, 0, 240, 0, 240, 0, 0, + 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, + 1, 0, 1, 0, 1, 16, 1, 16, 1, 16, 1, 16, 1, 16, 1, 16, 1, 16, 1, 16, 1, 16, 1, 16, 1, + 16, 1, 16, 1, 16, 1, 16, 1, 16, 1, 16, 1, 16, 1, 16, 1, 16, 1, 16, 1, 16, 1, 16, 1, 16, + 1, 16, 1, 16, 1, 16, 1, 16, 1, 16, 1, 16, 1, 16, 1, 16, 1, 16, 1, 32, 1, 32, 1, 32, 1, + 32, 1, 32, 1, 32, 1, 32, 1, 32, 1, 32, 1, 32, 1, 32, 1, 32, 1, 32, 1, 32, 1, 32, 1, 32, + 1, 32, 1, 32, 1, 32, 1, 32, 1, 32, 1, 32, 1, 32, 1, 32, 1, 32, 1, 32, 1, 32, 1, 32, 1, + 32, 1, 32, 1, 32, 1, 32, 1, 48, 1, 48, 1, 48, 1, 48, 1, 48, 1, 48, 1, 48, 1, 48, 1, 48, + 1, 48, 1, 48, 1, 48, 1, 48, 1, 48, 1, 48, 1, 48, 1, 48, 1, 48, 1, 48, 1, 48, 1, 48, 1, + 48, 1, 48, 1, 48, 1, 48, 1, 48, 1, 48, 1, 48, 1, 48, 1, 48, 1, 48, 1, 48, 1, 64, 1, 64, + 1, 64, 1, 64, 1, 64, 1, 64, 1, 64, 1, 64, 1, 64, 1, 64, 1, 64, 1, 64, 1, 64, 1, 64, 1, + 64, 1, 64, 1, 64, 1, 64, 1, 64, 1, 64, 1, 64, 1, 64, 1, 64, 1, 64, 1, 64, 1, 64, 1, 64, + 1, 64, 1, 64, 1, 64, 1, 64, 1, 64, 1, 80, 1, 80, 1, 80, 1, 80, 1, 80, 1, 80, 1, 80, 1, + 80, 1, 80, 1, 80, 1, 80, 1, 80, 1, 80, 1, 80, 1, 80, 1, 80, 1, 80, 1, 80, 1, 80, 1, 80, + 1, 80, 1, 80, 1, 80, 1, 80, 1, 80, 1, 80, 1, 80, 1, 80, 1, 80, 1, 80, 1, 80, 1, 80, 1, + 96, 1, 96, 1, 96, 1, 96, 1, 96, 1, 96, 1, 96, 1, 96, 1, 96, 1, 96, 1, 96, 1, 96, 1, 96, + 1, 96, 1, 96, 1, 96, 1, 96, 1, 96, 1, 96, 1, 96, 1, 96, 1, 96, 1, 96, 1, 96, 1, 96, 1, + 96, 1, 96, 1, 96, 1, 96, 1, 96, 1, 96, 1, 96, 1, 128, 0, 136, 0, 136, 0, 136, 0, 136, + 0, 136, 0, 136, 0, 136, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, + 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, + 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 168, 0, 168, 0, 168, 0, 168, 0, 168, 0, + 168, 0, 168, 0, 168, 0, 168, 0, 168, 0, 168, 0, 168, 0, 168, 0, 168, 0, 168, 0, 168, 0, + 168, 0, 168, 0, 168, 0, 168, 0, 168, 0, 168, 0, 168, 0, 168, 0, 168, 0, 168, 0, 168, 0, + 168, 0, 168, 0, 168, 0, 168, 0, 168, 0, 200, 0, 200, 0, 200, 0, 200, 0, 200, 0, 200, 0, + 200, 0, 200, 0, 200, 0, 200, 0, 200, 0, 200, 0, 200, 0, 200, 0, 200, 0, 200, 0, 200, 0, + 200, 0, 200, 0, 200, 0, 200, 0, 200, 0, 200, 0, 200, 0, 200, 0, 200, 0, 200, 0, 200, 0, + 200, 0, 200, 0, 200, 0, 200, 0, 232, 0, 232, 0, 232, 0, 232, 0, 232, 0, 232, 0, 232, 0, + 232, 0, 232, 0, 232, 0, 232, 0, 232, 0, 232, 0, 232, 0, 232, 0, 232, 0, 232, 0, 232, 0, + 232, 0, 232, 0, 232, 0, 232, 0, 232, 0, 232, 0, 232, 0, 232, 0, 232, 0, 232, 0, 232, 0, + 232, 0, 232, 0, 232, 0, 8, 1, 8, 1, 8, 1, 8, 1, 8, 1, 8, 1, 8, 1, 8, 1, 8, 1, 8, 1, 8, + 1, 8, 1, 8, 1, 8, 1, 8, 1, 8, 1, 8, 1, 8, 1, 8, 1, 8, 1, 8, 1, 8, 1, 8, 1, 8, 1, 8, 1, + 8, 1, 8, 1, 8, 1, 8, 1, 8, 1, 8, 1, 8, 1, 40, 1, 40, 1, 40, 1, 40, 1, 40, 1, 40, 1, 40, + 1, 40, 1, 40, 1, 40, 1, 40, 1, 40, 1, 40, 1, 40, 1, 40, 1, 40, 1, 40, 1, 40, 1, 40, 1, + 40, 1, 40, 1, 40, 1, 40, 1, 40, 1, 40, 1, 40, 1, 40, 1, 40, 1, 40, 1, 40, 1, 40, 1, 40, + 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 1, + 72, 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, + 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 1, 104, 1, 104, 1, 104, 1, 104, 1, + 104, 1, 104, 1, 104, 1, 104, 1, 104, 1, 104, 1, 104, 1, 104, 1, 104, 1, 104, 1, 104, 1, + 104, 1, 104, 1, 104, 1, 104, 1, 104, 1, 104, 1, 104, 1, 104, 1, 104, 1, 104, 1, 104, 1, + 104, 1, 104, 1, 104, 1, 104, 1, 104, 1, 104, 1, 136, 1, 136, 1, 136, 1, 136, 1, 136, 1, + 136, 1, 136, 1, 136, 1, 136, 1, 136, 1, 136, 1, 136, 1, 136, 1, 136, 1, 136, 1, 136, 1, + 136, 1, 136, 1, 136, 1, 136, 1, 136, 1, 136, 1, 136, 1, 136, 1, 136, 1, 136, 1, 136, 1, + 136, 1, 136, 1, 136, 1, 136, 1, 136, 1, 168, 1, 168, 1, 168, 1, 168, 1, 168, 1, 168, 1, + 168, 1, 168, 1, 168, 1, 168, 1, 168, 1, 168, 1, 168, 1, 168, 1, 168, 1, 168, 1, 168, 1, + 168, 1, 168, 1, 168, 1, 168, 1, 168, 1, 168, 1, 168, 1, 168, 1, 168, 1, 168, 1, 168, 1, + 168, 1, 168, 1, 168, 1, 168, 1, 200, 1, 200, 1, 200, 1, 200, 1, 200, 1, 200, 1, 200, 1, + 200, 1, 200, 1, 200, 1, 200, 1, 200, 1, 200, 1, 200, 1, 200, 1, 200, 1, 200, 1, 200, 1, + 200, 1, 200, 1, 200, 1, 200, 1, 200, 1, 200, 1, 200, 1, 200, 1, 200, 1, 200, 1, 200, 1, + 200, 1, 200, 1, 200, 1, 232, 1, 232, 1, 232, 1, 232, 1, 232, 1, 232, 1, 232, 1, 232, 1, + 232, 1, 232, 1, 232, 1, 232, 1, 232, 1, 232, 1, 232, 1, 232, 1, 232, 1, 232, 1, 232, 1, + 232, 1, 232, 1, 232, 1, 232, 1, 232, 1, 232, 1, 232, 1, 232, 1, 232, 1, 232, 1, 232, 1, + 232, 1, 232, 1, 8, 2, 8, 2, 8, 2, 8, 2, 8, 2, 8, 2, 8, 2, 8, 2, 8, 2, 8, 2, 8, 2, 8, 2, + 8, 2, 8, 2, 8, 2, 8, 2, 8, 2, 8, 2, 8, 2, 8, 2, 8, 2, 8, 2, 8, 2, 8, 2, 8, 2, 8, 2, 8, + 2, 8, 2, 8, 2, 8, 2, 8, 2, 8, 2, 40, 2, 40, 2, 40, 2, 40, 2, 40, 2, 40, 2, 40, 2, 40, + 2, 40, 2, 40, 2, 40, 2, 40, 2, 40, 2, 40, 2, 40, 2, 40, 2, 40, 2, 40, 2, 40, 2, 40, 2, + 40, 2, 40, 2, 40, 2, 40, 2, 40, 2, 40, 2, 40, 2, 40, 2, 40, 2, 40, 2, 40, 2, 40, 2, 72, + 2, 72, 2, 72, 2, 72, 2, 72, 2, 72, 2, 72, 2, 72, 2, 72, 2, 72, 2, 72, 2, 72, 2, 72, 2, + 72, 2, 72, 2, 72, 2, 72, 2, 72, 2, 72, 2, 72, 2, 72, 2, 72, 2, 72, 2, 72, 2, 72, 2, 72, + 2, 72, 2, 72, 2, 72, 2, 72, 2, 72, 2, 72, 2, 104, 2, 104, 2, 104, 2, 104, 2, 104, 2, + 104, 2, 104, 2, 104, 2, 104, 2, 104, 2, 104, 2, 104, 2, 104, 2, 104, 2, 104, 2, 104, 2, + 104, 2, 104, 2, 104, 2, 104, 2, 104, 2, 104, 2, 104, 2, 104, 2, 104, 2, 104, 2, 104, 2, + 104, 2, 104, 2, 104, 2, 104, 2, 104, 2, 244, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, + 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, + 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, + 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, + 15, 15, 15, 15, 15, 15, 15, 16, 16, 16, 0, + ]; + assert_eq!(trie_serialized, EXP_TRIE_SERIALIZED); + + let trie_deserialized = postcard::from_bytes::>(&trie_serialized)?; + + let trie_deserialized_untyped = trie_deserialized.as_untyped_ref(); + assert_eq!(&untyped.index, &trie_deserialized_untyped.index); + assert_eq!(&untyped.data, &trie_deserialized_untyped.data); + + assert!(!trie_deserialized_untyped.index.is_owned()); + assert!(!trie_deserialized_untyped.data.is_owned()); + + Ok(()) + } + #[test] fn test_get_range() { let planes_trie = planes::get_planes_trie(); @@ -1894,4 +2533,56 @@ mod tests { [zerovec], ); } + + #[test] + #[allow(unused_unsafe)] // `unsafe` below is both necessary and unnecessary + fn databake_small() { + databake::test_bake!( + SmallCodePointTrie<'static, u32>, + const, + unsafe { + crate::codepointtrie::SmallCodePointTrie::from_parts_unstable_unchecked_v1( + crate::codepointtrie::CodePointTrieHeader { + high_start: 1u32, + shifted12_high_start: 2u16, + index3_null_offset: 3u16, + data_null_offset: 4u32, + null_value: 5u32, + trie_type: crate::codepointtrie::TrieType::Small, + }, + zerovec::ZeroVec::new(), + zerovec::ZeroVec::new(), + 0u32, + ) + }, + icu_collections, + [zerovec], + ); + } + + #[test] + #[allow(unused_unsafe)] // `unsafe` below is both necessary and unnecessary + fn databake_fast() { + databake::test_bake!( + FastCodePointTrie<'static, u32>, + const, + unsafe { + crate::codepointtrie::FastCodePointTrie::from_parts_unstable_unchecked_v1( + crate::codepointtrie::CodePointTrieHeader { + high_start: 1u32, + shifted12_high_start: 2u16, + index3_null_offset: 3u16, + data_null_offset: 4u32, + null_value: 5u32, + trie_type: crate::codepointtrie::TrieType::Fast, + }, + zerovec::ZeroVec::new(), + zerovec::ZeroVec::new(), + 0u32, + ) + }, + icu_collections, + [zerovec], + ); + } } diff --git a/components/collections/src/codepointtrie/error.rs b/components/collections/src/codepointtrie/error.rs index 383949d9561..4168c3f1de8 100644 --- a/components/collections/src/codepointtrie/error.rs +++ b/components/collections/src/codepointtrie/error.rs @@ -25,6 +25,9 @@ pub enum Error { /// [`CodePointTrie`](super::CodePointTrie) must be constructed from data vector long enough to accommodate fast-path access #[displaydoc("CodePointTrie must be constructed from data vector long enough to accommodate fast-path access")] DataTooShortForFastAccess, + /// [`CodePointTrie`](super::CodePointTrie) must be constructed from data vector long enough to accommodate direct ASCII access + #[displaydoc("CodePointTrie must be constructed from data vector long enough to accommodate direct ASCII access")] + DataTooShortForAsciiAccess, } impl core::error::Error for Error {} diff --git a/components/collections/src/codepointtrie/iter.rs b/components/collections/src/codepointtrie/iter.rs new file mode 100644 index 00000000000..900fa5bd029 --- /dev/null +++ b/components/collections/src/codepointtrie/iter.rs @@ -0,0 +1,1288 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +use core::iter::FusedIterator; +use core::marker::PhantomData; + +use crate::codepointtrie::AbstractCodePointTrie; +use crate::codepointtrie::TrieValue; + +/// Provides a trie accessor for types (likely iterators) +/// that are holding a reference to a type that implements +/// `AbstractCodePointTrie`. +pub trait WithTrie<'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + /// Get a reference to the trie. + fn trie(&self) -> &'trie T; +} + +/// Iterator over `str` by `char` and `TrieValue`. +#[derive(Debug)] +pub struct CharsWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + /// # Safety Invariant + /// + /// `delegate` must represent a slice of well-formed UTF-8, + /// except temporarily within the implementation of `next()` and + /// `next_back()`. In particular, `next()` and `next_back()` may + /// assume that this invariant holds upon entry into `next()`/`next_back()` + /// and the invariant must hold upon return from `next()`/`next_back()`. + delegate: core::slice::Iter<'slice, u8>, + trie: &'trie T, + phantom: PhantomData, +} + +impl<'slice, 'trie, T, V> CharsWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + /// Construct a new `CharsWithTrie`. + #[inline] + pub fn new(s: &'slice str, trie: &'trie T) -> Self { + Self { + // Field invariant upheld: `s` is well-formed UTF-8, + // so `delegate` is as well. + delegate: s.as_bytes().iter(), + trie, + phantom: PhantomData, + } + } + + /// Obtains the remainder of the iterator as a string slice. + #[inline] + pub fn as_str(&self) -> &'slice str { + // SAFETY: OK due to field invariant of `delegate` guaranteeing + // well-formed UTF-8. + unsafe { core::str::from_utf8_unchecked(self.delegate.as_slice()) } + } +} + +impl<'slice, 'trie, T, V> Clone for CharsWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + #[inline] + fn clone(&self) -> Self { + Self { + // Field invariant upheld: Clone of well-formed UTF-8 is well-formed UTF-8. + delegate: self.delegate.clone(), + trie: self.trie, + phantom: PhantomData, + } + } +} + +impl<'slice, 'trie, T, V> WithTrie<'trie, T, V> for CharsWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + #[inline] + fn trie(&self) -> &'trie T { + self.trie + } +} + +impl<'slice, 'trie, T, V> Iterator for CharsWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + type Item = (char, V); + + #[inline] + fn next(&mut self) -> Option { + let lead = *self.delegate.next()?; + if lead < 0x80 { + // SAFETY: We checked the invariant of `ascii` immediately + // above. + // Field invariant upheld: `delegate` is again well-formed UTF-8, because we consumed + // a complete one-byte sequence. + return Some((char::from(lead), unsafe { self.trie.ascii(lead) })); + } + // SAFETY: Due to the field invariant of `delegate`, we may assume that we + // have a valid lead byte. Not need to check for other cases. + if lead < 0xE0 { + // Two-byte sequence. + // SAFETY: The field invariant of `delegate` providing UTF-8 well-formedness upon entry into this method and + // having seen a two-byte lead we may assume the presence of a trail byte. + let trail = *unsafe { self.delegate.next().unwrap_unchecked() }; + let high_five = u32::from(lead & 0b11_111); + let low_six = u32::from(trail & 0b111_111); + // SAFETY: By construction and the two above safety remarks, `high_five` and + // `low_six` conform to the invariant of `utf8_two_byte`. + let v = unsafe { self.trie.utf8_two_byte(high_five, low_six) }; + // SAFETY: Due to the field invariant of `delegate`, `lead` must be a + // valid (not overlong) two-byte lead and `trail` must be a valid + // trail. Therefore, the following shift and OR stays in the + // scalar value range. + let c = unsafe { char::from_u32_unchecked((high_five << 6) | low_six) }; + // Field invariant upheld: `delegate` is again well-formed UTF-8, because we consumed + // a complete two-byte sequence. + return Some((c, v)); + } + if lead < 0xF0 { + // Three-byte sequence. + // SAFETY: The field invariant of `delegate` providing UTF-8 well-formedness upon entry into this method and + // having seen a two-byte lead we may assume the presence of two trail bytes. + let second = *unsafe { self.delegate.next().unwrap_unchecked() }; + let third = *unsafe { self.delegate.next().unwrap_unchecked() }; + let high_ten = (u32::from(lead & 0b1111) << 6) | u32::from(second & 0b111_111); + let low_six = u32::from(third & 0b111_111); + // SAFETY: By construction and the safety remarks on the path to this point, + // `high_ten` and `low_six` conform to the invariant of `utf8_three_byte`. + let v = unsafe { self.trie.utf8_three_byte(high_ten, low_six) }; + // SAFETY: Due to the field invariant of `delegate`, `lead` must be a + // valid (not overlong) three-byte lead and `second` and `third` + // must be valid trails. Therefore, the following shift and OR + // stays in the scalar value range. + let c = unsafe { char::from_u32_unchecked((high_ten << 6) | low_six) }; + // Field invariant upheld: `delegate` is again well-formed UTF-8, because we consumed + // a complete three-byte sequence. + return Some((c, v)); + } + // Four-byte sequence + // SAFETY: The field invariant of `delegate` providing UTF-8 well-formedness upon entry into this method and + // having seen a two-byte lead we may assume the presence of three trail bytes. + let second = *unsafe { self.delegate.next().unwrap_unchecked() }; + let third = *unsafe { self.delegate.next().unwrap_unchecked() }; + let fourth = *unsafe { self.delegate.next().unwrap_unchecked() }; + // SAFETY: Due to the field invariant of `delegate`, `lead` must be a + // valid (not overlong or out-of-range) four-byte lead and `second`, + // `third`, and `fourth` must be valid trails. Therefore, the + // following shift and OR stays in the scalar value range. + let c = unsafe { + char::from_u32_unchecked( + (u32::from(lead & 0b111) << 18) + | (u32::from(second & 0b111_111) << 12) + | (u32::from(third & 0b111_111) << 6) + | u32::from(fourth & 0b111_111), + ) + }; + // Field invariant upheld: `delegate` is again well-formed UTF-8, because we consumed + // a complete four-byte sequence. + Some((c, self.trie.supplementary(c as u32))) + } + + #[inline] + fn count(self) -> usize { + self.as_str().chars().count() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.as_str().chars().size_hint() + } + + #[inline] + fn last(mut self) -> Option { + self.next_back() + } + + // TODO: Delegate advance_by to `Chars` once stabilized. +} + +impl<'slice, 'trie, T, V> DoubleEndedIterator for CharsWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + #[inline] + fn next_back(&mut self) -> Option { + let last = *self.delegate.next_back()?; + if last < 0x80 { + // SAFETY: We checked the invariant of `ascii` immediately + // above. + // Field invariant upheld: `delegate` is again well-formed UTF-8, because we consumed + // a complete one-byte sequence. + return Some((char::from(last), unsafe { self.trie.ascii(last) })); + } + // SAFETY Due to the field invariant of `delegate`, + // `last` must be a valid trail byte and it is preceded either by a lead byte for a + // two-byte sequence or by another trail byte. + let second_last = *unsafe { self.delegate.next_back().unwrap_unchecked() }; + if second_last >= 0b1100_0000 { + // Two-byte sequence. + let high_five = u32::from(second_last & 0b11_111); + let low_six = u32::from(last & 0b111_111); + // SAFETY: By construction and the two above safety remarks, `high_five` and `low_six` conform + // to the invariant of `utf8_two_byte`. + let v = unsafe { self.trie.utf8_two_byte(high_five, low_six) }; + // SAFETY: Due to the field invariant of `delegate`, `second_last` must be a + // valid (not overlong) two-byte lead and `last` must be a valid + // trail. Therefore, the following shift and OR stays in the + // scalar value range. + let c = unsafe { char::from_u32_unchecked((high_five << 6) | low_six) }; + // Field invariant upheld: `delegate` is again well-formed UTF-8, because we consumed + // a complete two-byte sequence. + return Some((c, v)); + } + // SAFETY: Due to the field invariant of `delegate`, + // `second_last` must be a valid trail byte and it is preceded either by a lead byte for a + // three-byte sequence or by another trail byte. + let third_last = *unsafe { self.delegate.next_back().unwrap_unchecked() }; + if third_last >= 0b1100_0000 { + // Three-byte sequence + let high_ten = + (u32::from(third_last & 0b1111) << 6) | u32::from(second_last & 0b111_111); + let low_six = u32::from(last & 0b111_111); + // SAFETY: By construction and the safety remarks on the path to this point, `high_ten` and `low_six` conform + // to the invariant of `utf8_three_byte`. + let v = unsafe { self.trie.utf8_three_byte(high_ten, low_six) }; + // SAFETY: Due to the field invariant of `delegate`, `third_last` must be a + // valid (not overlong) three-byte lead and `second_last` and `last` + // must be valid trails. Therefore, the following shift and OR + // stays in the scalar value range. + let c = unsafe { char::from_u32_unchecked((high_ten << 6) | low_six) }; + // Field invariant upheld: `delegate` is again well-formed UTF-8, because we consumed + // a complete three-byte sequence. + return Some((c, v)); + } + // Four-byte sequence + // SAFETY: Due to the field invariant of `delegate`, we may assume the + // presence of a lead byte. + let lead = *unsafe { self.delegate.next_back().unwrap_unchecked() }; + // SAFETY: Due to the field invariant of `delegate`, `lead` must be a + // valid (not overlong or out-of-range) four-byte lead and `third_last`, + // `second_last`, and `last` must be valid trails. Therefore, the + // following shift and OR stays in the scalar value range. + let c = unsafe { + char::from_u32_unchecked( + (u32::from(lead & 0b111) << 18) + | (u32::from(third_last & 0b111_111) << 12) + | (u32::from(second_last & 0b111_111) << 6) + | u32::from(last & 0b111_111), + ) + }; + // Field invariant upheld: `delegate` is again well-formed UTF-8, because we consumed + // a complete four-byte sequence. + Some((c, self.trie.supplementary(c as u32))) + } +} + +impl<'slice, 'trie, T, V> FusedIterator for CharsWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ +} +// -- + +/// Iterator over `str` by `char` and `TrieValue`. +#[derive(Debug)] +pub struct CharIndicesWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + offset: usize, + delegate: CharsWithTrie<'slice, 'trie, T, V>, +} + +impl<'slice, 'trie, T, V> CharIndicesWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + /// Construct a new `CharIndicesWithTrie`. + #[inline] + pub fn new(s: &'slice str, trie: &'trie T) -> Self { + Self { + offset: 0, + delegate: CharsWithTrie::new(s, trie), + } + } + + /// Obtains the remainder of the iterator as a string slice. + #[inline] + pub fn as_str(&self) -> &'slice str { + self.delegate.as_str() + } +} + +impl<'slice, 'trie, T, V> Clone for CharIndicesWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + #[inline] + fn clone(&self) -> Self { + Self { + offset: self.offset, + delegate: self.delegate.clone(), + } + } +} + +impl<'slice, 'trie, T, V> WithTrie<'trie, T, V> for CharIndicesWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + #[inline] + fn trie(&self) -> &'trie T { + self.delegate.trie() + } +} + +impl<'slice, 'trie, T, V> Iterator for CharIndicesWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + type Item = (usize, char, V); + + #[inline] + fn next(&mut self) -> Option { + let old_len = self.as_str().len(); + let (c, v) = self.delegate.next()?; + let old_offset = self.offset; + self.offset += old_len - self.as_str().len(); + Some((old_offset, c, v)) + } + + #[inline] + fn count(self) -> usize { + self.as_str().chars().count() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.as_str().chars().size_hint() + } + + #[inline] + fn last(mut self) -> Option { + self.next_back() + } + + // TODO: Delegate advance_by to `Chars` once stabilized. +} + +impl<'slice, 'trie, T, V> DoubleEndedIterator for CharIndicesWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + #[inline] + fn next_back(&mut self) -> Option { + let (c, v) = self.delegate.next_back()?; + Some((self.offset + self.as_str().len(), c, v)) + } +} + +impl<'slice, 'trie, T, V> FusedIterator for CharIndicesWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ +} + +// -- + +/// Adds convenience methods to `str`. +pub trait CharsWithTrieEx<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + /// Method for easily creating `CharsWithTrie` on `str` analogously to `chars()`. + fn chars_with_trie(&'slice self, trie: &'trie T) -> CharsWithTrie<'slice, 'trie, T, V>; + + /// Method for easily creating `CharIndicesWithTrie` on `str` analogously to `char_indices()`. + fn char_indices_with_trie( + &'slice self, + trie: &'trie T, + ) -> CharIndicesWithTrie<'slice, 'trie, T, V>; +} + +impl<'slice, 'trie, T, V> CharsWithTrieEx<'slice, 'trie, T, V> for str +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + /// Method for easily creating `CharsWithTrie` on `str` analogously to `chars()`. + #[inline] + fn chars_with_trie(&'slice self, trie: &'trie T) -> CharsWithTrie<'slice, 'trie, T, V> { + CharsWithTrie::new(self, trie) + } + + /// Method for easily creating `CharIndicesWithTrie` on `str` analogously to `char_indices()`. + #[inline] + fn char_indices_with_trie( + &'slice self, + trie: &'trie T, + ) -> CharIndicesWithTrie<'slice, 'trie, T, V> { + CharIndicesWithTrie::new(self, trie) + } +} + +// -- + +/// Iterator over `str` by `char` and `TrieValue` but +/// the trie value for ASCII is `V::default()` instead of +/// reading from the trie. (`V::default()` can be optimized +/// on at compile time while reading the trie's default value +/// is a run-time operation.) +#[derive(Debug)] +pub struct CharsWithTrieDefaultForAscii<'slice, 'trie, T, V> +where + V: TrieValue + Default, + T: AbstractCodePointTrie<'trie, V>, +{ + /// # Safety Invariant + /// + /// `delegate` must represent a slice of well-formed UTF-8, + /// except temporarily within the implementation of `next()` and + /// `next_back()`. In particular, `next()` and `next_back()` may + /// assume that this invariant holds upon entry into `next()`/`next_back()` + /// and the invariant must hold upon return from `next()`/`next_back()`. + /// + /// Note: All-safety-relavant code is copypaste from `CharsWithTrie`. + delegate: core::slice::Iter<'slice, u8>, + trie: &'trie T, + phantom: PhantomData, +} + +impl<'slice, 'trie, T, V> CharsWithTrieDefaultForAscii<'slice, 'trie, T, V> +where + V: TrieValue + Default, + T: AbstractCodePointTrie<'trie, V>, +{ + /// Construct a new `CharsWithTrieDefaultForAscii`. + #[inline] + pub fn new(s: &'slice str, trie: &'trie T) -> Self { + Self { + // Field invariant upheld: `s` is well-formed UTF-8, + // so `delegate` is as well. + delegate: s.as_bytes().iter(), + trie, + phantom: PhantomData, + } + } + + /// Obtains the remainder of the iterator as a string slice. + #[inline] + pub fn as_str(&self) -> &'slice str { + // SAFETY: OK due to field invariant of `delegate` guaranteeing + // well-formed UTF-8. + unsafe { core::str::from_utf8_unchecked(self.delegate.as_slice()) } + } +} + +impl<'slice, 'trie, T, V> Clone for CharsWithTrieDefaultForAscii<'slice, 'trie, T, V> +where + V: TrieValue + Default, + T: AbstractCodePointTrie<'trie, V>, +{ + #[inline] + fn clone(&self) -> Self { + Self { + // Field invariant upheld: Clone of well-formed UTF-8 is well-formed UTF-8. + delegate: self.delegate.clone(), + trie: self.trie, + phantom: PhantomData, + } + } +} + +impl<'slice, 'trie, T, V> WithTrie<'trie, T, V> + for CharsWithTrieDefaultForAscii<'slice, 'trie, T, V> +where + V: TrieValue + Default, + T: AbstractCodePointTrie<'trie, V>, +{ + #[inline] + fn trie(&self) -> &'trie T { + self.trie + } +} + +impl<'slice, 'trie, T, V> Iterator for CharsWithTrieDefaultForAscii<'slice, 'trie, T, V> +where + V: TrieValue + Default, + T: AbstractCodePointTrie<'trie, V>, +{ + type Item = (char, V); + + #[inline] + fn next(&mut self) -> Option { + let lead = *self.delegate.next()?; + if lead < 0x80 { + // Field invariant upheld: `delegate` is again well-formed UTF-8, because we consumed + // a complete one-byte sequence. + return Some((char::from(lead), V::default())); + } + // SAFETY: Due to the field invariant of `delegate`, we may assume that we + // have a valid lead byte. Not need to check for other cases. + if lead < 0xE0 { + // Two-byte sequence. + // SAFETY: The field invariant of `delegate` providing UTF-8 well-formedness upon entry into this method and + // having seen a two-byte lead we may assume the presence of a trail byte. + let trail = *unsafe { self.delegate.next().unwrap_unchecked() }; + let high_five = u32::from(lead & 0b11_111); + let low_six = u32::from(trail & 0b111_111); + // SAFETY: By construction and the two above safety remarks, `high_five` and + // `low_six` conform to the invariant of `utf8_two_byte`. + let v = unsafe { self.trie.utf8_two_byte(high_five, low_six) }; + // SAFETY: Due to the field invariant of `delegate`, `lead` must be a + // valid (not overlong) two-byte lead and `trail` must be a valid + // trail. Therefore, the following shift and OR stays in the + // scalar value range. + let c = unsafe { char::from_u32_unchecked((high_five << 6) | low_six) }; + // Field invariant upheld: `delegate` is again well-formed UTF-8, because we consumed + // a complete two-byte sequence. + return Some((c, v)); + } + if lead < 0xF0 { + // Three-byte sequence. + // SAFETY: The field invariant of `delegate` providing UTF-8 well-formedness upon entry into this method and + // having seen a two-byte lead we may assume the presence of two trail bytes. + let second = *unsafe { self.delegate.next().unwrap_unchecked() }; + let third = *unsafe { self.delegate.next().unwrap_unchecked() }; + let high_ten = (u32::from(lead & 0b1111) << 6) | u32::from(second & 0b111_111); + let low_six = u32::from(third & 0b111_111); + // SAFETY: By construction and the safety remarks on the path to this point, + // `high_ten` and `low_six` conform to the invariant of `utf8_three_byte`. + let v = unsafe { self.trie.utf8_three_byte(high_ten, low_six) }; + // SAFETY: Due to the field invariant of `delegate`, `lead` must be a + // valid (not overlong) three-byte lead and `second` and `third` + // must be valid trails. Therefore, the following shift and OR + // stays in the scalar value range. + let c = unsafe { char::from_u32_unchecked((high_ten << 6) | low_six) }; + // Field invariant upheld: `delegate` is again well-formed UTF-8, because we consumed + // a complete three-byte sequence. + return Some((c, v)); + } + // Four-byte sequence + // SAFETY: The field invariant of `delegate` providing UTF-8 well-formedness upon entry into this method and + // having seen a two-byte lead we may assume the presence of three trail bytes. + let second = *unsafe { self.delegate.next().unwrap_unchecked() }; + let third = *unsafe { self.delegate.next().unwrap_unchecked() }; + let fourth = *unsafe { self.delegate.next().unwrap_unchecked() }; + // SAFETY: Due to the field invariant of `delegate`, `lead` must be a + // valid (not overlong or out-of-range) four-byte lead and `second`, + // `third`, and `fourth` must be valid trails. Therefore, the + // following shift and OR stays in the scalar value range. + let c = unsafe { + char::from_u32_unchecked( + (u32::from(lead & 0b111) << 18) + | (u32::from(second & 0b111_111) << 12) + | (u32::from(third & 0b111_111) << 6) + | u32::from(fourth & 0b111_111), + ) + }; + // Field invariant upheld: `delegate` is again well-formed UTF-8, because we consumed + // a complete four-byte sequence. + Some((c, self.trie.supplementary(c as u32))) + } + + #[inline] + fn count(self) -> usize { + self.as_str().chars().count() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.as_str().chars().size_hint() + } + + #[inline] + fn last(mut self) -> Option { + self.next_back() + } + + // TODO: Delegate advance_by to `Chars` once stabilized. +} + +impl<'slice, 'trie, T, V> DoubleEndedIterator for CharsWithTrieDefaultForAscii<'slice, 'trie, T, V> +where + V: TrieValue + Default, + T: AbstractCodePointTrie<'trie, V>, +{ + #[inline] + fn next_back(&mut self) -> Option { + let last = *self.delegate.next_back()?; + if last < 0x80 { + // Field invariant upheld: `delegate` is again well-formed UTF-8, because we consumed + // a complete one-byte sequence. + return Some((char::from(last), V::default())); + } + // SAFETY Due to the field invariant of `delegate`, + // `last` must be a valid trail byte and it is preceded either by a lead byte for a + // two-byte sequence or by another trail byte. + let second_last = *unsafe { self.delegate.next_back().unwrap_unchecked() }; + if second_last >= 0b1100_0000 { + // Two-byte sequence. + let high_five = u32::from(second_last & 0b11_111); + let low_six = u32::from(last & 0b111_111); + // SAFETY: By construction and the two above safety remarks, `high_five` and `low_six` conform + // to the invariant of `utf8_two_byte`. + let v = unsafe { self.trie.utf8_two_byte(high_five, low_six) }; + // SAFETY: Due to the field invariant of `delegate`, `second_last` must be a + // valid (not overlong) two-byte lead and `last` must be a valid + // trail. Therefore, the following shift and OR stays in the + // scalar value range. + let c = unsafe { char::from_u32_unchecked((high_five << 6) | low_six) }; + // Field invariant upheld: `delegate` is again well-formed UTF-8, because we consumed + // a complete two-byte sequence. + return Some((c, v)); + } + // SAFETY: Due to the field invariant of `delegate`, + // `second_last` must be a valid trail byte and it is preceded either by a lead byte for a + // three-byte sequence or by another trail byte. + let third_last = *unsafe { self.delegate.next_back().unwrap_unchecked() }; + if third_last >= 0b1100_0000 { + // Three-byte sequence + let high_ten = + (u32::from(third_last & 0b1111) << 6) | u32::from(second_last & 0b111_111); + let low_six = u32::from(last & 0b111_111); + // SAFETY: By construction and the safety remarks on the path to this point, `high_ten` and `low_six` conform + // to the invariant of `utf8_three_byte`. + let v = unsafe { self.trie.utf8_three_byte(high_ten, low_six) }; + // SAFETY: Due to the field invariant of `delegate`, `third_last` must be a + // valid (not overlong) three-byte lead and `second_last` and `last` + // must be valid trails. Therefore, the following shift and OR + // stays in the scalar value range. + let c = unsafe { char::from_u32_unchecked((high_ten << 6) | low_six) }; + // Field invariant upheld: `delegate` is again well-formed UTF-8, because we consumed + // a complete three-byte sequence. + return Some((c, v)); + } + // Four-byte sequence + // SAFETY: Due to the field invariant of `delegate`, we may assume the + // presence of a lead byte. + let lead = *unsafe { self.delegate.next_back().unwrap_unchecked() }; + // SAFETY: Due to the field invariant of `delegate`, `lead` must be a + // valid (not overlong or out-of-range) four-byte lead and `third_last`, + // `second_last`, and `last` must be valid trails. Therefore, the + // following shift and OR stays in the scalar value range. + let c = unsafe { + char::from_u32_unchecked( + (u32::from(lead & 0b111) << 18) + | (u32::from(third_last & 0b111_111) << 12) + | (u32::from(second_last & 0b111_111) << 6) + | u32::from(last & 0b111_111), + ) + }; + // Field invariant upheld: `delegate` is again well-formed UTF-8, because we consumed + // a complete four-byte sequence. + Some((c, self.trie.supplementary(c as u32))) + } +} + +impl<'slice, 'trie, T, V> FusedIterator for CharsWithTrieDefaultForAscii<'slice, 'trie, T, V> +where + V: TrieValue + Default, + T: AbstractCodePointTrie<'trie, V>, +{ +} +// -- + +/// Iterator over `str` by `char` and `TrieValue`. +#[derive(Debug)] +pub struct CharIndicesWithTrieDefaultForAscii<'slice, 'trie, T, V> +where + V: TrieValue + Default, + T: AbstractCodePointTrie<'trie, V>, +{ + offset: usize, + delegate: CharsWithTrieDefaultForAscii<'slice, 'trie, T, V>, +} + +impl<'slice, 'trie, T, V> CharIndicesWithTrieDefaultForAscii<'slice, 'trie, T, V> +where + V: TrieValue + Default, + T: AbstractCodePointTrie<'trie, V>, +{ + /// Construct a new `CharIndicesWithTrieDefaultForAscii`. + #[inline] + pub fn new(s: &'slice str, trie: &'trie T) -> Self { + Self { + offset: 0, + delegate: CharsWithTrieDefaultForAscii::new(s, trie), + } + } + + /// Obtains the remainder of the iterator as a string slice. + #[inline] + pub fn as_str(&self) -> &'slice str { + self.delegate.as_str() + } +} + +impl<'slice, 'trie, T, V> Clone for CharIndicesWithTrieDefaultForAscii<'slice, 'trie, T, V> +where + V: TrieValue + Default, + T: AbstractCodePointTrie<'trie, V>, +{ + #[inline] + fn clone(&self) -> Self { + Self { + offset: self.offset, + delegate: self.delegate.clone(), + } + } +} + +impl<'slice, 'trie, T, V> WithTrie<'trie, T, V> + for CharIndicesWithTrieDefaultForAscii<'slice, 'trie, T, V> +where + V: TrieValue + Default, + T: AbstractCodePointTrie<'trie, V>, +{ + #[inline] + fn trie(&self) -> &'trie T { + self.delegate.trie() + } +} + +impl<'slice, 'trie, T, V> Iterator for CharIndicesWithTrieDefaultForAscii<'slice, 'trie, T, V> +where + V: TrieValue + Default, + T: AbstractCodePointTrie<'trie, V>, +{ + type Item = (usize, char, V); + + #[inline] + fn next(&mut self) -> Option { + let old_len = self.as_str().len(); + let (c, v) = self.delegate.next()?; + let old_offset = self.offset; + self.offset += old_len - self.as_str().len(); + Some((old_offset, c, v)) + } + + #[inline] + fn count(self) -> usize { + self.as_str().chars().count() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.as_str().chars().size_hint() + } + + #[inline] + fn last(mut self) -> Option { + self.next_back() + } + + // TODO: Delegate advance_by to `Chars` once stabilized. +} + +impl<'slice, 'trie, T, V> DoubleEndedIterator + for CharIndicesWithTrieDefaultForAscii<'slice, 'trie, T, V> +where + V: TrieValue + Default, + T: AbstractCodePointTrie<'trie, V>, +{ + #[inline] + fn next_back(&mut self) -> Option { + let (c, v) = self.delegate.next_back()?; + Some((self.offset + self.as_str().len(), c, v)) + } +} + +impl<'slice, 'trie, T, V> FusedIterator for CharIndicesWithTrieDefaultForAscii<'slice, 'trie, T, V> +where + V: TrieValue + Default, + T: AbstractCodePointTrie<'trie, V>, +{ +} + +// -- + +/// Adds convenience methods to `str`. +pub trait CharsWithTrieDefaultForAsciiEx<'slice, 'trie, T, V> +where + V: TrieValue + Default, + T: AbstractCodePointTrie<'trie, V>, +{ + /// Method for easily creating `CharsWithTrie` on `str` analogously to `chars()`. + fn chars_with_trie_default_for_ascii( + &'slice self, + trie: &'trie T, + ) -> CharsWithTrieDefaultForAscii<'slice, 'trie, T, V>; + + /// Method for easily creating `CharIndicesWithTrie` on `str` analogously to `char_indices()`. + fn char_indices_with_trie_default_for_ascii( + &'slice self, + trie: &'trie T, + ) -> CharIndicesWithTrieDefaultForAscii<'slice, 'trie, T, V>; +} + +impl<'slice, 'trie, T, V> CharsWithTrieDefaultForAsciiEx<'slice, 'trie, T, V> for str +where + V: TrieValue + Default, + T: AbstractCodePointTrie<'trie, V>, +{ + /// Method for easily creating `CharsWithTrie` on `str` analogously to `chars()`. + #[inline] + fn chars_with_trie_default_for_ascii( + &'slice self, + trie: &'trie T, + ) -> CharsWithTrieDefaultForAscii<'slice, 'trie, T, V> { + CharsWithTrieDefaultForAscii::new(self, trie) + } + + /// Method for easily creating `CharIndicesWithTrie` on `str` analogously to `char_indices()`. + #[inline] + fn char_indices_with_trie_default_for_ascii( + &'slice self, + trie: &'trie T, + ) -> CharIndicesWithTrieDefaultForAscii<'slice, 'trie, T, V> { + CharIndicesWithTrieDefaultForAscii::new(self, trie) + } +} + +// -- + +/// Iterator over Latin1 `[u8]` by `char` and `TrieValue`. +#[derive(Debug)] +pub struct Latin1CharsWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + delegate: core::slice::Iter<'slice, u8>, + trie: &'trie T, + phantom: PhantomData, +} + +impl<'slice, 'trie, T, V> Latin1CharsWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + /// Construct a new `Latin1CharsWithTrie`. + #[inline] + pub fn new(s: &'slice [u8], trie: &'trie T) -> Self { + Self { + delegate: s.iter(), + trie, + phantom: PhantomData, + } + } + + /// Obtains the remainder of the iterator as a slice. + #[inline] + pub fn as_slice(&self) -> &'slice [u8] { + self.delegate.as_slice() + } +} + +impl<'slice, 'trie, T, V> Clone for Latin1CharsWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + #[inline] + fn clone(&self) -> Self { + Self { + delegate: self.delegate.clone(), + trie: self.trie, + phantom: PhantomData, + } + } +} + +impl<'slice, 'trie, T, V> WithTrie<'trie, T, V> for Latin1CharsWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + #[inline] + fn trie(&self) -> &'trie T { + self.trie + } +} + +impl<'slice, 'trie, T, V> Iterator for Latin1CharsWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + type Item = (char, V); + + #[inline] + fn next(&mut self) -> Option { + let b = *self.delegate.next()?; + Some((char::from(b), self.trie.latin1(b))) + } + + #[inline] + fn count(self) -> usize { + self.delegate.count() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.delegate.size_hint() + } + + #[inline] + fn last(mut self) -> Option { + self.next_back() + } + + // TODO: Delegate advance_by to `delegate` once stabilized. +} + +impl<'slice, 'trie, T, V> DoubleEndedIterator for Latin1CharsWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + #[inline] + fn next_back(&mut self) -> Option { + let b = *self.delegate.next_back()?; + Some((char::from(b), self.trie.latin1(b))) + } +} + +impl<'slice, 'trie, T, V> FusedIterator for Latin1CharsWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ +} + +// -- + +/// Iterator over `str` by `char` and `TrieValue`. +#[derive(Debug)] +pub struct Latin1CharIndicesWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + offset: usize, + delegate: core::slice::Iter<'slice, u8>, + trie: &'trie T, + phantom: PhantomData, +} + +impl<'slice, 'trie, T, V> Latin1CharIndicesWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + /// Construct a new `Latin1CharIndicesWithTrie`. + #[inline] + pub fn new(s: &'slice [u8], trie: &'trie T) -> Self { + Self { + offset: 0, + delegate: s.iter(), + trie, + phantom: PhantomData, + } + } + + /// Obtains the remainder of the iterator as a slice. + #[inline] + pub fn as_slice(&self) -> &'slice [u8] { + self.delegate.as_slice() + } +} + +impl<'slice, 'trie, T, V> Clone for Latin1CharIndicesWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + #[inline] + fn clone(&self) -> Self { + Self { + offset: self.offset, + delegate: self.delegate.clone(), + trie: self.trie, + phantom: PhantomData, + } + } +} + +impl<'slice, 'trie, T, V> WithTrie<'trie, T, V> for Latin1CharIndicesWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + #[inline] + fn trie(&self) -> &'trie T { + self.trie + } +} + +impl<'slice, 'trie, T, V> Iterator for Latin1CharIndicesWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + type Item = (usize, char, V); + + #[inline] + fn next(&mut self) -> Option { + let b = *self.delegate.next()?; + let old_offset = self.offset; + self.offset += 1; + Some((old_offset, char::from(b), self.trie.latin1(b))) + } + + #[inline] + fn count(self) -> usize { + self.delegate.count() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.delegate.size_hint() + } + + #[inline] + fn last(mut self) -> Option { + self.next_back() + } + + // TODO: Delegate advance_by to `delegate` once stabilized. +} + +impl<'slice, 'trie, T, V> DoubleEndedIterator for Latin1CharIndicesWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + #[inline] + fn next_back(&mut self) -> Option { + let b = *self.delegate.next_back()?; + Some(( + self.offset + self.as_slice().len(), + char::from(b), + self.trie.latin1(b), + )) + } +} + +impl<'slice, 'trie, T, V> FusedIterator for Latin1CharIndicesWithTrie<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ +} + +// -- + +/// Adds convenience methods to `[u8]`. +pub trait Latin1CharsWithTrieEx<'slice, 'trie, T, V> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + /// Method for easily creating `Latin1CharsWithTrie` on `[u8]` analogously to `chars()` on `str`. + /// (The name is prefixed with `latin1_` to avoid ambiguity with interpreting [u8] as UTF-8.) + fn latin1_chars_with_trie( + &'slice self, + trie: &'trie T, + ) -> Latin1CharsWithTrie<'slice, 'trie, T, V>; + + /// Method for easily creating `Latin1CharIndicesWithTrie` on `str` analogously to `char_indices()` on `str`. + /// (The name is prefixed with `latin1_` to avoid ambiguity with interpreting [u8] as UTF-8.) + fn latin1_char_indices_with_trie( + &'slice self, + trie: &'trie T, + ) -> Latin1CharIndicesWithTrie<'slice, 'trie, T, V>; +} + +impl<'slice, 'trie, T, V> Latin1CharsWithTrieEx<'slice, 'trie, T, V> for [u8] +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, +{ + /// Method for easily creating `Latin1CharsWithTrie` on `[u8]` analogously to `chars()` on `str`. + /// (The name is prefixed with `latin1_` to avoid ambiguity with interpreting [u8] as UTF-8.) + #[inline] + fn latin1_chars_with_trie( + &'slice self, + trie: &'trie T, + ) -> Latin1CharsWithTrie<'slice, 'trie, T, V> { + Latin1CharsWithTrie::new(self, trie) + } + + /// Method for easily creating `Latin1CharIndicesWithTrie` on `str` analogously to `char_indices()` on `str`. + /// (The name is prefixed with `latin1_` to avoid ambiguity with interpreting [u8] as UTF-8.) + #[inline] + fn latin1_char_indices_with_trie( + &'slice self, + trie: &'trie T, + ) -> Latin1CharIndicesWithTrie<'slice, 'trie, T, V> { + Latin1CharIndicesWithTrie::new(self, trie) + } +} + +// -- + +/// Wraps an `Iterator` with a reference to +/// an `AbstractCodePointTrie`. +#[derive(Debug)] +pub struct CharIterWithTrie<'trie, T, V, I> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, + I: Iterator, +{ + delegate: I, + trie: &'trie T, + phantom: PhantomData, +} + +impl<'trie, T, V, I> CharIterWithTrie<'trie, T, V, I> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, + I: Iterator, +{ + /// Constructs a new `CharIterWithTrie`. + #[inline] + pub fn new(iter: I, trie: &'trie T) -> Self { + Self { + delegate: iter, + trie, + phantom: PhantomData, + } + } +} + +impl<'trie, T, V, I> WithTrie<'trie, T, V> for CharIterWithTrie<'trie, T, V, I> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, + I: Iterator, +{ + #[inline] + fn trie(&self) -> &'trie T { + self.trie + } +} + +impl<'trie, T, V, I> Iterator for CharIterWithTrie<'trie, T, V, I> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, + I: Iterator, +{ + type Item = (char, V); + + #[inline] + fn next(&mut self) -> Option { + let c = self.delegate.next()?; + Some((c, self.trie.scalar(c))) + } + + #[inline] + fn count(self) -> usize { + self.delegate.count() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.delegate.size_hint() + } + + // Looks like conditionally implementing `last()` is not allowed. + + // TODO: Delegate advance_by to `delegate` once stabilized. +} + +impl<'trie, T, V, I> DoubleEndedIterator for CharIterWithTrie<'trie, T, V, I> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, + I: DoubleEndedIterator, +{ + #[inline] + fn next_back(&mut self) -> Option { + let c = self.delegate.next_back()?; + Some((c, self.trie.scalar(c))) + } +} + +impl<'trie, T, V, I> FusedIterator for CharIterWithTrie<'trie, T, V, I> +where + V: TrieValue, + T: AbstractCodePointTrie<'trie, V>, + I: FusedIterator, +{ +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_forward() { + let trie = crate::codepointtrie::planes::get_planes_trie(); + let s = "abäαあ🥳𧉧"; + let mut iter = s.chars_with_trie(&trie); + assert_eq!(iter.next(), Some(('a', 0))); + assert_eq!(iter.next(), Some(('b', 0))); + assert_eq!(iter.next(), Some(('ä', 0))); + assert_eq!(iter.next(), Some(('α', 0))); + assert_eq!(iter.next(), Some(('あ', 0))); + assert_eq!(iter.next(), Some(('🥳', 1))); + assert_eq!(iter.next(), Some(('𧉧', 2))); + assert_eq!(iter.next(), None); + } + + #[test] + fn test_backwards() { + let trie = crate::codepointtrie::planes::get_planes_trie(); + let s = "abäαあ🥳𧉧"; + let mut iter = s.chars_with_trie(&trie); + assert_eq!(iter.next_back(), Some(('𧉧', 2))); + assert_eq!(iter.next_back(), Some(('🥳', 1))); + assert_eq!(iter.next_back(), Some(('あ', 0))); + assert_eq!(iter.next_back(), Some(('α', 0))); + assert_eq!(iter.next_back(), Some(('ä', 0))); + assert_eq!(iter.next_back(), Some(('b', 0))); + assert_eq!(iter.next_back(), Some(('a', 0))); + assert_eq!(iter.next(), None); + } + + #[test] + fn test_indices_forward() { + let trie = crate::codepointtrie::planes::get_planes_trie(); + let s = "abäαあ🥳𧉧"; + let mut iter = s.char_indices_with_trie(&trie); + assert_eq!(iter.next(), Some((0, 'a', 0))); + assert_eq!(iter.next(), Some((1, 'b', 0))); + assert_eq!(iter.next(), Some((2, 'ä', 0))); + assert_eq!(iter.next(), Some((4, 'α', 0))); + assert_eq!(iter.next(), Some((6, 'あ', 0))); + assert_eq!(iter.next(), Some((9, '🥳', 1))); + assert_eq!(iter.next(), Some((13, '𧉧', 2))); + assert_eq!(iter.next(), None); + } + + #[test] + fn test_indices_backwards() { + let trie = crate::codepointtrie::planes::get_planes_trie(); + let s = "abäαあ🥳𧉧"; + let mut iter = s.char_indices_with_trie(&trie); + assert_eq!(iter.next_back(), Some((13, '𧉧', 2))); + assert_eq!(iter.next_back(), Some((9, '🥳', 1))); + assert_eq!(iter.next_back(), Some((6, 'あ', 0))); + assert_eq!(iter.next_back(), Some((4, 'α', 0))); + assert_eq!(iter.next_back(), Some((2, 'ä', 0))); + assert_eq!(iter.next_back(), Some((1, 'b', 0))); + assert_eq!(iter.next_back(), Some((0, 'a', 0))); + assert_eq!(iter.next(), None); + } +} diff --git a/components/collections/src/codepointtrie/mod.rs b/components/collections/src/codepointtrie/mod.rs index dfc5a29ffc2..2cdee95e39c 100644 --- a/components/collections/src/codepointtrie/mod.rs +++ b/components/collections/src/codepointtrie/mod.rs @@ -32,6 +32,7 @@ mod cptrie; mod error; mod impl_const; +mod iter; pub mod planes; #[cfg(feature = "serde")] @@ -40,6 +41,7 @@ pub mod toml; #[cfg(feature = "serde")] mod serde; +pub use cptrie::AbstractCodePointTrie; pub use cptrie::CodePointMapRange; pub use cptrie::CodePointMapRangeIterator; pub use cptrie::CodePointTrie; @@ -51,3 +53,14 @@ pub use cptrie::TrieValue; pub use cptrie::Typed; pub use cptrie::TypedCodePointTrie; pub use error::Error as CodePointTrieError; +pub use iter::CharIndicesWithTrie; +pub use iter::CharIndicesWithTrieDefaultForAscii; +pub use iter::CharIterWithTrie; +pub use iter::CharsWithTrie; +pub use iter::CharsWithTrieDefaultForAscii; +pub use iter::CharsWithTrieDefaultForAsciiEx; +pub use iter::CharsWithTrieEx; +pub use iter::Latin1CharIndicesWithTrie; +pub use iter::Latin1CharsWithTrie; +pub use iter::Latin1CharsWithTrieEx; +pub use iter::WithTrie; diff --git a/components/collections/src/codepointtrie/serde.rs b/components/collections/src/codepointtrie/serde.rs index adfb2d8aa08..43c398a8a2f 100644 --- a/components/collections/src/codepointtrie/serde.rs +++ b/components/collections/src/codepointtrie/serde.rs @@ -2,7 +2,10 @@ // called LICENSE at the top level of the ICU4X source tree // (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). -use crate::codepointtrie::{CodePointTrie, CodePointTrieHeader, TrieValue}; +use crate::codepointtrie::{ + CodePointTrie, CodePointTrieHeader, FastCodePointTrie, SmallCodePointTrie, TrieValue, + TypedCodePointTrie, +}; use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer}; use zerofrom::ZeroFrom; use zerovec::ZeroVec; @@ -30,6 +33,36 @@ impl Serialize for CodePointTrie<'_, T> { } } +impl Serialize for SmallCodePointTrie<'_, T> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let untyped = self.as_untyped_ref(); + let ser = CodePointTrieSerde { + header: untyped.header, + index: ZeroFrom::zero_from(&untyped.index), + data: ZeroFrom::zero_from(&untyped.data), + }; + ser.serialize(serializer) + } +} + +impl Serialize for FastCodePointTrie<'_, T> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let untyped = self.as_untyped_ref(); + let ser = CodePointTrieSerde { + header: untyped.header, + index: ZeroFrom::zero_from(&untyped.index), + data: ZeroFrom::zero_from(&untyped.data), + }; + ser.serialize(serializer) + } +} + impl<'de, 'trie, T: TrieValue + Deserialize<'de>> Deserialize<'de> for CodePointTrie<'trie, T> where 'de: 'trie, @@ -60,6 +93,9 @@ where super::CodePointTrieError::DataTooShortForFastAccess => { return Err(D::Error::custom("CodePointTrie must be constructed from data vector long enough to accommodate fast-path access")); } + super::CodePointTrieError::DataTooShortForAsciiAccess => { + return Err(D::Error::custom("CodePointTrie must be constructed from data vector long enough to accommodate direct ASCII access")); + } } } }; @@ -72,3 +108,39 @@ where }) } } + +impl<'de, 'trie, T: TrieValue + Deserialize<'de>> Deserialize<'de> for SmallCodePointTrie<'trie, T> +where + 'de: 'trie, +{ + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let untyped_de = CodePointTrie::deserialize(deserializer)?; + let Ok(de) = >::try_from(untyped_de) else { + return Err(D::Error::custom( + "SmallCodePointTrie must have small-mode data", + )); + }; + Ok(de) + } +} + +impl<'de, 'trie, T: TrieValue + Deserialize<'de>> Deserialize<'de> for FastCodePointTrie<'trie, T> +where + 'de: 'trie, +{ + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let untyped_de = CodePointTrie::deserialize(deserializer)?; + let Ok(de) = >::try_from(untyped_de) else { + return Err(D::Error::custom( + "FastCodePointTrie must have fast-mode data", + )); + }; + Ok(de) + } +}