Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Added Scalar API #56

Merged
merged 3 commits into from
Aug 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions src/compute/aggregate/min_max.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact};
use crate::datatypes::{DataType, IntervalUnit};
use crate::error::{ArrowError, Result};
use crate::scalar::*;
use crate::types::simd::*;
use crate::types::NativeType;
use crate::{
Expand Down Expand Up @@ -281,6 +284,94 @@ pub fn max_boolean(array: &BooleanArray) -> Option<bool> {
.or(Some(false))
}

macro_rules! dyn_primitive {
($ty:ty, $array:expr, $f:ident) => {{
let array = $array
.as_any()
.downcast_ref::<PrimitiveArray<$ty>>()
.unwrap();
Box::new(PrimitiveScalar::<$ty>::new(
$array.data_type().clone(),
$f::<$ty>(array),
))
}};
}

macro_rules! dyn_generic {
($array_ty:ty, $scalar_ty:ty, $array:expr, $f:ident) => {{
let array = $array.as_any().downcast_ref::<$array_ty>().unwrap();
Box::new(<$scalar_ty>::new($f(array)))
}};
}

pub fn max(array: &dyn Array) -> Result<Box<dyn Scalar>> {
Ok(match array.data_type() {
DataType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, max_boolean),
DataType::Int8 => dyn_primitive!(i8, array, max_primitive),
DataType::Int16 => dyn_primitive!(i16, array, max_primitive),
DataType::Int32
| DataType::Date32
| DataType::Time32(_)
| DataType::Interval(IntervalUnit::YearMonth) => {
dyn_primitive!(i32, array, max_primitive)
}
DataType::Int64
| DataType::Date64
| DataType::Time64(_)
| DataType::Timestamp(_, _)
| DataType::Duration(_) => dyn_primitive!(i64, array, max_primitive),
DataType::UInt8 => dyn_primitive!(u8, array, max_primitive),
DataType::UInt16 => dyn_primitive!(u16, array, max_primitive),
DataType::UInt32 => dyn_primitive!(u32, array, max_primitive),
DataType::UInt64 => dyn_primitive!(u64, array, max_primitive),
DataType::Float16 => unreachable!(),
DataType::Float32 => dyn_primitive!(f32, array, max_primitive),
DataType::Float64 => dyn_primitive!(f64, array, max_primitive),
DataType::Utf8 => dyn_generic!(Utf8Array<i32>, Utf8Scalar<i32>, array, max_string),
DataType::LargeUtf8 => dyn_generic!(Utf8Array<i64>, Utf8Scalar<i64>, array, max_string),
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
"The `max` operator does not support type `{}`",
array.data_type(),
)))
}
})
}

pub fn min(array: &dyn Array) -> Result<Box<dyn Scalar>> {
Ok(match array.data_type() {
DataType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, min_boolean),
DataType::Int8 => dyn_primitive!(i8, array, min_primitive),
DataType::Int16 => dyn_primitive!(i16, array, min_primitive),
DataType::Int32
| DataType::Date32
| DataType::Time32(_)
| DataType::Interval(IntervalUnit::YearMonth) => {
dyn_primitive!(i32, array, min_primitive)
}
DataType::Int64
| DataType::Date64
| DataType::Time64(_)
| DataType::Timestamp(_, _)
| DataType::Duration(_) => dyn_primitive!(i64, array, min_primitive),
DataType::UInt8 => dyn_primitive!(u8, array, min_primitive),
DataType::UInt16 => dyn_primitive!(u16, array, min_primitive),
DataType::UInt32 => dyn_primitive!(u32, array, min_primitive),
DataType::UInt64 => dyn_primitive!(u64, array, min_primitive),
DataType::Float16 => unreachable!(),
DataType::Float32 => dyn_primitive!(f32, array, min_primitive),
DataType::Float64 => dyn_primitive!(f64, array, min_primitive),
DataType::Utf8 => dyn_generic!(Utf8Array<i32>, Utf8Scalar<i32>, array, min_string),
DataType::LargeUtf8 => dyn_generic!(Utf8Array<i64>, Utf8Scalar<i64>, array, min_string),
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
"The `max` operator does not support type `{}`",
array.data_type(),
)))
}
})
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
97 changes: 91 additions & 6 deletions src/compute/aggregate/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use std::ops::Add;
use multiversion::multiversion;

