Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
DRY
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Feb 12, 2022
1 parent 2da6078 commit f9696ff
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 80 deletions.
26 changes: 2 additions & 24 deletions src/io/ipc/read/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::io::ipc::IpcSchema;

use super::super::CONTINUATION_MARKER;
use super::common::*;
use super::schema::fb_to_schema;
use super::schema::deserialize_stream_metadata;
use super::Dictionaries;

/// Metadata of an Arrow IPC stream, written at the start of the stream
Expand Down Expand Up @@ -45,29 +45,7 @@ pub fn read_stream_metadata<R: Read>(reader: &mut R) -> Result<StreamMetadata> {
let mut meta_buffer = vec![0; meta_len as usize];
reader.read_exact(&mut meta_buffer)?;

let message =
arrow_format::ipc::MessageRef::read_as_root(meta_buffer.as_slice()).map_err(|err| {
ArrowError::OutOfSpec(format!("Unable to get root as message: {:?}", err))
})?;
let version = message.version()?;
// message header is a Schema, so read it
let header = message
.header()?
.ok_or_else(|| ArrowError::oos("Unable to read the first IPC message"))?;
let schema = if let arrow_format::ipc::MessageHeaderRef::Schema(schema) = header {
schema
} else {
return Err(ArrowError::oos(
"The first IPC message of the stream must be a schema",
));
};
let (schema, ipc_schema) = fb_to_schema(schema)?;

Ok(StreamMetadata {
schema,
version,
ipc_schema,
})
deserialize_stream_metadata(&meta_buffer)
}

/// Encodes the stream's status after each read.
Expand Down
105 changes: 50 additions & 55 deletions src/io/ipc/read/stream_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,23 @@ use super::schema::deserialize_stream_metadata;
use super::Dictionaries;
use super::StreamMetadata;

/// A (private) state of stream messages
struct ReadState<R> {
pub reader: R,
pub metadata: StreamMetadata,
pub dictionaries: Dictionaries,
/// The internal buffer to read data inside the messages (records and dictionaries) to
pub data_buffer: Vec<u8>,
/// The internal buffer to read messages to
pub message_buffer: Vec<u8>,
}

/// The state of an Arrow stream
enum StreamState<R> {
/// The stream does not contain new chunks (and it has not been closed)
Waiting((R, StreamMetadata, Dictionaries)),
Waiting(ReadState<R>),
/// The stream contain a new chunk
Some((R, StreamMetadata, Dictionaries, Chunk<Arc<dyn Array>>)),
Some((ReadState<R>, Chunk<Arc<dyn Array>>)),
}

