Skip to content

Commit

Permalink
[OPIK-857] [SDK][Tech Debt] Update usage parsing and validation logic (
Browse files Browse the repository at this point in the history
…#1247)

* add llm provider types

* add token usage dict for vertexai

* use vertexai-specific token usage dict in usage validator

* use vertexai-specific token usage dict in langchain

* fix linter warnings

* fix errors related to recent changes

* fix tests

* fix linter
  • Loading branch information
japdubengsub authored Feb 14, 2025
1 parent a5801dd commit d269147
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 40 deletions.
10 changes: 6 additions & 4 deletions sdks/python/src/opik/api_objects/opik_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,11 @@ def span(
start_time if start_time is not None else datetime_helpers.local_timestamp()
)

parsed_usage = validation_helpers.validate_and_parse_usage(usage, LOGGER)
parsed_usage = validation_helpers.validate_and_parse_usage(
usage=usage,
logger=LOGGER,
provider=provider,
)
if parsed_usage.full_usage is not None:
metadata = (
{"usage": parsed_usage.full_usage}
Expand Down Expand Up @@ -382,9 +386,7 @@ def span(
output=output,
metadata=metadata,
tags=tags,
usage=parsed_usage.full_usage
if provider == "google_vertexai"
else parsed_usage.supported_usage,
usage=parsed_usage.supported_usage,
model=model,
provider=provider,
error_info=error_info,
Expand Down
8 changes: 6 additions & 2 deletions sdks/python/src/opik/api_objects/span/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def update(
Returns:
None
"""
parsed_usage = validation_helpers.validate_and_parse_usage(usage, LOGGER)
parsed_usage = validation_helpers.validate_and_parse_usage(
usage, LOGGER, provider
)
if parsed_usage.full_usage is not None:
metadata = (
{"usage": parsed_usage.full_usage}
Expand Down Expand Up @@ -182,7 +184,9 @@ def span(
start_time = (
start_time if start_time is not None else datetime_helpers.local_timestamp()
)
parsed_usage = validation_helpers.validate_and_parse_usage(usage, LOGGER)
parsed_usage = validation_helpers.validate_and_parse_usage(
usage, LOGGER, provider
)
if parsed_usage.full_usage is not None:
metadata = (
{"usage": parsed_usage.full_usage}
Expand Down
4 changes: 3 additions & 1 deletion sdks/python/src/opik/api_objects/trace/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def span(
start_time = (
start_time if start_time is not None else datetime_helpers.local_timestamp()
)
parsed_usage = validation_helpers.validate_and_parse_usage(usage, LOGGER)
parsed_usage = validation_helpers.validate_and_parse_usage(
usage, LOGGER, provider
)
if parsed_usage.full_usage is not None:
metadata = (
{"usage": parsed_usage.full_usage}
Expand Down
9 changes: 7 additions & 2 deletions sdks/python/src/opik/api_objects/validation_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@


def validate_and_parse_usage(
usage: Any, logger: logging.Logger
usage: Any,
logger: logging.Logger,
provider: Optional[str],
) -> usage_validator.ParsedUsage:
if usage is None:
return usage_validator.ParsedUsage()

usage_validator_ = usage_validator.UsageValidator(usage)
usage_validator_ = usage_validator.UsageValidator(
usage=usage,
provider=provider,
)
if usage_validator_.validate().failed():
logger.warning(
logging_messages.INVALID_USAGE_WILL_NOT_BE_LOGGED,
Expand Down
31 changes: 21 additions & 10 deletions sdks/python/src/opik/integrations/langchain/google_run_helpers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import logging
from typing import Any, Dict, Optional, TYPE_CHECKING, Tuple, cast
from typing import Any, Dict, Final, Optional, TYPE_CHECKING, Tuple

from opik import logging_messages
from opik.types import LLMUsageInfo, UsageDict
from opik.types import LLMProvider, LLMUsageInfo, UsageDictVertexAI
from opik.validation import usage as usage_validator

if TYPE_CHECKING:
from langchain_core.tracers.schemas import Run

LOGGER = logging.getLogger(__name__)

PROVIDER_NAME: Final[LLMProvider] = "google_vertexai"


def get_llm_usage_info(run_dict: Optional[Dict[str, Any]] = None) -> LLMUsageInfo:
if run_dict is None:
Expand All @@ -21,21 +23,30 @@ def get_llm_usage_info(run_dict: Optional[Dict[str, Any]] = None) -> LLMUsageInf
return LLMUsageInfo(provider=provider, model=model, usage=usage_dict)


def _try_get_token_usage(run_dict: Dict[str, Any]) -> Optional[UsageDict]:
def _try_get_token_usage(run_dict: Dict[str, Any]) -> Optional[UsageDictVertexAI]:
try:
provider, _ = _get_provider_and_model(run_dict)

usage_metadata = run_dict["outputs"]["generations"][-1][-1]["generation_info"][
"usage_metadata"
]

token_usage = UsageDict(
token_usage = UsageDictVertexAI(
completion_tokens=usage_metadata["candidates_token_count"],
prompt_tokens=usage_metadata["prompt_token_count"],
total_tokens=usage_metadata["total_token_count"],
)
token_usage.update(usage_metadata)

if usage_validator.UsageValidator(token_usage).validate().ok():
return cast(UsageDict, token_usage)
**usage_metadata,
) # type: ignore

if (
usage_validator.UsageValidator(
usage=token_usage,
provider=provider,
)
.validate()
.ok()
):
return token_usage

return None
except Exception:
Expand Down Expand Up @@ -77,7 +88,7 @@ def _get_provider_and_model(
if invocation_params := run_dict["extra"].get("invocation_params"):
provider = invocation_params.get("_type")
if provider == "vertexai":
provider = "google_vertexai"
provider = PROVIDER_NAME
model = invocation_params.get("model_name")

return provider, model
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def _try_get_token_usage(run_dict: Dict[str, Any]) -> Optional[UsageDict]:
token usage info might be in different places, different formats or completely missing.
"""
try:
provider, _ = _get_provider_and_model(run_dict)

if run_dict["outputs"]["llm_output"] is not None:
token_usage = run_dict["outputs"]["llm_output"]["token_usage"]

Expand All @@ -53,7 +55,11 @@ def _try_get_token_usage(run_dict: Dict[str, Any]) -> Optional[UsageDict]:
)
return None

if usage_validator.UsageValidator(token_usage).validate().ok():
if (
usage_validator.UsageValidator(usage=token_usage, provider=provider)
.validate()
.ok()
):
return cast(UsageDict, token_usage)

return None
Expand Down
26 changes: 24 additions & 2 deletions sdks/python/src/opik/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses
import sys

from typing import Literal, Optional
from typing import Literal, Optional, Union
from typing_extensions import TypedDict

if sys.version_info < (3, 11):
Expand All @@ -12,6 +12,7 @@
SpanType = Literal["general", "tool", "llm"]
FeedbackType = Literal["numerical", "categorical"]
CreatedByType = Literal["evaluation"]
LLMProvider = Literal["openai", "google_vertexai"]


class UsageDict(TypedDict):
Expand All @@ -32,6 +33,27 @@ class UsageDict(TypedDict):
"""The total number of tokens used, including both prompt and completion."""


class UsageDictVertexAI(UsageDict):
"""
A TypedDict representing token usage information for Google Vertex AI.
This class defines the structure for token usage, including fields
for completion tokens, prompt tokens, and the total number of tokens used.
"""

cached_content_token_count: NotRequired[int]
"""The number of tokens cached."""

candidates_token_count: int
"""The number of tokens used for the completion."""

prompt_token_count: int
"""The number of tokens used for the prompt."""

total_token_count: int
"""The total number of tokens used, including both prompt and completion."""


class DistributedTraceHeadersDict(TypedDict):
opik_trace_id: str
opik_parent_span_id: str
Expand Down Expand Up @@ -84,4 +106,4 @@ class ErrorInfoDict(TypedDict):
class LLMUsageInfo:
provider: Optional[str] = None
model: Optional[str] = None
usage: Optional[UsageDict] = None
usage: Optional[Union[UsageDict, UsageDictVertexAI]] = None
46 changes: 29 additions & 17 deletions sdks/python/src/opik/validation/usage.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import pydantic
import dataclasses
from typing import Any, Dict, Optional, Union

import pydantic

from typing import Any, Dict, Optional
from ..types import UsageDict
from . import validator, result
from . import result, validator
from ..types import UsageDict, UsageDictVertexAI


class PydanticWrapper(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra="forbid")
usage: UsageDict
usage: Union[UsageDict, UsageDictVertexAI]


@dataclasses.dataclass
class ParsedUsage:
full_usage: Optional[Dict[str, Any]] = None
supported_usage: Optional[UsageDict] = None
supported_usage: Optional[Union[UsageDict, UsageDictVertexAI]] = None


EXPECTED_TYPES = "{'completion_tokens': int, 'prompt_tokens': int, 'total_tokens': int}"
Expand All @@ -25,17 +26,23 @@ class UsageValidator(validator.Validator):
Validator for span token usage
"""

def __init__(self, usage: Any):
def __init__(self, usage: Any, provider: Optional[str]):
self.usage = usage

self.provider = provider
self.parsed_usage = ParsedUsage()

def validate(self) -> result.ValidationResult:
try:
if isinstance(self.usage, dict):
filtered_usage = _keep_supported_keys(self.usage)
filtered_usage = self.supported_keys

# run validation
PydanticWrapper(usage=filtered_usage)
supported_usage = UsageDict(**filtered_usage) # type: ignore

if self.provider == "google_vertexai":
supported_usage = UsageDictVertexAI(**filtered_usage) # type: ignore
else:
supported_usage = UsageDict(**filtered_usage) # type: ignore
self.parsed_usage = ParsedUsage(
full_usage=self.usage, supported_usage=supported_usage
)
Expand Down Expand Up @@ -67,13 +74,18 @@ def failure_reason_message(self) -> str:
), "validate() must be called before accessing failure reason message"
return self.validation_result.failure_reasons[0]

@property
def supported_keys(self) -> Dict[str, Any]:
if self.provider == "google_vertexai":
supported_keys = UsageDictVertexAI.__annotations__.keys()
# `openai` and all other
else:
supported_keys = UsageDict.__annotations__.keys()

def _keep_supported_keys(usage: Dict[str, Any]) -> Dict[str, Any]:
supported_keys = UsageDict.__annotations__.keys()
filtered_usage = {}
filtered_usage = {}

for key in supported_keys:
if key in usage:
filtered_usage[key] = usage[key]
for key in supported_keys:
if key in self.usage:
filtered_usage[key] = self.usage[key]

return filtered_usage
return filtered_usage
2 changes: 1 addition & 1 deletion sdks/python/tests/unit/validation/test_usage_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
],
)
def test_usage_validator(usage_dict, is_valid):
tested = usage.UsageValidator(usage_dict)
tested = usage.UsageValidator(usage_dict, provider="some-provider")

assert tested.validate().ok() is is_valid, f"Failed with {usage_dict}"

Expand Down

0 comments on commit d269147

Please sign in to comment.