use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact};
use crate::datatypes::{DataType, IntervalUnit};
use crate::error::{ArrowError, Result};
use crate::scalar::*;
use crate::types::simd::*;
use crate::types::NativeType;
use crate::{
Expand Down Expand Up @@ -85,7 +88,7 @@ where
/// Returns the sum of values in the array.
///
/// Returns `None` if the array is empty or only contains null values.
pub fn sum<T>(array: &PrimitiveArray<T>) -> Option<T>
pub fn sum_primitive<T>(array: &PrimitiveArray<T>) -> Option<T>
where
T: NativeType + Simd,
T::Simd: Add<Output = T::Simd> + Sum<T>,
Expand All @@ -102,6 +105,76 @@ where
}
}

macro_rules! dyn_sum {
($ty:ty, $array:expr) => {{
let array = $array
.as_any()
.downcast_ref::<PrimitiveArray<$ty>>()
.unwrap();
Box::new(PrimitiveScalar::<$ty>::new(
$array.data_type().clone(),
sum_primitive::<$ty>(array),
))
}};
}

pub fn can_sum(data_type: &DataType) -> bool {
use DataType::*;
matches!(
data_type,
Int8 | Int16
| Date32
| Time32(_)
| Interval(IntervalUnit::YearMonth)
| Int64
| Date64
| Time64(_)
| Timestamp(_, _)
| Duration(_)
| UInt8
| UInt16
| UInt32
| UInt64
| Float32
| Float64
)
}

/// Returns the sum of all elements in `array` as a [`Scalar`] of the same physical
/// and logical types as `array`.
/// # Error
/// Errors iff the operation is not supported.
pub fn sum(array: &dyn Array) -> Result<Box<dyn Scalar>> {
Ok(match array.data_type() {
DataType::Int8 => dyn_sum!(i8, array),
DataType::Int16 => dyn_sum!(i16, array),
DataType::Int32
| DataType::Date32
| DataType::Time32(_)
| DataType::Interval(IntervalUnit::YearMonth) => {
dyn_sum!(i32, array)
}
DataType::Int64
| DataType::Date64
| DataType::Time64(_)
| DataType::Timestamp(_, _)
| DataType::Duration(_) => dyn_sum!(i64, array),
DataType::UInt8 => dyn_sum!(u8, array),
DataType::UInt16 => dyn_sum!(u16, array),
DataType::UInt32 => dyn_sum!(u32, array),
DataType::UInt64 => dyn_sum!(u64, array),
DataType::Float16 => unreachable!(),
DataType::Float32 => dyn_sum!(f32, array),
DataType::Float64 => dyn_sum!(f64, array),
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
"The `sum` operator does not support type `{}`",
array.data_type(),
)))
}
})
}

