Skip to content

Commit

Permalink
use Ray 1.9 api change (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
odp authored Dec 14, 2021
1 parent 88c7227 commit 8a0ad47
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 25 deletions.
33 changes: 10 additions & 23 deletions ray/adaptdl_ray/tune/adaptdl_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from datetime import datetime
import logging
import copy
from typing import List

import ray
Expand All @@ -24,10 +23,6 @@
from ray.tune import PlacementGroupFactory
from ray.tune.function_runner import FuncCheckpointUtil
from ray.tune.trainable import TrainableUtil
from ray.tune.resources import resources_to_json
from ray._private.utils import binary_to_hex
import ray.cloudpickle as cloudpickle
from ray.tune.trial import Location

from adaptdl_ray.adaptdl import AdaptDLJobMixin
from adaptdl_ray.tune.adaptdl_trainable import AdaptDLTrainableCreator
Expand All @@ -49,26 +44,16 @@ def _num_replicas(self) -> int:
return self.get_trainable_cls()._num_workers

def __getstate__(self):
state = self.__dict__.copy()
copy_state = {}
# Remove problematic members
for k in ("_trial_in_use", "_cached_metrics"):
del state[k]

state["resources"] = resources_to_json(self.resources)

for key in self._nonjson_fields:
state[key] = binary_to_hex(cloudpickle.dumps(state.get(key)))

state["runner"] = None
state["location"] = Location()
# Avoid waiting for events that will never occur on resume.
state["restoring_from"] = None
state["saving_to"] = None

state["_state_json"] = None
state["_state_valid"] = False

return copy.deepcopy(state)
copy_state[k] = self.__dict__[k]
del self.__dict__[k]
state = super().__getstate__()
# Restore members
for k, v in copy_state.items():
self.__dict__[k] = v
return state

def _requeue(self,
old_trial: Trial,
Expand Down Expand Up @@ -118,6 +103,8 @@ def _clone_from(cls,
rescale_count=rescale_count,
config=trial.config,
experiment_tag=trial.experiment_tag,
evaluated_params=trial.evaluated_params,
stopping_criterion=trial.stopping_criterion,
trial_id=trial.trial_id,
restore_path=restore_path,
local_dir="/tmp", # TODO: Decide a proper way
Expand Down
6 changes: 4 additions & 2 deletions ray/adaptdl_ray/tune/adaptdl_trial_sched.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,13 @@ def choose_trial_to_run(
self, trial_runner: "trial_runner.TrialRunner") -> Optional[Trial]:
for trial in trial_runner.get_trials():
if (trial.status == Trial.PENDING
and trial_runner.has_resources_for_trial(trial)):
and trial_runner.trial_executor.
has_resources_for_trial(trial)):
return trial
for trial in trial_runner.get_trials():
if (trial.status == Trial.PAUSED and
trial_runner.has_resources_for_trial(trial) and
trial_runner.trial_executor.
has_resources_for_trial(trial) and
len(self._allocs) == 0):
# Note: this puts the trial back to RUNNING, we allow Trials to
# resume when the allocation cache is empty and we reach a sync
Expand Down

0 comments on commit 8a0ad47

Please sign in to comment.