Skip to content

Commit

Permalink
Add fast path for categorical operations
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Dec 7, 2024
1 parent a6ca94d commit 1d25949
Show file tree
Hide file tree
Showing 9 changed files with 333 additions and 4 deletions.
12 changes: 12 additions & 0 deletions crates/polars-plan/src/dsl/cat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,16 @@ impl CategoricalNameSpace {
self.0
.apply_private(CategoricalFunction::GetCategories.into())
}

#[cfg(feature = "strings")]
pub fn len_bytes(self) -> Expr {
self.0
.map_private(FunctionExpr::Categorical(CategoricalFunction::LenBytes))
}

#[cfg(feature = "strings")]
pub fn len_chars(self) -> Expr {
self.0
.map_private(FunctionExpr::Categorical(CategoricalFunction::LenChars))
}
}
55 changes: 55 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/cat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,21 @@ use crate::map;
#[derive(Clone, PartialEq, Debug, Eq, Hash)]
pub enum CategoricalFunction {
GetCategories,
#[cfg(feature = "strings")]
LenBytes,
#[cfg(feature = "strings")]
LenChars,
}

impl CategoricalFunction {
pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
use CategoricalFunction::*;
match self {
GetCategories => mapper.with_dtype(DataType::String),
#[cfg(feature = "strings")]
LenBytes => mapper.with_dtype(DataType::UInt32),
#[cfg(feature = "strings")]
LenChars => mapper.with_dtype(DataType::UInt32),
}
}
}
Expand All @@ -21,6 +29,10 @@ impl Display for CategoricalFunction {
use CategoricalFunction::*;
let s = match self {
GetCategories => "get_categories",
#[cfg(feature = "strings")]
LenBytes => "len_bytes",
#[cfg(feature = "strings")]
LenChars => "len_chars",
};
write!(f, "cat.{s}")
}
Expand All @@ -31,6 +43,10 @@ impl From<CategoricalFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
use CategoricalFunction::*;
match func {
GetCategories => map!(get_categories),
#[cfg(feature = "strings")]
LenBytes => map!(len_bytes),
#[cfg(feature = "strings")]
LenChars => map!(len_chars),
}
}
}
Expand All @@ -48,3 +64,42 @@ fn get_categories(s: &Column) -> PolarsResult<Column> {
let arr = rev_map.get_categories().clone().boxed();
Series::try_from((ca.name().clone(), arr)).map(Column::from)
}

/// Apply a function to the categories of a categorical column and re-broadcast the result back to
/// to the array.
fn apply_to_cats<F, T>(s: &Column, mut op: F) -> PolarsResult<Column>
where
F: FnMut(&StringChunked) -> ChunkedArray<T>,
ChunkedArray<T>: IntoSeries,
T: PolarsDataType,
{
let ca = s.categorical()?;
let (categories, phys) = match &**ca.get_rev_map() {
RevMapping::Local(c, _) => (c, ca.physical().cast(&IDX_DTYPE)?),
RevMapping::Global(physical_map, c, _) => {
// Map physical to its local representation for use with take() later.
let phys = ca
.physical()
.apply(|opt_v| opt_v.map(|v| *physical_map.get(&v).unwrap()));
let out = phys.cast(&IDX_DTYPE)?;
(c, out)
},
};

// Apply function to categories
let categories = StringChunked::with_chunk(PlSmallStr::EMPTY, categories.clone());
let result = op(&categories).into_series();

let out = result.take(phys.idx()?)?;
Ok(out.into_column())
}

#[cfg(feature = "strings")]
fn len_bytes(s: &Column) -> PolarsResult<Column> {
apply_to_cats(s, |s| s.str_len_bytes())
}

