diff --git a/src/io/parquet/read/file.rs b/src/io/parquet/read/file.rs index 474f29f7de6..638089668af 100644 --- a/src/io/parquet/read/file.rs +++ b/src/io/parquet/read/file.rs @@ -179,10 +179,9 @@ pub struct RowGroupReader { reader: R, schema: Schema, groups_filter: Option, - row_groups: Vec, + row_groups: std::iter::Enumerate>, chunk_size: Option, remaining_rows: usize, - current_group: usize, } impl RowGroupReader { @@ -199,10 +198,9 @@ impl RowGroupReader { reader, schema, groups_filter, - row_groups, + row_groups: row_groups.into_iter().enumerate(), chunk_size, remaining_rows: limit.unwrap_or(usize::MAX), - current_group: 0, } } @@ -216,28 +214,27 @@ impl RowGroupReader { if self.schema.fields.is_empty() { return Ok(None); } - if self.current_group == self.row_groups.len() { - // reached the last row group - return Ok(None); - }; if self.remaining_rows == 0 { // reached the limit return Ok(None); } - let current_row_group = self.current_group; - let row_group = &self.row_groups[current_row_group]; - if let Some(groups_filter) = self.groups_filter.as_ref() { - if !(groups_filter)(current_row_group, row_group) { - self.current_group += 1; - return self._next(); - } - } - self.current_group += 1; + let row_group = if let Some(groups_filter) = self.groups_filter.as_ref() { + self.row_groups + .by_ref() + .find(|(index, row_group)| (groups_filter)(*index, row_group)) + } else { + self.row_groups.next() + }; + let row_group = if let Some((_, row_group)) = row_group { + row_group + } else { + return Ok(None); + }; let column_chunks = read_columns_many( &mut self.reader, - row_group, + &row_group, self.schema.fields.clone(), self.chunk_size, Some(self.remaining_rows), @@ -263,7 +260,6 @@ impl Iterator for RowGroupReader { } fn size_hint(&self) -> (usize, Option) { - let len = self.row_groups.len() - self.current_group; - (len, Some(len)) + self.row_groups.size_hint() } } diff --git a/tests/it/io/parquet/mod.rs b/tests/it/io/parquet/mod.rs index 1cf69d432c4..8b8a3c2038d 100644 --- a/tests/it/io/parquet/mod.rs +++ b/tests/it/io/parquet/mod.rs @@ -1150,8 +1150,7 @@ fn integration_write(schema: &Schema, chunks: &[Chunk>]) -> Resul type IntegrationRead = (Schema, Vec>>); fn integration_read(data: &[u8], limit: Option) -> Result { - let reader = Cursor::new(data); - let reader = FileReader::try_new(reader, None, None, limit, None)?; + let reader = FileReader::try_new(Cursor::new(data), None, None, limit, None)?; let schema = reader.schema().clone(); for field in &schema.fields { @@ -1519,3 +1518,32 @@ fn nested_dict_limit() -> Result<()> { assert_roundtrip(schema, chunk, Some(2)) } + +#[test] +fn filter_chunk() -> Result<()> { + let chunk1 = Chunk::new(vec![PrimitiveArray::from_slice([1i16, 3]).boxed()]); + let chunk2 = Chunk::new(vec![PrimitiveArray::from_slice([2i16, 4]).boxed()]); + let schema = Schema::from(vec![Field::new("c1", DataType::Int16, true)]); + + let r = integration_write(&schema, &[chunk1.clone(), chunk2.clone()])?; + + let reader = FileReader::try_new( + Cursor::new(r), + None, + None, + None, + // select chunk 1 + Some(std::sync::Arc::new(|i, _| i == 0)), + )?; + let new_schema = reader.schema().clone(); + + for field in &schema.fields { + let mut _statistics = deserialize(field, &reader.metadata().row_groups)?; + } + + let new_chunks = reader.collect::>>()?; + + assert_eq!(new_schema, schema); + assert_eq!(new_chunks, vec![chunk1]); + Ok(()) +}