Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Orchestration streaming #227

Merged
merged 11 commits into from
Dec 10, 2024
26 changes: 26 additions & 0 deletions docs/guides/ORCHESTRATION_CHAT_COMPLETION.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,32 @@ Use the grounding module to provide additional context to the AI model.

In this example, the AI model is provided with additional context in the form of grounding information. Note, that it is necessary to provide the grounding input via one or more input variables.

### Stream chat completion

It's possible to pass a stream of chat completion delta elements, e.g. from the application backend to the frontend in real-time.

#### Asynchronous Streaming

This is a blocking example for streaming and printing directly to the console:

```java
String msg = "Can you give me the first 100 numbers of the Fibonacci sequence?";

// try-with-resources on stream ensures the connection will be closed
try (Stream<String> stream = client.streamChatCompletion(prompt, config)) {
stream.forEach(
deltaString -> {
System.out.print(deltaString);
System.out.flush();
});
}
```

#### Spring Boot example

Please find [an example in our Spring Boot application](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java).
It shows the usage of Spring Boot's `ResponseBodyEmitter` to stream the chat completion delta messages to the frontend in real-time.

### Set model parameters

Change your LLM configuration to add model parameters:
Expand Down
2 changes: 1 addition & 1 deletion docs/release-notes/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

### ✨ New Functionality

-
- Added `streamChatCompletion()` and `streamChatCompletionDeltas()` to the `OrchestrationClient`.

### 📈 Improvements

Expand Down
5 changes: 5 additions & 0 deletions orchestration/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<profiles>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package com.sap.ai.sdk.orchestration;

import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Spliterator.NONNULL;
import static java.util.Spliterator.ORDERED;

import io.vavr.control.Try;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.Spliterators;
import java.util.concurrent.Callable;
import java.util.function.Function;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.hc.core5.http.HttpEntity;

