Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Orchestration LLM Config Convenience #152

Merged
merged 23 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions docs/guides/ORCHESTRATION_CHAT_COMPLETION.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ To use the Orchestration service, create a client and a configuration object:
var client = new OrchestrationClient();

var config = new OrchestrationModuleConfig()
.withLlmConfig(LLMModuleConfig.create().modelName("gpt-35-turbo").modelParams(Map.of()));
.withLlmConfig(OrchestrationAiModel.GPT_4O);
```

Please also refer to [our sample code](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java) for this and all following code examples.
Expand Down Expand Up @@ -214,16 +214,16 @@ In this example, the input will be masked before the call to the LLM. Note that

### Set model parameters

Change your LLM module configuration to add model parameters:
Change your LLM configuration to add model parameters:

```java
var llmConfig =
LLMModuleConfig.create()
.modelName("gpt-35-turbo")
.modelParams(
OrchestrationAiModel customGPT4O =
OrchestrationAiModel.GPT_4O
.withModelParams(
Map.of(
"max_tokens", 50,
"temperature", 0.1,
"frequency_penalty", 0,
"presence_penalty", 0));
"presence_penalty", 0))
.withModelVersion("2024-05-13");
Comment on lines +220 to +228
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Minor/Question)

Could/should we shorten the method names...?

  OrchestrationAiModel.GPT_4O
-   .withModelParams(...)
-   .withModelVersion("2024-05-13");
+   .withParams(...)
+   .withVersion("2024-05-13");

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can discuss here:
#170