#[cfg(feature = "strings")]
fn len_chars(s: &Column) -> PolarsResult<Column> {
apply_to_cats(s, |s| s.str_len_chars())
}
8 changes: 8 additions & 0 deletions crates/polars-python/src/expr/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,12 @@ impl PyExpr {
fn cat_get_categories(&self) -> Self {
self.inner.clone().cat().get_categories().into()
}

fn cat_len_bytes(&self) -> Self {
self.inner.clone().cat().len_bytes().into()
}

fn cat_len_chars(&self) -> Self {
self.inner.clone().cat().len_chars().into()
}
}
29 changes: 26 additions & 3 deletions crates/polars-python/src/lazyframe/visitor/expr_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use polars_ops::series::InterpolationMethod;
use polars_ops::series::SearchSortedSide;
use polars_plan::dsl::function_expr::rolling::RollingFunction;
use polars_plan::dsl::function_expr::rolling_by::RollingFunctionBy;
use polars_plan::dsl::{BooleanFunction, StringFunction, TemporalFunction};
use polars_plan::dsl::{BooleanFunction, CategoricalFunction, StringFunction, TemporalFunction};
use polars_plan::prelude::{
AExpr, FunctionExpr, GroupbyOptions, IRAggExpr, LiteralValue, Operator, PowFunction,
WindowMapping, WindowType,
Expand Down Expand Up @@ -171,6 +171,21 @@ impl PyStringFunction {
}
}

#[pyclass(name = "CategoricalFunction", eq)]
#[derive(Copy, Clone, PartialEq)]
pub enum PyCategoricalFunction {
GetCategories,
LenBytes,
LenChars,
}

#[pymethods]
impl PyCategoricalFunction {
fn __hash__(&self) -> isize {
*self as isize
}
}

