Skip to content

Commit

Permalink
fix: RP feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
roeap committed Jan 31, 2024
1 parent 6b5c85e commit aa44f36
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 39 deletions.
1 change: 1 addition & 0 deletions kernel/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ bytes = "1.4"
chrono = { version = "0.4" }
either = "1.8"
fix-hidden-lifetime-bug = "0.2"
indexmap = "2.2.1"
itertools = "0.12"
lazy_static = "1.4"
regex = "1.8"
Expand Down
11 changes: 4 additions & 7 deletions kernel/src/client/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ impl TryFrom<&StructType> for ArrowSchema {
fn try_from(s: &StructType) -> Result<Self, ArrowError> {
let fields = s
.fields()
.iter()
.map(|f| <ArrowField as TryFrom<&StructField>>::try_from(*f))
.map(|f| <ArrowField as TryFrom<&StructField>>::try_from(f))
.collect::<Result<Vec<ArrowField>, ArrowError>>()?;

Ok(ArrowSchema::new(fields))
Expand Down Expand Up @@ -105,11 +104,10 @@ impl TryFrom<&DataType> for ArrowDataType {
PrimitiveType::Decimal(precision, scale) => {
if precision <= &38 {
Ok(ArrowDataType::Decimal128(*precision, *scale))
} else if precision <= &76 {
Ok(ArrowDataType::Decimal256(*precision, *scale))
} else {
// NOTE: since we are converting from delta, we should never get here.
Err(ArrowError::SchemaError(format!(
"Precision too large to be represented in Arrow: {}",
"Precision too large to be represented as Delta type: {} > 38",
precision
)))
}
Expand All @@ -127,8 +125,7 @@ impl TryFrom<&DataType> for ArrowDataType {
}
DataType::Struct(s) => Ok(ArrowDataType::Struct(
s.fields()
.iter()
.map(|f| <ArrowField as TryFrom<&StructField>>::try_from(*f))
.map(|f| <ArrowField as TryFrom<&StructField>>::try_from(f))
.collect::<Result<Vec<ArrowField>, ArrowError>>()?
.into(),
)),
Expand Down
18 changes: 9 additions & 9 deletions kernel/src/client/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ impl Scalar {
PrimitiveType::Binary => Arc::new(BinaryArray::new_null(num_rows)),
PrimitiveType::Decimal(precision, scale) => Arc::new(
Decimal128Array::new_null(num_rows)
.with_precision_and_scale(*precision, *scale)
.unwrap(),
.with_precision_and_scale(*precision, *scale)?,
),
},
DataType::Array(_) => unimplemented!(),
Expand Down Expand Up @@ -199,27 +198,28 @@ fn evaluate_expression(

eval(&left_arr, &right_arr).map_err(Error::generic_err)
}
(VariadicOperation { op, exprs }, _) => {
(VariadicOperation { op, exprs }, Some(&DataType::BOOLEAN)) => {
type Operation = fn(&BooleanArray, &BooleanArray) -> Result<BooleanArray, ArrowError>;
let (reducer, default): (Operation, _) = match op {
VariadicOperator::And => (and, true),
VariadicOperator::Or => (or, false),
};
exprs
.iter()
.map(|expr| evaluate_expression(expr, batch, Some(&DataType::BOOLEAN)))
.map(|expr| evaluate_expression(expr, batch, result_type))
.reduce(|l, r| {
Ok(reducer(downcast_to_bool(&l?)?, downcast_to_bool(&r?)?)
.map(wrap_comparison_result)?)
})
.unwrap_or_else(|| {
evaluate_expression(
&Expression::literal(default),
batch,
Some(&DataType::BOOLEAN),
)
evaluate_expression(&Expression::literal(default), batch, result_type)
})
}
(VariadicOperation { .. }, _) => {
// NOTE: this panics as it would be a bug in our code if we get here. However it does swallow
// the error message from the compiler if we add variants to the enum and forget to add them here.
unreachable!("We unly support variadic operations for boolean expressions right now.")
}
(NullIf { expr, if_expr }, _) => {
let expr_arr = evaluate_expression(expr.as_ref(), batch, None)?;
let if_expr_arr =
Expand Down
3 changes: 1 addition & 2 deletions kernel/src/scan/data_skipping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ impl DataSkippingFilter {
// Build the stats read schema by extracting the column names referenced by the predicate,
// extracting the corresponding field from the table schema, and inserting that field.
let data_fields: Vec<_> = table_schema
.fields
.iter()
.fields()
.filter(|field| field_names.contains(&field.name.as_str()))
.cloned()
.collect();
Expand Down
12 changes: 5 additions & 7 deletions kernel/src/scan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ impl Scan {

let select_fields = read_schema
.fields()
.iter()
.map(|f| Expression::Column(f.name().to_string()))
.collect_vec();

Expand Down Expand Up @@ -202,13 +201,12 @@ impl Scan {
let mut fields =
Vec::with_capacity(partition_fields.len() + batch.num_columns());
for field in &partition_fields {
let value_expression =
if let Some(Some(value)) = add.partition_values.get(field.name()) {
let value_expression = match add.partition_values.get(field.name()) {
Some(Some(value)) => {
Expression::Literal(get_partition_value(value, field.data_type())?)
} else {
// TODO: is it allowed to assume null for missing partition values?
Expression::Literal(Scalar::Null(field.data_type().clone()))
};
}
_ => Expression::Literal(Scalar::Null(field.data_type().clone())),
};
fields.push(value_expression);
}
fields.extend(select_fields.clone());
Expand Down
67 changes: 53 additions & 14 deletions kernel/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ use std::fmt::Formatter;
use std::sync::Arc;
use std::{collections::HashMap, fmt::Display};

use indexmap::IndexMap;
use serde::{Deserialize, Serialize};

pub type Schema = StructType;
pub type SchemaRef = Arc<StructType>;

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq)]
#[serde(untagged)]
pub enum MetadataValue {
Number(i32),
Expand Down Expand Up @@ -59,7 +60,7 @@ impl AsRef<str> for ColumnMetadataKey {
}
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq)]
pub struct StructField {
/// Name of this (possibly nested) column
pub name: String,
Expand Down Expand Up @@ -121,32 +122,70 @@ impl StructField {

/// A struct is used to represent both the top-level schema of the table
/// as well as struct columns that contain nested columns.
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[derive(Debug, PartialEq, Clone, Eq)]
pub struct StructType {
#[serde(rename = "type")]
pub type_name: String,
/// The type of element stored in this array
pub fields: Vec<StructField>,
pub fields: IndexMap<String, StructField>,
}

impl StructType {
pub fn new(fields: Vec<StructField>) -> Self {
Self {
type_name: "struct".into(),
fields,
fields: fields.into_iter().map(|f| (f.name.clone(), f)).collect(),
}
}

pub fn field(&self, name: impl AsRef<str>) -> Option<&StructField> {
self.fields.iter().find(|f| f.name == name.as_ref())
self.fields.get(name.as_ref())
}

pub fn fields(&self) -> impl Iterator<Item = &StructField> {
self.fields.values()
}
}

#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
struct StructTypeSerDeHelper {
#[serde(rename = "type")]
type_name: String,
fields: Vec<StructField>,
}

impl Serialize for StructType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
StructTypeSerDeHelper {
type_name: self.type_name.clone(),
fields: self.fields.values().cloned().collect(),
}
.serialize(serializer)
}
}

pub fn fields(&self) -> Vec<&StructField> {
self.fields.iter().collect()
impl<'de> Deserialize<'de> for StructType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
Self: Sized,
{
let helper = StructTypeSerDeHelper::deserialize(deserializer)?;
Ok(Self {
type_name: helper.type_name,
fields: helper
.fields
.into_iter()
.map(|f| (f.name.clone(), f))
.collect(),
})
}
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq)]
#[serde(rename_all = "camelCase")]
pub struct ArrayType {
#[serde(rename = "type")]
Expand Down Expand Up @@ -177,7 +216,7 @@ impl ArrayType {
}
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq)]
#[serde(rename_all = "camelCase")]
pub struct MapType {
#[serde(rename = "type")]
Expand Down Expand Up @@ -221,7 +260,7 @@ fn default_true() -> bool {
true
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq)]
#[serde(rename_all = "camelCase")]
pub enum PrimitiveType {
/// UTF-8 encoded string of characters
Expand Down Expand Up @@ -311,7 +350,7 @@ impl Display for PrimitiveType {
}
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq)]
#[serde(untagged, rename_all = "camelCase")]
pub enum DataType {
/// UTF-8 encoded string of characters
Expand Down Expand Up @@ -369,7 +408,7 @@ impl Display for DataType {
DataType::Array(a) => write!(f, "array<{}>", a.element_type),
DataType::Struct(s) => {
write!(f, "struct<")?;
for (i, field) in s.fields.iter().enumerate() {
for (i, (_, field)) in s.fields.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
Expand Down

0 comments on commit aa44f36

Please sign in to comment.