Skip to content

Commit

Permalink
Add score and other metadatas at /v1/retrieve endpoint (#1055)
Browse files Browse the repository at this point in the history
* Edit /v1/retrieve endpoint return formats (includes score, filepath and file page)

* Edit the endpoint documentation

* resolve test error
  • Loading branch information
vkehfdl1 authored Dec 15, 2024
1 parent 94f51d7 commit 05f4e72
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 41 deletions.
34 changes: 14 additions & 20 deletions autorag/deploy/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class QueryRequest(BaseModel):
class RetrievedPassage(BaseModel):
content: str
doc_id: str
score: float
filepath: Optional[str] = None
file_page: Optional[int] = None
start_idx: Optional[int] = None
Expand All @@ -41,14 +42,8 @@ class RunResponse(BaseModel):
retrieved_passage: List[RetrievedPassage]


class Passage(BaseModel):
doc_id: str
content: str
score: float


class RetrievalResponse(BaseModel):
passages: List[Passage]
passages: List[RetrievedPassage]


class StreamResponse(BaseModel):
Expand Down Expand Up @@ -153,18 +148,9 @@ async def run_retrieve_only():
previous_result = pd.concat([drop_previous_result, new_result], axis=1)

# Simulate processing the query
retrieved_contents = previous_result["retrieved_contents"].tolist()[0]
retrieved_ids = previous_result["retrieved_ids"].tolist()[0]
retrieve_scores = previous_result["retrieve_scores"].tolist()[0]

retrieval_response = RetrievalResponse(
passages=[
Passage(doc_id=doc_id, content=content, score=score)
for doc_id, content, score in zip(
retrieved_ids, retrieved_contents, retrieve_scores
)
]
)
retrieved_passages = self.extract_retrieve_passage(previous_result)

retrieval_response = RetrievalResponse(passages=retrieved_passages)
return jsonify(retrieval_response.model_dump()), 200

@self.app.route("/v1/stream", methods=["POST"])
Expand Down Expand Up @@ -264,6 +250,7 @@ def run_api_server(
def extract_retrieve_passage(self, df: pd.DataFrame) -> List[RetrievedPassage]:
retrieved_ids: List[str] = df["retrieved_ids"].tolist()[0]
contents = fetch_contents(self.corpus_df, [retrieved_ids])[0]
scores = df["retrieve_scores"].tolist()[0]
if "path" in self.corpus_df.columns:
paths = fetch_contents(self.corpus_df, [retrieved_ids], column_name="path")[
0
Expand All @@ -282,16 +269,23 @@ def extract_retrieve_passage(self, df: pd.DataFrame) -> List[RetrievedPassage]:
start_end_indices = to_list(start_end_indices)
return list(
map(
lambda content, doc_id, path, metadata, start_end_idx: RetrievedPassage(
lambda content,
doc_id,
score,
path,
metadata,
start_end_idx: RetrievedPassage(
content=content,
doc_id=doc_id,
score=score,
filepath=path,
file_page=metadata.get("page", None),
start_idx=start_end_idx[0] if start_end_idx else None,
end_idx=start_end_idx[1] if start_end_idx else None,
),
contents,
retrieved_ids,
scores,
paths,
metadatas,
start_end_indices,
Expand Down
46 changes: 25 additions & 21 deletions docs/source/deploy/api_endpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ This is the URL to your local server, so use it as the host at request.
- **Properties**:
- `content` (string): The content of the passage.
- `doc_id` (string): Document ID.
- `score` (string): The relevance score of the retrieved passage.
- `filepath` (string, nullable): File path.
- `file_page` (integer, nullable): File page number.
- `start_idx` (integer, nullable): Start index.
Expand Down Expand Up @@ -105,31 +106,17 @@ The request must include a JSON object with the following structure:
#### Success Response
**HTTP Status Code:** `200 OK`

#### Response Body
On a successful retrieval, the response will contain a JSON object structured as follows:

```json
{
"passages": [
{
"doc_id": "unique-document-id-1",
"content": "Content of the retrieved document.",
"score": 0.95
},
{
"doc_id": "unique-document-id-2",
"content": "Content of another retrieved document.",
"score": 0.89
}
]
}
```
### Response Body

#### Properties
- **passages** (array): An array of documents retrieved based on the query.
- **doc_id** (string): The unique identifier for each document.
- **content** (string): The content of the retrieved document.
- **score** (number, float): The relevance score of the retrieved document.
- **filepath** (string, optional): The file path of the document.
- **file_page** (integer, optional): The page number of the document.
- **start_idx** (integer, optional): The start index of the retrieved passage from the parsed data.
- **end_idx** (integer, optional): The end index of the retrieved passage from the parsed data.

#### Example Response
```json
Expand All @@ -138,12 +125,20 @@ On a successful retrieval, the response will contain a JSON object structured as
{
"doc_id": "doc123",
"content": "Artificial Intelligence is transforming industries.",
"score": 0.98
"score": 0.98,
"filepath": "path/to/file",
"file_page": 2,
"start_idx": 100,
"end_idx": 150
},
{
"doc_id": "doc456",
"content": "The future of AI includes advancements in machine learning.",
"score": 0.92
"score": 0.92,
"filepath": null,
"file_page": null,
"start_idx": null,
"end_idx": null
}
]
}
Expand Down Expand Up @@ -172,6 +167,7 @@ On a successful retrieval, the response will contain a JSON object structured as
- **Properties**:
- `content` (string): The content of the passage.
- `doc_id` (string): Document ID.
- `score` (string): The relevance score of the retrieved passage.
- `filepath` (string, nullable): File path.
- `file_page` (integer, nullable): File page number.
- `start_idx` (integer, nullable): Start index.
Expand Down Expand Up @@ -295,6 +291,14 @@ curl -X POST "http://example.com/v1/stream" \
--no-buffer
```

### `/v1/retrieve` (POST)

```bash
curl -X POST "http://example.com/v1/retrieve" \
-H "Content-Type: application/json" \
-d '{"query": "example query"}'
```

#### `/version` (GET)

```bash
Expand Down
4 changes: 4 additions & 0 deletions tests/autorag/test_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,10 @@ async def post_to_server_retrieve():
assert "doc_id" in passages[0]
assert "content" in passages[0]
assert "score" in passages[0]
assert "filepath" in passages[0]
assert "file_page" in passages[0]
assert "start_idx" in passages[0]
assert "end_idx" in passages[0]
assert isinstance(passages[0]["doc_id"], str)
assert isinstance(passages[0]["content"], str)
assert isinstance(passages[0]["score"], float)
Expand Down

0 comments on commit 05f4e72

Please sign in to comment.