#[pyclass(name = "BooleanFunction", eq)]
#[derive(Copy, Clone, PartialEq)]
pub enum PyBooleanFunction {
Expand Down Expand Up @@ -793,8 +808,16 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
FunctionExpr::BinaryExpr(_) => {
return Err(PyNotImplementedError::new_err("binary expr"))
},
FunctionExpr::Categorical(_) => {
return Err(PyNotImplementedError::new_err("categorical expr"))
FunctionExpr::Categorical(catfun) => match catfun {
CategoricalFunction::GetCategories => {
(PyCategoricalFunction::GetCategories.into_py(py),).to_object(py)
},
CategoricalFunction::LenBytes => {
(PyCategoricalFunction::LenBytes.into_py(py),).to_object(py)
},
CategoricalFunction::LenChars => {
(PyCategoricalFunction::LenChars.into_py(py),).to_object(py)
},
},
FunctionExpr::ListExpr(_) => {
return Err(PyNotImplementedError::new_err("list expr"))
Expand Down
2 changes: 2 additions & 0 deletions py-polars/docs/source/reference/expressions/categories.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ The following methods are available under the `expr.cat` attribute.
:template: autosummary/accessor_method.rst

Expr.cat.get_categories
Expr.cat.len_bytes
Expr.cat.len_chars
2 changes: 2 additions & 0 deletions py-polars/docs/source/reference/series/categories.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,7 @@ The following methods are available under the `Series.cat` attribute.

Series.cat.get_categories
Series.cat.is_local
Series.cat.len_bytes
Series.cat.len_chars
Series.cat.to_local
Series.cat.uses_lexical_ordering
91 changes: 91 additions & 0 deletions py-polars/polars/expr/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,94 @@ def get_categories(self) -> Expr:
└──────┘
"""
return wrap_expr(self._pyexpr.cat_get_categories())

def len_bytes(self) -> Expr:
"""
Return the byte-length of the string representation of each value.
Returns
-------
Expr
Expression of data type :class:`UInt32`.
See Also
--------
len_chars
Notes
-----
When working with non-ASCII text, the length in bytes is not the same as the
length in characters. You may want to use :func:`len_chars` instead.
Note that :func:`len_bytes` is much more performant (_O(1)_) than
:func:`len_chars` (_O(n)_).
Examples
--------
>>> df = pl.DataFrame(
... {"a": pl.Series(["Café", "345", "東京", None], dtype=pl.Categorical)}
... )
>>> df.with_columns(
... pl.col("a").cat.len_bytes().alias("n_bytes"),
... pl.col("a").cat.len_chars().alias("n_chars"),
... )
shape: (4, 3)
┌──────┬─────────┬─────────┐
│ a ┆ n_bytes ┆ n_chars │
│ --- ┆ --- ┆ --- │
│ cat ┆ u32 ┆ u32 │
╞══════╪═════════╪═════════╡
│ Café ┆ 5 ┆ 4 │
│ 345 ┆ 3 ┆ 3 │
│ 東京 ┆ 6 ┆ 2 │
│ null ┆ null ┆ null │
└──────┴─────────┴─────────┘
"""
return wrap_expr(self._pyexpr.cat_len_bytes())

def len_chars(self) -> Expr:
"""
Return the number of characters of the string representation of each value.
Returns
-------
Expr
Expression of data type :class:`UInt32`.
See Also
--------
len_bytes
Notes
-----
When working with ASCII text, use :func:`len_bytes` instead to achieve
equivalent output with much better performance:
:func:`len_bytes` runs in _O(1)_, while :func:`len_chars` runs in (_O(n)_).
A character is defined as a `Unicode scalar value`_. A single character is
represented by a single byte when working with ASCII text, and a maximum of
4 bytes otherwise.
.. _Unicode scalar value: https://www.unicode.org/glossary/#unicode_scalar_value
Examples
--------
>>> df = pl.DataFrame(
... {"a": pl.Series(["Café", "345", "東京", None], dtype=pl.Categorical)}
... )
>>> df.with_columns(
... pl.col("a").cat.len_chars().alias("n_chars"),
... pl.col("a").cat.len_bytes().alias("n_bytes"),
... )
shape: (4, 3)
┌──────┬─────────┬─────────┐
│ a ┆ n_chars ┆ n_bytes │
│ --- ┆ --- ┆ --- │
│ cat ┆ u32 ┆ u32 │
╞══════╪═════════╪═════════╡
│ Café ┆ 4 ┆ 5 │
│ 345 ┆ 3 ┆ 3 │
│ 東京 ┆ 2 ┆ 6 │
│ null ┆ null ┆ null │
└──────┴─────────┴─────────┘
"""
return wrap_expr(self._pyexpr.cat_len_chars())
73 changes: 73 additions & 0 deletions py-polars/polars/series/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,76 @@ def uses_lexical_ordering(self) -> bool:
True
"""
return self._s.cat_uses_lexical_ordering()

def len_bytes(self) -> Series:
"""
Return the byte-length of the string representation of each value.
Returns
-------
Series
Series of data type :class:`UInt32`.
See Also
--------
len_chars
Notes
-----
When working with non-ASCII text, the length in bytes is not the same as the
length in characters. You may want to use :func:`len_chars` instead.
Note that :func:`len_bytes` is much more performant (_O(1)_) than
:func:`len_chars` (_O(n)_).
Examples
--------
>>> s = pl.Series(["Café", "345", "東京", None], dtype=pl.Categorical)
>>> s.cat.len_bytes()
shape: (4,)
Series: '' [u32]
[
5
3
6
null
]
"""

def len_chars(self) -> Series:
"""
Return the number of characters of the string representation of each value.
Returns
-------
Series
Series of data type :class:`UInt32`.
See Also
--------
len_bytes
Notes
-----
When working with ASCII text, use :func:`len_bytes` instead to achieve
equivalent output with much better performance:
:func:`len_bytes` runs in _O(1)_, while :func:`len_chars` runs in (_O(n)_).
A character is defined as a `Unicode scalar value`_. A single character is
represented by a single byte when working with ASCII text, and a maximum of
4 bytes otherwise.
.. _Unicode scalar value: https://www.unicode.org/glossary/#unicode_scalar_value
Examples
--------
>>> s = pl.Series(["Café", "345", "東京", None], dtype=pl.Categorical)
>>> s.cat.len_chars()
shape: (4,)
Series: '' [u32]
[
4
3
2
null
]
"""
Loading

0 comments on commit 1d25949

Please sign in to comment.