Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Orchestration Convenience] Typed model parameters #180

Merged
merged 13 commits into from
Nov 29, 2024
6 changes: 1 addition & 5 deletions docs/guides/ORCHESTRATION_CHAT_COMPLETION.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,6 @@ Change your LLM configuration to add model parameters:
OrchestrationAiModel customGPT4O =
OrchestrationAiModel.GPT_4O
.withParams(
Map.of(
"max_tokens", 50,
"temperature", 0.1,
"frequency_penalty", 0,
"presence_penalty", 0))
params().maxTokens(50).temperature(0.1).frequencyPenalty(0).presencePenalty(0))
.withVersion("2024-05-13");
```
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,21 @@
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
import java.util.Map;
import javax.annotation.Nonnull;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Value;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.With;
import lombok.experimental.Tolerate;

/** Large language models available in Orchestration. */
@Value
@With
@AllArgsConstructor
@Getter(AccessLevel.PACKAGE)
@EqualsAndHashCode
public class OrchestrationAiModel {
/** The name of the model */
String name;
private final String name;

/**
* Optional parameters on this model.
Expand All @@ -26,10 +30,10 @@ public class OrchestrationAiModel {
* "presence_penalty", 0)
* }</pre>
*/
Map<String, Object> params;
private final Map<String, Object> params;

/** The version of the model, defaults to "latest". */
String version;
private final String version;

/** IBM Granite 13B chat completions model */
public static final OrchestrationAiModel IBM_GRANITE_13B_CHAT =
Expand Down Expand Up @@ -76,23 +80,28 @@ public class OrchestrationAiModel {
new OrchestrationAiModel("amazon--titan-text-express");

/** Azure OpenAI GPT-3.5 Turbo chat completions model */
public static final OrchestrationAiModel GPT_35_TURBO = new OrchestrationAiModel("gpt-35-turbo");
public static final Parameterized<OrchestrationAiModelParameters.GPT> GPT_35_TURBO =
new Parameterized<>("gpt-35-turbo");

/** Azure OpenAI GPT-3.5 Turbo chat completions model */
public static final OrchestrationAiModel GPT_35_TURBO_16K =
new OrchestrationAiModel("gpt-35-turbo-16k");
public static final Parameterized<OrchestrationAiModelParameters.GPT> GPT_35_TURBO_16K =
new Parameterized<>("gpt-35-turbo-16k");

/** Azure OpenAI GPT-4 chat completions model */
public static final OrchestrationAiModel GPT_4 = new OrchestrationAiModel("gpt-4");
public static final Parameterized<OrchestrationAiModelParameters.GPT> GPT_4 =
new Parameterized<>("gpt-4");

/** Azure OpenAI GPT-4-32k chat completions model */
public static final OrchestrationAiModel GPT_4_32K = new OrchestrationAiModel("gpt-4-32k");
public static final Parameterized<OrchestrationAiModelParameters.GPT> GPT_4_32K =
new Parameterized<>("gpt-4-32k");

/** Azure OpenAI GPT-4o chat completions model */
public static final OrchestrationAiModel GPT_4O = new OrchestrationAiModel("gpt-4o");
public static final Parameterized<OrchestrationAiModelParameters.GPT> GPT_4O =
new Parameterized<>("gpt-4o");

/** Azure OpenAI GPT-4o-mini chat completions model */
public static final OrchestrationAiModel GPT_4O_MINI = new OrchestrationAiModel("gpt-4o-mini");
public static final Parameterized<OrchestrationAiModelParameters.GPT> GPT_4O_MINI =
new Parameterized<>("gpt-4o-mini");

/** Google Cloud Platform Gemini 1.0 Pro model */
public static final OrchestrationAiModel GEMINI_1_0_PRO =
Expand All @@ -114,4 +123,28 @@ public class OrchestrationAiModel {
LLMModuleConfig createConfig() {
return new LLMModuleConfig().modelName(name).modelParams(params).modelVersion(version);
}

/**
* Subclass to allow for parameterized models.
*
* @param <T> The type of parameters for this model.
*/
public static final class Parameterized<T extends OrchestrationAiModelParameters>
newtork marked this conversation as resolved.
Show resolved Hide resolved
extends OrchestrationAiModel {
private Parameterized(@Nonnull final String name) {
super(name);
}

/**
* Set the typed parameters for this model.
*
* @param params The parameters for this model.
* @return The model with the parameters set.
*/
@Tolerate
@Nonnull
public OrchestrationAiModel withParams(@Nonnull final T params) {
return super.withParams(params.getParams());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package com.sap.ai.sdk.orchestration;

import java.util.Map;
import javax.annotation.Nonnull;

/** Helper interface to define typed parameters. */
@FunctionalInterface
public interface OrchestrationAiModelParameters {
/**
* Get the parameters.
*
* @return the parameters.
*/
@Nonnull
Map<String, Object> getParams();

/** GPT model parameters. */
interface GPT extends OrchestrationAiModelParameters {
newtork marked this conversation as resolved.
Show resolved Hide resolved
/**
* Create a new builder.
*
* @return the builder.
*/
@Nonnull
static GPT.Builder0 params() {
newtork marked this conversation as resolved.
Show resolved Hide resolved
return maxTokens ->
temperature ->
frequencyPenalty ->
presencePenalty ->
() ->
Map.of(
"max_tokens", maxTokens,
newtork marked this conversation as resolved.
Show resolved Hide resolved
"temperature", temperature,
"frequency_penalty", frequencyPenalty,
"presence_penalty", presencePenalty);
}

/** Builder for GPT model parameters. */
interface Builder0 {
/**
* Set the max tokens.
*
* @param maxTokens the max tokens.
* @return the next builder.
*/
@Nonnull
GPT.Builder1 maxTokens(@Nonnull final Number maxTokens);
}

/** Builder for GPT model parameters. */
interface Builder1 {
/**
* Set the temperature.
*
* @param temperature the temperature.
* @return the next builder.
*/
@Nonnull
GPT.Builder2 temperature(@Nonnull final Number temperature);
}

/** Builder for GPT model parameters. */
interface Builder2 {
/**
* Set the frequency penalty.
*
* @param frequencyPenalty the frequency penalty.
* @return the next builder.
*/
@Nonnull
GPT.Builder3 frequencyPenalty(@Nonnull final Number frequencyPenalty);
}

/** Builder for GPT model parameters. */
interface Builder3 {
/**
* Set the presence penalty.
*
* @param presencePenalty the presence penalty.
* @return the final typed parameter object.
*/
@Nonnull
GPT presencePenalty(@Nonnull final Number presencePenalty);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static com.sap.ai.sdk.orchestration.AzureFilterThreshold.ALLOW_SAFE;
import static com.sap.ai.sdk.orchestration.AzureFilterThreshold.ALLOW_SAFE_LOW_MEDIUM;
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_35_TURBO_16K;
import static com.sap.ai.sdk.orchestration.OrchestrationAiModelParameters.GPT.params;
import static org.apache.hc.core5.http.HttpStatus.SC_BAD_REQUEST;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
Expand Down Expand Up @@ -53,11 +54,8 @@
class OrchestrationUnitTest {
static final OrchestrationAiModel CUSTOM_GPT_35 =
GPT_35_TURBO_16K.withParams(
Map.of(
"max_tokens", 50,
"temperature", 0.1,
"frequency_penalty", 0,
"presence_penalty", 0));
params().maxTokens(50).temperature(0.1).frequencyPenalty(0).presencePenalty(0));

private final Function<String, InputStream> fileLoader =
filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename));

Expand Down