diff --git a/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md b/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md index d336a47f5..31adb89a5 100644 --- a/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md +++ b/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md @@ -91,8 +91,7 @@ var prompt = new OrchestrationPrompt("Hello world! Why is this phrase so famous? var result = client.chatCompletion(prompt, config); -String messageResult = - result.getOrchestrationResult().getChoices().get(0).getMessage().getContent(); +String messageResult = result.getContent(); ``` In this example, the Orchestration service generates a response to the user message "Hello world! Why is this phrase so famous?". diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java new file mode 100644 index 000000000..7eb830e4a --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java @@ -0,0 +1,41 @@ +package com.sap.ai.sdk.orchestration; + +import static lombok.AccessLevel.PACKAGE; + +import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse; +import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous; +import javax.annotation.Nonnull; +import lombok.RequiredArgsConstructor; +import lombok.Value; + +/** Orchestration chat completion output. */ +@Value +@RequiredArgsConstructor(access = PACKAGE) +public class OrchestrationChatResponse { + CompletionPostResponse originalResponse; + + /** + * Get the message content from the output. + * + *

Note: If there are multiple choices only the first one is returned + * + * @return the message content or empty string. + * @throws OrchestrationClientException if the content filter filtered the output. + */ + @Nonnull + public String getContent() throws OrchestrationClientException { + final var choices = + ((LLMModuleResultSynchronous) originalResponse.getOrchestrationResult()).getChoices(); + + if (choices.isEmpty()) { + return ""; + } + + final var choice = choices.get(0); + + if ("content_filter".equals(choice.getFinishReason())) { + throw new OrchestrationClientException("Content filter filtered the output."); + } + return choice.getMessage().getContent(); + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java index b6622edd7..6053a8089 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java @@ -101,12 +101,12 @@ public static CompletionPostRequest toCompletionPostRequest( * @throws OrchestrationClientException if the request fails. */ @Nonnull - public CompletionPostResponse chatCompletion( + public OrchestrationChatResponse chatCompletion( @Nonnull final OrchestrationPrompt prompt, @Nonnull final OrchestrationModuleConfig config) throws OrchestrationClientException { val request = toCompletionPostRequest(prompt, config); - return executeRequest(request); + return new OrchestrationChatResponse(executeRequest(request)); } /** diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java index 5b6044a00..2af45a429 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java @@ -121,8 +121,7 @@ void testCompletion() { final var result = client.chatCompletion(prompt, config); assertThat(result).isNotNull(); - var orchestrationResult = (LLMModuleResultSynchronous) result.getOrchestrationResult(); - assertThat(orchestrationResult.getChoices().get(0).getMessage().getContent()).isNotEmpty(); + assertThat(result.getContent()).isNotEmpty(); } @Test @@ -141,11 +140,12 @@ void testTemplating() throws IOException { final var result = client.chatCompletion(new OrchestrationPrompt(inputParams, template), config); - assertThat(result.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91"); - assertThat(result.getModuleResults().getTemplating().get(0).getContent()) + final var response = result.getOriginalResponse(); + assertThat(response.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91"); + assertThat(response.getModuleResults().getTemplating().get(0).getContent()) .isEqualTo("Reply with 'Orchestration Service is working!' in German"); - assertThat(result.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user"); - var llm = (LLMModuleResultSynchronous) result.getModuleResults().getLlm(); + assertThat(response.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user"); + var llm = (LLMModuleResultSynchronous) response.getModuleResults().getLlm(); assertThat(llm.getId()).isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj"); assertThat(llm.getObject()).isEqualTo("chat.completion"); assertThat(llm.getCreated()).isEqualTo(1721224505); @@ -160,7 +160,7 @@ void testTemplating() throws IOException { assertThat(usage.getCompletionTokens()).isEqualTo(7); assertThat(usage.getPromptTokens()).isEqualTo(19); assertThat(usage.getTotalTokens()).isEqualTo(26); - var orchestrationResult = (LLMModuleResultSynchronous) result.getOrchestrationResult(); + var orchestrationResult = (LLMModuleResultSynchronous) response.getOrchestrationResult(); assertThat(orchestrationResult.getId()).isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj"); assertThat(orchestrationResult.getObject()).isEqualTo("chat.completion"); assertThat(orchestrationResult.getCreated()).isEqualTo(1721224505); @@ -286,7 +286,8 @@ void messagesHistory() throws IOException { final var result = client.chatCompletion(prompt, config); - assertThat(result.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91"); + assertThat(result.getOriginalResponse().getRequestId()) + .isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91"); // verify that the history is sent correctly try (var requestInputStream = fileLoader.apply("messagesHistoryRequest.json")) { @@ -310,13 +311,13 @@ void maskingPseudonymization() throws IOException { createMaskingConfig(DPIConfig.MethodEnum.PSEUDONYMIZATION, DPIEntities.PHONE); final var result = client.chatCompletion(prompt, config.withMaskingConfig(maskingConfig)); + final var response = result.getOriginalResponse(); - assertThat(result).isNotNull(); - GenericModuleResult inputMasking = result.getModuleResults().getInputMasking(); + assertThat(response).isNotNull(); + GenericModuleResult inputMasking = response.getModuleResults().getInputMasking(); assertThat(inputMasking.getMessage()).isEqualTo("Input to LLM is masked successfully."); assertThat(inputMasking.getData()).isNotNull(); - final var choices = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices(); - assertThat(choices.get(0).getMessage().getContent()).contains("Hi Mallory"); + assertThat(result.getContent()).contains("Hi Mallory"); // verify that the request is sent correctly try (var requestInputStream = fileLoader.apply("maskingRequest.json")) { @@ -412,4 +413,17 @@ void testErrorHandling() { softly.assertAll(); } + + @Test + void testEmptyChoicesResponse() { + stubFor( + post(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion")) + .willReturn( + aResponse() + .withBodyFile("emptyChoicesResponse.json") + .withHeader("Content-Type", "application/json"))); + final var result = client.chatCompletion(prompt, config); + + assertThat(result.getContent()).isEmpty(); + } } diff --git a/orchestration/src/test/resources/__files/emptyChoicesResponse.json b/orchestration/src/test/resources/__files/emptyChoicesResponse.json new file mode 100644 index 000000000..3d36bdcd2 --- /dev/null +++ b/orchestration/src/test/resources/__files/emptyChoicesResponse.json @@ -0,0 +1,35 @@ +{ + "request_id": "26ea36b5-c196-4806-a9a6-a686f0c6ad91", + "module_results": { + "templating": [ + { + "role": "user", + "content": "Reply with 'Orchestration Service is working!' in German" + } + ], + "llm": { + "id": "chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj", + "object": "chat.completion", + "created": 1721224505, + "model": "gpt-35-turbo-16k", + "choices": [], + "usage": { + "completion_tokens": 7, + "prompt_tokens": 19, + "total_tokens": 26 + } + } + }, + "orchestration_result": { + "id": "chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj", + "object": "chat.completion", + "created": 1721224505, + "model": "gpt-35-turbo-16k", + "choices": [], + "usage": { + "completion_tokens": 7, + "prompt_tokens": 19, + "total_tokens": 26 + } + } +} diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java index 812192769..ccf03d943 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java @@ -1,5 +1,6 @@ package com.sap.ai.sdk.app.controllers; +import com.sap.ai.sdk.orchestration.OrchestrationChatResponse; import com.sap.ai.sdk.orchestration.OrchestrationClient; import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig; import com.sap.ai.sdk.orchestration.OrchestrationPrompt; @@ -7,7 +8,6 @@ import com.sap.ai.sdk.orchestration.client.model.AzureContentSafetyFilterConfig; import com.sap.ai.sdk.orchestration.client.model.AzureThreshold; import com.sap.ai.sdk.orchestration.client.model.ChatMessage; -import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse; import com.sap.ai.sdk.orchestration.client.model.DPIConfig; import com.sap.ai.sdk.orchestration.client.model.DPIEntities; import com.sap.ai.sdk.orchestration.client.model.DPIEntityConfig; @@ -44,7 +44,7 @@ class OrchestrationController { */ @GetMapping("/completion") @Nonnull - public CompletionPostResponse completion() { + public OrchestrationChatResponse completion() { final var prompt = new OrchestrationPrompt("Hello world! Why is this phrase so famous?"); return client.chatCompletion(prompt, config); @@ -57,7 +57,7 @@ public CompletionPostResponse completion() { */ @GetMapping("/template") @Nonnull - public CompletionPostResponse template() { + public OrchestrationChatResponse template() { final var template = new ChatMessage() .role("user") @@ -78,7 +78,7 @@ public CompletionPostResponse template() { */ @GetMapping("/messagesHistory") @Nonnull - public CompletionPostResponse messagesHistory() { + public OrchestrationChatResponse messagesHistory() { final List messagesHistory = List.of( new ChatMessage().role("user").content("What is the capital of France?"), @@ -98,7 +98,7 @@ public CompletionPostResponse messagesHistory() { */ @GetMapping("/filter/{threshold}") @Nonnull - public CompletionPostResponse filter( + public OrchestrationChatResponse filter( @Nonnull @PathVariable("threshold") final AzureThreshold threshold) { final var prompt = new OrchestrationPrompt( @@ -145,7 +145,7 @@ private static FilteringModuleConfig createAzureContentFilter( */ @GetMapping("/maskingAnonymization") @Nonnull - public CompletionPostResponse maskingAnonymization() { + public OrchestrationChatResponse maskingAnonymization() { final var systemMessage = new ChatMessage() .role("system") @@ -176,7 +176,7 @@ public CompletionPostResponse maskingAnonymization() { */ @GetMapping("/maskingPseudonymization") @Nonnull - public CompletionPostResponse maskingPseudonymization() { + public OrchestrationChatResponse maskingPseudonymization() { final var systemMessage = new ChatMessage() .role("system") diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java index 31aa71edb..dc2d27a79 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java @@ -27,18 +27,13 @@ void testCompletion() { final var result = controller.completion(); assertThat(result).isNotNull(); - assertThat( - ((LLMModuleResultSynchronous) result.getOrchestrationResult()) - .getChoices() - .get(0) - .getMessage() - .getContent()) - .isNotEmpty(); + assertThat(result.getContent()).isNotEmpty(); } @Test void testTemplate() { - final var result = controller.template(); + final var response = controller.template(); + final var result = response.getOriginalResponse(); assertThat(result.getRequestId()).isNotEmpty(); assertThat(result.getModuleResults().getTemplating().get(0).getContent()) @@ -64,12 +59,12 @@ void testTemplate() { assertThat(orchestrationResult.getCreated()).isGreaterThan(1); assertThat(orchestrationResult.getModel()) .isEqualTo(OrchestrationController.LLM_CONFIG.getModelName()); - choices = ((LLMModuleResultSynchronous) orchestrationResult).getChoices(); + choices = orchestrationResult.getChoices(); assertThat(choices.get(0).getIndex()).isZero(); assertThat(choices.get(0).getMessage().getContent()).isNotEmpty(); assertThat(choices.get(0).getMessage().getRole()).isEqualTo("assistant"); assertThat(choices.get(0).getFinishReason()).isEqualTo("stop"); - usage = ((LLMModuleResultSynchronous) orchestrationResult).getUsage(); + usage = orchestrationResult.getUsage(); assertThat(usage.getCompletionTokens()).isGreaterThan(1); assertThat(usage.getPromptTokens()).isGreaterThan(1); assertThat(usage.getTotalTokens()).isGreaterThan(1); @@ -77,7 +72,8 @@ void testTemplate() { @Test void testLenientContentFilter() { - var result = controller.filter(AzureThreshold.NUMBER_4); + var response = controller.filter(AzureThreshold.NUMBER_4); + var result = response.getOriginalResponse(); var llmChoice = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0); assertThat(llmChoice.getFinishReason()).isEqualTo("stop"); @@ -97,7 +93,7 @@ void testStrictContentFilter() { @Test void testMessagesHistory() { - CompletionPostResponse result = controller.messagesHistory(); + CompletionPostResponse result = controller.messagesHistory().getOriginalResponse(); final var choices = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices(); assertThat(choices.get(0).getMessage().getContent()).isNotEmpty(); } @@ -105,7 +101,8 @@ void testMessagesHistory() { @SuppressWarnings("unchecked") @Test void testMaskingAnonymization() { - var result = controller.maskingAnonymization(); + var response = controller.maskingAnonymization(); + var result = response.getOriginalResponse(); var llmChoice = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0); assertThat(llmChoice.getFinishReason()).isEqualTo("stop"); @@ -124,7 +121,8 @@ void testMaskingAnonymization() { @SuppressWarnings("unchecked") @Test void testMaskingPseudonymization() { - var result = controller.maskingPseudonymization(); + var response = controller.maskingPseudonymization(); + var result = response.getOriginalResponse(); var llmChoice = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0); assertThat(llmChoice.getFinishReason()).isEqualTo("stop");