-
Notifications
You must be signed in to change notification settings - Fork 170
feat: sequence cast compute #8403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,7 +28,6 @@ use vortex_array::dtype::PType; | |
| use vortex_array::expr::stats::Precision as StatPrecision; | ||
| use vortex_array::expr::stats::Stat; | ||
| use vortex_array::match_each_integer_ptype; | ||
| use vortex_array::match_each_native_ptype; | ||
| use vortex_array::match_each_pvalue; | ||
| use vortex_array::scalar::PValue; | ||
| use vortex_array::scalar::Scalar; | ||
|
|
@@ -61,6 +60,8 @@ pub struct SequenceMetadata { | |
| base: Option<vortex_proto::scalar::ScalarValue>, | ||
| #[prost(message, tag = "2")] | ||
| multiplier: Option<vortex_proto::scalar::ScalarValue>, | ||
| #[prost(enumeration = "PType", optional, tag = "3")] | ||
| calculation_ptype: Option<i32>, | ||
| } | ||
|
|
||
| pub(super) const SLOT_NAMES: [&str; 0] = []; | ||
|
|
@@ -70,11 +71,16 @@ pub(super) const SLOT_NAMES: [&str; 0] = []; | |
| pub struct SequenceData { | ||
| base: PValue, | ||
| multiplier: PValue, | ||
| calculation_ptype: PType, | ||
|
Comment on lines
73
to
+74
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit outside of what this PR is trying to do, but could you document this now that we are adding a third type here that is different from the base and multiplier? With just base and multiplier it is obvious what this is doing, but with the addition of Edit: I am interested to see if your idea about not storing this at all works. That would probably be better for us since that is not a breaking change.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My understanding of this issue was to use two types. I'l work on the idea of not storing this at all. Thanks a lot ! |
||
| } | ||
|
|
||
| impl Display for SequenceData { | ||
| fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { | ||
| write!(f, "base: {}, multiplier: {}", self.base, self.multiplier) | ||
| write!( | ||
| f, | ||
| "base: {}, multiplier: {}, calculation_ptype: {}", | ||
| self.base, self.multiplier, self.calculation_ptype | ||
| ) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -91,51 +97,52 @@ impl SequenceData { | |
| nullability: Nullability, | ||
| length: usize, | ||
| ) -> VortexResult<Self> { | ||
| Self::try_new( | ||
| base.into(), | ||
| multiplier.into(), | ||
| T::PTYPE, | ||
| nullability, | ||
| length, | ||
| ) | ||
| let dtype = DType::Primitive(T::PTYPE, nullability); | ||
| Self::try_new(base.into(), multiplier.into(), T::PTYPE, &dtype, length) | ||
| } | ||
|
|
||
| /// Constructs a sequence array using two integer values (with the same ptype). | ||
| /// Constructs a sequence array using `calculation_ptype` for arithmetic. | ||
| pub(crate) fn try_new( | ||
| base: PValue, | ||
| multiplier: PValue, | ||
| ptype: PType, | ||
| nullability: Nullability, | ||
| calculation_ptype: PType, | ||
| dtype: &DType, | ||
| length: usize, | ||
| ) -> VortexResult<Self> { | ||
| let dtype = DType::Primitive(ptype, nullability); | ||
| Self::validate(base, multiplier, &dtype, length)?; | ||
| let (base, multiplier) = Self::normalize(base, multiplier, ptype)?; | ||
| Self::validate(base, multiplier, calculation_ptype, dtype, length)?; | ||
| let (base, multiplier) = Self::normalize(base, multiplier, calculation_ptype)?; | ||
|
|
||
| Ok(unsafe { Self::new_unchecked(base, multiplier) }) | ||
| Ok(unsafe { Self::new_unchecked(base, multiplier, calculation_ptype) }) | ||
| } | ||
|
|
||
| pub fn validate( | ||
| base: PValue, | ||
| multiplier: PValue, | ||
| dtype: &DType, | ||
| calculation_ptype: PType, | ||
| output_dtype: &DType, | ||
| length: usize, | ||
| ) -> VortexResult<()> { | ||
| let DType::Primitive(ptype, _) = dtype else { | ||
| let DType::Primitive(output_ptype, _) = output_dtype else { | ||
| vortex_bail!("only primitive dtypes are supported in SequenceArray currently"); | ||
| }; | ||
|
|
||
| if !ptype.is_int() { | ||
| vortex_bail!("only integer ptype are supported in SequenceArray currently") | ||
| if !calculation_ptype.is_int() || !output_ptype.is_int() { | ||
| vortex_bail!("only integer ptypes are supported in SequenceArray currently") | ||
| } | ||
|
|
||
| vortex_ensure!(length > 0, "SequenceArray length must be greater than zero"); | ||
| Self::try_last(base, multiplier, *ptype, length).map_err(|e| { | ||
| let last = Self::try_last(base, multiplier, calculation_ptype, length).map_err(|e| { | ||
| e.with_context(format!( | ||
| "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ", | ||
| )) | ||
| })?; | ||
|
|
||
| match_each_integer_ptype!(*output_ptype, |P| { | ||
| base.cast::<P>()?; | ||
| last.cast::<P>()?; | ||
| VortexResult::Ok(()) | ||
| })?; | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
|
|
@@ -153,14 +160,23 @@ impl SequenceData { | |
| /// # Safety | ||
| /// | ||
| /// The caller must ensure that: | ||
| /// - `base` and `multiplier` are both normalized to the same integer `ptype`. | ||
| /// - `base` and `multiplier` are both normalized to `calculation_ptype`. | ||
| /// - `calculation_ptype` is an integer type. | ||
| /// - they are logically compatible with the outer dtype and len. | ||
| pub(crate) unsafe fn new_unchecked(base: PValue, multiplier: PValue) -> Self { | ||
| Self { base, multiplier } | ||
| pub(crate) unsafe fn new_unchecked( | ||
| base: PValue, | ||
| multiplier: PValue, | ||
| calculation_ptype: PType, | ||
| ) -> Self { | ||
| Self { | ||
| base, | ||
| multiplier, | ||
| calculation_ptype, | ||
| } | ||
| } | ||
|
|
||
| pub fn ptype(&self) -> PType { | ||
| self.base.ptype() | ||
| pub fn calculation_ptype(&self) -> PType { | ||
| self.calculation_ptype | ||
| } | ||
|
|
||
| pub fn base(&self) -> PValue { | ||
|
|
@@ -171,6 +187,10 @@ impl SequenceData { | |
| self.multiplier | ||
| } | ||
|
|
||
| pub(crate) fn cast_value(value: PValue, output_ptype: PType) -> VortexResult<PValue> { | ||
| match_each_integer_ptype!(output_ptype, |O| { Ok(PValue::from(value.cast::<O>()?)) }) | ||
| } | ||
|
|
||
| pub fn into_parts(self) -> SequenceDataParts { | ||
| SequenceDataParts { | ||
| base: self.base, | ||
|
|
@@ -200,7 +220,7 @@ impl SequenceData { | |
| } | ||
|
|
||
| pub(crate) fn index_value(&self, idx: usize) -> PValue { | ||
| match_each_native_ptype!(self.ptype(), |P| { | ||
| match_each_integer_ptype!(self.calculation_ptype(), |P| { | ||
| let base = self.base.cast::<P>().vortex_expect("must be able to cast"); | ||
| let multiplier = self | ||
| .multiplier | ||
|
|
@@ -217,12 +237,15 @@ impl ArrayHash for SequenceData { | |
| fn array_hash<H: Hasher>(&self, state: &mut H, _accuracy: EqMode) { | ||
| self.base.hash(state); | ||
| self.multiplier.hash(state); | ||
| self.calculation_ptype.hash(state); | ||
| } | ||
| } | ||
|
|
||
| impl ArrayEq for SequenceData { | ||
| fn array_eq(&self, other: &Self, _accuracy: EqMode) -> bool { | ||
| self.base == other.base && self.multiplier == other.multiplier | ||
| self.base == other.base | ||
| && self.multiplier == other.multiplier | ||
| && self.calculation_ptype == other.calculation_ptype | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -244,7 +267,13 @@ impl VTable for Sequence { | |
| len: usize, | ||
| _slots: &[Option<ArrayRef>], | ||
| ) -> VortexResult<()> { | ||
| SequenceData::validate(data.base, data.multiplier, dtype, len) | ||
| SequenceData::validate( | ||
| data.base, | ||
| data.multiplier, | ||
| data.calculation_ptype, | ||
| dtype, | ||
| len, | ||
| ) | ||
| } | ||
|
|
||
| fn nbuffers(_array: ArrayView<'_, Self>) -> usize { | ||
|
|
@@ -266,6 +295,7 @@ impl VTable for Sequence { | |
| let metadata = SequenceMetadata { | ||
| base: Some((&array.base()).into()), | ||
| multiplier: Some((&array.multiplier()).into()), | ||
| calculation_ptype: Some(array.calculation_ptype() as i32), | ||
| }; | ||
|
|
||
| Ok(Some(metadata.encode_to_vec())) | ||
|
|
@@ -292,15 +322,19 @@ impl VTable for Sequence { | |
| ); | ||
| let metadata = SequenceMetadata::decode(metadata)?; | ||
|
|
||
| let ptype = dtype.as_ptype(); | ||
| let calculation_ptype = metadata | ||
| .calculation_ptype | ||
| .map(|p| PType::try_from(p).map_err(|e| vortex_err!("invalid PType {p}: {e}"))) | ||
| .transpose()? | ||
| .unwrap_or_else(|| dtype.as_ptype()); | ||
|
|
||
| // We go via Scalar to validate that the value is valid for the ptype. | ||
| let base = Scalar::from_proto_value( | ||
| metadata | ||
| .base | ||
| .as_ref() | ||
| .ok_or_else(|| vortex_err!("base required"))?, | ||
| &DType::Primitive(ptype, NonNullable), | ||
| &DType::Primitive(calculation_ptype, NonNullable), | ||
| session, | ||
| )? | ||
| .as_primitive() | ||
|
|
@@ -312,14 +346,14 @@ impl VTable for Sequence { | |
| .multiplier | ||
| .as_ref() | ||
| .ok_or_else(|| vortex_err!("multiplier required"))?, | ||
| &DType::Primitive(ptype, NonNullable), | ||
| &DType::Primitive(calculation_ptype, NonNullable), | ||
| session, | ||
| )? | ||
| .as_primitive() | ||
| .pvalue() | ||
| .vortex_expect("sequence array multiplier should be a non-nullable primitive"); | ||
|
|
||
| let data = SequenceData::try_new(base, multiplier, ptype, dtype.nullability(), len)?; | ||
| let data = SequenceData::try_new(base, multiplier, calculation_ptype, dtype, len)?; | ||
| Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data)) | ||
| } | ||
|
|
||
|
|
@@ -355,10 +389,8 @@ impl OperationsVTable<Sequence> for Sequence { | |
| index: usize, | ||
| _ctx: &mut ExecutionCtx, | ||
| ) -> VortexResult<Scalar> { | ||
| Scalar::try_new( | ||
| array.dtype().clone(), | ||
| Some(ScalarValue::Primitive(array.index_value(index))), | ||
| ) | ||
| let value = SequenceData::cast_value(array.index_value(index), array.dtype().as_ptype())?; | ||
| Scalar::try_new(array.dtype().clone(), Some(ScalarValue::Primitive(value))) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -402,15 +434,16 @@ impl Sequence { | |
| pub(crate) unsafe fn new_unchecked( | ||
| base: PValue, | ||
| multiplier: PValue, | ||
| ptype: PType, | ||
| calculation_ptype: PType, | ||
| output_ptype: PType, | ||
| nullability: Nullability, | ||
| length: usize, | ||
| ) -> SequenceArray { | ||
| let dtype = DType::Primitive(ptype, nullability); | ||
| let (base, multiplier) = SequenceData::normalize(base, multiplier, ptype) | ||
| let dtype = DType::Primitive(output_ptype, nullability); | ||
| let (base, multiplier) = SequenceData::normalize(base, multiplier, calculation_ptype) | ||
| .vortex_expect("SequenceArray parts must be normalized to the target ptype"); | ||
| let stats = Self::stats(multiplier); | ||
| let data = unsafe { SequenceData::new_unchecked(base, multiplier) }; | ||
| let data = unsafe { SequenceData::new_unchecked(base, multiplier, calculation_ptype) }; | ||
| unsafe { Array::from_parts_unchecked(ArrayParts::new(Sequence, dtype, length, data)) } | ||
| .with_stats_set(stats) | ||
| } | ||
|
|
@@ -419,12 +452,13 @@ impl Sequence { | |
| pub fn try_new( | ||
| base: PValue, | ||
| multiplier: PValue, | ||
| ptype: PType, | ||
| calculation_ptype: PType, | ||
| output_ptype: PType, | ||
| nullability: Nullability, | ||
| length: usize, | ||
| ) -> VortexResult<SequenceArray> { | ||
| let dtype = DType::Primitive(ptype, nullability); | ||
| let data = SequenceData::try_new(base, multiplier, ptype, nullability, length)?; | ||
| let dtype = DType::Primitive(output_ptype, nullability); | ||
| let data = SequenceData::try_new(base, multiplier, calculation_ptype, &dtype, length)?; | ||
| let stats = Self::stats(data.multiplier()); | ||
| Ok( | ||
| unsafe { Array::from_parts_unchecked(ArrayParts::new(Sequence, dtype, length, data)) } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you explain why in a doc str why we need this, if we need this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was trying to decode the base and multiplier as
calculation_ptypeduring deserialization, so I added this. But after Gates comments, I found that I could usescalar_value::Kindto get the type instead. I'll remove this in the next changeThanks for the review :)