diff --git a/sdk/python/feast/errors.py b/sdk/python/feast/errors.py index 11ce9ebc62..84a9f12ec4 100644 --- a/sdk/python/feast/errors.py +++ b/sdk/python/feast/errors.py @@ -467,10 +467,15 @@ def __init__(self): class DataFrameSerializationError(FeastError): - def __init__(self, input_dict: dict): - super().__init__( - f"Failed to serialize the provided dictionary into a pandas DataFrame: {input_dict.keys()}" - ) + def __init__(self, input: Any): + if isinstance(input, dict): + super().__init__( + f"Failed to serialize the provided dictionary into a pandas DataFrame: {input.keys()}" + ) + else: + super().__init__( + "Failed to serialize the provided input into a pandas DataFrame" + ) class PermissionNotFoundException(FeastError): diff --git a/sdk/python/feast/feature_service.py b/sdk/python/feast/feature_service.py index 8b8cbac8ea..5c062f15c6 100644 --- a/sdk/python/feast/feature_service.py +++ b/sdk/python/feast/feature_service.py @@ -84,7 +84,9 @@ def __init__( if isinstance(feature_grouping, BaseFeatureView): self.feature_view_projections.append(feature_grouping.projection) - def infer_features(self, fvs_to_update: Dict[str, FeatureView]): + def infer_features( + self, fvs_to_update: Dict[str, Union[FeatureView, BaseFeatureView]] + ): """ Infers the features for the projections of this feature service, and updates this feature service in place. diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 52556eda15..6112279a02 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -608,7 +608,12 @@ def _make_inferences( update_feature_views_with_inferred_features_and_entities( sfvs_to_update, entities + entities_to_update, self.config ) - # TODO(kevjumba): Update schema inferrence + # We need to attach the time stamp fields to the underlying data sources + # and cascade the dependencies + update_feature_views_with_inferred_features_and_entities( + odfvs_to_update, entities + entities_to_update, self.config + ) + # TODO(kevjumba): Update schema inference for sfv in sfvs_to_update: if not sfv.schema: raise ValueError( @@ -618,8 +623,13 @@ def _make_inferences( for odfv in odfvs_to_update: odfv.infer_features() + odfvs_to_write = [ + odfv for odfv in odfvs_to_update if odfv.write_to_online_store + ] + # Update to include ODFVs with write to online store fvs_to_update_map = { - view.name: view for view in [*views_to_update, *sfvs_to_update] + view.name: view + for view in [*views_to_update, *sfvs_to_update, *odfvs_to_write] } for feature_service in feature_services_to_update: feature_service.infer_features(fvs_to_update=fvs_to_update_map) @@ -847,6 +857,11 @@ def apply( ] sfvs_to_update = [ob for ob in objects if isinstance(ob, StreamFeatureView)] odfvs_to_update = [ob for ob in objects if isinstance(ob, OnDemandFeatureView)] + odfvs_with_writes_to_update = [ + ob + for ob in objects + if isinstance(ob, OnDemandFeatureView) and ob.write_to_online_store + ] services_to_update = [ob for ob in objects if isinstance(ob, FeatureService)] data_sources_set_to_update = { ob for ob in objects if isinstance(ob, DataSource) @@ -868,10 +883,23 @@ def apply( for batch_source in batch_sources_to_add: data_sources_set_to_update.add(batch_source) - for fv in itertools.chain(views_to_update, sfvs_to_update): - data_sources_set_to_update.add(fv.batch_source) - if fv.stream_source: - data_sources_set_to_update.add(fv.stream_source) + for fv in itertools.chain( + views_to_update, sfvs_to_update, odfvs_with_writes_to_update + ): + if isinstance(fv, FeatureView): + data_sources_set_to_update.add(fv.batch_source) + if hasattr(fv, "stream_source"): + if fv.stream_source: + data_sources_set_to_update.add(fv.stream_source) + if isinstance(fv, OnDemandFeatureView): + for source_fvp in fv.source_feature_view_projections: + odfv_batch_source: Optional[DataSource] = ( + fv.source_feature_view_projections[source_fvp].batch_source + ) + if odfv_batch_source is not None: + data_sources_set_to_update.add(odfv_batch_source) + else: + pass for odfv in odfvs_to_update: for v in odfv.source_request_sources.values(): @@ -884,7 +912,9 @@ def apply( # Validate all feature views and make inferences. self._validate_all_feature_views( - views_to_update, odfvs_to_update, sfvs_to_update + views_to_update, + odfvs_to_update, + sfvs_to_update, ) self._make_inferences( data_sources_to_update, @@ -989,7 +1019,9 @@ def apply( tables_to_delete: List[FeatureView] = ( views_to_delete + sfvs_to_delete if not partial else [] # type: ignore ) - tables_to_keep: List[FeatureView] = views_to_update + sfvs_to_update # type: ignore + tables_to_keep: List[ + Union[FeatureView, StreamFeatureView, OnDemandFeatureView] + ] = views_to_update + sfvs_to_update + odfvs_with_writes_to_update # type: ignore self._get_provider().update_infra( project=self.project, @@ -1444,19 +1476,18 @@ def write_to_online_store( inputs: Optional the dictionary object to be written allow_registry_cache (optional): Whether to allow retrieving feature views from a cached registry. """ - # TODO: restrict this to work with online StreamFeatureViews and validate the FeatureView type + feature_view_dict = { + fv_proto.name: fv_proto + for fv_proto in self.list_all_feature_views(allow_registry_cache) + } try: - feature_view: FeatureView = self.get_stream_feature_view( - feature_view_name, allow_registry_cache=allow_registry_cache - ) + feature_view = feature_view_dict[feature_view_name] except FeatureViewNotFoundException: - feature_view = self.get_feature_view( - feature_view_name, allow_registry_cache=allow_registry_cache - ) + raise FeatureViewNotFoundException(feature_view_name, self.project) if df is not None and inputs is not None: raise ValueError("Both df and inputs cannot be provided at the same time.") if df is None and inputs is not None: - if isinstance(inputs, dict): + if isinstance(inputs, dict) or isinstance(inputs, List): try: df = pd.DataFrame(inputs) except Exception as _: @@ -1465,6 +1496,13 @@ def write_to_online_store( pass else: raise ValueError("inputs must be a dictionary or a pandas DataFrame.") + if df is not None and inputs is None: + if isinstance(df, dict) or isinstance(df, List): + try: + df = pd.DataFrame(df) + except Exception as _: + raise DataFrameSerializationError(df) + provider = self._get_provider() provider.ingest_df(feature_view, df) diff --git a/sdk/python/feast/inference.py b/sdk/python/feast/inference.py index b9fb9b694d..39782e1e31 100644 --- a/sdk/python/feast/inference.py +++ b/sdk/python/feast/inference.py @@ -183,10 +183,13 @@ def update_feature_views_with_inferred_features_and_entities( ) if not fv.features: - raise RegistryInferenceFailure( - "FeatureView", - f"Could not infer Features for the FeatureView named {fv.name}.", - ) + if isinstance(fv, OnDemandFeatureView): + return None + else: + raise RegistryInferenceFailure( + "FeatureView", + f"Could not infer Features for the FeatureView named {fv.name}.", + ) def _infer_features_and_entities( @@ -209,6 +212,7 @@ def _infer_features_and_entities( fv, join_keys, run_inference_for_features, config ) + entity_columns: List[Field] = fv.entity_columns if fv.entity_columns else [] columns_to_exclude = { fv.batch_source.timestamp_field, fv.batch_source.created_timestamp_column, @@ -235,7 +239,7 @@ def _infer_features_and_entities( if field.name not in [ entity_column.name for entity_column in fv.entity_columns ]: - fv.entity_columns.append(field) + entity_columns.append(field) elif not re.match( "^__|__$", col_name ): # double underscores often signal an internal-use column @@ -256,6 +260,8 @@ def _infer_features_and_entities( if field.name not in [feature.name for feature in fv.features]: fv.features.append(field) + fv.entity_columns = entity_columns + def _infer_on_demand_features_and_entities( fv: OnDemandFeatureView, @@ -282,18 +288,19 @@ def _infer_on_demand_features_and_entities( batch_source = getattr(source_feature_view, "batch_source") batch_field_mapping = getattr(batch_source or None, "field_mapping") - if batch_field_mapping: - for ( - original_col, - mapped_col, - ) in batch_field_mapping.items(): - if mapped_col in columns_to_exclude: - columns_to_exclude.remove(mapped_col) - columns_to_exclude.add(original_col) + for ( + original_col, + mapped_col, + ) in batch_field_mapping.items(): + if mapped_col in columns_to_exclude: + columns_to_exclude.remove(mapped_col) + columns_to_exclude.add(original_col) + + table_column_names_and_types = batch_source.get_table_column_names_and_types( + config + ) + batch_field_mapping = getattr(batch_source, "field_mapping", {}) - table_column_names_and_types = ( - batch_source.get_table_column_names_and_types(config) - ) for col_name, col_datatype in table_column_names_and_types: if col_name in columns_to_exclude: continue diff --git a/sdk/python/feast/infra/materialization/aws_lambda/lambda_engine.py b/sdk/python/feast/infra/materialization/aws_lambda/lambda_engine.py index 93e33d5949..2864012055 100644 --- a/sdk/python/feast/infra/materialization/aws_lambda/lambda_engine.py +++ b/sdk/python/feast/infra/materialization/aws_lambda/lambda_engine.py @@ -24,6 +24,7 @@ from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.base_registry import BaseRegistry +from feast.on_demand_feature_view import OnDemandFeatureView from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.stream_feature_view import StreamFeatureView from feast.utils import _get_column_names @@ -80,10 +81,10 @@ def update( self, project: str, views_to_delete: Sequence[ - Union[BatchFeatureView, StreamFeatureView, FeatureView] + Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] ], views_to_keep: Sequence[ - Union[BatchFeatureView, StreamFeatureView, FeatureView] + Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] ], entities_to_delete: Sequence[Entity], entities_to_keep: Sequence[Entity], diff --git a/sdk/python/feast/infra/materialization/batch_materialization_engine.py b/sdk/python/feast/infra/materialization/batch_materialization_engine.py index 8e854a508d..af92b95d17 100644 --- a/sdk/python/feast/infra/materialization/batch_materialization_engine.py +++ b/sdk/python/feast/infra/materialization/batch_materialization_engine.py @@ -12,6 +12,7 @@ from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.base_registry import BaseRegistry +from feast.on_demand_feature_view import OnDemandFeatureView from feast.repo_config import RepoConfig from feast.stream_feature_view import StreamFeatureView @@ -89,7 +90,7 @@ def update( Union[BatchFeatureView, StreamFeatureView, FeatureView] ], views_to_keep: Sequence[ - Union[BatchFeatureView, StreamFeatureView, FeatureView] + Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] ], entities_to_delete: Sequence[Entity], entities_to_keep: Sequence[Entity], diff --git a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py index 24608baebf..3abb6fffd6 100644 --- a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py +++ b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py @@ -23,6 +23,7 @@ from feast.infra.online_stores.online_store import OnlineStore from feast.infra.passthrough_provider import PassthroughProvider from feast.infra.registry.base_registry import BaseRegistry +from feast.on_demand_feature_view import OnDemandFeatureView from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.stream_feature_view import StreamFeatureView @@ -77,10 +78,10 @@ def update( self, project: str, views_to_delete: Sequence[ - Union[BatchFeatureView, StreamFeatureView, FeatureView] + Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] ], views_to_keep: Sequence[ - Union[BatchFeatureView, StreamFeatureView, FeatureView] + Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] ], entities_to_delete: Sequence[Entity], entities_to_keep: Sequence[Entity], diff --git a/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_engine.py b/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_engine.py index 510b6b4e4c..a0ccbcd768 100644 --- a/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_engine.py +++ b/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_engine.py @@ -24,6 +24,7 @@ from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.base_registry import BaseRegistry +from feast.on_demand_feature_view import OnDemandFeatureView from feast.repo_config import FeastConfigBaseModel from feast.stream_feature_view import StreamFeatureView from feast.utils import _get_column_names @@ -122,10 +123,10 @@ def update( self, project: str, views_to_delete: Sequence[ - Union[BatchFeatureView, StreamFeatureView, FeatureView] + Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] ], views_to_keep: Sequence[ - Union[BatchFeatureView, StreamFeatureView, FeatureView] + Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] ], entities_to_delete: Sequence[Entity], entities_to_keep: Sequence[Entity], diff --git a/sdk/python/feast/infra/materialization/local_engine.py b/sdk/python/feast/infra/materialization/local_engine.py index d818571453..fa60950f29 100644 --- a/sdk/python/feast/infra/materialization/local_engine.py +++ b/sdk/python/feast/infra/materialization/local_engine.py @@ -10,6 +10,7 @@ from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.base_registry import BaseRegistry +from feast.on_demand_feature_view import OnDemandFeatureView from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.stream_feature_view import StreamFeatureView from feast.utils import ( @@ -69,10 +70,10 @@ def update( self, project: str, views_to_delete: Sequence[ - Union[BatchFeatureView, StreamFeatureView, FeatureView] + Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] ], views_to_keep: Sequence[ - Union[BatchFeatureView, StreamFeatureView, FeatureView] + Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] ], entities_to_delete: Sequence[Entity], entities_to_keep: Sequence[Entity], diff --git a/sdk/python/feast/infra/materialization/snowflake_engine.py b/sdk/python/feast/infra/materialization/snowflake_engine.py index 3b789a3e5d..f4be9dd7da 100644 --- a/sdk/python/feast/infra/materialization/snowflake_engine.py +++ b/sdk/python/feast/infra/materialization/snowflake_engine.py @@ -31,6 +31,7 @@ get_snowflake_online_store_path, package_snowpark_zip, ) +from feast.on_demand_feature_view import OnDemandFeatureView from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import FeastConfigBaseModel, RepoConfig @@ -120,10 +121,10 @@ def update( self, project: str, views_to_delete: Sequence[ - Union[BatchFeatureView, StreamFeatureView, FeatureView] + Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] ], views_to_keep: Sequence[ - Union[BatchFeatureView, StreamFeatureView, FeatureView] + Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] ], entities_to_delete: Sequence[Entity], entities_to_keep: Sequence[Entity], diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index fdb5b055cf..ea86bd9175 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -17,6 +17,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union from feast import Entity, utils +from feast.batch_feature_view import BatchFeatureView from feast.feature_service import FeatureService from feast.feature_view import FeatureView from feast.infra.infra_object import InfraObject @@ -27,6 +28,7 @@ from feast.protos.feast.types.Value_pb2 import RepeatedValue from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import RepoConfig +from feast.stream_feature_view import StreamFeatureView class OnlineStore(ABC): @@ -288,7 +290,9 @@ def update( self, config: RepoConfig, tables_to_delete: Sequence[FeatureView], - tables_to_keep: Sequence[FeatureView], + tables_to_keep: Sequence[ + Union[BatchFeatureView, StreamFeatureView, FeatureView] + ], entities_to_delete: Sequence[Entity], entities_to_keep: Sequence[Entity], partial: bool, diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index c3c3048a89..11ccbb2450 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -5,7 +5,8 @@ import pyarrow as pa from tqdm import tqdm -from feast import importer +from feast import OnDemandFeatureView, importer +from feast.base_feature_view import BaseFeatureView from feast.batch_feature_view import BatchFeatureView from feast.data_source import DataSource from feast.entity import Entity @@ -122,7 +123,7 @@ def update_infra( self, project: str, tables_to_delete: Sequence[FeatureView], - tables_to_keep: Sequence[FeatureView], + tables_to_keep: Sequence[Union[FeatureView, OnDemandFeatureView]], entities_to_delete: Sequence[Entity], entities_to_keep: Sequence[Entity], partial: bool, @@ -160,7 +161,7 @@ def teardown_infra( def online_write_batch( self, config: RepoConfig, - table: FeatureView, + table: Union[FeatureView, BaseFeatureView, OnDemandFeatureView], data: List[ Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] ], @@ -274,25 +275,52 @@ def retrieve_online_documents( def ingest_df( self, - feature_view: FeatureView, + feature_view: Union[BaseFeatureView, FeatureView, OnDemandFeatureView], df: pd.DataFrame, + field_mapping: Optional[Dict] = None, ): table = pa.Table.from_pandas(df) - - if feature_view.batch_source.field_mapping is not None: - table = _run_pyarrow_field_mapping( - table, feature_view.batch_source.field_mapping + if isinstance(feature_view, OnDemandFeatureView): + if not field_mapping: + field_mapping = {} + table = _run_pyarrow_field_mapping(table, field_mapping) + join_keys = { + entity.name: entity.dtype.to_value_type() + for entity in feature_view.entity_columns + } + rows_to_write = _convert_arrow_to_proto(table, feature_view, join_keys) + + self.online_write_batch( + self.repo_config, feature_view, rows_to_write, progress=None ) + else: + if hasattr(feature_view, "entity_columns"): + join_keys = { + entity.name: entity.dtype.to_value_type() + for entity in feature_view.entity_columns + } + else: + join_keys = {} + + # Note: A dictionary mapping of column names in this data + # source to feature names in a feature table or view. Only used for feature + # columns, not entity or timestamp columns. + if hasattr(feature_view, "batch_source"): + if feature_view.batch_source.field_mapping is not None: + table = _run_pyarrow_field_mapping( + table, feature_view.batch_source.field_mapping + ) + else: + table = _run_pyarrow_field_mapping(table, {}) - join_keys = { - entity.name: entity.dtype.to_value_type() - for entity in feature_view.entity_columns - } - rows_to_write = _convert_arrow_to_proto(table, feature_view, join_keys) + if not isinstance(feature_view, BaseFeatureView): + for entity in feature_view.entity_columns: + join_keys[entity.name] = entity.dtype.to_value_type() + rows_to_write = _convert_arrow_to_proto(table, feature_view, join_keys) - self.online_write_batch( - self.repo_config, feature_view, rows_to_write, progress=None - ) + self.online_write_batch( + self.repo_config, feature_view, rows_to_write, progress=None + ) def ingest_df_to_offline_store(self, feature_view: FeatureView, table: pa.Table): if feature_view.batch_source.field_mapping is not None: diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index c0062dde02..22c2088861 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -8,6 +8,7 @@ from tqdm import tqdm from feast import FeatureService, errors +from feast.base_feature_view import BaseFeatureView from feast.data_source import DataSource from feast.entity import Entity from feast.feature_view import FeatureView @@ -15,6 +16,7 @@ from feast.infra.infra_object import Infra from feast.infra.offline_stores.offline_store import RetrievalJob from feast.infra.registry.base_registry import BaseRegistry +from feast.on_demand_feature_view import OnDemandFeatureView from feast.online_response import OnlineResponse from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto @@ -47,7 +49,7 @@ def update_infra( self, project: str, tables_to_delete: Sequence[FeatureView], - tables_to_keep: Sequence[FeatureView], + tables_to_keep: Sequence[Union[FeatureView, OnDemandFeatureView]], entities_to_delete: Sequence[Entity], entities_to_keep: Sequence[Entity], partial: bool, @@ -125,8 +127,9 @@ def online_write_batch( def ingest_df( self, - feature_view: FeatureView, + feature_view: Union[BaseFeatureView, FeatureView, OnDemandFeatureView], df: pd.DataFrame, + field_mapping: Optional[Dict] = None, ): """ Persists a dataframe to the online store. @@ -134,6 +137,7 @@ def ingest_df( Args: feature_view: The feature view to which the dataframe corresponds. df: The dataframe to be persisted. + field_mapping: A dictionary mapping dataframe column names to feature names. """ pass diff --git a/sdk/python/feast/type_map.py b/sdk/python/feast/type_map.py index e47e4fcbe4..315b73909e 100644 --- a/sdk/python/feast/type_map.py +++ b/sdk/python/feast/type_map.py @@ -559,6 +559,7 @@ def pa_to_feast_value_type(pa_type_as_str: str) -> ValueType: "binary": ValueType.BYTES, "bool": ValueType.BOOL, "null": ValueType.NULL, + "list": ValueType.DOUBLE_LIST, } value_type = type_map[pa_type_as_str] diff --git a/sdk/python/feast/types.py b/sdk/python/feast/types.py index 4b07c58d19..9fb3207e6d 100644 --- a/sdk/python/feast/types.py +++ b/sdk/python/feast/types.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod +from datetime import datetime, timezone from enum import Enum from typing import Dict, Union +import pyarrow + from feast.value_type import ValueType PRIMITIVE_FEAST_TYPES_TO_VALUE_TYPES = { @@ -30,6 +33,10 @@ } +def _utc_now() -> datetime: + return datetime.now(tz=timezone.utc) + + class ComplexFeastType(ABC): """ A ComplexFeastType represents a structured type that is recognized by Feast. @@ -103,7 +110,6 @@ def __hash__(self): Float64 = PrimitiveFeastType.FLOAT64 UnixTimestamp = PrimitiveFeastType.UNIX_TIMESTAMP - SUPPORTED_BASE_TYPES = [ Invalid, String, @@ -159,7 +165,6 @@ def __str__(self): FeastType = Union[ComplexFeastType, PrimitiveFeastType] - VALUE_TYPES_TO_FEAST_TYPES: Dict["ValueType", FeastType] = { ValueType.UNKNOWN: Invalid, ValueType.BYTES: Bytes, @@ -180,6 +185,40 @@ def __str__(self): ValueType.UNIX_TIMESTAMP_LIST: Array(UnixTimestamp), } +FEAST_TYPES_TO_PYARROW_TYPES = { + String: pyarrow.string(), + Bool: pyarrow.bool_(), + Int32: pyarrow.int32(), + Int64: pyarrow.int64(), + Float32: pyarrow.float32(), + Float64: pyarrow.float64(), + # Note: datetime only supports microseconds /~https://github.com/python/cpython/blob/3.8/Lib/datetime.py#L1559 + UnixTimestamp: pyarrow.timestamp("us", tz=_utc_now().tzname()), +} + + +def from_feast_to_pyarrow_type(feast_type: FeastType) -> pyarrow.DataType: + """ + Converts a Feast type to a PyArrow type. + + Args: + feast_type: The Feast type to be converted. + + Raises: + ValueError: The conversion could not be performed. + """ + assert isinstance( + feast_type, (ComplexFeastType, PrimitiveFeastType) + ), f"Expected FeastType, got {type(feast_type)}" + if isinstance(feast_type, PrimitiveFeastType): + if feast_type in FEAST_TYPES_TO_PYARROW_TYPES: + return FEAST_TYPES_TO_PYARROW_TYPES[feast_type] + elif isinstance(feast_type, ComplexFeastType): + # Handle the case when feast_type is an instance of ComplexFeastType + pass + + raise ValueError(f"Could not convert Feast type {feast_type} to PyArrow type.") + def from_value_type( value_type: ValueType, diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 8a9f1fadae..ec2da79782 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -28,7 +28,6 @@ from feast.constants import FEAST_FS_YAML_FILE_PATH_ENV_NAME from feast.entity import Entity from feast.errors import ( - EntityNotFoundException, FeatureNameCollisionError, FeatureViewNotFoundException, RequestDataNotFoundInEntityRowsException, @@ -43,10 +42,12 @@ from feast.protos.feast.types.Value_pb2 import RepeatedValue as RepeatedValueProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.type_map import python_values_to_proto_values +from feast.types import from_feast_to_pyarrow_type from feast.value_type import ValueType from feast.version import get_version if typing.TYPE_CHECKING: + from feast.base_feature_view import BaseFeatureView from feast.feature_service import FeatureService from feast.feature_view import FeatureView from feast.infra.registry.base_registry import BaseRegistry @@ -227,6 +228,18 @@ def _coerce_datetime(ts): def _convert_arrow_to_proto( + table: Union[pyarrow.Table, pyarrow.RecordBatch], + feature_view: Union["FeatureView", "BaseFeatureView", "OnDemandFeatureView"], + join_keys: Dict[str, ValueType], +) -> List[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]]: + # This is a workaround for isinstance(feature_view, OnDemandFeatureView), which triggers a circular import + if getattr(feature_view, "source_request_sources", None): + return _convert_arrow_odfv_to_proto(table, feature_view, join_keys) # type: ignore[arg-type] + else: + return _convert_arrow_fv_to_proto(table, feature_view, join_keys) # type: ignore[arg-type] + + +def _convert_arrow_fv_to_proto( table: Union[pyarrow.Table, pyarrow.RecordBatch], feature_view: "FeatureView", join_keys: Dict[str, ValueType], @@ -287,6 +300,76 @@ def _convert_arrow_to_proto( return list(zip(entity_keys, features, event_timestamps, created_timestamps)) +def _convert_arrow_odfv_to_proto( + table: Union[pyarrow.Table, pyarrow.RecordBatch], + feature_view: "OnDemandFeatureView", + join_keys: Dict[str, ValueType], +) -> List[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]]: + # Avoid ChunkedArrays which guarantees `zero_copy_only` available. + if isinstance(table, pyarrow.Table): + table = table.to_batches()[0] + + columns = [ + (field.name, field.dtype.to_value_type()) for field in feature_view.features + ] + list(join_keys.items()) + + proto_values_by_column = { + column: python_values_to_proto_values( + table.column(column).to_numpy(zero_copy_only=False), value_type + ) + for column, value_type in columns + if column in table.column_names + } + # Adding On Demand Features + for feature in feature_view.features: + if ( + feature.name in [c[0] for c in columns] + and feature.name not in proto_values_by_column + ): + # initializing the column as null + null_column = pyarrow.array( + [None] * table.num_rows, + type=from_feast_to_pyarrow_type(feature.dtype), + ) + updated_table = pyarrow.RecordBatch.from_arrays( + table.columns + [null_column], + schema=table.schema.append( + pyarrow.field(feature.name, null_column.type) + ), + ) + proto_values_by_column[feature.name] = python_values_to_proto_values( + updated_table.column(feature.name).to_numpy(zero_copy_only=False), + feature.dtype.to_value_type(), + ) + + entity_keys = [ + EntityKeyProto( + join_keys=join_keys, + entity_values=[proto_values_by_column[k][idx] for k in join_keys], + ) + for idx in range(table.num_rows) + ] + + # Serialize the features per row + feature_dict = { + feature.name: proto_values_by_column[feature.name] + for feature in feature_view.features + } + features = [dict(zip(feature_dict, vars)) for vars in zip(*feature_dict.values())] + + # We need to artificially add event_timestamps and created_timestamps + event_timestamps = [] + timestamp_values = pd.to_datetime([_utc_now() for i in range(table.num_rows)]) + + for val in timestamp_values: + event_timestamps.append(_coerce_datetime(val)) + + # setting them equivalent + created_timestamps = event_timestamps + + return list(zip(entity_keys, features, event_timestamps, created_timestamps)) + + def _validate_entity_values(join_key_values: Dict[str, List[ValueProto]]): set_of_row_lengths = {len(v) for v in join_key_values.values()} if len(set_of_row_lengths) > 1: @@ -931,6 +1014,17 @@ def _prepare_entities_to_read_from_online_store( num_rows = _validate_entity_values(entity_proto_values) + odfv_entities: List[Entity] = [] + request_source_keys: List[str] = [] + for on_demand_feature_view in requested_on_demand_feature_views: + odfv_entities.append(*getattr(on_demand_feature_view, "entities", [])) + for source in on_demand_feature_view.source_request_sources: + source_schema = on_demand_feature_view.source_request_sources[source].schema + for column in source_schema: + request_source_keys.append(column.name) + + join_keys_set.update(set(odfv_entities)) + join_key_values: Dict[str, List[ValueProto]] = {} request_data_features: Dict[str, List[ValueProto]] = {} # Entity rows may be either entities or request data. @@ -938,22 +1032,20 @@ def _prepare_entities_to_read_from_online_store( # Found request data if join_key_or_entity_name in needed_request_data: request_data_features[join_key_or_entity_name] = values - else: - if join_key_or_entity_name in join_keys_set: - join_key = join_key_or_entity_name - else: - try: - join_key = entity_name_to_join_key_map[join_key_or_entity_name] - except KeyError: - raise EntityNotFoundException(join_key_or_entity_name, project) - else: - warnings.warn( - "Using entity name is deprecated. Use join_key instead." - ) - - # All join keys should be returned in the result. + elif join_key_or_entity_name in join_keys_set: + # It's a join key + join_key = join_key_or_entity_name requested_result_row_names.add(join_key) join_key_values[join_key] = values + elif join_key_or_entity_name in entity_name_to_join_key_map: + # It's an entity name (deprecated) + join_key = entity_name_to_join_key_map[join_key_or_entity_name] + warnings.warn("Using entity name is deprecated. Use join_key instead.") + requested_result_row_names.add(join_key) + join_key_values[join_key] = values + else: + # Key is not recognized (likely a feature value), so we skip it. + continue # Or handle accordingly ensure_request_data_values_exist(needed_request_data, request_data_features) diff --git a/sdk/python/tests/integration/materialization/test_snowflake.py b/sdk/python/tests/integration/materialization/test_snowflake.py index dc9d684ab5..5f01641c3b 100644 --- a/sdk/python/tests/integration/materialization/test_snowflake.py +++ b/sdk/python/tests/integration/materialization/test_snowflake.py @@ -5,7 +5,7 @@ from feast import Field from feast.entity import Entity -from feast.feature_view import FeatureView +from feast.feature_view import DUMMY_ENTITY_FIELD, FeatureView from feast.types import Array, Bool, Bytes, Float64, Int32, Int64, String, UnixTimestamp from feast.utils import _utc_now from tests.data.data_creator import create_basic_driver_dataset @@ -225,7 +225,7 @@ def test_snowflake_materialization_entityless_fv(): try: fs.apply([overall_stats_fv, driver]) - assert overall_stats_fv.entity_columns != [] + assert overall_stats_fv.entity_columns == [DUMMY_ENTITY_FIELD] # materialization is run in two steps and # we use timestamp from generated dataframe as a split point diff --git a/sdk/python/tests/unit/local_feast_tests/test_local_feature_store.py b/sdk/python/tests/unit/local_feast_tests/test_local_feature_store.py index cc48295b20..1b9b48d8d0 100644 --- a/sdk/python/tests/unit/local_feast_tests/test_local_feature_store.py +++ b/sdk/python/tests/unit/local_feast_tests/test_local_feature_store.py @@ -11,7 +11,7 @@ from feast.entity import Entity from feast.feast_object import ALL_RESOURCE_TYPES from feast.feature_store import FeatureStore -from feast.feature_view import DUMMY_ENTITY_ID, FeatureView +from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_NAME, FeatureView from feast.field import Field from feast.infra.offline_stores.file_source import FileSource from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig @@ -20,7 +20,7 @@ from feast.permissions.policy import RoleBasedPolicy from feast.repo_config import RepoConfig from feast.stream_feature_view import stream_feature_view -from feast.types import Array, Bytes, Float32, Int64, String +from feast.types import Array, Bytes, Float32, Int64, String, ValueType, from_value_type from tests.integration.feature_repos.universal.feature_views import TAGS from tests.utils.cli_repo_creator import CliRunner, get_example_repo from tests.utils.data_source_test_creator import prep_file_source @@ -347,7 +347,7 @@ def test_apply_entities_and_feature_views(test_feature_store): "test_feature_store", [lazy_fixture("feature_store_with_local_registry")], ) -def test_apply_dummuy_entity_and_feature_view_columns(test_feature_store): +def test_apply_dummy_entity_and_feature_view_columns(test_feature_store): assert isinstance(test_feature_store, FeatureStore) # Create Feature Views batch_source = FileSource( @@ -357,33 +357,51 @@ def test_apply_dummuy_entity_and_feature_view_columns(test_feature_store): created_timestamp_column="timestamp", ) - e1 = Entity(name="fs1_my_entity_1", description="something") + e1 = Entity( + name="fs1_my_entity_1", description="something", value_type=ValueType.INT64 + ) - fv = FeatureView( + fv_no_entity = FeatureView( name="my_feature_view_no_entity", schema=[ Field(name="fs1_my_feature_1", dtype=Int64), - Field(name="fs1_my_feature_2", dtype=String), - Field(name="fs1_my_feature_3", dtype=Array(String)), - Field(name="fs1_my_feature_4", dtype=Array(Bytes)), - Field(name="fs1_my_entity_2", dtype=Int64), + Field(name="fs1_my_entity_1", dtype=Int64), ], entities=[], tags={"team": "matchmaking"}, source=batch_source, ttl=timedelta(minutes=5), ) + fv_with_entity = FeatureView( + name="my_feature_view_with_entity", + schema=[ + Field(name="fs1_my_feature_1", dtype=Int64), + Field(name="fs1_my_entity_1", dtype=Int64), + ], + entities=[e1], + tags={"team": "matchmaking"}, + source=batch_source, + ttl=timedelta(minutes=5), + ) # Check that the entity_columns are empty before applying - assert fv.entity_columns == [] - + assert fv_no_entity.entities == [DUMMY_ENTITY_NAME] + assert fv_no_entity.entity_columns == [] + # Note that this test is a special case rooted in the entity being included in the schema + assert fv_with_entity.entity_columns == [ + Field(name=e1.join_key, dtype=from_value_type(e1.value_type)) + ] # Register Feature View - test_feature_store.apply([fv, e1]) - fv_actual = test_feature_store.get_feature_view("my_feature_view_no_entity") - + test_feature_store.apply([e1, fv_no_entity, fv_with_entity]) + fv_from_online_store = test_feature_store.get_feature_view( + "my_feature_view_no_entity" + ) # Note that after the apply() the feature_view serializes the Dummy Entity ID - assert fv.entity_columns[0].name == DUMMY_ENTITY_ID - assert fv_actual.entity_columns[0].name == DUMMY_ENTITY_ID + assert fv_no_entity.entity_columns[0].name == DUMMY_ENTITY_ID + assert fv_from_online_store.entity_columns[0].name == DUMMY_ENTITY_ID + assert fv_from_online_store.entities == [] + assert fv_no_entity.entities == [DUMMY_ENTITY_NAME] + assert fv_with_entity.entity_columns[0].name == e1.join_key test_feature_store.teardown() diff --git a/sdk/python/tests/unit/online_store/test_online_writes.py b/sdk/python/tests/unit/online_store/test_online_writes.py index 7927b564c6..803a93b48d 100644 --- a/sdk/python/tests/unit/online_store/test_online_writes.py +++ b/sdk/python/tests/unit/online_store/test_online_writes.py @@ -77,6 +77,7 @@ def setUp(self): ) # Before apply() join_keys is empty assert driver_stats_fv.join_keys == [] + assert driver_stats_fv.entity_columns == [] @on_demand_feature_view( sources=[driver_stats_fv[["conv_rate", "acc_rate"]]], @@ -104,6 +105,8 @@ def test_view(inputs: dict[str, Any]) -> dict[str, Any]: ) # after apply() join_keys is [driver] assert driver_stats_fv.join_keys == [driver.join_key] + assert driver_stats_fv.entity_columns[0].name == driver.join_key + self.store.write_to_online_store( feature_view_name="driver_hourly_stats", df=driver_df ) diff --git a/sdk/python/tests/unit/test_on_demand_feature_view.py b/sdk/python/tests/unit/test_on_demand_feature_view.py index 6073891aba..4b30bd6be9 100644 --- a/sdk/python/tests/unit/test_on_demand_feature_view.py +++ b/sdk/python/tests/unit/test_on_demand_feature_view.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import datetime from typing import Any, Dict, List import pandas as pd @@ -50,6 +50,15 @@ def python_native_udf(features_dict: Dict[str, Any]) -> Dict[str, Any]: return output_dict +def python_writes_test_udf(features_dict: Dict[str, Any]) -> Dict[str, Any]: + output_dict: Dict[str, List[Any]] = { + "output1": features_dict["feature1"] + 100, + "output2": features_dict["feature2"] + 101, + "output3": datetime.datetime.now(), + } + return output_dict + + @pytest.mark.filterwarnings("ignore:udf and udf_string parameters are deprecated") def test_hash(): file_source = FileSource(name="my-file-source", path="test.parquet") @@ -261,3 +270,89 @@ def test_from_proto_backwards_compatible_udf(): reserialized_proto.feature_transformation.udf_string == on_demand_feature_view.feature_transformation.udf_string ) + + +def test_on_demand_feature_view_writes_protos(): + file_source = FileSource(name="my-file-source", path="test.parquet") + feature_view = FeatureView( + name="my-feature-view", + entities=[], + schema=[ + Field(name="feature1", dtype=Float32), + Field(name="feature2", dtype=Float32), + ], + source=file_source, + ) + sources = [feature_view] + on_demand_feature_view = OnDemandFeatureView( + name="my-on-demand-feature-view", + sources=sources, + schema=[ + Field(name="output1", dtype=Float32), + Field(name="output2", dtype=Float32), + ], + feature_transformation=PandasTransformation( + udf=udf1, udf_string="udf1 source code" + ), + write_to_online_store=True, + ) + + proto = on_demand_feature_view.to_proto() + reserialized_proto = OnDemandFeatureView.from_proto(proto) + + assert on_demand_feature_view.write_to_online_store + assert proto.spec.write_to_online_store + assert reserialized_proto.write_to_online_store + + proto.spec.write_to_online_store = False + reserialized_proto = OnDemandFeatureView.from_proto(proto) + assert not reserialized_proto.write_to_online_store + + +def test_on_demand_feature_view_stored_writes(): + file_source = FileSource(name="my-file-source", path="test.parquet") + feature_view = FeatureView( + name="my-feature-view", + entities=[], + schema=[ + Field(name="feature1", dtype=Float32), + Field(name="feature2", dtype=Float32), + ], + source=file_source, + ) + sources = [feature_view] + + on_demand_feature_view = OnDemandFeatureView( + name="my-on-demand-feature-view", + sources=sources, + schema=[ + Field(name="output1", dtype=Float32), + Field(name="output2", dtype=Float32), + ], + feature_transformation=PythonTransformation( + udf=python_writes_test_udf, udf_string="python native udf source code" + ), + description="testing on demand feature view stored writes", + mode="python", + write_to_online_store=True, + ) + + transformed_output = on_demand_feature_view.transform_dict( + { + "feature1": 0, + "feature2": 1, + } + ) + expected_output = {"feature1": 0, "feature2": 1, "output1": 100, "output2": 102} + keys_to_validate = [ + "feature1", + "feature2", + "output1", + "output2", + ] + for k in keys_to_validate: + assert transformed_output[k] == expected_output[k] + + assert transformed_output["output3"] is not None and isinstance( + transformed_output["output3"], datetime.datetime + ) diff --git a/sdk/python/tests/unit/test_on_demand_python_transformation.py b/sdk/python/tests/unit/test_on_demand_python_transformation.py index b994ea8042..2410ee03aa 100644 --- a/sdk/python/tests/unit/test_on_demand_python_transformation.py +++ b/sdk/python/tests/unit/test_on_demand_python_transformation.py @@ -28,7 +28,9 @@ Float64, Int64, String, + UnixTimestamp, ValueType, + _utc_now, from_value_type, ) @@ -71,6 +73,13 @@ def setUp(self): timestamp_field="event_timestamp", created_timestamp_column="created", ) + input_request_source = RequestSource( + name="counter_source", + schema=[ + Field(name="counter", dtype=Int64), + Field(name="input_datetime", dtype=UnixTimestamp), + ], + ) driver_stats_fv = FeatureView( name="driver_hourly_stats", @@ -165,6 +174,36 @@ def python_singleton_view(inputs: dict[str, Any]) -> dict[str, Any]: ) return output + @on_demand_feature_view( + sources=[ + driver_stats_fv[["conv_rate", "acc_rate"]], + input_request_source, + ], + schema=[ + Field(name="conv_rate_plus_acc", dtype=Float64), + Field(name="current_datetime", dtype=UnixTimestamp), + Field(name="counter", dtype=Int64), + Field(name="input_datetime", dtype=UnixTimestamp), + ], + mode="python", + write_to_online_store=True, + ) + def python_stored_writes_feature_view( + inputs: dict[str, Any], + ) -> dict[str, Any]: + output: dict[str, Any] = { + "conv_rate_plus_acc": [ + conv_rate + acc_rate + for conv_rate, acc_rate in zip( + inputs["conv_rate"], inputs["acc_rate"] + ) + ], + "current_datetime": [datetime.now() for _ in inputs["conv_rate"]], + "counter": [c + 1 for c in inputs["counter"]], + "input_datetime": [d for d in inputs["input_datetime"]], + } + return output + with pytest.raises(TypeError): # Note the singleton view will fail as the type is # expected to be a list which can be confirmed in _infer_features_dict @@ -189,6 +228,7 @@ def python_singleton_view(inputs: dict[str, Any]) -> dict[str, Any]: python_view, python_demo_view, driver_stats_entity_less_fv, + python_stored_writes_feature_view, ] ) self.store.write_to_online_store( @@ -199,15 +239,17 @@ def python_singleton_view(inputs: dict[str, Any]) -> dict[str, Any]: ] assert driver_stats_entity_less_fv.entity_columns == [DUMMY_ENTITY_FIELD] - assert len(self.store.list_all_feature_views()) == 5 + assert len(self.store.list_all_feature_views()) == 6 assert len(self.store.list_feature_views()) == 2 - assert len(self.store.list_on_demand_feature_views()) == 3 + assert len(self.store.list_on_demand_feature_views()) == 4 assert len(self.store.list_stream_feature_views()) == 0 def test_python_pandas_parity(self): entity_rows = [ { "driver_id": 1001, + "counter": 0, + "input_datetime": _utc_now(), } ] @@ -289,6 +331,40 @@ def test_python_docs_demo(self): == online_python_response["conv_rate_plus_val2_python"][0] ) + def test_stored_writes(self): + # Note that here we shouldn't have to pass the request source features for reading + # because they should have already been written to the online store + current_datetime = _utc_now() + entity_rows_to_read = [ + { + "driver_id": 1001, + "counter": 0, + "input_datetime": current_datetime, + } + ] + + online_python_response = self.store.get_online_features( + entity_rows=entity_rows_to_read, + features=[ + "python_stored_writes_feature_view:conv_rate_plus_acc", + "python_stored_writes_feature_view:current_datetime", + "python_stored_writes_feature_view:counter", + "python_stored_writes_feature_view:input_datetime", + ], + ).to_dict() + + assert sorted(list(online_python_response.keys())) == sorted( + [ + "driver_id", + "conv_rate_plus_acc", + "counter", + "current_datetime", + "input_datetime", + ] + ) + print(online_python_response) + # Now this is where we need to test the stored writes, this should return the same output as the previous + class TestOnDemandPythonTransformationAllDataTypes(unittest.TestCase): def setUp(self): @@ -403,15 +479,15 @@ def python_view(inputs: dict[str, Any]) -> dict[str, Any]: self.store.apply( [driver, driver_stats_source, driver_stats_fv, python_view] ) - self.store.write_to_online_store( - feature_view_name="driver_hourly_stats", df=driver_df - ) - fv_applied = self.store.get_feature_view("driver_hourly_stats") assert fv_applied.entities == [driver.name] # Note here that after apply() is called, the entity_columns are populated with the join_key assert fv_applied.entity_columns[0].name == driver.join_key + self.store.write_to_online_store( + feature_view_name="driver_hourly_stats", df=driver_df + ) + def test_python_transformation_returning_all_data_types(self): entity_rows = [ { @@ -526,3 +602,215 @@ def python_view(inputs: dict[str, Any]) -> dict[str, Any]: ), ): store.apply([request_source, python_view]) + + +class TestOnDemandTransformationsWithWrites(unittest.TestCase): + def test_stored_writes(self): + with tempfile.TemporaryDirectory() as data_dir: + self.store = FeatureStore( + config=RepoConfig( + project="test_on_demand_python_transformation", + registry=os.path.join(data_dir, "registry.db"), + provider="local", + entity_key_serialization_version=2, + online_store=SqliteOnlineStoreConfig( + path=os.path.join(data_dir, "online.db") + ), + ) + ) + + # Generate test data. + end_date = datetime.now().replace(microsecond=0, second=0, minute=0) + start_date = end_date - timedelta(days=15) + + driver_entities = [1001, 1002, 1003, 1004, 1005] + driver_df = create_driver_hourly_stats_df( + driver_entities, start_date, end_date + ) + driver_stats_path = os.path.join(data_dir, "driver_stats.parquet") + driver_df.to_parquet( + path=driver_stats_path, allow_truncated_timestamps=True + ) + + driver = Entity(name="driver", join_keys=["driver_id"]) + + driver_stats_source = FileSource( + name="driver_hourly_stats_source", + path=driver_stats_path, + timestamp_field="event_timestamp", + created_timestamp_column="created", + ) + input_request_source = RequestSource( + name="counter_source", + schema=[ + Field(name="counter", dtype=Int64), + Field(name="input_datetime", dtype=UnixTimestamp), + ], + ) + + driver_stats_fv = FeatureView( + name="driver_hourly_stats", + entities=[driver], + ttl=timedelta(days=0), + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + ], + online=True, + source=driver_stats_source, + ) + assert driver_stats_fv.entities == [driver.name] + assert driver_stats_fv.entity_columns == [] + + @on_demand_feature_view( + entities=[driver], + sources=[ + driver_stats_fv[["conv_rate", "acc_rate"]], + input_request_source, + ], + schema=[ + Field(name="conv_rate_plus_acc", dtype=Float64), + Field(name="current_datetime", dtype=UnixTimestamp), + Field(name="counter", dtype=Int64), + Field(name="input_datetime", dtype=UnixTimestamp), + ], + mode="python", + write_to_online_store=True, + ) + def python_stored_writes_feature_view( + inputs: dict[str, Any], + ) -> dict[str, Any]: + output: dict[str, Any] = { + "conv_rate_plus_acc": [ + conv_rate + acc_rate + for conv_rate, acc_rate in zip( + inputs["conv_rate"], inputs["acc_rate"] + ) + ], + "current_datetime": [datetime.now() for _ in inputs["conv_rate"]], + "counter": [c + 1 for c in inputs["counter"]], + "input_datetime": [d for d in inputs["input_datetime"]], + } + return output + + assert python_stored_writes_feature_view.entities == [driver.name] + assert python_stored_writes_feature_view.entity_columns == [] + + self.store.apply( + [ + driver, + driver_stats_source, + driver_stats_fv, + python_stored_writes_feature_view, + ] + ) + fv_applied = self.store.get_feature_view("driver_hourly_stats") + odfv_applied = self.store.get_on_demand_feature_view( + "python_stored_writes_feature_view" + ) + + assert fv_applied.entities == [driver.name] + assert odfv_applied.entities == [driver.name] + + # Note here that after apply() is called, the entity_columns are populated with the join_key + # assert fv_applied.entity_columns[0].name == driver.join_key + assert fv_applied.entity_columns[0].name == driver.join_key + assert odfv_applied.entity_columns[0].name == driver.join_key + + assert len(self.store.list_all_feature_views()) == 2 + assert len(self.store.list_feature_views()) == 1 + assert len(self.store.list_on_demand_feature_views()) == 1 + assert len(self.store.list_stream_feature_views()) == 0 + assert ( + driver_stats_fv.entity_columns + == self.store.get_feature_view("driver_hourly_stats").entity_columns + ) + assert ( + python_stored_writes_feature_view.entity_columns + == self.store.get_on_demand_feature_view( + "python_stored_writes_feature_view" + ).entity_columns + ) + + current_datetime = _utc_now() + fv_entity_rows_to_write = [ + { + "driver_id": 1001, + "conv_rate": 0.25, + "acc_rate": 0.25, + "avg_daily_trips": 2, + "event_timestamp": current_datetime, + "created": current_datetime, + } + ] + odfv_entity_rows_to_write = [ + { + "driver_id": 1001, + "counter": 0, + "input_datetime": current_datetime, + } + ] + fv_entity_rows_to_read = [ + { + "driver_id": 1001, + } + ] + # Note that here we shouldn't have to pass the request source features for reading + # because they should have already been written to the online store + odfv_entity_rows_to_read = [ + { + "driver_id": 1001, + "conv_rate": 0.25, + "acc_rate": 0.50, + "counter": 0, + "input_datetime": current_datetime, + } + ] + print("storing fv features") + self.store.write_to_online_store( + feature_view_name="driver_hourly_stats", + df=fv_entity_rows_to_write, + ) + print("reading fv features") + online_python_response = self.store.get_online_features( + entity_rows=fv_entity_rows_to_read, + features=[ + "driver_hourly_stats:conv_rate", + "driver_hourly_stats:acc_rate", + "driver_hourly_stats:avg_daily_trips", + ], + ).to_dict() + + assert online_python_response == { + "driver_id": [1001], + "conv_rate": [0.25], + "avg_daily_trips": [2], + "acc_rate": [0.25], + } + + print("storing odfv features") + self.store.write_to_online_store( + feature_view_name="python_stored_writes_feature_view", + df=odfv_entity_rows_to_write, + ) + print("reading odfv features") + online_odfv_python_response = self.store.get_online_features( + entity_rows=odfv_entity_rows_to_read, + features=[ + "python_stored_writes_feature_view:conv_rate_plus_acc", + "python_stored_writes_feature_view:current_datetime", + "python_stored_writes_feature_view:counter", + "python_stored_writes_feature_view:input_datetime", + ], + ).to_dict() + print(online_odfv_python_response) + assert sorted(list(online_odfv_python_response.keys())) == sorted( + [ + "driver_id", + "conv_rate_plus_acc", + "counter", + "current_datetime", + "input_datetime", + ] + )