Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Orchestration LLM Config Convenience #152

Merged
merged 23 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions docs/guides/ORCHESTRATION_CHAT_COMPLETION.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ To use the Orchestration service, create a client and a configuration object:
var client = new OrchestrationClient();

var config = new OrchestrationModuleConfig()
.withLlmConfig(LLMModuleConfig.create().modelName("gpt-35-turbo").modelParams(Map.of()));
.withLlmConfig(OrchestrationAiModel.GPT_4O);
```

Please also refer to [our sample code](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java) for this and all following code examples.
Expand Down Expand Up @@ -218,13 +218,11 @@ In this example, the input will be masked before the call to the LLM. Note that
Change your LLM module configuration to add model parameters:

```java
var llmConfig =
LLMModuleConfig.create()
.modelName("gpt-35-turbo")
.modelParams(
Map.of(
"max_tokens", 50,
"temperature", 0.1,
"frequency_penalty", 0,
"presence_penalty", 0));
LLMModuleConfig llmConfig =
OrchestrationAiModel.GPT_4O.modelParams(
Jonas-Isr marked this conversation as resolved.
Show resolved Hide resolved
Map.of(
"max_tokens", 50,
"temperature", 0.1,
"frequency_penalty", 0,
"presence_penalty", 0));
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package com.sap.ai.sdk.orchestration;

import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
import java.util.Map;
import javax.annotation.Nonnull;

