Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Added Scalar API (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao authored Aug 11, 2021
1 parent eb7b34f commit 4b987fc
Show file tree
Hide file tree
Showing 13 changed files with 1,003 additions and 6 deletions.
91 changes: 91 additions & 0 deletions src/compute/aggregate/min_max.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact};
use crate::datatypes::{DataType, IntervalUnit};
use crate::error::{ArrowError, Result};
use crate::scalar::*;
use crate::types::simd::*;
use crate::types::NativeType;
use crate::{
Expand Down Expand Up @@ -281,6 +284,94 @@ pub fn max_boolean(array: &BooleanArray) -> Option<bool> {
.or(Some(false))
}

macro_rules! dyn_primitive {
($ty:ty, $array:expr, $f:ident) => {{
let array = $array
.as_any()
.downcast_ref::<PrimitiveArray<$ty>>()
.unwrap();
Box::new(PrimitiveScalar::<$ty>::new(
$array.data_type().clone(),
$f::<$ty>(array),
))
}};
}

macro_rules! dyn_generic {
($array_ty:ty, $scalar_ty:ty, $array:expr, $f:ident) => {{
let array = $array.as_any().downcast_ref::<$array_ty>().unwrap();
Box::new(<$scalar_ty>::new($f(array)))
}};
}

pub fn max(array: &dyn Array) -> Result<Box<dyn Scalar>> {
Ok(match array.data_type() {
DataType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, max_boolean),
DataType::Int8 => dyn_primitive!(i8, array, max_primitive),
DataType::Int16 => dyn_primitive!(i16, array, max_primitive),
DataType::Int32
| DataType::Date32
| DataType::Time32(_)
| DataType::Interval(IntervalUnit::YearMonth) => {
dyn_primitive!(i32, array, max_primitive)
}
DataType::Int64
| DataType::Date64
| DataType::Time64(_)
| DataType::Timestamp(_, _)
| DataType::Duration(_) => dyn_primitive!(i64, array, max_primitive),
DataType::UInt8 => dyn_primitive!(u8, array, max_primitive),
DataType::UInt16 => dyn_primitive!(u16, array, max_primitive),
DataType::UInt32 => dyn_primitive!(u32, array, max_primitive),
DataType::UInt64 => dyn_primitive!(u64, array, max_primitive),
DataType::Float16 => unreachable!(),
DataType::Float32 => dyn_primitive!(f32, array, max_primitive),
DataType::Float64 => dyn_primitive!(f64, array, max_primitive),
DataType::Utf8 => dyn_generic!(Utf8Array<i32>, Utf8Scalar<i32>, array, max_string),
DataType::LargeUtf8 => dyn_generic!(Utf8Array<i64>, Utf8Scalar<i64>, array, max_string),
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
"The `max` operator does not support type `{}`",
array.data_type(),
)))
}
})
}

pub fn min(array: &dyn Array) -> Result<Box<dyn Scalar>> {
Ok(match array.data_type() {
DataType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, min_boolean),
DataType::Int8 => dyn_primitive!(i8, array, min_primitive),
DataType::Int16 => dyn_primitive!(i16, array, min_primitive),
DataType::Int32
| DataType::Date32
| DataType::Time32(_)
| DataType::Interval(IntervalUnit::YearMonth) => {
dyn_primitive!(i32, array, min_primitive)
}
DataType::Int64
| DataType::Date64
| DataType::Time64(_)
| DataType::Timestamp(_, _)
| DataType::Duration(_) => dyn_primitive!(i64, array, min_primitive),
DataType::UInt8 => dyn_primitive!(u8, array, min_primitive),
DataType::UInt16 => dyn_primitive!(u16, array, min_primitive),
DataType::UInt32 => dyn_primitive!(u32, array, min_primitive),
DataType::UInt64 => dyn_primitive!(u64, array, min_primitive),
DataType::Float16 => unreachable!(),
DataType::Float32 => dyn_primitive!(f32, array, min_primitive),
DataType::Float64 => dyn_primitive!(f64, array, min_primitive),
DataType::Utf8 => dyn_generic!(Utf8Array<i32>, Utf8Scalar<i32>, array, min_string),
DataType::LargeUtf8 => dyn_generic!(Utf8Array<i64>, Utf8Scalar<i64>, array, min_string),
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
"The `max` operator does not support type `{}`",
array.data_type(),
)))
}
})
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
97 changes: 91 additions & 6 deletions src/compute/aggregate/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use std::ops::Add;
use multiversion::multiversion;

