diff --git a/src/io/ipc/read/stream.rs b/src/io/ipc/read/stream.rs index 370ea9f429d..81ede969cca 100644 --- a/src/io/ipc/read/stream.rs +++ b/src/io/ipc/read/stream.rs @@ -12,7 +12,7 @@ use crate::io::ipc::IpcSchema; use super::super::CONTINUATION_MARKER; use super::common::*; -use super::schema::fb_to_schema; +use super::schema::deserialize_stream_metadata; use super::Dictionaries; /// Metadata of an Arrow IPC stream, written at the start of the stream @@ -45,29 +45,7 @@ pub fn read_stream_metadata(reader: &mut R) -> Result { let mut meta_buffer = vec![0; meta_len as usize]; reader.read_exact(&mut meta_buffer)?; - let message = - arrow_format::ipc::MessageRef::read_as_root(meta_buffer.as_slice()).map_err(|err| { - ArrowError::OutOfSpec(format!("Unable to get root as message: {:?}", err)) - })?; - let version = message.version()?; - // message header is a Schema, so read it - let header = message - .header()? - .ok_or_else(|| ArrowError::oos("Unable to read the first IPC message"))?; - let schema = if let arrow_format::ipc::MessageHeaderRef::Schema(schema) = header { - schema - } else { - return Err(ArrowError::oos( - "The first IPC message of the stream must be a schema", - )); - }; - let (schema, ipc_schema) = fb_to_schema(schema)?; - - Ok(StreamMetadata { - schema, - version, - ipc_schema, - }) + deserialize_stream_metadata(&meta_buffer) } /// Encodes the stream's status after each read. diff --git a/src/io/ipc/read/stream_async.rs b/src/io/ipc/read/stream_async.rs index a3ccffec9ae..9e054cd07ce 100644 --- a/src/io/ipc/read/stream_async.rs +++ b/src/io/ipc/read/stream_async.rs @@ -17,12 +17,23 @@ use super::schema::deserialize_stream_metadata; use super::Dictionaries; use super::StreamMetadata; +/// A (private) state of stream messages +struct ReadState { + pub reader: R, + pub metadata: StreamMetadata, + pub dictionaries: Dictionaries, + /// The internal buffer to read data inside the messages (records and dictionaries) to + pub data_buffer: Vec, + /// The internal buffer to read messages to + pub message_buffer: Vec, +} + /// The state of an Arrow stream enum StreamState { /// The stream does not contain new chunks (and it has not been closed) - Waiting((R, StreamMetadata, Dictionaries)), + Waiting(ReadState), /// The stream contain a new chunk - Some((R, StreamMetadata, Dictionaries, Chunk>)), + Some((ReadState, Chunk>)), } /// Reads the [`StreamMetadata`] of the Arrow stream asynchronously @@ -49,24 +60,20 @@ pub async fn read_stream_metadata_async( /// Reads the next item, yielding `None` if the stream has been closed, /// or a [`StreamState`] otherwise. -async fn _read_next( - mut reader: R, - metadata: StreamMetadata, - mut dictionaries: Dictionaries, - message_buffer: &mut Vec, - data_buffer: &mut Vec, +async fn maybe_next( + mut state: ReadState, ) -> Result>> { // determine metadata length let mut meta_length: [u8; 4] = [0; 4]; - match reader.read_exact(&mut meta_length).await { + match state.reader.read_exact(&mut meta_length).await { Ok(()) => (), Err(e) => { return if e.kind() == std::io::ErrorKind::UnexpectedEof { // Handle EOF without the "0xFFFFFFFF 0x00000000" // valid according to: // https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format - Ok(Some(StreamState::Waiting((reader, metadata, dictionaries)))) + Ok(Some(StreamState::Waiting(state))) } else { Err(ArrowError::from(e)) }; @@ -77,7 +84,7 @@ async fn _read_next( // If a continuation marker is encountered, skip over it and read // the size from the next four bytes. if meta_length == CONTINUATION_MARKER { - reader.read_exact(&mut meta_length).await?; + state.reader.read_exact(&mut meta_length).await?; } i32::from_le_bytes(meta_length) as usize }; @@ -87,13 +94,14 @@ async fn _read_next( return Ok(None); } - message_buffer.clear(); - message_buffer.resize(meta_length, 0); - reader.read_exact(message_buffer).await?; + state.message_buffer.clear(); + state.message_buffer.resize(meta_length, 0); + state.reader.read_exact(&mut state.message_buffer).await?; - let message = arrow_format::ipc::MessageRef::read_as_root(message_buffer).map_err(|err| { - ArrowError::OutOfSpec(format!("Unable to get root as message: {:?}", err)) - })?; + let message = + arrow_format::ipc::MessageRef::read_as_root(&state.message_buffer).map_err(|err| { + ArrowError::OutOfSpec(format!("Unable to get root as message: {:?}", err)) + })?; let header = message.header()?.ok_or_else(|| { ArrowError::oos("IPC: unable to fetch the message header. The file or stream is corrupted.") })?; @@ -102,40 +110,40 @@ async fn _read_next( arrow_format::ipc::MessageHeaderRef::Schema(_) => Err(ArrowError::oos("A stream ")), arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => { // read the block that makes up the record batch into a buffer - data_buffer.clear(); - data_buffer.resize(message.body_length()? as usize, 0); - reader.read_exact(data_buffer).await?; + state.data_buffer.clear(); + state.data_buffer.resize(message.body_length()? as usize, 0); + state.reader.read_exact(&mut state.data_buffer).await?; read_record_batch( batch, - &metadata.schema.fields, - &metadata.ipc_schema, + &state.metadata.schema.fields, + &state.metadata.ipc_schema, None, - &dictionaries, - metadata.version, - &mut std::io::Cursor::new(data_buffer), + &state.dictionaries, + state.metadata.version, + &mut std::io::Cursor::new(&state.data_buffer), 0, ) - .map(|x| Some(StreamState::Some((reader, metadata, dictionaries, x)))) + .map(|chunk| Some(StreamState::Some((state, chunk)))) } arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => { // read the block that makes up the dictionary batch into a buffer let mut buf = vec![0; message.body_length()? as usize]; - reader.read_exact(&mut buf).await?; + state.reader.read_exact(&mut buf).await?; let mut dict_reader = std::io::Cursor::new(buf); read_dictionary( batch, - &metadata.schema.fields, - &metadata.ipc_schema, - &mut dictionaries, + &state.metadata.schema.fields, + &state.metadata.ipc_schema, + &mut state.dictionaries, &mut dict_reader, 0, )?; // read the next message until we encounter a Chunk> message - Ok(Some(StreamState::Waiting((reader, metadata, dictionaries)))) + Ok(Some(StreamState::Waiting(state))) } t => Err(ArrowError::OutOfSpec(format!( "Reading types other than record batches not yet supported, unable to read {:?} ", @@ -144,22 +152,7 @@ async fn _read_next( } } -/// Reads the next item, yielding `None` if the stream is done, -/// and a [`StreamState`] otherwise. -async fn maybe_next( - reader: R, - metadata: StreamMetadata, - dictionaries: Dictionaries, -) -> Result>> { - _read_next(reader, metadata, dictionaries, &mut vec![], &mut vec![]).await -} - -/// Arrow Stream reader. -/// -/// A [`Stream`] over an Arrow stream that yields a result of [`StreamState`]s. -/// This is the recommended way to read an arrow stream (by iterating over its data). -/// -/// For a more thorough walkthrough consult [this example](/~https://github.com/jorgecarleitao/arrow2/tree/main/examples/ipc_pyarrow). +/// A [`Stream`] over an Arrow IPC stream that asynchronously yields [`Chunk`]s. pub struct AsyncStreamReader { metadata: StreamMetadata, future: Option>>>>, @@ -168,7 +161,14 @@ pub struct AsyncStreamReader { impl AsyncStreamReader { /// Creates a new [`AsyncStreamReader`] pub fn new(reader: R, metadata: StreamMetadata) -> Self { - let future = Some(Box::pin(maybe_next(reader, metadata.clone(), Default::default())) as _); + let state = ReadState { + reader, + metadata: metadata.clone(), + dictionaries: Default::default(), + data_buffer: Default::default(), + message_buffer: Default::default(), + }; + let future = Some(Box::pin(maybe_next(state)) as _); Self { metadata, future } } @@ -195,13 +195,8 @@ impl Stream for AsyncStreamReader { me.future = None; Poll::Ready(None) } - Poll::Ready(Ok(Some(StreamState::Some(( - reader, - metadata, - dictionaries, - batch, - ))))) => { - me.future = Some(Box::pin(maybe_next(reader, metadata, dictionaries))); + Poll::Ready(Ok(Some(StreamState::Some((state, batch))))) => { + me.future = Some(Box::pin(maybe_next(state))); Poll::Ready(Some(Ok(batch))) } Poll::Ready(Ok(Some(StreamState::Waiting(_)))) => Poll::Pending, diff --git a/tests/it/io/ipc/mod.rs b/tests/it/io/ipc/mod.rs index 6464685561b..e1a5f79b4f0 100644 --- a/tests/it/io/ipc/mod.rs +++ b/tests/it/io/ipc/mod.rs @@ -7,5 +7,5 @@ pub use common::read_gzip_json; #[cfg(feature = "io_ipc_write_async")] mod write_async; -//#[cfg(feature = "io_ipc_read_async")] +#[cfg(feature = "io_ipc_read_async")] mod read_stream_async;