Skip to content

Commit

Permalink
Field::Index to Field::Name expr transform (#1894)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicholas Gates <nick@nickgates.com>
  • Loading branch information
joseph-isaacs and gatesn authored Jan 11, 2025
1 parent b1dbfe7 commit e8228c0
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 0 deletions.
51 changes: 51 additions & 0 deletions vortex-dtype/src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ use std::fmt::{Display, Formatter};
use std::sync::Arc;

use itertools::Itertools;
use vortex_error::{vortex_err, VortexResult};

use crate::FieldNames;

/// A selector for a field in a struct
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -46,6 +49,54 @@ impl Display for Field {
}
}

impl Field {
/// Returns true if the field is defined by Name
pub fn is_named(&self) -> bool {
matches!(self, Field::Name(_))
}

/// Returns true if the field is defined by Index
pub fn is_indexed(&self) -> bool {
matches!(self, Field::Index(_))
}

/// Convert a field to a named field
pub fn into_named_field(self, field_names: &FieldNames) -> VortexResult<Self> {
match self {
Field::Index(idx) => field_names
.get(idx)
.ok_or_else(|| {
vortex_err!(
"Field index {} out of bounds, it has names {:?}",
idx,
field_names
)
})
.cloned()
.map(Field::Name),
Field::Name(_) => Ok(self),
}
}

/// Convert a field to an indexed field
pub fn into_index_field(self, field_names: &FieldNames) -> VortexResult<Self> {
match self {
Field::Name(name) => field_names
.iter()
.position(|n| *n == name)
.ok_or_else(|| {
vortex_err!(
"Field name {} not found, it has names {:?}",
name,
field_names
)
})
.map(Field::Index),
Field::Index(_) => Ok(self),
}
}
}

/// A path through a (possibly nested) struct, composed of a sequence of field selectors
// TODO(ngates): wrap `Vec<Field>` in Option for cheaper "root" path.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
Expand Down
2 changes: 2 additions & 0 deletions vortex-expr/src/transform/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
//! A collection of transformations that can be applied to a [`crate::ExprRef`].
pub mod partition;
pub mod resolve_field_names;
pub mod simplify;
86 changes: 86 additions & 0 deletions vortex-expr/src/transform/resolve_field_names.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
use vortex_dtype::DType;
use vortex_error::{vortex_err, VortexResult};

use crate::traversal::{MutNodeVisitor, Node, TransformResult};
use crate::{ExprRef, GetItem};

/// Resolves any [`vortex_dtype::Field::Index`] nodes in the expression to
/// [`vortex_dtype::Field::Name`] nodes.
pub fn resolve_field_names(expr: ExprRef, scope_dtype: &DType) -> VortexResult<ExprRef> {
let mut visitor = FieldToNameTransform { scope_dtype };
expr.transform(&mut visitor).map(|node| node.result)
}

struct FieldToNameTransform<'a> {
scope_dtype: &'a DType,
}

impl MutNodeVisitor for FieldToNameTransform<'_> {
type NodeTy = ExprRef;

fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>> {
if let Some(get_item) = node.as_any().downcast_ref::<GetItem>() {
if get_item.field().is_named() {
return Ok(TransformResult::no(node));
}

let child_dtype = get_item.child().return_dtype(self.scope_dtype)?;
let struct_dtype = child_dtype
.as_struct()
.ok_or_else(|| vortex_err!("get_item requires child to have struct dtype"))?;

return Ok(TransformResult::yes(GetItem::new_expr(
get_item
.field()
.clone()
.into_named_field(struct_dtype.names())?,
get_item.child().clone(),
)));
}

Ok(TransformResult::no(node))
}
}

#[cfg(test)]
mod tests {
use vortex_dtype::Nullability::NonNullable;
use vortex_dtype::PType::I32;
use vortex_dtype::{DType, StructDType};

use super::*;
use crate::{get_item, ident};

#[test]
fn test_idx_to_name_expr() {
let dtype = DType::Struct(
StructDType::new(
vec!["a".into(), "b".into()].into(),
vec![
DType::Struct(
StructDType::new(
vec!["c".into(), "d".into()].into(),
vec![I32.into(), I32.into()],
),
NonNullable,
),
DType::Struct(
StructDType::new(
vec!["e".into(), "f".into()].into(),
vec![I32.into(), I32.into()],
),
NonNullable,
),
],
),
NonNullable,
);
let expr = get_item(1, get_item("a", ident()));
let new_expr = resolve_field_names(expr, &dtype).unwrap();
assert_eq!(&new_expr, &get_item("d", get_item("a", ident())));

let expr = get_item(0, get_item(1, ident()));
let new_expr = resolve_field_names(expr, &dtype).unwrap();
assert_eq!(&new_expr, &get_item("e", get_item("b", ident())));
}
}

0 comments on commit e8228c0

Please sign in to comment.