From 10e30909c3d7e0d9f67db8e14f666dfd5968a7da Mon Sep 17 00:00:00 2001 From: Linus Torvalds Date: Mon, 18 Nov 2024 07:51:29 +0200 Subject: [PATCH] feat: impl deez completions 2nd endpoint --- src/chat_compl.rs | 138 ++++++++++++++++++++++++++++++++++++++++ src/completions.rs | 145 +++++++++++++++++++++++++++++-------------- src/lib.rs | 1 + src/traits.rs | 13 +++- tests/chat_compl.rs | 124 ++++++++++++++++++++++++++++++++++++ tests/completions.rs | 120 ++++++++++++----------------------- 6 files changed, 412 insertions(+), 129 deletions(-) create mode 100644 src/chat_compl.rs create mode 100644 tests/chat_compl.rs diff --git a/src/chat_compl.rs b/src/chat_compl.rs new file mode 100644 index 0000000..6e12a71 --- /dev/null +++ b/src/chat_compl.rs @@ -0,0 +1,138 @@ +//! Reference: https://docs.x.ai/api/endpoints#chat-completions + +use crate::error::XaiError; +use crate::traits::ChatCompletionsFetcher; +use crate::traits::ClientConfig; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionRequest { + pub model: String, + pub messages: Vec, + pub temperature: Option, + pub max_tokens: Option, + pub frequency_penalty: Option, + pub presence_penalty: Option, + pub n: Option, + pub stop: Option>, + pub stream: Option, + pub logprobs: Option, + pub top_p: Option, + pub top_logprobs: Option, + pub seed: Option, + pub user: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub role: String, + pub content: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Option, + pub system_fingerprint: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Choice { + pub index: u32, + pub message: Message, + pub finish_reason: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Debug, Clone)] +pub struct ChatCompletionsRequestBuilder { + client: T, + request: ChatCompletionRequest, +} + +impl ChatCompletionsRequestBuilder +where + T: ClientConfig + Clone + Send + Sync, +{ + pub fn new(client: T, model: String, messages: Vec) -> Self { + Self { + client, + request: ChatCompletionRequest { + model, + messages, + temperature: None, + max_tokens: None, + frequency_penalty: None, + presence_penalty: None, + n: None, + stop: None, + stream: None, + logprobs: None, + top_p: None, + top_logprobs: None, + seed: None, + user: None, + }, + } + } + + pub fn temperature(mut self, temperature: f32) -> Self { + self.request.temperature = Some(temperature); + self + } + + pub fn max_tokens(mut self, max_tokens: u32) -> Self { + self.request.max_tokens = Some(max_tokens); + self + } + + pub fn n(mut self, n: u32) -> Self { + self.request.n = Some(n); + self + } + + pub fn stop(mut self, stop: Vec) -> Self { + self.request.stop = Some(stop); + self + } + + pub fn build(self) -> Result { + Ok(self.request) + } +} + +impl ChatCompletionsFetcher for ChatCompletionsRequestBuilder +where + T: ClientConfig + Clone + Send + Sync, +{ + async fn create_chat_completion( + &self, + request: ChatCompletionRequest, + ) -> Result { + let response = self + .client + .request(reqwest::Method::POST, "chat/completions")? + .json(&request) + .send() + .await?; + + if response.status().is_success() { + let chat_completion = response.json::().await?; + Ok(chat_completion) + } else { + Err(XaiError::Http( + response.error_for_status().unwrap_err().to_string(), + )) + } + } +} diff --git a/src/completions.rs b/src/completions.rs index 6e12a71..6e7658f 100644 --- a/src/completions.rs +++ b/src/completions.rs @@ -1,49 +1,46 @@ -//! Reference: https://docs.x.ai/api/endpoints#chat-completions +//! Reference: https://docs.x.ai/api/endpoints#completions use crate::error::XaiError; -use crate::traits::ChatCompletionsFetcher; -use crate::traits::ClientConfig; +use crate::traits::{ClientConfig, CompletionsFetcher}; +use reqwest::Method; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ChatCompletionRequest { +pub struct CompletionsRequest { pub model: String, - pub messages: Vec, - pub temperature: Option, - pub max_tokens: Option, + pub prompt: String, + pub best_of: Option, + pub echo: Option, pub frequency_penalty: Option, - pub presence_penalty: Option, + pub logit_bias: Option>, + pub logprobs: Option, + pub max_tokens: Option, pub n: Option, + pub presence_penalty: Option, + pub seed: Option, pub stop: Option>, pub stream: Option, - pub logprobs: Option, + pub suffix: Option, + pub temperature: Option, pub top_p: Option, - pub top_logprobs: Option, - pub seed: Option, pub user: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Message { - pub role: String, - pub content: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ChatCompletionResponse { - pub id: String, - pub object: String, +pub struct CompletionsResponse { + pub choices: Vec, pub created: u64, + pub id: String, pub model: String, - pub choices: Vec, - pub usage: Option, + pub object: String, pub system_fingerprint: Option, + pub usage: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Choice { pub index: u32, - pub message: Message, + pub text: String, pub finish_reason: String, } @@ -55,39 +52,62 @@ pub struct Usage { } #[derive(Debug, Clone)] -pub struct ChatCompletionsRequestBuilder { +pub struct CompletionsRequestBuilder { client: T, - request: ChatCompletionRequest, + request: CompletionsRequest, } -impl ChatCompletionsRequestBuilder +impl CompletionsRequestBuilder where T: ClientConfig + Clone + Send + Sync, { - pub fn new(client: T, model: String, messages: Vec) -> Self { + pub fn new(client: T, model: String, prompt: String) -> Self { Self { client, - request: ChatCompletionRequest { + request: CompletionsRequest { model, - messages, - temperature: None, - max_tokens: None, + prompt, + best_of: None, + echo: None, frequency_penalty: None, - presence_penalty: None, + logit_bias: None, + logprobs: None, + max_tokens: None, n: None, + presence_penalty: None, + seed: None, stop: None, stream: None, - logprobs: None, + suffix: None, + temperature: None, top_p: None, - top_logprobs: None, - seed: None, user: None, }, } } - pub fn temperature(mut self, temperature: f32) -> Self { - self.request.temperature = Some(temperature); + pub fn best_of(mut self, best_of: u32) -> Self { + self.request.best_of = Some(best_of); + self + } + + pub fn echo(mut self, echo: bool) -> Self { + self.request.echo = Some(echo); + self + } + + pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self { + self.request.frequency_penalty = Some(frequency_penalty); + self + } + + pub fn logit_bias(mut self, logit_bias: std::collections::HashMap) -> Self { + self.request.logit_bias = Some(logit_bias); + self + } + + pub fn logprobs(mut self, logprobs: u32) -> Self { + self.request.logprobs = Some(logprobs); self } @@ -101,34 +121,69 @@ where self } + pub fn presence_penalty(mut self, presence_penalty: f32) -> Self { + self.request.presence_penalty = Some(presence_penalty); + self + } + + pub fn seed(mut self, seed: u32) -> Self { + self.request.seed = Some(seed); + self + } + pub fn stop(mut self, stop: Vec) -> Self { self.request.stop = Some(stop); self } - pub fn build(self) -> Result { + pub fn stream(mut self, stream: bool) -> Self { + self.request.stream = Some(stream); + self + } + + pub fn suffix(mut self, suffix: String) -> Self { + self.request.suffix = Some(suffix); + self + } + + pub fn temperature(mut self, temperature: f32) -> Self { + self.request.temperature = Some(temperature); + self + } + + pub fn top_p(mut self, top_p: f32) -> Self { + self.request.top_p = Some(top_p); + self + } + + pub fn user(mut self, user: String) -> Self { + self.request.user = Some(user); + self + } + + pub fn build(self) -> Result { Ok(self.request) } } -impl ChatCompletionsFetcher for ChatCompletionsRequestBuilder +impl CompletionsFetcher for CompletionsRequestBuilder where T: ClientConfig + Clone + Send + Sync, { - async fn create_chat_completion( + async fn create_completions( &self, - request: ChatCompletionRequest, - ) -> Result { + request: CompletionsRequest, + ) -> Result { let response = self .client - .request(reqwest::Method::POST, "chat/completions")? + .request(Method::POST, "/v1/completions")? .json(&request) .send() .await?; if response.status().is_success() { - let chat_completion = response.json::().await?; - Ok(chat_completion) + let completions = response.json::().await?; + Ok(completions) } else { Err(XaiError::Http( response.error_for_status().unwrap_err().to_string(), diff --git a/src/lib.rs b/src/lib.rs index 3213de0..8e93209 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod api_key; +pub mod chat_compl; pub mod client; pub mod completions; pub mod error; diff --git a/src/traits.rs b/src/traits.rs index c9de45c..53d5186 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,8 +1,10 @@ #![allow(async_fn_in_trait)] use crate::api_key::ApiKeyInfo; -use crate::completions::ChatCompletionRequest; -use crate::completions::ChatCompletionResponse; +use crate::chat_compl::ChatCompletionRequest; +use crate::chat_compl::ChatCompletionResponse; +use crate::completions::CompletionsRequest; +use crate::completions::CompletionsResponse; use crate::error::XaiError; use reqwest::{Method, RequestBuilder}; @@ -22,3 +24,10 @@ pub trait ChatCompletionsFetcher { request: ChatCompletionRequest, ) -> Result; } + +pub trait CompletionsFetcher { + async fn create_completions( + &self, + request: CompletionsRequest, + ) -> Result; +} diff --git a/tests/chat_compl.rs b/tests/chat_compl.rs new file mode 100644 index 0000000..d8cd7b9 --- /dev/null +++ b/tests/chat_compl.rs @@ -0,0 +1,124 @@ +use mockito::{Matcher, Server}; +use reqwest::Method; +use serde_json::json; +use x_ai::client::XaiClient; +use x_ai::traits::ClientConfig; + +#[tokio::test] +async fn test_chat_completions() { + let mut server = Server::new_async().await; + + let chat_completion_mock = server + .mock("POST", "/v1/chat/completions") + .match_header("Content-Type", "application/json") + .match_body(Matcher::JsonString(r#" + { + "messages": [ + { + "role": "system", + "content": "You are Grok, a chatbot inspired by the Hitchhikers Guide to the Galaxy." + }, + { + "role": "user", + "content": "What is the answer to life and universe?" + } + ], + "model": "grok-beta", + "stream": false, + "temperature": 0 + } + "#.to_string())) + .with_status(200) + .with_header("Content-Type", "application/json") + .with_body(r#" + { + "id": "304e12ef-81f4-4e93-a41c-f5f57f6a2b56", + "object": "chat.completion", + "created": 1728511727, + "model": "grok-beta", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "The answer to the ultimate question of life, the universe, and everything is **42**, according to Douglas Adams science fiction series \"The Hitchhiker's Guide to the Galaxy.\" This number is often humorously referenced in discussions about the meaning of life. However, in the context of the story, the actual question to which 42 is the answer remains unknown, symbolizing the ongoing search for understanding the purpose or meaning of existence." + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 24, + "completion_tokens": 91, + "total_tokens": 115 + }, + "system_fingerprint": "fp_3813298403" + } + "#) + .create_async() + .await; + + let client = XaiClient::builder() + .base_url(&format!("{}/", server.url())) + .build() + .expect("Failed to build XaiClient"); + + client.set_api_key("test-api-key".to_string()); + + let body = json!({ + "messages": [ + { + "role": "system", + "content": "You are Grok, a chatbot inspired by the Hitchhikers Guide to the Galaxy." + }, + { + "role": "user", + "content": "What is the answer to life and universe?" + } + ], + "model": "grok-beta", + "stream": false, + "temperature": 0 + }); + + let result = client + .request(Method::POST, "/v1/chat/completions") + .expect("body") + .json(&body) + .send() + .await; + + assert!(result.is_ok()); + let response = result.unwrap(); + assert_eq!(response.status(), 200); + + let response_text = response.text().await.unwrap(); + assert_eq!( + response_text, + r#" + { + "id": "304e12ef-81f4-4e93-a41c-f5f57f6a2b56", + "object": "chat.completion", + "created": 1728511727, + "model": "grok-beta", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "The answer to the ultimate question of life, the universe, and everything is **42**, according to Douglas Adams science fiction series \"The Hitchhiker's Guide to the Galaxy.\" This number is often humorously referenced in discussions about the meaning of life. However, in the context of the story, the actual question to which 42 is the answer remains unknown, symbolizing the ongoing search for understanding the purpose or meaning of existence." + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 24, + "completion_tokens": 91, + "total_tokens": 115 + }, + "system_fingerprint": "fp_3813298403" + } + "# + ); + + chat_completion_mock.assert_async().await; +} diff --git a/tests/completions.rs b/tests/completions.rs index d8cd7b9..becb204 100644 --- a/tests/completions.rs +++ b/tests/completions.rs @@ -5,55 +5,42 @@ use x_ai::client::XaiClient; use x_ai::traits::ClientConfig; #[tokio::test] -async fn test_chat_completions() { +async fn test_completions_endpoint() { let mut server = Server::new_async().await; - let chat_completion_mock = server - .mock("POST", "/v1/chat/completions") + let mock_response = r#" + { + "choices": [], + "created": 0, + "id": "", + "model": "", + "object": "", + "system_fingerprint": "", + "usage": null + } + "#; + + let completions_mock = server + .mock("POST", "/v1/completions") .match_header("Content-Type", "application/json") - .match_body(Matcher::JsonString(r#" + .match_body(Matcher::JsonString( + r#" { - "messages": [ - { - "role": "system", - "content": "You are Grok, a chatbot inspired by the Hitchhikers Guide to the Galaxy." - }, - { - "role": "user", - "content": "What is the answer to life and universe?" - } - ], "model": "grok-beta", - "stream": false, - "temperature": 0 + "prompt": "What is the meaning of life?", + "best_of": 1, + "echo": false, + "max_tokens": 100, + "temperature": 0.7, + "n": 1, + "top_p": 1 } - "#.to_string())) + "# + .to_string(), + )) .with_status(200) .with_header("Content-Type", "application/json") - .with_body(r#" - { - "id": "304e12ef-81f4-4e93-a41c-f5f57f6a2b56", - "object": "chat.completion", - "created": 1728511727, - "model": "grok-beta", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "The answer to the ultimate question of life, the universe, and everything is **42**, according to Douglas Adams science fiction series \"The Hitchhiker's Guide to the Galaxy.\" This number is often humorously referenced in discussions about the meaning of life. However, in the context of the story, the actual question to which 42 is the answer remains unknown, symbolizing the ongoing search for understanding the purpose or meaning of existence." - }, - "finish_reason": "stop" - } - ], - "usage": { - "prompt_tokens": 24, - "completion_tokens": 91, - "total_tokens": 115 - }, - "system_fingerprint": "fp_3813298403" - } - "#) + .with_body(mock_response) .create_async() .await; @@ -65,23 +52,18 @@ async fn test_chat_completions() { client.set_api_key("test-api-key".to_string()); let body = json!({ - "messages": [ - { - "role": "system", - "content": "You are Grok, a chatbot inspired by the Hitchhikers Guide to the Galaxy." - }, - { - "role": "user", - "content": "What is the answer to life and universe?" - } - ], "model": "grok-beta", - "stream": false, - "temperature": 0 + "prompt": "What is the meaning of life?", + "best_of": 1, + "echo": false, + "max_tokens": 100, + "temperature": 0.7, + "n": 1, + "top_p": 1 }); let result = client - .request(Method::POST, "/v1/chat/completions") + .request(Method::POST, "/v1/completions") .expect("body") .json(&body) .send() @@ -92,33 +74,7 @@ async fn test_chat_completions() { assert_eq!(response.status(), 200); let response_text = response.text().await.unwrap(); - assert_eq!( - response_text, - r#" - { - "id": "304e12ef-81f4-4e93-a41c-f5f57f6a2b56", - "object": "chat.completion", - "created": 1728511727, - "model": "grok-beta", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "The answer to the ultimate question of life, the universe, and everything is **42**, according to Douglas Adams science fiction series \"The Hitchhiker's Guide to the Galaxy.\" This number is often humorously referenced in discussions about the meaning of life. However, in the context of the story, the actual question to which 42 is the answer remains unknown, symbolizing the ongoing search for understanding the purpose or meaning of existence." - }, - "finish_reason": "stop" - } - ], - "usage": { - "prompt_tokens": 24, - "completion_tokens": 91, - "total_tokens": 115 - }, - "system_fingerprint": "fp_3813298403" - } - "# - ); + assert_eq!(response_text, mock_response); - chat_completion_mock.assert_async().await; + completions_mock.assert_async().await; }