Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: rename get_iterable #24994

Merged
merged 7 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions superset/connectors/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def data_for_slices( # pylint: disable=too-many-locals
form_data = slc.form_data
# pull out all required metrics from the form_data
for metric_param in METRIC_FORM_DATA_PARAMS:
for metric in utils.get_iterable(form_data.get(metric_param) or []):
for metric in utils.get_as_list(form_data.get(metric_param) or []):
metric_names.add(utils.get_metric_name(metric))
if utils.is_adhoc_metric(metric):
column = metric.get("column") or {}
Expand Down Expand Up @@ -377,7 +377,7 @@ def data_for_slices( # pylint: disable=too-many-locals
if utils.is_adhoc_column(column)
else column
for column_param in COLUMN_FORM_DATA_PARAMS
for column in utils.get_iterable(form_data.get(column_param) or [])
for column in utils.get_as_list(form_data.get(column_param) or [])
]
column_names.update(_columns)

Expand Down
4 changes: 2 additions & 2 deletions superset/daos/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
DAOUpdateFailedError,
)
from superset.extensions import db
from superset.utils.core import get_iterable
from superset.utils.core import get_as_list

T = TypeVar("T", bound=Model)

Expand Down Expand Up @@ -216,7 +216,7 @@ def delete(cls, items: T | list[T], commit: bool = True) -> None:
"""

try:
for item in get_iterable(items):
for item in get_as_list(items):
db.session.delete(item)

if commit:
Expand Down
9 changes: 5 additions & 4 deletions superset/daos/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import logging
from datetime import datetime
from typing import TYPE_CHECKING
from typing import cast, TYPE_CHECKING
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved

from sqlalchemy.exc import SQLAlchemyError

Expand All @@ -27,7 +27,7 @@
from superset.extensions import db
from superset.models.core import FavStar, FavStarClassName
from superset.models.slice import Slice
from superset.utils.core import get_iterable, get_user_id
from superset.utils.core import get_as_list, get_user_id

if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource
Expand All @@ -39,8 +39,9 @@ class ChartDAO(BaseDAO[Slice]):
base_filter = ChartFilter

@classmethod
def delete(cls, items: Slice | list[Slice], commit: bool = True) -> None:
item_ids = [item.id for item in get_iterable(items)]
def delete(cls, item_or_items: Slice | list[Slice], commit: bool = True) -> None:
items = cast(list[Slice], item_or_items)
item_ids = [item.id for item in items]
# bulk delete, first delete related data
# bulk delete itself
try:
Expand Down
11 changes: 7 additions & 4 deletions superset/daos/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import json
import logging
from datetime import datetime
from typing import Any
from typing import Any, cast
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved

from flask import g
from flask_appbuilder.models.sqla.interface import SQLAInterface
Expand Down Expand Up @@ -48,7 +48,7 @@
from superset.models.embedded_dashboard import EmbeddedDashboard
from superset.models.filter_set import FilterSet
from superset.models.slice import Slice
from superset.utils.core import get_iterable, get_user_id
from superset.utils.core import get_as_list, get_user_id
from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -191,8 +191,11 @@ def update_charts_owners(model: Dashboard, commit: bool = True) -> Dashboard:
return model

@classmethod
def delete(cls, items: Dashboard | list[Dashboard], commit: bool = True) -> None:
item_ids = [item.id for item in get_iterable(items)]
def delete(
cls, item_or_items: Dashboard | list[Dashboard], commit: bool = True
) -> None:
items = cast(list[Dashboard], get_as_list(item_or_items))
item_ids = [item.id for item in items]
try:
db.session.query(Dashboard).filter(Dashboard.id.in_(item_ids)).delete(
synchronize_session="fetch"
Expand Down
12 changes: 6 additions & 6 deletions superset/daos/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

import logging
from typing import Any
from typing import Any, cast
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved

from sqlalchemy.exc import SQLAlchemyError

Expand All @@ -28,7 +28,7 @@
from superset.models.core import Database
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils.core import DatasourceType, get_iterable
from superset.utils.core import DatasourceType, get_as_list
from superset.views.base import DatasourceFilter

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -312,7 +312,7 @@ def find_dataset_metric(cls, dataset_id: int, metric_id: int) -> SqlMetric | Non
@classmethod
def delete(
cls,
items: SqlaTable | list[SqlaTable],
item_or_items: SqlaTable | list[SqlaTable],
commit: bool = True,
) -> None:
"""
Expand All @@ -326,16 +326,16 @@ def delete(
:raises DAODeleteFailedError: If the deletion failed
:see: https://docs.sqlalchemy.org/en/latest/orm/queryguide/dml.html
"""

items = cast(list[SqlaTable], get_as_list(item_or_items))
try:
db.session.query(SqlaTable).filter(
SqlaTable.id.in_(item.id for item in get_iterable(items))
SqlaTable.id.in_(item.id for item in items)
).delete(synchronize_session="fetch")

connection = db.session.connection()
mapper = next(iter(cls.model_cls.registry.mappers)) # type: ignore

for item in get_iterable(items):
for item in items:
security_manager.dataset_after_delete(mapper, connection, item)

if commit:
Expand Down
9 changes: 6 additions & 3 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1578,12 +1578,15 @@ def split(
yield string[i:]


def get_iterable(x: Any) -> list[Any]:
T = TypeVar("T")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!



def get_as_list(x: T | list[T]) -> list[T]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just as_list?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think either as_iterable or as_list is likely best.

Copy link
Member Author

@betodealmeida betodealmeida Aug 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@john-bodley as_iterable is super confusing, because there are many things that are already iterable but will still be modified by this function, eg:

>>> from typing import Any
>>> def is_iterable(foo: Any) -> bool:
...     try:
...         iter(foo)
...     except Exception:
...         return False
...     return True
...
>>>
>>> a = (1,2,3)
>>> is_iterable(a)
True
>>> as_iterable(a)
[(1, 2, 3)]

Copy link
Member Author

@betodealmeida betodealmeida Aug 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And to be clear, I think this function could be called as_iterable, but then it should be improved to not check if the passed object is a list, but instead use an improved version of the is_iterable function above, something like this:

def is_iterable(foo: Any) -> bool:
    if isinstance(foo, str):
        return False
    try:
        iter(foo)
    except Exception:
        return False
    return True

But that might introduce bugs. In this PR I just want to fix the erroneous name.

"""
Get an iterable (list) representation of the object.
Wrap an object in a list if it's not a list.

:param x: The object
:returns: An iterable representation
:returns: A list wrapping the object if it's not already a list
"""
return x if isinstance(x, list) else [x]

Expand Down
10 changes: 5 additions & 5 deletions tests/integration_tests/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
format_timedelta,
GenericDataType,
get_form_data_token,
get_iterable,
get_as_list,
get_email_address_list,
get_stacktrace,
json_int_dttm_ser,
Expand Down Expand Up @@ -749,10 +749,10 @@ def test_get_or_create_db_existing_invalid_uri(self):
database = get_or_create_db("test_db", "sqlite:///superset.db")
assert database.sqlalchemy_uri == "sqlite:///superset.db"

def test_get_iterable(self):
self.assertListEqual(get_iterable(123), [123])
self.assertListEqual(get_iterable([123]), [123])
self.assertListEqual(get_iterable("foo"), ["foo"])
def test_get_as_list(self):
self.assertListEqual(get_as_list(123), [123])
self.assertListEqual(get_as_list([123]), [123])
self.assertListEqual(get_as_list("foo"), ["foo"])

@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
def test_build_extra_filters(self):
Expand Down