diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index 26d414232f..4a561994c3 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -1,7 +1,7 @@ import tempfile import warnings from datetime import datetime -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import pandas @@ -191,6 +191,68 @@ def get_historical_features( ), ) + @staticmethod + def offline_write_batch( + config: RepoConfig, + feature_view: FeatureView, + table: pyarrow.Table, + progress: Optional[Callable[[int], Any]], + ): + if not feature_view.batch_source: + raise ValueError( + "feature view does not have a batch source to persist offline data" + ) + if not isinstance(config.offline_store, SparkOfflineStoreConfig): + raise ValueError( + f"offline store config is of type {type(config.offline_store)} when spark type required" + ) + if not isinstance(feature_view.batch_source, SparkSource): + raise ValueError( + f"feature view batch source is {type(feature_view.batch_source)} not spark source" + ) + + pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source( + config, feature_view.batch_source + ) + if column_names != table.column_names: + raise ValueError( + f"The input pyarrow table has schema {table.schema} with the incorrect columns {table.column_names}. " + f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}." + ) + + spark_session = get_spark_session_or_start_new_with_repoconfig( + store_config=config.offline_store + ) + + if feature_view.batch_source.path: + # write data to disk so that it can be loaded into spark (for preserving column types) + with tempfile.NamedTemporaryFile(suffix=".parquet") as tmp_file: + print(tmp_file.name) + pq.write_table(table, tmp_file.name) + + # load data + df_batch = spark_session.read.parquet(tmp_file.name) + + # load existing data to get spark table schema + df_existing = spark_session.read.format( + feature_view.batch_source.file_format + ).load(feature_view.batch_source.path) + + # cast columns if applicable + df_batch = _cast_data_frame(df_batch, df_existing) + + df_batch.write.format(feature_view.batch_source.file_format).mode( + "append" + ).save(feature_view.batch_source.path) + elif feature_view.batch_source.query: + raise NotImplementedError( + "offline_write_batch not implemented for batch sources specified by query" + ) + else: + raise NotImplementedError( + "offline_write_batch not implemented for batch sources specified by a table" + ) + @staticmethod @log_exceptions_and_usage(offline_store="spark") def pull_all_from_table_or_query( @@ -388,6 +450,24 @@ def _format_datetime(t: datetime) -> str: return dt +def _cast_data_frame( + df_new: pyspark.sql.DataFrame, df_existing: pyspark.sql.DataFrame +) -> pyspark.sql.DataFrame: + """Convert new dataframe's columns to the same types as existing dataframe while preserving the order of columns""" + existing_dtypes = {k: v for k, v in df_existing.dtypes} + new_dtypes = {k: v for k, v in df_new.dtypes} + + select_expression = [] + for col, new_type in new_dtypes.items(): + existing_type = existing_dtypes[col] + if new_type != existing_type: + select_expression.append(f"cast({col} as {existing_type}) as {col}") + else: + select_expression.append(col) + + return df_new.selectExpr(*select_expression) + + MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN = """ /* Compute a deterministic hash for the `left_table_query_string` that will be used throughout diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py index 95bedd1b40..ab1acbef73 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py @@ -1,3 +1,6 @@ +import os +import shutil +import tempfile import uuid from typing import Dict, List @@ -48,6 +51,8 @@ def __init__(self, project_name: str, *args, **kwargs): def teardown(self): self.spark_session.stop() + for table in self.tables: + shutil.rmtree(table) def create_offline_store_config(self): self.spark_offline_store_config = SparkOfflineStoreConfig() @@ -86,11 +91,17 @@ def create_data_source( .appName("pytest-pyspark-local-testing") .getOrCreate() ) - self.spark_session.createDataFrame(df).createOrReplaceTempView(destination_name) - self.tables.append(destination_name) + temp_dir = tempfile.mkdtemp(prefix="spark_offline_store_test_data") + + path = os.path.join(temp_dir, destination_name) + self.tables.append(path) + + self.spark_session.createDataFrame(df).write.parquet(path) return SparkSource( - table=destination_name, + name=destination_name, + file_format="parquet", + path=path, timestamp_field=timestamp_field, created_timestamp_column=created_timestamp_column, field_mapping=field_mapping or {"ts_1": "ts"},