```
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package com.sap.ai.sdk.orchestration;

import com.sap.ai.sdk.core.AiModel;
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.With;

/** Large language models available in Orchestration. */
Jonas-Isr marked this conversation as resolved.
Show resolved Hide resolved
@With
@AllArgsConstructor(access = AccessLevel.PRIVATE)
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
public class OrchestrationAiModel implements AiModel {
/** The name of the model */
private String modelName;
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved

/** The version of the model, defaults to "latest". */
private String modelVersion = "latest";

/**
* Optional parameters on this model.
*
* <pre>{@code
* Map.of(
* "max_tokens", 50,
* "temperature", 0.1,
* "frequency_penalty", 0,
* "presence_penalty", 0)
* }</pre>
*/
private Map<String, Object> modelParams;
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved

/** IBM Granite 13B chat completions model */
public static final OrchestrationAiModel IBM_GRANITE_13B_CHAT =
new OrchestrationAiModel("ibm--granite-13b-chat");

/** MistralAI Mistral Large Instruct model */
public static final OrchestrationAiModel MISTRAL_LARGE_INSTRUCT =
new OrchestrationAiModel("mistralai--mistral-large-instruct");

/** MistralAI Mixtral 8x7B Instruct v01 model */
public static final OrchestrationAiModel MIXTRAL_8X7B_INSTRUCT_V01 =
new OrchestrationAiModel("mistralai--mixtral-8x7b-instruct-v01");

/** Meta Llama3 70B Instruct model */
public static final OrchestrationAiModel LLAMA3_70B_INSTRUCT =
new OrchestrationAiModel("meta--llama3-70b-instruct");

/** Meta Llama3.1 70B Instruct model */
public static final OrchestrationAiModel LLAMA3_1_70B_INSTRUCT =
new OrchestrationAiModel("meta--llama3.1-70b-instruct");

/** Anthropic Claude 3 Sonnet model */
public static final OrchestrationAiModel CLAUDE_3_SONNET =
new OrchestrationAiModel("anthropic--claude-3-sonnet");

/** Anthropic Claude 3 Haiku model */
public static final OrchestrationAiModel CLAUDE_3_HAIKU =
new OrchestrationAiModel("anthropic--claude-3-haiku");

/** Anthropic Claude 3 Opus model */
public static final OrchestrationAiModel CLAUDE_3_OPUS =
new OrchestrationAiModel("anthropic--claude-3-opus");

/** Anthropic Claude 3.5 Sonnet model */
public static final OrchestrationAiModel CLAUDE_3_5_SONNET =
new OrchestrationAiModel("anthropic--claude-3.5-sonnet");

/** Amazon Titan Text Lite model */
public static final OrchestrationAiModel TITAN_TEXT_LITE =
new OrchestrationAiModel("amazon--titan-text-lite");

/** Amazon Titan Text Express model */
public static final OrchestrationAiModel TITAN_TEXT_EXPRESS =
new OrchestrationAiModel("amazon--titan-text-express");

/** Azure OpenAI GPT-3.5 Turbo chat completions model */
public static final OrchestrationAiModel GPT_35_TURBO = new OrchestrationAiModel("gpt-35-turbo");

/** Azure OpenAI GPT-3.5 Turbo chat completions model */
public static final OrchestrationAiModel GPT_35_TURBO_16K =
new OrchestrationAiModel("gpt-35-turbo-16k");

/** Azure OpenAI GPT-4 chat completions model */
public static final OrchestrationAiModel GPT_4 = new OrchestrationAiModel("gpt-4");

/** Azure OpenAI GPT-4-32k chat completions model */
public static final OrchestrationAiModel GPT_4_32K = new OrchestrationAiModel("gpt-4-32k");

/** Azure OpenAI GPT-4o chat completions model */
public static final OrchestrationAiModel GPT_4O = new OrchestrationAiModel("gpt-4o");

/** Azure OpenAI GPT-4o-mini chat completions model */
public static final OrchestrationAiModel GPT_4O_MINI = new OrchestrationAiModel("gpt-4o-mini");

/** Google Cloud Platform Gemini 1.0 Pro model */
public static final OrchestrationAiModel GEMINI_1_0_PRO =
new OrchestrationAiModel("gemini-1.0-pro");

/** Google Cloud Platform Gemini 1.5 Pro model */
public static final OrchestrationAiModel GEMINI_1_5_PRO =
new OrchestrationAiModel("gemini-1.5-pro");

/** Google Cloud Platform Gemini 1.5 Flash model */
public static final OrchestrationAiModel GEMINI_1_5_FLASH =
new OrchestrationAiModel("gemini-1.5-flash");

OrchestrationAiModel(@Nonnull final String modelName) {
newtork marked this conversation as resolved.
Show resolved Hide resolved
this.modelName = modelName;
}

@Nonnull
LLMModuleConfig createConfig() {
return new LLMModuleConfig()
.modelName(modelName)
.modelParams(modelParams)
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
.modelVersion(modelVersion);
}

/** {@inheritDoc} */
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
@Nonnull
@Override
public String name() {
return modelName;
}

/** {@inheritDoc} */
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
@Nullable
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
@Override
public String version() {
return modelVersion;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
Expand All @@ -28,7 +29,6 @@
* </ul>
*/
@Value
@With
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
@AllArgsConstructor(access = AccessLevel.PRIVATE)
@NoArgsConstructor(force = true)
public class OrchestrationModuleConfig {
Expand All @@ -41,11 +41,23 @@ public class OrchestrationModuleConfig {
* A template to be populated with input parameters. Upon request execution, this template will be
* enhanced with any messages and parameter values from {@link OrchestrationPrompt}.
*/
@Nullable TemplatingModuleConfig templateConfig;
@With @Nullable TemplatingModuleConfig templateConfig;

/** A masking configuration to pseudonymous or anonymize sensitive data in the input. */
@Nullable MaskingModuleConfig maskingConfig;
@With @Nullable MaskingModuleConfig maskingConfig;

/** A content filter to filter the prompt. */
@Nullable FilteringModuleConfig filteringConfig;
@With @Nullable FilteringModuleConfig filteringConfig;

/**
* Creates a new configuration with the given LLM configuration.
*
* @param aiModel The LLM configuration to use.
* @return A new configuration with the given LLM configuration.
*/
@Nonnull
public OrchestrationModuleConfig withLlmConfig(@Nonnull final OrchestrationAiModel aiModel) {
return new OrchestrationModuleConfig(
aiModel.createConfig(), templateConfig, maskingConfig, filteringConfig);
MatKuhr marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.sap.ai.sdk.orchestration;

import static com.sap.ai.sdk.orchestration.OrchestrationUnitTest.LLM_CONFIG;
import static com.sap.ai.sdk.orchestration.OrchestrationUnitTest.CUSTOM_GPT_35;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

Expand Down Expand Up @@ -71,7 +71,7 @@ void testMessagesHistory() {
var prompt = new OrchestrationPrompt("bar").messageHistory(List.of(systemMessage));
var actual =
ConfigToRequestTransformer.toCompletionPostRequest(
prompt, new OrchestrationModuleConfig().withLlmConfig(LLM_CONFIG));
prompt, new OrchestrationModuleConfig().withLlmConfig(CUSTOM_GPT_35));

assertThat(actual.getMessagesHistory()).containsExactly(systemMessage);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import static com.github.tomakehurst.wiremock.client.WireMock.stubFor;
import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo;
import static com.github.tomakehurst.wiremock.client.WireMock.verify;
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_35_TURBO_16K;
import static com.sap.ai.sdk.orchestration.client.model.AzureThreshold.NUMBER_0;
import static com.sap.ai.sdk.orchestration.client.model.AzureThreshold.NUMBER_4;
import static org.apache.hc.core5.http.HttpStatus.SC_BAD_REQUEST;
Expand All @@ -39,7 +40,6 @@
import com.sap.ai.sdk.orchestration.client.model.FilteringModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.GenericModuleResult;
import com.sap.ai.sdk.orchestration.client.model.InputFilteringConfig;
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous;
import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.OutputFilteringConfig;
Expand All @@ -62,10 +62,9 @@
*/
@WireMockTest
class OrchestrationUnitTest {
static final LLMModuleConfig LLM_CONFIG =
new LLMModuleConfig()
.modelName("gpt-35-turbo-16k")
.modelParams(
static final OrchestrationAiModel CUSTOM_GPT_35 =
GPT_35_TURBO_16K
.withModelParams(
Map.of(
"max_tokens", 50,
"temperature", 0.1,
Expand Down Expand Up @@ -106,7 +105,7 @@ void setup(WireMockRuntimeInfo server) {
.forDeploymentByScenario("orchestration")
.withResourceGroup("my-resource-group");
client = new OrchestrationClient(deployment);
config = new OrchestrationModuleConfig().withLlmConfig(LLM_CONFIG);
config = new OrchestrationModuleConfig().withLlmConfig(CUSTOM_GPT_35);
prompt = new OrchestrationPrompt("Hello World! Why is this phrase so famous?");
}

Expand Down Expand Up @@ -146,6 +145,7 @@ void testTemplating() throws IOException {
.isEqualTo("Reply with 'Orchestration Service is working!' in German");
assertThat(response.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user");
var llm = (LLMModuleResultSynchronous) response.getModuleResults().getLlm();
assertThat(llm).isNotNull();
assertThat(llm.getId()).isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj");
assertThat(llm.getObject()).isEqualTo("chat.completion");
assertThat(llm.getCreated()).isEqualTo(1721224505);
Expand Down Expand Up @@ -315,6 +315,7 @@ void maskingPseudonymization() throws IOException {

assertThat(response).isNotNull();
GenericModuleResult inputMasking = response.getModuleResults().getInputMasking();
assertThat(inputMasking).isNotNull();
assertThat(inputMasking.getMessage()).isEqualTo("Input to LLM is masked successfully.");
assertThat(inputMasking.getData()).isNotNull();
assertThat(result.getContent()).contains("Hi Mallory");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.sap.ai.sdk.app.controllers;

import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_35_TURBO;

import com.sap.ai.sdk.orchestration.OrchestrationChatResponse;
import com.sap.ai.sdk.orchestration.OrchestrationClient;
import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig;
Expand All @@ -13,7 +15,6 @@
import com.sap.ai.sdk.orchestration.client.model.DPIEntityConfig;
import com.sap.ai.sdk.orchestration.client.model.FilteringModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.InputFilteringConfig;
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.OutputFilteringConfig;
import com.sap.ai.sdk.orchestration.client.model.Template;
Expand All @@ -30,12 +31,8 @@
@RestController
@RequestMapping("/orchestration")
class OrchestrationController {
static final LLMModuleConfig LLM_CONFIG =
new LLMModuleConfig().modelName("gpt-35-turbo").modelParams(Map.of());

private final OrchestrationClient client = new OrchestrationClient();
private final OrchestrationModuleConfig config =
new OrchestrationModuleConfig().withLlmConfig(LLM_CONFIG);
OrchestrationModuleConfig config = new OrchestrationModuleConfig().withLlmConfig(GPT_35_TURBO);

/**
* Chat request to OpenAI through the Orchestration service with a simple prompt.
Expand Down Expand Up @@ -170,7 +167,7 @@ public OrchestrationChatResponse maskingAnonymization() {

/**
* Let the orchestration service a response to a hypothetical user who provided feedback on the AI
* SDK. Pseydonymize the user's name and location to protect their privacy.
* SDK. Pseudonymize the user's name and location to protect their privacy.
*
* @return the result object
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.assertj.core.api.InstanceOfAssertFactories;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

@Slf4j
class OrchestrationTest {
OrchestrationController controller;

Expand All @@ -32,6 +34,9 @@ void testCompletion() {

@Test
void testTemplate() {
assertThat(controller.config.getLlmConfig()).isNotNull();
final var modelName = controller.config.getLlmConfig().getModelName();

final var response = controller.template();
final var result = response.getOriginalResponse();

Expand All @@ -43,7 +48,7 @@ void testTemplate() {
assertThat(llm.getId()).isNotEmpty();
assertThat(llm.getObject()).isEqualTo("chat.completion");
assertThat(llm.getCreated()).isGreaterThan(1);
assertThat(llm.getModel()).isEqualTo(OrchestrationController.LLM_CONFIG.getModelName());
assertThat(llm.getModel()).isEqualTo(modelName);
var choices = llm.getChoices();
assertThat(choices.get(0).getIndex()).isZero();
assertThat(choices.get(0).getMessage().getContent()).isNotEmpty();
Expand All @@ -57,8 +62,7 @@ void testTemplate() {
var orchestrationResult = ((LLMModuleResultSynchronous) result.getOrchestrationResult());
assertThat(orchestrationResult.getObject()).isEqualTo("chat.completion");
assertThat(orchestrationResult.getCreated()).isGreaterThan(1);
assertThat(orchestrationResult.getModel())
.isEqualTo(OrchestrationController.LLM_CONFIG.getModelName());
assertThat(orchestrationResult.getModel()).isEqualTo(modelName);
choices = orchestrationResult.getChoices();
assertThat(choices.get(0).getIndex()).isZero();
assertThat(choices.get(0).getMessage().getContent()).isNotEmpty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;

public class ScenarioTest {
class ScenarioTest {

@Test
@DisplayName("Declared OpenAI models must match AI Core's available OpenAI models")
@SneakyThrows
public void openAiModelAvailability() {
void openAiModelAvailability() {

// Gather AI Core's list of available OpenAI models
final var aiModelList = new ScenarioController().getModels().getResources();
Expand Down