From 4b987fcd392d478e9f80da9c32356760b782da09 Mon Sep 17 00:00:00 2001 From: Jorge Leitao Date: Wed, 11 Aug 2021 06:14:52 +0100 Subject: [PATCH] Added Scalar API (#56) --- src/compute/aggregate/min_max.rs | 91 +++++++++++++++++++++ src/compute/aggregate/sum.rs | 97 ++++++++++++++++++++-- src/lib.rs | 1 + src/scalar/README.md | 28 +++++++ src/scalar/binary.rs | 95 +++++++++++++++++++++ src/scalar/boolean.rs | 82 +++++++++++++++++++ src/scalar/equal.rs | 115 ++++++++++++++++++++++++++ src/scalar/list.rs | 66 +++++++++++++++ src/scalar/mod.rs | 136 +++++++++++++++++++++++++++++++ src/scalar/null.rs | 36 ++++++++ src/scalar/primitive.rs | 113 +++++++++++++++++++++++++ src/scalar/struct_.rs | 54 ++++++++++++ src/scalar/utf8.rs | 95 +++++++++++++++++++++ 13 files changed, 1003 insertions(+), 6 deletions(-) create mode 100644 src/scalar/README.md create mode 100644 src/scalar/binary.rs create mode 100644 src/scalar/boolean.rs create mode 100644 src/scalar/equal.rs create mode 100644 src/scalar/list.rs create mode 100644 src/scalar/mod.rs create mode 100644 src/scalar/null.rs create mode 100644 src/scalar/primitive.rs create mode 100644 src/scalar/struct_.rs create mode 100644 src/scalar/utf8.rs diff --git a/src/compute/aggregate/min_max.rs b/src/compute/aggregate/min_max.rs index 8e9c0577e6c..2ab1bf15346 100644 --- a/src/compute/aggregate/min_max.rs +++ b/src/compute/aggregate/min_max.rs @@ -1,4 +1,7 @@ use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact}; +use crate::datatypes::{DataType, IntervalUnit}; +use crate::error::{ArrowError, Result}; +use crate::scalar::*; use crate::types::simd::*; use crate::types::NativeType; use crate::{ @@ -281,6 +284,94 @@ pub fn max_boolean(array: &BooleanArray) -> Option { .or(Some(false)) } +macro_rules! dyn_primitive { + ($ty:ty, $array:expr, $f:ident) => {{ + let array = $array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(PrimitiveScalar::<$ty>::new( + $array.data_type().clone(), + $f::<$ty>(array), + )) + }}; +} + +macro_rules! dyn_generic { + ($array_ty:ty, $scalar_ty:ty, $array:expr, $f:ident) => {{ + let array = $array.as_any().downcast_ref::<$array_ty>().unwrap(); + Box::new(<$scalar_ty>::new($f(array))) + }}; +} + +pub fn max(array: &dyn Array) -> Result> { + Ok(match array.data_type() { + DataType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, max_boolean), + DataType::Int8 => dyn_primitive!(i8, array, max_primitive), + DataType::Int16 => dyn_primitive!(i16, array, max_primitive), + DataType::Int32 + | DataType::Date32 + | DataType::Time32(_) + | DataType::Interval(IntervalUnit::YearMonth) => { + dyn_primitive!(i32, array, max_primitive) + } + DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) => dyn_primitive!(i64, array, max_primitive), + DataType::UInt8 => dyn_primitive!(u8, array, max_primitive), + DataType::UInt16 => dyn_primitive!(u16, array, max_primitive), + DataType::UInt32 => dyn_primitive!(u32, array, max_primitive), + DataType::UInt64 => dyn_primitive!(u64, array, max_primitive), + DataType::Float16 => unreachable!(), + DataType::Float32 => dyn_primitive!(f32, array, max_primitive), + DataType::Float64 => dyn_primitive!(f64, array, max_primitive), + DataType::Utf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, max_string), + DataType::LargeUtf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, max_string), + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "The `max` operator does not support type `{}`", + array.data_type(), + ))) + } + }) +} + +pub fn min(array: &dyn Array) -> Result> { + Ok(match array.data_type() { + DataType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, min_boolean), + DataType::Int8 => dyn_primitive!(i8, array, min_primitive), + DataType::Int16 => dyn_primitive!(i16, array, min_primitive), + DataType::Int32 + | DataType::Date32 + | DataType::Time32(_) + | DataType::Interval(IntervalUnit::YearMonth) => { + dyn_primitive!(i32, array, min_primitive) + } + DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) => dyn_primitive!(i64, array, min_primitive), + DataType::UInt8 => dyn_primitive!(u8, array, min_primitive), + DataType::UInt16 => dyn_primitive!(u16, array, min_primitive), + DataType::UInt32 => dyn_primitive!(u32, array, min_primitive), + DataType::UInt64 => dyn_primitive!(u64, array, min_primitive), + DataType::Float16 => unreachable!(), + DataType::Float32 => dyn_primitive!(f32, array, min_primitive), + DataType::Float64 => dyn_primitive!(f64, array, min_primitive), + DataType::Utf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, min_string), + DataType::LargeUtf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, min_string), + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "The `max` operator does not support type `{}`", + array.data_type(), + ))) + } + }) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/compute/aggregate/sum.rs b/src/compute/aggregate/sum.rs index b1f1cece782..eb3b3ef4d2f 100644 --- a/src/compute/aggregate/sum.rs +++ b/src/compute/aggregate/sum.rs @@ -3,6 +3,9 @@ use std::ops::Add; use multiversion::multiversion; use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact}; +use crate::datatypes::{DataType, IntervalUnit}; +use crate::error::{ArrowError, Result}; +use crate::scalar::*; use crate::types::simd::*; use crate::types::NativeType; use crate::{ @@ -85,7 +88,7 @@ where /// Returns the sum of values in the array. /// /// Returns `None` if the array is empty or only contains null values. -pub fn sum(array: &PrimitiveArray) -> Option +pub fn sum_primitive(array: &PrimitiveArray) -> Option where T: NativeType + Simd, T::Simd: Add + Sum, @@ -102,6 +105,76 @@ where } } +macro_rules! dyn_sum { + ($ty:ty, $array:expr) => {{ + let array = $array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(PrimitiveScalar::<$ty>::new( + $array.data_type().clone(), + sum_primitive::<$ty>(array), + )) + }}; +} + +pub fn can_sum(data_type: &DataType) -> bool { + use DataType::*; + matches!( + data_type, + Int8 | Int16 + | Date32 + | Time32(_) + | Interval(IntervalUnit::YearMonth) + | Int64 + | Date64 + | Time64(_) + | Timestamp(_, _) + | Duration(_) + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float32 + | Float64 + ) +} + +/// Returns the sum of all elements in `array` as a [`Scalar`] of the same physical +/// and logical types as `array`. +/// # Error +/// Errors iff the operation is not supported. +pub fn sum(array: &dyn Array) -> Result> { + Ok(match array.data_type() { + DataType::Int8 => dyn_sum!(i8, array), + DataType::Int16 => dyn_sum!(i16, array), + DataType::Int32 + | DataType::Date32 + | DataType::Time32(_) + | DataType::Interval(IntervalUnit::YearMonth) => { + dyn_sum!(i32, array) + } + DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) => dyn_sum!(i64, array), + DataType::UInt8 => dyn_sum!(u8, array), + DataType::UInt16 => dyn_sum!(u16, array), + DataType::UInt32 => dyn_sum!(u32, array), + DataType::UInt64 => dyn_sum!(u64, array), + DataType::Float16 => unreachable!(), + DataType::Float32 => dyn_sum!(f32, array), + DataType::Float64 => dyn_sum!(f64, array), + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "The `sum` operator does not support type `{}`", + array.data_type(), + ))) + } + }) +} + #[cfg(test)] mod tests { use super::super::super::arithmetics; @@ -111,25 +184,34 @@ mod tests { #[test] fn test_primitive_array_sum() { let a = Int32Array::from_slice(&[1, 2, 3, 4, 5]); - assert_eq!(15, sum(&a).unwrap()); + assert_eq!( + &PrimitiveScalar::::from(Some(15)) as &dyn Scalar, + sum(&a).unwrap().as_ref() + ); + + let a = a.to(DataType::Date32); + assert_eq!( + &PrimitiveScalar::::from(Some(15)).to(DataType::Date32) as &dyn Scalar, + sum(&a).unwrap().as_ref() + ); } #[test] fn test_primitive_array_float_sum() { let a = Float64Array::from_slice(&[1.1f64, 2.2, 3.3, 4.4, 5.5]); - assert!((16.5 - sum(&a).unwrap()).abs() < f64::EPSILON); + assert!((16.5 - sum_primitive(&a).unwrap()).abs() < f64::EPSILON); } #[test] fn test_primitive_array_sum_with_nulls() { let a = Int32Array::from(&[None, Some(2), Some(3), None, Some(5)]); - assert_eq!(10, sum(&a).unwrap()); + assert_eq!(10, sum_primitive(&a).unwrap()); } #[test] fn test_primitive_array_sum_all_nulls() { let a = Int32Array::from(&[None, None, None]); - assert_eq!(None, sum(&a)); + assert_eq!(None, sum_primitive(&a)); } #[test] @@ -142,6 +224,9 @@ mod tests { .collect(); // create an array that actually has non-zero values at the invalid indices let c = arithmetics::basic::add::add(&a, &b).unwrap(); - assert_eq!(Some((1..=100).filter(|i| i % 3 == 0).sum()), sum(&c)); + assert_eq!( + Some((1..=100).filter(|i| i % 3 == 0).sum()), + sum_primitive(&c) + ); } } diff --git a/src/lib.rs b/src/lib.rs index 060b6278883..f0b43d91fa0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,7 @@ pub mod bitmap; pub mod buffer; mod endianess; pub mod error; +pub mod scalar; pub mod trusted_len; pub mod types; diff --git a/src/scalar/README.md b/src/scalar/README.md new file mode 100644 index 00000000000..2bac790873b --- /dev/null +++ b/src/scalar/README.md @@ -0,0 +1,28 @@ +# Scalar API + +Design choices: + +### `Scalar` is trait object + +There are three reasons: + +* a scalar should have a small memory footprint, which an enum would not ensure given the different physical types available. +* forward-compatibility: a new entry on an `enum` is backward-incompatible +* do not expose implementation details to users (reduce the surface of the public API) + +### `Scalar` MUST contain nullability information + +This is to be aligned with the general notion of arrow's `Array`. + +This API is a companion to the `Array`, and follows the same design as `Array`. +Specifically, a `Scalar` is a trait object that can be downcasted to concrete implementations. + +Like `Array`, `Scalar` implements + +* `data_type`, which is used to perform the correct downcast +* `is_valid`, to tell whether the scalar is null or not + +### There is one implementation per arrows' physical type + +* Reduces the number of `match` that users need to write +* Allows casting of logical types without changing the underlying physical type diff --git a/src/scalar/binary.rs b/src/scalar/binary.rs new file mode 100644 index 00000000000..fbd769f9d5c --- /dev/null +++ b/src/scalar/binary.rs @@ -0,0 +1,95 @@ +use crate::{array::*, buffer::Buffer, datatypes::DataType}; + +use super::Scalar; + +#[derive(Debug, Clone)] +pub struct BinaryScalar { + value: Buffer, + is_valid: bool, + phantom: std::marker::PhantomData, +} + +impl PartialEq for BinaryScalar { + fn eq(&self, other: &Self) -> bool { + self.is_valid == other.is_valid && ((!self.is_valid) | (self.value == other.value)) + } +} + +impl BinaryScalar { + #[inline] + pub fn new>(v: Option

) -> Self { + let is_valid = v.is_some(); + O::from_usize(v.as_ref().map(|x| x.as_ref().len()).unwrap_or_default()).expect("Too large"); + let value = Buffer::from(v.as_ref().map(|x| x.as_ref()).unwrap_or(&[])); + Self { + value, + is_valid, + phantom: std::marker::PhantomData, + } + } + + #[inline] + pub fn value(&self) -> &[u8] { + self.value.as_slice() + } +} + +impl> From> for BinaryScalar { + #[inline] + fn from(v: Option

) -> Self { + Self::new(v) + } +} + +impl Scalar for BinaryScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.is_valid + } + + #[inline] + fn data_type(&self) -> &DataType { + if O::is_large() { + &DataType::LargeBinary + } else { + &DataType::Binary + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[allow(clippy::eq_op)] + #[test] + fn equal() { + let a = BinaryScalar::::from(Some("a")); + let b = BinaryScalar::::from(None::<&str>); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = BinaryScalar::::from(Some("b")); + assert!(a != b); + assert_eq!(b, b); + } + + #[test] + fn basics() { + let a = BinaryScalar::::from(Some("a")); + + assert_eq!(a.value(), b"a"); + assert_eq!(a.data_type(), &DataType::Binary); + assert!(a.is_valid()); + + let a = BinaryScalar::::from(None::<&str>); + + assert_eq!(a.data_type(), &DataType::LargeBinary); + assert!(!a.is_valid()); + } +} diff --git a/src/scalar/boolean.rs b/src/scalar/boolean.rs new file mode 100644 index 00000000000..67137c752f1 --- /dev/null +++ b/src/scalar/boolean.rs @@ -0,0 +1,82 @@ +use crate::datatypes::DataType; + +use super::Scalar; + +#[derive(Debug, Clone)] +pub struct BooleanScalar { + value: bool, + is_valid: bool, +} + +impl PartialEq for BooleanScalar { + fn eq(&self, other: &Self) -> bool { + self.is_valid == other.is_valid && ((!self.is_valid) | (self.value == other.value)) + } +} + +impl BooleanScalar { + #[inline] + pub fn new(v: Option) -> Self { + let is_valid = v.is_some(); + Self { + value: v.unwrap_or_default(), + is_valid, + } + } + + #[inline] + pub fn value(&self) -> bool { + self.value + } +} + +impl Scalar for BooleanScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.is_valid + } + + #[inline] + fn data_type(&self) -> &DataType { + &DataType::Boolean + } +} + +impl From> for BooleanScalar { + #[inline] + fn from(v: Option) -> Self { + Self::new(v) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[allow(clippy::eq_op)] + #[test] + fn equal() { + let a = BooleanScalar::from(Some(true)); + let b = BooleanScalar::from(None); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = BooleanScalar::from(Some(false)); + assert!(a != b); + assert_eq!(b, b); + } + + #[test] + fn basics() { + let a = BooleanScalar::new(Some(true)); + + assert!(a.value()); + assert_eq!(a.data_type(), &DataType::Boolean); + assert!(a.is_valid()); + } +} diff --git a/src/scalar/equal.rs b/src/scalar/equal.rs new file mode 100644 index 00000000000..503f9fc6856 --- /dev/null +++ b/src/scalar/equal.rs @@ -0,0 +1,115 @@ +use super::*; + +impl PartialEq for dyn Scalar { + fn eq(&self, other: &Self) -> bool { + equal(self, other) + } +} + +macro_rules! dyn_eq { + ($ty:ty, $lhs:expr, $rhs:expr) => {{ + let lhs = $lhs + .as_any() + .downcast_ref::>() + .unwrap(); + let rhs = $rhs + .as_any() + .downcast_ref::>() + .unwrap(); + lhs == rhs + }}; +} + +fn equal(lhs: &dyn Scalar, rhs: &dyn Scalar) -> bool { + if lhs.data_type() != rhs.data_type() { + return false; + } + + match lhs.data_type() { + DataType::Null => { + let lhs = lhs.as_any().downcast_ref::().unwrap(); + let rhs = rhs.as_any().downcast_ref::().unwrap(); + lhs == rhs + } + DataType::Boolean => { + let lhs = lhs.as_any().downcast_ref::().unwrap(); + let rhs = rhs.as_any().downcast_ref::().unwrap(); + lhs == rhs + } + DataType::UInt8 => { + dyn_eq!(u8, lhs, rhs) + } + DataType::UInt16 => { + dyn_eq!(u16, lhs, rhs) + } + DataType::UInt32 => { + dyn_eq!(u32, lhs, rhs) + } + DataType::UInt64 => { + dyn_eq!(u64, lhs, rhs) + } + DataType::Int8 => { + dyn_eq!(i8, lhs, rhs) + } + DataType::Int16 => { + dyn_eq!(i16, lhs, rhs) + } + DataType::Int32 + | DataType::Date32 + | DataType::Time32(_) + | DataType::Interval(IntervalUnit::YearMonth) => { + dyn_eq!(i32, lhs, rhs) + } + DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) => { + dyn_eq!(i64, lhs, rhs) + } + DataType::Decimal(_, _) => { + dyn_eq!(i128, lhs, rhs) + } + DataType::Interval(IntervalUnit::DayTime) => { + dyn_eq!(days_ms, lhs, rhs) + } + DataType::Float16 => unreachable!(), + DataType::Float32 => { + dyn_eq!(f32, lhs, rhs) + } + DataType::Float64 => { + dyn_eq!(f64, lhs, rhs) + } + DataType::Utf8 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + lhs == rhs + } + DataType::LargeUtf8 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + lhs == rhs + } + DataType::Binary => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + lhs == rhs + } + DataType::LargeBinary => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + lhs == rhs + } + DataType::List(_) => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + lhs == rhs + } + DataType::LargeList(_) => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + lhs == rhs + } + _ => unimplemented!(), + } +} diff --git a/src/scalar/list.rs b/src/scalar/list.rs new file mode 100644 index 00000000000..4623711a5b2 --- /dev/null +++ b/src/scalar/list.rs @@ -0,0 +1,66 @@ +use std::any::Any; +use std::sync::Arc; + +use crate::{array::*, datatypes::DataType}; + +use super::Scalar; + +/// The scalar equivalent of [`ListArray`]. Like [`ListArray`], this struct holds a dynamically-typed +/// [`Array`]. The only difference is that this has only one element. +#[derive(Debug, Clone)] +pub struct ListScalar { + values: Arc, + is_valid: bool, + phantom: std::marker::PhantomData, + data_type: DataType, +} + +impl PartialEq for ListScalar { + fn eq(&self, other: &Self) -> bool { + (self.data_type == other.data_type) + && (self.is_valid == other.is_valid) + && ((!self.is_valid) | (self.values.as_ref() == other.values.as_ref())) + } +} + +impl ListScalar { + /// # Panics + /// iff + /// * the `data_type` is not `List` or `LargeList` (depending on this scalar's offset `O`) + /// * the child of the `data_type` is not equal to the `values` + #[inline] + pub fn new(data_type: DataType, values: Option>) -> Self { + let inner_data_type = ListArray::::get_child_type(&data_type); + let (is_valid, values) = match values { + Some(values) => { + assert_eq!(inner_data_type, values.data_type()); + (true, values) + } + None => (false, new_empty_array(inner_data_type.clone()).into()), + }; + Self { + values, + is_valid, + phantom: std::marker::PhantomData, + data_type, + } + } + + pub fn values(&self) -> &Arc { + &self.values + } +} + +impl Scalar for ListScalar { + fn as_any(&self) -> &dyn Any { + self + } + + fn is_valid(&self) -> bool { + self.is_valid + } + + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/src/scalar/mod.rs b/src/scalar/mod.rs new file mode 100644 index 00000000000..9ec6417ad66 --- /dev/null +++ b/src/scalar/mod.rs @@ -0,0 +1,136 @@ +use std::any::Any; + +use crate::{array::*, datatypes::*, types::days_ms}; + +mod equal; +mod primitive; +pub use primitive::*; +mod utf8; +pub use utf8::*; +mod binary; +pub use binary::*; +mod boolean; +pub use boolean::*; +mod list; +pub use list::*; +mod null; +pub use null::*; +mod struct_; +pub use struct_::*; + +pub trait Scalar: std::fmt::Debug { + fn as_any(&self) -> &dyn Any; + + fn is_valid(&self) -> bool; + + fn data_type(&self) -> &DataType; +} + +macro_rules! dyn_new { + ($array:expr, $index:expr, $type:ty) => {{ + let array = $array + .as_any() + .downcast_ref::>() + .unwrap(); + let value = if array.is_valid($index) { + Some(array.value($index)) + } else { + None + }; + Box::new(PrimitiveScalar::new(array.data_type().clone(), value)) + }}; +} + +macro_rules! dyn_new_utf8 { + ($array:expr, $index:expr, $type:ty) => {{ + let array = $array.as_any().downcast_ref::>().unwrap(); + let value = if array.is_valid($index) { + Some(array.value($index)) + } else { + None + }; + Box::new(Utf8Scalar::<$type>::new(value)) + }}; +} + +macro_rules! dyn_new_binary { + ($array:expr, $index:expr, $type:ty) => {{ + let array = $array + .as_any() + .downcast_ref::>() + .unwrap(); + let value = if array.is_valid($index) { + Some(array.value($index)) + } else { + None + }; + Box::new(BinaryScalar::<$type>::new(value)) + }}; +} + +macro_rules! dyn_new_list { + ($array:expr, $index:expr, $type:ty) => {{ + let array = $array.as_any().downcast_ref::>().unwrap(); + let value = if array.is_valid($index) { + Some(array.value($index).into()) + } else { + None + }; + Box::new(ListScalar::<$type>::new(array.data_type().clone(), value)) + }}; +} + +/// creates a new [`Scalar`] from an [`Array`]. +pub fn new_scalar(array: &dyn Array, index: usize) -> Box { + use DataType::*; + match array.data_type() { + Null => Box::new(NullScalar::new()), + Boolean => { + let array = array.as_any().downcast_ref::().unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index)) + } else { + None + }; + Box::new(BooleanScalar::new(value)) + } + Int8 => dyn_new!(array, index, i8), + Int16 => dyn_new!(array, index, i16), + Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => { + dyn_new!(array, index, i32) + } + Int64 | Date64 | Time64(_) | Duration(_) | Timestamp(_, _) => dyn_new!(array, index, i64), + Interval(IntervalUnit::DayTime) => dyn_new!(array, index, days_ms), + UInt8 => dyn_new!(array, index, u8), + UInt16 => dyn_new!(array, index, u16), + UInt32 => dyn_new!(array, index, u32), + UInt64 => dyn_new!(array, index, u64), + Decimal(_, _) => dyn_new!(array, index, i128), + Float16 => unreachable!(), + Float32 => dyn_new!(array, index, f32), + Float64 => dyn_new!(array, index, f64), + Utf8 => dyn_new_utf8!(array, index, i32), + LargeUtf8 => dyn_new_utf8!(array, index, i64), + Binary => dyn_new_binary!(array, index, i32), + LargeBinary => dyn_new_binary!(array, index, i64), + List(_) => dyn_new_list!(array, index, i32), + LargeList(_) => dyn_new_list!(array, index, i64), + Struct(_) => { + let array = array.as_any().downcast_ref::().unwrap(); + if array.is_valid(index) { + let values = array + .values() + .iter() + .map(|x| new_scalar(x.as_ref(), index).into()) + .collect(); + Box::new(StructScalar::new(array.data_type().clone(), Some(values))) + } else { + Box::new(StructScalar::new(array.data_type().clone(), None)) + } + } + FixedSizeBinary(_) => todo!(), + FixedSizeList(_, _) => todo!(), + Union(_) => todo!(), + Dictionary(_, _) => todo!(), + } +} diff --git a/src/scalar/null.rs b/src/scalar/null.rs new file mode 100644 index 00000000000..3751c6cfbd6 --- /dev/null +++ b/src/scalar/null.rs @@ -0,0 +1,36 @@ +use crate::datatypes::DataType; + +use super::Scalar; + +#[derive(Debug, Clone, PartialEq)] +pub struct NullScalar {} + +impl NullScalar { + #[inline] + pub fn new() -> Self { + Self {} + } +} + +impl Default for NullScalar { + fn default() -> Self { + Self::new() + } +} + +impl Scalar for NullScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + false + } + + #[inline] + fn data_type(&self) -> &DataType { + &DataType::Null + } +} diff --git a/src/scalar/primitive.rs b/src/scalar/primitive.rs new file mode 100644 index 00000000000..b204a774ef8 --- /dev/null +++ b/src/scalar/primitive.rs @@ -0,0 +1,113 @@ +use crate::{ + datatypes::DataType, + types::{NativeType, NaturalDataType}, +}; + +use super::Scalar; + +#[derive(Debug, Clone)] +pub struct PrimitiveScalar { + // Not Option because this offers a stabler pointer offset on the struct + value: T, + is_valid: bool, + data_type: DataType, +} + +impl PartialEq for PrimitiveScalar { + fn eq(&self, other: &Self) -> bool { + self.data_type == other.data_type + && self.is_valid == other.is_valid + && ((!self.is_valid) | (self.value == other.value)) + } +} + +impl PrimitiveScalar { + #[inline] + pub fn new(data_type: DataType, v: Option) -> Self { + let is_valid = v.is_some(); + Self { + value: v.unwrap_or_default(), + is_valid, + data_type, + } + } + + #[inline] + pub fn value(&self) -> T { + self.value + } + + /// Returns a new `PrimitiveScalar` with the same value but different [`DataType`] + /// # Panic + /// This function panics if the `data_type` is not valid for self's physical type `T`. + pub fn to(self, data_type: DataType) -> Self { + let v = if self.is_valid { + Some(self.value) + } else { + None + }; + Self::new(data_type, v) + } +} + +impl From> for PrimitiveScalar { + #[inline] + fn from(v: Option) -> Self { + Self::new(T::DATA_TYPE, v) + } +} + +impl Scalar for PrimitiveScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.is_valid + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[allow(clippy::eq_op)] + #[test] + fn equal() { + let a = PrimitiveScalar::from(Some(2i32)); + let b = PrimitiveScalar::::from(None); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = PrimitiveScalar::::from(Some(1i32)); + assert!(a != b); + assert_eq!(b, b); + } + + #[test] + fn basics() { + let a = PrimitiveScalar::from(Some(2i32)); + + assert_eq!(a.value(), 2i32); + assert_eq!(a.data_type(), &DataType::Int32); + assert!(a.is_valid()); + + let a = a.to(DataType::Date32); + assert_eq!(a.data_type(), &DataType::Date32); + + let a = PrimitiveScalar::::from(None); + + assert_eq!(a.data_type(), &DataType::Int32); + assert!(!a.is_valid()); + + let a = a.to(DataType::Date32); + assert_eq!(a.data_type(), &DataType::Date32); + } +} diff --git a/src/scalar/struct_.rs b/src/scalar/struct_.rs new file mode 100644 index 00000000000..eab4671f1dc --- /dev/null +++ b/src/scalar/struct_.rs @@ -0,0 +1,54 @@ +use std::sync::Arc; + +use crate::datatypes::DataType; + +use super::Scalar; + +#[derive(Debug, Clone)] +pub struct StructScalar { + values: Vec>, + is_valid: bool, + data_type: DataType, +} + +impl PartialEq for StructScalar { + fn eq(&self, other: &Self) -> bool { + (self.data_type == other.data_type) + && (self.is_valid == other.is_valid) + && ((!self.is_valid) | (self.values == other.values)) + } +} + +impl StructScalar { + #[inline] + pub fn new(data_type: DataType, values: Option>>) -> Self { + let is_valid = values.is_some(); + Self { + values: values.unwrap_or_default(), + is_valid, + data_type, + } + } + + #[inline] + pub fn values(&self) -> &[Arc] { + &self.values + } +} + +impl Scalar for StructScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.is_valid + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/src/scalar/utf8.rs b/src/scalar/utf8.rs new file mode 100644 index 00000000000..32fe5d7f1bb --- /dev/null +++ b/src/scalar/utf8.rs @@ -0,0 +1,95 @@ +use crate::{array::*, buffer::Buffer, datatypes::DataType}; + +use super::Scalar; + +#[derive(Debug, Clone)] +pub struct Utf8Scalar { + value: Buffer, + is_valid: bool, + phantom: std::marker::PhantomData, +} + +impl PartialEq for Utf8Scalar { + fn eq(&self, other: &Self) -> bool { + self.is_valid == other.is_valid && ((!self.is_valid) | (self.value == other.value)) + } +} + +impl Utf8Scalar { + #[inline] + pub fn new>(v: Option

) -> Self { + let is_valid = v.is_some(); + O::from_usize(v.as_ref().map(|x| x.as_ref().len()).unwrap_or_default()).expect("Too large"); + let value = Buffer::from(v.as_ref().map(|x| x.as_ref().as_bytes()).unwrap_or(&[])); + Self { + value, + is_valid, + phantom: std::marker::PhantomData, + } + } + + #[inline] + pub fn value(&self) -> &str { + unsafe { std::str::from_utf8_unchecked(self.value.as_slice()) } + } +} + +impl> From> for Utf8Scalar { + #[inline] + fn from(v: Option

) -> Self { + Self::new(v) + } +} + +impl Scalar for Utf8Scalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.is_valid + } + + #[inline] + fn data_type(&self) -> &DataType { + if O::is_large() { + &DataType::LargeUtf8 + } else { + &DataType::Utf8 + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[allow(clippy::eq_op)] + #[test] + fn equal() { + let a = Utf8Scalar::::from(Some("a")); + let b = Utf8Scalar::::from(None::<&str>); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = Utf8Scalar::::from(Some("b")); + assert!(a != b); + assert_eq!(b, b); + } + + #[test] + fn basics() { + let a = Utf8Scalar::::from(Some("a")); + + assert_eq!(a.value(), "a"); + assert_eq!(a.data_type(), &DataType::Utf8); + assert!(a.is_valid()); + + let a = Utf8Scalar::::from(None::<&str>); + + assert_eq!(a.data_type(), &DataType::LargeUtf8); + assert!(!a.is_valid()); + } +}