Skip to content

Commit

Permalink
fix: Revert "fix: don't strip SQL comments in Explore (#28363)" (#28567)
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-s-molina committed May 17, 2024
1 parent bfa85b4 commit 53f98af
Show file tree
Hide file tree
Showing 8 changed files with 12 additions and 15 deletions.
2 changes: 1 addition & 1 deletion superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1492,7 +1492,7 @@ def get_rendered_sql(
msg=ex.message,
)
) from ex
sql = sqlparse.format(sql.strip("\t\r\n; "))
sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True)
if not sql:
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
if len(sqlparse.split(sql)) > 1:
Expand Down
5 changes: 3 additions & 2 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,8 +916,9 @@ def apply_top_to_sql(cls, sql: str, limit: int) -> str:
cte = None
sql_remainder = None
sql = sql.strip(" \t\n;")
sql_statement = sqlparse.format(sql, strip_comments=True)
query_limit: int | None = sql_parse.extract_top_from_query(
sql, cls.top_keywords
sql_statement, cls.top_keywords
)
if not limit:
final_limit = query_limit
Expand All @@ -926,7 +927,7 @@ def apply_top_to_sql(cls, sql: str, limit: int) -> str:
else:
final_limit = limit
if not cls.allows_cte_in_subquery:
cte, sql_remainder = sql_parse.get_cte_remainder_query(sql)
cte, sql_remainder = sql_parse.get_cte_remainder_query(sql_statement)
if cte:
str_statement = str(sql_remainder)
cte = cte + "\n"
Expand Down
2 changes: 1 addition & 1 deletion superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,7 +1071,7 @@ def get_rendered_sql(
msg=ex.message,
)
) from ex
sql = sqlparse.format(sql.strip("\t\r\n; "))
sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True)
if not sql:
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
if len(sqlparse.split(sql)) > 1:
Expand Down
1 change: 1 addition & 0 deletions superset/sqllab/query_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def render(self, execution_context: SqlJsonExecutionContext) -> str:

parsed_query = ParsedQuery(
query_model.sql,
strip_comments=True,
engine=query_model.database.db_engine_spec.engine,
)
rendered_query = sql_template_processor.process_template(
Expand Down
1 change: 0 additions & 1 deletion tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,6 @@ def virtual_dataset():
"SELECT 3 as col1, 'd' as col2, 1.3, NULL, '2000-01-04 00:00:00', 4 "
"UNION ALL "
"SELECT 4 as col1, 'e' as col2, 1.4, NULL, '2000-01-05 00:00:00', 5 "
"\n /* CONTAINS A RANDOM COMMENT */ \n"
"UNION ALL "
"SELECT 5 as col1, 'f' as col2, 1.5, NULL, '2000-01-06 00:00:00', 6 "
"UNION ALL "
Expand Down
4 changes: 1 addition & 3 deletions tests/integration_tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,9 +539,7 @@ def test_comments_in_sqlatable_query(self):
database=get_example_database(),
)
rendered_query = str(table.get_from_clause()[0])
assert "comment 1" in rendered_query
assert "comment 2" in rendered_query
assert "FROM tbl" in rendered_query
self.assertEqual(clean_query, rendered_query)

def test_slice_payload_no_datasource(self):
form_data = {
Expand Down
8 changes: 3 additions & 5 deletions tests/integration_tests/datasource_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,12 +529,10 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset):
assert "coltypes" in rv2.json["result"]
assert "data" in rv2.json["result"]

sql = (
f"select * from ({virtual_dataset.sql}) as tbl "
f'limit {app.config["SAMPLES_ROW_LIMIT"]}'
eager_samples = virtual_dataset.database.get_df(
f"select * from ({virtual_dataset.sql}) as tbl"
f' limit {app.config["SAMPLES_ROW_LIMIT"]}'
)
eager_samples = virtual_dataset.database.get_df(sql)

# the col3 is Decimal
eager_samples["col3"] = eager_samples["col3"].apply(float)
eager_samples = eager_samples.to_dict(orient="records")
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/sqllab_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,9 +571,9 @@ def test_sql_json_parameter_error(self):
assert data["status"] == "success"

data = self.run_sql(
"SELECT * FROM birth_names WHERE state = '{{ state }}' -- blabblah {{ extra1 }}\nLIMIT 10",
"SELECT * FROM birth_names WHERE state = '{{ state }}' -- blabblah {{ extra1 }} {{fake.fn()}}\nLIMIT 10",
"3",
template_params=json.dumps({"state": "CA", "extra1": "comment"}),
template_params=json.dumps({"state": "CA"}),
)
assert data["status"] == "success"

Expand Down

0 comments on commit 53f98af

Please sign in to comment.