Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Added input validation for explode operation in the array namespace #19163

Merged
merged 8 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions crates/polars-plan/src/dsl/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,9 @@ impl ArrayNameSpace {
None,
)
}
/// Returns a column with a separate row for every array element.
pub fn explode(self) -> Expr {
self.0
.map_private(FunctionExpr::ArrayExpr(ArrayFunction::Explode))
}
}
4 changes: 4 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub enum ArrayFunction {
#[cfg(feature = "array_count")]
CountMatches,
Shift,
Explode,
}

impl ArrayFunction {
Expand All @@ -56,6 +57,7 @@ impl ArrayFunction {
#[cfg(feature = "array_count")]
CountMatches => mapper.with_dtype(IDX_DTYPE),
Shift => mapper.with_same_dtype(),
Explode => mapper.try_map_to_array_inner_dtype(),
}
}
}
Expand Down Expand Up @@ -96,6 +98,7 @@ impl Display for ArrayFunction {
#[cfg(feature = "array_count")]
CountMatches => "count_matches",
Shift => "shift",
Explode => "explode",
};
write!(f, "arr.{name}")
}
Expand Down Expand Up @@ -129,6 +132,7 @@ impl From<ArrayFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
#[cfg(feature = "array_count")]
CountMatches => map_as_slice!(count_matches),
Shift => map_as_slice!(shift),
Explode => unreachable!(),
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,16 @@ impl<'a> FieldsMapper<'a> {
Ok(first)
}

#[cfg(feature = "dtype-array")]
/// Map the dtype to the dtype of the array elements, with typo validation.
pub fn try_map_to_array_inner_dtype(&self) -> PolarsResult<Field> {
let dt = self.fields[0].dtype();
match dt {
DataType::Array(_, _) => self.map_to_list_and_array_inner_dtype(),
_ => polars_bail!(InvalidOperation: "expected Array type, got: {}", dt),
}
}

/// Map the dtypes to the "supertype" of a list of lists.
pub fn map_to_list_supertype(&self) -> PolarsResult<Field> {
self.try_map_dtypes(|dts| {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ pub(super) fn optimize_functions(
expr_arena: &mut Arena<AExpr>,
) -> PolarsResult<Option<AExpr>> {
let out = match function {
#[cfg(feature = "dtype-array")]
// arr.explode() -> explode()
FunctionExpr::ArrayExpr(ArrayFunction::Explode) => {
let input_node = input[0].node();
Some(AExpr::Explode(input_node))
},
// is_null().any() -> null_count() > 0
// is_not_null().any() -> null_count() < len()
// CORRECTNESS: we can ignore 'ignore_nulls' since is_null/is_not_null never produces NULLS
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-python/src/expr/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,8 @@ impl PyExpr {
fn arr_shift(&self, n: PyExpr) -> Self {
self.inner.clone().arr().shift(n.inner).into()
}

fn arr_explode(&self) -> Self {
self.inner.clone().arr().explode().into()
}
}
2 changes: 1 addition & 1 deletion py-polars/polars/expr/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def explode(self) -> Expr:
│ 6 │
└─────┘
"""
return wrap_expr(self._pyexpr.explode())
return wrap_expr(self._pyexpr.arr_explode())

def contains(
self, item: float | str | bool | int | date | datetime | time | IntoExprColumn
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/series/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@expr_dispatch
class ArrayNameSpace:
"""Namespace for list related methods."""
"""Namespace for array related methods."""

_accessor = "arr"

Expand Down
13 changes: 12 additions & 1 deletion py-polars/tests/unit/operations/namespaces/array/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

import polars as pl
from polars.exceptions import ComputeError
from polars.exceptions import ComputeError, InvalidOperationError
from polars.testing import assert_frame_equal, assert_series_equal


Expand Down Expand Up @@ -449,3 +449,14 @@ def test_array_n_unique() -> None:
{"n_unique": [2, 1, 1, None]}, schema={"n_unique": pl.UInt32}
)
assert_frame_equal(out, expected)


def test_explode_19049() -> None:
df = pl.DataFrame({"a": [[1, 2, 3]]}, schema={"a": pl.Array(pl.Int64, 3)})
result_df = df.select(pl.col.a.arr.explode())
expected_df = pl.DataFrame({"a": [1, 2, 3]}, schema={"a": pl.Int64})
assert_frame_equal(result_df, expected_df)

df = pl.DataFrame({"a": [1, 2, 3]}, schema={"a": pl.Int64})
with pytest.raises(InvalidOperationError, match="expected Array type, got: i64"):
df.select(pl.col.a.arr.explode())