From 9f70f99449e6e137ad0724192824dfc607363db1 Mon Sep 17 00:00:00 2001 From: Jonas Israel Date: Wed, 13 Nov 2024 15:21:04 +0100 Subject: [PATCH 1/9] Added Response Convenience --- docs/guides/ORCHESTRATION_CHAT_COMPLETION.md | 3 +-- .../orchestration/OrchestrationClient.java | 9 +++---- .../orchestration/OrchestrationResponse.java | 26 +++++++++++++++++++ .../orchestration/OrchestrationUnitTest.java | 3 +-- pom.xml | 2 +- .../controllers/OrchestrationController.java | 14 +++++----- .../app/controllers/OrchestrationTest.java | 3 +-- 7 files changed, 41 insertions(+), 19 deletions(-) create mode 100644 orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationResponse.java diff --git a/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md b/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md index d336a47f5..31adb89a5 100644 --- a/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md +++ b/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md @@ -91,8 +91,7 @@ var prompt = new OrchestrationPrompt("Hello world! Why is this phrase so famous? var result = client.chatCompletion(prompt, config); -String messageResult = - result.getOrchestrationResult().getChoices().get(0).getMessage().getContent(); +String messageResult = result.getContent(); ``` In this example, the Orchestration service generates a response to the user message "Hello world! Why is this phrase so famous?". diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java index cf2a75ed9..03c024720 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java @@ -9,7 +9,6 @@ import com.sap.ai.sdk.core.AiCoreDeployment; import com.sap.ai.sdk.core.AiCoreService; import com.sap.ai.sdk.orchestration.client.model.CompletionPostRequest; -import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse; import com.sap.ai.sdk.orchestration.client.model.ModuleConfigs; import com.sap.ai.sdk.orchestration.client.model.OrchestrationConfig; import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor; @@ -75,7 +74,7 @@ public OrchestrationClient(@Nonnull final AiCoreDeployment deployment) { * @throws OrchestrationClientException if the request fails. */ @Nonnull - public CompletionPostResponse chatCompletion( + public OrchestrationResponse chatCompletion( @Nonnull final OrchestrationPrompt prompt, @Nonnull final OrchestrationModuleConfig config) throws OrchestrationClientException { @@ -104,7 +103,7 @@ public CompletionPostResponse chatCompletion( * @throws OrchestrationClientException If the request fails. */ @Nonnull - public CompletionPostResponse executeRequest(@Nonnull final CompletionPostRequest request) + public OrchestrationResponse executeRequest(@Nonnull final CompletionPostRequest request) throws OrchestrationClientException { final BasicClassicHttpRequest postRequest = new HttpPost("/completion"); try { @@ -133,13 +132,13 @@ public static CompletionPostRequest toCompletionPostRequest( } @Nonnull - CompletionPostResponse executeRequest(@Nonnull final BasicClassicHttpRequest request) { + OrchestrationResponse executeRequest(@Nonnull final BasicClassicHttpRequest request) { try { val destination = deployment.get().destination(); log.debug("Using destination {} to connect to orchestration service", destination); val client = ApacheHttpClient5Accessor.getHttpClient(destination); return client.execute( - request, new OrchestrationResponseHandler<>(CompletionPostResponse.class)); + request, new OrchestrationResponseHandler<>(OrchestrationResponse.class)); } catch (NoSuchElementException | DestinationAccessException | DestinationNotFoundException diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationResponse.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationResponse.java new file mode 100644 index 000000000..d67956132 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationResponse.java @@ -0,0 +1,26 @@ +package com.sap.ai.sdk.orchestration; + +import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse; +import javax.annotation.Nonnull; + +/** Orchestration chat completion output. */ +public class OrchestrationResponse extends CompletionPostResponse { + /** + * Get the message content from the output. + * + *

Note: If there are multiple choices only the first one is returned + * + * @return the message content or empty string. + * @throws OrchestrationClientException if the content filter filtered the output. + */ + @Nonnull + public String getContent() throws OrchestrationClientException { + if (getOrchestrationResult().getChoices().isEmpty()) { + return ""; + } + if ("content_filter".equals(getOrchestrationResult().getChoices().get(0).getFinishReason())) { + throw new OrchestrationClientException("Content filter filtered the output."); + } + return getOrchestrationResult().getChoices().get(0).getMessage().getContent(); + } +} diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java index 3898b8650..1e2a71b10 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java @@ -115,8 +115,7 @@ void testCompletion() { final var result = client.chatCompletion(prompt, config); assertThat(result).isNotNull(); - assertThat(result.getOrchestrationResult().getChoices().get(0).getMessage().getContent()) - .isNotEmpty(); + assertThat(result.getContent()).isNotEmpty(); } @Test diff --git a/pom.xml b/pom.xml index a7342d72f..43b1fa0be 100644 --- a/pom.xml +++ b/pom.xml @@ -74,7 +74,7 @@ false 74% - 62% + 60% 67% 75% 80% diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java index e6e1fc8c3..7f8767860 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java @@ -3,10 +3,10 @@ import com.sap.ai.sdk.orchestration.OrchestrationClient; import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig; import com.sap.ai.sdk.orchestration.OrchestrationPrompt; +import com.sap.ai.sdk.orchestration.OrchestrationResponse; import com.sap.ai.sdk.orchestration.client.model.AzureContentSafety; import com.sap.ai.sdk.orchestration.client.model.AzureThreshold; import com.sap.ai.sdk.orchestration.client.model.ChatMessage; -import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse; import com.sap.ai.sdk.orchestration.client.model.DPIEntities; import com.sap.ai.sdk.orchestration.client.model.DPIEntityConfig; import com.sap.ai.sdk.orchestration.client.model.FilterConfig; @@ -44,7 +44,7 @@ class OrchestrationController { */ @GetMapping("/completion") @Nonnull - public CompletionPostResponse completion() { + public OrchestrationResponse completion() { final var prompt = new OrchestrationPrompt("Hello world! Why is this phrase so famous?"); return client.chatCompletion(prompt, config); @@ -57,7 +57,7 @@ public CompletionPostResponse completion() { */ @GetMapping("/template") @Nonnull - public CompletionPostResponse template() { + public OrchestrationResponse template() { final var template = ChatMessage.create() .role("user") @@ -78,7 +78,7 @@ public CompletionPostResponse template() { */ @GetMapping("/messagesHistory") @Nonnull - public CompletionPostResponse messagesHistory() { + public OrchestrationResponse messagesHistory() { final List messagesHistory = List.of( ChatMessage.create().role("user").content("What is the capital of France?"), @@ -99,7 +99,7 @@ public CompletionPostResponse messagesHistory() { */ @GetMapping("/filter/{threshold}") @Nonnull - public CompletionPostResponse filter( + public OrchestrationResponse filter( @Nonnull @PathVariable("threshold") final AzureThreshold threshold) { final var prompt = new OrchestrationPrompt( @@ -146,7 +146,7 @@ private static FilteringModuleConfig createAzureContentFilter( */ @GetMapping("/maskingAnonymization") @Nonnull - public CompletionPostResponse maskingAnonymization() { + public OrchestrationResponse maskingAnonymization() { final var systemMessage = ChatMessage.create() .role("system") @@ -177,7 +177,7 @@ public CompletionPostResponse maskingAnonymization() { */ @GetMapping("/maskingPseudonymization") @Nonnull - public CompletionPostResponse maskingPseudonymization() { + public OrchestrationResponse maskingPseudonymization() { final var systemMessage = ChatMessage.create() .role("system") diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java index d5a0f044a..18eac5c9a 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java @@ -25,8 +25,7 @@ void testCompletion() { final var result = controller.completion(); assertThat(result).isNotNull(); - assertThat(result.getOrchestrationResult().getChoices().get(0).getMessage().getContent()) - .isNotEmpty(); + assertThat(result.getContent()).isNotEmpty(); } @Test From c657c061e986f6b691c9f01e81abd16832fc0c35 Mon Sep 17 00:00:00 2001 From: Jonas Israel Date: Thu, 14 Nov 2024 14:51:56 +0100 Subject: [PATCH 2/9] Applied review comments --- .../model/OpenAiChatCompletionOutput.java | 3 --- .../orchestration/OrchestrationClient.java | 8 +++---- .../orchestration/OrchestrationResponse.java | 19 +++++++++++------ .../orchestration/OrchestrationUnitTest.java | 21 ++++++++++--------- pom.xml | 2 +- .../app/controllers/OrchestrationTest.java | 18 +++++++++------- 6 files changed, 40 insertions(+), 31 deletions(-) diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionOutput.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionOutput.java index 7d172aa0d..b9c8f814a 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionOutput.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionOutput.java @@ -40,9 +40,6 @@ public class OpenAiChatCompletionOutput extends OpenAiCompletionOutput */ @Nonnull public String getContent() throws OpenAiClientException { - if (getChoices().isEmpty()) { - return ""; - } if ("content_filter".equals(getChoices().get(0).getFinishReason())) { throw new OpenAiClientException("Content filter filtered the output."); } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java index 86a5fa3ce..3b5c7ac0c 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java @@ -104,7 +104,7 @@ public OrchestrationResponse chatCompletion( throws OrchestrationClientException { val request = toCompletionPostRequest(prompt, config); - return executeRequest(request); + return new OrchestrationResponse(executeRequest(request)); } /** @@ -128,7 +128,7 @@ public OrchestrationResponse chatCompletion( * @throws OrchestrationClientException If the request fails. */ @Nonnull - public OrchestrationResponse executeRequest(@Nonnull final CompletionPostRequest request) + public CompletionPostResponse executeRequest(@Nonnull final CompletionPostRequest request) throws OrchestrationClientException { final BasicClassicHttpRequest postRequest = new HttpPost("/completion"); try { @@ -143,13 +143,13 @@ public OrchestrationResponse executeRequest(@Nonnull final CompletionPostRequest } @Nonnull - OrchestrationResponse executeRequest(@Nonnull final BasicClassicHttpRequest request) { + CompletionPostResponse executeRequest(@Nonnull final BasicClassicHttpRequest request) { try { val destination = deployment.get().destination(); log.debug("Using destination {} to connect to orchestration service", destination); val client = ApacheHttpClient5Accessor.getHttpClient(destination); return client.execute( - request, new OrchestrationResponseHandler<>(OrchestrationResponse.class)); + request, new OrchestrationResponseHandler<>(CompletionPostResponse.class)); } catch (NoSuchElementException | DestinationAccessException | DestinationNotFoundException diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationResponse.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationResponse.java index d67956132..b8c1a11f8 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationResponse.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationResponse.java @@ -1,10 +1,17 @@ package com.sap.ai.sdk.orchestration; import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse; +import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous; import javax.annotation.Nonnull; +import lombok.Getter; +import lombok.RequiredArgsConstructor; /** Orchestration chat completion output. */ -public class OrchestrationResponse extends CompletionPostResponse { +@RequiredArgsConstructor +@Getter +public class OrchestrationResponse { + private final CompletionPostResponse data; + /** * Get the message content from the output. * @@ -15,12 +22,12 @@ public class OrchestrationResponse extends CompletionPostResponse { */ @Nonnull public String getContent() throws OrchestrationClientException { - if (getOrchestrationResult().getChoices().isEmpty()) { - return ""; - } - if ("content_filter".equals(getOrchestrationResult().getChoices().get(0).getFinishReason())) { + final var choice = + ((LLMModuleResultSynchronous) data.getOrchestrationResult()).getChoices().get(0); + + if ("content_filter".equals(choice.getFinishReason())) { throw new OrchestrationClientException("Content filter filtered the output."); } - return getOrchestrationResult().getChoices().get(0).getMessage().getContent(); + return choice.getMessage().getContent(); } } diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java index db78d3336..ab2ac0d47 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java @@ -140,11 +140,12 @@ void testTemplating() throws IOException { final var result = client.chatCompletion(new OrchestrationPrompt(inputParams, template), config); - assertThat(result.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91"); - assertThat(result.getModuleResults().getTemplating().get(0).getContent()) + final var response = result.getData(); + assertThat(response.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91"); + assertThat(response.getModuleResults().getTemplating().get(0).getContent()) .isEqualTo("Reply with 'Orchestration Service is working!' in German"); - assertThat(result.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user"); - var llm = (LLMModuleResultSynchronous) result.getModuleResults().getLlm(); + assertThat(response.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user"); + var llm = (LLMModuleResultSynchronous) response.getModuleResults().getLlm(); assertThat(llm.getId()).isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj"); assertThat(llm.getObject()).isEqualTo("chat.completion"); assertThat(llm.getCreated()).isEqualTo(1721224505); @@ -159,7 +160,7 @@ void testTemplating() throws IOException { assertThat(usage.getCompletionTokens()).isEqualTo(7); assertThat(usage.getPromptTokens()).isEqualTo(19); assertThat(usage.getTotalTokens()).isEqualTo(26); - var orchestrationResult = (LLMModuleResultSynchronous) result.getOrchestrationResult(); + var orchestrationResult = (LLMModuleResultSynchronous) response.getOrchestrationResult(); assertThat(orchestrationResult.getId()).isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj"); assertThat(orchestrationResult.getObject()).isEqualTo("chat.completion"); assertThat(orchestrationResult.getCreated()).isEqualTo(1721224505); @@ -285,7 +286,7 @@ void messagesHistory() throws IOException { final var result = client.chatCompletion(prompt, config); - assertThat(result.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91"); + assertThat(result.getData().getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91"); // verify that the history is sent correctly try (var requestInputStream = fileLoader.apply("messagesHistoryRequest.json")) { @@ -309,13 +310,13 @@ void maskingAnonymization() throws IOException { createMaskingConfig(DPIConfig.MethodEnum.ANONYMIZATION, DPIEntities.PHONE); final var result = client.chatCompletion(prompt, config.withMaskingConfig(maskingConfig)); + final var response = result.getData(); - assertThat(result).isNotNull(); - GenericModuleResult inputMasking = result.getModuleResults().getInputMasking(); + assertThat(response).isNotNull(); + GenericModuleResult inputMasking = response.getModuleResults().getInputMasking(); assertThat(inputMasking.getMessage()).isEqualTo("Input to LLM is masked successfully."); assertThat(inputMasking.getData()).isNotNull(); - final var choices = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices(); - assertThat(choices.get(0).getMessage().getContent()) + assertThat(result.getContent()) .isEqualTo( "I'm sorry, I cannot provide information about specific individuals, including their nationality."); diff --git a/pom.xml b/pom.xml index bda0646e3..b6d7adc84 100644 --- a/pom.xml +++ b/pom.xml @@ -74,7 +74,7 @@ false 74% - 66% + 67% 67% 75% 80% diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java index 209404c45..95a7f8683 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java @@ -32,7 +32,8 @@ void testCompletion() { @Test void testTemplate() { - final var result = controller.template(); + final var response = controller.template(); + final var result = response.getData(); assertThat(result.getRequestId()).isNotEmpty(); assertThat(result.getModuleResults().getTemplating().get(0).getContent()) @@ -58,12 +59,12 @@ void testTemplate() { assertThat(orchestrationResult.getCreated()).isGreaterThan(1); assertThat(orchestrationResult.getModel()) .isEqualTo(OrchestrationController.LLM_CONFIG.getModelName()); - choices = ((LLMModuleResultSynchronous) orchestrationResult).getChoices(); + choices = orchestrationResult.getChoices(); assertThat(choices.get(0).getIndex()).isZero(); assertThat(choices.get(0).getMessage().getContent()).isNotEmpty(); assertThat(choices.get(0).getMessage().getRole()).isEqualTo("assistant"); assertThat(choices.get(0).getFinishReason()).isEqualTo("stop"); - usage = ((LLMModuleResultSynchronous) orchestrationResult).getUsage(); + usage = orchestrationResult.getUsage(); assertThat(usage.getCompletionTokens()).isGreaterThan(1); assertThat(usage.getPromptTokens()).isGreaterThan(1); assertThat(usage.getTotalTokens()).isGreaterThan(1); @@ -71,7 +72,8 @@ void testTemplate() { @Test void testLenientContentFilter() { - var result = controller.filter(AzureThreshold.NUMBER_4); + var response = controller.filter(AzureThreshold.NUMBER_4); + var result = response.getData(); var llmChoice = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0); assertThat(llmChoice.getFinishReason()).isEqualTo("stop"); @@ -91,7 +93,7 @@ void testStrictContentFilter() { @Test void testMessagesHistory() { - CompletionPostResponse result = controller.messagesHistory(); + CompletionPostResponse result = controller.messagesHistory().getData(); final var choices = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices(); assertThat(choices.get(0).getMessage().getContent()).isNotEmpty(); } @@ -99,7 +101,8 @@ void testMessagesHistory() { @SuppressWarnings("unchecked") @Test void testMaskingAnonymization() { - var result = controller.maskingAnonymization(); + var response = controller.maskingAnonymization(); + var result = response.getData(); var llmChoice = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0); assertThat(llmChoice.getFinishReason()).isEqualTo("stop"); @@ -118,7 +121,8 @@ void testMaskingAnonymization() { @SuppressWarnings("unchecked") @Test void testMaskingPseudonymization() { - var result = controller.maskingPseudonymization(); + var response = controller.maskingPseudonymization(); + var result = response.getData(); var llmChoice = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0); assertThat(llmChoice.getFinishReason()).isEqualTo("stop"); From 1768a98af17e0d891dbb6fbe5e35a8bb34674f5e Mon Sep 17 00:00:00 2001 From: Jonas Israel Date: Mon, 18 Nov 2024 10:06:08 +0100 Subject: [PATCH 3/9] Use Lombok's Value --- .../ai/sdk/orchestration/OrchestrationResponse.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationResponse.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationResponse.java index b8c1a11f8..8cd3e8356 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationResponse.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationResponse.java @@ -1,16 +1,18 @@ package com.sap.ai.sdk.orchestration; +import static lombok.AccessLevel.PACKAGE; + import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse; import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous; import javax.annotation.Nonnull; -import lombok.Getter; import lombok.RequiredArgsConstructor; +import lombok.Value; /** Orchestration chat completion output. */ -@RequiredArgsConstructor -@Getter +@Value +@RequiredArgsConstructor(access = PACKAGE) public class OrchestrationResponse { - private final CompletionPostResponse data; + CompletionPostResponse data; /** * Get the message content from the output. From f5fb9fdf3559eae4dff81c857a780900cc0f279e Mon Sep 17 00:00:00 2001 From: Jonas Israel Date: Mon, 18 Nov 2024 14:48:50 +0100 Subject: [PATCH 4/9] Renaming and adding tests. --- .../model/OpenAiChatCompletionOutput.java | 3 ++ ...se.java => OrchestrationChatResponse.java} | 14 +++++--- .../orchestration/OrchestrationClient.java | 4 +-- .../orchestration/OrchestrationUnitTest.java | 20 +++++++++-- .../__files/emptyChoicesResponse.json | 35 +++++++++++++++++++ .../controllers/OrchestrationController.java | 14 ++++---- .../app/controllers/OrchestrationTest.java | 10 +++--- 7 files changed, 79 insertions(+), 21 deletions(-) rename orchestration/src/main/java/com/sap/ai/sdk/orchestration/{OrchestrationResponse.java => OrchestrationChatResponse.java} (76%) create mode 100644 orchestration/src/test/resources/__files/emptyChoicesResponse.json diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionOutput.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionOutput.java index b9c8f814a..7d172aa0d 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionOutput.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionOutput.java @@ -40,6 +40,9 @@ public class OpenAiChatCompletionOutput extends OpenAiCompletionOutput */ @Nonnull public String getContent() throws OpenAiClientException { + if (getChoices().isEmpty()) { + return ""; + } if ("content_filter".equals(getChoices().get(0).getFinishReason())) { throw new OpenAiClientException("Content filter filtered the output."); } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationResponse.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java similarity index 76% rename from orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationResponse.java rename to orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java index 8cd3e8356..7eb830e4a 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationResponse.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java @@ -11,8 +11,8 @@ /** Orchestration chat completion output. */ @Value @RequiredArgsConstructor(access = PACKAGE) -public class OrchestrationResponse { - CompletionPostResponse data; +public class OrchestrationChatResponse { + CompletionPostResponse originalResponse; /** * Get the message content from the output. @@ -24,8 +24,14 @@ public class OrchestrationResponse { */ @Nonnull public String getContent() throws OrchestrationClientException { - final var choice = - ((LLMModuleResultSynchronous) data.getOrchestrationResult()).getChoices().get(0); + final var choices = + ((LLMModuleResultSynchronous) originalResponse.getOrchestrationResult()).getChoices(); + + if (choices.isEmpty()) { + return ""; + } + + final var choice = choices.get(0); if ("content_filter".equals(choice.getFinishReason())) { throw new OrchestrationClientException("Content filter filtered the output."); diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java index 9a1183541..6053a8089 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java @@ -101,12 +101,12 @@ public static CompletionPostRequest toCompletionPostRequest( * @throws OrchestrationClientException if the request fails. */ @Nonnull - public OrchestrationResponse chatCompletion( + public OrchestrationChatResponse chatCompletion( @Nonnull final OrchestrationPrompt prompt, @Nonnull final OrchestrationModuleConfig config) throws OrchestrationClientException { val request = toCompletionPostRequest(prompt, config); - return new OrchestrationResponse(executeRequest(request)); + return new OrchestrationChatResponse(executeRequest(request)); } /** diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java index b2a0a0b1e..2af45a429 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java @@ -140,7 +140,7 @@ void testTemplating() throws IOException { final var result = client.chatCompletion(new OrchestrationPrompt(inputParams, template), config); - final var response = result.getData(); + final var response = result.getOriginalResponse(); assertThat(response.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91"); assertThat(response.getModuleResults().getTemplating().get(0).getContent()) .isEqualTo("Reply with 'Orchestration Service is working!' in German"); @@ -286,7 +286,8 @@ void messagesHistory() throws IOException { final var result = client.chatCompletion(prompt, config); - assertThat(result.getData().getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91"); + assertThat(result.getOriginalResponse().getRequestId()) + .isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91"); // verify that the history is sent correctly try (var requestInputStream = fileLoader.apply("messagesHistoryRequest.json")) { @@ -310,7 +311,7 @@ void maskingPseudonymization() throws IOException { createMaskingConfig(DPIConfig.MethodEnum.PSEUDONYMIZATION, DPIEntities.PHONE); final var result = client.chatCompletion(prompt, config.withMaskingConfig(maskingConfig)); - final var response = result.getData(); + final var response = result.getOriginalResponse(); assertThat(response).isNotNull(); GenericModuleResult inputMasking = response.getModuleResults().getInputMasking(); @@ -412,4 +413,17 @@ void testErrorHandling() { softly.assertAll(); } + + @Test + void testEmptyChoicesResponse() { + stubFor( + post(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion")) + .willReturn( + aResponse() + .withBodyFile("emptyChoicesResponse.json") + .withHeader("Content-Type", "application/json"))); + final var result = client.chatCompletion(prompt, config); + + assertThat(result.getContent()).isEmpty(); + } } diff --git a/orchestration/src/test/resources/__files/emptyChoicesResponse.json b/orchestration/src/test/resources/__files/emptyChoicesResponse.json new file mode 100644 index 000000000..3d36bdcd2 --- /dev/null +++ b/orchestration/src/test/resources/__files/emptyChoicesResponse.json @@ -0,0 +1,35 @@ +{ + "request_id": "26ea36b5-c196-4806-a9a6-a686f0c6ad91", + "module_results": { + "templating": [ + { + "role": "user", + "content": "Reply with 'Orchestration Service is working!' in German" + } + ], + "llm": { + "id": "chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj", + "object": "chat.completion", + "created": 1721224505, + "model": "gpt-35-turbo-16k", + "choices": [], + "usage": { + "completion_tokens": 7, + "prompt_tokens": 19, + "total_tokens": 26 + } + } + }, + "orchestration_result": { + "id": "chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj", + "object": "chat.completion", + "created": 1721224505, + "model": "gpt-35-turbo-16k", + "choices": [], + "usage": { + "completion_tokens": 7, + "prompt_tokens": 19, + "total_tokens": 26 + } + } +} diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java index 4c1fb8ab2..ccf03d943 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java @@ -1,9 +1,9 @@ package com.sap.ai.sdk.app.controllers; +import com.sap.ai.sdk.orchestration.OrchestrationChatResponse; import com.sap.ai.sdk.orchestration.OrchestrationClient; import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig; import com.sap.ai.sdk.orchestration.OrchestrationPrompt; -import com.sap.ai.sdk.orchestration.OrchestrationResponse; import com.sap.ai.sdk.orchestration.client.model.AzureContentSafety; import com.sap.ai.sdk.orchestration.client.model.AzureContentSafetyFilterConfig; import com.sap.ai.sdk.orchestration.client.model.AzureThreshold; @@ -44,7 +44,7 @@ class OrchestrationController { */ @GetMapping("/completion") @Nonnull - public OrchestrationResponse completion() { + public OrchestrationChatResponse completion() { final var prompt = new OrchestrationPrompt("Hello world! Why is this phrase so famous?"); return client.chatCompletion(prompt, config); @@ -57,7 +57,7 @@ public OrchestrationResponse completion() { */ @GetMapping("/template") @Nonnull - public OrchestrationResponse template() { + public OrchestrationChatResponse template() { final var template = new ChatMessage() .role("user") @@ -78,7 +78,7 @@ public OrchestrationResponse template() { */ @GetMapping("/messagesHistory") @Nonnull - public OrchestrationResponse messagesHistory() { + public OrchestrationChatResponse messagesHistory() { final List messagesHistory = List.of( new ChatMessage().role("user").content("What is the capital of France?"), @@ -98,7 +98,7 @@ public OrchestrationResponse messagesHistory() { */ @GetMapping("/filter/{threshold}") @Nonnull - public OrchestrationResponse filter( + public OrchestrationChatResponse filter( @Nonnull @PathVariable("threshold") final AzureThreshold threshold) { final var prompt = new OrchestrationPrompt( @@ -145,7 +145,7 @@ private static FilteringModuleConfig createAzureContentFilter( */ @GetMapping("/maskingAnonymization") @Nonnull - public OrchestrationResponse maskingAnonymization() { + public OrchestrationChatResponse maskingAnonymization() { final var systemMessage = new ChatMessage() .role("system") @@ -176,7 +176,7 @@ public OrchestrationResponse maskingAnonymization() { */ @GetMapping("/maskingPseudonymization") @Nonnull - public OrchestrationResponse maskingPseudonymization() { + public OrchestrationChatResponse maskingPseudonymization() { final var systemMessage = new ChatMessage() .role("system") diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java index 95a7f8683..dc2d27a79 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java @@ -33,7 +33,7 @@ void testCompletion() { @Test void testTemplate() { final var response = controller.template(); - final var result = response.getData(); + final var result = response.getOriginalResponse(); assertThat(result.getRequestId()).isNotEmpty(); assertThat(result.getModuleResults().getTemplating().get(0).getContent()) @@ -73,7 +73,7 @@ void testTemplate() { @Test void testLenientContentFilter() { var response = controller.filter(AzureThreshold.NUMBER_4); - var result = response.getData(); + var result = response.getOriginalResponse(); var llmChoice = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0); assertThat(llmChoice.getFinishReason()).isEqualTo("stop"); @@ -93,7 +93,7 @@ void testStrictContentFilter() { @Test void testMessagesHistory() { - CompletionPostResponse result = controller.messagesHistory().getData(); + CompletionPostResponse result = controller.messagesHistory().getOriginalResponse(); final var choices = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices(); assertThat(choices.get(0).getMessage().getContent()).isNotEmpty(); } @@ -102,7 +102,7 @@ void testMessagesHistory() { @Test void testMaskingAnonymization() { var response = controller.maskingAnonymization(); - var result = response.getData(); + var result = response.getOriginalResponse(); var llmChoice = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0); assertThat(llmChoice.getFinishReason()).isEqualTo("stop"); @@ -122,7 +122,7 @@ void testMaskingAnonymization() { @Test void testMaskingPseudonymization() { var response = controller.maskingPseudonymization(); - var result = response.getData(); + var result = response.getOriginalResponse(); var llmChoice = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0); assertThat(llmChoice.getFinishReason()).isEqualTo("stop"); From 3f5f04325215e0ca3efc8d2677b1d6479dcac8c7 Mon Sep 17 00:00:00 2001 From: Jonas Israel Date: Tue, 19 Nov 2024 13:50:09 +0100 Subject: [PATCH 5/9] Add and test getTokenUsage --- .../sdk/orchestration/OrchestrationChatResponse.java | 11 +++++++++++ .../ai/sdk/orchestration/OrchestrationUnitTest.java | 4 ++-- .../sap/ai/sdk/app/controllers/OrchestrationTest.java | 4 ++-- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java index 7eb830e4a..cb0cd70d7 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java @@ -4,6 +4,7 @@ import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse; import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous; +import com.sap.ai.sdk.orchestration.client.model.TokenUsage; import javax.annotation.Nonnull; import lombok.RequiredArgsConstructor; import lombok.Value; @@ -38,4 +39,14 @@ public String getContent() throws OrchestrationClientException { } return choice.getMessage().getContent(); } + + /** + * Get the token usage. + * + * @return The token usage. + */ + @Nonnull + public TokenUsage getTokenUsage() { + return ((LLMModuleResultSynchronous) originalResponse.getOrchestrationResult()).getUsage(); + } } diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java index 2af45a429..d5bf82bd6 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java @@ -156,7 +156,7 @@ void testTemplating() throws IOException { .isEqualTo("Orchestration Service funktioniert!"); assertThat(choices.get(0).getMessage().getRole()).isEqualTo("assistant"); assertThat(choices.get(0).getFinishReason()).isEqualTo("stop"); - var usage = llm.getUsage(); + var usage = result.getTokenUsage(); assertThat(usage.getCompletionTokens()).isEqualTo(7); assertThat(usage.getPromptTokens()).isEqualTo(19); assertThat(usage.getTotalTokens()).isEqualTo(26); @@ -171,7 +171,7 @@ void testTemplating() throws IOException { .isEqualTo("Orchestration Service funktioniert!"); assertThat(choices.get(0).getMessage().getRole()).isEqualTo("assistant"); assertThat(choices.get(0).getFinishReason()).isEqualTo("stop"); - usage = orchestrationResult.getUsage(); + usage = result.getTokenUsage(); assertThat(usage.getCompletionTokens()).isEqualTo(7); assertThat(usage.getPromptTokens()).isEqualTo(19); assertThat(usage.getTotalTokens()).isEqualTo(26); diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java index dc2d27a79..387c4abc4 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java @@ -49,7 +49,7 @@ void testTemplate() { assertThat(choices.get(0).getMessage().getContent()).isNotEmpty(); assertThat(choices.get(0).getMessage().getRole()).isEqualTo("assistant"); assertThat(choices.get(0).getFinishReason()).isEqualTo("stop"); - var usage = llm.getUsage(); + var usage = response.getTokenUsage(); assertThat(usage.getCompletionTokens()).isGreaterThan(1); assertThat(usage.getPromptTokens()).isGreaterThan(1); assertThat(usage.getTotalTokens()).isGreaterThan(1); @@ -64,7 +64,7 @@ void testTemplate() { assertThat(choices.get(0).getMessage().getContent()).isNotEmpty(); assertThat(choices.get(0).getMessage().getRole()).isEqualTo("assistant"); assertThat(choices.get(0).getFinishReason()).isEqualTo("stop"); - usage = orchestrationResult.getUsage(); + usage = response.getTokenUsage(); assertThat(usage.getCompletionTokens()).isGreaterThan(1); assertThat(usage.getPromptTokens()).isGreaterThan(1); assertThat(usage.getTotalTokens()).isGreaterThan(1); From 8244f64dd8fc3f94da2239fb55d853ed90e90d28 Mon Sep 17 00:00:00 2001 From: Jonas Israel Date: Tue, 19 Nov 2024 14:14:00 +0100 Subject: [PATCH 6/9] Add and test getAllMessages --- .../OrchestrationChatResponse.java | 13 +++++++++++++ .../orchestration/OrchestrationUnitTest.java | 4 ++-- .../sdk/app/controllers/OrchestrationTest.java | 18 +++++++++--------- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java index cb0cd70d7..68fcb4220 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java @@ -2,10 +2,13 @@ import static lombok.AccessLevel.PACKAGE; +import com.sap.ai.sdk.orchestration.client.model.ChatMessage; import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse; import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous; import com.sap.ai.sdk.orchestration.client.model.TokenUsage; +import java.util.List; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import lombok.RequiredArgsConstructor; import lombok.Value; @@ -49,4 +52,14 @@ public String getContent() throws OrchestrationClientException { public TokenUsage getTokenUsage() { return ((LLMModuleResultSynchronous) originalResponse.getOrchestrationResult()).getUsage(); } + + /** + * Get all messages. + * + * @return A list of all messages. + */ + @Nullable + public List getAllMessages() { + return originalResponse.getModuleResults().getTemplating(); + } } diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java index d5bf82bd6..d467cc569 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java @@ -142,9 +142,9 @@ void testTemplating() throws IOException { final var response = result.getOriginalResponse(); assertThat(response.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91"); - assertThat(response.getModuleResults().getTemplating().get(0).getContent()) + assertThat(result.getAllMessages().get(0).getContent()) .isEqualTo("Reply with 'Orchestration Service is working!' in German"); - assertThat(response.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user"); + assertThat(result.getAllMessages().get(0).getRole()).isEqualTo("user"); var llm = (LLMModuleResultSynchronous) response.getModuleResults().getLlm(); assertThat(llm.getId()).isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj"); assertThat(llm.getObject()).isEqualTo("chat.completion"); diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java index 387c4abc4..d6ba3c034 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java @@ -32,14 +32,14 @@ void testCompletion() { @Test void testTemplate() { - final var response = controller.template(); - final var result = response.getOriginalResponse(); + final var result = controller.template(); + final var response = result.getOriginalResponse(); - assertThat(result.getRequestId()).isNotEmpty(); - assertThat(result.getModuleResults().getTemplating().get(0).getContent()) + assertThat(response.getRequestId()).isNotEmpty(); + assertThat(result.getAllMessages().get(0).getContent()) .isEqualTo("Reply with 'Orchestration Service is working!' in German"); - assertThat(result.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user"); - var llm = (LLMModuleResultSynchronous) result.getModuleResults().getLlm(); + assertThat(result.getAllMessages().get(0).getRole()).isEqualTo("user"); + var llm = (LLMModuleResultSynchronous) response.getModuleResults().getLlm(); assertThat(llm.getId()).isNotEmpty(); assertThat(llm.getObject()).isEqualTo("chat.completion"); assertThat(llm.getCreated()).isGreaterThan(1); @@ -49,12 +49,12 @@ void testTemplate() { assertThat(choices.get(0).getMessage().getContent()).isNotEmpty(); assertThat(choices.get(0).getMessage().getRole()).isEqualTo("assistant"); assertThat(choices.get(0).getFinishReason()).isEqualTo("stop"); - var usage = response.getTokenUsage(); + var usage = result.getTokenUsage(); assertThat(usage.getCompletionTokens()).isGreaterThan(1); assertThat(usage.getPromptTokens()).isGreaterThan(1); assertThat(usage.getTotalTokens()).isGreaterThan(1); - var orchestrationResult = ((LLMModuleResultSynchronous) result.getOrchestrationResult()); + var orchestrationResult = ((LLMModuleResultSynchronous) response.getOrchestrationResult()); assertThat(orchestrationResult.getObject()).isEqualTo("chat.completion"); assertThat(orchestrationResult.getCreated()).isGreaterThan(1); assertThat(orchestrationResult.getModel()) @@ -64,7 +64,7 @@ void testTemplate() { assertThat(choices.get(0).getMessage().getContent()).isNotEmpty(); assertThat(choices.get(0).getMessage().getRole()).isEqualTo("assistant"); assertThat(choices.get(0).getFinishReason()).isEqualTo("stop"); - usage = response.getTokenUsage(); + usage = result.getTokenUsage(); assertThat(usage.getCompletionTokens()).isGreaterThan(1); assertThat(usage.getPromptTokens()).isGreaterThan(1); assertThat(usage.getTotalTokens()).isGreaterThan(1); From bb8141ebdb6923e1761980a61d73291b6ed55201 Mon Sep 17 00:00:00 2001 From: Jonas Israel Date: Tue, 19 Nov 2024 15:51:32 +0100 Subject: [PATCH 7/9] Apply requested changes --- .../sdk/orchestration/OrchestrationChatResponse.java | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java index 68fcb4220..0ffd3f8f5 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java @@ -8,7 +8,6 @@ import com.sap.ai.sdk.orchestration.client.model.TokenUsage; import java.util.List; import javax.annotation.Nonnull; -import javax.annotation.Nullable; import lombok.RequiredArgsConstructor; import lombok.Value; @@ -58,8 +57,13 @@ public TokenUsage getTokenUsage() { * * @return A list of all messages. */ - @Nullable + @Nonnull public List getAllMessages() { - return originalResponse.getModuleResults().getTemplating(); + final var allMessages = originalResponse.getModuleResults().getTemplating(); + + if (allMessages == null) { + return List.of(); + } + return allMessages; } } From ec15a5a95fb7f3c5bd27d7d6c8f4491c4ae56ffb Mon Sep 17 00:00:00 2001 From: Jonas Israel Date: Wed, 20 Nov 2024 13:10:14 +0100 Subject: [PATCH 8/9] Apply requested changes --- .../OrchestrationChatResponse.java | 32 +++++++++-------- .../orchestration/OrchestrationUnitTest.java | 13 ------- .../__files/emptyChoicesResponse.json | 35 ------------------- 3 files changed, 18 insertions(+), 62 deletions(-) delete mode 100644 orchestration/src/test/resources/__files/emptyChoicesResponse.json diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java index 0ffd3f8f5..5560e3862 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java @@ -4,6 +4,7 @@ import com.sap.ai.sdk.orchestration.client.model.ChatMessage; import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse; +import com.sap.ai.sdk.orchestration.client.model.LLMChoice; import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous; import com.sap.ai.sdk.orchestration.client.model.TokenUsage; import java.util.List; @@ -27,14 +28,7 @@ public class OrchestrationChatResponse { */ @Nonnull public String getContent() throws OrchestrationClientException { - final var choices = - ((LLMModuleResultSynchronous) originalResponse.getOrchestrationResult()).getChoices(); - - if (choices.isEmpty()) { - return ""; - } - - final var choice = choices.get(0); + final var choice = getCurrentChoice(); if ("content_filter".equals(choice.getFinishReason())) { throw new OrchestrationClientException("Content filter filtered the output."); @@ -53,17 +47,27 @@ public TokenUsage getTokenUsage() { } /** - * Get all messages. + * Get all messages. This can be used for subsequent prompts as a message history. * * @return A list of all messages. */ @Nonnull public List getAllMessages() { - final var allMessages = originalResponse.getModuleResults().getTemplating(); + final var messages = originalResponse.getModuleResults().getTemplating(); - if (allMessages == null) { - return List.of(); - } - return allMessages; + messages.add(getCurrentChoice().getMessage()); + return messages; + } + + /** + * Get list of choices. + * + * @return A list of choices. + */ + @Nonnull + private LLMChoice getCurrentChoice() { + return ((LLMModuleResultSynchronous) originalResponse.getOrchestrationResult()) + .getChoices() + .get(0); } } diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java index 1ea58ed4b..931edbf28 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java @@ -413,17 +413,4 @@ void testErrorHandling() { softly.assertAll(); } - - @Test - void testEmptyChoicesResponse() { - stubFor( - post(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion")) - .willReturn( - aResponse() - .withBodyFile("emptyChoicesResponse.json") - .withHeader("Content-Type", "application/json"))); - final var result = client.chatCompletion(prompt, config); - - assertThat(result.getContent()).isEmpty(); - } } diff --git a/orchestration/src/test/resources/__files/emptyChoicesResponse.json b/orchestration/src/test/resources/__files/emptyChoicesResponse.json deleted file mode 100644 index 3d36bdcd2..000000000 --- a/orchestration/src/test/resources/__files/emptyChoicesResponse.json +++ /dev/null @@ -1,35 +0,0 @@ -{ - "request_id": "26ea36b5-c196-4806-a9a6-a686f0c6ad91", - "module_results": { - "templating": [ - { - "role": "user", - "content": "Reply with 'Orchestration Service is working!' in German" - } - ], - "llm": { - "id": "chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj", - "object": "chat.completion", - "created": 1721224505, - "model": "gpt-35-turbo-16k", - "choices": [], - "usage": { - "completion_tokens": 7, - "prompt_tokens": 19, - "total_tokens": 26 - } - } - }, - "orchestration_result": { - "id": "chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj", - "object": "chat.completion", - "created": 1721224505, - "model": "gpt-35-turbo-16k", - "choices": [], - "usage": { - "completion_tokens": 7, - "prompt_tokens": 19, - "total_tokens": 26 - } - } -} From da98a2926e27b9e6dd69425bb9cdb83de6e9153e Mon Sep 17 00:00:00 2001 From: Jonas Israel Date: Thu, 21 Nov 2024 13:41:19 +0100 Subject: [PATCH 9/9] Apply requested changes --- .../sdk/orchestration/OrchestrationChatResponse.java | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java index 5560e3862..de0e786c4 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java @@ -7,7 +7,9 @@ import com.sap.ai.sdk.orchestration.client.model.LLMChoice; import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous; import com.sap.ai.sdk.orchestration.client.model.TokenUsage; +import java.util.ArrayList; import java.util.List; +import java.util.Objects; import javax.annotation.Nonnull; import lombok.RequiredArgsConstructor; import lombok.Value; @@ -53,19 +55,20 @@ public TokenUsage getTokenUsage() { */ @Nonnull public List getAllMessages() { - final var messages = originalResponse.getModuleResults().getTemplating(); - + final var items = Objects.requireNonNull(originalResponse.getModuleResults().getTemplating()); + final var messages = new ArrayList<>(items); messages.add(getCurrentChoice().getMessage()); return messages; } /** - * Get list of choices. + * Get current choice. * - * @return A list of choices. + * @return The current choice. */ @Nonnull private LLMChoice getCurrentChoice() { + // We expect choices to be defined and never empty. return ((LLMModuleResultSynchronous) originalResponse.getOrchestrationResult()) .getChoices() .get(0);