use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact};
use crate::datatypes::{DataType, IntervalUnit};
use crate::error::{ArrowError, Result};
use crate::scalar::*;
use crate::types::simd::*;
use crate::types::NativeType;
use crate::{
Expand Down Expand Up @@ -85,7 +88,7 @@ where
/// Returns the sum of values in the array.
///
/// Returns `None` if the array is empty or only contains null values.
pub fn sum<T>(array: &PrimitiveArray<T>) -> Option<T>
pub fn sum_primitive<T>(array: &PrimitiveArray<T>) -> Option<T>
where
T: NativeType + Simd,
T::Simd: Add<Output = T::Simd> + Sum<T>,
Expand All @@ -102,6 +105,76 @@ where
}
}

macro_rules! dyn_sum {
($ty:ty, $array:expr) => {{
let array = $array
.as_any()
.downcast_ref::<PrimitiveArray<$ty>>()
.unwrap();
Box::new(PrimitiveScalar::<$ty>::new(
$array.data_type().clone(),
sum_primitive::<$ty>(array),
))
}};
}

pub fn can_sum(data_type: &DataType) -> bool {
use DataType::*;
matches!(
data_type,
Int8 | Int16
| Date32
| Time32(_)
| Interval(IntervalUnit::YearMonth)
| Int64
| Date64
| Time64(_)
| Timestamp(_, _)
| Duration(_)
| UInt8
| UInt16
| UInt32
| UInt64
| Float32
| Float64
)
}

/// Returns the sum of all elements in `array` as a [`Scalar`] of the same physical
/// and logical types as `array`.
/// # Error
/// Errors iff the operation is not supported.
pub fn sum(array: &dyn Array) -> Result<Box<dyn Scalar>> {
Ok(match array.data_type() {
DataType::Int8 => dyn_sum!(i8, array),
DataType::Int16 => dyn_sum!(i16, array),
DataType::Int32
| DataType::Date32
| DataType::Time32(_)
| DataType::Interval(IntervalUnit::YearMonth) => {
dyn_sum!(i32, array)
}
DataType::Int64
| DataType::Date64
| DataType::Time64(_)
| DataType::Timestamp(_, _)
| DataType::Duration(_) => dyn_sum!(i64, array),
DataType::UInt8 => dyn_sum!(u8, array),
DataType::UInt16 => dyn_sum!(u16, array),
DataType::UInt32 => dyn_sum!(u32, array),
DataType::UInt64 => dyn_sum!(u64, array),
DataType::Float16 => unreachable!(),
DataType::Float32 => dyn_sum!(f32, array),
DataType::Float64 => dyn_sum!(f64, array),
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
"The `sum` operator does not support type `{}`",
array.data_type(),
)))
}
})
}

