Skip to content

Commit

Permalink
Refactoring Deckretriever (#7671)
Browse files Browse the repository at this point in the history
Co-authored-by: Kamil Piechowiak <32928185+KamilPiechowiak@users.noreply.github.com>
GitOrigin-RevId: 0212b7d459cf1767ed68635aa42fff179c435fd4
  • Loading branch information
2 people authored and Manul from Pathway committed Nov 28, 2024
1 parent da48f52 commit d213e33
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 13 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
## [Unreleased]

### Added
- `pw.xpacks.llm.document_store.SlidesDocumentStore`, which is a subclass of `pw.xpacks.llm.document_store.DocumentStore` customized for retrieving slides from presentations.
- `pw.temporal.inactivity_detection` and `pw.temporal.utc_now` functions allowing for alerting and other time dependent usecases

### Changed
- `pw.Table.concat`, `pw.Table.with_id`, `pw.Table.with_id_from` no longer perform checks if ids are unique. It improves memory usage.
- table operations that store values (like `pw.Table.join`, `pw.Table.update_cells`) no longer store columns that are not used downstream.
- `append_only` column property is now propagated better (there are more places where we can infer it).
- **BREAKING**: Unused arguments from the constructor `pw.xpacks.llm.question_answering.DeckRetriever` are no longer accepted.

### Fixed
- `query_as_of_now` of `pw.stdlib.indexing.DataIndex` and `pw.stdlib.indexing.HybridIndex` now work in constant memory for infinite query stream (no query-related data is kept after query is answered).
Expand Down
67 changes: 64 additions & 3 deletions python/pathway/xpacks/llm/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

from collections.abc import Callable
from typing import TYPE_CHECKING, Iterable
from typing import TYPE_CHECKING, Iterable, TypeAlias

import jmespath

Expand Down Expand Up @@ -185,7 +185,7 @@ class FilterSchema(pw.Schema):
default_value=None, description="An optional Glob pattern for the file path"
)

InputsQuerySchema = FilterSchema
InputsQuerySchema: TypeAlias = FilterSchema

