From 343aa7ee62805d38fc750057b414f6204749ecdd Mon Sep 17 00:00:00 2001 From: "Carsten Wolff (cawo)" Date: Tue, 31 Dec 2024 13:11:06 +0000 Subject: [PATCH] [IMP] pg: add returning_id option to parallel_execute The default behavior is unchanged. This adds the possibility to parallelize modifying queries that have a `RETURNING id` clause. For those, return the resulting ids (in a defined order) instead of the affected row count. To avoid misuse add a warning to the docstring and try to detect queries other than the ones of the intended form. Raise an error if such are found. --- src/util/pg.py | 44 ++++++++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/src/util/pg.py b/src/util/pg.py index de8b0f1e..c3fc90fd 100644 --- a/src/util/pg.py +++ b/src/util/pg.py @@ -73,24 +73,24 @@ def savepoint(cr): yield -def _parallel_execute_serial(cr, queries, logger=_logger): - cnt = 0 +def _parallel_execute_serial(cr, queries, logger, returning_id): + res = [] if returning_id else 0 for query in log_progress(queries, logger, qualifier="queries", size=len(queries)): cr.execute(query) - cnt += cr.rowcount - return cnt + res += cr.fetchall() if returning_id else cr.rowcount + return res if ThreadPoolExecutor is not None: - def _parallel_execute_threaded(cr, queries, logger=_logger): + def _parallel_execute_threaded(cr, queries, logger, returning_id): if not queries: return None if len(queries) == 1: # No need to spawn other threads cr.execute(queries[0]) - return cr.rowcount + return cr.fetchall() if returning_id else cr.rowcount max_workers = min(get_max_workers(), len(queries)) cursor = db_connect(cr.dbname).cursor @@ -98,9 +98,9 @@ def _parallel_execute_threaded(cr, queries, logger=_logger): def execute(query): with cursor() as tcr: tcr.execute(query) - cnt = tcr.rowcount + res = tcr.fetchall() if returning_id else tcr.rowcount tcr.commit() - return cnt + return res cr.commit() @@ -109,7 +109,7 @@ def execute(query): errorcodes.SERIALIZATION_FAILURE, } failed_queries = [] - tot_cnt = 0 + tot_res = [] if returning_id else 0 with ThreadPoolExecutor(max_workers=max_workers) as executor: future_queries = {executor.submit(execute, q): q for q in queries} for future in log_progress( @@ -121,7 +121,7 @@ def execute(query): log_hundred_percent=True, ): try: - tot_cnt += future.result() or 0 + tot_res += future.result() or ([] if returning_id else 0) except psycopg2.OperationalError as exc: if exc.pgcode not in CONCURRENCY_ERRORCODES: raise @@ -131,16 +131,16 @@ def execute(query): if failed_queries: logger.warning("Serialize queries that failed due to concurrency issues") - tot_cnt += _parallel_execute_serial(cr, failed_queries, logger=logger) + tot_res += _parallel_execute_serial(cr, failed_queries, logger, returning_id) cr.commit() - return tot_cnt + return tot_res else: _parallel_execute_threaded = _parallel_execute_serial -def parallel_execute(cr, queries, logger=_logger): +def parallel_execute(cr, queries, logger=_logger, returning_id=False): """ Execute queries in parallel. @@ -154,15 +154,20 @@ def parallel_execute(cr, queries, logger=_logger): :param list(str) queries: list of queries to execute concurrently :param `~logging.Logger` logger: logger used to report the progress - :return: the sum of `cr.rowcount` for each query run + :param bool returning_id: wether to return a tuple of affected ids (default: return affected row count) + :return: the sum of `cr.rowcount` for each query run or a joined array of all result tuples, if `returning_id` :rtype: int .. warning:: + - As a side effect, the cursor will be committed. + - Due to the nature of `cr.rowcount`, the return value of this function may represent an underestimate of the real number of affected records. For instance, when some records are deleted/updated as a result of an `ondelete` clause, they won't be taken into account. - - As a side effect, the cursor will be committed. + - It would not be generally safe to use this function for selecting queries. Because of this, + `returning_id=True` is only accepted for `UPDATE/DELETE/INSERT/MERGE [...] RETURNING id` queries. Also, the + caller cannot influnce the order of the returned result tuples, it is always sorted in ascending order. .. note:: If a concurrency issue occurs, the *failing* queries will be retried sequentially. @@ -172,7 +177,14 @@ def parallel_execute(cr, queries, logger=_logger): if getattr(threading.current_thread(), "testing", False) else _parallel_execute_threaded ) - return parallel_execute_impl(cr, queries, logger=_logger) + + if returning_id: + returning_id_re = re.compile(r"(?s)(?:UPDATE|DELETE|INSERT|MERGE).*RETURNING\s+\S*\.?id\s*$") + if not all((bool(returning_id_re.search(q)) for q in queries)): + raise ValueError("The returning_id parameter can only be used with certain queries.") + + res = parallel_execute_impl(cr, queries, logger, returning_id) + return tuple(sorted([id for (id,) in res])) if returning_id else res def format_query(cr, query, *args, **kwargs):