Skip to content

Commit

Permalink
fix: Return an empty infra object from sql registry when it doesn't e…
Browse files Browse the repository at this point in the history
…xist (#3022)

* fix: Return an empty infra object from sql registry when it doesn't exist

Signed-off-by: Achal Shah <achals@gmail.com>

* better

Signed-off-by: Achal Shah <achals@gmail.com>

* types

Signed-off-by: Achal Shah <achals@gmail.com>

* fix hasattr

Signed-off-by: Achal Shah <achals@gmail.com>
  • Loading branch information
achals authored Aug 5, 2022
1 parent b1660aa commit 8ba87d1
Showing 1 changed file with 39 additions and 15 deletions.
54 changes: 39 additions & 15 deletions sdk/python/feast/infra/registry_stores/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, List, Optional, Set, Union
from typing import Any, Callable, List, Optional, Set, Union

from sqlalchemy import ( # type: ignore
BigInteger,
Expand Down Expand Up @@ -560,7 +560,7 @@ def update_infra(self, infra: Infra, project: str, commit: bool = True):
)

def get_infra(self, project: str, allow_cache: bool = False) -> Infra:
return self._get_object(
infra_object = self._get_object(
managed_infra,
"infra_obj",
project,
Expand All @@ -570,6 +570,8 @@ def get_infra(self, project: str, allow_cache: bool = False) -> Infra:
"infra_proto",
None,
)
infra_object = infra_object or InfraProto()
return Infra.from_proto(infra_object)

def apply_user_metadata(
self,
Expand Down Expand Up @@ -683,11 +685,18 @@ def commit(self):
pass

def _apply_object(
self, table, project: str, id_field_name, obj, proto_field_name, name=None
self,
table: Table,
project: str,
id_field_name,
obj,
proto_field_name,
name=None,
):
self._maybe_init_project_metadata(project)

name = name or obj.name
name = name or obj.name if hasattr(obj, "name") else None
assert name, f"name needs to be provided for {obj}"
with self.engine.connect() as conn:
update_datetime = datetime.utcnow()
update_time = int(update_datetime.timestamp())
Expand Down Expand Up @@ -749,7 +758,14 @@ def _maybe_init_project_metadata(self, project):
conn.execute(insert_stmt)
usage.set_current_project_uuid(new_project_uuid)

def _delete_object(self, table, name, project, id_field_name, not_found_exception):
def _delete_object(
self,
table: Table,
name: str,
project: str,
id_field_name: str,
not_found_exception: Optional[Callable],
):
with self.engine.connect() as conn:
stmt = delete(table).where(
getattr(table.c, id_field_name) == name, table.c.project_id == project
Expand All @@ -763,14 +779,14 @@ def _delete_object(self, table, name, project, id_field_name, not_found_exceptio

def _get_object(
self,
table,
name,
project,
proto_class,
python_class,
id_field_name,
proto_field_name,
not_found_exception,
table: Table,
name: str,
project: str,
proto_class: Any,
python_class: Any,
id_field_name: str,
proto_field_name: str,
not_found_exception: Optional[Callable],
):
self._maybe_init_project_metadata(project)

Expand All @@ -782,10 +798,18 @@ def _get_object(
if row:
_proto = proto_class.FromString(row[proto_field_name])
return python_class.from_proto(_proto)
raise not_found_exception(name, project)
if not_found_exception:
raise not_found_exception(name, project)
else:
return None

def _list_objects(
self, table, project, proto_class, python_class, proto_field_name
self,
table: Table,
project: str,
proto_class: Any,
python_class: Any,
proto_field_name: str,
):
self._maybe_init_project_metadata(project)
with self.engine.connect() as conn:
Expand Down

0 comments on commit 8ba87d1

Please sign in to comment.