class InputsResultSchema(pw.Schema):
result: list[pw.Json]
Expand Down Expand Up @@ -365,7 +365,7 @@ def _get_jmespath_filter(

@pw.table_transformer
def inputs_query(
self, input_queries: pw.Table[InputsQuerySchema] # type: ignore
self, input_queries: pw.Table[InputsQuerySchema]
) -> pw.Table[InputsResultSchema]:
"""
Query ``DocumentStore`` for the list of input documents.
Expand Down Expand Up @@ -448,3 +448,64 @@ def retrieve_query(
@property
def index(self) -> DataIndex:
return self._retriever


class SlidesDocumentStore(DocumentStore):
"""
Document store for the ``slide-search`` application.
Builds a document indexing pipeline and starts an HTTP REST server.
Adds to the ``DocumentStore`` a new method ``parsed_documents`` a set of
documents metadata after the parsing and document post processing stages.
"""

excluded_response_metadata = ["b64_image"]

@pw.table_transformer
def parsed_documents_query(
self,
parse_docs_queries: pw.Table[DocumentStore.InputsQuerySchema],
) -> pw.Table:
"""
Query the SlidesDocumentStore for the list of documents with the associated
metadata after the parsing stage.
"""
docs = self.parsed_docs

all_metas = docs.reduce(metadatas=pw.reducers.tuple(pw.this.metadata))

parse_docs_queries = self.merge_filters(parse_docs_queries)

@pw.udf
def format_inputs(
metadatas: list[pw.Json] | None,
metadata_filter: str | None,
) -> list[pw.Json]:
metadatas = metadatas if metadatas is not None else []
if metadata_filter:
metadatas = [
m
for m in metadatas
if jmespath.search(
metadata_filter, m.value, options=_knn_lsh._glob_options
)
]

metadata_list: list[dict] = [m.as_dict() for m in metadatas]

for metadata in metadata_list:
for metadata_key in self.excluded_response_metadata:
metadata.pop(metadata_key, None)

return [pw.Json(m) for m in metadata_list]

input_results = parse_docs_queries.join_left(
all_metas, id=parse_docs_queries.id
).select(
all_metas.metadatas,
parse_docs_queries.metadata_filter,
)
input_results = input_results.select(
result=format_inputs(pw.this.metadatas, pw.this.metadata_filter)
)
return input_results
96 changes: 90 additions & 6 deletions python/pathway/xpacks/llm/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
from pathway.internals import ColumnReference, Table, udfs
from pathway.stdlib.indexing import DataIndex
from pathway.xpacks.llm import Doc, llms, prompts
from pathway.xpacks.llm.document_store import DocumentStore
from pathway.xpacks.llm.document_store import DocumentStore, SlidesDocumentStore
from pathway.xpacks.llm.llms import BaseChat, prompt_chat_single_qa
from pathway.xpacks.llm.prompts import prompt_qa_geometric_rag
from pathway.xpacks.llm.vector_store import VectorStoreClient, VectorStoreServer
from pathway.xpacks.llm.vector_store import (
SlidesVectorStoreServer,
VectorStoreClient,
VectorStoreServer,
)

if TYPE_CHECKING:
from pathway.xpacks.llm.servers import QASummaryRestServer
from pathway.xpacks.llm.servers import QARestServer, QASummaryRestServer


@pw.udf
Expand Down Expand Up @@ -455,14 +459,23 @@ def summarize_query(self, summarize_queries: pw.Table) -> pw.Table:

@pw.table_transformer
def retrieve(self, retrieve_queries: pw.Table) -> pw.Table:
"""
Retrieve documents from the index.
"""
return self.indexer.retrieve_query(retrieve_queries)

@pw.table_transformer
def statistics(self, statistics_queries: pw.Table) -> pw.Table:
"""
Get statistics about indexed files.
"""
return self.indexer.statistics_query(statistics_queries)

@pw.table_transformer
def list_documents(self, list_documents_queries: pw.Table) -> pw.Table:
"""
Get list of documents from the retriever.
"""
return self.indexer.inputs_query(list_documents_queries)

def build_server(
Expand Down Expand Up @@ -682,14 +695,45 @@ def answer_query(self, pw_ai_queries: pw.Table) -> pw.Table:
return result


class DeckRetriever(BaseRAGQuestionAnswerer):
"""Class for slides search."""
class DeckRetriever(BaseQuestionAnswerer):
"""
Builds the logic for the Retriever of slides.
Args:
indexer: document store for parsing and indexing slides.
search_topk: Number of slides to be returned by the `answer_query` method.
"""

excluded_response_metadata = ["b64_image"]

def __init__(
self,
indexer: SlidesDocumentStore | SlidesVectorStoreServer,
*,
search_topk: int = 6,
) -> None:
self.indexer = indexer
self._init_schemas()
self.search_topk = search_topk

self.server: None | QARestServer = None
self._pending_endpoints: list[tuple] = []

def _init_schemas(
self,
) -> None:
class PWAIQuerySchema(pw.Schema):
prompt: str
filters: str | None = pw.column_definition(default_value=None)

self.AnswerQuerySchema = PWAIQuerySchema
self.RetrieveQuerySchema = self.indexer.RetrieveQuerySchema
self.StatisticsQuerySchema = self.indexer.StatisticsQuerySchema
self.InputsQuerySchema = self.indexer.InputsQuerySchema

@pw.table_transformer
def answer_query(self, pw_ai_queries: pw.Table) -> pw.Table:
"""Return similar docs from the index."""
"""Return slides similar to the given query."""

pw_ai_results = pw_ai_queries + self.indexer.retrieve_query(
pw_ai_queries.select(
Expand Down Expand Up @@ -720,6 +764,46 @@ def _format_results(docs: pw.Json) -> pw.Json:

return pw_ai_results

@pw.table_transformer
def retrieve(self, retrieve_queries: pw.Table) -> pw.Table:
return self.indexer.retrieve_query(retrieve_queries)

@pw.table_transformer
def statistics(self, statistics_queries: pw.Table) -> pw.Table:
return self.indexer.statistics_query(statistics_queries)

@pw.table_transformer
def list_documents(self, list_documents_queries: pw.Table) -> pw.Table:
return self.indexer.parsed_documents_query(list_documents_queries)

def build_server(
self,
host: str,
port: int,
**rest_kwargs,
):
warn(
"build_server method is deprecated. Instead, use explicitly a server from pw.xpacks.llm.servers.",
DeprecationWarning,
stacklevel=2,
)
# circular import
from pathway.xpacks.llm.servers import QARestServer

self.server = QARestServer(host, port, self, **rest_kwargs)

def run_server(self, *args, **kwargs):
warn(
"run_server method is deprecated. Instead, use explicitly a server from pw.xpacks.llm.servers.",
DeprecationWarning,
stacklevel=2,
)
if self.server is None:
raise ValueError(
"HTTP server is not built, initialize it with `build_server`"
)
self.server.run(*args, **kwargs)


def send_post_request(
url: str, data: dict, headers: dict = {}, timeout: int | None = None
Expand Down
15 changes: 11 additions & 4 deletions python/pathway/xpacks/llm/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import logging
import threading
from collections.abc import Callable, Coroutine
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, TypeAlias, cast

import jmespath
import requests
Expand Down Expand Up @@ -330,7 +330,7 @@ class FilterSchema(pw.Schema):
default_value=None, description="An optional Glob pattern for the file path"
)

InputsQuerySchema = FilterSchema
InputsQuerySchema: TypeAlias = FilterSchema

@staticmethod
def merge_filters(queries: pw.Table):
Expand Down Expand Up @@ -363,7 +363,7 @@ def _get_jmespath_filter(

@pw.table_transformer
def inputs_query(
self, input_queries: pw.Table[InputsQuerySchema] # type:ignore
self, input_queries: pw.Table[InputsQuerySchema]
) -> pw.Table[InputResultSchema]:
docs = self._graph["docs"]
# TODO: compare this approach to first joining queries to dicuments, then filtering,
Expand Down Expand Up @@ -576,7 +576,7 @@ class SlidesVectorStoreServer(VectorStoreServer):
@pw.table_transformer
def inputs_query(
self,
input_queries: pw.Table[VectorStoreServer.InputsQuerySchema], # type:ignore
input_queries: pw.Table[VectorStoreServer.InputsQuerySchema],
) -> pw.Table:
docs = self._graph["parsed_docs"]

Expand Down Expand Up @@ -617,6 +617,13 @@ def format_inputs(
)
return input_results

@pw.table_transformer
def parsed_documents_query(
self,
parse_docs_queries: pw.Table[VectorStoreServer.InputsQuerySchema],
) -> pw.Table:
return self.inputs_query(parse_docs_queries)


class VectorStoreClient:
"""
Expand Down

0 comments on commit d213e33

Please sign in to comment.