Skip to content

Commit

Permalink
[Orchestration] OSS Generator Improvements (#158)
Browse files Browse the repository at this point in the history
* Simplifications and cleanup

* Remove parsing for streaming
  • Loading branch information
MatKuhr authored Nov 14, 2024
1 parent 67840db commit 5b4cdf6
Show file tree
Hide file tree
Showing 29 changed files with 262 additions and 2,616 deletions.
4 changes: 0 additions & 4 deletions orchestration/.openapi-generator-ignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,3 @@ settings.gradle
src/main/AndroidManifest.xml
api/
.openapi-generator/


src/main/java/com/sap/ai/sdk/orchestration/client/model/LLMModuleResult.java
src/main/java/com/sap/ai/sdk/orchestration/client/model/ModuleResultsOutputUnmaskingInner.java
151 changes: 87 additions & 64 deletions orchestration/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</dependency>
<dependency>
<groupId>javax.annotation</groupId>
<artifactId>javax.annotation-api</artifactId>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</dependency>

<!-- TODO: only needed for JsonObjectMapperBuilder, maybe we can use Jackson natively to avoid this dependency -->
<dependency>
Expand Down Expand Up @@ -118,73 +126,88 @@
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>

<!-- Additional dependencies for OSS code generator resttemplate -->
<dependency>
<groupId>org.openapitools</groupId>
<artifactId>jackson-databind-nullable</artifactId>
</dependency>
<dependency>
<groupId>javax.annotation</groupId>
<artifactId>javax.annotation-api</artifactId>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</dependency>
</dependencies>

<build>
<plugins>
<plugin>
<groupId>org.openapitools</groupId>
<artifactId>openapi-generator-maven-plugin</artifactId>
<executions>
<execution>
<goals>
<goal>generate</goal>
</goals>
<phase>generate-sources</phase>
<configuration>
<!-- Specify the input OpenAPI spec file -->
<inputSpec>${project.basedir}/src/main/resources/spec/orchestration.yaml</inputSpec>
<output>${project.basedir}</output>
<!-- Specify the generator to use, e.g., java, spring, kotlin, etc. -->
<generatorName>java</generatorName>
<!-- Specify the package names for models, APIs, and invokers -->
<modelPackage>com.sap.ai.sdk.orchestration.client.model</modelPackage>
<apiPackage>com.sap.ai.sdk.orchestration.client.api</apiPackage>
<invokerPackage>com.sap.ai.sdk.orchestration.client.invoker</invokerPackage>
<profiles>
<profile>
<id>generate</id>
<activation>
<activeByDefault>false</activeByDefault>
<property>
<name>generate</name>
</property>
</activation>
<build>
<plugins>
<plugin>
<groupId>org.openapitools</groupId>
<artifactId>openapi-generator-maven-plugin</artifactId>
<executions>
<execution>
<goals>
<goal>generate</goal>
</goals>
<phase>generate-sources</phase>
<configuration>
<!-- Specify the input OpenAPI spec file -->
<inputSpec>${project.basedir}/src/main/resources/spec/orchestration.yaml</inputSpec>
<output>${project.basedir}</output>
<!-- Specify the generator to use, e.g., java, spring, kotlin, etc. -->
<generatorName>java</generatorName>
<!-- Specify the package names for models, APIs, and invokers -->
<modelPackage>com.sap.ai.sdk.orchestration.client.model</modelPackage>
<apiPackage>com.sap.ai.sdk.orchestration.client.api</apiPackage>
<invokerPackage>com.sap.ai.sdk.orchestration.client.invoker</invokerPackage>

<!-- Global properties level; can be unpacked with 'generate' prefix-->
<globalProperties>
<apiDocs>false</apiDocs>
<modelDocs>false</modelDocs>
<modelTests>false</modelTests>
<apiTests>false</apiTests>
<minimalUpdate>true</minimalUpdate>
</globalProperties>
<!-- Global properties level; can be unpacked with 'generate' prefix-->
<globalProperties>
<apiDocs>false</apiDocs>
<modelDocs>false</modelDocs>
<modelTests>false</modelTests>
<apiTests>false</apiTests>
<minimalUpdate>true</minimalUpdate>
</globalProperties>

<!-- Generator Specific properties level; some can be unpacked-->
<configOptions>
<generateBuilders>true</generateBuilders>
<failOnUnknownProperties>false</failOnUnknownProperties>
<hideGenerationTimestamp>true</hideGenerationTimestamp>
<disallowAdditionalPropertiesIfNotPresent>false</disallowAdditionalPropertiesIfNotPresent>
<enumUnknownDefaultCase>true</enumUnknownDefaultCase>
<useBeanValidation>false</useBeanValidation>
<useOneOfInterfaces>true</useOneOfInterfaces>
<additionalModelTypeAnnotations>@com.google.common.annotations.Beta</additionalModelTypeAnnotations>
</configOptions>
<generateModels>true</generateModels>
<generateSupportingFiles>false</generateSupportingFiles>
<generateApis>false</generateApis>
<library>resttemplate</library>
<!--<configHelp>true</configHelp>-->
<!-- Generator Specific properties level; some can be unpacked-->
<configOptions>
<generateBuilders>true</generateBuilders>
<failOnUnknownProperties>false</failOnUnknownProperties>
<hideGenerationTimestamp>true</hideGenerationTimestamp>
<disallowAdditionalPropertiesIfNotPresent>false</disallowAdditionalPropertiesIfNotPresent>
<enumUnknownDefaultCase>true</enumUnknownDefaultCase>
<useBeanValidation>false</useBeanValidation>
<useOneOfInterfaces>true</useOneOfInterfaces>
<additionalModelTypeAnnotations>@com.google.common.annotations.Beta</additionalModelTypeAnnotations>
</configOptions>
<generateModels>true</generateModels>
<generateSupportingFiles>false</generateSupportingFiles>
<generateApis>false</generateApis>
<library>resttemplate</library>
<!--<configHelp>true</configHelp>-->
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-clean-plugin</artifactId>
<configuration>
<filesets>
<fileset>
<directory>${project.basedir}/src/main/java/com/sap/ai/sdk/orchestration/client</directory>
<includes>
<include>**/*</include>
</includes>
</fileset>
</filesets>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
<executions>
<execution>
<id>delete-orchestration-generated-client</id>
</execution>
</executions>
</plugin>
</plugins>
</build>
</profile>
</profiles>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ static TemplatingModuleConfig toTemplateModuleConfig(
* In this case, the request will fail, since the templating module will try to resolve the parameter.
* To be fixed with https://github.tools.sap/AI/llm-orchestration/issues/662
*/
val messages = Option.of(template).map(t -> ((Template) t).getTemplate()).getOrElse(List::of);
val messages =
Option.of(template)
.filter(Template.class::isInstance)
.map(Template.class::cast)
.map(Template::getTemplate)
.getOrElse(List::of);
val messagesWithPrompt = new ArrayList<>(messages);
messagesWithPrompt.addAll(prompt.getMessages());
if (messagesWithPrompt.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,82 +1,45 @@
package com.sap.ai.sdk.orchestration;

import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.*;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
import com.sap.ai.sdk.orchestration.client.model.LLMModuleResult;
import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultStreaming;
import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous;
import java.io.IOException;
import java.io.Serial;
import javax.annotation.Nonnull;
import lombok.val;

/**
* A deserializer for {@link LLMModuleResult} that determines the concrete implementation based on
* the structure of the JSON object.
*/
public class LLMModuleResultDeserializer extends StdDeserializer<LLMModuleResult> {
class LLMModuleResultDeserializer extends StdDeserializer<LLMModuleResult> {
// checkstyle requires a serialVersionUid since StdDeserializer implements Serializable
@Serial private static final long serialVersionUID = 1L;

public LLMModuleResultDeserializer() {
/** Default constructor. */
LLMModuleResultDeserializer() {
super(LLMModuleResult.class);
}

/**
* Deserialize the JSON object into one of the subtypes of the base type.
*
* <ul>
* <li>If elements of "choices" array contains "delta", deserialize into {@link
* LLMModuleResultStreaming}.
* <li>Otherwise, deserialize into {@link LLMModuleResultSynchronous}.
* </ul>
* Always deserialize into {@link LLMModuleResultSynchronous} since streaming isn't supported yet.
*
* @param parser The JSON parser.
* @param context The deserialization context.
* @return The deserialized object.
* @throws IOException If an I/O error occurs.
*/
@Nonnull
@Override
public LLMModuleResult deserialize(JsonParser parser, @Nonnull DeserializationContext context)
public LLMModuleResult deserialize(
@Nonnull final JsonParser parser, @Nonnull final DeserializationContext context)
throws IOException {
val mapper = (ObjectMapper) parser.getCodec();
val rootNode = mapper.readTree(parser);

// Check if the target type is a concrete class
JavaType targetType = context.getContextualType();
if (targetType != null && !LLMModuleResult.class.equals(targetType.getRawClass())) {
return delegateToDefaultDeserializer(parser, context, targetType);
}

// Custom deserialization logic for LLMModuleResult interface
var mapper = (ObjectMapper) parser.getCodec();
var rootNode = mapper.readTree(parser);
Class<? extends LLMModuleResult> concreteClass = LLMModuleResultSynchronous.class;

// Inspect the "choices" field
var choicesNode = rootNode.get("choices");
if (choicesNode != null && choicesNode.isArray()) {
var firstChoice = (JsonNode) choicesNode.get(0);
if (firstChoice != null && firstChoice.has("delta")) {
concreteClass = LLMModuleResultStreaming.class;
}
}

// Create a new parser for the root node
var rootParser = rootNode.traverse(mapper);
rootParser.nextToken(); // Advance to the first token

// Use the default deserializer for the concrete class
return delegateToDefaultDeserializer(rootParser, context, mapper.constructType(concreteClass));
}

/**
* Delegate deserialization to the default deserializer for the given concrete type.
*
* @param parser The JSON parser.
* @param context The deserialization context.
* @param concreteType The concrete type to deserialize into.
* @return The deserialized object.
* @throws IOException If an I/O error occurs.
*/
private LLMModuleResult delegateToDefaultDeserializer(
JsonParser parser, DeserializationContext context, JavaType concreteType) throws IOException {
var defaultDeserializer = context.findRootValueDeserializer(concreteType);
return (LLMModuleResult) defaultDeserializer.deserialize(parser, context);
return mapper.readValue(rootNode.toString(), LLMModuleResultSynchronous.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@
import com.sap.ai.sdk.orchestration.client.model.CompletionPostRequest;
import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse;
import com.sap.ai.sdk.orchestration.client.model.FilterConfig;
import com.sap.ai.sdk.orchestration.client.model.GroundingModuleConfigConfigFiltersInner;
import com.sap.ai.sdk.orchestration.client.model.LLMModuleResult;
import com.sap.ai.sdk.orchestration.client.model.MaskingProviderConfig;
import com.sap.ai.sdk.orchestration.client.model.ModuleConfigs;
import com.sap.ai.sdk.orchestration.client.model.ModuleResultsOutputUnmaskingInner;
import com.sap.ai.sdk.orchestration.client.model.OrchestrationConfig;
import com.sap.ai.sdk.orchestration.client.model.TemplateRefTemplateRef;
import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig;
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.DestinationAccessException;
Expand Down Expand Up @@ -47,10 +46,10 @@ public class OrchestrationClient {
.visibility(PropertyAccessor.SETTER, JsonAutoDetect.Visibility.NONE)
.serializationInclusion(JsonInclude.Include.NON_NULL)
.deserializerByType(LLMModuleResult.class, new LLMModuleResultDeserializer())
.mixIn(LLMModuleResult.class, NoTypeInfoMixin.class)
.mixIn(ModuleResultsOutputUnmaskingInner.class, NoTypeInfoMixin.class)
.mixIn(FilterConfig.class, NoTypeInfoMixin.class)
.mixIn(GroundingModuleConfigConfigFiltersInner.class, NoTypeInfoMixin.class)
.mixIn(MaskingProviderConfig.class, NoTypeInfoMixin.class)
.mixIn(TemplateRefTemplateRef.class, NoTypeInfoMixin.class)
.mixIn(TemplatingModuleConfig.class, NoTypeInfoMixin.class)
.build();
}
Expand Down
Loading

0 comments on commit 5b4cdf6

Please sign in to comment.