diff --git a/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md b/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md
index 31adb89a5..87d6d3f18 100644
--- a/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md
+++ b/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md
@@ -77,7 +77,7 @@ To use the Orchestration service, create a client and a configuration object:
var client = new OrchestrationClient();
var config = new OrchestrationModuleConfig()
- .withLlmConfig(LLMModuleConfig.create().modelName("gpt-35-turbo").modelParams(Map.of()));
+ .withLlmConfig(OrchestrationAiModel.GPT_4O);
```
Please also refer to [our sample code](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java) for this and all following code examples.
@@ -214,16 +214,16 @@ In this example, the input will be masked before the call to the LLM. Note that
### Set model parameters
-Change your LLM module configuration to add model parameters:
+Change your LLM configuration to add model parameters:
```java
-var llmConfig =
- LLMModuleConfig.create()
- .modelName("gpt-35-turbo")
- .modelParams(
+OrchestrationAiModel customGPT4O =
+ OrchestrationAiModel.GPT_4O
+ .withModelParams(
Map.of(
"max_tokens", 50,
"temperature", 0.1,
"frequency_penalty", 0,
- "presence_penalty", 0));
+ "presence_penalty", 0))
+ .withModelVersion("2024-05-13");
```
diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationAiModel.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationAiModel.java
new file mode 100644
index 000000000..7f806ece7
--- /dev/null
+++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationAiModel.java
@@ -0,0 +1,120 @@
+package com.sap.ai.sdk.orchestration;
+
+import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
+import java.util.Map;
+import javax.annotation.Nonnull;
+import lombok.AllArgsConstructor;
+import lombok.Value;
+import lombok.With;
+
+/** Large language models available in Orchestration. */
+@Value
+@With
+@AllArgsConstructor
+public class OrchestrationAiModel {
+ /** The name of the model */
+ String modelName;
+
+ /**
+ * Optional parameters on this model.
+ *
+ *
{@code
+ * Map.of(
+ * "max_tokens", 50,
+ * "temperature", 0.1,
+ * "frequency_penalty", 0,
+ * "presence_penalty", 0)
+ * }
+ */
+ Map modelParams;
+
+ /** The version of the model, defaults to "latest". */
+ String modelVersion;
+
+ /** IBM Granite 13B chat completions model */
+ public static final OrchestrationAiModel IBM_GRANITE_13B_CHAT =
+ new OrchestrationAiModel("ibm--granite-13b-chat");
+
+ /** MistralAI Mistral Large Instruct model */
+ public static final OrchestrationAiModel MISTRAL_LARGE_INSTRUCT =
+ new OrchestrationAiModel("mistralai--mistral-large-instruct");
+
+ /** MistralAI Mixtral 8x7B Instruct v01 model */
+ public static final OrchestrationAiModel MIXTRAL_8X7B_INSTRUCT_V01 =
+ new OrchestrationAiModel("mistralai--mixtral-8x7b-instruct-v01");
+
+ /** Meta Llama3 70B Instruct model */
+ public static final OrchestrationAiModel LLAMA3_70B_INSTRUCT =
+ new OrchestrationAiModel("meta--llama3-70b-instruct");
+
+ /** Meta Llama3.1 70B Instruct model */
+ public static final OrchestrationAiModel LLAMA3_1_70B_INSTRUCT =
+ new OrchestrationAiModel("meta--llama3.1-70b-instruct");
+
+ /** Anthropic Claude 3 Sonnet model */
+ public static final OrchestrationAiModel CLAUDE_3_SONNET =
+ new OrchestrationAiModel("anthropic--claude-3-sonnet");
+
+ /** Anthropic Claude 3 Haiku model */
+ public static final OrchestrationAiModel CLAUDE_3_HAIKU =
+ new OrchestrationAiModel("anthropic--claude-3-haiku");
+
+ /** Anthropic Claude 3 Opus model */
+ public static final OrchestrationAiModel CLAUDE_3_OPUS =
+ new OrchestrationAiModel("anthropic--claude-3-opus");
+
+ /** Anthropic Claude 3.5 Sonnet model */
+ public static final OrchestrationAiModel CLAUDE_3_5_SONNET =
+ new OrchestrationAiModel("anthropic--claude-3.5-sonnet");
+
+ /** Amazon Titan Text Lite model */
+ public static final OrchestrationAiModel TITAN_TEXT_LITE =
+ new OrchestrationAiModel("amazon--titan-text-lite");
+
+ /** Amazon Titan Text Express model */
+ public static final OrchestrationAiModel TITAN_TEXT_EXPRESS =
+ new OrchestrationAiModel("amazon--titan-text-express");
+
+ /** Azure OpenAI GPT-3.5 Turbo chat completions model */
+ public static final OrchestrationAiModel GPT_35_TURBO = new OrchestrationAiModel("gpt-35-turbo");
+
+ /** Azure OpenAI GPT-3.5 Turbo chat completions model */
+ public static final OrchestrationAiModel GPT_35_TURBO_16K =
+ new OrchestrationAiModel("gpt-35-turbo-16k");
+
+ /** Azure OpenAI GPT-4 chat completions model */
+ public static final OrchestrationAiModel GPT_4 = new OrchestrationAiModel("gpt-4");
+
+ /** Azure OpenAI GPT-4-32k chat completions model */
+ public static final OrchestrationAiModel GPT_4_32K = new OrchestrationAiModel("gpt-4-32k");
+
+ /** Azure OpenAI GPT-4o chat completions model */
+ public static final OrchestrationAiModel GPT_4O = new OrchestrationAiModel("gpt-4o");
+
+ /** Azure OpenAI GPT-4o-mini chat completions model */
+ public static final OrchestrationAiModel GPT_4O_MINI = new OrchestrationAiModel("gpt-4o-mini");
+
+ /** Google Cloud Platform Gemini 1.0 Pro model */
+ public static final OrchestrationAiModel GEMINI_1_0_PRO =
+ new OrchestrationAiModel("gemini-1.0-pro");
+
+ /** Google Cloud Platform Gemini 1.5 Pro model */
+ public static final OrchestrationAiModel GEMINI_1_5_PRO =
+ new OrchestrationAiModel("gemini-1.5-pro");
+
+ /** Google Cloud Platform Gemini 1.5 Flash model */
+ public static final OrchestrationAiModel GEMINI_1_5_FLASH =
+ new OrchestrationAiModel("gemini-1.5-flash");
+
+ OrchestrationAiModel(@Nonnull final String modelName) {
+ this(modelName, Map.of(), "latest");
+ }
+
+ @Nonnull
+ LLMModuleConfig createConfig() {
+ return new LLMModuleConfig()
+ .modelName(modelName)
+ .modelParams(modelParams)
+ .modelVersion(modelVersion);
+ }
+}
diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java
index ee8771d49..b0f60e511 100644
--- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java
+++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java
@@ -4,12 +4,14 @@
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig;
+import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.NoArgsConstructor;
import lombok.Value;
import lombok.With;
+import lombok.experimental.Tolerate;
/**
* Represents the configuration for the orchestration service. Allows for configuring the different
@@ -48,4 +50,16 @@ public class OrchestrationModuleConfig {
/** A content filter to filter the prompt. */
@Nullable FilteringModuleConfig filteringConfig;
+
+ /**
+ * Creates a new configuration with the given LLM configuration.
+ *
+ * @param aiModel The LLM configuration to use.
+ * @return A new configuration with the given LLM configuration.
+ */
+ @Tolerate
+ @Nonnull
+ public OrchestrationModuleConfig withLlmConfig(@Nonnull final OrchestrationAiModel aiModel) {
+ return withLlmConfig(aiModel.createConfig());
+ }
}
diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformerTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformerTest.java
index 141b882f1..19ba9166d 100644
--- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformerTest.java
+++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformerTest.java
@@ -1,6 +1,6 @@
package com.sap.ai.sdk.orchestration;
-import static com.sap.ai.sdk.orchestration.OrchestrationUnitTest.LLM_CONFIG;
+import static com.sap.ai.sdk.orchestration.OrchestrationUnitTest.CUSTOM_GPT_35;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -71,7 +71,7 @@ void testMessagesHistory() {
var prompt = new OrchestrationPrompt("bar").messageHistory(List.of(systemMessage));
var actual =
ConfigToRequestTransformer.toCompletionPostRequest(
- prompt, new OrchestrationModuleConfig().withLlmConfig(LLM_CONFIG));
+ prompt, new OrchestrationModuleConfig().withLlmConfig(CUSTOM_GPT_35));
assertThat(actual.getMessagesHistory()).containsExactly(systemMessage);
}
diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java
index 2af45a429..298408b08 100644
--- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java
+++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java
@@ -16,6 +16,7 @@
import static com.github.tomakehurst.wiremock.client.WireMock.stubFor;
import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo;
import static com.github.tomakehurst.wiremock.client.WireMock.verify;
+import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_35_TURBO_16K;
import static com.sap.ai.sdk.orchestration.client.model.AzureThreshold.NUMBER_0;
import static com.sap.ai.sdk.orchestration.client.model.AzureThreshold.NUMBER_4;
import static org.apache.hc.core5.http.HttpStatus.SC_BAD_REQUEST;
@@ -39,7 +40,6 @@
import com.sap.ai.sdk.orchestration.client.model.FilteringModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.GenericModuleResult;
import com.sap.ai.sdk.orchestration.client.model.InputFilteringConfig;
-import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous;
import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.OutputFilteringConfig;
@@ -62,15 +62,13 @@
*/
@WireMockTest
class OrchestrationUnitTest {
- static final LLMModuleConfig LLM_CONFIG =
- new LLMModuleConfig()
- .modelName("gpt-35-turbo-16k")
- .modelParams(
- Map.of(
- "max_tokens", 50,
- "temperature", 0.1,
- "frequency_penalty", 0,
- "presence_penalty", 0));
+ static final OrchestrationAiModel CUSTOM_GPT_35 =
+ GPT_35_TURBO_16K.withModelParams(
+ Map.of(
+ "max_tokens", 50,
+ "temperature", 0.1,
+ "frequency_penalty", 0,
+ "presence_penalty", 0));
private final Function fileLoader =
filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename));
@@ -106,7 +104,7 @@ void setup(WireMockRuntimeInfo server) {
.forDeploymentByScenario("orchestration")
.withResourceGroup("my-resource-group");
client = new OrchestrationClient(deployment);
- config = new OrchestrationModuleConfig().withLlmConfig(LLM_CONFIG);
+ config = new OrchestrationModuleConfig().withLlmConfig(CUSTOM_GPT_35);
prompt = new OrchestrationPrompt("Hello World! Why is this phrase so famous?");
}
@@ -146,6 +144,7 @@ void testTemplating() throws IOException {
.isEqualTo("Reply with 'Orchestration Service is working!' in German");
assertThat(response.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user");
var llm = (LLMModuleResultSynchronous) response.getModuleResults().getLlm();
+ assertThat(llm).isNotNull();
assertThat(llm.getId()).isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj");
assertThat(llm.getObject()).isEqualTo("chat.completion");
assertThat(llm.getCreated()).isEqualTo(1721224505);
@@ -315,6 +314,7 @@ void maskingPseudonymization() throws IOException {
assertThat(response).isNotNull();
GenericModuleResult inputMasking = response.getModuleResults().getInputMasking();
+ assertThat(inputMasking).isNotNull();
assertThat(inputMasking.getMessage()).isEqualTo("Input to LLM is masked successfully.");
assertThat(inputMasking.getData()).isNotNull();
assertThat(result.getContent()).contains("Hi Mallory");
diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java
index ccf03d943..9d4f99435 100644
--- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java
+++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java
@@ -1,5 +1,7 @@
package com.sap.ai.sdk.app.controllers;
+import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_35_TURBO;
+
import com.sap.ai.sdk.orchestration.OrchestrationChatResponse;
import com.sap.ai.sdk.orchestration.OrchestrationClient;
import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig;
@@ -13,7 +15,6 @@
import com.sap.ai.sdk.orchestration.client.model.DPIEntityConfig;
import com.sap.ai.sdk.orchestration.client.model.FilteringModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.InputFilteringConfig;
-import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.OutputFilteringConfig;
import com.sap.ai.sdk.orchestration.client.model.Template;
@@ -30,12 +31,8 @@
@RestController
@RequestMapping("/orchestration")
class OrchestrationController {
- static final LLMModuleConfig LLM_CONFIG =
- new LLMModuleConfig().modelName("gpt-35-turbo").modelParams(Map.of());
-
private final OrchestrationClient client = new OrchestrationClient();
- private final OrchestrationModuleConfig config =
- new OrchestrationModuleConfig().withLlmConfig(LLM_CONFIG);
+ OrchestrationModuleConfig config = new OrchestrationModuleConfig().withLlmConfig(GPT_35_TURBO);
/**
* Chat request to OpenAI through the Orchestration service with a simple prompt.
@@ -170,7 +167,7 @@ public OrchestrationChatResponse maskingAnonymization() {
/**
* Let the orchestration service a response to a hypothetical user who provided feedback on the AI
- * SDK. Pseydonymize the user's name and location to protect their privacy.
+ * SDK. Pseudonymize the user's name and location to protect their privacy.
*
* @return the result object
*/
diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java
index dc2d27a79..4cdc40103 100644
--- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java
+++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java
@@ -10,10 +10,12 @@
import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous;
import java.util.List;
import java.util.Map;
+import lombok.extern.slf4j.Slf4j;
import org.assertj.core.api.InstanceOfAssertFactories;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
+@Slf4j
class OrchestrationTest {
OrchestrationController controller;
@@ -32,6 +34,9 @@ void testCompletion() {
@Test
void testTemplate() {
+ assertThat(controller.config.getLlmConfig()).isNotNull();
+ final var modelName = controller.config.getLlmConfig().getModelName();
+
final var response = controller.template();
final var result = response.getOriginalResponse();
@@ -43,7 +48,7 @@ void testTemplate() {
assertThat(llm.getId()).isNotEmpty();
assertThat(llm.getObject()).isEqualTo("chat.completion");
assertThat(llm.getCreated()).isGreaterThan(1);
- assertThat(llm.getModel()).isEqualTo(OrchestrationController.LLM_CONFIG.getModelName());
+ assertThat(llm.getModel()).isEqualTo(modelName);
var choices = llm.getChoices();
assertThat(choices.get(0).getIndex()).isZero();
assertThat(choices.get(0).getMessage().getContent()).isNotEmpty();
@@ -57,8 +62,7 @@ void testTemplate() {
var orchestrationResult = ((LLMModuleResultSynchronous) result.getOrchestrationResult());
assertThat(orchestrationResult.getObject()).isEqualTo("chat.completion");
assertThat(orchestrationResult.getCreated()).isGreaterThan(1);
- assertThat(orchestrationResult.getModel())
- .isEqualTo(OrchestrationController.LLM_CONFIG.getModelName());
+ assertThat(orchestrationResult.getModel()).isEqualTo(modelName);
choices = orchestrationResult.getChoices();
assertThat(choices.get(0).getIndex()).isZero();
assertThat(choices.get(0).getMessage().getContent()).isNotEmpty();
diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/ScenarioTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/ScenarioTest.java
index a13bc6cb4..6d7a417b7 100644
--- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/ScenarioTest.java
+++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/ScenarioTest.java
@@ -11,12 +11,12 @@
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
-public class ScenarioTest {
+class ScenarioTest {
@Test
@DisplayName("Declared OpenAI models must match AI Core's available OpenAI models")
@SneakyThrows
- public void openAiModelAvailability() {
+ void openAiModelAvailability() {
// Gather AI Core's list of available OpenAI models
final var aiModelList = new ScenarioController().getModels().getResources();