Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: impl deez completions 2nd endpoint #3

Merged
merged 1 commit into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions src/chat_compl.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
//! Reference: https://docs.x.ai/api/endpoints#chat-completions

use crate::error::XaiError;
use crate::traits::ChatCompletionsFetcher;
use crate::traits::ClientConfig;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<Message>,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,
pub n: Option<u32>,
pub stop: Option<Vec<String>>,
pub stream: Option<bool>,
pub logprobs: Option<bool>,
pub top_p: Option<f32>,
pub top_logprobs: Option<u32>,
pub seed: Option<u32>,
pub user: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Option<Usage>,
pub system_fingerprint: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
pub index: u32,
pub message: Message,
pub finish_reason: String,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}

#[derive(Debug, Clone)]
pub struct ChatCompletionsRequestBuilder<T: ClientConfig + Clone + Send + Sync> {
client: T,
request: ChatCompletionRequest,
}

impl<T> ChatCompletionsRequestBuilder<T>
where
T: ClientConfig + Clone + Send + Sync,
{
pub fn new(client: T, model: String, messages: Vec<Message>) -> Self {
Self {
client,
request: ChatCompletionRequest {
model,
messages,
temperature: None,
max_tokens: None,
frequency_penalty: None,
presence_penalty: None,
n: None,
stop: None,
stream: None,
logprobs: None,
top_p: None,
top_logprobs: None,
seed: None,
user: None,
},
}
}

pub fn temperature(mut self, temperature: f32) -> Self {
self.request.temperature = Some(temperature);
self
}

pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.request.max_tokens = Some(max_tokens);
self
}

pub fn n(mut self, n: u32) -> Self {
self.request.n = Some(n);
self
}

pub fn stop(mut self, stop: Vec<String>) -> Self {
self.request.stop = Some(stop);
self
}

pub fn build(self) -> Result<ChatCompletionRequest, XaiError> {
Ok(self.request)
}
}