/** Large language models available in Orchestration. */
Jonas-Isr marked this conversation as resolved.
Show resolved Hide resolved
// https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/models-and-scenarios-in-generative-ai-hub
@SuppressWarnings("unused") // tested in OrchestrationTest.orchestrationModelAvailability()
public class OrchestrationAiModel extends LLMModuleConfig {
Jonas-Isr marked this conversation as resolved.
Show resolved Hide resolved

/** IBM Granite 13B chat completions model */
public static final OrchestrationAiModel IBM_GRANITE_13B_CHAT =
new OrchestrationAiModel("ibm--granite-13b-chat");

/** MistralAI Mistral Large Instruct model */
public static final OrchestrationAiModel MISTRAL_LARGE_INSTRUCT =
new OrchestrationAiModel("mistralai--mistral-large-instruct");

/** MistralAI Mixtral 8x7B Instruct v01 model */
public static final OrchestrationAiModel MIXTRAL_8X7B_INSTRUCT_V01 =
new OrchestrationAiModel("mistralai--mixtral-8x7b-instruct-v01");

/** Meta Llama3 70B Instruct model */
public static final OrchestrationAiModel LLAMA3_70B_INSTRUCT =
new OrchestrationAiModel("meta--llama3-70b-instruct");

/** Meta Llama3.1 70B Instruct model */
public static final OrchestrationAiModel LLAMA3_1_70B_INSTRUCT =
new OrchestrationAiModel("meta--llama3.1-70b-instruct");

/** Anthropic Claude 3 Sonnet model */
public static final OrchestrationAiModel CLAUDE_3_SONNET =
new OrchestrationAiModel("anthropic--claude-3-sonnet");

/** Anthropic Claude 3 Haiku model */
public static final OrchestrationAiModel CLAUDE_3_HAIKU =
new OrchestrationAiModel("anthropic--claude-3-haiku");

/** Anthropic Claude 3 Opus model */
public static final OrchestrationAiModel CLAUDE_3_OPUS =
new OrchestrationAiModel("anthropic--claude-3-opus");

/** Anthropic Claude 3.5 Sonnet model */
public static final OrchestrationAiModel CLAUDE_3_5_SONNET =
new OrchestrationAiModel("anthropic--claude-3.5-sonnet");

/** Amazon Titan Embed Text model */
public static final OrchestrationAiModel TITAN_EMBED_TEXT =
new OrchestrationAiModel("amazon--titan-embed-text");

/** Amazon Titan Text Lite model */
public static final OrchestrationAiModel TITAN_TEXT_LITE =
new OrchestrationAiModel("amazon--titan-text-lite");

/** Amazon Titan Text Express model */
public static final OrchestrationAiModel TITAN_TEXT_EXPRESS =
new OrchestrationAiModel("amazon--titan-text-express");

/** Azure OpenAI GPT-3.5 Turbo chat completions model */
public static final OrchestrationAiModel GPT_35_TURBO = new OrchestrationAiModel("gpt-35-turbo");

/** Azure OpenAI GPT-3.5 Turbo chat completions model */
public static final OrchestrationAiModel GPT_35_TURBO_16K =
new OrchestrationAiModel("gpt-35-turbo-16k");

/** Azure OpenAI GPT-4 chat completions model */
public static final OrchestrationAiModel GPT_4 = new OrchestrationAiModel("gpt-4");

/** Azure OpenAI GPT-4-32k chat completions model */
public static final OrchestrationAiModel GPT_4_32K = new OrchestrationAiModel("gpt-4-32k");

/** Azure OpenAI Text Embedding Ada 002 model */
public static final OrchestrationAiModel TEXT_EMBEDDING_ADA_002 =
new OrchestrationAiModel("text-embedding-ada-002");

/** Azure OpenAI Text Embedding 3 Small model */
public static final OrchestrationAiModel TEXT_EMBEDDING_3_SMALL =
new OrchestrationAiModel("text-embedding-3-small");

/** Azure OpenAI Text Embedding 3 Large model */
public static final OrchestrationAiModel TEXT_EMBEDDING_3_LARGE =
new OrchestrationAiModel("text-embedding-3-large");

/** Azure OpenAI GPT-4o chat completions model */
public static final OrchestrationAiModel GPT_4O = new OrchestrationAiModel("gpt-4o");

/** Azure OpenAI GPT-4o-mini chat completions model */
public static final OrchestrationAiModel GPT_4O_MINI = new OrchestrationAiModel("gpt-4o-mini");

/** Google Cloud Platform Text Bison model */
public static final OrchestrationAiModel TEXT_BISON = new OrchestrationAiModel("text-bison");

/** Google Cloud Platform Chat Bison model */
public static final OrchestrationAiModel CHAT_BISON = new OrchestrationAiModel("chat-bison");

/** Google Cloud Platform Text Embedding Gecko model */
public static final OrchestrationAiModel TEXT_EMBEDDING_GECKO =
new OrchestrationAiModel("textembedding-gecko");

/** Google Cloud Platform Text Embedding Gecko Multilingual model */
public static final OrchestrationAiModel TEXT_EMBEDDING_GECKO_MULTILINGUAL =
new OrchestrationAiModel("textembedding-gecko-multilingual");

/** Google Cloud Platform Gemini 1.0 Pro model */
public static final OrchestrationAiModel GEMINI_1_0_PRO =
new OrchestrationAiModel("gemini-1.0-pro");

/** Google Cloud Platform Gemini 1.5 Pro model */
public static final OrchestrationAiModel GEMINI_1_5_PRO =
new OrchestrationAiModel("gemini-1.5-pro");

/** Google Cloud Platform Gemini 1.5 Flash model */
public static final OrchestrationAiModel GEMINI_1_5_FLASH =
new OrchestrationAiModel("gemini-1.5-flash");

OrchestrationAiModel(@Nonnull final String modelName) {
newtork marked this conversation as resolved.
Show resolved Hide resolved
setModelName(modelName);
setModelParams(Map.of());
}
newtork marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,12 @@
@WireMockTest
class OrchestrationUnitTest {
static final LLMModuleConfig LLM_CONFIG =
LLMModuleConfig.create()
.modelName("gpt-35-turbo-16k")
.modelParams(
Map.of(
"max_tokens", 50,
"temperature", 0.1,
"frequency_penalty", 0,
"presence_penalty", 0));
OrchestrationAiModel.GPT_35_TURBO_16K.modelParams(
Map.of(
"max_tokens", 50,
"temperature", 0.1,
"frequency_penalty", 0,
"presence_penalty", 0));
private final Function<String, InputStream> fileLoader =
filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename));

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.sap.ai.sdk.app.controllers;

import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_35_TURBO;

