Skip to content

Commit

Permalink
Change in line with review points on other PRs
Browse files Browse the repository at this point in the history
  • Loading branch information
martinbrose committed Sep 5, 2023
1 parent ac4e75e commit 56b314c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 38 deletions.
20 changes: 4 additions & 16 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,8 +677,8 @@ def object_detection(
return output

def question_answering(
self, question: str, context: str, model: str, *, parameters: Optional[Dict[str, Any]] = None
) -> List[QuestionAnsweringOutput]:
self, question: str, context: str, *, model: Optional[str] = None
) -> QuestionAnsweringOutput:
"""
Retrieve the answer to a question from a given text.
Expand All @@ -690,17 +690,9 @@ def question_answering(
model (`str`):
The model to use for the question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
a deployed Inference Endpoint.
parameters (`Dict[str, Any]`, *optional*):
Additional parameters for the question answering task. Defaults to None. For more details about the available
parameters, please refer to [this page](https://huggingface.co/docs/api-inference/detailed_parameters#question-answering-task)
Returns:
`Dict`: a dictionary containing:
- answer: A string that’s the answer within the text.
- score: A float that represents how likely that the answer is correct
- start: The index (string wise) of the start of the answer within context.
- stop: The index (string wise) of the stop of the answer within context.
`Dict`: a dictionary of question answering output containing the score, start index, end index, and answer.
Raises:
[`InferenceTimeoutError`]:
Expand All @@ -717,18 +709,14 @@ def question_answering(
{'score': 0.9326562285423279, 'start': 11, 'end': 16, 'answer': 'Clara'}
```
"""
if model is None:
raise ValueError("You must specify a model. Task question-answering has no recommended standard model.")

payload: Dict[str, Any] = {"question": question, "context": context}
if parameters is not None:
payload["parameters"] = parameters
response = self.post(
json=payload,
model=model,
task="question-answering",
)
return _bytes_to_dict(response)
return _bytes_to_dict(response) # type: ignore

def sentence_similarity(
self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None
Expand Down
20 changes: 4 additions & 16 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,8 +682,8 @@ async def object_detection(
return output

async def question_answering(
self, question: str, context: str, model: str, *, parameters: Optional[Dict[str, Any]] = None
) -> List[QuestionAnsweringOutput]:
self, question: str, context: str, *, model: Optional[str] = None
) -> QuestionAnsweringOutput:
"""
Retrieve the answer to a question from a given text.
Expand All @@ -695,17 +695,9 @@ async def question_answering(
model (`str`):
The model to use for the question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
a deployed Inference Endpoint.
parameters (`Dict[str, Any]`, *optional*):
Additional parameters for the question answering task. Defaults to None. For more details about the available
parameters, please refer to [this page](https://huggingface.co/docs/api-inference/detailed_parameters#question-answering-task)
Returns:
`Dict`: a dictionary containing:
- answer: A string that’s the answer within the text.
- score: A float that represents how likely that the answer is correct
- start: The index (string wise) of the start of the answer within context.
- stop: The index (string wise) of the stop of the answer within context.
`Dict`: a dictionary of question answering output containing the score, start index, end index, and answer.
Raises:
[`InferenceTimeoutError`]:
Expand All @@ -723,18 +715,14 @@ async def question_answering(
{'score': 0.9326562285423279, 'start': 11, 'end': 16, 'answer': 'Clara'}
```
"""
if model is None:
raise ValueError("You must specify a model. Task question-answering has no recommended standard model.")

payload: Dict[str, Any] = {"question": question, "context": context}
if parameters is not None:
payload["parameters"] = parameters
response = await self.post(
json=payload,
model=model,
task="question-answering",
)
return _bytes_to_dict(response)
return _bytes_to_dict(response) # type: ignore

async def sentence_similarity(
self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None
Expand Down
13 changes: 7 additions & 6 deletions src/huggingface_hub/inference/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,14 @@ class QuestionAnsweringOutput(TypedDict):
"""Dictionary containing information about a [`~InferenceClient.question_answering`] task.
Args:
label (`str`):
The label corresponding to the detected object.
box (`dict`):
A dict response of bounding box coordinates of
the detected object: xmin, ymin, xmax, ymax
score (`float`):
The score corresponding to the detected object.
A float that represents how likely that the answer is correct.
start (`int`):
The index (string wise) of the start of the answer within context.
end (`int`):
The index (string wise) of the end of the answer within context.
answer (`str`):
A string that is the answer within the text.
"""

score: float
Expand Down

0 comments on commit 56b314c

Please sign in to comment.