diff --git a/.github/workflows/psalm.yml b/.github/workflows/psalm.yml index f7f3f436..efc6a1f3 100644 --- a/.github/workflows/psalm.yml +++ b/.github/workflows/psalm.yml @@ -39,7 +39,7 @@ jobs: fail-fast: false matrix: ocp-version: [ 'dev-master' ] - php-version: [ '8.0', '8.1', '8.2', '8.3' ] + php-version: [ '8.1', '8.2', '8.3' ] name: Psalm check on PHP ${{ matrix.php-version }} and OCP ${{ matrix.ocp-version }} diff --git a/appinfo/routes.php b/appinfo/routes.php index 5c29feae..dc657ce2 100644 --- a/appinfo/routes.php +++ b/appinfo/routes.php @@ -1,4 +1,5 @@ registerTaskProcessingProvider(ReformulateProvider::class); $context->registerTaskProcessingTaskType(ChangeToneTaskType::class); $context->registerTaskProcessingProvider(ChangeToneProvider::class); + if (class_exists('OCP\\TaskProcessing\\TaskTypes\\TextToTextChatWithTools')) { + $context->registerTaskProcessingProvider(\OCA\OpenAi\TaskProcessing\TextToTextChatWithToolsProvider::class); + } } if ($this->appConfig->getValueString(Application::APP_ID, 't2i_provider_enabled', '1') === '1') { $context->registerTaskProcessingProvider(TextToImageProvider::class); diff --git a/lib/Controller/ConfigController.php b/lib/Controller/ConfigController.php index 4d15ab3f..0be3d9f5 100644 --- a/lib/Controller/ConfigController.php +++ b/lib/Controller/ConfigController.php @@ -1,4 +1,5 @@ openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) { $completion = $this->openAiAPIService->createChatCompletion($this->userId, $adminModel, $prompt, null, null, 1, 100); + $completion = $completion['messages']; } else { $completion = $this->openAiAPIService->createCompletion($this->userId, $prompt, 1, $adminModel, 100); } @@ -137,6 +138,7 @@ public function translate(?string $fromLanguage, string $toLanguage, string $tex if ($this->openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) { $completion = $this->openAiAPIService->createChatCompletion($this->userId, $adminModel, $prompt, null, null, 1, PHP_INT_MAX); + $completion = $completion['messages']; } else { $completion = $this->openAiAPIService->createCompletion($this->userId, $prompt, 1, $adminModel, 4000); } diff --git a/lib/Service/OpenAiAPIService.php b/lib/Service/OpenAiAPIService.php index 54e821be..b33bb19a 100644 --- a/lib/Service/OpenAiAPIService.php +++ b/lib/Service/OpenAiAPIService.php @@ -1,4 +1,5 @@ > * @throws Exception */ public function createChatCompletion( ?string $userId, string $model, - string $userPrompt, + ?string $userPrompt = null, ?string $systemPrompt = null, ?array $history = null, int $n = 1, ?int $maxTokens = null, ?array $extraParams = null, + ?string $toolMessage = null, + ?array $tools = null, ): array { if ($this->isQuotaExceeded($userId, Application::QUOTA_TYPE_TEXT)) { throw new Exception($this->l10n->t('Text generation quota exceeded'), Http::STATUS_TOO_MANY_REQUESTS); @@ -384,21 +389,37 @@ public function createChatCompletion( } if ($history !== null) { foreach ($history as $i => $historyEntry) { - if (str_starts_with($historyEntry, 'system:')) { - $historyEntry = preg_replace('/^system:/', '', $historyEntry); - $messages[] = ['role' => 'system', 'content' => $historyEntry]; - } elseif (str_starts_with($historyEntry, 'user:')) { - $historyEntry = preg_replace('/^user:/', '', $historyEntry); - $messages[] = ['role' => 'user', 'content' => $historyEntry]; - } elseif (((int)$i) % 2 === 0) { - // we assume even indexes are user messages and odd ones are system ones - $messages[] = ['role' => 'user', 'content' => $historyEntry]; - } else { - $messages[] = ['role' => 'system', 'content' => $historyEntry]; + $message = json_decode($historyEntry, true); + if ($message['role'] === 'human') { + $message['role'] = 'user'; + } + if (isset($message['tool_calls']) && is_array($message['tool_calls'])) { + $message['tool_calls'] = array_map(static function ($toolCall) { + $formattedToolCall = [ + 'id' => $toolCall['id'], + 'type' => 'function', + 'function' => $toolCall, + ]; + $formattedToolCall['function']['arguments'] = json_encode($toolCall['args']); + unset($formattedToolCall['function']['id']); + unset($formattedToolCall['function']['args']); + unset($formattedToolCall['function']['type']); + return $formattedToolCall; + }, $message['tool_calls']); } + $messages[] = $message; + } + } + if ($userPrompt !== null) { + $messages[] = ['role' => 'user', 'content' => $userPrompt]; + } + if ($toolMessage !== null) { + $msgs = json_decode($toolMessage, true); + foreach ($msgs as $msg) { + $msg['role'] = 'tool'; + $messages[] = $msg; } } - $messages[] = ['role' => 'user', 'content' => $userPrompt]; $params = [ 'model' => $model === Application::DEFAULT_MODEL_ID ? Application::DEFAULT_COMPLETION_MODEL_ID : $model, @@ -406,6 +427,9 @@ public function createChatCompletion( 'max_tokens' => $maxTokens, 'n' => $n, ]; + if ($tools !== null) { + $params['tools'] = $tools; + } if ($userId !== null && $this->isUsingOpenAi()) { $params['user'] = $userId; } @@ -434,10 +458,30 @@ public function createChatCompletion( $this->logger->warning('Could not create quota usage for user: ' . $userId . ' and quota type: ' . Application::QUOTA_TYPE_TEXT . '. Error: ' . $e->getMessage(), ['app' => Application::APP_ID]); } } - $completions = []; + $completions = [ + 'messages' => [], + 'tool_calls' => [], + ]; foreach ($response['choices'] as $choice) { - $completions[] = $choice['message']['content']; + // get tool calls only if this is the finish reason and it's defined and it's an array + if ($choice['finish_reason'] === 'tool_calls' + && isset($choice['message']['tool_calls']) + && is_array($choice['message']['tool_calls']) + ) { + // fix the tool_calls format, make it like expected by the context_agent app + $choice['message']['tool_calls'] = array_map(static function ($toolCall) { + $toolCall['function']['id'] = $toolCall['id']; + $toolCall['function']['args'] = json_decode($toolCall['function']['arguments']); + unset($toolCall['function']['arguments']); + return $toolCall['function']; + }, $choice['message']['tool_calls']); + $completions['tool_calls'][] = json_encode($choice['message']['tool_calls']); + } + // always try to get a message + if (isset($choice['message']['content']) && is_string($choice['message']['content'])) { + $completions['messages'][] = $choice['message']['content']; + } } return $completions; diff --git a/lib/Service/OpenAiSettingsService.php b/lib/Service/OpenAiSettingsService.php index 1be2b41d..9f15aa6c 100644 --- a/lib/Service/OpenAiSettingsService.php +++ b/lib/Service/OpenAiSettingsService.php @@ -1,4 +1,5 @@ openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) { $completion = $this->openAiAPIService->createChatCompletion($userId, $model, $prompt, null, null, 1, $maxTokens); + $completion = $completion['messages']; } else { $completion = $this->openAiAPIService->createCompletion($userId, $prompt, 1, $model, $maxTokens); } diff --git a/lib/TaskProcessing/ContextWriteProvider.php b/lib/TaskProcessing/ContextWriteProvider.php index d128924c..8b36701e 100644 --- a/lib/TaskProcessing/ContextWriteProvider.php +++ b/lib/TaskProcessing/ContextWriteProvider.php @@ -129,6 +129,7 @@ public function process(?string $userId, array $input, callable $reportProgress) try { if ($this->openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) { $completion = $this->openAiAPIService->createChatCompletion($userId, $model, $prompt, null, null, 1, $maxTokens); + $completion = $completion['messages']; } else { $completion = $this->openAiAPIService->createCompletion($userId, $prompt, 1, $model, $maxTokens); } diff --git a/lib/TaskProcessing/HeadlineProvider.php b/lib/TaskProcessing/HeadlineProvider.php index 79e69c1a..dda9b449 100644 --- a/lib/TaskProcessing/HeadlineProvider.php +++ b/lib/TaskProcessing/HeadlineProvider.php @@ -117,6 +117,7 @@ public function process(?string $userId, array $input, callable $reportProgress) try { if ($this->openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) { $completion = $this->openAiAPIService->createChatCompletion($userId, $model, $prompt, null, null, 1, $maxTokens); + $completion = $completion['messages']; } else { $completion = $this->openAiAPIService->createCompletion($userId, $prompt, 1, $model, $maxTokens); } diff --git a/lib/TaskProcessing/ReformulateProvider.php b/lib/TaskProcessing/ReformulateProvider.php index 14f94ba4..59e87ed3 100644 --- a/lib/TaskProcessing/ReformulateProvider.php +++ b/lib/TaskProcessing/ReformulateProvider.php @@ -117,6 +117,7 @@ public function process(?string $userId, array $input, callable $reportProgress) try { if ($this->openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) { $completion = $this->openAiAPIService->createChatCompletion($userId, $model, $prompt, null, null, 1, $maxTokens); + $completion = $completion['messages']; } else { $completion = $this->openAiAPIService->createCompletion($userId, $prompt, 1, $model, $maxTokens); } diff --git a/lib/TaskProcessing/SummaryProvider.php b/lib/TaskProcessing/SummaryProvider.php index 2dce20d4..222a25d0 100644 --- a/lib/TaskProcessing/SummaryProvider.php +++ b/lib/TaskProcessing/SummaryProvider.php @@ -117,6 +117,7 @@ public function process(?string $userId, array $input, callable $reportProgress) try { if ($this->openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) { $completion = $this->openAiAPIService->createChatCompletion($userId, $model, $prompt, null, null, 1, $maxTokens); + $completion = $completion['messages']; } else { $completion = $this->openAiAPIService->createCompletion($userId, $prompt, 1, $model, $maxTokens); } diff --git a/lib/TaskProcessing/TextToTextChatProvider.php b/lib/TaskProcessing/TextToTextChatProvider.php index 998b0ace..7886729d 100644 --- a/lib/TaskProcessing/TextToTextChatProvider.php +++ b/lib/TaskProcessing/TextToTextChatProvider.php @@ -104,6 +104,7 @@ public function process(?string $userId, array $input, callable $reportProgress) try { $completion = $this->openAiAPIService->createChatCompletion($userId, $adminModel, $userPrompt, $systemPrompt, $history, 1, $maxTokens); + $completion = $completion['messages']; } catch (Exception $e) { throw new RuntimeException('OpenAI/LocalAI request failed: ' . $e->getMessage()); } diff --git a/lib/TaskProcessing/TextToTextChatWithToolsProvider.php b/lib/TaskProcessing/TextToTextChatWithToolsProvider.php new file mode 100644 index 00000000..186bfb5e --- /dev/null +++ b/lib/TaskProcessing/TextToTextChatWithToolsProvider.php @@ -0,0 +1,142 @@ +openAiAPIService->getServiceName(); + } + + public function getTaskTypeId(): string { + return TextToTextChatWithTools::ID; + } + + public function getExpectedRuntime(): int { + return $this->openAiAPIService->getExpTextProcessingTime(); + } + + public function getInputShapeEnumValues(): array { + return []; + } + + public function getInputShapeDefaults(): array { + return []; + } + + public function getOptionalInputShape(): array { + return [ + 'max_tokens' => new ShapeDescriptor( + $this->l->t('Maximum output words'), + $this->l->t('The maximum number of words/tokens that can be generated in the completion.'), + EShapeType::Number + ), + ]; + } + + public function getOptionalInputShapeEnumValues(): array { + return []; + } + + public function getOptionalInputShapeDefaults(): array { + return []; + } + + public function getOutputShapeEnumValues(): array { + return []; + } + + public function getOptionalOutputShape(): array { + return []; + } + + public function getOptionalOutputShapeEnumValues(): array { + return []; + } + + public function process(?string $userId, array $input, callable $reportProgress): array { + $startTime = time(); + $adminModel = $this->appConfig->getValueString(Application::APP_ID, 'default_completion_model_id', Application::DEFAULT_COMPLETION_MODEL_ID) ?: Application::DEFAULT_COMPLETION_MODEL_ID; + + if (!isset($input['input']) || !is_string($input['input'])) { + throw new RuntimeException('Invalid input'); + } + $userPrompt = $input['input']; + if ($userPrompt === '') { + $userPrompt = null; + } + + if (!isset($input['system_prompt']) || !is_string($input['system_prompt'])) { + throw new RuntimeException('Invalid system_prompt'); + } + $systemPrompt = $input['system_prompt']; + + if (!isset($input['tool_message']) || !is_string($input['tool_message'])) { + throw new RuntimeException('Invalid tool_message'); + } + $toolMessage = $input['tool_message']; + if ($toolMessage === '') { + $toolMessage = null; + } + + if (!isset($input['tools']) || !is_string($input['tools'])) { + throw new RuntimeException('Invalid tools'); + } + $tools = json_decode($input['tools']); + if (!is_array($tools) || !\array_is_list($tools)) { + throw new RuntimeException('Invalid JSON tools'); + } + + if (!isset($input['history']) || !is_array($input['history']) || !\array_is_list($input['history'])) { + throw new RuntimeException('Invalid history'); + } + $history = $input['history']; + + $maxTokens = null; + if (isset($input['max_tokens']) && is_int($input['max_tokens'])) { + $maxTokens = $input['max_tokens']; + } + + try { + $completion = $this->openAiAPIService->createChatCompletion( + $userId, $adminModel, $userPrompt, $systemPrompt, $history, 1, $maxTokens, null, $toolMessage, $tools + ); + } catch (Exception $e) { + throw new RuntimeException('OpenAI/LocalAI request failed: ' . $e->getMessage()); + } + if (count($completion['messages']) > 0 || count($completion['tool_calls']) > 0) { + $endTime = time(); + $this->openAiAPIService->updateExpTextProcessingTime($endTime - $startTime); + return [ + 'output' => array_pop($completion['messages']) ?? '', + 'tool_calls' => array_pop($completion['tool_calls']) ?? '', + ]; + } + + throw new RuntimeException('No result in OpenAI/LocalAI response.'); + } +} diff --git a/lib/TaskProcessing/TextToTextProvider.php b/lib/TaskProcessing/TextToTextProvider.php index 09a481ed..5b1fd8ee 100644 --- a/lib/TaskProcessing/TextToTextProvider.php +++ b/lib/TaskProcessing/TextToTextProvider.php @@ -122,6 +122,7 @@ public function process(?string $userId, array $input, callable $reportProgress) try { if ($this->openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) { $completion = $this->openAiAPIService->createChatCompletion($userId, $model, $prompt, null, null, 1, $maxTokens); + $completion = $completion['messages']; } else { $completion = $this->openAiAPIService->createCompletion($userId, $prompt, 1, $model, $maxTokens); } diff --git a/lib/TaskProcessing/TopicsProvider.php b/lib/TaskProcessing/TopicsProvider.php index 745e77fc..b9d11a8b 100644 --- a/lib/TaskProcessing/TopicsProvider.php +++ b/lib/TaskProcessing/TopicsProvider.php @@ -117,6 +117,7 @@ public function process(?string $userId, array $input, callable $reportProgress) try { if ($this->openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) { $completion = $this->openAiAPIService->createChatCompletion($userId, $model, $prompt, null, null, 1, $maxTokens); + $completion = $completion['messages']; } else { $completion = $this->openAiAPIService->createCompletion($userId, $prompt, 1, $model, $maxTokens); } diff --git a/lib/TaskProcessing/TranslateProvider.php b/lib/TaskProcessing/TranslateProvider.php index cef11dc9..071672a8 100644 --- a/lib/TaskProcessing/TranslateProvider.php +++ b/lib/TaskProcessing/TranslateProvider.php @@ -168,6 +168,7 @@ public function process(?string $userId, array $input, callable $reportProgress) if ($this->openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) { $completion = $this->openAiAPIService->createChatCompletion($userId, $model, $prompt, null, null, 1, $maxTokens); + $completion = $completion['messages']; } else { $completion = $this->openAiAPIService->createCompletion($userId, $prompt, 1, $model, $maxTokens); } diff --git a/vendor-bin/php-cs-fixer/composer.lock b/vendor-bin/php-cs-fixer/composer.lock index 33edf055..cfc4d58c 100644 --- a/vendor-bin/php-cs-fixer/composer.lock +++ b/vendor-bin/php-cs-fixer/composer.lock @@ -97,16 +97,16 @@ }, { "name": "php-cs-fixer/shim", - "version": "v3.64.0", + "version": "v3.65.0", "source": { "type": "git", "url": "/~https://github.com/PHP-CS-Fixer/shim.git", - "reference": "81ccfd24baf3a10810dab1152c403981a790b837" + "reference": "4983ec79b9dee926695ac324ea6e8d291935525d" }, "dist": { "type": "zip", - "url": "https://api.github.com/repos/PHP-CS-Fixer/shim/zipball/81ccfd24baf3a10810dab1152c403981a790b837", - "reference": "81ccfd24baf3a10810dab1152c403981a790b837", + "url": "https://api.github.com/repos/PHP-CS-Fixer/shim/zipball/4983ec79b9dee926695ac324ea6e8d291935525d", + "reference": "4983ec79b9dee926695ac324ea6e8d291935525d", "shasum": "" }, "require": { @@ -143,9 +143,9 @@ "description": "A tool to automatically fix PHP code style", "support": { "issues": "/~https://github.com/PHP-CS-Fixer/shim/issues", - "source": "/~https://github.com/PHP-CS-Fixer/shim/tree/v3.64.0" + "source": "/~https://github.com/PHP-CS-Fixer/shim/tree/v3.65.0" }, - "time": "2024-08-30T23:10:11+00:00" + "time": "2024-11-25T00:39:41+00:00" } ], "aliases": [], diff --git a/vendor-bin/phpunit/composer.lock b/vendor-bin/phpunit/composer.lock index 2158e030..3dd4aaf8 100644 --- a/vendor-bin/phpunit/composer.lock +++ b/vendor-bin/phpunit/composer.lock @@ -634,16 +634,16 @@ }, { "name": "phpunit/phpunit", - "version": "9.6.21", + "version": "9.6.22", "source": { "type": "git", "url": "/~https://github.com/sebastianbergmann/phpunit.git", - "reference": "de6abf3b6f8dd955fac3caad3af7a9504e8c2ffa" + "reference": "f80235cb4d3caa59ae09be3adf1ded27521d1a9c" }, "dist": { "type": "zip", - "url": "https://api.github.com/repos/sebastianbergmann/phpunit/zipball/de6abf3b6f8dd955fac3caad3af7a9504e8c2ffa", - "reference": "de6abf3b6f8dd955fac3caad3af7a9504e8c2ffa", + "url": "https://api.github.com/repos/sebastianbergmann/phpunit/zipball/f80235cb4d3caa59ae09be3adf1ded27521d1a9c", + "reference": "f80235cb4d3caa59ae09be3adf1ded27521d1a9c", "shasum": "" }, "require": { @@ -654,7 +654,7 @@ "ext-mbstring": "*", "ext-xml": "*", "ext-xmlwriter": "*", - "myclabs/deep-copy": "^1.12.0", + "myclabs/deep-copy": "^1.12.1", "phar-io/manifest": "^2.0.4", "phar-io/version": "^3.2.1", "php": ">=7.3", @@ -717,7 +717,7 @@ "support": { "issues": "/~https://github.com/sebastianbergmann/phpunit/issues", "security": "/~https://github.com/sebastianbergmann/phpunit/security/policy", - "source": "/~https://github.com/sebastianbergmann/phpunit/tree/9.6.21" + "source": "/~https://github.com/sebastianbergmann/phpunit/tree/9.6.22" }, "funding": [ { @@ -733,7 +733,7 @@ "type": "tidelift" } ], - "time": "2024-09-19T10:50:18+00:00" + "time": "2024-12-05T13:48:26+00:00" }, { "name": "sebastian/cli-parser",