/// Reads the [`StreamMetadata`] of the Arrow stream asynchronously
Expand All @@ -49,24 +60,20 @@ pub async fn read_stream_metadata_async<R: AsyncRead + Unpin + Send>(

/// Reads the next item, yielding `None` if the stream has been closed,
/// or a [`StreamState`] otherwise.
async fn _read_next<R: AsyncRead + Unpin + Send>(
mut reader: R,
metadata: StreamMetadata,
mut dictionaries: Dictionaries,
message_buffer: &mut Vec<u8>,
data_buffer: &mut Vec<u8>,
async fn maybe_next<R: AsyncRead + Unpin + Send>(
mut state: ReadState<R>,
) -> Result<Option<StreamState<R>>> {
// determine metadata length
let mut meta_length: [u8; 4] = [0; 4];

match reader.read_exact(&mut meta_length).await {
match state.reader.read_exact(&mut meta_length).await {
Ok(()) => (),
Err(e) => {
return if e.kind() == std::io::ErrorKind::UnexpectedEof {
// Handle EOF without the "0xFFFFFFFF 0x00000000"
// valid according to:
// https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format
Ok(Some(StreamState::Waiting((reader, metadata, dictionaries))))
Ok(Some(StreamState::Waiting(state)))
} else {
Err(ArrowError::from(e))
};
Expand All @@ -77,7 +84,7 @@ async fn _read_next<R: AsyncRead + Unpin + Send>(
// If a continuation marker is encountered, skip over it and read
// the size from the next four bytes.
if meta_length == CONTINUATION_MARKER {
reader.read_exact(&mut meta_length).await?;
state.reader.read_exact(&mut meta_length).await?;
}
i32::from_le_bytes(meta_length) as usize
};
Expand All @@ -87,13 +94,14 @@ async fn _read_next<R: AsyncRead + Unpin + Send>(
return Ok(None);
}

message_buffer.clear();
message_buffer.resize(meta_length, 0);
reader.read_exact(message_buffer).await?;
state.message_buffer.clear();
state.message_buffer.resize(meta_length, 0);
state.reader.read_exact(&mut state.message_buffer).await?;

let message = arrow_format::ipc::MessageRef::read_as_root(message_buffer).map_err(|err| {
ArrowError::OutOfSpec(format!("Unable to get root as message: {:?}", err))
})?;
let message =
arrow_format::ipc::MessageRef::read_as_root(&state.message_buffer).map_err(|err| {
ArrowError::OutOfSpec(format!("Unable to get root as message: {:?}", err))
})?;
let header = message.header()?.ok_or_else(|| {
ArrowError::oos("IPC: unable to fetch the message header. The file or stream is corrupted.")
})?;
Expand All @@ -102,40 +110,40 @@ async fn _read_next<R: AsyncRead + Unpin + Send>(
arrow_format::ipc::MessageHeaderRef::Schema(_) => Err(ArrowError::oos("A stream ")),
arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => {
// read the block that makes up the record batch into a buffer
data_buffer.clear();
data_buffer.resize(message.body_length()? as usize, 0);
reader.read_exact(data_buffer).await?;
state.data_buffer.clear();
state.data_buffer.resize(message.body_length()? as usize, 0);
state.reader.read_exact(&mut state.data_buffer).await?;

read_record_batch(
batch,
&metadata.schema.fields,
&metadata.ipc_schema,
&state.metadata.schema.fields,
&state.metadata.ipc_schema,
None,
&dictionaries,
metadata.version,
&mut std::io::Cursor::new(data_buffer),
&state.dictionaries,
state.metadata.version,
&mut std::io::Cursor::new(&state.data_buffer),
0,
)
.map(|x| Some(StreamState::Some((reader, metadata, dictionaries, x))))
.map(|chunk| Some(StreamState::Some((state, chunk))))
}
arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => {
// read the block that makes up the dictionary batch into a buffer
let mut buf = vec![0; message.body_length()? as usize];
reader.read_exact(&mut buf).await?;
state.reader.read_exact(&mut buf).await?;

let mut dict_reader = std::io::Cursor::new(buf);

read_dictionary(
batch,
&metadata.schema.fields,
&metadata.ipc_schema,
&mut dictionaries,
&state.metadata.schema.fields,
&state.metadata.ipc_schema,
&mut state.dictionaries,
&mut dict_reader,
0,
)?;

// read the next message until we encounter a Chunk<Arc<dyn Array>> message
Ok(Some(StreamState::Waiting((reader, metadata, dictionaries))))
Ok(Some(StreamState::Waiting(state)))
}
t => Err(ArrowError::OutOfSpec(format!(
"Reading types other than record batches not yet supported, unable to read {:?} ",
Expand All @@ -144,22 +152,7 @@ async fn _read_next<R: AsyncRead + Unpin + Send>(
}
}

/// Reads the next item, yielding `None` if the stream is done,
/// and a [`StreamState`] otherwise.
async fn maybe_next<R: AsyncRead + Unpin + Send>(
reader: R,
metadata: StreamMetadata,
dictionaries: Dictionaries,
) -> Result<Option<StreamState<R>>> {
_read_next(reader, metadata, dictionaries, &mut vec![], &mut vec![]).await
}

/// Arrow Stream reader.
///
/// A [`Stream`] over an Arrow stream that yields a result of [`StreamState`]s.
/// This is the recommended way to read an arrow stream (by iterating over its data).
///
/// For a more thorough walkthrough consult [this example](/~https://github.com/jorgecarleitao/arrow2/tree/main/examples/ipc_pyarrow).
/// A [`Stream`] over an Arrow IPC stream that asynchronously yields [`Chunk`]s.
pub struct AsyncStreamReader<R: AsyncRead + Unpin + Send + 'static> {
metadata: StreamMetadata,
future: Option<BoxFuture<'static, Result<Option<StreamState<R>>>>>,
Expand All @@ -168,7 +161,14 @@ pub struct AsyncStreamReader<R: AsyncRead + Unpin + Send + 'static> {
impl<R: AsyncRead + Unpin + Send + 'static> AsyncStreamReader<R> {
/// Creates a new [`AsyncStreamReader`]
pub fn new(reader: R, metadata: StreamMetadata) -> Self {
let future = Some(Box::pin(maybe_next(reader, metadata.clone(), Default::default())) as _);
let state = ReadState {
reader,
metadata: metadata.clone(),
dictionaries: Default::default(),
data_buffer: Default::default(),
message_buffer: Default::default(),
};
let future = Some(Box::pin(maybe_next(state)) as _);
Self { metadata, future }
}

Expand All @@ -195,13 +195,8 @@ impl<R: AsyncRead + Unpin + Send> Stream for AsyncStreamReader<R> {
me.future = None;
Poll::Ready(None)
}
Poll::Ready(Ok(Some(StreamState::Some((
reader,
metadata,
dictionaries,
batch,
))))) => {
me.future = Some(Box::pin(maybe_next(reader, metadata, dictionaries)));
Poll::Ready(Ok(Some(StreamState::Some((state, batch))))) => {
me.future = Some(Box::pin(maybe_next(state)));
Poll::Ready(Some(Ok(batch)))
}
Poll::Ready(Ok(Some(StreamState::Waiting(_)))) => Poll::Pending,
Expand Down
2 changes: 1 addition & 1 deletion tests/it/io/ipc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ pub use common::read_gzip_json;
#[cfg(feature = "io_ipc_write_async")]
mod write_async;

//#[cfg(feature = "io_ipc_read_async")]
#[cfg(feature = "io_ipc_read_async")]
mod read_stream_async;

0 comments on commit f9696ff

Please sign in to comment.