Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add question answering to inference client #1609

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ has a simple API that supports the most common tasks. Here is a list of the curr
| NLP | [Conversational](https://huggingface.co/tasks/conversational) | ✅ | [`~InferenceClient.conversational`] |
| | [Feature Extraction](https://huggingface.co/tasks/feature-extraction) | ✅ | [`~InferenceClient.feature_extraction`] |
| | [Fill Mask](https://huggingface.co/tasks/fill-mask) | | |
| | [Question Answering](https://huggingface.co/tasks/question-answering) | | |
| | [Question Answering](https://huggingface.co/tasks/question-answering) | ✅ | [`~InferenceClient.question-answering`]
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
| | [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) | ✅ | [`~InferenceClient.sentence_similarity`] |
| | [Summarization](https://huggingface.co/tasks/summarization) | ✅ | [`~InferenceClient.summarization`] |
| | [Table Question Answering](https://huggingface.co/tasks/table-question-answering) | | |
Expand Down
43 changes: 43 additions & 0 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
ConversationalOutput,
ImageSegmentationOutput,
ObjectDetectionOutput,
QuestionAnsweringOutput,
TokenClassificationOutput,
)
from huggingface_hub.utils import (
Expand Down Expand Up @@ -676,6 +677,48 @@ def object_detection(
raise ValueError(f"Server output must be a list. Got {type(output)}: {str(output)[:200]}...")
return output

def question_answering(
self, question: str, context: str, *, model: Optional[str] = None
) -> QuestionAnsweringOutput:
"""
Retrieve the answer to a question from a given text.

Args:
question (`str`):
Question to be answered.
context (`str`):
The context of the question.
model (`str`):
The model to use for the question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
a deployed Inference Endpoint.

Returns:
`Dict`: a dictionary of question answering output containing the score, start index, end index, and answer.

Raises:
[`InferenceTimeoutError`]:
If the model is unavailable or the request times out.
`HTTPError`:
If the request fails with an HTTP error status code other than HTTP 503.

Example:
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient()
>>> output = client.question_answering(question="What's my name?", context="My name is Clara and I live in Berkeley.")
>>> output
{'score': 0.9326562285423279, 'start': 11, 'end': 16, 'answer': 'Clara'}
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
```
"""

payload: Dict[str, Any] = {"question": question, "context": context}
response = self.post(
json=payload,
model=model,
task="question-answering",
)
return _bytes_to_dict(response) # type: ignore

def sentence_similarity(
self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None
) -> List[float]:
Expand Down
44 changes: 44 additions & 0 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
ConversationalOutput,
ImageSegmentationOutput,
ObjectDetectionOutput,
QuestionAnsweringOutput,
TokenClassificationOutput,
)
from huggingface_hub.utils import (
Expand Down Expand Up @@ -681,6 +682,49 @@ async def object_detection(
raise ValueError(f"Server output must be a list. Got {type(output)}: {str(output)[:200]}...")
return output

async def question_answering(
self, question: str, context: str, *, model: Optional[str] = None
) -> QuestionAnsweringOutput:
"""
Retrieve the answer to a question from a given text.

Args:
question (`str`):
Question to be answered.
context (`str`):
The context of the question.
model (`str`):
The model to use for the question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
a deployed Inference Endpoint.

Returns:
`Dict`: a dictionary of question answering output containing the score, start index, end index, and answer.

Raises:
[`InferenceTimeoutError`]:
If the model is unavailable or the request times out.
`aiohttp.ClientResponseError`:
If the request fails with an HTTP error status code other than HTTP 503.

Example:
```py
# Must be run in an async context
>>> from huggingface_hub import AsyncInferenceClient
>>> client = AsyncInferenceClient()
>>> output = await client.question_answering(question="What's my name?", context="My name is Clara and I live in Berkeley.")
>>> output
{'score': 0.9326562285423279, 'start': 11, 'end': 16, 'answer': 'Clara'}
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
```
"""

payload: Dict[str, Any] = {"question": question, "context": context}
response = await self.post(
json=payload,
model=model,
task="question-answering",
)
return _bytes_to_dict(response) # type: ignore

async def sentence_similarity(
self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None
) -> List[float]:
Expand Down
20 changes: 20 additions & 0 deletions src/huggingface_hub/inference/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,26 @@ class ObjectDetectionOutput(TypedDict):
score: float


class QuestionAnsweringOutput(TypedDict):
"""Dictionary containing information about a [`~InferenceClient.question_answering`] task.

Args:
score (`float`):
A float that represents how likely that the answer is correct.
start (`int`):
The index (string wise) of the start of the answer within context.
end (`int`):
The index (string wise) of the end of the answer within context.
answer (`str`):
A string that is the answer within the text.
"""

score: float
start: int
end: int
answer: str


class TokenClassificationOutput(TypedDict):
"""Dictionary containing the output of a [`~InferenceClient.token_classification`] task.

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
interactions:
- request:
body: '{"question": "What is the meaning of life?", "context": "42"}'
headers:
Accept:
- '*/*'
Accept-Encoding:
- gzip, deflate, br
Connection:
- keep-alive
Content-Length:
- '61'
Content-Type:
- application/json
X-Amzn-Trace-Id:
- e3071dca-bb69-47b9-b9c7-f0ce58a69927
user-agent:
- unknown/None; hf_hub/0.17.0.dev0; python/3.10.12
method: POST
uri: https://api-inference.huggingface.co/models/deepset/roberta-base-squad2
response:
body:
string: '{"score":1.4291124728060822e-08,"start":0,"end":2,"answer":"42"}'
headers:
Connection:
- keep-alive
Content-Length:
- '64'
Content-Type:
- application/json
Date:
- Sun, 20 Aug 2023 18:17:17 GMT
access-control-allow-credentials:
- 'true'
vary:
- Origin, Access-Control-Request-Method, Access-Control-Request-Headers
x-compute-time:
- '0.094'
x-compute-type:
- cache
x-request-id:
- vY1d3zhYMs71Bmhh1OI5N
x-sha:
- e09df911dd96d8b052d2665dfbb309e9398a9d70
status:
code: 200
message: OK
version: 1
10 changes: 10 additions & 0 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,16 @@ def test_object_detection(self) -> None:
self.assertIn("xmax", item["box"])
self.assertIn("ymax", item["box"])

def test_question_answering(self) -> None:
model = "deepset/roberta-base-squad2"
output = self.client.question_answering(question="What is the meaning of life?", context="42", model=model)
self.assertIsInstance(output, dict)
self.assertGreater(len(output), 0)
self.assertIsInstance(output["score"], float)
self.assertIsInstance(output["start"], int)
self.assertIsInstance(output["end"], int)
self.assertIsInstance(output["answer"], str)
Wauplin marked this conversation as resolved.
Show resolved Hide resolved

def test_sentence_similarity(self) -> None:
scores = self.client.sentence_similarity(
"Machine learning is so easy.",
Expand Down