-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pg: add returning_id option to parallel_execute #181
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -73,34 +73,34 @@ def savepoint(cr): | |||||||||||||
yield | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def _parallel_execute_serial(cr, queries, logger=_logger): | ||||||||||||||
cnt = 0 | ||||||||||||||
def _parallel_execute_serial(cr, queries, logger, returning_id): | ||||||||||||||
res = [] if returning_id else 0 | ||||||||||||||
for query in log_progress(queries, logger, qualifier="queries", size=len(queries)): | ||||||||||||||
cr.execute(query) | ||||||||||||||
cnt += cr.rowcount | ||||||||||||||
return cnt | ||||||||||||||
res += cr.fetchall() if returning_id else cr.rowcount | ||||||||||||||
return res | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
if ThreadPoolExecutor is not None: | ||||||||||||||
|
||||||||||||||
def _parallel_execute_threaded(cr, queries, logger=_logger): | ||||||||||||||
def _parallel_execute_threaded(cr, queries, logger, returning_id): | ||||||||||||||
if not queries: | ||||||||||||||
return None | ||||||||||||||
|
||||||||||||||
if len(queries) == 1: | ||||||||||||||
# No need to spawn other threads | ||||||||||||||
cr.execute(queries[0]) | ||||||||||||||
return cr.rowcount | ||||||||||||||
return cr.fetchall() if returning_id else cr.rowcount | ||||||||||||||
|
||||||||||||||
max_workers = min(get_max_workers(), len(queries)) | ||||||||||||||
cursor = db_connect(cr.dbname).cursor | ||||||||||||||
|
||||||||||||||
def execute(query): | ||||||||||||||
with cursor() as tcr: | ||||||||||||||
tcr.execute(query) | ||||||||||||||
cnt = tcr.rowcount | ||||||||||||||
res = tcr.fetchall() if returning_id else tcr.rowcount | ||||||||||||||
tcr.commit() | ||||||||||||||
return cnt | ||||||||||||||
return res | ||||||||||||||
|
||||||||||||||
cr.commit() | ||||||||||||||
|
||||||||||||||
|
@@ -109,7 +109,7 @@ def execute(query): | |||||||||||||
errorcodes.SERIALIZATION_FAILURE, | ||||||||||||||
} | ||||||||||||||
failed_queries = [] | ||||||||||||||
tot_cnt = 0 | ||||||||||||||
tot_res = [] if returning_id else 0 | ||||||||||||||
with ThreadPoolExecutor(max_workers=max_workers) as executor: | ||||||||||||||
future_queries = {executor.submit(execute, q): q for q in queries} | ||||||||||||||
for future in log_progress( | ||||||||||||||
|
@@ -121,7 +121,7 @@ def execute(query): | |||||||||||||
log_hundred_percent=True, | ||||||||||||||
): | ||||||||||||||
try: | ||||||||||||||
tot_cnt += future.result() or 0 | ||||||||||||||
tot_res += future.result() or ([] if returning_id else 0) | ||||||||||||||
except psycopg2.OperationalError as exc: | ||||||||||||||
if exc.pgcode not in CONCURRENCY_ERRORCODES: | ||||||||||||||
raise | ||||||||||||||
|
@@ -131,16 +131,16 @@ def execute(query): | |||||||||||||
|
||||||||||||||
if failed_queries: | ||||||||||||||
logger.warning("Serialize queries that failed due to concurrency issues") | ||||||||||||||
tot_cnt += _parallel_execute_serial(cr, failed_queries, logger=logger) | ||||||||||||||
tot_res += _parallel_execute_serial(cr, failed_queries, logger, returning_id) | ||||||||||||||
cr.commit() | ||||||||||||||
|
||||||||||||||
return tot_cnt | ||||||||||||||
return tot_res | ||||||||||||||
|
||||||||||||||
else: | ||||||||||||||
_parallel_execute_threaded = _parallel_execute_serial | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def parallel_execute(cr, queries, logger=_logger): | ||||||||||||||
def parallel_execute(cr, queries, logger=_logger, returning_id=False): | ||||||||||||||
""" | ||||||||||||||
Execute queries in parallel. | ||||||||||||||
|
||||||||||||||
|
@@ -154,15 +154,20 @@ def parallel_execute(cr, queries, logger=_logger): | |||||||||||||
|
||||||||||||||
:param list(str) queries: list of queries to execute concurrently | ||||||||||||||
:param `~logging.Logger` logger: logger used to report the progress | ||||||||||||||
:return: the sum of `cr.rowcount` for each query run | ||||||||||||||
:param bool returning_id: wether to return a tuple of affected ids (default: return affected row count) | ||||||||||||||
:return: the sum of `cr.rowcount` for each query run or a joined array of all result tuples, if `returning_id` | ||||||||||||||
:rtype: int | ||||||||||||||
|
||||||||||||||
.. warning:: | ||||||||||||||
- As a side effect, the cursor will be committed. | ||||||||||||||
|
||||||||||||||
- Due to the nature of `cr.rowcount`, the return value of this function may represent an | ||||||||||||||
underestimate of the real number of affected records. For instance, when some records | ||||||||||||||
are deleted/updated as a result of an `ondelete` clause, they won't be taken into account. | ||||||||||||||
|
||||||||||||||
- As a side effect, the cursor will be committed. | ||||||||||||||
- It would not be generally safe to use this function for selecting queries. Because of this, | ||||||||||||||
`returning_id=True` is only accepted for `UPDATE/DELETE/INSERT/MERGE [...] RETURNING id` queries. Also, the | ||||||||||||||
caller cannot influnce the order of the returned result tuples, it is always sorted in ascending order. | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
||||||||||||||
.. note:: | ||||||||||||||
If a concurrency issue occurs, the *failing* queries will be retried sequentially. | ||||||||||||||
|
@@ -172,7 +177,14 @@ def parallel_execute(cr, queries, logger=_logger): | |||||||||||||
if getattr(threading.current_thread(), "testing", False) | ||||||||||||||
else _parallel_execute_threaded | ||||||||||||||
) | ||||||||||||||
return parallel_execute_impl(cr, queries, logger=_logger) | ||||||||||||||
|
||||||||||||||
if returning_id: | ||||||||||||||
returning_id_re = re.compile(r"(?s)(?:UPDATE|DELETE|INSERT|MERGE).*RETURNING\s+\S*\.?id\s*$") | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Change |
||||||||||||||
if not all((bool(returning_id_re.search(q)) for q in queries)): | ||||||||||||||
raise ValueError("The returning_id parameter can only be used with certain queries.") | ||||||||||||||
Comment on lines
+181
to
+184
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
How about we automatically detect whether the queries at hand are returning something? |
||||||||||||||
|
||||||||||||||
res = parallel_execute_impl(cr, queries, logger, returning_id) | ||||||||||||||
return tuple(sorted([id for (id,) in res])) if returning_id else res | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Why sorting these? |
||||||||||||||
|
||||||||||||||
|
||||||||||||||
def format_query(cr, query, *args, **kwargs): | ||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.