impl<T> ChatCompletionsFetcher for ChatCompletionsRequestBuilder<T>
where
T: ClientConfig + Clone + Send + Sync,
{
async fn create_chat_completion(
&self,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, XaiError> {
let response = self
.client
.request(reqwest::Method::POST, "chat/completions")?
.json(&request)
.send()
.await?;

if response.status().is_success() {
let chat_completion = response.json::<ChatCompletionResponse>().await?;
Ok(chat_completion)
} else {
Err(XaiError::Http(
response.error_for_status().unwrap_err().to_string(),
))
}
}
}
145 changes: 100 additions & 45 deletions src/completions.rs
Original file line number Diff line number Diff line change
@@ -1,49 +1,46 @@
//! Reference: https://docs.x.ai/api/endpoints#chat-completions
//! Reference: https://docs.x.ai/api/endpoints#completions

use crate::error::XaiError;
use crate::traits::ChatCompletionsFetcher;
use crate::traits::ClientConfig;
use crate::traits::{ClientConfig, CompletionsFetcher};
use reqwest::Method;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub struct CompletionsRequest {
pub model: String,
pub messages: Vec<Message>,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub prompt: String,
pub best_of: Option<u32>,
pub echo: Option<bool>,
pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,
pub logit_bias: Option<std::collections::HashMap<String, i32>>,
pub logprobs: Option<u32>,
pub max_tokens: Option<u32>,
pub n: Option<u32>,
pub presence_penalty: Option<f32>,
pub seed: Option<u32>,
pub stop: Option<Vec<String>>,
pub stream: Option<bool>,
pub logprobs: Option<bool>,
pub suffix: Option<String>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_logprobs: Option<u32>,
pub seed: Option<u32>,
pub user: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub struct CompletionsResponse {
pub choices: Vec<Choice>,
pub created: u64,
pub id: String,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Option<Usage>,
pub object: String,
pub system_fingerprint: Option<String>,
pub usage: Option<Usage>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
pub index: u32,
pub message: Message,
pub text: String,
pub finish_reason: String,
}

Expand All @@ -55,39 +52,62 @@ pub struct Usage {
}

#[derive(Debug, Clone)]
pub struct ChatCompletionsRequestBuilder<T: ClientConfig + Clone + Send + Sync> {
pub struct CompletionsRequestBuilder<T: ClientConfig + Clone + Send + Sync> {
client: T,
request: ChatCompletionRequest,
request: CompletionsRequest,
}

impl<T> ChatCompletionsRequestBuilder<T>
impl<T> CompletionsRequestBuilder<T>
where
T: ClientConfig + Clone + Send + Sync,
{
pub fn new(client: T, model: String, messages: Vec<Message>) -> Self {
pub fn new(client: T, model: String, prompt: String) -> Self {
Self {
client,
request: ChatCompletionRequest {
request: CompletionsRequest {
model,
messages,
temperature: None,
max_tokens: None,
prompt,
best_of: None,
echo: None,
frequency_penalty: None,
presence_penalty: None,
logit_bias: None,
logprobs: None,
max_tokens: None,
n: None,
presence_penalty: None,
seed: None,
stop: None,
stream: None,
logprobs: None,
suffix: None,
temperature: None,
top_p: None,
top_logprobs: None,
seed: None,
user: None,
},
}
}

pub fn temperature(mut self, temperature: f32) -> Self {
self.request.temperature = Some(temperature);
pub fn best_of(mut self, best_of: u32) -> Self {
self.request.best_of = Some(best_of);
self
}

pub fn echo(mut self, echo: bool) -> Self {
self.request.echo = Some(echo);
self
}

pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
self.request.frequency_penalty = Some(frequency_penalty);
self
}

pub fn logit_bias(mut self, logit_bias: std::collections::HashMap<String, i32>) -> Self {
self.request.logit_bias = Some(logit_bias);
self
}

pub fn logprobs(mut self, logprobs: u32) -> Self {
self.request.logprobs = Some(logprobs);
self
}

Expand All @@ -101,34 +121,69 @@ where
self
}

pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
self.request.presence_penalty = Some(presence_penalty);
self
}

pub fn seed(mut self, seed: u32) -> Self {
self.request.seed = Some(seed);
self
}

pub fn stop(mut self, stop: Vec<String>) -> Self {
self.request.stop = Some(stop);
self
}

pub fn build(self) -> Result<ChatCompletionRequest, XaiError> {
pub fn stream(mut self, stream: bool) -> Self {
self.request.stream = Some(stream);
self
}

pub fn suffix(mut self, suffix: String) -> Self {
self.request.suffix = Some(suffix);
self
}

pub fn temperature(mut self, temperature: f32) -> Self {
self.request.temperature = Some(temperature);
self
}

pub fn top_p(mut self, top_p: f32) -> Self {
self.request.top_p = Some(top_p);
self
}

pub fn user(mut self, user: String) -> Self {
self.request.user = Some(user);
self
}

pub fn build(self) -> Result<CompletionsRequest, XaiError> {
Ok(self.request)
}
}

impl<T> ChatCompletionsFetcher for ChatCompletionsRequestBuilder<T>
impl<T> CompletionsFetcher for CompletionsRequestBuilder<T>
where
T: ClientConfig + Clone + Send + Sync,
{
async fn create_chat_completion(
async fn create_completions(
&self,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, XaiError> {
request: CompletionsRequest,
) -> Result<CompletionsResponse, XaiError> {
let response = self
.client
.request(reqwest::Method::POST, "chat/completions")?
.request(Method::POST, "/v1/completions")?
.json(&request)
.send()
.await?;

if response.status().is_success() {
let chat_completion = response.json::<ChatCompletionResponse>().await?;
Ok(chat_completion)
let completions = response.json::<CompletionsResponse>().await?;
Ok(completions)
} else {
Err(XaiError::Http(
response.error_for_status().unwrap_err().to_string(),
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod api_key;
pub mod chat_compl;
pub mod client;
pub mod completions;
pub mod error;
Expand Down
13 changes: 11 additions & 2 deletions src/traits.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#![allow(async_fn_in_trait)]

use crate::api_key::ApiKeyInfo;
use crate::completions::ChatCompletionRequest;
use crate::completions::ChatCompletionResponse;
use crate::chat_compl::ChatCompletionRequest;
use crate::chat_compl::ChatCompletionResponse;
use crate::completions::CompletionsRequest;
use crate::completions::CompletionsResponse;
use crate::error::XaiError;
use reqwest::{Method, RequestBuilder};

Expand All @@ -22,3 +24,10 @@ pub trait ChatCompletionsFetcher {
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, XaiError>;
}

pub trait CompletionsFetcher {
async fn create_completions(
&self,
request: CompletionsRequest,
) -> Result<CompletionsResponse, XaiError>;
}
Loading