-
Notifications
You must be signed in to change notification settings - Fork 588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add text classification to inference client #1606
Add text classification to inference client #1606
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @martinbrose, I'm starting to review all your inference-related PRs. Thanks a ton for the massive work, it's good quality for what I saw! 👏 🙏 I have left some comments on this PR that are also relevant to the other PRs (especially simplifying as much as possible the methods signature + the merge conflict issue). In the meantime, I'll take the time to thoroughly review the other PRs.
FYI, I'm off this Thursday/Friday and be fully back on the project starting from next week :)
Thanks for the review! |
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## main #1606 +/- ##
==========================================
- Coverage 82.30% 81.78% -0.53%
==========================================
Files 62 60 -2
Lines 6964 6785 -179
==========================================
- Hits 5732 5549 -183
- Misses 1232 1236 +4
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making the changes @martinbrose :) I left some comments but most of them are due to the change between a multi-text input and a single-text input. Once those are addressed, I think we'll be good to merge 🚀
>>> output | ||
{'label': 'POSITIVE', 'score': 0.9998695850372314} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert back to list output
>>> output | |
{'label': 'POSITIVE', 'score': 0.9998695850372314} | |
[[{'label': 'POSITIVE', 'score': 0.9998695850372314}, {'label': 'NEGATIVE', 'score': 0.0001304351753788069}]] |
payload: Dict[str, Any] = {"inputs": text} | ||
response = self.post( | ||
json=payload, | ||
model=model, | ||
task="text-classification", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) having a separate payload
variable is fine as well but when it's tiny like this one (no parameter other than inputs), I prefer to pass it directly to .post
method. No big deal anyway.
payload: Dict[str, Any] = {"inputs": text} | |
response = self.post( | |
json=payload, | |
model=model, | |
task="text-classification", | |
) | |
response = self.post(json={"inputs": text}, model=model, task="text-classification") |
>>> output | ||
{'label': 'POSITIVE', 'score': 0.9998695850372314} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
>>> output | |
{'label': 'POSITIVE', 'score': 0.9998695850372314} | |
[{'label': 'POSITIVE', 'score': 0.9998695850372314}, {'label': 'NEGATIVE', 'score': 0.0001304351753788069}] |
(same as sync version)
payload: Dict[str, Any] = {"inputs": text} | ||
response = await self.post( | ||
json=payload, | ||
model=model, | ||
task="text-classification", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
payload: Dict[str, Any] = {"inputs": text} | |
response = await self.post( | |
json=payload, | |
model=model, | |
task="text-classification", | |
) | |
response = await self.post(json={"inputs": text}, model=model, task="text-classification") |
(same as sync version)
@@ -0,0 +1,48 @@ | |||
interactions: | |||
- request: | |||
body: '{"inputs": ["I like you", "I love you."]}' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
body: '{"inputs": ["I like you", "I love you."]}' | |
body: '{"inputs": ["I like you"]}' |
should be only 1 sample now
uri: https://api-inference.huggingface.co/models/distilbert-base-uncased-finetuned-sst-2-english | ||
response: | ||
body: | ||
string: '[[{"label":"POSITIVE","score":0.9998695850372314},{"label":"NEGATIVE","score":0.0001304351753788069}],[{"label":"POSITIVE","score":0.9998705387115479},{"label":"NEGATIVE","score":0.00012938841246068478}]]' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
string: '[[{"label":"POSITIVE","score":0.9998695850372314},{"label":"NEGATIVE","score":0.0001304351753788069}],[{"label":"POSITIVE","score":0.9998705387115479},{"label":"NEGATIVE","score":0.00012938841246068478}]]' | |
string: '[[{"label":"POSITIVE","score":0.9998695850372314},{"label":"NEGATIVE","score":0.0001304351753788069}]]' |
... and therefore only 1 response
tests/test_inference_client.py
Outdated
self.assertIsInstance(item[0]["score"], float) | ||
self.assertIsInstance(item[0]["label"], str) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.assertIsInstance(item[0]["score"], float) | |
self.assertIsInstance(item[0]["label"], str) | |
self.assertIsInstance(item["score"], float) | |
self.assertIsInstance(item["label"], str) |
1 level less
model=model, | ||
task="text-classification", | ||
) | ||
return _bytes_to_list(response) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return _bytes_to_list(response) | |
return _bytes_to_list(response)[0] |
Since we take as input only a str
(not a List[str]
), we need to output the first item returned (since the server returns a list of list of items).
model=model, | ||
task="text-classification", | ||
) | ||
return _bytes_to_list(response) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return _bytes_to_list(response) | |
return _bytes_to_list(response)[0] |
(same as sync version)
I've merged the suggested changes (see above) and tried it locally. It works great! :) |
Add text-classification to HuggingFace🤗 Hub
References #1539
This is an ongoing list of model tasks to implement in the Hugging Face Hub inference client. Each task is planned to be its own PR. The task for this is text-classification.
Key Features