diff --git a/vortex-dtype/src/field.rs b/vortex-dtype/src/field.rs index d6b4d60364..ac083ad435 100644 --- a/vortex-dtype/src/field.rs +++ b/vortex-dtype/src/field.rs @@ -8,6 +8,9 @@ use std::fmt::{Display, Formatter}; use std::sync::Arc; use itertools::Itertools; +use vortex_error::{vortex_err, VortexResult}; + +use crate::FieldNames; /// A selector for a field in a struct #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -46,6 +49,54 @@ impl Display for Field { } } +impl Field { + /// Returns true if the field is defined by Name + pub fn is_named(&self) -> bool { + matches!(self, Field::Name(_)) + } + + /// Returns true if the field is defined by Index + pub fn is_indexed(&self) -> bool { + matches!(self, Field::Index(_)) + } + + /// Convert a field to a named field + pub fn into_named_field(self, field_names: &FieldNames) -> VortexResult { + match self { + Field::Index(idx) => field_names + .get(idx) + .ok_or_else(|| { + vortex_err!( + "Field index {} out of bounds, it has names {:?}", + idx, + field_names + ) + }) + .cloned() + .map(Field::Name), + Field::Name(_) => Ok(self), + } + } + + /// Convert a field to an indexed field + pub fn into_index_field(self, field_names: &FieldNames) -> VortexResult { + match self { + Field::Name(name) => field_names + .iter() + .position(|n| *n == name) + .ok_or_else(|| { + vortex_err!( + "Field name {} not found, it has names {:?}", + name, + field_names + ) + }) + .map(Field::Index), + Field::Index(_) => Ok(self), + } + } +} + /// A path through a (possibly nested) struct, composed of a sequence of field selectors // TODO(ngates): wrap `Vec` in Option for cheaper "root" path. #[derive(Clone, Debug, PartialEq, Eq, Hash)] diff --git a/vortex-expr/src/transform/mod.rs b/vortex-expr/src/transform/mod.rs index 465b4a019a..2f6c4dc96b 100644 --- a/vortex-expr/src/transform/mod.rs +++ b/vortex-expr/src/transform/mod.rs @@ -1,2 +1,4 @@ +//! A collection of transformations that can be applied to a [`crate::ExprRef`]. pub mod partition; +pub mod resolve_field_names; pub mod simplify; diff --git a/vortex-expr/src/transform/resolve_field_names.rs b/vortex-expr/src/transform/resolve_field_names.rs new file mode 100644 index 0000000000..9ccaddbc1e --- /dev/null +++ b/vortex-expr/src/transform/resolve_field_names.rs @@ -0,0 +1,86 @@ +use vortex_dtype::DType; +use vortex_error::{vortex_err, VortexResult}; + +use crate::traversal::{MutNodeVisitor, Node, TransformResult}; +use crate::{ExprRef, GetItem}; + +/// Resolves any [`vortex_dtype::Field::Index`] nodes in the expression to +/// [`vortex_dtype::Field::Name`] nodes. +pub fn resolve_field_names(expr: ExprRef, scope_dtype: &DType) -> VortexResult { + let mut visitor = FieldToNameTransform { scope_dtype }; + expr.transform(&mut visitor).map(|node| node.result) +} + +struct FieldToNameTransform<'a> { + scope_dtype: &'a DType, +} + +impl MutNodeVisitor for FieldToNameTransform<'_> { + type NodeTy = ExprRef; + + fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult> { + if let Some(get_item) = node.as_any().downcast_ref::() { + if get_item.field().is_named() { + return Ok(TransformResult::no(node)); + } + + let child_dtype = get_item.child().return_dtype(self.scope_dtype)?; + let struct_dtype = child_dtype + .as_struct() + .ok_or_else(|| vortex_err!("get_item requires child to have struct dtype"))?; + + return Ok(TransformResult::yes(GetItem::new_expr( + get_item + .field() + .clone() + .into_named_field(struct_dtype.names())?, + get_item.child().clone(), + ))); + } + + Ok(TransformResult::no(node)) + } +} + +#[cfg(test)] +mod tests { + use vortex_dtype::Nullability::NonNullable; + use vortex_dtype::PType::I32; + use vortex_dtype::{DType, StructDType}; + + use super::*; + use crate::{get_item, ident}; + + #[test] + fn test_idx_to_name_expr() { + let dtype = DType::Struct( + StructDType::new( + vec!["a".into(), "b".into()].into(), + vec![ + DType::Struct( + StructDType::new( + vec!["c".into(), "d".into()].into(), + vec![I32.into(), I32.into()], + ), + NonNullable, + ), + DType::Struct( + StructDType::new( + vec!["e".into(), "f".into()].into(), + vec![I32.into(), I32.into()], + ), + NonNullable, + ), + ], + ), + NonNullable, + ); + let expr = get_item(1, get_item("a", ident())); + let new_expr = resolve_field_names(expr, &dtype).unwrap(); + assert_eq!(&new_expr, &get_item("d", get_item("a", ident()))); + + let expr = get_item(0, get_item(1, ident())); + let new_expr = resolve_field_names(expr, &dtype).unwrap(); + assert_eq!(&new_expr, &get_item("e", get_item("b", ident()))); + } +}