Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Orchestration Prompt #140

Merged
merged 18 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package com.sap.ai.sdk.orchestration;

import com.sap.ai.sdk.orchestration.client.model.CompletionPostRequest;
import com.sap.ai.sdk.orchestration.client.model.ModuleConfigs;
import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig;
import java.util.ArrayList;
import javax.annotation.Nonnull;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import lombok.val;

/** Factory to create all DTOs from an orchestration configuration. */
@NoArgsConstructor(access = AccessLevel.NONE)
final class ModuleConfigFactory {
@Nonnull
static CompletionPostRequest toCompletionPostRequestDto(
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
@Nonnull final OrchestrationPrompt prompt, @Nonnull final ModuleConfigs config) {
// copying is required because we have to merge the prompt into the template config
// also, users may modify the object before request execution
val configCopy = copyModuleConfigs(config);
configCopy.setTemplatingModuleConfig(
toTemplateModuleConfigDto(prompt, config.getTemplatingModuleConfig()));

return CompletionPostRequest.create()
.orchestrationConfig(
com.sap.ai.sdk.orchestration.client.model.OrchestrationConfig.create()
.moduleConfigurations(configCopy))
.inputParams(prompt.getTemplateParameters());
}

@Nonnull
static TemplatingModuleConfig toTemplateModuleConfigDto(
@Nonnull final OrchestrationPrompt prompt, @Nonnull final TemplatingModuleConfig template) {
/*
* Currently, we have to merge the prompt into the template configuration.
* This works around the limitation that the template config isn't optional.
* This comes at the risk that the prompt unintentionally contains the templating pattern "{{? .. }}".
* In this case, the request will fail, since the templating module will try to resolve the parameter.
* To be fixed with https://github.tools.sap/AI/llm-orchestration/issues/662
*/
val messagesWithPrompt = new ArrayList<>(template.getTemplate());
messagesWithPrompt.addAll(prompt.getMessages());
if (messagesWithPrompt.isEmpty()) {
throw new IllegalStateException(
"A prompt is required. Pass at least one message or configure a template with messages or a template reference.");
}
return TemplatingModuleConfig.create().template(messagesWithPrompt);
}

static ModuleConfigs copyModuleConfigs(@Nonnull final ModuleConfigs configs) {
return ModuleConfigs.create()
.llmModuleConfig(configs.getLlmModuleConfig())
.templatingModuleConfig(configs.getTemplatingModuleConfig())
.maskingModuleConfig(configs.getMaskingModuleConfig())
.filteringModuleConfig(configs.getFilteringModuleConfig());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import com.sap.ai.sdk.core.AiCoreService;
import com.sap.ai.sdk.orchestration.client.model.CompletionPostRequest;
import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse;
import com.sap.ai.sdk.orchestration.client.model.ModuleConfigs;
import com.sap.ai.sdk.orchestration.client.model.OrchestrationConfig;
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.DestinationAccessException;
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.DestinationNotFoundException;
Expand Down Expand Up @@ -64,6 +66,23 @@ public OrchestrationClient(@Nonnull final AiCoreDeployment deployment) {
this.deployment = () -> deployment;
}

/**
* Generate a completion for the given prompt.
*
* @param prompt The {@link OrchestrationPrompt} to send to orchestration.
* @param config The {@link ModuleConfigs} configuration to use for the completion.
* @return the completion output
* @throws OrchestrationClientException if the request fails.
*/
@Nonnull
public CompletionPostResponse chatCompletion(
@Nonnull final OrchestrationPrompt prompt, @Nonnull final ModuleConfigs config)
MatKuhr marked this conversation as resolved.
Show resolved Hide resolved
throws OrchestrationClientException {

val request = toCompletionPostRequestDto(prompt, config);
return executeRequest(request);
}

/**
* Generate a completion for the given prompt.
*
Expand Down Expand Up @@ -112,6 +131,20 @@ public CompletionPostResponse executeRequest(@Nonnull final CompletionPostReques
return executeRequest(postRequest);
}

/**
* Convert the given prompt and config into a low-level request DTO. The DTO allows for further
* customization before sending the request.
*
* @param prompt The {@link OrchestrationPrompt} to generate a completion for.
* @param config The {@link OrchestrationConfig } configuration to use for the completion.
* @return The low-level request DTO to send to orchestration.
*/
@Nonnull
public static CompletionPostRequest toCompletionPostRequestDto(
MatKuhr marked this conversation as resolved.
Show resolved Hide resolved
@Nonnull final OrchestrationPrompt prompt, @Nonnull final ModuleConfigs config) {
return ModuleConfigFactory.toCompletionPostRequestDto(prompt, config);
}

@SuppressWarnings("UnstableApiUsage")
@Nonnull
CompletionPostResponse executeRequest(@Nonnull final BasicClassicHttpRequest request) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package com.sap.ai.sdk.orchestration;

import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
import com.sap.ai.sdk.orchestration.client.model.OrchestrationConfig;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnull;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Value;
import lombok.val;

/**
* Represents a request that can be sent to the orchestration service, containing messages and
* configuration for the orchestration modules. Prompts may be reused across multiple requests.
*
* @see OrchestrationClient
* @see OrchestrationConfig
*/
@Value
@Getter(AccessLevel.PACKAGE)
@AllArgsConstructor
public class OrchestrationPrompt {
@Nonnull List<ChatMessage> messages;
@Nonnull Map<String, String> templateParameters;

/**
* Initialize a prompt with the given user message.
*
* @param message A user message.
*/
public OrchestrationPrompt(@Nonnull final String message) {
this(List.of(ChatMessage.create().role("user").content(message)), Map.of());
}

/**
* Initialize a prompt from the given messages.
*
* @param message The first message.
* @param messages Optionally, more messages.
*/
public OrchestrationPrompt(
@Nonnull final ChatMessage message, @Nonnull final ChatMessage... messages) {
val allMessages = new ArrayList<ChatMessage>();
allMessages.add(message);
allMessages.addAll(Arrays.asList(messages));
this.messages = allMessages;
this.templateParameters = Map.of();
}

/**
* Initialize a prompt based on template variables.
*
* @param inputParams The input parameters as entries of template variables and their contents.
*/
public OrchestrationPrompt(@Nonnull final Map<String, String> inputParams) {
this(List.of(), inputParams);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package com.sap.ai.sdk.orchestration;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.mock;

import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
import com.sap.ai.sdk.orchestration.client.model.FilteringModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.ModuleConfigs;
import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig;
import java.util.Map;
import org.junit.jupiter.api.Test;

class ModuleConfigFactoryTest {

@Test
void testThrowsOnMissingMessages() {
var prompt = new OrchestrationPrompt(Map.of());
var templateConfig = TemplatingModuleConfig.create().template();

assertThatThrownBy(() -> ModuleConfigFactory.toTemplateModuleConfigDto(prompt, templateConfig))
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining("A prompt is required");
}

@Test
void testEmptyTemplateConfig() {
var systemMessage = ChatMessage.create().role("system").content("foo");
var userMessage = ChatMessage.create().role("user").content("Hello");

var expected = TemplatingModuleConfig.create().template(systemMessage, userMessage);

var prompt = new OrchestrationPrompt(systemMessage, userMessage);
var actual =
ModuleConfigFactory.toTemplateModuleConfigDto(
prompt, TemplatingModuleConfig.create().template());

assertThat(actual).isEqualTo(expected);
assertThat(actual.getTemplate())
.describedAs(
"The template should be copied to not modify an existing config which might be reused.")
.isNotSameAs(expected.getTemplate());
}

@Test
void testMergingTemplateConfig() {
var systemMessage = ChatMessage.create().role("system").content("foo");
var userMessage = ChatMessage.create().role("user").content("Hello ");
var userMessage2 = ChatMessage.create().role("user").content("World");

var expected =
TemplatingModuleConfig.create().template(systemMessage, userMessage, userMessage2);

var prompt = new OrchestrationPrompt(userMessage2);
var templateConfig = TemplatingModuleConfig.create().template(systemMessage, userMessage);
var actual = ModuleConfigFactory.toTemplateModuleConfigDto(prompt, templateConfig);

assertThat(actual).isEqualTo(expected);
}

@Test
void testCopy() {
var moduleConfigs =
ModuleConfigs.create()
.llmModuleConfig(mock(LLMModuleConfig.class))
.templatingModuleConfig(mock(TemplatingModuleConfig.class))
.filteringModuleConfig(mock(FilteringModuleConfig.class))
.maskingModuleConfig(mock(MaskingModuleConfig.class));
assertThat(ModuleConfigFactory.copyModuleConfigs(moduleConfigs))
.isEqualTo(moduleConfigs)
.isNotSameAs(moduleConfigs);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
@WireMockTest
class OrchestrationUnitTest {
private OrchestrationClient client;
private ModuleConfigs config;
private final Function<String, InputStream> fileLoader =
filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename));

Expand All @@ -71,16 +72,6 @@ class OrchestrationUnitTest {
"frequency_penalty", 0,
"presence_penalty", 0));

private static final Function<TemplatingModuleConfig, CompletionPostRequest> TEMPLATE_CONFIG =
(TemplatingModuleConfig templatingModuleConfig) ->
CompletionPostRequest.create()
.orchestrationConfig(
OrchestrationConfig.create()
.moduleConfigurations(
ModuleConfigs.create()
.llmModuleConfig(LLM_CONFIG)
.templatingModuleConfig(templatingModuleConfig)));

@BeforeEach
void setup(WireMockRuntimeInfo server) {
stubFor(
Expand Down Expand Up @@ -109,27 +100,43 @@ void setup(WireMockRuntimeInfo server) {
.forDeploymentByScenario("orchestration")
.withResourceGroup("my-resource-group");
client = new OrchestrationClient(deployment);
config =
ModuleConfigs.create()
.llmModuleConfig(LLM_CONFIG)
.templatingModuleConfig(TemplatingModuleConfig.create().template());
}

@Test
void testCompletion() {
stubFor(
post(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion"))
.willReturn(
aResponse()
.withBodyFile("templatingResponse.json")
.withHeader("Content-Type", "application/json")));
final var result =
client.chatCompletion(new OrchestrationPrompt("What is the capital of France?"), config);

assertThat(result).isNotNull();
assertThat(result.getOrchestrationResult().getChoices().get(0).getMessage().getContent())
.isNotEmpty();
}

@Test
void templating() throws IOException {
void testTemplating() throws IOException {
stubFor(
post(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion"))
.willReturn(
aResponse()
.withBodyFile("templatingResponse.json")
.withHeader("Content-Type", "application/json")));

final var template = ChatMessage.create().role("user").content("{{?input}}");
final var template = List.of(ChatMessage.create().role("user").content("{{?input}}"));
final var inputParams =
Map.of("input", "Reply with 'Orchestration Service is working!' in German");

final var config =
TEMPLATE_CONFIG
.apply(TemplatingModuleConfig.create().template(template))
.inputParams(inputParams);

final var result = client.chatCompletion(config);
final var result =
client.chatCompletion(new OrchestrationPrompt(template, inputParams), config);

assertThat(result.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
assertThat(result.getModuleResults().getTemplating().get(0).getContent())
Expand Down Expand Up @@ -176,7 +183,7 @@ void templating() throws IOException {
}

@Test
void templatingBadRequest() {
void testBadRequest() {
stubFor(
post(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion"))
.willReturn(
Expand All @@ -191,17 +198,10 @@ void templatingBadRequest() {
}
""",
SC_BAD_REQUEST)));
var message = ChatMessage.create().role("user").content("What is the capital of {{?input}}?");
final var prompt = new OrchestrationPrompt(List.of(message), Map.of());

final var template = ChatMessage.create().role("user").content("{{?input}}");
// input params are omitted on purpose to trigger an error
Map<String, String> inputParams = Map.of();

final var config =
TEMPLATE_CONFIG
.apply(TemplatingModuleConfig.create().template(template))
.inputParams(inputParams);

assertThatThrownBy(() -> client.chatCompletion(config))
assertThatThrownBy(() -> client.chatCompletion(prompt, config))
.isInstanceOf(OrchestrationClientException.class)
.hasMessage(
"Request to orchestration service failed with status 400 Bad Request and error message: 'Missing required parameters: ['input']'");
Expand Down Expand Up @@ -305,21 +305,20 @@ void messagesHistory() throws IOException {
final var message =
ChatMessage.create().role("user").content("What is the typical food there?");

final var config =
TEMPLATE_CONFIG
.apply(TemplatingModuleConfig.create().template(message))
.messagesHistory(messagesHistory);
final var prompt = new OrchestrationPrompt(message);
final var request = OrchestrationClient.toCompletionPostRequestDto(prompt, config);
request.setMessagesHistory(messagesHistory);

final var result = client.chatCompletion(config);
final var result = client.chatCompletion(request);

assertThat(result.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");

// verify that the history is sent correctly
try (var requestInputStream = fileLoader.apply("messagesHistoryRequest.json")) {
final String request = new String(requestInputStream.readAllBytes());
final String requestBody = new String(requestInputStream.readAllBytes());
verify(
postRequestedFor(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion"))
.withRequestBody(equalToJson(request)));
.withRequestBody(equalToJson(requestBody)));
}
}

Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
<enforcer.skipBanGeneratedModulesReference>false</enforcer.skipBanGeneratedModulesReference>
<!-- Test coverage -->
<coverage.instruction>74%</coverage.instruction>
<coverage.branch>68%</coverage.branch>
<coverage.branch>60%</coverage.branch>
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

temporary, will be increased again after tests have been refactored in a separate PR

<coverage.complexity>67%</coverage.complexity>
<coverage.line>75%</coverage.line>
<coverage.method>80%</coverage.method>
Expand Down
Loading