Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rust): Allow setting and reading custom schema-level IPC metadata #20066

Merged
merged 6 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/polars-arrow/src/io/ipc/append/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ impl<R: Read + Seek + Write> FileWriter<R> {
cannot_replace: true,
},
encoded_message: Default::default(),
custom_schema_metadata: None,
})
}
}
8 changes: 6 additions & 2 deletions crates/polars-arrow/src/io/ipc/read/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use super::common::*;
use super::schema::fb_to_schema;
use super::{Dictionaries, OutOfSpecKind, SendableIterator};
use crate::array::Array;
use crate::datatypes::ArrowSchemaRef;
use crate::datatypes::{ArrowSchemaRef, Metadata};
use crate::io::ipc::IpcSchema;
use crate::record_batch::RecordBatchT;

Expand All @@ -21,6 +21,9 @@ pub struct FileMetadata {
/// The schema that is read from the file footer
pub schema: ArrowSchemaRef,

/// The custom metadata that is read from the schema
pub custom_schema_metadata: Arc<Option<Metadata>>,
nameexhaustion marked this conversation as resolved.
Show resolved Hide resolved

/// The files' [`IpcSchema`]
pub ipc_schema: IpcSchema,

Expand Down Expand Up @@ -245,14 +248,15 @@ pub fn deserialize_footer(footer_data: &[u8], size: u64) -> PolarsResult<FileMet
.map(|dicts| dicts.collect::<PolarsResult<Vec<_>>>())
.transpose()?;
let ipc_schema = deserialize_schema_ref_from_footer(footer)?;
let (schema, ipc_schema) = fb_to_schema(ipc_schema)?;
let (schema, ipc_schema, custom_schema_metadata) = fb_to_schema(ipc_schema)?;

Ok(FileMetadata {
schema: Arc::new(schema),
ipc_schema,
blocks,
dictionaries,
size,
custom_schema_metadata: Arc::new(custom_schema_metadata),
})
}

Expand Down
30 changes: 27 additions & 3 deletions crates/polars-arrow/src/io/ipc/read/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,9 @@ fn get_dtype(
}

/// Deserialize an flatbuffers-encoded Schema message into [`ArrowSchema`] and [`IpcSchema`].
pub fn deserialize_schema(message: &[u8]) -> PolarsResult<(ArrowSchema, IpcSchema)> {
pub fn deserialize_schema(
message: &[u8],
) -> PolarsResult<(ArrowSchema, IpcSchema, Option<Metadata>)> {
let message = arrow_format::ipc::MessageRef::read_as_root(message)
.map_err(|_err| polars_err!(oos = "Unable deserialize message: {err:?}"))?;

Expand All @@ -374,7 +376,7 @@ pub fn deserialize_schema(message: &[u8]) -> PolarsResult<(ArrowSchema, IpcSchem
/// Deserialize the raw Schema table from IPC format to Schema data type
pub(super) fn fb_to_schema(
schema: arrow_format::ipc::SchemaRef,
) -> PolarsResult<(ArrowSchema, IpcSchema)> {
) -> PolarsResult<(ArrowSchema, IpcSchema, Option<Metadata>)> {
let fields = schema
.fields()?
.ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingFields))?;
Expand All @@ -393,12 +395,33 @@ pub(super) fn fb_to_schema(
arrow_format::ipc::Endianness::Big => false,
};

let custom_schema_metadata = match schema.custom_metadata()? {
None => None,
Some(metadata) => {
let metadata: Metadata = metadata
.into_iter()
.filter_map(|kv_result| {
// FIXME: silently hiding errors here
let kv_ref = kv_result.ok()?;
Some((kv_ref.key().ok()??.into(), kv_ref.value().ok()??.into()))
})
.collect();

if metadata.is_empty() {
None
} else {
Some(metadata)
}
},
};

Ok((
arrow_schema,
IpcSchema {
fields: ipc_fields,
is_little_endian,
},
custom_schema_metadata,
))
}

Expand All @@ -415,11 +438,12 @@ pub(super) fn deserialize_stream_metadata(meta: &[u8]) -> PolarsResult<StreamMet
} else {
polars_bail!(oos = "The first IPC message of the stream must be a schema")
};
let (schema, ipc_schema) = fb_to_schema(schema)?;
let (schema, ipc_schema, custom_schema_metadata) = fb_to_schema(schema)?;

