Skip to content

Commit

Permalink
feat: Support cumulative aggregations for Decimal dtype (#20802)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemanley authored Jan 20, 2025
1 parent 7f75fc1 commit 374bce7
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 10 deletions.
41 changes: 31 additions & 10 deletions crates/polars-ops/src/series/ops/cum_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,13 @@ pub fn cum_sum(s: &Series, reverse: bool) -> PolarsResult<Series> {
Int128 => cum_sum_numeric(s.i128()?, reverse).into_series(),
Float32 => cum_sum_numeric(s.f32()?, reverse).into_series(),
Float64 => cum_sum_numeric(s.f64()?, reverse).into_series(),
#[cfg(feature = "dtype-decimal")]
Decimal(precision, scale) => {
let ca = s.decimal().unwrap().as_ref();
cum_sum_numeric(ca, reverse)
.into_decimal_unchecked(*precision, scale.unwrap())
.into_series()
},
#[cfg(feature = "dtype-duration")]
Duration(tu) => {
let s = s.to_physical_repr();
Expand All @@ -232,16 +239,23 @@ pub fn cum_sum(s: &Series, reverse: bool) -> PolarsResult<Series> {

/// Get an array with the cumulative min computed at every element.
pub fn cum_min(s: &Series, reverse: bool) -> PolarsResult<Series> {
let original_type = s.dtype();
let s = s.to_physical_repr();
match s.dtype() {
DataType::Boolean => Ok(cum_min_bool(s.bool()?, reverse).into_series()),
dt if dt.is_primitive_numeric() => {
#[cfg(feature = "dtype-decimal")]
DataType::Decimal(precision, scale) => {
let ca = s.decimal().unwrap().as_ref();
let out = cum_min_numeric(ca, reverse)
.into_decimal_unchecked(*precision, scale.unwrap())
.into_series();
Ok(out)
},
dt if dt.to_physical().is_primitive_numeric() => {
let s = s.to_physical_repr();
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
let out = cum_min_numeric(ca, reverse).into_series();
if original_type.is_logical() {
out.cast(original_type)
if dt.is_logical() {
out.cast(dt)
} else {
Ok(out)
}
Expand All @@ -253,16 +267,23 @@ pub fn cum_min(s: &Series, reverse: bool) -> PolarsResult<Series> {

/// Get an array with the cumulative max computed at every element.
pub fn cum_max(s: &Series, reverse: bool) -> PolarsResult<Series> {
let original_type = s.dtype();
let s = s.to_physical_repr();
match s.dtype() {
DataType::Boolean => Ok(cum_max_bool(s.bool()?, reverse).into_series()),
dt if dt.is_primitive_numeric() => {
#[cfg(feature = "dtype-decimal")]
DataType::Decimal(precision, scale) => {
let ca = s.decimal().unwrap().as_ref();
let out = cum_max_numeric(ca, reverse)
.into_decimal_unchecked(*precision, scale.unwrap())
.into_series();
Ok(out)
},
dt if dt.to_physical().is_primitive_numeric() => {
let s = s.to_physical_repr();
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
let out = cum_max_numeric(ca, reverse).into_series();
if original_type.is_logical() {
out.cast(original_type)
if dt.is_logical() {
out.cast(dt)
} else {
Ok(out)
}
Expand Down
17 changes: 17 additions & 0 deletions py-polars/tests/unit/datatypes/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,23 @@ def test_decimal_aggregations() -> None:
assert_frame_equal(df.describe(), description)


def test_decimal_cumulative_aggregations() -> None:
df = pl.Series("a", [D("2.2"), D("1.1"), D("3.3")]).to_frame()
result = df.select(
pl.col("a").cum_sum().alias("cum_sum"),
pl.col("a").cum_min().alias("cum_min"),
pl.col("a").cum_max().alias("cum_max"),
)
expected = pl.DataFrame(
{
"cum_sum": [D("2.2"), D("3.3"), D("6.6")],
"cum_min": [D("2.2"), D("1.1"), D("1.1")],
"cum_max": [D("2.2"), D("2.2"), D("3.3")],
}
)
assert_frame_equal(result, expected)


def test_decimal_df_vertical_sum() -> None:
df = pl.DataFrame({"a": [D("1.1"), D("2.2")]})
expected = pl.DataFrame({"a": [D("3.3")]})
Expand Down

0 comments on commit 374bce7

Please sign in to comment.