Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Added support to io::print and Display for UnionArray.
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Aug 14, 2021
1 parent ef2bcc9 commit 91a2e51
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 20 deletions.
13 changes: 12 additions & 1 deletion src/array/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,18 @@ pub fn get_value_display<'a>(array: &'a dyn Array) -> Box<dyn Fn(usize) -> Strin
string
})
}
Union(_) => todo!(),
Union(_, _, _) => {
let array = array.as_any().downcast_ref::<UnionArray>().unwrap();
let displays = array
.fields()
.iter()
.map(|x| get_display(x.as_ref()))
.collect::<Vec<_>>();
Box::new(move |row: usize| {
let (field, index) = array.index(row);
displays[field](index)
})
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ impl Display for dyn Array {
DataType::LargeList(_) => fmt_dyn!(self, ListArray::<i64>, f),
DataType::FixedSizeList(_, _) => fmt_dyn!(self, FixedSizeListArray, f),
DataType::Struct(_) => fmt_dyn!(self, StructArray, f),
DataType::Union(_, _, _) => unimplemented!(),
DataType::Union(_, _, _) => fmt_dyn!(self, UnionArray, f),
DataType::Dictionary(key_type, _) => match key_type.as_ref() {
DataType::Int8 => fmt_dyn!(self, DictionaryArray::<i8>, f),
DataType::Int16 => fmt_dyn!(self, DictionaryArray::<i16>, f),
Expand Down
104 changes: 87 additions & 17 deletions src/array/union/mod.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
use std::{collections::HashMap, sync::Arc};

use crate::{
array::{display::get_value_display, display_fmt, new_empty_array, Array},
bitmap::Bitmap,
buffer::Buffer,
datatypes::{DataType, Field},
scalar::{new_scalar, Scalar},
};

use super::{new_empty_array, Array};

mod ffi;
mod iterator;

/// A union
type FieldEntry = (usize, Arc<dyn Array>);

/// [`UnionArray`] represents an array whose each slot can contain different values.
///
// How to read a value at slot i:
// ```
// let index = self.types()[i] as usize;
Expand All @@ -24,7 +26,8 @@ mod iterator;
#[derive(Debug, Clone)]
pub struct UnionArray {
types: Buffer<i8>,
fields_hash: HashMap<i8, Arc<dyn Array>>,
// None represents when there is no typeid
fields_hash: Option<HashMap<i8, FieldEntry>>,
fields: Vec<Arc<dyn Array>>,
offsets: Option<Buffer<i32>>,
data_type: DataType,
Expand All @@ -47,7 +50,7 @@ impl UnionArray {

Self {
data_type,
fields_hash: HashMap::new(),
fields_hash: None,
fields,
offsets,
types: Buffer::new(),
Expand All @@ -65,10 +68,6 @@ impl UnionArray {
offsets: Option<Buffer<i32>>,
) -> Self {
let fields_hash = if let DataType::Union(f, ids, is_sparse) = &data_type {
let ids: Vec<i8> = ids
.as_ref()
.map(|x| x.iter().map(|x| *x as i8).collect())
.unwrap_or_else(|| (0..f.len() as i8).collect());
if f.len() != fields.len() {
panic!(
"The number of `fields` must equal the number of fields in the Union DataType"
Expand All @@ -84,7 +83,14 @@ impl UnionArray {
if offsets.is_none() != *is_sparse {
panic!("Sparsness flag must equal to noness of offsets in UnionArray")
}
ids.into_iter().zip(fields.iter().cloned()).collect()
ids.as_ref().map(|ids| {
ids.iter()
.map(|x| *x as i8)
.enumerate()
.zip(fields.iter().cloned())
.map(|((i, type_), field)| (type_, (i, field)))
.collect()
})
} else {
panic!("Union struct must be created with the corresponding Union DataType")
};
Expand Down Expand Up @@ -113,15 +119,40 @@ impl UnionArray {
&self.types
}

pub fn value(&self, index: usize) -> Box<dyn Scalar> {
let field_index = self.types()[index];
let field = self.fields_hash[&field_index].as_ref();
let offset = self
.offsets()
#[inline]
fn field(&self, type_: i8) -> &Arc<dyn Array> {
self.fields_hash
.as_ref()
.map(|x| &x[&type_].1)
.unwrap_or_else(|| &self.fields[type_ as usize])
}

#[inline]
fn field_slot(&self, index: usize) -> usize {
self.offsets()
.as_ref()
.map(|x| x[index] as usize)
.unwrap_or(index);
new_scalar(field, offset)
.unwrap_or(index)
}

/// Returns the index and slot of the field to select from `self.fields`.
pub fn index(&self, index: usize) -> (usize, usize) {
let type_ = self.types()[index];
let field_index = self
.fields_hash
.as_ref()
.map(|x| x[&type_].0)
.unwrap_or_else(|| type_ as usize);
let index = self.field_slot(index);
(field_index, index)
}

/// Returns the slot `index` as a [`Scalar`].
pub fn value(&self, index: usize) -> Box<dyn Scalar> {
let type_ = self.types()[index];
let field = self.field(type_);
let index = self.field_slot(index);
new_scalar(field.as_ref(), index)
}

/// Returns a slice of this [`UnionArray`].
Expand Down Expand Up @@ -181,3 +212,42 @@ impl UnionArray {
}
}
}

impl std::fmt::Display for UnionArray {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let display = get_value_display(self);
let new_lines = false;
let head = "UnionArray";
let iter = self
.iter()
.enumerate()
.map(|(i, x)| if x.is_valid() { Some(display(i)) } else { None });
display_fmt(iter, head, f, new_lines)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{array::*, buffer::Buffer, datatypes::*, error::Result};

#[test]
fn display() -> Result<()> {
let fields = vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
];
let data_type = DataType::Union(fields, None, true);
let types = Buffer::from(&[0, 0, 1]);
let fields = vec![
Arc::new(Int32Array::from(&[Some(1), None, Some(2)])) as Arc<dyn Array>,
Arc::new(Utf8Array::<i32>::from(&[Some("a"), Some("b"), Some("c")])) as Arc<dyn Array>,
];

let array = UnionArray::from_data(data_type, types, fields, None);

assert_eq!(format!("{}", array), "UnionArray[1, , c]");

Ok(())
}
}
34 changes: 33 additions & 1 deletion src/io/print.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ fn create_table(results: &[RecordBatch]) -> Table {

#[cfg(test)]
mod tests {
use crate::{array::*, bitmap::Bitmap, datatypes::*, error::Result};
use crate::{array::*, bitmap::Bitmap, buffer::Buffer, datatypes::*, error::Result};

use super::*;
use std::sync::Arc;
Expand Down Expand Up @@ -426,4 +426,36 @@ mod tests {

Ok(())
}

#[test]
fn test_write_union() -> Result<()> {
let fields = vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
];
let data_type = DataType::Union(fields, None, true);
let types = Buffer::from(&[0, 0, 1]);
let fields = vec![
Arc::new(Int32Array::from(&[Some(1), None, Some(2)])) as Arc<dyn Array>,
Arc::new(Utf8Array::<i32>::from(&[Some("a"), Some("b"), Some("c")])) as Arc<dyn Array>,
];

let array = UnionArray::from_data(data_type, types, fields, None);

let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), true)]);

let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)])?;

let table = write(&[batch]);

let expected = vec![
"+---+", "| a |", "+---+", "| 1 |", "| |", "| c |", "+---+",
];

let actual: Vec<&str> = table.lines().collect();

assert_eq!(expected, actual, "Actual result:\n{}", table);

Ok(())
}
}

0 comments on commit 91a2e51

Please sign in to comment.