Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Added remaining scalars and improved API.
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Jul 31, 2021
1 parent d086e7d commit 2ca4b16
Show file tree
Hide file tree
Showing 12 changed files with 619 additions and 221 deletions.
39 changes: 37 additions & 2 deletions src/compute/aggregate/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,32 @@ macro_rules! dyn_sum {
}};
}

pub fn can_sum(data_type: &DataType) -> bool {
use DataType::*;
matches!(
data_type,
Int8 | Int16
| Date32
| Time32(_)
| Interval(IntervalUnit::YearMonth)
| Int64
| Date64
| Time64(_)
| Timestamp(_, _)
| Duration(_)
| UInt8
| UInt16
| UInt32
| UInt64
| Float32
| Float64
)
}

/// Returns the sum of all elements in `array` as a [`Scalar`] of the same physical
/// and logical types as `array`.
/// # Error
/// Errors iff the operation is not supported.
pub fn sum(array: &dyn Array) -> Result<Box<dyn Scalar>> {
Ok(match array.data_type() {
DataType::Int8 => dyn_sum!(i8, array),
Expand Down Expand Up @@ -158,13 +184,22 @@ mod tests {
#[test]
fn test_primitive_array_sum() {
let a = Int32Array::from_slice(&[1, 2, 3, 4, 5]);
assert_eq!(15, sum(&a).unwrap());
assert_eq!(
&PrimitiveScalar::<i32>::from(Some(15)) as &dyn Scalar,
sum(&a).unwrap().as_ref()
);

let a = a.to(DataType::Date32);
assert_eq!(
&PrimitiveScalar::<i32>::from(Some(15)).to(DataType::Date32) as &dyn Scalar,
sum(&a).unwrap().as_ref()
);
}

#[test]
fn test_primitive_array_float_sum() {
let a = Float64Array::from_slice(&[1.1f64, 2.2, 3.3, 4.4, 5.5]);
assert!((16.5 - sum(&a).unwrap()).abs() < f64::EPSILON);
assert!((16.5 - sum_primitive(&a).unwrap()).abs() < f64::EPSILON);
}

#[test]
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod bitmap;
pub mod buffer;
mod endianess;
pub mod error;
#[cfg(feature = "compute")]
pub mod scalar;
pub mod trusted_len;
pub mod types;
Expand Down
2 changes: 1 addition & 1 deletion src/scalar/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ There are three reasons:
* forward-compatibility: a new entry on an `enum` is backward-incompatible
* do not expose implementation details to users (reduce the surface of the public API)

### `Scalar` should contain nullability information
### `Scalar` MUST contain nullability information

This is to be aligned with the general notion of arrow's `Array`.

Expand Down
50 changes: 50 additions & 0 deletions src/scalar/binary.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use crate::{array::*, buffer::Buffer, datatypes::DataType};

use super::Scalar;

#[derive(Debug, Clone, PartialEq)]
pub struct BinaryScalar<O: Offset> {
value: Buffer<u8>,
is_valid: bool,
phantom: std::marker::PhantomData<O>,
}

impl<O: Offset> BinaryScalar<O> {
#[inline]
pub fn new(v: Option<&[u8]>) -> Self {
let is_valid = v.is_some();
O::from_usize(v.map(|x| x.len()).unwrap_or_default()).expect("Too large");
let value = Buffer::from(v.unwrap_or(&[]));
Self {
value,
is_valid,
phantom: std::marker::PhantomData,
}
}

#[inline]
pub fn value(&self) -> &[u8] {
self.value.as_slice()
}
}

impl<O: Offset> Scalar for BinaryScalar<O> {
#[inline]
fn as_any(&self) -> &dyn std::any::Any {
self
}

#[inline]
fn is_valid(&self) -> bool {
self.is_valid
}

#[inline]
fn data_type(&self) -> &DataType {
if O::is_large() {
&DataType::LargeBinary
} else {
&DataType::Binary
}
}
}
42 changes: 42 additions & 0 deletions src/scalar/boolean.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use crate::datatypes::DataType;

use super::Scalar;

#[derive(Debug, Clone, PartialEq)]
pub struct BooleanScalar {
value: bool,
is_valid: bool,
}

impl BooleanScalar {
#[inline]
pub fn new(v: Option<bool>) -> Self {
let is_valid = v.is_some();
Self {
value: v.unwrap_or_default(),
is_valid,
}
}

#[inline]
pub fn value(&self) -> bool {
self.value
}
}

impl Scalar for BooleanScalar {
#[inline]
fn as_any(&self) -> &dyn std::any::Any {
self
}

#[inline]
fn is_valid(&self) -> bool {
self.is_valid
}

#[inline]
fn data_type(&self) -> &DataType {
&DataType::Boolean
}
}
115 changes: 115 additions & 0 deletions src/scalar/equal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
use super::*;

impl PartialEq for dyn Scalar {
fn eq(&self, other: &Self) -> bool {
equal(self, other)
}
}

macro_rules! dyn_eq {
($ty:ty, $lhs:expr, $rhs:expr) => {{
let lhs = $lhs
.as_any()
.downcast_ref::<PrimitiveScalar<$ty>>()
.unwrap();
let rhs = $rhs
.as_any()
.downcast_ref::<PrimitiveScalar<$ty>>()
.unwrap();
lhs == rhs
}};
}

fn equal(lhs: &dyn Scalar, rhs: &dyn Scalar) -> bool {
if lhs.data_type() != rhs.data_type() {
return false;
}

match lhs.data_type() {
DataType::Null => {
let lhs = lhs.as_any().downcast_ref::<NullScalar>().unwrap();
let rhs = rhs.as_any().downcast_ref::<NullScalar>().unwrap();
lhs == rhs
}
DataType::Boolean => {
let lhs = lhs.as_any().downcast_ref::<BooleanScalar>().unwrap();
let rhs = rhs.as_any().downcast_ref::<BooleanScalar>().unwrap();
lhs == rhs
}
DataType::UInt8 => {
dyn_eq!(u8, lhs, rhs)
}
DataType::UInt16 => {
dyn_eq!(u16, lhs, rhs)
}
DataType::UInt32 => {
dyn_eq!(u32, lhs, rhs)
}
DataType::UInt64 => {
dyn_eq!(u64, lhs, rhs)
}
DataType::Int8 => {
dyn_eq!(i8, lhs, rhs)
}
DataType::Int16 => {
dyn_eq!(i16, lhs, rhs)
}
DataType::Int32
| DataType::Date32
| DataType::Time32(_)
| DataType::Interval(IntervalUnit::YearMonth) => {
dyn_eq!(i32, lhs, rhs)
}
DataType::Int64
| DataType::Date64
| DataType::Time64(_)
| DataType::Timestamp(_, _)
| DataType::Duration(_) => {
dyn_eq!(i64, lhs, rhs)
}
DataType::Decimal(_, _) => {
dyn_eq!(i128, lhs, rhs)
}
DataType::Interval(IntervalUnit::DayTime) => {
dyn_eq!(days_ms, lhs, rhs)
}
DataType::Float16 => unreachable!(),
DataType::Float32 => {
dyn_eq!(f32, lhs, rhs)
}
DataType::Float64 => {
dyn_eq!(f64, lhs, rhs)
}
DataType::Utf8 => {
let lhs = lhs.as_any().downcast_ref::<Utf8Scalar<i32>>().unwrap();
let rhs = rhs.as_any().downcast_ref::<Utf8Scalar<i32>>().unwrap();
lhs == rhs
}
DataType::LargeUtf8 => {
let lhs = lhs.as_any().downcast_ref::<Utf8Scalar<i64>>().unwrap();
let rhs = rhs.as_any().downcast_ref::<Utf8Scalar<i64>>().unwrap();
lhs == rhs
}
DataType::Binary => {
let lhs = lhs.as_any().downcast_ref::<BinaryScalar<i32>>().unwrap();
let rhs = rhs.as_any().downcast_ref::<BinaryScalar<i32>>().unwrap();
lhs == rhs
}
DataType::LargeBinary => {
let lhs = lhs.as_any().downcast_ref::<BinaryScalar<i64>>().unwrap();
let rhs = rhs.as_any().downcast_ref::<BinaryScalar<i64>>().unwrap();
lhs == rhs
}
DataType::List(_) => {
let lhs = lhs.as_any().downcast_ref::<ListScalar<i32>>().unwrap();
let rhs = rhs.as_any().downcast_ref::<ListScalar<i32>>().unwrap();
lhs == rhs
}
DataType::LargeList(_) => {
let lhs = lhs.as_any().downcast_ref::<ListScalar<i64>>().unwrap();
let rhs = rhs.as_any().downcast_ref::<ListScalar<i64>>().unwrap();
lhs == rhs
}
_ => unimplemented!(),
}
}
62 changes: 62 additions & 0 deletions src/scalar/list.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use std::any::Any;
use std::sync::Arc;

use crate::{array::*, datatypes::DataType};

use super::Scalar;

/// The scalar equivalent of [`ListArray`]. Like [`ListArray`], this struct holds a dynamically-typed
/// [`Array`]. The only difference is that this has only one element.
#[derive(Debug, Clone)]
pub struct ListScalar<O: Offset> {
values: Arc<dyn Array>,
is_valid: bool,
phantom: std::marker::PhantomData<O>,
data_type: DataType,
}

impl<O: Offset> PartialEq for ListScalar<O> {
fn eq(&self, other: &Self) -> bool {
(self.data_type == other.data_type)
&& (self.is_valid == other.is_valid)
&& (self.is_valid && (self.values.as_ref() == other.values.as_ref()))
}
}

pub enum ListScalarNew {
Array(Arc<dyn Array>),
DataType(DataType),
}

impl<O: Offset> ListScalar<O> {
#[inline]
pub fn new(data_type: DataType, values: Option<Arc<dyn Array>>) -> Self {
let (is_valid, values) = match values {
Some(values) => (true, values),
None => {
let data_type = ListArray::<O>::get_child_type(&data_type).clone();
(false, new_empty_array(data_type).into())
}
};
Self {
values,
is_valid,
phantom: std::marker::PhantomData,
data_type,
}
}
}

impl<O: Offset> Scalar for ListScalar<O> {
fn as_any(&self) -> &dyn Any {
self
}

fn is_valid(&self) -> bool {
self.is_valid
}

fn data_type(&self) -> &DataType {
&self.data_type
}
}
Loading

0 comments on commit 2ca4b16

Please sign in to comment.