diff --git a/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md b/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md index 31adb89a5..87d6d3f18 100644 --- a/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md +++ b/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md @@ -77,7 +77,7 @@ To use the Orchestration service, create a client and a configuration object: var client = new OrchestrationClient(); var config = new OrchestrationModuleConfig() - .withLlmConfig(LLMModuleConfig.create().modelName("gpt-35-turbo").modelParams(Map.of())); + .withLlmConfig(OrchestrationAiModel.GPT_4O); ``` Please also refer to [our sample code](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java) for this and all following code examples. @@ -214,16 +214,16 @@ In this example, the input will be masked before the call to the LLM. Note that ### Set model parameters -Change your LLM module configuration to add model parameters: +Change your LLM configuration to add model parameters: ```java -var llmConfig = - LLMModuleConfig.create() - .modelName("gpt-35-turbo") - .modelParams( +OrchestrationAiModel customGPT4O = + OrchestrationAiModel.GPT_4O + .withModelParams( Map.of( "max_tokens", 50, "temperature", 0.1, "frequency_penalty", 0, - "presence_penalty", 0)); + "presence_penalty", 0)) + .withModelVersion("2024-05-13"); ``` diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationAiModel.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationAiModel.java new file mode 100644 index 000000000..7f806ece7 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationAiModel.java @@ -0,0 +1,120 @@ +package com.sap.ai.sdk.orchestration; + +import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; +import java.util.Map; +import javax.annotation.Nonnull; +import lombok.AllArgsConstructor; +import lombok.Value; +import lombok.With; + +/** Large language models available in Orchestration. */ +@Value +@With +@AllArgsConstructor +public class OrchestrationAiModel { + /** The name of the model */ + String modelName; + + /** + * Optional parameters on this model. + * + *
{@code
+   * Map.of(
+   *     "max_tokens", 50,
+   *     "temperature", 0.1,
+   *     "frequency_penalty", 0,
+   *     "presence_penalty", 0)
+   * }
+ */ + Map modelParams; + + /** The version of the model, defaults to "latest". */ + String modelVersion; + + /** IBM Granite 13B chat completions model */ + public static final OrchestrationAiModel IBM_GRANITE_13B_CHAT = + new OrchestrationAiModel("ibm--granite-13b-chat"); + + /** MistralAI Mistral Large Instruct model */ + public static final OrchestrationAiModel MISTRAL_LARGE_INSTRUCT = + new OrchestrationAiModel("mistralai--mistral-large-instruct"); + + /** MistralAI Mixtral 8x7B Instruct v01 model */ + public static final OrchestrationAiModel MIXTRAL_8X7B_INSTRUCT_V01 = + new OrchestrationAiModel("mistralai--mixtral-8x7b-instruct-v01"); + + /** Meta Llama3 70B Instruct model */ + public static final OrchestrationAiModel LLAMA3_70B_INSTRUCT = + new OrchestrationAiModel("meta--llama3-70b-instruct"); + + /** Meta Llama3.1 70B Instruct model */ + public static final OrchestrationAiModel LLAMA3_1_70B_INSTRUCT = + new OrchestrationAiModel("meta--llama3.1-70b-instruct"); + + /** Anthropic Claude 3 Sonnet model */ + public static final OrchestrationAiModel CLAUDE_3_SONNET = + new OrchestrationAiModel("anthropic--claude-3-sonnet"); + + /** Anthropic Claude 3 Haiku model */ + public static final OrchestrationAiModel CLAUDE_3_HAIKU = + new OrchestrationAiModel("anthropic--claude-3-haiku"); + + /** Anthropic Claude 3 Opus model */ + public static final OrchestrationAiModel CLAUDE_3_OPUS = + new OrchestrationAiModel("anthropic--claude-3-opus"); + + /** Anthropic Claude 3.5 Sonnet model */ + public static final OrchestrationAiModel CLAUDE_3_5_SONNET = + new OrchestrationAiModel("anthropic--claude-3.5-sonnet"); + + /** Amazon Titan Text Lite model */ + public static final OrchestrationAiModel TITAN_TEXT_LITE = + new OrchestrationAiModel("amazon--titan-text-lite"); + + /** Amazon Titan Text Express model */ + public static final OrchestrationAiModel TITAN_TEXT_EXPRESS = + new OrchestrationAiModel("amazon--titan-text-express"); + + /** Azure OpenAI GPT-3.5 Turbo chat completions model */ + public static final OrchestrationAiModel GPT_35_TURBO = new OrchestrationAiModel("gpt-35-turbo"); + + /** Azure OpenAI GPT-3.5 Turbo chat completions model */ + public static final OrchestrationAiModel GPT_35_TURBO_16K = + new OrchestrationAiModel("gpt-35-turbo-16k"); + + /** Azure OpenAI GPT-4 chat completions model */ + public static final OrchestrationAiModel GPT_4 = new OrchestrationAiModel("gpt-4"); + + /** Azure OpenAI GPT-4-32k chat completions model */ + public static final OrchestrationAiModel GPT_4_32K = new OrchestrationAiModel("gpt-4-32k"); + + /** Azure OpenAI GPT-4o chat completions model */ + public static final OrchestrationAiModel GPT_4O = new OrchestrationAiModel("gpt-4o"); + + /** Azure OpenAI GPT-4o-mini chat completions model */ + public static final OrchestrationAiModel GPT_4O_MINI = new OrchestrationAiModel("gpt-4o-mini"); + + /** Google Cloud Platform Gemini 1.0 Pro model */ + public static final OrchestrationAiModel GEMINI_1_0_PRO = + new OrchestrationAiModel("gemini-1.0-pro"); + + /** Google Cloud Platform Gemini 1.5 Pro model */ + public static final OrchestrationAiModel GEMINI_1_5_PRO = + new OrchestrationAiModel("gemini-1.5-pro"); + + /** Google Cloud Platform Gemini 1.5 Flash model */ + public static final OrchestrationAiModel GEMINI_1_5_FLASH = + new OrchestrationAiModel("gemini-1.5-flash"); + + OrchestrationAiModel(@Nonnull final String modelName) { + this(modelName, Map.of(), "latest"); + } + + @Nonnull + LLMModuleConfig createConfig() { + return new LLMModuleConfig() + .modelName(modelName) + .modelParams(modelParams) + .modelVersion(modelVersion); + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java index ee8771d49..b0f60e511 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java @@ -4,12 +4,14 @@ import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig; import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import lombok.AccessLevel; import lombok.AllArgsConstructor; import lombok.NoArgsConstructor; import lombok.Value; import lombok.With; +import lombok.experimental.Tolerate; /** * Represents the configuration for the orchestration service. Allows for configuring the different @@ -48,4 +50,16 @@ public class OrchestrationModuleConfig { /** A content filter to filter the prompt. */ @Nullable FilteringModuleConfig filteringConfig; + + /** + * Creates a new configuration with the given LLM configuration. + * + * @param aiModel The LLM configuration to use. + * @return A new configuration with the given LLM configuration. + */ + @Tolerate + @Nonnull + public OrchestrationModuleConfig withLlmConfig(@Nonnull final OrchestrationAiModel aiModel) { + return withLlmConfig(aiModel.createConfig()); + } } diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformerTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformerTest.java index 141b882f1..19ba9166d 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformerTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformerTest.java @@ -1,6 +1,6 @@ package com.sap.ai.sdk.orchestration; -import static com.sap.ai.sdk.orchestration.OrchestrationUnitTest.LLM_CONFIG; +import static com.sap.ai.sdk.orchestration.OrchestrationUnitTest.CUSTOM_GPT_35; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -71,7 +71,7 @@ void testMessagesHistory() { var prompt = new OrchestrationPrompt("bar").messageHistory(List.of(systemMessage)); var actual = ConfigToRequestTransformer.toCompletionPostRequest( - prompt, new OrchestrationModuleConfig().withLlmConfig(LLM_CONFIG)); + prompt, new OrchestrationModuleConfig().withLlmConfig(CUSTOM_GPT_35)); assertThat(actual.getMessagesHistory()).containsExactly(systemMessage); } diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java index 2af45a429..298408b08 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java @@ -16,6 +16,7 @@ import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; import static com.github.tomakehurst.wiremock.client.WireMock.verify; +import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_35_TURBO_16K; import static com.sap.ai.sdk.orchestration.client.model.AzureThreshold.NUMBER_0; import static com.sap.ai.sdk.orchestration.client.model.AzureThreshold.NUMBER_4; import static org.apache.hc.core5.http.HttpStatus.SC_BAD_REQUEST; @@ -39,7 +40,6 @@ import com.sap.ai.sdk.orchestration.client.model.FilteringModuleConfig; import com.sap.ai.sdk.orchestration.client.model.GenericModuleResult; import com.sap.ai.sdk.orchestration.client.model.InputFilteringConfig; -import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous; import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig; import com.sap.ai.sdk.orchestration.client.model.OutputFilteringConfig; @@ -62,15 +62,13 @@ */ @WireMockTest class OrchestrationUnitTest { - static final LLMModuleConfig LLM_CONFIG = - new LLMModuleConfig() - .modelName("gpt-35-turbo-16k") - .modelParams( - Map.of( - "max_tokens", 50, - "temperature", 0.1, - "frequency_penalty", 0, - "presence_penalty", 0)); + static final OrchestrationAiModel CUSTOM_GPT_35 = + GPT_35_TURBO_16K.withModelParams( + Map.of( + "max_tokens", 50, + "temperature", 0.1, + "frequency_penalty", 0, + "presence_penalty", 0)); private final Function fileLoader = filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename)); @@ -106,7 +104,7 @@ void setup(WireMockRuntimeInfo server) { .forDeploymentByScenario("orchestration") .withResourceGroup("my-resource-group"); client = new OrchestrationClient(deployment); - config = new OrchestrationModuleConfig().withLlmConfig(LLM_CONFIG); + config = new OrchestrationModuleConfig().withLlmConfig(CUSTOM_GPT_35); prompt = new OrchestrationPrompt("Hello World! Why is this phrase so famous?"); } @@ -146,6 +144,7 @@ void testTemplating() throws IOException { .isEqualTo("Reply with 'Orchestration Service is working!' in German"); assertThat(response.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user"); var llm = (LLMModuleResultSynchronous) response.getModuleResults().getLlm(); + assertThat(llm).isNotNull(); assertThat(llm.getId()).isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj"); assertThat(llm.getObject()).isEqualTo("chat.completion"); assertThat(llm.getCreated()).isEqualTo(1721224505); @@ -315,6 +314,7 @@ void maskingPseudonymization() throws IOException { assertThat(response).isNotNull(); GenericModuleResult inputMasking = response.getModuleResults().getInputMasking(); + assertThat(inputMasking).isNotNull(); assertThat(inputMasking.getMessage()).isEqualTo("Input to LLM is masked successfully."); assertThat(inputMasking.getData()).isNotNull(); assertThat(result.getContent()).contains("Hi Mallory"); diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java index ccf03d943..9d4f99435 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java @@ -1,5 +1,7 @@ package com.sap.ai.sdk.app.controllers; +import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_35_TURBO; + import com.sap.ai.sdk.orchestration.OrchestrationChatResponse; import com.sap.ai.sdk.orchestration.OrchestrationClient; import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig; @@ -13,7 +15,6 @@ import com.sap.ai.sdk.orchestration.client.model.DPIEntityConfig; import com.sap.ai.sdk.orchestration.client.model.FilteringModuleConfig; import com.sap.ai.sdk.orchestration.client.model.InputFilteringConfig; -import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig; import com.sap.ai.sdk.orchestration.client.model.OutputFilteringConfig; import com.sap.ai.sdk.orchestration.client.model.Template; @@ -30,12 +31,8 @@ @RestController @RequestMapping("/orchestration") class OrchestrationController { - static final LLMModuleConfig LLM_CONFIG = - new LLMModuleConfig().modelName("gpt-35-turbo").modelParams(Map.of()); - private final OrchestrationClient client = new OrchestrationClient(); - private final OrchestrationModuleConfig config = - new OrchestrationModuleConfig().withLlmConfig(LLM_CONFIG); + OrchestrationModuleConfig config = new OrchestrationModuleConfig().withLlmConfig(GPT_35_TURBO); /** * Chat request to OpenAI through the Orchestration service with a simple prompt. @@ -170,7 +167,7 @@ public OrchestrationChatResponse maskingAnonymization() { /** * Let the orchestration service a response to a hypothetical user who provided feedback on the AI - * SDK. Pseydonymize the user's name and location to protect their privacy. + * SDK. Pseudonymize the user's name and location to protect their privacy. * * @return the result object */ diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java index dc2d27a79..4cdc40103 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java @@ -10,10 +10,12 @@ import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous; import java.util.List; import java.util.Map; +import lombok.extern.slf4j.Slf4j; import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +@Slf4j class OrchestrationTest { OrchestrationController controller; @@ -32,6 +34,9 @@ void testCompletion() { @Test void testTemplate() { + assertThat(controller.config.getLlmConfig()).isNotNull(); + final var modelName = controller.config.getLlmConfig().getModelName(); + final var response = controller.template(); final var result = response.getOriginalResponse(); @@ -43,7 +48,7 @@ void testTemplate() { assertThat(llm.getId()).isNotEmpty(); assertThat(llm.getObject()).isEqualTo("chat.completion"); assertThat(llm.getCreated()).isGreaterThan(1); - assertThat(llm.getModel()).isEqualTo(OrchestrationController.LLM_CONFIG.getModelName()); + assertThat(llm.getModel()).isEqualTo(modelName); var choices = llm.getChoices(); assertThat(choices.get(0).getIndex()).isZero(); assertThat(choices.get(0).getMessage().getContent()).isNotEmpty(); @@ -57,8 +62,7 @@ void testTemplate() { var orchestrationResult = ((LLMModuleResultSynchronous) result.getOrchestrationResult()); assertThat(orchestrationResult.getObject()).isEqualTo("chat.completion"); assertThat(orchestrationResult.getCreated()).isGreaterThan(1); - assertThat(orchestrationResult.getModel()) - .isEqualTo(OrchestrationController.LLM_CONFIG.getModelName()); + assertThat(orchestrationResult.getModel()).isEqualTo(modelName); choices = orchestrationResult.getChoices(); assertThat(choices.get(0).getIndex()).isZero(); assertThat(choices.get(0).getMessage().getContent()).isNotEmpty(); diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/ScenarioTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/ScenarioTest.java index a13bc6cb4..6d7a417b7 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/ScenarioTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/ScenarioTest.java @@ -11,12 +11,12 @@ import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; -public class ScenarioTest { +class ScenarioTest { @Test @DisplayName("Declared OpenAI models must match AI Core's available OpenAI models") @SneakyThrows - public void openAiModelAvailability() { + void openAiModelAvailability() { // Gather AI Core's list of available OpenAI models final var aiModelList = new ScenarioController().getModels().getResources();