diff --git a/src/main/java/org/devlive/sdk/common/DefaultClient.java b/src/main/java/org/devlive/sdk/common/DefaultClient.java index 0e01c2f..ae1c930 100644 --- a/src/main/java/org/devlive/sdk/common/DefaultClient.java +++ b/src/main/java/org/devlive/sdk/common/DefaultClient.java @@ -266,7 +266,7 @@ public AssistantsEntity createAssistants(AssistantsEntity configure) } public AssistantsFileEntity createAssistantsFile(String fileId, - String assistantId) + String assistantId) { String url = String.format(ProviderUtils.getUrl(provider, UrlModel.FETCH_ASSISTANTS_FILES), assistantId); Map configure = Maps.newHashMap(); diff --git a/src/main/java/org/devlive/sdk/openai/OpenAiClient.java b/src/main/java/org/devlive/sdk/openai/OpenAiClient.java index ccb508b..6858ce5 100644 --- a/src/main/java/org/devlive/sdk/openai/OpenAiClient.java +++ b/src/main/java/org/devlive/sdk/openai/OpenAiClient.java @@ -185,6 +185,12 @@ public OpenAiClientBuilder model(CompletionModel model) return this; } + public OpenAiClientBuilder model(String model) + { + this.model = model; + return this; + } + private String getDefaultHost() { if (ObjectUtils.isEmpty(this.provider)) { diff --git a/src/main/java/org/devlive/sdk/openai/entity/ChatEntity.java b/src/main/java/org/devlive/sdk/openai/entity/ChatEntity.java index 4bfd366..09787fb 100644 --- a/src/main/java/org/devlive/sdk/openai/entity/ChatEntity.java +++ b/src/main/java/org/devlive/sdk/openai/entity/ChatEntity.java @@ -13,6 +13,9 @@ import org.devlive.sdk.openai.utils.EnumsUtils; import java.util.List; +import java.util.Objects; + +import static org.devlive.sdk.openai.model.CompletionModel.GPT_35_TURBO; @Data @Builder @@ -47,7 +50,7 @@ public class ChatEntity private ChatEntity(ChatEntityBuilder builder) { if (ObjectUtils.isEmpty(builder.model)) { - builder.model(CompletionModel.GPT_35_TURBO); + builder.model(GPT_35_TURBO); } this.model = builder.model; this.messages = builder.messages; @@ -73,7 +76,7 @@ public static class ChatEntityBuilder public ChatEntityBuilder model(CompletionModel model) { if (ObjectUtils.isEmpty(model)) { - model = CompletionModel.GPT_35_TURBO; + model = GPT_35_TURBO; } switch (model) { case GPT_35_TURBO: @@ -96,6 +99,12 @@ public ChatEntityBuilder model(CompletionModel model) return this; } + public ChatEntityBuilder model(String model) + { + this.model = model; + return this; + } + public ChatEntityBuilder temperature(Double temperature) { if (temperature < 0 || temperature > 2) { @@ -108,11 +117,20 @@ public ChatEntityBuilder temperature(Double temperature) public ChatEntityBuilder maxTokens(Integer maxTokens) { CompletionModel completionModel = EnumsUtils.getCompleteModel(this.model); - if (ObjectUtils.isNotEmpty(this.model) && maxTokens > completionModel.getMaxTokens()) { - throw new ParamException(String.format("Invalid maxTokens: %s, Cannot be larger than the model default configuration %s", maxTokens, completionModel.getMaxTokens())); + if (Objects.isNull(completionModel)) { + this.maxTokens = maxTokens; + return this; + } + else { + if (ObjectUtils.isNotEmpty(this.model) + && maxTokens > completionModel.getMaxTokens()) { + throw new ParamException(String.format( + "Invalid maxTokens: %s, Cannot be larger than the model default configuration %s", + maxTokens, completionModel.getMaxTokens())); + } + this.maxTokens = maxTokens; + return this; } - this.maxTokens = maxTokens; - return this; } private ChatEntityBuilder stream() diff --git a/src/test/java/org/devlive/sdk/openai/OpenAiClientTest.java b/src/test/java/org/devlive/sdk/openai/OpenAiClientTest.java index cd19897..5266e08 100644 --- a/src/test/java/org/devlive/sdk/openai/OpenAiClientTest.java +++ b/src/test/java/org/devlive/sdk/openai/OpenAiClientTest.java @@ -15,6 +15,7 @@ import org.devlive.sdk.common.exception.RequestException; import org.devlive.sdk.openai.model.CompletionModel; import org.devlive.sdk.openai.model.EditModel; +import org.devlive.sdk.openai.response.ChatResponse; import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -94,6 +95,30 @@ public void testCreateCompletion() Assert.assertTrue(client.createCompletion(configure).getChoices().size() > 0); } + @Test + public void testCustomizedModel() + { + client = OpenAiClient.builder() + .apiHost(System.getProperty("proxy.host")) + .apiKey(System.getProperty("openai.token")) + .model("text-davinci-003") + .build(); + + List messages = Lists.newArrayList(); + messages.add(MessageEntity.builder() + .content("Hello, please show me a jok!") + .build()); + + ChatEntity configure = ChatEntity.builder() + .messages(messages) + .model("text-davinci-003") + .build(); + ChatResponse chatCompletion = client.createChatCompletion(configure); + String content = chatCompletion.getChoices().get(0).getMessage().getContent(); + // System.out.println(content); + Assert.assertNotNull(content); + } + @Test public void testCreateChatCompletion() {