Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Orchestration Convenience] Typed model parameters #180

Merged
merged 13 commits into from
Nov 29, 2024
12 changes: 5 additions & 7 deletions docs/guides/ORCHESTRATION_CHAT_COMPLETION.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,10 @@ Change your LLM configuration to add model parameters:
```java
OrchestrationAiModel customGPT4O =
OrchestrationAiModel.GPT_4O
.withParams(
Map.of(
"max_tokens", 50,
"temperature", 0.1,
"frequency_penalty", 0,
"presence_penalty", 0))
.withParam(MAX_TOKENS, 50)
.withParam(TEMPERATURE, 0.1)
.withParam(FREQUENCY_PENALTY, 0)
.withParam(PRESENCE_PENALTY, 0)
.withVersion("2024-05-13");
```

Expand All @@ -225,4 +223,4 @@ var prompt = new OrchestrationPrompt(Map.of("your-input-parameter", "your-param-
new OrchestrationClient().executeRequestFromJsonModuleConfig(prompt, configJson);
```

While this is not recommended for long term use, it can be useful for creating demos and PoCs.
While this is not recommended for long term use, it can be useful for creating demos and PoCs.
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package com.sap.ai.sdk.orchestration;

import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
import java.util.LinkedHashMap;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.AllArgsConstructor;
import lombok.RequiredArgsConstructor;
import lombok.Value;
import lombok.With;

Expand Down Expand Up @@ -114,4 +117,58 @@ public class OrchestrationAiModel {
LLMModuleConfig createConfig() {
return LLMModuleConfig.create().modelName(name).modelParams(params).modelVersion(version);
}

/**
* Additional parameter on this model.
*
* @param key the parameter key.
* @param value the parameter value, nullable.
* @return A new model with the additional parameter.
*/
@Nonnull
public OrchestrationAiModel withParam(@Nonnull final String key, @Nullable final Object value) {
final var params = new LinkedHashMap<>(getParams());
params.put(key, value);
return withParams(params);
}

/**
* Additional parameter on this model.
*
* @param param the parameter key.
* @param value the parameter value, nullable.
* @return A new model with the additional parameter.
*/
@Nonnull
public OrchestrationAiModel withParam(
@Nonnull final Parameter param, @Nullable final Object value) {
return withParam(param.value, value);
}

/** Parameter key for a model. */
@RequiredArgsConstructor
public enum Parameter {
/** The maximum number of tokens to generate. */
MAX_TOKENS("max_tokens"),
newtork marked this conversation as resolved.
Show resolved Hide resolved
/** The sampling temperature. */
TEMPERATURE("temperature"),
/** The frequency penalty. */
FREQUENCY_PENALTY("frequency_penalty"),
/** The presence penalty. */
PRESENCE_PENALTY("presence_penalty"),
/** The maximum number of tokens for completion */
MAX_COMPLETION_TOKENS("max_completion_tokens"),
/** The probability mass to be considered . */
TOP_P("top_p"),
/** The toggle to enable partial message delta. */
STREAM("stream"),
/** The options for streaming response. */
STREAM_OPTIONS("stream_options"),
/** The tokens where the API will stop generating further tokens. */
STOP("stop"),
newtork marked this conversation as resolved.
Show resolved Hide resolved
/** The number of chat completion choices to generate for each input message. */
N("n");

private final String value;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static com.sap.ai.sdk.orchestration.AzureFilterThreshold.ALLOW_SAFE_LOW_MEDIUM;
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_4O;
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.Parameter.MAX_TOKENS;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

Expand Down Expand Up @@ -65,6 +66,44 @@ void testDpiMaskingConfig() {
.hasSize(1);
}

@Test
void testParams() {
// test withParams(Map<String, Object>)
{
var params = Map.<String, Object>of("foo", "bar", "fizz", "buzz");

var modelA = GPT_4O.withParams(params);
var modelB = modelA.withParams(params);
assertThat(modelA).isEqualTo(modelB);

var modelC = modelA.withParams(Map.of("foo", "bar"));
assertThat(modelA).isNotEqualTo(modelC);

var modelD = modelA.withParams(Map.of("foo", "bazz"));
assertThat(modelA).isNotEqualTo(modelD);
}

// test withParam(String, Object)
{
var modelA = GPT_4O.withParam("foo", "bar");
var modelB = modelA.withParam("foo", "bar");
assertThat(modelA).isEqualTo(modelB);

var modelC = modelA.withParam("foo", "bazz");
assertThat(modelA).isNotEqualTo(modelC);
}

// test withParam(Parameter, Object)
{
var modelA = GPT_4O.withParam(MAX_TOKENS, 10);
var modelB = modelA.withParam(MAX_TOKENS, 10);
assertThat(modelA).isEqualTo(modelB);

var modelC = modelA.withParam(MAX_TOKENS, 20);
assertThat(modelA).isNotEqualTo(modelC);
}
}

@Test
void testLLMConfig() {
Map<String, Object> params = Map.of("foo", "bar");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static com.sap.ai.sdk.orchestration.AzureFilterThreshold.ALLOW_SAFE;
import static com.sap.ai.sdk.orchestration.AzureFilterThreshold.ALLOW_SAFE_LOW_MEDIUM;
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_35_TURBO_16K;
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.Parameter.*;
import static org.apache.hc.core5.http.HttpStatus.SC_BAD_REQUEST;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
Expand Down Expand Up @@ -52,12 +53,12 @@
@WireMockTest
class OrchestrationUnitTest {
static final OrchestrationAiModel CUSTOM_GPT_35 =
GPT_35_TURBO_16K.withParams(
Map.of(
"max_tokens", 50,
"temperature", 0.1,
"frequency_penalty", 0,
"presence_penalty", 0));
GPT_35_TURBO_16K
.withParam(MAX_TOKENS, 50)
.withParam(TEMPERATURE, 0.1)
.withParam(FREQUENCY_PENALTY, 0)
.withParam(PRESENCE_PENALTY, 0);

private final Function<String, InputStream> fileLoader =
filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename));

Expand Down