Ok(StreamMetadata {
schema,
version,
ipc_schema,
custom_schema_metadata,
})
}
5 changes: 4 additions & 1 deletion crates/polars-arrow/src/io/ipc/read/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use super::common::*;
use super::schema::deserialize_stream_metadata;
use super::{Dictionaries, OutOfSpecKind};
use crate::array::Array;
use crate::datatypes::ArrowSchema;
use crate::datatypes::{ArrowSchema, Metadata};
use crate::io::ipc::IpcSchema;
use crate::record_batch::RecordBatchT;

Expand All @@ -18,6 +18,9 @@ pub struct StreamMetadata {
/// The schema that is read from the stream's first message
pub schema: ArrowSchema,

/// The custom metadata that is read from the schema
pub custom_schema_metadata: Option<Metadata>,

/// The IPC version of the stream
pub version: arrow_format::ipc::MetadataVersion,

Expand Down
17 changes: 14 additions & 3 deletions crates/polars-arrow/src/io/ipc/write/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ use crate::datatypes::{
use crate::io::ipc::endianness::is_native_little_endian;

/// Converts a [ArrowSchema] and [IpcField]s to a flatbuffers-encoded [arrow_format::ipc::Message].
pub fn schema_to_bytes(schema: &ArrowSchema, ipc_fields: &[IpcField]) -> Vec<u8> {
let schema = serialize_schema(schema, ipc_fields);
pub fn schema_to_bytes(
schema: &ArrowSchema,
ipc_fields: &[IpcField],
custom_metadata: Option<&Metadata>,
) -> Vec<u8> {
let schema = serialize_schema(schema, ipc_fields, custom_metadata);

let message = arrow_format::ipc::Message {
version: arrow_format::ipc::MetadataVersion::V5,
Expand All @@ -24,6 +28,7 @@ pub fn schema_to_bytes(schema: &ArrowSchema, ipc_fields: &[IpcField]) -> Vec<u8>
pub fn serialize_schema(
schema: &ArrowSchema,
ipc_fields: &[IpcField],
custom_schema_metadata: Option<&Metadata>,
) -> arrow_format::ipc::Schema {
let endianness = if is_native_little_endian() {
arrow_format::ipc::Endianness::Little
Expand All @@ -37,7 +42,13 @@ pub fn serialize_schema(
.map(|(field, ipc_field)| serialize_field(field, ipc_field))
.collect::<Vec<_>>();

let custom_metadata = None;
let custom_metadata = custom_schema_metadata.and_then(|custom_meta| {
let as_kv = custom_meta
.iter()
.map(|(key, val)| key_value(key.clone().into_string(), val.clone().into_string()))
.collect::<Vec<_>>();
(!as_kv.is_empty()).then_some(as_kv)
});

arrow_format::ipc::Schema {
endianness,
Expand Down
15 changes: 14 additions & 1 deletion crates/polars-arrow/src/io/ipc/write/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
//! however the `FileWriter` expects a reader that supports `Seek`ing

use std::io::Write;
use std::sync::Arc;

use polars_error::{PolarsError, PolarsResult};

Expand All @@ -30,6 +31,8 @@ pub struct StreamWriter<W: Write> {
finished: bool,
/// Keeps track of dictionaries that have been written
dictionary_tracker: DictionaryTracker,
/// Custom schema-level metadata
custom_schema_metadata: Option<Arc<Metadata>>,

ipc_fields: Option<Vec<IpcField>>,
}
Expand All @@ -46,9 +49,15 @@ impl<W: Write> StreamWriter<W> {
cannot_replace: false,
},
ipc_fields: None,
custom_schema_metadata: None,
}
}

/// Sets custom schema metadata. Must be called before `start` is called
pub fn set_custom_schema_metadata(&mut self, custom_metadata: Arc<Metadata>) {
self.custom_schema_metadata = Some(custom_metadata);
}

/// Starts the stream by writing a Schema message to it.
/// Use `ipc_fields` to declare dictionary ids in the schema, for dictionary-reuse
pub fn start(
Expand All @@ -63,7 +72,11 @@ impl<W: Write> StreamWriter<W> {
});

let encoded_message = EncodedData {
ipc_message: schema_to_bytes(schema, self.ipc_fields.as_ref().unwrap()),
ipc_message: schema_to_bytes(
schema,
self.ipc_fields.as_ref().unwrap(),
self.custom_schema_metadata.as_deref(),
),
arrow_data: vec![],
};
write_message(&mut self.writer, &encoded_message)?;
Expand Down
21 changes: 19 additions & 2 deletions crates/polars-arrow/src/io/ipc/write/writer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::io::Write;
use std::sync::Arc;

use arrow_format::ipc::planus::Builder;
use polars_error::{polars_bail, PolarsResult};
Expand Down Expand Up @@ -40,6 +41,8 @@ pub struct FileWriter<W: Write> {
pub(crate) dictionary_tracker: DictionaryTracker,
/// Buffer/scratch that is reused between writes
pub(crate) encoded_message: EncodedData,
/// Custom schema-level metadata
pub(crate) custom_schema_metadata: Option<Arc<Metadata>>,
}

impl<W: Write> FileWriter<W> {
Expand Down Expand Up @@ -83,6 +86,7 @@ impl<W: Write> FileWriter<W> {
cannot_replace: true,
},
encoded_message: Default::default(),
custom_schema_metadata: None,
}
}

Expand Down Expand Up @@ -116,7 +120,11 @@ impl<W: Write> FileWriter<W> {
// write the schema, set the written bytes to the schema

let encoded_message = EncodedData {
ipc_message: schema_to_bytes(&self.schema, &self.ipc_fields),
ipc_message: schema_to_bytes(
&self.schema,
&self.ipc_fields,
self.custom_schema_metadata.as_deref(),
nameexhaustion marked this conversation as resolved.
Show resolved Hide resolved
),
arrow_data: vec![],
};

Expand Down Expand Up @@ -210,7 +218,11 @@ impl<W: Write> FileWriter<W> {
// write EOS
write_continuation(&mut self.writer, 0)?;

let schema = schema::serialize_schema(&self.schema, &self.ipc_fields);
let schema = schema::serialize_schema(
&self.schema,
&self.ipc_fields,
self.custom_schema_metadata.as_deref(),
);

let root = arrow_format::ipc::Footer {
version: arrow_format::ipc::MetadataVersion::V5,
Expand All @@ -230,4 +242,9 @@ impl<W: Write> FileWriter<W> {

Ok(())
}

/// Sets custom schema metadata. Must be called before `start` is called
pub fn set_custom_schema_metadata(&mut self, custom_metadata: Arc<Metadata>) {
self.custom_schema_metadata = Some(custom_metadata);
}
}
11 changes: 10 additions & 1 deletion crates/polars-io/src/ipc/ipc_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
use std::io::{Read, Seek};
use std::path::PathBuf;

use arrow::datatypes::ArrowSchemaRef;
use arrow::datatypes::{ArrowSchemaRef, Metadata};
use arrow::io::ipc::read::{self, get_row_count};
use arrow::record_batch::RecordBatch;
use polars_core::prelude::*;
Expand Down Expand Up @@ -115,6 +115,15 @@ impl<R: MmapBytesReader> IpcReader<R> {
self.get_metadata()?;
Ok(self.schema.as_ref().unwrap().clone())
}

/// Get schema-level custom metadata of the Ipc file
pub fn custom_metadata(&mut self) -> PolarsResult<Arc<Option<Metadata>>> {
nameexhaustion marked this conversation as resolved.
Show resolved Hide resolved
self.get_metadata()?;
Ok(Arc::clone(
&self.metadata.as_ref().unwrap().custom_schema_metadata,
))
}

/// Stop reading when `n` rows are read.
pub fn with_n_rows(mut self, num_rows: Option<usize>) -> Self {
self.n_rows = num_rows;
Expand Down
32 changes: 30 additions & 2 deletions crates/polars-io/src/ipc/ipc_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
use std::io::{Read, Write};
use std::path::PathBuf;

use arrow::datatypes::Metadata;
use arrow::io::ipc::read::{StreamMetadata, StreamState};
use arrow::io::ipc::write::WriteOptions;
use arrow::io::ipc::{read, write};
Expand Down Expand Up @@ -83,6 +84,12 @@ impl<R: Read> IpcStreamReader<R> {
pub fn arrow_schema(&mut self) -> PolarsResult<ArrowSchema> {
Ok(self.metadata()?.schema)
}

/// Get schema-level custom metadata of the Ipc Stream file
pub fn custom_metadata(&mut self) -> PolarsResult<Arc<Option<Metadata>>> {
nameexhaustion marked this conversation as resolved.
Show resolved Hide resolved
Ok(Arc::new(self.metadata()?.custom_schema_metadata))
}

/// Stop reading when `n` rows are read.
pub fn with_n_rows(mut self, num_rows: Option<usize>) -> Self {
self.n_rows = num_rows;
Expand Down Expand Up @@ -198,8 +205,17 @@ where
/// fn example(df: &mut DataFrame) -> PolarsResult<()> {
/// let mut file = File::create("file.ipc").expect("could not create file");
///
/// IpcStreamWriter::new(&mut file)
/// .finish(df)
/// let mut writer = IpcStreamWriter::new(&mut file);
///
/// let custom_metadata = [
/// ("first_name".into(), "John".into()),
/// ("last_name".into(), "Doe".into()),
/// ]
/// .into_iter()
/// .collect();
/// writer.set_custom_schema_metadata(Arc::new(custom_metadata));
///
/// writer.finish(df)
/// }
///
/// ```
Expand All @@ -208,6 +224,8 @@ pub struct IpcStreamWriter<W> {
writer: W,
compression: Option<IpcCompression>,
compat_level: CompatLevel,
/// Custom schema-level metadata
custom_schema_metadata: Option<Arc<Metadata>>,
}

use arrow::record_batch::RecordBatch;
Expand All @@ -225,6 +243,11 @@ impl<W> IpcStreamWriter<W> {
self.compat_level = compat_level;
self
}

/// Sets custom schema metadata. Must be called before `start` is called
pub fn set_custom_schema_metadata(&mut self, custom_metadata: Arc<Metadata>) {
self.custom_schema_metadata = Some(custom_metadata);
}
}

impl<W> SerWriter<W> for IpcStreamWriter<W>
Expand All @@ -236,6 +259,7 @@ where
writer,
compression: None,
compat_level: CompatLevel::oldest(),
custom_schema_metadata: None,
}
}

Expand All @@ -247,6 +271,10 @@ where
},
);

if let Some(custom_metadata) = &self.custom_schema_metadata {
ipc_stream_writer.set_custom_schema_metadata(Arc::clone(custom_metadata));
}

ipc_stream_writer.start(&df.schema().to_arrow(self.compat_level), None)?;
let df = chunk_df_for_writing(df, 512 * 512)?;
let iter = df.iter_chunks(self.compat_level, true);
Expand Down
Loading