From 49e456eb7bd0002aaf853bc560b1f6097a0916f8 Mon Sep 17 00:00:00 2001 From: Tom White Date: Sat, 10 Sep 2022 16:20:30 +0100 Subject: [PATCH] Beam callbacks --- cubed/core/array.py | 5 +- cubed/extensions/tqdm.py | 2 +- cubed/runtime/executors/beam.py | 119 ++++++++++++++++++++++++-------- cubed/tests/test_core.py | 7 +- examples/dataflow-add-random.py | 10 ++- 5 files changed, 105 insertions(+), 38 deletions(-) diff --git a/cubed/core/array.py b/cubed/core/array.py index f478b172a..2c025281e 100644 --- a/cubed/core/array.py +++ b/cubed/core/array.py @@ -232,11 +232,14 @@ def on_task_end(self, event): @dataclass class TaskEndEvent: - """Callback information about a completed task.""" + """Callback information about a completed task (or tasks).""" array_name: str """Name of the array that the task is for.""" + num_tasks: int = 1 + """Number of tasks that this event applies to (default 1).""" + task_create_tstamp: Optional[float] = None """Timestamp of when the task was created by the client.""" diff --git a/cubed/extensions/tqdm.py b/cubed/extensions/tqdm.py index 964b1c98c..d4524e3d7 100644 --- a/cubed/extensions/tqdm.py +++ b/cubed/extensions/tqdm.py @@ -31,7 +31,7 @@ def on_compute_end(self, dag): pbar.close() def on_task_end(self, event): - self.pbars[event.array_name].update() + self.pbars[event.array_name].update(event.num_tasks) @contextlib.contextmanager diff --git a/cubed/runtime/executors/beam.py b/cubed/runtime/executors/beam.py index 101725a1f..167c89aee 100644 --- a/cubed/runtime/executors/beam.py +++ b/cubed/runtime/executors/beam.py @@ -1,9 +1,12 @@ import itertools +import sched +import time from dataclasses import dataclass from typing import Any, Iterable, List, Tuple, cast import apache_beam as beam import networkx as nx +from apache_beam.runners.runner import PipelineState from rechunker.types import ( Config, NoArgumentStageFunction, @@ -12,6 +15,7 @@ Stage, ) +from cubed.core.array import TaskEndEvent from cubed.core.plan import visit_nodes from cubed.runtime.types import DagExecutor @@ -37,6 +41,7 @@ def _no_arg_stage( class _SingleArgumentStage(beam.PTransform): """Execute mappable stage in parallel.""" + name: str step: int stage: Stage config: Config @@ -56,7 +61,7 @@ def exec_stage(self, last: int, arg: Any) -> int: self.stage.function(arg, config=self.config) # type: ignore - beam.metrics.metric.Metrics.counter("cubed", "completed_tasks").inc() + beam.metrics.metric.Metrics.counter(self.name, "completed_tasks").inc() return self.step @@ -130,41 +135,47 @@ class BeamDagExecutor(DagExecutor): """An execution engine that uses Apache Beam.""" def execute_dag(self, dag, callbacks=None, array_names=None, **kwargs): - if callbacks is not None: - raise NotImplementedError("Callbacks not supported") dag = dag.copy() - with beam.Pipeline(**kwargs) as pipeline: - for name, node in visit_nodes(dag): - rechunker_pipeline = node["pipeline"] - - dep_nodes = list(dag.predecessors(name)) - - pcolls = [ - p - for (n, p) in nx.get_node_attributes(dag, "pcoll").items() - if n in dep_nodes - ] - if len(pcolls) == 0: - pcoll = pipeline | gensym("Start") >> beam.Create([-1]) - pcoll = add_to_pcoll(rechunker_pipeline, pcoll) - dag.nodes[name]["pcoll"] = pcoll - - elif len(pcolls) == 1: - pcoll = pcolls[0] - pcoll = add_to_pcoll(rechunker_pipeline, pcoll) - dag.nodes[name]["pcoll"] = pcoll - else: - pcoll = pcolls | gensym("Flatten") >> beam.Flatten() - pcoll |= gensym("Distinct") >> beam.Distinct() - pcoll = add_to_pcoll(rechunker_pipeline, pcoll) - dag.nodes[name]["pcoll"] = pcoll + pipeline = beam.Pipeline(**kwargs) + + for name, node in visit_nodes(dag): + rechunker_pipeline = node["pipeline"] + + dep_nodes = list(dag.predecessors(name)) + + pcolls = [ + p + for (n, p) in nx.get_node_attributes(dag, "pcoll").items() + if n in dep_nodes + ] + if len(pcolls) == 0: + pcoll = pipeline | gensym("Start") >> beam.Create([-1]) + pcoll = add_to_pcoll(name, rechunker_pipeline, pcoll) + dag.nodes[name]["pcoll"] = pcoll + + elif len(pcolls) == 1: + pcoll = pcolls[0] + pcoll = add_to_pcoll(name, rechunker_pipeline, pcoll) + dag.nodes[name]["pcoll"] = pcoll + else: + pcoll = pcolls | gensym("Flatten") >> beam.Flatten() + pcoll |= gensym("Distinct") >> beam.Distinct() + pcoll = add_to_pcoll(name, rechunker_pipeline, pcoll) + dag.nodes[name]["pcoll"] = pcoll + + result = pipeline.run() + + if callbacks is None: + result.wait_until_finish() + else: + wait_until_finish_with_callbacks(result, callbacks) -def add_to_pcoll(rechunker_pipeline, pcoll): +def add_to_pcoll(name, rechunker_pipeline, pcoll): for step, stage in enumerate(rechunker_pipeline.stages): if stage.mappable is not None: pcoll |= stage.name >> _SingleArgumentStage( - step, stage, rechunker_pipeline.config + name, step, stage, rechunker_pipeline.config ) else: pcoll |= stage.name >> beam.Map( @@ -181,3 +192,51 @@ def add_to_pcoll(rechunker_pipeline, pcoll): pcoll |= gensym("End") >> beam.Map(lambda x: -1) return pcoll + + +# A generalized version of Beam's PipelineResult.wait_until_finish method +# that polls for Beam metrics to make callbacks. +# If the pipeline is already done (e.g. the DirectRunner, which blocks) +# then all callbacks will be called before returning immediately. +def wait_until_finish_with_callbacks(result, callbacks): + MetricCallbackPoller(result, callbacks) + + +class MetricCallbackPoller: + def __init__(self, result, callbacks): + self.result = result + self.callbacks = callbacks + self.array_counts = {} + self.scheduler = sched.scheduler(time.time, time.sleep) + poll(self, self.result) # poll immediately + self.scheduler.run() + + def update(self, new_array_counts): + for name, new_count in new_array_counts.items(): + old_count = self.array_counts.get(name, 0) + # it's possible that new_count < old_count + event = TaskEndEvent(name, num_tasks=(new_count - old_count)) + if self.callbacks is not None: + [callback.on_task_end(event) for callback in self.callbacks] + self.array_counts[name] = new_count + + +def poll(poller, result): + new_array_counts = get_array_counts_from_metrics(result) + poller.update(new_array_counts) + state = result.state + if PipelineState.is_terminal(state): + return + else: + # poll again in 5 seconds + scheduler = poller.scheduler + scheduler.enter(5, 1, poll, (poller, result)) + + +def get_array_counts_from_metrics(result): + filter = beam.metrics.MetricsFilter().with_name("completed_tasks") + metrics = result.metrics().query(filter)["counters"] + new_array_counts = { + metric.key.metric.namespace: metric.result for metric in metrics + } + return new_array_counts diff --git a/cubed/tests/test_core.py b/cubed/tests/test_core.py index 7938a2532..0110ec2d2 100644 --- a/cubed/tests/test_core.py +++ b/cubed/tests/test_core.py @@ -12,6 +12,7 @@ from cubed.extensions.tqdm import TqdmProgressBar from cubed.primitive.blockwise import apply_blockwise from cubed.runtime.executors.python import PythonDagExecutor +from cubed.runtime.types import DagExecutor from cubed.tests.utils import MAIN_EXECUTORS, MODAL_EXECUTORS, create_zarr @@ -388,13 +389,11 @@ def on_task_end(self, event): >= event.task_create_tstamp > 0 ) - self.value += 1 + self.value += event.num_tasks def test_callbacks(spec, executor): - from cubed.runtime.executors.lithops import LithopsDagExecutor - - if not isinstance(executor, (PythonDagExecutor, LithopsDagExecutor)): + if not isinstance(executor, DagExecutor): pytest.skip(f"{type(executor)} does not support callbacks") task_counter = TaskCounter() diff --git a/examples/dataflow-add-random.py b/examples/dataflow-add-random.py index 39f721e98..4bd3cba2d 100644 --- a/examples/dataflow-add-random.py +++ b/examples/dataflow-add-random.py @@ -2,10 +2,12 @@ import logging from apache_beam.options.pipeline_options import PipelineOptions +from tqdm.contrib.logging import logging_redirect_tqdm import cubed import cubed.array_api as xp import cubed.random +from cubed.extensions.tqdm import TqdmProgressBar from cubed.runtime.executors.beam import BeamDagExecutor @@ -29,8 +31,12 @@ def run(argv=None): (50000, 50000), chunks=(5000, 5000), spec=spec ) # 200MB chunks c = xp.add(a, b) - # use store=None to write to temporary zarr - cubed.to_zarr(c, store=None, executor=executor, options=beam_options) + with logging_redirect_tqdm(): + progress = TqdmProgressBar() + # use store=None to write to temporary zarr + cubed.to_zarr( + c, store=None, executor=executor, callbacks=[progress], options=beam_options + ) if __name__ == "__main__":