Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pg: add returning_id option to parallel_execute #181

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 28 additions & 16 deletions src/util/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,34 +73,34 @@ def savepoint(cr):
yield


def _parallel_execute_serial(cr, queries, logger=_logger):
cnt = 0
def _parallel_execute_serial(cr, queries, logger, returning_id):
res = [] if returning_id else 0
for query in log_progress(queries, logger, qualifier="queries", size=len(queries)):
cr.execute(query)
cnt += cr.rowcount
return cnt
res += cr.fetchall() if returning_id else cr.rowcount
return res


if ThreadPoolExecutor is not None:

def _parallel_execute_threaded(cr, queries, logger=_logger):
def _parallel_execute_threaded(cr, queries, logger, returning_id):
if not queries:
return None

if len(queries) == 1:
# No need to spawn other threads
cr.execute(queries[0])
return cr.rowcount
return cr.fetchall() if returning_id else cr.rowcount

max_workers = min(get_max_workers(), len(queries))
cursor = db_connect(cr.dbname).cursor

def execute(query):
with cursor() as tcr:
tcr.execute(query)
cnt = tcr.rowcount
res = tcr.fetchall() if returning_id else tcr.rowcount
tcr.commit()
return cnt
return res

cr.commit()

Expand All @@ -109,7 +109,7 @@ def execute(query):
errorcodes.SERIALIZATION_FAILURE,
}
failed_queries = []
tot_cnt = 0
tot_res = [] if returning_id else 0
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_queries = {executor.submit(execute, q): q for q in queries}
for future in log_progress(
Expand All @@ -121,7 +121,7 @@ def execute(query):
log_hundred_percent=True,
):
try:
tot_cnt += future.result() or 0
tot_res += future.result() or ([] if returning_id else 0)
except psycopg2.OperationalError as exc:
if exc.pgcode not in CONCURRENCY_ERRORCODES:
raise
Expand All @@ -131,16 +131,16 @@ def execute(query):

if failed_queries:
logger.warning("Serialize queries that failed due to concurrency issues")
tot_cnt += _parallel_execute_serial(cr, failed_queries, logger=logger)
tot_res += _parallel_execute_serial(cr, failed_queries, logger, returning_id)
cr.commit()

return tot_cnt
return tot_res

else:
_parallel_execute_threaded = _parallel_execute_serial


def parallel_execute(cr, queries, logger=_logger):
def parallel_execute(cr, queries, logger=_logger, returning_id=False):
"""
Execute queries in parallel.

Expand All @@ -154,15 +154,20 @@ def parallel_execute(cr, queries, logger=_logger):

:param list(str) queries: list of queries to execute concurrently
:param `~logging.Logger` logger: logger used to report the progress
:return: the sum of `cr.rowcount` for each query run
:param bool returning_id: wether to return a tuple of affected ids (default: return affected row count)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
:param bool returning_id: wether to return a tuple of affected ids (default: return affected row count)
:param bool returning_id: whether to return a tuple of affected ids (default: return affected row count)

:return: the sum of `cr.rowcount` for each query run or a joined array of all result tuples, if `returning_id`
:rtype: int

.. warning::
- As a side effect, the cursor will be committed.

- Due to the nature of `cr.rowcount`, the return value of this function may represent an
underestimate of the real number of affected records. For instance, when some records
are deleted/updated as a result of an `ondelete` clause, they won't be taken into account.

- As a side effect, the cursor will be committed.
- It would not be generally safe to use this function for selecting queries. Because of this,
`returning_id=True` is only accepted for `UPDATE/DELETE/INSERT/MERGE [...] RETURNING id` queries. Also, the
caller cannot influnce the order of the returned result tuples, it is always sorted in ascending order.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
caller cannot influnce the order of the returned result tuples, it is always sorted in ascending order.
caller cannot influence the order of the returned result tuples, it is always sorted in ascending order.


.. note::
If a concurrency issue occurs, the *failing* queries will be retried sequentially.
Expand All @@ -172,7 +177,14 @@ def parallel_execute(cr, queries, logger=_logger):
if getattr(threading.current_thread(), "testing", False)
else _parallel_execute_threaded
)
return parallel_execute_impl(cr, queries, logger=_logger)

if returning_id:
returning_id_re = re.compile(r"(?s)(?:UPDATE|DELETE|INSERT|MERGE).*RETURNING\s+\S*\.?id\s*$")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
returning_id_re = re.compile(r"(?s)(?:UPDATE|DELETE|INSERT|MERGE).*RETURNING\s+\S*\.?id\s*$")
returning_id_re = re.compile(r"(?s)(?:UPDATE|DELETE|INSERT|MERGE).*RETURNING\s+(?:\w+\.)?id\s*$")

Change \S*\.? to (?:\w+\.)? to avoid false positive like "UPDATE ... RETURNING name,id" and "DELETE FROM ... RETURNING report_id".

if not all((bool(returning_id_re.search(q)) for q in queries)):
raise ValueError("The returning_id parameter can only be used with certain queries.")
Comment on lines +181 to +184
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if returning_id:
returning_id_re = re.compile(r"(?s)(?:UPDATE|DELETE|INSERT|MERGE).*RETURNING\s+\S*\.?id\s*$")
if not all((bool(returning_id_re.search(q)) for q in queries)):
raise ValueError("The returning_id parameter can only be used with certain queries.")
returning_id_re = re.compile(r"(?s)(?:UPDATE|DELETE|INSERT|MERGE).*RETURNING\s+\S*\.?id\s*$")
returning_id = all((bool(returning_id_re.search(q)) for q in queries))

How about we automatically detect whether the queries at hand are returning something?


res = parallel_execute_impl(cr, queries, logger, returning_id)
return tuple(sorted([id for (id,) in res])) if returning_id else res
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return tuple(sorted([id for (id,) in res])) if returning_id else res
return tuple(id for (id,) in res) if returning_id else res

Why sorting these?



def format_query(cr, query, *args, **kwargs):
Expand Down