/**
* Internal utility class to convert from a reading handler to {@link Iterable} and {@link Stream}.
*
* <p><strong>Note:</strong> All operations are sequential in nature. Thread safety is not
* guaranteed.
*
* @param <T> Iterated item type.
*/
@Slf4j
@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
class IterableStreamConverter<T> implements Iterator<T> {
/** see DEFAULT_CHAR_BUFFER_SIZE in {@link BufferedReader} * */
static final int BUFFER_SIZE = 8192;

/** Read next entry for Stream or {@code null} when no further entry can be read. */
private final Callable<T> readHandler;

/** Close handler to be called when Stream terminated. */
private final Runnable stopHandler;

/** Error handler to be called when Stream is interrupted. */
private final Function<Exception, RuntimeException> errorHandler;

private boolean isDone = false;
private boolean isNextFetched = false;
private T next = null;

@SuppressWarnings("checkstyle:IllegalCatch")
@Override
public boolean hasNext() {
if (isDone) {
return false;
}
if (isNextFetched) {
return true;
}
try {
next = readHandler.call();
isNextFetched = true;
if (next == null) {
isDone = true;
stopHandler.run();
}
} catch (final Exception e) {
isDone = true;
stopHandler.run();
log.debug("Error while reading next element.", e);
throw errorHandler.apply(e);
}
return !isDone;
}

@Override
public T next() {
if (next == null && !hasNext()) {
throw new NoSuchElementException(); // normally not reached with Stream API
}
isNextFetched = false;
return next;
}

/**
* Create a sequential Stream of lines from an HTTP response string (UTF-8). The underlying {@link
* InputStream} is closed, when the resulting Stream is closed (e.g. via try-with-resources) or
* when an exception occurred.
*
* @param entity The HTTP entity object.
* @return A sequential Stream object.
* @throws OrchestrationClientException if the provided HTTP entity object is {@code null} or
* empty.
*/
@SuppressWarnings("PMD.CloseResource") // Stream is closed automatically when consumed
@Nonnull
static Stream<String> lines(@Nullable final HttpEntity entity)
throws OrchestrationClientException {
if (entity == null) {
throw new OrchestrationClientException("Orchestration service response was empty.");
}

final InputStream inputStream;
try {
inputStream = entity.getContent();
} catch (final IOException e) {
throw new OrchestrationClientException("Failed to read response content.", e);
}

final var reader = new BufferedReader(new InputStreamReader(inputStream, UTF_8), BUFFER_SIZE);
final Runnable closeHandler =
() -> Try.run(reader::close).onFailure(e -> log.error("Could not close input stream", e));
final Function<Exception, RuntimeException> errHandler =
e -> new OrchestrationClientException("Parsing response content was interrupted.", e);

final var iterator = new IterableStreamConverter<>(reader::readLine, closeHandler, errHandler);
final var spliterator = Spliterators.spliteratorUnknownSize(iterator, ORDERED | NONNULL);
return StreamSupport.stream(spliterator, /* NOT PARALLEL */ false).onClose(closeHandler);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package com.sap.ai.sdk.orchestration;

import com.sap.ai.sdk.orchestration.model.CompletionPostResponse;
import com.sap.ai.sdk.orchestration.model.LLMModuleResultSynchronous;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.val;

/** Orchestration chat completion output delta for streaming. */
public class OrchestrationChatCompletionDelta extends CompletionPostResponse
implements StreamedDelta {

@Nonnull
@Override
// will be fixed once the generated code add a discriminator which will allow this class to extend
// CompletionPostResponseStreaming
@SuppressWarnings("unchecked")
public String getDeltaContent() {
val choices = ((LLMModuleResultSynchronous) getOrchestrationResult()).getChoices();
// Avoid the first delta: "choices":[]
if (!choices.isEmpty()
// Multiple choices are spread out on multiple deltas
// A delta only contains one choice with a variable index
&& choices.get(0).getIndex() == 0) {

final var message = (Map<String, Object>) choices.get(0).getCustomField("delta");
// Avoid the second delta: "choices":[{"delta":{"content":"","role":"assistant"}}]
if (message != null && message.get("content") != null) {
return message.get("content").toString();
}
}
return "";
}

@Nullable
@Override
public String getFinishReason() {
return ((LLMModuleResultSynchronous) getOrchestrationResult())
.getChoices()
.get(0)
.getFinishReason();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.HttpClientInstantiationException;
import java.io.IOException;
import java.util.function.Supplier;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.hc.client5.http.classic.methods.HttpPost;
import org.apache.hc.core5.http.ContentType;
import org.apache.hc.core5.http.io.entity.StringEntity;
import org.apache.hc.core5.http.message.BasicClassicHttpRequest;

/** Client to execute requests to the orchestration service. */
@Slf4j
Expand Down Expand Up @@ -105,6 +107,33 @@ public OrchestrationChatResponse chatCompletion(
return new OrchestrationChatResponse(executeRequest(request));
}

/**
* Generate a completion for the given prompt.
*
* @param prompt a text message.
* @return A stream of message deltas
* @throws OrchestrationClientException if the request fails or if the finish reason is
* content_filter
* @since 1.1.0
*/
@Nonnull
public Stream<String> streamChatCompletion(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Comment)

New method names match the existing OpenAiClient counter-part in openai module.

@Nonnull final OrchestrationPrompt prompt, @Nonnull final OrchestrationModuleConfig config)
throws OrchestrationClientException {

val request = toCompletionPostRequest(prompt, config);
return streamChatCompletionDeltas(request)
.peek(OrchestrationClient::throwOnContentFilter)
.map(OrchestrationChatCompletionDelta::getDeltaContent);
}

private static void throwOnContentFilter(@Nonnull final OrchestrationChatCompletionDelta delta) {
final String finishReason = delta.getFinishReason();
if (finishReason != null && finishReason.equals("content_filter")) {
throw new OrchestrationClientException("Content filter filtered the output.");
}
}

/**
* Serializes the given request, executes it and deserializes the response.
*
Expand Down Expand Up @@ -205,4 +234,53 @@ CompletionPostResponse executeRequest(@Nonnull final String request) {
throw new OrchestrationClientException("Failed to execute request", e);
}
}

/**
* Generate a completion for the given prompt.
*
* @param request the prompt, including messages and other parameters.
* @return A stream of chat completion delta elements.
* @throws OrchestrationClientException if the request fails
* @since 1.1.0
*/
@Nonnull
public Stream<OrchestrationChatCompletionDelta> streamChatCompletionDeltas(
@Nonnull final CompletionPostRequest request) throws OrchestrationClientException {
request.getOrchestrationConfig().setStream(true);
return executeStream("/completion", request, OrchestrationChatCompletionDelta.class);
}

@Nonnull
private <D extends StreamedDelta> Stream<D> executeStream(
@Nonnull final String path,
@Nonnull final Object payload,
@Nonnull final Class<D> deltaType) {
final var request = new HttpPost(path);
serializeAndSetHttpEntity(request, payload);
return streamRequest(request, deltaType);
}

private static void serializeAndSetHttpEntity(
@Nonnull final BasicClassicHttpRequest request, @Nonnull final Object payload) {
try {
final var json = JACKSON.writeValueAsString(payload);
request.setEntity(new StringEntity(json, ContentType.APPLICATION_JSON));
} catch (final JsonProcessingException e) {
throw new OrchestrationClientException("Failed to serialize request parameters", e);
}
}

@Nonnull
private <D extends StreamedDelta> Stream<D> streamRequest(
final BasicClassicHttpRequest request, @Nonnull final Class<D> deltaType) {
try {
val destination = destinationSupplier.get();
log.debug("Using destination {} to connect to orchestration service", destination);
val client = ApacheHttpClient5Accessor.getHttpClient(destination);
return new OrchestrationStreamingHandler<>(deltaType)
.handleResponse(client.executeOpen(null, request, null));
} catch (final IOException e) {
throw new OrchestrationClientException("Request to the Orchestration service failed", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package com.sap.ai.sdk.orchestration;

import static com.sap.ai.sdk.orchestration.OrchestrationClient.JACKSON;
import static com.sap.ai.sdk.orchestration.OrchestrationResponseHandler.buildExceptionAndThrow;
import static com.sap.ai.sdk.orchestration.OrchestrationResponseHandler.parseErrorAndThrow;

import java.io.IOException;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.hc.core5.http.ClassicHttpResponse;

@Slf4j
@RequiredArgsConstructor
class OrchestrationStreamingHandler<D extends StreamedDelta> {

@Nonnull private final Class<D> deltaType;

/**
* @param response The response to process
* @return A {@link Stream} of a model class instantiated from the response
*/
@SuppressWarnings("PMD.CloseResource") // Stream is closed automatically when consumed
@Nonnull
Stream<D> handleResponse(@Nonnull final ClassicHttpResponse response)
throws OrchestrationClientException {
if (response.getCode() >= 300) {
buildExceptionAndThrow(response);
}
return IterableStreamConverter.lines(response.getEntity())
// half of the lines are empty newlines, the last line is "data: [DONE]"
.peek(line -> log.info("Handler: {}", line))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Comment)

I'm surprised we have so many info log statements. I feel like we could define a coding guideline when to use what log level.

(no action required.)

.filter(line -> !line.isEmpty() && !"data: [DONE]".equals(line.trim()))
.peek(
line -> {
if (!line.startsWith("data: ")) {
final String msg = "Failed to parse response from the Orchestration service";
parseErrorAndThrow(line, new OrchestrationClientException(msg));
}
})
.map(
line -> {
final String data = line.substring(5); // remove "data: "
try {
return JACKSON.readValue(data, deltaType);
} catch (final IOException e) { // exception message e gets lost
log.error(
"Failed to parse the following response from the Orchestration service: {}",
line);
throw new OrchestrationClientException("Failed to parse delta message: " + line, e);
}
});
}
}
Loading
Loading