diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index c81be17cee..25fb037f87 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1405,7 +1405,7 @@ def write_to_online_store( provider.ingest_df(feature_view, entities, df) @log_exceptions_and_usage - def write_to_offline_store( + def _write_to_offline_store( self, feature_view_name: str, df: pd.DataFrame, @@ -1423,8 +1423,9 @@ def write_to_offline_store( feature_view = self.get_feature_view( feature_view_name, allow_registry_cache=allow_registry_cache ) + table = pa.Table.from_pandas(df) provider = self._get_provider() - provider.ingest_df_to_offline_store(feature_view, df) + provider.ingest_df_to_offline_store(feature_view, table) @log_exceptions_and_usage def get_online_features( diff --git a/sdk/python/feast/infra/offline_stores/file.py b/sdk/python/feast/infra/offline_stores/file.py index 7288223883..194c233f53 100644 --- a/sdk/python/feast/infra/offline_stores/file.py +++ b/sdk/python/feast/infra/offline_stores/file.py @@ -1,7 +1,7 @@ import uuid from datetime import datetime from pathlib import Path -from typing import Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import dask.dataframe as dd import pandas as pd @@ -404,6 +404,42 @@ def write_logged_features( existing_data_behavior="overwrite_or_ignore", ) + @staticmethod + def offline_write_batch( + config: RepoConfig, + feature_view: FeatureView, + data: pyarrow.Table, + progress: Optional[Callable[[int], Any]], + ): + if not feature_view.batch_source: + raise ValueError( + "feature view does not have a batch source to persist offline data" + ) + if not isinstance(config.offline_store, FileOfflineStoreConfig): + raise ValueError( + f"offline store config is of type {type(config.offline_store)} when file type required" + ) + if not isinstance(feature_view.batch_source, FileSource): + raise ValueError( + f"feature view batch source is {type(feature_view.batch_source)} not file source" + ) + file_options = feature_view.batch_source.file_options + filesystem, path = FileSource.create_filesystem_and_path( + file_options.uri, file_options.s3_endpoint_override + ) + + prev_table = pyarrow.parquet.read_table(path, memory_map=True) + if prev_table.column_names != data.column_names: + raise ValueError( + f"Input dataframe has incorrect schema or wrong order, expected columns are: {prev_table.column_names}" + ) + if data.schema != prev_table.schema: + data = data.cast(prev_table.schema) + new_table = pyarrow.concat_tables([data, prev_table]) + writer = pyarrow.parquet.ParquetWriter(path, data.schema, filesystem=filesystem) + writer.write_table(new_table) + writer.close() + def _get_entity_df_event_timestamp_range( entity_df: Union[pd.DataFrame, str], entity_df_event_timestamp_col: str, diff --git a/sdk/python/feast/infra/offline_stores/offline_store.py b/sdk/python/feast/infra/offline_stores/offline_store.py index 6c95283358..cd807764ba 100644 --- a/sdk/python/feast/infra/offline_stores/offline_store.py +++ b/sdk/python/feast/infra/offline_stores/offline_store.py @@ -274,8 +274,8 @@ def write_logged_features( @staticmethod def offline_write_batch( config: RepoConfig, - table: FeatureView, - data: pd.DataFrame, + feature_view: FeatureView, + data: pyarrow.Table, progress: Optional[Callable[[int], Any]], ): """ @@ -287,7 +287,7 @@ def offline_write_batch( Args: config: Repo configuration object table: FeatureView to write the data to. - data: dataframe containing feature data and timestamp column for historical feature retrieval + data: pyarrow table containing feature data and timestamp column for historical feature retrieval progress: Optional function to be called once every mini-batch of rows is written to the online store. Can be used to display progress. """ diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index ef72541147..e023afe782 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -104,10 +104,11 @@ def offline_write_batch( self, config: RepoConfig, table: FeatureView, - data: pd.DataFrame, + data: pa.Table, progress: Optional[Callable[[int], Any]], ) -> None: set_usage_attribute("provider", self.__class__.__name__) + if self.offline_store: self.offline_store.offline_write_batch(config, table, data, progress) @@ -143,6 +144,14 @@ def ingest_df( self.repo_config, feature_view, rows_to_write, progress=None ) + def ingest_df_to_offline_store(self, feature_view: FeatureView, table: pa.Table): + set_usage_attribute("provider", self.__class__.__name__) + + if feature_view.batch_source.field_mapping is not None: + table = _run_field_mapping(table, feature_view.batch_source.field_mapping) + + self.offline_write_batch(self.repo_config, feature_view, table, None) + def materialize_single_feature_view( self, config: RepoConfig, diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index c6c9b75787..d2e37e69db 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -127,7 +127,7 @@ def ingest_df( pass def ingest_df_to_offline_store( - self, feature_view: FeatureView, df: pd.DataFrame, + self, feature_view: FeatureView, df: pyarrow.Table, ): """ Ingests a DataFrame directly into the offline store diff --git a/sdk/python/tests/integration/offline_store/test_offline_write.py b/sdk/python/tests/integration/offline_store/test_offline_write.py new file mode 100644 index 0000000000..41f6ea89fa --- /dev/null +++ b/sdk/python/tests/integration/offline_store/test_offline_write.py @@ -0,0 +1,226 @@ +import random +from datetime import datetime, timedelta + +import numpy as np +import pandas as pd +import pytest + +from feast import FeatureView, Field +from feast.types import Float32, Int32 +from tests.integration.feature_repos.universal.entities import driver + + +@pytest.mark.integration +@pytest.mark.universal_online_stores +def test_writing_incorrect_order_fails(environment, universal_data_sources): + # TODO(kevjumba) handle incorrect order later, for now schema must be in the order that the filesource is in + store = environment.feature_store + _, _, data_sources = universal_data_sources + driver_stats = FeatureView( + name="driver_stats", + entities=["driver"], + schema=[ + Field(name="avg_daily_trips", dtype=Int32), + Field(name="conv_rate", dtype=Float32), + ], + source=data_sources.driver, + ) + + now = datetime.utcnow() + ts = pd.Timestamp(now).round("ms") + + entity_df = pd.DataFrame.from_dict( + {"driver_id": [1001, 1002], "event_timestamp": [ts - timedelta(hours=3), ts]} + ) + + store.apply([driver(), driver_stats]) + df = store.get_historical_features( + entity_df=entity_df, + features=["driver_stats:conv_rate", "driver_stats:avg_daily_trips"], + full_feature_names=False, + ).to_df() + + assert df["conv_rate"].isnull().all() + assert df["avg_daily_trips"].isnull().all() + + expected_df = pd.DataFrame.from_dict( + { + "driver_id": [1001, 1002], + "event_timestamp": [ts - timedelta(hours=3), ts], + "conv_rate": [random.random(), random.random()], + "avg_daily_trips": [random.randint(0, 10), random.randint(0, 10)], + "created": [ts, ts], + }, + ) + with pytest.raises(ValueError): + store._write_to_offline_store( + driver_stats.name, expected_df, allow_registry_cache=False + ) + + +@pytest.mark.integration +@pytest.mark.universal_online_stores +def test_writing_incorrect_schema_fails(environment, universal_data_sources): + # TODO(kevjumba) handle incorrect order later, for now schema must be in the order that the filesource is in + store = environment.feature_store + _, _, data_sources = universal_data_sources + driver_stats = FeatureView( + name="driver_stats", + entities=["driver"], + schema=[ + Field(name="avg_daily_trips", dtype=Int32), + Field(name="conv_rate", dtype=Float32), + ], + source=data_sources.driver, + ) + + now = datetime.utcnow() + ts = pd.Timestamp(now).round("ms") + + entity_df = pd.DataFrame.from_dict( + {"driver_id": [1001, 1002], "event_timestamp": [ts - timedelta(hours=3), ts]} + ) + + store.apply([driver(), driver_stats]) + df = store.get_historical_features( + entity_df=entity_df, + features=["driver_stats:conv_rate", "driver_stats:avg_daily_trips"], + full_feature_names=False, + ).to_df() + + assert df["conv_rate"].isnull().all() + assert df["avg_daily_trips"].isnull().all() + + expected_df = pd.DataFrame.from_dict( + { + "event_timestamp": [ts - timedelta(hours=3), ts], + "driver_id": [1001, 1002], + "conv_rate": [random.random(), random.random()], + "incorrect_schema": [random.randint(0, 10), random.randint(0, 10)], + "created": [ts, ts], + }, + ) + with pytest.raises(ValueError): + store._write_to_offline_store( + driver_stats.name, expected_df, allow_registry_cache=False + ) + + +@pytest.mark.integration +@pytest.mark.universal_online_stores +def test_writing_consecutively_to_offline_store(environment, universal_data_sources): + store = environment.feature_store + _, _, data_sources = universal_data_sources + driver_stats = FeatureView( + name="driver_stats", + entities=["driver"], + schema=[ + Field(name="avg_daily_trips", dtype=Int32), + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + ], + source=data_sources.driver, + ttl=timedelta(minutes=10), + ) + + now = datetime.utcnow() + ts = pd.Timestamp(now, unit="ns") + + entity_df = pd.DataFrame.from_dict( + { + "driver_id": [1001, 1001], + "event_timestamp": [ts - timedelta(hours=4), ts - timedelta(hours=3)], + } + ) + + store.apply([driver(), driver_stats]) + df = store.get_historical_features( + entity_df=entity_df, + features=["driver_stats:conv_rate", "driver_stats:avg_daily_trips"], + full_feature_names=False, + ).to_df() + + assert df["conv_rate"].isnull().all() + assert df["avg_daily_trips"].isnull().all() + + first_df = pd.DataFrame.from_dict( + { + "event_timestamp": [ts - timedelta(hours=4), ts - timedelta(hours=3)], + "driver_id": [1001, 1001], + "conv_rate": [random.random(), random.random()], + "acc_rate": [random.random(), random.random()], + "avg_daily_trips": [random.randint(0, 10), random.randint(0, 10)], + "created": [ts, ts], + }, + ) + store._write_to_offline_store( + driver_stats.name, first_df, allow_registry_cache=False + ) + + after_write_df = store.get_historical_features( + entity_df=entity_df, + features=["driver_stats:conv_rate", "driver_stats:avg_daily_trips"], + full_feature_names=False, + ).to_df() + + assert len(after_write_df) == len(first_df) + assert np.where( + after_write_df["conv_rate"].reset_index(drop=True) + == first_df["conv_rate"].reset_index(drop=True) + ) + assert np.where( + after_write_df["avg_daily_trips"].reset_index(drop=True) + == first_df["avg_daily_trips"].reset_index(drop=True) + ) + + second_df = pd.DataFrame.from_dict( + { + "event_timestamp": [ts - timedelta(hours=1), ts], + "driver_id": [1001, 1001], + "conv_rate": [random.random(), random.random()], + "acc_rate": [random.random(), random.random()], + "avg_daily_trips": [random.randint(0, 10), random.randint(0, 10)], + "created": [ts, ts], + }, + ) + + store._write_to_offline_store( + driver_stats.name, second_df, allow_registry_cache=False + ) + + entity_df = pd.DataFrame.from_dict( + { + "driver_id": [1001, 1001, 1001, 1001], + "event_timestamp": [ + ts - timedelta(hours=4), + ts - timedelta(hours=3), + ts - timedelta(hours=1), + ts, + ], + } + ) + + after_write_df = store.get_historical_features( + entity_df=entity_df, + features=[ + "driver_stats:conv_rate", + "driver_stats:acc_rate", + "driver_stats:avg_daily_trips", + ], + full_feature_names=False, + ).to_df() + + expected_df = pd.concat([first_df, second_df]) + assert len(after_write_df) == len(expected_df) + assert np.where( + after_write_df["conv_rate"].reset_index(drop=True) + == expected_df["conv_rate"].reset_index(drop=True) + ) + assert np.where( + after_write_df["acc_rate"].reset_index(drop=True) + == expected_df["acc_rate"].reset_index(drop=True) + ) + assert np.where( + after_write_df["avg_daily_trips"].reset_index(drop=True) + == expected_df["avg_daily_trips"].reset_index(drop=True) + )