#[cfg(test)]
mod tests {
use super::super::super::arithmetics;
Expand All @@ -111,25 +184,34 @@ mod tests {
#[test]
fn test_primitive_array_sum() {
let a = Int32Array::from_slice(&[1, 2, 3, 4, 5]);
assert_eq!(15, sum(&a).unwrap());
assert_eq!(
&PrimitiveScalar::<i32>::from(Some(15)) as &dyn Scalar,
sum(&a).unwrap().as_ref()
);

let a = a.to(DataType::Date32);
assert_eq!(
&PrimitiveScalar::<i32>::from(Some(15)).to(DataType::Date32) as &dyn Scalar,
sum(&a).unwrap().as_ref()
);
}

#[test]
fn test_primitive_array_float_sum() {
let a = Float64Array::from_slice(&[1.1f64, 2.2, 3.3, 4.4, 5.5]);
assert!((16.5 - sum(&a).unwrap()).abs() < f64::EPSILON);
assert!((16.5 - sum_primitive(&a).unwrap()).abs() < f64::EPSILON);
}

#[test]
fn test_primitive_array_sum_with_nulls() {
let a = Int32Array::from(&[None, Some(2), Some(3), None, Some(5)]);
assert_eq!(10, sum(&a).unwrap());
assert_eq!(10, sum_primitive(&a).unwrap());
}

#[test]
fn test_primitive_array_sum_all_nulls() {
let a = Int32Array::from(&[None, None, None]);
assert_eq!(None, sum(&a));
assert_eq!(None, sum_primitive(&a));
}

#[test]
Expand All @@ -142,6 +224,9 @@ mod tests {
.collect();
// create an array that actually has non-zero values at the invalid indices
let c = arithmetics::basic::add::add(&a, &b).unwrap();
assert_eq!(Some((1..=100).filter(|i| i % 3 == 0).sum()), sum(&c));
assert_eq!(
Some((1..=100).filter(|i| i % 3 == 0).sum()),
sum_primitive(&c)
);
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod bitmap;
pub mod buffer;
mod endianess;
pub mod error;
pub mod scalar;
pub mod trusted_len;
pub mod types;

Expand Down
28 changes: 28 additions & 0 deletions src/scalar/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Scalar API

Design choices:

### `Scalar` is trait object

There are three reasons:

* a scalar should have a small memory footprint, which an enum would not ensure given the different physical types available.
* forward-compatibility: a new entry on an `enum` is backward-incompatible
* do not expose implementation details to users (reduce the surface of the public API)

### `Scalar` MUST contain nullability information

This is to be aligned with the general notion of arrow's `Array`.

This API is a companion to the `Array`, and follows the same design as `Array`.
Specifically, a `Scalar` is a trait object that can be downcasted to concrete implementations.

Like `Array`, `Scalar` implements

* `data_type`, which is used to perform the correct downcast
* `is_valid`, to tell whether the scalar is null or not

### There is one implementation per arrows' physical type

* Reduces the number of `match` that users need to write
* Allows casting of logical types without changing the underlying physical type
95 changes: 95 additions & 0 deletions src/scalar/binary.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use crate::{array::*, buffer::Buffer, datatypes::DataType};

use super::Scalar;

#[derive(Debug, Clone)]
pub struct BinaryScalar<O: Offset> {
value: Buffer<u8>,
is_valid: bool,
phantom: std::marker::PhantomData<O>,
}

impl<O: Offset> PartialEq for BinaryScalar<O> {
fn eq(&self, other: &Self) -> bool {
self.is_valid == other.is_valid && ((!self.is_valid) | (self.value == other.value))
}
}

impl<O: Offset> BinaryScalar<O> {
#[inline]
pub fn new<P: AsRef<[u8]>>(v: Option<P>) -> Self {
let is_valid = v.is_some();
O::from_usize(v.as_ref().map(|x| x.as_ref().len()).unwrap_or_default()).expect("Too large");
let value = Buffer::from(v.as_ref().map(|x| x.as_ref()).unwrap_or(&[]));
Self {
value,
is_valid,
phantom: std::marker::PhantomData,
}
}

#[inline]
pub fn value(&self) -> &[u8] {
self.value.as_slice()
}
}

impl<O: Offset, P: AsRef<[u8]>> From<Option<P>> for BinaryScalar<O> {
#[inline]
fn from(v: Option<P>) -> Self {
Self::new(v)
}
}

impl<O: Offset> Scalar for BinaryScalar<O> {
#[inline]
fn as_any(&self) -> &dyn std::any::Any {
self
}

#[inline]
fn is_valid(&self) -> bool {
self.is_valid
}

#[inline]
fn data_type(&self) -> &DataType {
if O::is_large() {
&DataType::LargeBinary
} else {
&DataType::Binary
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[allow(clippy::eq_op)]
#[test]
fn equal() {
let a = BinaryScalar::<i32>::from(Some("a"));
let b = BinaryScalar::<i32>::from(None::<&str>);
assert_eq!(a, a);
assert_eq!(b, b);
assert!(a != b);
let b = BinaryScalar::<i32>::from(Some("b"));
assert!(a != b);
assert_eq!(b, b);
}

#[test]
fn basics() {
let a = BinaryScalar::<i32>::from(Some("a"));

assert_eq!(a.value(), b"a");
assert_eq!(a.data_type(), &DataType::Binary);
assert!(a.is_valid());

let a = BinaryScalar::<i64>::from(None::<&str>);

assert_eq!(a.data_type(), &DataType::LargeBinary);
assert!(!a.is_valid());
}
}
Loading