#[cfg(test)]
mod tests {
use super::super::super::arithmetics;
Expand All @@ -111,25 +184,34 @@ mod tests {
#[test]
fn test_primitive_array_sum() {
let a = Int32Array::from_slice(&[1, 2, 3, 4, 5]);
assert_eq!(15, sum(&a).unwrap());
assert_eq!(
&PrimitiveScalar::<i32>::from(Some(15)) as &dyn Scalar,
sum(&a).unwrap().as_ref()
);

let a = a.to(DataType::Date32);
assert_eq!(
&PrimitiveScalar::<i32>::from(Some(15)).to(DataType::Date32) as &dyn Scalar,
sum(&a).unwrap().as_ref()
);
}

#[test]
fn test_primitive_array_float_sum() {
let a = Float64Array::from_slice(&[1.1f64, 2.2, 3.3, 4.4, 5.5]);
assert!((16.5 - sum(&a).unwrap()).abs() < f64::EPSILON);
assert!((16.5 - sum_primitive(&a).unwrap()).abs() < f64::EPSILON);
}

#[test]
fn test_primitive_array_sum_with_nulls() {
let a = Int32Array::from(&[None, Some(2), Some(3), None, Some(5)]);
assert_eq!(10, sum(&a).unwrap());
assert_eq!(10, sum_primitive(&a).unwrap());
}

#[test]
fn test_primitive_array_sum_all_nulls() {
let a = Int32Array::from(&[None, None, None]);
assert_eq!(None, sum(&a));
assert_eq!(None, sum_primitive(&a));
}

#[test]
Expand All @@ -142,6 +224,9 @@ mod tests {
.collect();
// create an array that actually has non-zero values at the invalid indices
let c = arithmetics::basic::add::add(&a, &b).unwrap();
assert_eq!(Some((1..=100).filter(|i| i % 3 == 0).sum()), sum(&c));
assert_eq!(
Some((1..=100).filter(|i| i % 3 == 0).sum()),
sum_primitive(&c)
);
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod bitmap;
pub mod buffer;
mod endianess;
pub mod error;
pub mod scalar;
pub mod trusted_len;
pub mod types;

Expand Down
28 changes: 28 additions & 0 deletions src/scalar/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Scalar API

Design choices:

### `Scalar` is trait object

There are three reasons:

* a scalar should have a small memory footprint, which an enum would not ensure given the different physical types available.
* forward-compatibility: a new entry on an `enum` is backward-incompatible
* do not expose implementation details to users (reduce the surface of the public API)

### `Scalar` MUST contain nullability information

This is to be aligned with the general notion of arrow's `Array`.

This API is a companion to the `Array`, and follows the same design as `Array`.
Specifically, a `Scalar` is a trait object that can be downcasted to concrete implementations.

Like `Array`, `Scalar` implements

* `data_type`, which is used to perform the correct downcast
* `is_valid`, to tell whether the scalar is null or not

### There is one implementation per arrows' physical type

* Reduces the number of `match` that users need to write
* Allows casting of logical types without changing the underlying physical type
95 changes: 95 additions & 0 deletions src/scalar/binary.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use crate::{array::*, buffer::Buffer, datatypes::DataType};

use super::Scalar;

#[derive(Debug, Clone)]
pub struct BinaryScalar<O: Offset> {
value: Buffer<u8>,
is_valid: bool,
phantom: std::marker::PhantomData<O>,
}

impl<O: Offset> PartialEq for BinaryScalar<O> {
fn eq(&self, other: &Self) -> bool {
self.is_valid == other.is_valid && ((!self.is_valid) | (self.value == other.value))
}
}

impl<O: Offset> BinaryScalar<O> {
#[inline]
pub fn new<P: AsRef<[u8]>>(v: Option<P>) -> Self {
let is_valid = v.is_some();
O::from_usize(v.as_ref().map(|x| x.as_ref().len()).unwrap_or_default()).expect("Too large");
let value = Buffer::from(v.as_ref().map(|x| x.as_ref()).unwrap_or(&[]));
Self {
value,
is_valid,
phantom: std::marker::PhantomData,
}
}

#[inline]
pub fn value(&self) -> &[u8] {
self.value.as_slice()
}
}

impl<O: Offset, P: AsRef<[u8]>> From<Option<P>> for BinaryScalar<O> {
#[inline]
fn from(v: Option<P>) -> Self {
Self::new(v)
}
}

impl<O: Offset> Scalar for BinaryScalar<O> {
#[inline]
fn as_any(&self) -> &dyn std::any::Any {
self
}

#[inline]
fn is_valid(&self) -> bool {
self.is_valid
}

#[inline]
fn data_type(&self) -> &DataType {
if O::is_large() {
&DataType::LargeBinary
} else {
&DataType::Binary
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[allow(clippy::eq_op)]
#[test]
fn equal() {
let a = BinaryScalar::<i32>::from(Some("a"));
let b = BinaryScalar::<i32>::from(None::<&str>);
assert_eq!(a, a);
assert_eq!(b, b);
assert!(a != b);
let b = BinaryScalar::<i32>::from(Some("b"));
assert!(a != b);
assert_eq!(b, b);
}

#[test]
fn basics() {
let a = BinaryScalar::<i32>::from(Some("a"));

assert_eq!(a.value(), b"a");
assert_eq!(a.data_type(), &DataType::Binary);
assert!(a.is_valid());

let a = BinaryScalar::<i64>::from(None::<&str>);

assert_eq!(a.data_type(), &DataType::LargeBinary);
assert!(!a.is_valid());
}
}
Loading

0 comments on commit 4b987fc

Please sign in to comment.