import com.sap.ai.sdk.orchestration.OrchestrationClient;
import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig;
import com.sap.ai.sdk.orchestration.OrchestrationPrompt;
Expand Down Expand Up @@ -30,12 +32,11 @@
@RestController
@RequestMapping("/orchestration")
class OrchestrationController {
static final LLMModuleConfig LLM_CONFIG =
LLMModuleConfig.create().modelName("gpt-35-turbo").modelParams(Map.of());
LLMModuleConfig llmConfig = GPT_35_TURBO;

private final OrchestrationClient client = new OrchestrationClient();
private final OrchestrationModuleConfig config =
new OrchestrationModuleConfig().withLlmConfig(LLM_CONFIG);
new OrchestrationModuleConfig().withLlmConfig(llmConfig);
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved

/**
* Chat request to OpenAI through the Orchestration service with a simple prompt.
Expand Down Expand Up @@ -171,7 +172,7 @@ public CompletionPostResponse maskingAnonymization() {

/**
* Let the orchestration service a response to a hypothetical user who provided feedback on the AI
* SDK. Pseydonymize the user's name and location to protect their privacy.
* SDK. Pseudonymize the user's name and location to protect their privacy.
*
* @return the result object
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,22 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import com.sap.ai.sdk.orchestration.OrchestrationAiModel;
import com.sap.ai.sdk.orchestration.OrchestrationClientException;
import com.sap.ai.sdk.orchestration.client.model.AzureThreshold;
import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.assertj.core.api.InstanceOfAssertFactories;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;

@Slf4j
class OrchestrationTest {
OrchestrationController controller;

Expand Down Expand Up @@ -41,7 +48,7 @@ void testTemplate() {
assertThat(llm.getId()).isNotEmpty();
assertThat(llm.getObject()).isEqualTo("chat.completion");
assertThat(llm.getCreated()).isGreaterThan(1);
assertThat(llm.getModel()).isEqualTo(OrchestrationController.LLM_CONFIG.getModelName());
assertThat(llm.getModel()).isEqualTo(controller.llmConfig.getModelName());
var choices = llm.getChoices();
assertThat(choices.get(0).getIndex()).isZero();
assertThat(choices.get(0).getMessage().getContent()).isNotEmpty();
Expand All @@ -54,7 +61,7 @@ void testTemplate() {
assertThat(result.getOrchestrationResult().getObject()).isEqualTo("chat.completion");
assertThat(result.getOrchestrationResult().getCreated()).isGreaterThan(1);
assertThat(result.getOrchestrationResult().getModel())
.isEqualTo(OrchestrationController.LLM_CONFIG.getModelName());
.isEqualTo(controller.llmConfig.getModelName());
choices = result.getOrchestrationResult().getChoices();
assertThat(choices.get(0).getIndex()).isZero();
assertThat(choices.get(0).getMessage().getContent()).isNotEmpty();
Expand Down Expand Up @@ -137,4 +144,29 @@ void testMaskingPseudonymization() {
.doesNotContain("MASKED_PERSON")
.contains("Mallory");
}

@Test
@DisplayName("Declared OrchestrationAiModels must match Orchestration's list of available models")
@SneakyThrows
void orchestrationModelAvailability() {
// TODO: Change this test to be like ScenarioTest.openAiModelAvailability() once the
// Orchestration service has made the "available models endpoint".
// Right now this test cannot tell if we are lacking models.
newtork marked this conversation as resolved.
Show resolved Hide resolved

// Gather our declared Orchestration models
List<OrchestrationAiModel> declaredOrchestrationModelList = new ArrayList<>();
for (Field field : OrchestrationAiModel.class.getFields()) {
if (field.getType().equals(OrchestrationAiModel.class)) {
declaredOrchestrationModelList.add(((OrchestrationAiModel) field.get(null)));
}
}

declaredOrchestrationModelList.parallelStream()
.forEach(
model -> {
controller.llmConfig = model;
log.info("Testing completion for model: {}", model.getModelName());
assertThat(controller.completion()).isNotNull();
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;

public class ScenarioTest {
class ScenarioTest {

@Test
@DisplayName("Declared OpenAI models must match AI Core's available OpenAI models")
@SneakyThrows
public void openAiModelAvailability() {
void openAiModelAvailability() {

// Gather AI Core's list of available OpenAI models
final var aiModelList = new ScenarioController().getModels().getResources();
Expand Down