diff --git a/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md b/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md index 552be328d..01bc97477 100644 --- a/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md +++ b/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md @@ -200,11 +200,11 @@ Change your LLM configuration to add model parameters: ```java OrchestrationAiModel customGPT4O = OrchestrationAiModel.GPT_4O - .withModelParams( + .withParams( Map.of( "max_tokens", 50, "temperature", 0.1, "frequency_penalty", 0, "presence_penalty", 0)) - .withModelVersion("2024-05-13"); + .withVersion("2024-05-13"); ``` diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationAiModel.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationAiModel.java index 7f806ece7..0342dd82c 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationAiModel.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationAiModel.java @@ -13,7 +13,7 @@ @AllArgsConstructor public class OrchestrationAiModel { /** The name of the model */ - String modelName; + String name; /** * Optional parameters on this model. @@ -26,10 +26,10 @@ public class OrchestrationAiModel { * "presence_penalty", 0) * } */ - Map modelParams; + Map params; /** The version of the model, defaults to "latest". */ - String modelVersion; + String version; /** IBM Granite 13B chat completions model */ public static final OrchestrationAiModel IBM_GRANITE_13B_CHAT = @@ -106,15 +106,12 @@ public class OrchestrationAiModel { public static final OrchestrationAiModel GEMINI_1_5_FLASH = new OrchestrationAiModel("gemini-1.5-flash"); - OrchestrationAiModel(@Nonnull final String modelName) { - this(modelName, Map.of(), "latest"); + OrchestrationAiModel(@Nonnull final String name) { + this(name, Map.of(), "latest"); } @Nonnull LLMModuleConfig createConfig() { - return new LLMModuleConfig() - .modelName(modelName) - .modelParams(modelParams) - .modelVersion(modelVersion); + return new LLMModuleConfig().modelName(name).modelParams(params).modelVersion(version); } } diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfigTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfigTest.java index 137598cc5..1d0c0527c 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfigTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfigTest.java @@ -69,18 +69,16 @@ void testDpiMaskingConfig() { void testLLMConfig() { Map params = Map.of("foo", "bar"); String version = "2024-05-13"; - OrchestrationAiModel aiModel = GPT_4O.withModelParams(params).withModelVersion(version); + OrchestrationAiModel aiModel = GPT_4O.withParams(params).withVersion(version); var config = new OrchestrationModuleConfig().withLlmConfig(aiModel); assertThat(config.getLlmConfig()).isNotNull(); - assertThat(config.getLlmConfig().getModelName()).isEqualTo(GPT_4O.getModelName()); + assertThat(config.getLlmConfig().getModelName()).isEqualTo(GPT_4O.getName()); assertThat(config.getLlmConfig().getModelParams()).isEqualTo(params); assertThat(config.getLlmConfig().getModelVersion()).isEqualTo(version); - assertThat(GPT_4O.getModelParams()) - .withFailMessage("Static models should be unchanged") - .isEmpty(); - assertThat(GPT_4O.getModelVersion()) + assertThat(GPT_4O.getParams()).withFailMessage("Static models should be unchanged").isEmpty(); + assertThat(GPT_4O.getVersion()) .withFailMessage("Static models should be unchanged") .isEqualTo("latest"); } diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java index db785ab78..ca37f456b 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java @@ -52,7 +52,7 @@ @WireMockTest class OrchestrationUnitTest { static final OrchestrationAiModel CUSTOM_GPT_35 = - GPT_35_TURBO_16K.withModelParams( + GPT_35_TURBO_16K.withParams( Map.of( "max_tokens", 50, "temperature", 0.1,