Skip to content

Commit

Permalink
[OPIK-689] [FR]: Add Support for Additional AI Provider Configuration…
Browse files Browse the repository at this point in the history
…s (Local Models) (#1276)

* [OPIK-689] [FR]: Add Support for Additional AI Provider Configurations (Local Models)

* - fix/add descriptions messages
  • Loading branch information
andriidudar authored Feb 14, 2025
1 parent b24c58c commit 61eba4b
Show file tree
Hide file tree
Showing 27 changed files with 1,047 additions and 343 deletions.
2 changes: 1 addition & 1 deletion apps/opik-frontend/src/api/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export const COMPARE_EXPERIMENTS_KEY = "compare-experiments";
export const SPANS_KEY = "spans";
export const TRACES_KEY = "traces";
export const TRACE_KEY = "trace";
export const PROVIDERS_KEYS_KEY = "providerKeys";
export const PROVIDERS_KEYS_KEY = "provider-keys";
export const AUTOMATIONS_KEY = "automations";
export const PROJECTS_KEY = "projects";
export const PROJECT_STATISTICS_KEY = "project-statistics";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,21 @@ import api, {
TRACES_REST_ENDPOINT,
} from "@/api/api";
import { snakeCaseObj } from "@/lib/utils";
import { getModelProvider } from "@/lib/llm";
import { createBatchProcessor } from "@/lib/batches";
import { RunStreamingReturn } from "@/api/playground/useCompletionProxyStreaming";
import { LLMPromptConfigsType, PROVIDER_MODEL_TYPE } from "@/types/providers";
import {
LLMPromptConfigsType,
PROVIDER_MODEL_TYPE,
PROVIDER_TYPE,
} from "@/types/providers";
import { ProviderMessageType } from "@/types/llm";

export interface LogQueueParams extends RunStreamingReturn {
promptId: string;
datasetItemId?: string;
datasetName: string | null;
model: PROVIDER_MODEL_TYPE | "";
provider: PROVIDER_TYPE | "";
providerMessages: ProviderMessageType[];
configs: LLMPromptConfigsType;
}
Expand Down Expand Up @@ -98,7 +102,7 @@ const getSpanFromRun = (run: LogQueueParams, traceId: string): LogSpan => {
output: { choices: run.choices ? run.choices : [] },
usage: !run.usage ? undefined : pick(run.usage, USAGE_FIELDS_TO_SEND),
metadata: {
created_from: run.model ? getModelProvider(run.model) : "",
created_from: run.provider,
usage: run.usage,
model: run.model,
parameters: run.configs,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import { useCallback } from "react";

import dayjs from "dayjs";
import isObject from "lodash/isObject";

import { UsageType } from "@/types/shared";
import {
ChatCompletionMessageChoiceType,
ChatCompletionResponse,
ChatCompletionProviderErrorMessageType,
ChatCompletionSuccessMessageType,
ChatCompletionOpikErrorMessageType,
ChatCompletionPythonProxyErrorMessageType,
} from "@/types/playground";
import { isValidJsonObject, safelyParseJSON, snakeCaseObj } from "@/lib/utils";
import { BASE_API_URL } from "@/api/api";
Expand All @@ -19,13 +21,20 @@ const getNowUtcTimeISOString = (): string => {
};

interface GetCompletionProxyStreamParams {
url?: string;
model: PROVIDER_MODEL_TYPE | "";
messages: ProviderMessageType[];
signal: AbortSignal;
configs: LLMPromptConfigsType;
workspaceName: string;
}

const isPythonProxyError = (
response: ChatCompletionResponse,
): response is ChatCompletionPythonProxyErrorMessageType => {
return "detail" in response;
};

const isOpikError = (
response: ChatCompletionResponse,
): response is ChatCompletionOpikErrorMessageType => {
Expand All @@ -42,13 +51,14 @@ const isProviderError = (
};

const getCompletionProxyStream = async ({
url = `${BASE_API_URL}/v1/private/chat/completions`,
model,
messages,
signal,
configs,
workspaceName,
}: GetCompletionProxyStreamParams) => {
return fetch(`${BASE_API_URL}/v1/private/chat/completions`, {
return fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
Expand All @@ -67,6 +77,7 @@ const getCompletionProxyStream = async ({
};

export interface RunStreamingArgs {
url?: string;
model: PROVIDER_MODEL_TYPE | "";
messages: ProviderMessageType[];
configs: LLMPromptConfigsType;
Expand All @@ -82,6 +93,7 @@ export interface RunStreamingReturn {
choices: ChatCompletionMessageChoiceType[] | null;
providerError: null | string;
opikError: null | string;
pythonProxyError: null | string;
}

interface UseCompletionProxyStreamingParameters {
Expand All @@ -91,8 +103,9 @@ interface UseCompletionProxyStreamingParameters {
const useCompletionProxyStreaming = ({
workspaceName,
}: UseCompletionProxyStreamingParameters) => {
const runStreaming = useCallback(
return useCallback(
async ({
url,
model,
messages,
configs,
Expand All @@ -106,11 +119,13 @@ const useCompletionProxyStreaming = ({
let choices: ChatCompletionMessageChoiceType[] = [];

// errors
let pythonProxyError = null;
let opikError = null;
let providerError = null;

try {
const response = await getCompletionProxyStream({
url,
model,
messages,
configs,
Expand Down Expand Up @@ -156,23 +171,43 @@ const useCompletionProxyStreaming = ({
opikError = parsedMessage.errors.join(" ");
};

const handlePythonProxyErrorMessage = (
parsedMessage: ChatCompletionPythonProxyErrorMessageType,
) => {
if (
isObject(parsedMessage.detail) &&
"error" in parsedMessage.detail
) {
pythonProxyError = parsedMessage.detail.error;
} else {
pythonProxyError = parsedMessage.detail ?? "Python proxy error";
}
};

// an analogue of true && reader
// we need it to wait till the stream is closed
while (reader) {
const { done, value } = await reader.read();

if (done || opikError || providerError) {
if (done || opikError || pythonProxyError || providerError) {
break;
}

const chunk = decoder.decode(value, { stream: true });
const lines = chunk.split("\n").filter((line) => line.trim() !== "");

for (const line of lines) {
const parsed = safelyParseJSON(line) as ChatCompletionResponse;
const ollamaDataPrefix = "data:";
const JSONData = line.startsWith(ollamaDataPrefix)
? line.split(ollamaDataPrefix)[1]
: line;

const parsed = safelyParseJSON(JSONData) as ChatCompletionResponse;

// handle different message types
if (isOpikError(parsed)) {
if (isPythonProxyError(parsed)) {
handlePythonProxyErrorMessage(parsed);
} else if (isOpikError(parsed)) {
handleOpikErrorMessage(parsed);
} else if (isProviderError(parsed)) {
handleAIPlatformErrorMessage(parsed);
Expand All @@ -188,6 +223,7 @@ const useCompletionProxyStreaming = ({
result: accumulatedValue,
providerError,
opikError,
pythonProxyError,
usage,
choices,
};
Expand All @@ -205,15 +241,14 @@ const useCompletionProxyStreaming = ({
result: accumulatedValue,
providerError,
opikError: opikError || defaultErrorMessage,
pythonProxyError,
usage: null,
choices,
};
}
},
[workspaceName],
);

return runStreaming;
};

export default useCompletionProxyStreaming;
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import api, {
QueryConfig,
} from "@/api/api";
import { ProviderKey } from "@/types/providers";
import useLocalAIProviderData from "@/hooks/useLocalAIProviderData";

type UseProviderKeysListParams = {
workspaceName: string;
Expand All @@ -15,21 +16,28 @@ type UseProviderKeysListResponse = {
total: number;
};

const getProviderKeys = async ({ signal }: QueryFunctionContext) => {
const getProviderKeys = async (
{ signal }: QueryFunctionContext,
extendWithLocalData: (
data: UseProviderKeysListResponse,
) => UseProviderKeysListResponse,
) => {
const { data } = await api.get(PROVIDER_KEYS_REST_ENDPOINT, {
signal,
});

return data;
return extendWithLocalData(data);
};

export default function useProviderKeys(
params: UseProviderKeysListParams,
options?: QueryConfig<UseProviderKeysListResponse>,
) {
const { extendWithLocalData } = useLocalAIProviderData();

return useQuery({
queryKey: [PROVIDERS_KEYS_KEY, params],
queryFn: (context) => getProviderKeys(context),
queryFn: (context) => getProviderKeys(context, extendWithLocalData),
...options,
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import React, { useCallback, useMemo, useRef, useState } from "react";
import isNull from "lodash/isNull";
import pick from "lodash/pick";

import { PROVIDER_MODELS } from "@/constants/llm";
import { PROVIDERS } from "@/constants/providers";

import {
Expand All @@ -29,11 +28,12 @@ import { PROVIDER_MODEL_TYPE, PROVIDER_TYPE } from "@/types/providers";
import useProviderKeys from "@/api/provider-keys/useProviderKeys";
import AddEditAIProviderDialog from "@/components/shared/AddEditAIProviderDialog/AddEditAIProviderDialog";
import { areAllProvidersConfigured } from "@/lib/provider";
import useLLMProviderModelsData from "@/hooks/useLLMProviderModelsData";

interface PromptModelSelectProps {
value: PROVIDER_MODEL_TYPE | "";
workspaceName: string;
onChange: (value: PROVIDER_MODEL_TYPE) => void;
onChange: (value: PROVIDER_MODEL_TYPE, provider: PROVIDER_TYPE) => void;
hasError?: boolean;
provider: PROVIDER_TYPE | "";
onAddProvider?: (provider: PROVIDER_TYPE) => void;
Expand All @@ -57,6 +57,7 @@ const PromptModelSelect = ({
const [openConfigDialog, setOpenConfigDialog] = React.useState(false);
const [filterValue, setFilterValue] = useState("");
const [openProviderMenu, setOpenProviderMenu] = useState<string | null>(null);
const { getProviderModels } = useLLMProviderModelsData();

const { data } = useProviderKeys(
{
Expand All @@ -74,7 +75,7 @@ const PromptModelSelect = ({

const groupOptions = useMemo(() => {
const filteredByConfiguredProviders = pick(
PROVIDER_MODELS,
getProviderModels(),
configuredProviderKeys,
);

Expand All @@ -97,10 +98,11 @@ const PromptModelSelect = ({
label: PROVIDERS[providerName].label,
options,
icon: PROVIDERS[providerName].icon,
provider: providerName,
};
})
.filter((g): g is NonNullable<typeof g> => !isNull(g));
}, [configuredProviderKeys, onlyWithStructuredOutput]);
}, [configuredProviderKeys, onlyWithStructuredOutput, getProviderModels]);

const filteredOptions = useMemo(() => {
if (filterValue === "") {
Expand Down Expand Up @@ -131,9 +133,9 @@ const PromptModelSelect = ({

const handleOnChange = useCallback(
(value: PROVIDER_MODEL_TYPE) => {
onChange(value);
onChange(value, openProviderMenu as PROVIDER_TYPE);
},
[onChange],
[onChange, openProviderMenu],
);

const handleSelectOpenChange = useCallback((open: boolean) => {
Expand Down Expand Up @@ -191,16 +193,17 @@ const PromptModelSelect = ({
return (
<div>
{groupOptions.map((group) => (
<Popover key={group.label} open={group.label === openProviderMenu}>
<Popover key={group.label} open={group.provider === openProviderMenu}>
<PopoverTrigger asChild>
<div
key={group.label}
onMouseEnter={() => setOpenProviderMenu(group.label)}
onMouseEnter={() => setOpenProviderMenu(group.provider)}
onMouseLeave={() => setOpenProviderMenu(null)}
className={cn(
"comet-body-s flex h-10 w-full items-center rounded-sm p-0 pl-2 hover:bg-primary-foreground justify-center",
{
"bg-primary-foreground": group.label === openProviderMenu,
"bg-primary-foreground":
group.provider === openProviderMenu,
},
)}
>
Expand All @@ -215,7 +218,7 @@ const PromptModelSelect = ({
align="start"
className="max-h-[400px] overflow-y-auto p-1"
sideOffset={-5}
onMouseEnter={() => setOpenProviderMenu(group.label)}
onMouseEnter={() => setOpenProviderMenu(group.provider)}
hideWhenDetached
>
{group.options.map((option) => {
Expand Down Expand Up @@ -308,6 +311,7 @@ const PromptModelSelect = ({
</Select>
<AddEditAIProviderDialog
key={resetDialogKeyRef.current}
excludedProviders={configuredProviderKeys}
open={openConfigDialog}
setOpen={setOpenConfigDialog}
onAddProvider={onAddProvider}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import { Button } from "@/components/ui/button";

import OpenAIModelConfigs from "@/components/pages-shared/llm/PromptModelSettings/providerConfigs/OpenAIModelConfigs";
import AnthropicModelConfigs from "@/components/pages-shared/llm/PromptModelSettings/providerConfigs/AnthropicModelConfigs";
import isEmpty from "lodash/isEmpty";

interface PromptModelConfigsProps {
provider: PROVIDER_TYPE | "";
Expand Down Expand Up @@ -53,12 +54,12 @@ const PromptModelConfigs = ({
return;
};

const noProvider = provider === "";
const disabled = provider === "" || isEmpty(configs);

return (
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button variant="outline" size={size} disabled={noProvider}>
<Button variant="outline" size={size} disabled={disabled}>
<Settings2 className="size-3.5" />
</Button>
</DropdownMenuTrigger>
Expand Down
Loading

0 comments on commit 61eba4b

Please sign in to comment.