Skip to content

Commit

Permalink
vine: file pruning by depth (#4057)
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZhou5042 authored Feb 14, 2025
1 parent f38ba02 commit 2fc88bc
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 25 deletions.
69 changes: 58 additions & 11 deletions taskvine/src/bindings/python3/ndcctools/taskvine/compat/dask_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def hashable(s):
except TypeError:
return False

def __init__(self, dsk, low_memory_mode=False):
def __init__(self, dsk, low_memory_mode=False, prune_depth=0):
self._dsk = dsk

# child -> parents. I.e., which parents needs the result of child
Expand All @@ -73,9 +73,6 @@ def __init__(self, dsk, low_memory_mode=False):
# key->value of its computation
self._result_of = {}

# child -> nodes that use the child as an input, and that have not been completed
self._pending_parents_of = defaultdict(lambda: set())

# key->depth. The shallowest level the key is found
self._depth_of = defaultdict(lambda: float('inf'))

Expand All @@ -86,6 +83,10 @@ def __init__(self, dsk, low_memory_mode=False):
if low_memory_mode:
self._flatten_graph()

self.prune_depth = prune_depth
self.pending_consumers = defaultdict(int)
self.pending_producers = defaultdict(lambda: set())

self.initialize_graph()

def left_to_compute(self):
Expand All @@ -103,6 +104,11 @@ def initialize_graph(self):
for key, sexpr in self._working_graph.items():
self.set_relations(key, sexpr)

# Then initialize pending consumers if pruning is enabled
if self.prune_depth > 0:
self._initialize_pending_consumers()
self._initialize_pending_producers()

def find_dependencies(self, sexpr, depth=0):
dependencies = set()
if self.graph_keyp(sexpr):
Expand All @@ -123,7 +129,53 @@ def set_relations(self, key, sexpr):

for c in self._children_of[key]:
self._parents_of[c].add(key)
self._pending_parents_of[c].add(key)

def _initialize_pending_consumers(self):
"""Initialize pending consumers counts based on prune_depth"""
for key in self._working_graph:
if key not in self.pending_consumers:
count = 0
# BFS to count consumers up to prune_depth
visited = set()
queue = [(c, 1) for c in self._parents_of[key]] # (consumer, depth)

while queue:
consumer, depth = queue.pop(0)
if depth <= self.prune_depth and consumer not in visited:
visited.add(consumer)
count += 1

# Add next level consumers if we haven't reached max depth
if depth < self.prune_depth:
next_consumers = [(c, depth + 1) for c in self._parents_of[consumer]]
queue.extend(next_consumers)

self.pending_consumers[key] = count

def _initialize_pending_producers(self):
"""Initialize pending producers based on prune_depth"""
if self.prune_depth <= 0:
return

for key in self._working_graph:
# Use set to store unique producers
producers = set()
visited = set()
queue = [(p, 1) for p in self._children_of[key]] # (producer, depth)

while queue:
producer, depth = queue.pop(0)
if depth <= self.prune_depth and producer not in visited:
visited.add(producer)
producers.add(producer)

# Add next level producers if we haven't reached max depth
if depth < self.prune_depth:
next_producers = [(p, depth + 1) for p in self._children_of[producer]]
queue.extend(next_producers)

# Store all producers for this key in pending_producers
self.pending_producers[key] = producers

def get_ready(self):
""" List of [(key, sexpr),...] ready for computation.
Expand All @@ -148,6 +200,7 @@ def set_result(self, key, value):
of computations that become ready to be executed """
rs = {}
self._result_of[key] = value

for p in self._parents_of[key]:
self._missing_of[p].discard(key)

Expand All @@ -164,9 +217,6 @@ def set_result(self, key, value):
else:
rs[p] = (p, sexpr)

for c in self._children_of[key]:
self._pending_parents_of[c].discard(key)

return rs.values()

def _flatten_graph(self):
Expand Down Expand Up @@ -228,9 +278,6 @@ def get_missing_children(self, key):
def get_parents(self, key):
return self._parents_of[key]

def get_pending_parents(self, key):
return self._pending_parents_of[key]

def set_targets(self, keys):
""" Values of keys that need to be computed. """
self._targets.update(keys)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class DaskVine(Manager):
# fn(*args) at some point during its execution to produce the dask task result.
# Should return a tuple of (wrapper result, dask call result). Use for debugging.
# @param wrapper_proc Function to process results from wrapper on completion. (default is print)
# @param prune_files If True, remove files from the cluster after they are no longer needed.
# @param prune_depth Control pruning behavior: 0 (default) - no pruning, 1 - only check direct consumers, 2+ - check consumers up to specified depth
def get(self, dsk, keys, *,
environment=None,
extra_files=None,
Expand All @@ -132,7 +132,7 @@ def get(self, dsk, keys, *,
progress_label="[green]tasks",
wrapper=None,
wrapper_proc=print,
prune_files=False,
prune_depth=0,
hoisting_modules=None, # Deprecated, use lib_modules
import_modules=None, # Deprecated, use lib_modules
lazy_transfers=True, # Deprecated, use worker_tranfers
Expand Down Expand Up @@ -174,7 +174,7 @@ def get(self, dsk, keys, *,
self.progress_label = progress_label
self.wrapper = wrapper
self.wrapper_proc = wrapper_proc
self.prune_files = prune_files
self.prune_depth = prune_depth
self.category_info = defaultdict(lambda: {"num_tasks": 0, "total_execution_time": 0})
self.max_priority = float('inf')
self.min_priority = float('-inf')
Expand Down Expand Up @@ -212,7 +212,7 @@ def _dask_execute(self, dsk, keys):
indices = {k: inds for (k, inds) in find_dask_keys(keys)}
keys_flatten = indices.keys()

dag = DaskVineDag(dsk, low_memory_mode=self.low_memory_mode)
dag = DaskVineDag(dsk, low_memory_mode=self.low_memory_mode, prune_depth=self.prune_depth)
tag = f"dag-{id(dag)}"

# create Library if using 'function-calls' task mode.
Expand Down Expand Up @@ -294,8 +294,12 @@ def _dask_execute(self, dsk, keys):
if t.key in dsk:
bar_update(advance=1)

if self.prune_files:
self._prune_file(dag, t.key)
if self.prune_depth > 0:
for p in dag.pending_producers[t.key]:
dag.pending_consumers[p] -= 1
if dag.pending_consumers[p] == 0:
p_result = dag.get_result(p)
self.prune_file(p_result._file)
else:
retries_left = t.decrement_retry()
print(f"task id {t.id} key {t.key} failed: {t.result}. {retries_left} attempts left.\n{t.std_output}")
Expand Down Expand Up @@ -446,14 +450,6 @@ def _fill_key_result(self, dag, key):
return raw.load()
else:
return raw

def _prune_file(self, dag, key):
children = dag.get_children(key)
for c in children:
if len(dag.get_pending_parents(c)) == 0:
c_result = dag.get_result(c)
self.prune_file(c_result._file)

##
# @class ndcctools.taskvine.dask_executor.DaskVineFile
#
Expand Down

0 comments on commit 2fc88bc

Please sign in to comment.