Skip to content

Commit

Permalink
refactor: Reduce mode bloat (#20839)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jan 22, 2025
1 parent 3144f67 commit 20979a0
Showing 1 changed file with 8 additions and 95 deletions.
103 changes: 8 additions & 95 deletions crates/polars-ops/src/chunked_array/mode.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,5 @@
use polars_core::prelude::*;
use polars_core::{with_match_physical_integer_polars_type, POOL};

fn mode_primitive<T: PolarsDataType>(ca: &ChunkedArray<T>) -> PolarsResult<ChunkedArray<T>>
where
ChunkedArray<T>: IntoGroupsType + ChunkTake<[IdxSize]>,
{
if ca.is_empty() {
return Ok(ca.clone());
}
let parallel = !POOL.current_thread_has_pending_tasks().unwrap_or(false);
let groups = ca.group_tuples(parallel, false).unwrap();
let idx = mode_indices(groups);

// SAFETY:
// group indices are in bounds
Ok(unsafe { ca.take_unchecked(idx.as_slice()) })
}

fn mode_f32(ca: &Float32Chunked) -> PolarsResult<Float32Chunked> {
let s = ca.apply_as_ints(|v| mode(v).unwrap());
let ca = s.f32().unwrap().clone();
Ok(ca)
}

fn mode_64(ca: &Float64Chunked) -> PolarsResult<Float64Chunked> {
let s = ca.apply_as_ints(|v| mode(v).unwrap());
let ca = s.f64().unwrap().clone();
Ok(ca)
}
use polars_core::POOL;

fn mode_indices(groups: GroupsType) -> Vec<IdxSize> {
match groups {
Expand Down Expand Up @@ -55,70 +27,11 @@ fn mode_indices(groups: GroupsType) -> Vec<IdxSize> {
}

pub fn mode(s: &Series) -> PolarsResult<Series> {
let s_phys = s.to_physical_repr();
let out = match s_phys.dtype() {
DataType::Binary => mode_primitive(s_phys.binary().unwrap())?.into_series(),
DataType::Boolean => mode_primitive(s_phys.bool().unwrap())?.into_series(),
DataType::Float32 => mode_f32(s_phys.f32().unwrap())?.into_series(),
DataType::Float64 => mode_64(s_phys.f64().unwrap())?.into_series(),
DataType::String => {
let ca = mode_primitive(&s_phys.str().unwrap().as_binary())?;
unsafe { ca.to_string_unchecked() }.into_series()
},
dt if dt.is_integer() => {
with_match_physical_integer_polars_type!(dt, |$T| {
let ca: &ChunkedArray<$T> = s_phys.as_ref().as_ref().as_ref();
mode_primitive(ca)?.into_series()
})
},
_ => polars_bail!(opq = mode, s.dtype()),
};
// SAFETY: Casting back into the original from physical representation
unsafe { out.from_physical_unchecked(s.dtype()) }
}

#[cfg(test)]
mod test {
use polars_core::prelude::*;

use super::{mode, mode_primitive};

#[test]
fn mode_test() {
let ca = Int32Chunked::from_slice("test".into(), &[0, 1, 2, 3, 4, 4, 5, 6, 5, 0]);
let mut result = mode_primitive(&ca).unwrap().to_vec();
result.sort_by_key(|a| a.unwrap());
assert_eq!(&result, &[Some(0), Some(4), Some(5)]);

let ca = Int32Chunked::from_slice("test".into(), &[1, 1]);
let mut result = mode_primitive(&ca).unwrap().to_vec();
result.sort_by_key(|a| a.unwrap());
assert_eq!(&result, &[Some(1)]);

let ca = Int32Chunked::from_slice("test".into(), &[]);
let mut result = mode_primitive(&ca).unwrap().to_vec();
result.sort_by_key(|a| a.unwrap());
assert_eq!(result, &[]);

let ca = Float32Chunked::from_slice("test".into(), &[1.0f32, 2.0, 2.0, 3.0, 3.0, 3.0]);
let result = mode_primitive(&ca).unwrap().to_vec();
assert_eq!(result, &[Some(3.0f32)]);

let ca =
StringChunked::from_slice("test".into(), &["test", "test", "test", "another test"]);
let result = mode_primitive(&ca).unwrap();
let vec_result4: Vec<Option<&str>> = result.into_iter().collect();
assert_eq!(vec_result4, &[Some("test")]);

let mut ca_builder = CategoricalChunkedBuilder::new("test".into(), 5, Default::default());
ca_builder.append_value("test");
ca_builder.append_value("test");
ca_builder.append_value("test2");
ca_builder.append_value("test2");
ca_builder.append_value("test2");
let s = ca_builder.finish().into_series();
let result = mode(&s).unwrap();
assert_eq!(result.str_value(0).unwrap(), "test2");
assert_eq!(result.len(), 1);
}
let parallel = !POOL.current_thread_has_pending_tasks().unwrap_or(false);
let groups = s.group_tuples(parallel, false).unwrap();
let idx = mode_indices(groups);
let idx = IdxCa::from_vec("".into(), idx);
// SAFETY:
// group indices are in bounds
Ok(unsafe { s.take_unchecked(&idx) })
}

0 comments on commit 20979a0

Please sign in to comment.