Skip to content

Commit

Permalink
feat: add to_polars_lazy (#195)
Browse files Browse the repository at this point in the history
  • Loading branch information
tshauck authored Jan 18, 2025
1 parent 972bb89 commit 7e95e6e
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
22 changes: 22 additions & 0 deletions python/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,25 @@ def test_connect_and_to_arrow():
assert len(arrow_table) == 2


@pytest.mark.skipif(
not importlib.util.find_spec("polars"), reason="polars not installed"
)
def test_to_polars_lazy():
session = connect()
fastq_path = DATA / "test.fastq.gz"

df = session.sql(
f"""
SELECT quality_scores_to_list(quality_scores) quality_score_list,
locate_regex(sequence, '[AC]AT') locate
FROM fastq_scan('{fastq_path}')
"""
).to_polars_lazy()

assert isinstance(df, pl.LazyFrame)
assert len(df.collect()) == 2


@pytest.mark.skipif(
not importlib.util.find_spec("polars"), reason="polars not installed"
)
Expand Down Expand Up @@ -148,6 +167,7 @@ def test_read_fastq():

assert len(df) == 2


@pytest.mark.skipif(
not importlib.util.find_spec("polars"), reason="polars not installed"
)
Expand Down Expand Up @@ -293,6 +313,7 @@ def test_read_fasta_gz():

assert len(df) == 2


def test_read_fasta_bz2():
"""Test reading a fasta.bz2 file."""
session = connect()
Expand All @@ -306,6 +327,7 @@ def test_read_fasta_bz2():

assert len(df) == 2


@pytest.mark.skipif(
not importlib.util.find_spec("polars"), reason="polars not installed"
)
Expand Down
34 changes: 34 additions & 0 deletions src/execution_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,40 @@ impl ExecutionResult {
Ok(table)
}

/// Convert to a Polars LazyFrame
fn to_polars_lazy(&self, py: Python) -> PyResult<PyObject> {
let stream = wait_for_future(py, self.df.as_ref().clone().execute_stream())
.map_err(error::BioBearError::from)?;

let schema = stream.schema().to_pyarrow(py)?;

let runtime = Arc::new(Runtime::new()?);

let dataframe_record_batch_stream = DataFrameRecordBatchStream::new(stream, runtime);

let mut stream = FFI_ArrowArrayStream::new(Box::new(dataframe_record_batch_stream));

let batches =
unsafe { ArrowArrayStreamReader::from_raw(&mut stream).map_err(BioBearError::from) }?;

let batches = batches.into_pyarrow(py)?;

let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = (batches, schema);

let table: PyObject = table_class.call_method1("from_batches", args)?.into();

let dataset_class = py.import_bound("pyarrow.dataset")?;

let dataset: PyObject = dataset_class.call_method1("dataset", (table,))?.into();

let module = py.import_bound("polars")?;
let args = (dataset,);
let result = module.call_method1("scan_pyarrow_dataset", args)?.into();

Ok(result)
}

/// Convert to a Polars DataFrame
fn to_polars(&self, py: Python) -> PyResult<PyObject> {
let stream = wait_for_future(py, self.df.as_ref().clone().execute_stream())
Expand Down

0 comments on commit 7e95e6e

Please sign in to comment.