diff --git a/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md b/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md index d336a47f5..31adb89a5 100644 --- a/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md +++ b/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md @@ -91,8 +91,7 @@ var prompt = new OrchestrationPrompt("Hello world! Why is this phrase so famous? var result = client.chatCompletion(prompt, config); -String messageResult = - result.getOrchestrationResult().getChoices().get(0).getMessage().getContent(); +String messageResult = result.getContent(); ``` In this example, the Orchestration service generates a response to the user message "Hello world! Why is this phrase so famous?". diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java new file mode 100644 index 000000000..7eb830e4a --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java @@ -0,0 +1,41 @@ +package com.sap.ai.sdk.orchestration; + +import static lombok.AccessLevel.PACKAGE; + +import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse; +import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous; +import javax.annotation.Nonnull; +import lombok.RequiredArgsConstructor; +import lombok.Value; + +/** Orchestration chat completion output. */ +@Value +@RequiredArgsConstructor(access = PACKAGE) +public class OrchestrationChatResponse { + CompletionPostResponse originalResponse; + + /** + * Get the message content from the output. + * + *
Note: If there are multiple choices only the first one is returned
+ *
+ * @return the message content or empty string.
+ * @throws OrchestrationClientException if the content filter filtered the output.
+ */
+ @Nonnull
+ public String getContent() throws OrchestrationClientException {
+ final var choices =
+ ((LLMModuleResultSynchronous) originalResponse.getOrchestrationResult()).getChoices();
+
+ if (choices.isEmpty()) {
+ return "";
+ }
+
+ final var choice = choices.get(0);
+
+ if ("content_filter".equals(choice.getFinishReason())) {
+ throw new OrchestrationClientException("Content filter filtered the output.");
+ }
+ return choice.getMessage().getContent();
+ }
+}
diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java
index b6622edd7..6053a8089 100644
--- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java
+++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java
@@ -101,12 +101,12 @@ public static CompletionPostRequest toCompletionPostRequest(
* @throws OrchestrationClientException if the request fails.
*/
@Nonnull
- public CompletionPostResponse chatCompletion(
+ public OrchestrationChatResponse chatCompletion(
@Nonnull final OrchestrationPrompt prompt, @Nonnull final OrchestrationModuleConfig config)
throws OrchestrationClientException {
val request = toCompletionPostRequest(prompt, config);
- return executeRequest(request);
+ return new OrchestrationChatResponse(executeRequest(request));
}
/**
diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java
index 5b6044a00..2af45a429 100644
--- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java
+++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java
@@ -121,8 +121,7 @@ void testCompletion() {
final var result = client.chatCompletion(prompt, config);
assertThat(result).isNotNull();
- var orchestrationResult = (LLMModuleResultSynchronous) result.getOrchestrationResult();
- assertThat(orchestrationResult.getChoices().get(0).getMessage().getContent()).isNotEmpty();
+ assertThat(result.getContent()).isNotEmpty();
}
@Test
@@ -141,11 +140,12 @@ void testTemplating() throws IOException {
final var result =
client.chatCompletion(new OrchestrationPrompt(inputParams, template), config);
- assertThat(result.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
- assertThat(result.getModuleResults().getTemplating().get(0).getContent())
+ final var response = result.getOriginalResponse();
+ assertThat(response.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
+ assertThat(response.getModuleResults().getTemplating().get(0).getContent())
.isEqualTo("Reply with 'Orchestration Service is working!' in German");
- assertThat(result.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user");
- var llm = (LLMModuleResultSynchronous) result.getModuleResults().getLlm();
+ assertThat(response.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user");
+ var llm = (LLMModuleResultSynchronous) response.getModuleResults().getLlm();
assertThat(llm.getId()).isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj");
assertThat(llm.getObject()).isEqualTo("chat.completion");
assertThat(llm.getCreated()).isEqualTo(1721224505);
@@ -160,7 +160,7 @@ void testTemplating() throws IOException {
assertThat(usage.getCompletionTokens()).isEqualTo(7);
assertThat(usage.getPromptTokens()).isEqualTo(19);
assertThat(usage.getTotalTokens()).isEqualTo(26);
- var orchestrationResult = (LLMModuleResultSynchronous) result.getOrchestrationResult();
+ var orchestrationResult = (LLMModuleResultSynchronous) response.getOrchestrationResult();
assertThat(orchestrationResult.getId()).isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj");
assertThat(orchestrationResult.getObject()).isEqualTo("chat.completion");
assertThat(orchestrationResult.getCreated()).isEqualTo(1721224505);
@@ -286,7 +286,8 @@ void messagesHistory() throws IOException {
final var result = client.chatCompletion(prompt, config);
- assertThat(result.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
+ assertThat(result.getOriginalResponse().getRequestId())
+ .isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
// verify that the history is sent correctly
try (var requestInputStream = fileLoader.apply("messagesHistoryRequest.json")) {
@@ -310,13 +311,13 @@ void maskingPseudonymization() throws IOException {
createMaskingConfig(DPIConfig.MethodEnum.PSEUDONYMIZATION, DPIEntities.PHONE);
final var result = client.chatCompletion(prompt, config.withMaskingConfig(maskingConfig));
+ final var response = result.getOriginalResponse();
- assertThat(result).isNotNull();
- GenericModuleResult inputMasking = result.getModuleResults().getInputMasking();
+ assertThat(response).isNotNull();
+ GenericModuleResult inputMasking = response.getModuleResults().getInputMasking();
assertThat(inputMasking.getMessage()).isEqualTo("Input to LLM is masked successfully.");
assertThat(inputMasking.getData()).isNotNull();
- final var choices = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices();
- assertThat(choices.get(0).getMessage().getContent()).contains("Hi Mallory");
+ assertThat(result.getContent()).contains("Hi Mallory");
// verify that the request is sent correctly
try (var requestInputStream = fileLoader.apply("maskingRequest.json")) {
@@ -412,4 +413,17 @@ void testErrorHandling() {
softly.assertAll();
}
+
+ @Test
+ void testEmptyChoicesResponse() {
+ stubFor(
+ post(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion"))
+ .willReturn(
+ aResponse()
+ .withBodyFile("emptyChoicesResponse.json")
+ .withHeader("Content-Type", "application/json")));
+ final var result = client.chatCompletion(prompt, config);
+
+ assertThat(result.getContent()).isEmpty();
+ }
}
diff --git a/orchestration/src/test/resources/__files/emptyChoicesResponse.json b/orchestration/src/test/resources/__files/emptyChoicesResponse.json
new file mode 100644
index 000000000..3d36bdcd2
--- /dev/null
+++ b/orchestration/src/test/resources/__files/emptyChoicesResponse.json
@@ -0,0 +1,35 @@
+{
+ "request_id": "26ea36b5-c196-4806-a9a6-a686f0c6ad91",
+ "module_results": {
+ "templating": [
+ {
+ "role": "user",
+ "content": "Reply with 'Orchestration Service is working!' in German"
+ }
+ ],
+ "llm": {
+ "id": "chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj",
+ "object": "chat.completion",
+ "created": 1721224505,
+ "model": "gpt-35-turbo-16k",
+ "choices": [],
+ "usage": {
+ "completion_tokens": 7,
+ "prompt_tokens": 19,
+ "total_tokens": 26
+ }
+ }
+ },
+ "orchestration_result": {
+ "id": "chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj",
+ "object": "chat.completion",
+ "created": 1721224505,
+ "model": "gpt-35-turbo-16k",
+ "choices": [],
+ "usage": {
+ "completion_tokens": 7,
+ "prompt_tokens": 19,
+ "total_tokens": 26
+ }
+ }
+}
diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java
index 812192769..ccf03d943 100644
--- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java
+++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java
@@ -1,5 +1,6 @@
package com.sap.ai.sdk.app.controllers;
+import com.sap.ai.sdk.orchestration.OrchestrationChatResponse;
import com.sap.ai.sdk.orchestration.OrchestrationClient;
import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig;
import com.sap.ai.sdk.orchestration.OrchestrationPrompt;
@@ -7,7 +8,6 @@
import com.sap.ai.sdk.orchestration.client.model.AzureContentSafetyFilterConfig;
import com.sap.ai.sdk.orchestration.client.model.AzureThreshold;
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
-import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse;
import com.sap.ai.sdk.orchestration.client.model.DPIConfig;
import com.sap.ai.sdk.orchestration.client.model.DPIEntities;
import com.sap.ai.sdk.orchestration.client.model.DPIEntityConfig;
@@ -44,7 +44,7 @@ class OrchestrationController {
*/
@GetMapping("/completion")
@Nonnull
- public CompletionPostResponse completion() {
+ public OrchestrationChatResponse completion() {
final var prompt = new OrchestrationPrompt("Hello world! Why is this phrase so famous?");
return client.chatCompletion(prompt, config);
@@ -57,7 +57,7 @@ public CompletionPostResponse completion() {
*/
@GetMapping("/template")
@Nonnull
- public CompletionPostResponse template() {
+ public OrchestrationChatResponse template() {
final var template =
new ChatMessage()
.role("user")
@@ -78,7 +78,7 @@ public CompletionPostResponse template() {
*/
@GetMapping("/messagesHistory")
@Nonnull
- public CompletionPostResponse messagesHistory() {
+ public OrchestrationChatResponse messagesHistory() {
final List