Skip to content

Commit

Permalink
Fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Mar 30, 2022
1 parent f349047 commit 8d2318c
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 95 deletions.
48 changes: 26 additions & 22 deletions superset/connectors/sqla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,11 @@ def load_or_create_tables( # pylint: disable=too-many-arguments

# set the default schema in tables that don't have it
if default_schema:
tables = list(tables)
for i, table in enumerate(tables):
fixed_tables = list(tables)
for i, table in enumerate(fixed_tables):
if table.schema is None:
tables[i] = Table(table.table, default_schema, table.catalog)
fixed_tables[i] = Table(table.table, default_schema, table.catalog)
tables = set(fixed_tables)

# load existing tables
predicate = or_(
Expand All @@ -196,35 +197,38 @@ def load_or_create_tables( # pylint: disable=too-many-arguments
new_tables = session.query(NewTable).filter(predicate).all()

# add missing tables
inspector = inspect(engine)
existing = {(table.schema, table.name) for table in new_tables}
for table in tables:
if (table.schema, table.table) not in existing:
column_metadata = inspector.get_columns(table.table, schema=table.schema)

physical_columns = []
for column in column_metadata:
physical_columns.append(
NewColumn(
name=column["name"],
type=str(column["type"]),
expression=conditional_quote(column["name"]),
is_temporal=column["type"].python_type.__name__.upper()
in TEMPORAL_TYPES,
is_aggregation=False,
is_physical=True,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
),
try:
inspector = inspect(engine)
column_metadata = inspector.get_columns(
table.table, schema=table.schema
)
except Exception: # pylint: disable=broad-except
continue
columns = [
NewColumn(
name=column["name"],
type=str(column["type"]),
expression=conditional_quote(column["name"]),
is_temporal=column["type"].python_type.__name__.upper()
in TEMPORAL_TYPES,
is_aggregation=False,
is_physical=True,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
)
for column in column_metadata
]
new_tables.append(
NewTable(
name=table.table,
schema=table.schema,
catalog=None,
database_id=database_id,
columns=physical_columns,
columns=columns,
)
)
existing.add((table.schema, table.table))
Expand Down
13 changes: 3 additions & 10 deletions superset/migrations/shared/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def find_nodes_by_key(element: Any, target: str) -> Iterator[Any]:
yield from find_nodes_by_key(value, target)


def extract_table_references(
default_schema: Optional[str], sql_text: str, sqla_dialect: str
) -> Set[Table]:
def extract_table_references(sql_text: str, sqla_dialect: str) -> Set[Table]:
"""
Return all the dependencies from a SQL sql_text.
"""
Expand All @@ -102,12 +100,7 @@ def extract_table_references(
parsed = ParsedQuery(sql_text)
return parsed.tables

tables = [
return {
Table(*[part["value"] for part in table["name"][::-1]])
for table in find_nodes_by_key(tree, "Table")
]
for i, table in enumerate(tables):
if table.schema is None:
tables[i] = Table(table.table, default_schema, table.catalog)

return set(tables)
}
45 changes: 24 additions & 21 deletions superset/migrations/versions/b8d3a24d9131_new_dataset_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,13 @@ def load_or_create_tables(
if not tables:
return []

# set the default schema in tables that don't have it
if default_schema:
tables = list(tables)
for i, table in enumerate(tables):
if table.schema is None:
tables[i] = Table(table.table, default_schema, table.catalog)

# load existing tables
predicate = or_(
*[
Expand All @@ -259,7 +266,7 @@ def load_or_create_tables(
)
new_tables = session.query(NewTable).filter(predicate).all()

# use original database model to the engine
# use original database model to get the engine
engine = (
session.query(OriginalDatabase)
.filter_by(id=database_id)
Expand All @@ -273,30 +280,28 @@ def load_or_create_tables(
for table in tables:
if (table.schema, table.table) not in existing:
column_metadata = inspector.get_columns(table.table, schema=table.schema)

physical_columns = []
for column in column_metadata:
physical_columns.append(
NewColumn(
name=column["name"],
type=str(column["type"]),
expression=conditional_quote(column["name"]),
is_temporal=column["type"].python_type.__name__.upper()
in TEMPORAL_TYPES,
is_aggregation=False,
is_physical=True,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
),
columns = [
NewColumn(
name=column["name"],
type=str(column["type"]),
expression=conditional_quote(column["name"]),
is_temporal=column["type"].python_type.__name__.upper()
in TEMPORAL_TYPES,
is_aggregation=False,
is_physical=True,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
)
for column in column_metadata
]
new_tables.append(
NewTable(
name=table.table,
schema=table.schema,
catalog=None,
database_id=database_id,
columns=physical_columns,
columns=columns,
)
)
existing.add((table.schema, table.table))
Expand Down Expand Up @@ -413,9 +418,7 @@ def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals
column.is_physical = False

# find referenced tables
referenced_tables = extract_table_references(
target.schema, target.sql, dialect_class.name,
)
referenced_tables = extract_table_references(target.sql, dialect_class.name)
tables = load_or_create_tables(
session,
target.database_id,
Expand Down
56 changes: 56 additions & 0 deletions tests/unit_tests/migrations/shared/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-outside-toplevel, unused-argument

"""
Test the SIP-68 migration.
"""

from pytest_mock import MockerFixture

from superset.sql_parse import Table


def test_extract_table_references(mocker: MockerFixture, app_context: None) -> None:
"""
Test the ``extract_table_references`` helper function.
"""
from superset.migrations.shared.utils import extract_table_references

assert extract_table_references("SELECT 1", "trino") == set()
assert extract_table_references("SELECT 1 FROM some_table", "trino") == {
Table(table="some_table", schema=None, catalog=None)
}
assert extract_table_references(
"SELECT 1 FROM some_catalog.some_schema.some_table", "trino"
) == {Table(table="some_table", schema="some_schema", catalog="some_catalog")}
assert extract_table_references(
"SELECT * FROM some_table JOIN other_table ON some_table.id = other_table.id",
"trino",
) == {
Table(table="some_table", schema=None, catalog=None),
Table(table="other_table", schema=None, catalog=None),
}

# test falling back to sqlparse
logger = mocker.patch("superset.migrations.shared.utils.logger")
sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
assert extract_table_references(
sql,
"trino",
) == {Table(table="other_table", schema=None, catalog=None)}
logger.warning.assert_called_with("Unable to parse query with sqloxide: %s", sql)

This file was deleted.

0 comments on commit 8d2318c

Please sign in to comment.