From c326b7da5adb0a9fcabee427c377182e52d1aafd Mon Sep 17 00:00:00 2001 From: Charles Dubois <103174266+CharlesDuboisSAP@users.noreply.github.com> Date: Tue, 10 Dec 2024 17:13:12 +0100 Subject: [PATCH] Orchestration streaming (#227) * Orchestration streaming first version * Added unit tests * Formatting * Added documentation * Added tests * Release notes * Applied Alex's review comments --------- Co-authored-by: SAP Cloud SDK Bot --- docs/guides/ORCHESTRATION_CHAT_COMPLETION.md | 26 +++ docs/release-notes/release_notes.md | 2 +- orchestration/pom.xml | 5 + .../IterableStreamConverter.java | 122 +++++++++++ .../OrchestrationChatCompletionDelta.java | 44 ++++ .../orchestration/OrchestrationClient.java | 78 +++++++ .../OrchestrationStreamingHandler.java | 55 +++++ .../ai/sdk/orchestration/StreamedDelta.java | 43 ++++ .../orchestration/OrchestrationUnitTest.java | 195 +++++++++++++++++- .../streamChatCompletionInputFilter.json | 22 ++ .../test/resources/streamChatCompletion.txt | 4 + .../streamChatCompletionOutputFilter.txt | 2 + pom.xml | 10 +- .../sdk/app/controllers/OpenAiController.java | 3 +- .../controllers/OrchestrationController.java | 39 ++++ .../src/main/resources/static/index.html | 1 + .../app/controllers/OrchestrationTest.java | 24 +++ 17 files changed, 660 insertions(+), 15 deletions(-) create mode 100644 orchestration/src/main/java/com/sap/ai/sdk/orchestration/IterableStreamConverter.java create mode 100644 orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatCompletionDelta.java create mode 100644 orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationStreamingHandler.java create mode 100644 orchestration/src/main/java/com/sap/ai/sdk/orchestration/StreamedDelta.java create mode 100644 orchestration/src/test/resources/__files/streamChatCompletionInputFilter.json create mode 100644 orchestration/src/test/resources/streamChatCompletion.txt create mode 100644 orchestration/src/test/resources/streamChatCompletionOutputFilter.txt diff --git a/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md b/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md index d1a23f187..ccd033da8 100644 --- a/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md +++ b/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md @@ -217,6 +217,32 @@ Use the grounding module to provide additional context to the AI model. In this example, the AI model is provided with additional context in the form of grounding information. Note, that it is necessary to provide the grounding input via one or more input variables. +### Stream chat completion + +It's possible to pass a stream of chat completion delta elements, e.g. from the application backend to the frontend in real-time. + +#### Asynchronous Streaming + +This is a blocking example for streaming and printing directly to the console: + +```java +String msg = "Can you give me the first 100 numbers of the Fibonacci sequence?"; + +// try-with-resources on stream ensures the connection will be closed +try (Stream stream = client.streamChatCompletion(prompt, config)) { + stream.forEach( + deltaString -> { + System.out.print(deltaString); + System.out.flush(); + }); +} +``` + +#### Spring Boot example + +Please find [an example in our Spring Boot application](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java). +It shows the usage of Spring Boot's `ResponseBodyEmitter` to stream the chat completion delta messages to the frontend in real-time. + ### Set model parameters Change your LLM configuration to add model parameters: diff --git a/docs/release-notes/release_notes.md b/docs/release-notes/release_notes.md index 30c6ae851..275f7c850 100644 --- a/docs/release-notes/release_notes.md +++ b/docs/release-notes/release_notes.md @@ -12,7 +12,7 @@ ### ✨ New Functionality -- +- Added `streamChatCompletion()` and `streamChatCompletionDeltas()` to the `OrchestrationClient`. ### 📈 Improvements diff --git a/orchestration/pom.xml b/orchestration/pom.xml index 846359337..4635360a1 100644 --- a/orchestration/pom.xml +++ b/orchestration/pom.xml @@ -112,6 +112,11 @@ mockito-core test + + org.junit.jupiter + junit-jupiter-params + test + diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/IterableStreamConverter.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/IterableStreamConverter.java new file mode 100644 index 000000000..4d267829a --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/IterableStreamConverter.java @@ -0,0 +1,122 @@ +package com.sap.ai.sdk.orchestration; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Spliterator.NONNULL; +import static java.util.Spliterator.ORDERED; + +import io.vavr.control.Try; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.Spliterators; +import java.util.concurrent.Callable; +import java.util.function.Function; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.AccessLevel; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.apache.hc.core5.http.HttpEntity; + +/** + * Internal utility class to convert from a reading handler to {@link Iterable} and {@link Stream}. + * + *

Note: All operations are sequential in nature. Thread safety is not + * guaranteed. + * + * @param Iterated item type. + */ +@Slf4j +@RequiredArgsConstructor(access = AccessLevel.PRIVATE) +class IterableStreamConverter implements Iterator { + /** see DEFAULT_CHAR_BUFFER_SIZE in {@link BufferedReader} * */ + static final int BUFFER_SIZE = 8192; + + /** Read next entry for Stream or {@code null} when no further entry can be read. */ + private final Callable readHandler; + + /** Close handler to be called when Stream terminated. */ + private final Runnable stopHandler; + + /** Error handler to be called when Stream is interrupted. */ + private final Function errorHandler; + + private boolean isDone = false; + private boolean isNextFetched = false; + private T next = null; + + @SuppressWarnings("checkstyle:IllegalCatch") + @Override + public boolean hasNext() { + if (isDone) { + return false; + } + if (isNextFetched) { + return true; + } + try { + next = readHandler.call(); + isNextFetched = true; + if (next == null) { + isDone = true; + stopHandler.run(); + } + } catch (final Exception e) { + isDone = true; + stopHandler.run(); + log.debug("Error while reading next element.", e); + throw errorHandler.apply(e); + } + return !isDone; + } + + @Override + public T next() { + if (next == null && !hasNext()) { + throw new NoSuchElementException(); // normally not reached with Stream API + } + isNextFetched = false; + return next; + } + + /** + * Create a sequential Stream of lines from an HTTP response string (UTF-8). The underlying {@link + * InputStream} is closed, when the resulting Stream is closed (e.g. via try-with-resources) or + * when an exception occurred. + * + * @param entity The HTTP entity object. + * @return A sequential Stream object. + * @throws OrchestrationClientException if the provided HTTP entity object is {@code null} or + * empty. + */ + @SuppressWarnings("PMD.CloseResource") // Stream is closed automatically when consumed + @Nonnull + static Stream lines(@Nullable final HttpEntity entity) + throws OrchestrationClientException { + if (entity == null) { + throw new OrchestrationClientException("Orchestration service response was empty."); + } + + final InputStream inputStream; + try { + inputStream = entity.getContent(); + } catch (final IOException e) { + throw new OrchestrationClientException("Failed to read response content.", e); + } + + final var reader = new BufferedReader(new InputStreamReader(inputStream, UTF_8), BUFFER_SIZE); + final Runnable closeHandler = + () -> Try.run(reader::close).onFailure(e -> log.error("Could not close input stream", e)); + final Function errHandler = + e -> new OrchestrationClientException("Parsing response content was interrupted.", e); + + final var iterator = new IterableStreamConverter<>(reader::readLine, closeHandler, errHandler); + final var spliterator = Spliterators.spliteratorUnknownSize(iterator, ORDERED | NONNULL); + return StreamSupport.stream(spliterator, /* NOT PARALLEL */ false).onClose(closeHandler); + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatCompletionDelta.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatCompletionDelta.java new file mode 100644 index 000000000..36ae5e209 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatCompletionDelta.java @@ -0,0 +1,44 @@ +package com.sap.ai.sdk.orchestration; + +import com.sap.ai.sdk.orchestration.model.CompletionPostResponse; +import com.sap.ai.sdk.orchestration.model.LLMModuleResultSynchronous; +import java.util.Map; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.val; + +/** Orchestration chat completion output delta for streaming. */ +public class OrchestrationChatCompletionDelta extends CompletionPostResponse + implements StreamedDelta { + + @Nonnull + @Override + // will be fixed once the generated code add a discriminator which will allow this class to extend + // CompletionPostResponseStreaming + @SuppressWarnings("unchecked") + public String getDeltaContent() { + val choices = ((LLMModuleResultSynchronous) getOrchestrationResult()).getChoices(); + // Avoid the first delta: "choices":[] + if (!choices.isEmpty() + // Multiple choices are spread out on multiple deltas + // A delta only contains one choice with a variable index + && choices.get(0).getIndex() == 0) { + + final var message = (Map) choices.get(0).getCustomField("delta"); + // Avoid the second delta: "choices":[{"delta":{"content":"","role":"assistant"}}] + if (message != null && message.get("content") != null) { + return message.get("content").toString(); + } + } + return ""; + } + + @Nullable + @Override + public String getFinishReason() { + return ((LLMModuleResultSynchronous) getOrchestrationResult()) + .getChoices() + .get(0) + .getFinishReason(); + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java index 3a1ab932d..986730149 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java @@ -22,12 +22,14 @@ import com.sap.cloud.sdk.cloudplatform.connectivity.exception.HttpClientInstantiationException; import java.io.IOException; import java.util.function.Supplier; +import java.util.stream.Stream; import javax.annotation.Nonnull; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.hc.client5.http.classic.methods.HttpPost; import org.apache.hc.core5.http.ContentType; import org.apache.hc.core5.http.io.entity.StringEntity; +import org.apache.hc.core5.http.message.BasicClassicHttpRequest; /** Client to execute requests to the orchestration service. */ @Slf4j @@ -105,6 +107,33 @@ public OrchestrationChatResponse chatCompletion( return new OrchestrationChatResponse(executeRequest(request)); } + /** + * Generate a completion for the given prompt. + * + * @param prompt a text message. + * @return A stream of message deltas + * @throws OrchestrationClientException if the request fails or if the finish reason is + * content_filter + * @since 1.1.0 + */ + @Nonnull + public Stream streamChatCompletion( + @Nonnull final OrchestrationPrompt prompt, @Nonnull final OrchestrationModuleConfig config) + throws OrchestrationClientException { + + val request = toCompletionPostRequest(prompt, config); + return streamChatCompletionDeltas(request) + .peek(OrchestrationClient::throwOnContentFilter) + .map(OrchestrationChatCompletionDelta::getDeltaContent); + } + + private static void throwOnContentFilter(@Nonnull final OrchestrationChatCompletionDelta delta) { + final String finishReason = delta.getFinishReason(); + if (finishReason != null && finishReason.equals("content_filter")) { + throw new OrchestrationClientException("Content filter filtered the output."); + } + } + /** * Serializes the given request, executes it and deserializes the response. * @@ -205,4 +234,53 @@ CompletionPostResponse executeRequest(@Nonnull final String request) { throw new OrchestrationClientException("Failed to execute request", e); } } + + /** + * Generate a completion for the given prompt. + * + * @param request the prompt, including messages and other parameters. + * @return A stream of chat completion delta elements. + * @throws OrchestrationClientException if the request fails + * @since 1.1.0 + */ + @Nonnull + public Stream streamChatCompletionDeltas( + @Nonnull final CompletionPostRequest request) throws OrchestrationClientException { + request.getOrchestrationConfig().setStream(true); + return executeStream("/completion", request, OrchestrationChatCompletionDelta.class); + } + + @Nonnull + private Stream executeStream( + @Nonnull final String path, + @Nonnull final Object payload, + @Nonnull final Class deltaType) { + final var request = new HttpPost(path); + serializeAndSetHttpEntity(request, payload); + return streamRequest(request, deltaType); + } + + private static void serializeAndSetHttpEntity( + @Nonnull final BasicClassicHttpRequest request, @Nonnull final Object payload) { + try { + final var json = JACKSON.writeValueAsString(payload); + request.setEntity(new StringEntity(json, ContentType.APPLICATION_JSON)); + } catch (final JsonProcessingException e) { + throw new OrchestrationClientException("Failed to serialize request parameters", e); + } + } + + @Nonnull + private Stream streamRequest( + final BasicClassicHttpRequest request, @Nonnull final Class deltaType) { + try { + val destination = destinationSupplier.get(); + log.debug("Using destination {} to connect to orchestration service", destination); + val client = ApacheHttpClient5Accessor.getHttpClient(destination); + return new OrchestrationStreamingHandler<>(deltaType) + .handleResponse(client.executeOpen(null, request, null)); + } catch (final IOException e) { + throw new OrchestrationClientException("Request to the Orchestration service failed", e); + } + } } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationStreamingHandler.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationStreamingHandler.java new file mode 100644 index 000000000..5a20002d5 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationStreamingHandler.java @@ -0,0 +1,55 @@ +package com.sap.ai.sdk.orchestration; + +import static com.sap.ai.sdk.orchestration.OrchestrationClient.JACKSON; +import static com.sap.ai.sdk.orchestration.OrchestrationResponseHandler.buildExceptionAndThrow; +import static com.sap.ai.sdk.orchestration.OrchestrationResponseHandler.parseErrorAndThrow; + +import java.io.IOException; +import java.util.stream.Stream; +import javax.annotation.Nonnull; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.apache.hc.core5.http.ClassicHttpResponse; + +@Slf4j +@RequiredArgsConstructor +class OrchestrationStreamingHandler { + + @Nonnull private final Class deltaType; + + /** + * @param response The response to process + * @return A {@link Stream} of a model class instantiated from the response + */ + @SuppressWarnings("PMD.CloseResource") // Stream is closed automatically when consumed + @Nonnull + Stream handleResponse(@Nonnull final ClassicHttpResponse response) + throws OrchestrationClientException { + if (response.getCode() >= 300) { + buildExceptionAndThrow(response); + } + return IterableStreamConverter.lines(response.getEntity()) + // half of the lines are empty newlines, the last line is "data: [DONE]" + .peek(line -> log.info("Handler: {}", line)) + .filter(line -> !line.isEmpty() && !"data: [DONE]".equals(line.trim())) + .peek( + line -> { + if (!line.startsWith("data: ")) { + final String msg = "Failed to parse response from the Orchestration service"; + parseErrorAndThrow(line, new OrchestrationClientException(msg)); + } + }) + .map( + line -> { + final String data = line.substring(5); // remove "data: " + try { + return JACKSON.readValue(data, deltaType); + } catch (final IOException e) { // exception message e gets lost + log.error( + "Failed to parse the following response from the Orchestration service: {}", + line); + throw new OrchestrationClientException("Failed to parse delta message: " + line, e); + } + }); + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/StreamedDelta.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/StreamedDelta.java new file mode 100644 index 000000000..1f3e033a9 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/StreamedDelta.java @@ -0,0 +1,43 @@ +package com.sap.ai.sdk.orchestration; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * Interface for streamed delta classes. + * + *

This interface defines a method to retrieve the content from a delta, which is a chunk in a + * stream of data. Implementations of this interface should provide the logic to extract the + * relevant content from the delta. + */ +public interface StreamedDelta { + + /** + * Get the message content from the delta. + * + *

Note: If there are multiple choices only the first one is returned + * + *

Note: Some deltas do not contain any content + * + * @return the message content or empty string. + */ + @Nonnull + String getDeltaContent(); + + /** + * Reason for finish. The possible values are: + * + *

{@code stop}: API returned complete message, or a message terminated by one of the stop + * sequences provided via the stop parameter + * + *

{@code length}: Incomplete model output due to max_tokens parameter or token limit + * + *

{@code function_call}: The model decided to call a function + * + *

{@code content_filter}: Omitted content due to a flag from our content filters + * + *

{@code null}: API response still in progress or incomplete + */ + @Nullable + String getFinishReason(); +} diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java index 8cb554f03..29240790f 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java @@ -21,16 +21,23 @@ import static org.apache.hc.core5.http.HttpStatus.SC_BAD_REQUEST; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.when; import com.fasterxml.jackson.core.JsonParseException; import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; import com.github.tomakehurst.wiremock.junit5.WireMockTest; import com.github.tomakehurst.wiremock.stubbing.Scenario; -import com.sap.ai.sdk.orchestration.model.CompletionPostRequest; +import com.sap.ai.sdk.orchestration.model.ChatMessage; import com.sap.ai.sdk.orchestration.model.DPIEntities; import com.sap.ai.sdk.orchestration.model.GenericModuleResult; import com.sap.ai.sdk.orchestration.model.LLMModuleResultSynchronous; +import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor; +import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Cache; import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; import java.io.IOException; import java.io.InputStream; @@ -38,9 +45,19 @@ import java.util.Map; import java.util.Objects; import java.util.function.Function; +import java.util.stream.Stream; +import javax.annotation.Nonnull; +import org.apache.hc.client5.http.classic.HttpClient; +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.io.entity.InputStreamEntity; +import org.apache.hc.core5.http.message.BasicClassicHttpResponse; import org.assertj.core.api.SoftAssertions; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; /** * Test that queries are on the right URL, with the right headers. Also check that the received @@ -58,9 +75,9 @@ class OrchestrationUnitTest { private final Function fileLoader = filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename)); - private OrchestrationClient client; - private OrchestrationModuleConfig config; - private OrchestrationPrompt prompt; + private static OrchestrationClient client; + private static OrchestrationModuleConfig config; + private static OrchestrationPrompt prompt; @BeforeEach void setup(WireMockRuntimeInfo server) { @@ -69,6 +86,13 @@ void setup(WireMockRuntimeInfo server) { client = new OrchestrationClient(destination); config = new OrchestrationModuleConfig().withLlmConfig(CUSTOM_GPT_35); prompt = new OrchestrationPrompt("Hello World! Why is this phrase so famous?"); + ApacheHttpClient5Accessor.setHttpClientCache(ApacheHttpClient5Cache.DISABLED); + } + + @AfterEach + void reset() { + ApacheHttpClient5Accessor.setHttpClientCache(null); + ApacheHttpClient5Accessor.setHttpClientFactory(null); } @Test @@ -304,8 +328,20 @@ void maskingPseudonymization() throws IOException { } } - @Test - void testErrorHandling() { + private static Runnable[] errorHandlingCalls() { + return new Runnable[] { + () -> client.chatCompletion(new OrchestrationPrompt(""), config), + () -> + client + .streamChatCompletion(new OrchestrationPrompt(""), config) + // the stream needs to be consumed to parse the response + .forEach(System.out::println) + }; + } + + @ParameterizedTest + @MethodSource("errorHandlingCalls") + void testErrorHandling(@Nonnull final Runnable request) { stubFor( post(anyUrl()) .inScenario("Errors") @@ -339,7 +375,6 @@ void testErrorHandling() { stubFor(post(anyUrl()).inScenario("Errors").whenScenarioStateIs("4").willReturn(noContent())); final var softly = new SoftAssertions(); - final Runnable request = () -> client.executeRequest(mock(CompletionPostRequest.class)); softly .assertThatThrownBy(request::run) @@ -433,4 +468,150 @@ void testExecuteRequestFromJsonThrows() { .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("not valid JSON"); } + + @Test + void testThrowsOnContentFilter() { + var mock = mock(OrchestrationClient.class); + when(mock.streamChatCompletion(any(), any())).thenCallRealMethod(); + + var deltaWithContentFilter = mock(OrchestrationChatCompletionDelta.class); + when(deltaWithContentFilter.getFinishReason()).thenReturn("content_filter"); + when(mock.streamChatCompletionDeltas(any())).thenReturn(Stream.of(deltaWithContentFilter)); + + // this must not throw, since the stream is lazily evaluated + var stream = mock.streamChatCompletion(new OrchestrationPrompt(""), config); + assertThatThrownBy(stream::toList) + .isInstanceOf(OrchestrationClientException.class) + .hasMessageContaining("Content filter"); + } + + @Test + void streamChatCompletionOutputFilterErrorHandling() throws IOException { + try (var inputStream = spy(fileLoader.apply("streamChatCompletionOutputFilter.txt"))) { + + final var httpClient = mock(HttpClient.class); + ApacheHttpClient5Accessor.setHttpClientFactory(destination -> httpClient); + + // Create a mock response + final var mockResponse = new BasicClassicHttpResponse(200, "OK"); + final var inputStreamEntity = new InputStreamEntity(inputStream, ContentType.TEXT_PLAIN); + mockResponse.setEntity(inputStreamEntity); + mockResponse.setHeader("Content-Type", "text/event-stream"); + + // Configure the HttpClient mock to return the mock response + doReturn(mockResponse).when(httpClient).executeOpen(any(), any(), any()); + + try (Stream stream = client.streamChatCompletion(prompt, config)) { + assertThatThrownBy(() -> stream.forEach(System.out::println)) + .isInstanceOf(OrchestrationClientException.class) + .hasMessage("Content filter filtered the output."); + } + + Mockito.verify(inputStream, times(1)).close(); + } + } + + @Test + void streamChatCompletionDeltas() throws IOException { + try (var inputStream = spy(fileLoader.apply("streamChatCompletion.txt"))) { + + final var httpClient = mock(HttpClient.class); + ApacheHttpClient5Accessor.setHttpClientFactory(destination -> httpClient); + + // Create a mock response + final var mockResponse = new BasicClassicHttpResponse(200, "OK"); + final var inputStreamEntity = new InputStreamEntity(inputStream, ContentType.TEXT_PLAIN); + mockResponse.setEntity(inputStreamEntity); + mockResponse.setHeader("Content-Type", "text/event-stream"); + + // Configure the HttpClient mock to return the mock response + doReturn(mockResponse).when(httpClient).executeOpen(any(), any(), any()); + + var prompt = + new OrchestrationPrompt( + "Can you give me the first 100 numbers of the Fibonacci sequence?"); + var request = OrchestrationClient.toCompletionPostRequest(prompt, config); + + try (Stream stream = + client.streamChatCompletionDeltas(request)) { + var deltaList = stream.toList(); + + assertThat(deltaList).hasSize(3); + // the first delta doesn't have any content + assertThat(deltaList.get(0).getDeltaContent()).isEqualTo(""); + assertThat(deltaList.get(1).getDeltaContent()).isEqualTo("Sure"); + assertThat(deltaList.get(2).getDeltaContent()).isEqualTo("!"); + + assertThat(deltaList.get(0).getRequestId()) + .isEqualTo("5bd87b41-6368-4c18-aaae-47ab82e9475b"); + assertThat(deltaList.get(1).getRequestId()) + .isEqualTo("5bd87b41-6368-4c18-aaae-47ab82e9475b"); + assertThat(deltaList.get(2).getRequestId()) + .isEqualTo("5bd87b41-6368-4c18-aaae-47ab82e9475b"); + + assertThat(deltaList.get(0).getFinishReason()).isEqualTo(""); + assertThat(deltaList.get(1).getFinishReason()).isEqualTo(""); + assertThat(deltaList.get(2).getFinishReason()).isEqualTo("stop"); + + // should be of type LLMModuleResultStreaming, will be fixed with a discriminator + var result0 = (LLMModuleResultSynchronous) deltaList.get(0).getOrchestrationResult(); + var result1 = (LLMModuleResultSynchronous) deltaList.get(1).getOrchestrationResult(); + var result2 = (LLMModuleResultSynchronous) deltaList.get(2).getOrchestrationResult(); + + assertThat(result0.getSystemFingerprint()).isEmpty(); + assertThat(result0.getId()).isEmpty(); + assertThat(result0.getCreated()).isEqualTo(0); + assertThat(result0.getModel()).isEmpty(); + assertThat(result0.getObject()).isEmpty(); + // BUG: usage is absent from the request + assertThat(result0.getUsage()).isNull(); + assertThat(result0.getChoices()).hasSize(1); + final var choices0 = result0.getChoices().get(0); + assertThat(choices0.getIndex()).isEqualTo(0); + assertThat(choices0.getFinishReason()).isEmpty(); + assertThat(choices0.getCustomField("delta")).isNotNull(); + // this should be getDelta(), only when the result is of type LLMModuleResultStreaming + final var message0 = (Map) choices0.getCustomField("delta"); + assertThat(message0.get("role")).isEqualTo(""); + assertThat(message0.get("content")).isEqualTo(""); + List templating = deltaList.get(0).getModuleResults().getTemplating(); + assertThat(templating).hasSize(1); + assertThat(templating.get(0).getRole()).isEqualTo("user"); + assertThat(templating.get(0).getContent()) + .isEqualTo("Hello world! Why is this phrase so famous?"); + + assertThat(result1.getSystemFingerprint()).isEqualTo("fp_808245b034"); + assertThat(result1.getId()).isEqualTo("chatcmpl-AYZSQQwWv7ajJsyDBpMG4X01BBJxq"); + assertThat(result1.getCreated()).isEqualTo(1732802814); + assertThat(result1.getModel()).isEqualTo("gpt-35-turbo"); + assertThat(result1.getObject()).isEqualTo("chat.completion.chunk"); + assertThat(result1.getUsage()).isNull(); + assertThat(result1.getChoices()).hasSize(1); + final var choices1 = result1.getChoices().get(0); + assertThat(choices1.getIndex()).isEqualTo(0); + assertThat(choices1.getFinishReason()).isEmpty(); + assertThat(choices1.getCustomField("delta")).isNotNull(); + final var message1 = (Map) choices1.getCustomField("delta"); + assertThat(message1.get("role")).isEqualTo("assistant"); + assertThat(message1.get("content")).isEqualTo("Sure"); + + assertThat(result2.getSystemFingerprint()).isEqualTo("fp_808245b034"); + assertThat(result2.getId()).isEqualTo("chatcmpl-AYZSQQwWv7ajJsyDBpMG4X01BBJxq"); + assertThat(result2.getCreated()).isEqualTo(1732802814); + assertThat(result2.getModel()).isEqualTo("gpt-35-turbo"); + assertThat(result2.getObject()).isEqualTo("chat.completion.chunk"); + assertThat(result2.getUsage()).isNull(); + assertThat(result2.getChoices()).hasSize(1); + final var choices2 = result2.getChoices().get(0); + assertThat(choices2.getIndex()).isEqualTo(0); + assertThat(choices2.getFinishReason()).isEqualTo("stop"); + // this should be getDelta(), only when the result is of type LLMModuleResultStreaming + assertThat(choices2.getCustomField("delta")).isNotNull(); + final var message2 = (Map) choices2.getCustomField("delta"); + assertThat(message2.get("role")).isEqualTo("assistant"); + assertThat(message2.get("content")).isEqualTo("!"); + } + Mockito.verify(inputStream, times(1)).close(); + } + } } diff --git a/orchestration/src/test/resources/__files/streamChatCompletionInputFilter.json b/orchestration/src/test/resources/__files/streamChatCompletionInputFilter.json new file mode 100644 index 000000000..4d46693de --- /dev/null +++ b/orchestration/src/test/resources/__files/streamChatCompletionInputFilter.json @@ -0,0 +1,22 @@ +{ + "request_id": "b589de57-512e-4e11-9b69-8601453b3296", + "code": 400, + "message": "Content filtered due to safety violations. Please modify the prompt and try again.", + "location": "Filtering Module - Input Filter", + "module_results": { + "templating": [ + { + "role": "user", + "content": "Fuck you" + } + ], + "input_filtering": { + "message": "Content filtered due to safety violations. Please modify the prompt and try again.", + "data": { + "azure_content_safety": { + "Hate": 2 + } + } + } + } +} diff --git a/orchestration/src/test/resources/streamChatCompletion.txt b/orchestration/src/test/resources/streamChatCompletion.txt new file mode 100644 index 000000000..15a7bbd88 --- /dev/null +++ b/orchestration/src/test/resources/streamChatCompletion.txt @@ -0,0 +1,4 @@ +data: {"request_id": "5bd87b41-6368-4c18-aaae-47ab82e9475b", "module_results": {"templating": [{"role": "user", "content": "Hello world! Why is this phrase so famous?"}]}, "orchestration_result": {"id": "", "object": "", "created": 0, "model": "", "system_fingerprint": "", "choices": [{"index": 0, "delta": {"role": "", "content": ""}, "finish_reason": ""}]}} +data: {"request_id": "5bd87b41-6368-4c18-aaae-47ab82e9475b", "module_results": {"llm": {"id": "chatcmpl-AYZSQQwWv7ajJsyDBpMG4X01BBJxq", "object": "chat.completion.chunk", "created": 1732802814, "model": "gpt-35-turbo", "system_fingerprint": "fp_808245b034", "choices": [{"index": 0, "delta": {"role": "assistant", "content": "Sure"}, "finish_reason": ""}]}}, "orchestration_result": {"id": "chatcmpl-AYZSQQwWv7ajJsyDBpMG4X01BBJxq", "object": "chat.completion.chunk", "created": 1732802814, "model": "gpt-35-turbo", "system_fingerprint": "fp_808245b034", "choices": [{"index": 0, "delta": {"role": "assistant", "content": "Sure"}, "finish_reason": ""}]}} +data: {"request_id": "5bd87b41-6368-4c18-aaae-47ab82e9475b", "module_results": {"llm": {"id": "chatcmpl-AYZSQQwWv7ajJsyDBpMG4X01BBJxq", "object": "chat.completion.chunk", "created": 1732802814, "model": "gpt-35-turbo", "system_fingerprint": "fp_808245b034", "choices": [{"index": 0, "delta": {"role": "assistant", "content": "!"}, "finish_reason": "stop"}]}}, "orchestration_result": {"id": "chatcmpl-AYZSQQwWv7ajJsyDBpMG4X01BBJxq", "object": "chat.completion.chunk", "created": 1732802814, "model": "gpt-35-turbo", "system_fingerprint": "fp_808245b034", "choices": [{"index": 0, "delta": {"role": "assistant", "content": "!"}, "finish_reason": "stop"}]}} +data: [DONE] diff --git a/orchestration/src/test/resources/streamChatCompletionOutputFilter.txt b/orchestration/src/test/resources/streamChatCompletionOutputFilter.txt new file mode 100644 index 000000000..cb473481a --- /dev/null +++ b/orchestration/src/test/resources/streamChatCompletionOutputFilter.txt @@ -0,0 +1,2 @@ +data: {"request_id": "eec90bca-a43e-43fa-864e-1d8962341350", "module_results": {"templating": [{"role": "user", "content": "Create 3 paraphrases of the following text: 'I hate you.'"}]}, "orchestration_result": {"id": "", "object": "", "created": 0, "model": "", "system_fingerprint": "", "choices": [{"index": 0, "delta": {"role": "", "content": ""}, "finish_reason": ""}]}} +data: {"request_id": "eec90bca-a43e-43fa-864e-1d8962341350", "module_results": {"llm": {"id": "chatcmpl-Ab4mSDp5DXFu7hfbs2DkCsVJaM4IP", "object": "chat.completion.chunk", "created": 1733399876, "model": "gpt-35-turbo", "system_fingerprint": "fp_808245b034", "choices": [{"index": 0, "delta": {"role": "assistant", "content": "1. I can't stand you.\n2. You are detestable to me.\n3. I have a strong aversion towards you."}, "finish_reason": "stop"}]}, "output_filtering": {"message": "Content filtered due to safety violations. Model returned a result violating the safety threshold. Please modify the prompt and try again.", "data": {"original_service_response": {"azure_content_safety": {"content_allowed": false, "original_service_response": {"Hate": 2}, "checked_text": "1. I can't stand you. 2. You are detestable to me. 3. I have a strong aversion towards you."}}}}}, "orchestration_result": {"id": "chatcmpl-Ab4mSDp5DXFu7hfbs2DkCsVJaM4IP", "object": "chat.completion.chunk", "created": 1733399876, "model": "gpt-35-turbo", "system_fingerprint": "fp_808245b034", "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": "content_filter"}]}} diff --git a/pom.xml b/pom.xml index a4bc65f57..7ed778c37 100644 --- a/pom.xml +++ b/pom.xml @@ -76,11 +76,11 @@ false false - 75% - 67% - 69% - 76% - 85% + 77% + 68% + 71% + 79% + 100% 85% diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java index 20e69c246..2c843c2a9 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java @@ -123,8 +123,7 @@ ResponseEntity streamChatCompletion() { return ResponseEntity.ok().contentType(MediaType.TEXT_EVENT_STREAM).body(emitter); } - private static void send( - @Nonnull final ResponseBodyEmitter emitter, @Nonnull final String chunk) { + static void send(@Nonnull final ResponseBodyEmitter emitter, @Nonnull final String chunk) { try { emitter.send(chunk); } catch (final IOException e) { diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java index f6f1b5496..a1209a2c5 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java @@ -1,5 +1,6 @@ package com.sap.ai.sdk.app.controllers; +import static com.sap.ai.sdk.app.controllers.OpenAiController.send; import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_35_TURBO; import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.Parameter.TEMPERATURE; @@ -18,14 +19,18 @@ import com.sap.ai.sdk.orchestration.model.GroundingModuleConfig; import com.sap.ai.sdk.orchestration.model.GroundingModuleConfigConfig; import com.sap.ai.sdk.orchestration.model.Template; +import com.sap.cloud.sdk.cloudplatform.thread.ThreadContextExecutors; import java.util.List; import java.util.Map; import javax.annotation.Nonnull; import lombok.extern.slf4j.Slf4j; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter; /** Endpoints for the Orchestration service */ @RestController @@ -53,6 +58,40 @@ OrchestrationChatResponse completion() { return result; } + /** + * Asynchronous stream of an OpenAI chat request + * + * @return the emitter that streams the assistant message response + */ + @SuppressWarnings("unused") // The end-to-end test doesn't use this method + @GetMapping("/streamChatCompletion") + @Nonnull + public ResponseEntity streamChatCompletion() { + final var prompt = + new OrchestrationPrompt("Can you give me the first 100 numbers of the Fibonacci sequence?"); + final var stream = client.streamChatCompletion(prompt, config); + + final var emitter = new ResponseBodyEmitter(); + + final Runnable consumeStream = + () -> { + try (stream) { + stream.forEach( + deltaMessage -> { + log.info("Controller: {}", deltaMessage); + send(emitter, deltaMessage); + }); + } finally { + emitter.complete(); + } + }; + + ThreadContextExecutors.getExecutor().execute(consumeStream); + + // TEXT_EVENT_STREAM allows the browser to display the content as it is streamed + return ResponseEntity.ok().contentType(MediaType.TEXT_EVENT_STREAM).body(emitter); + } + /** * Chat request to OpenAI through the Orchestration service with a template. * diff --git a/sample-code/spring-app/src/main/resources/static/index.html b/sample-code/spring-app/src/main/resources/static/index.html index 497c382e3..468874f58 100644 --- a/sample-code/spring-app/src/main/resources/static/index.html +++ b/sample-code/spring-app/src/main/resources/static/index.html @@ -67,6 +67,7 @@

Endpoints

  • Orchestration

    • /orchestration/completion
    • +
    • /orchestration/streamChatCompletion
    • /orchestration/template
    • /orchestration/messagesHistory
    • /orchestration/filter/NUMBER_4 Loose filter
    • diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java index 35944f7f8..258aa156f 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java @@ -4,11 +4,14 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import com.sap.ai.sdk.orchestration.AzureFilterThreshold; +import com.sap.ai.sdk.orchestration.OrchestrationClient; import com.sap.ai.sdk.orchestration.OrchestrationClientException; +import com.sap.ai.sdk.orchestration.OrchestrationPrompt; import com.sap.ai.sdk.orchestration.model.CompletionPostResponse; import com.sap.ai.sdk.orchestration.model.LLMChoice; import com.sap.ai.sdk.orchestration.model.LLMModuleResultSynchronous; import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -31,6 +34,27 @@ void testCompletion() { assertThat(result.getContent()).isNotEmpty(); } + @Test + void testStreamChatCompletion() { + final var prompt = new OrchestrationPrompt("Who is the prettiest?"); + final var stream = new OrchestrationClient().streamChatCompletion(prompt, controller.config); + + final var filledDeltaCount = new AtomicInteger(0); + stream + // foreach consumes all elements, closing the stream at the end + .forEach( + delta -> { + log.info("delta: {}", delta); + if (!delta.isEmpty()) { + filledDeltaCount.incrementAndGet(); + } + }); + + // the first two and the last delta don't have any content + // see OpenAiChatCompletionDelta#getDeltaContent + assertThat(filledDeltaCount.get()).isGreaterThan(0); + } + @Test void testTemplate() { assertThat(controller.config.getLlmConfig()).isNotNull();