Skip to content

Commit

Permalink
Beam callbacks (#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Sep 13, 2022
1 parent 1adef3d commit c0d776e
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 39 deletions.
5 changes: 4 additions & 1 deletion cubed/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,14 @@ def on_task_end(self, event):

@dataclass
class TaskEndEvent:
"""Callback information about a completed task."""
"""Callback information about a completed task (or tasks)."""

array_name: str
"""Name of the array that the task is for."""

num_tasks: int = 1
"""Number of tasks that this event applies to (default 1)."""

task_create_tstamp: Optional[float] = None
"""Timestamp of when the task was created by the client."""

Expand Down
2 changes: 1 addition & 1 deletion cubed/extensions/tqdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def on_compute_end(self, dag):
pbar.close()

def on_task_end(self, event):
self.pbars[event.array_name].update()
self.pbars[event.array_name].update(event.num_tasks)


@contextlib.contextmanager
Expand Down
122 changes: 91 additions & 31 deletions cubed/runtime/executors/beam.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import itertools
import sched
import time
from dataclasses import dataclass
from typing import Any, Iterable, List, Tuple, cast
from typing import Any, Iterable, List, Optional, Tuple, cast

import apache_beam as beam
import networkx as nx
from apache_beam.runners.runner import PipelineState
from rechunker.types import (
Config,
NoArgumentStageFunction,
Expand All @@ -12,6 +15,7 @@
Stage,
)

from cubed.core.array import TaskEndEvent
from cubed.core.plan import visit_nodes
from cubed.runtime.types import DagExecutor

Expand Down Expand Up @@ -40,6 +44,7 @@ class _SingleArgumentStage(beam.PTransform):
step: int
stage: Stage
config: Config
name: Optional[str] = None

def prepare_stage(self, last: int) -> Iterable[Tuple[int, Any]]:
"""Propagate current stage to Mappables for parallel execution."""
Expand All @@ -56,7 +61,8 @@ def exec_stage(self, last: int, arg: Any) -> int:

self.stage.function(arg, config=self.config) # type: ignore

beam.metrics.metric.Metrics.counter("cubed", "completed_tasks").inc()
if self.name is not None:
beam.metrics.metric.Metrics.counter(self.name, "completed_tasks").inc()

return self.step

Expand Down Expand Up @@ -130,41 +136,47 @@ class BeamDagExecutor(DagExecutor):
"""An execution engine that uses Apache Beam."""

def execute_dag(self, dag, callbacks=None, array_names=None, **kwargs):
if callbacks is not None:
raise NotImplementedError("Callbacks not supported")
dag = dag.copy()
with beam.Pipeline(**kwargs) as pipeline:
for name, node in visit_nodes(dag):
rechunker_pipeline = node["pipeline"]

dep_nodes = list(dag.predecessors(name))

pcolls = [
p
for (n, p) in nx.get_node_attributes(dag, "pcoll").items()
if n in dep_nodes
]
if len(pcolls) == 0:
pcoll = pipeline | gensym("Start") >> beam.Create([-1])
pcoll = add_to_pcoll(rechunker_pipeline, pcoll)
dag.nodes[name]["pcoll"] = pcoll

elif len(pcolls) == 1:
pcoll = pcolls[0]
pcoll = add_to_pcoll(rechunker_pipeline, pcoll)
dag.nodes[name]["pcoll"] = pcoll
else:
pcoll = pcolls | gensym("Flatten") >> beam.Flatten()
pcoll |= gensym("Distinct") >> beam.Distinct()
pcoll = add_to_pcoll(rechunker_pipeline, pcoll)
dag.nodes[name]["pcoll"] = pcoll
pipeline = beam.Pipeline(**kwargs)

for name, node in visit_nodes(dag):
rechunker_pipeline = node["pipeline"]

dep_nodes = list(dag.predecessors(name))

pcolls = [
p
for (n, p) in nx.get_node_attributes(dag, "pcoll").items()
if n in dep_nodes
]
if len(pcolls) == 0:
pcoll = pipeline | gensym("Start") >> beam.Create([-1])
pcoll = add_to_pcoll(name, rechunker_pipeline, pcoll)
dag.nodes[name]["pcoll"] = pcoll

elif len(pcolls) == 1:
pcoll = pcolls[0]
pcoll = add_to_pcoll(name, rechunker_pipeline, pcoll)
dag.nodes[name]["pcoll"] = pcoll
else:
pcoll = pcolls | gensym("Flatten") >> beam.Flatten()
pcoll |= gensym("Distinct") >> beam.Distinct()
pcoll = add_to_pcoll(name, rechunker_pipeline, pcoll)
dag.nodes[name]["pcoll"] = pcoll

result = pipeline.run()

if callbacks is None:
result.wait_until_finish()
else:
wait_until_finish_with_callbacks(result, callbacks)


def add_to_pcoll(rechunker_pipeline, pcoll):
def add_to_pcoll(name, rechunker_pipeline, pcoll):
for step, stage in enumerate(rechunker_pipeline.stages):
if stage.mappable is not None:
pcoll |= stage.name >> _SingleArgumentStage(
step, stage, rechunker_pipeline.config
step, stage, rechunker_pipeline.config, name
)
else:
pcoll |= stage.name >> beam.Map(
Expand All @@ -181,3 +193,51 @@ def add_to_pcoll(rechunker_pipeline, pcoll):

pcoll |= gensym("End") >> beam.Map(lambda x: -1)
return pcoll


# A generalized version of Beam's PipelineResult.wait_until_finish method
# that polls for Beam metrics to make callbacks.
# If the pipeline is already done (e.g. the DirectRunner, which blocks)
# then all callbacks will be called before returning immediately.
def wait_until_finish_with_callbacks(result, callbacks):
MetricCallbackPoller(result, callbacks)


class MetricCallbackPoller:
def __init__(self, result, callbacks):
self.result = result
self.callbacks = callbacks
self.array_counts = {}
self.scheduler = sched.scheduler(time.time, time.sleep)
poll(self, self.result) # poll immediately
self.scheduler.run()

def update(self, new_array_counts):
for name, new_count in new_array_counts.items():
old_count = self.array_counts.get(name, 0)
# it's possible that new_count < old_count
event = TaskEndEvent(name, num_tasks=(new_count - old_count))
if self.callbacks is not None:
[callback.on_task_end(event) for callback in self.callbacks]
self.array_counts[name] = new_count


def poll(poller, result):
new_array_counts = get_array_counts_from_metrics(result)
poller.update(new_array_counts)
state = result.state
if PipelineState.is_terminal(state):
return
else:
# poll again in 5 seconds
scheduler = poller.scheduler
scheduler.enter(5, 1, poll, (poller, result))


def get_array_counts_from_metrics(result):
filter = beam.metrics.MetricsFilter().with_name("completed_tasks")
metrics = result.metrics().query(filter)["counters"]
new_array_counts = {
metric.key.metric.namespace: metric.result for metric in metrics
}
return new_array_counts
7 changes: 3 additions & 4 deletions cubed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from cubed.extensions.tqdm import TqdmProgressBar
from cubed.primitive.blockwise import apply_blockwise
from cubed.runtime.executors.python import PythonDagExecutor
from cubed.runtime.types import DagExecutor
from cubed.tests.utils import MAIN_EXECUTORS, MODAL_EXECUTORS, create_zarr


Expand Down Expand Up @@ -388,13 +389,11 @@ def on_task_end(self, event):
>= event.task_create_tstamp
> 0
)
self.value += 1
self.value += event.num_tasks


def test_callbacks(spec, executor):
from cubed.runtime.executors.lithops import LithopsDagExecutor

if not isinstance(executor, (PythonDagExecutor, LithopsDagExecutor)):
if not isinstance(executor, DagExecutor):
pytest.skip(f"{type(executor)} does not support callbacks")

task_counter = TaskCounter()
Expand Down
10 changes: 8 additions & 2 deletions examples/dataflow-add-random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import logging

from apache_beam.options.pipeline_options import PipelineOptions
from tqdm.contrib.logging import logging_redirect_tqdm

import cubed
import cubed.array_api as xp
import cubed.random
from cubed.extensions.tqdm import TqdmProgressBar
from cubed.runtime.executors.beam import BeamDagExecutor


Expand All @@ -29,8 +31,12 @@ def run(argv=None):
(50000, 50000), chunks=(5000, 5000), spec=spec
) # 200MB chunks
c = xp.add(a, b)
# use store=None to write to temporary zarr
cubed.to_zarr(c, store=None, executor=executor, options=beam_options)
with logging_redirect_tqdm():
progress = TqdmProgressBar()
# use store=None to write to temporary zarr
cubed.to_zarr(
c, store=None, executor=executor, callbacks=[progress], options=beam_options
)


if __name__ == "__main__":
Expand Down

0 comments on commit c0d776e

Please sign in to comment.