From 8a0ad47da91ab4b8f5e13c819cb4701a2ebe8ca8 Mon Sep 17 00:00:00 2001 From: Omkar Pangarkar Date: Tue, 14 Dec 2021 15:21:11 -0500 Subject: [PATCH] use Ray 1.9 api change (#107) --- ray/adaptdl_ray/tune/adaptdl_trial.py | 33 +++++++-------------- ray/adaptdl_ray/tune/adaptdl_trial_sched.py | 6 ++-- 2 files changed, 14 insertions(+), 25 deletions(-) diff --git a/ray/adaptdl_ray/tune/adaptdl_trial.py b/ray/adaptdl_ray/tune/adaptdl_trial.py index becd619f..2906fca2 100644 --- a/ray/adaptdl_ray/tune/adaptdl_trial.py +++ b/ray/adaptdl_ray/tune/adaptdl_trial.py @@ -15,7 +15,6 @@ from datetime import datetime import logging -import copy from typing import List import ray @@ -24,10 +23,6 @@ from ray.tune import PlacementGroupFactory from ray.tune.function_runner import FuncCheckpointUtil from ray.tune.trainable import TrainableUtil -from ray.tune.resources import resources_to_json -from ray._private.utils import binary_to_hex -import ray.cloudpickle as cloudpickle -from ray.tune.trial import Location from adaptdl_ray.adaptdl import AdaptDLJobMixin from adaptdl_ray.tune.adaptdl_trainable import AdaptDLTrainableCreator @@ -49,26 +44,16 @@ def _num_replicas(self) -> int: return self.get_trainable_cls()._num_workers def __getstate__(self): - state = self.__dict__.copy() + copy_state = {} # Remove problematic members for k in ("_trial_in_use", "_cached_metrics"): - del state[k] - - state["resources"] = resources_to_json(self.resources) - - for key in self._nonjson_fields: - state[key] = binary_to_hex(cloudpickle.dumps(state.get(key))) - - state["runner"] = None - state["location"] = Location() - # Avoid waiting for events that will never occur on resume. - state["restoring_from"] = None - state["saving_to"] = None - - state["_state_json"] = None - state["_state_valid"] = False - - return copy.deepcopy(state) + copy_state[k] = self.__dict__[k] + del self.__dict__[k] + state = super().__getstate__() + # Restore members + for k, v in copy_state.items(): + self.__dict__[k] = v + return state def _requeue(self, old_trial: Trial, @@ -118,6 +103,8 @@ def _clone_from(cls, rescale_count=rescale_count, config=trial.config, experiment_tag=trial.experiment_tag, + evaluated_params=trial.evaluated_params, + stopping_criterion=trial.stopping_criterion, trial_id=trial.trial_id, restore_path=restore_path, local_dir="/tmp", # TODO: Decide a proper way diff --git a/ray/adaptdl_ray/tune/adaptdl_trial_sched.py b/ray/adaptdl_ray/tune/adaptdl_trial_sched.py index d18825b3..378b99cc 100644 --- a/ray/adaptdl_ray/tune/adaptdl_trial_sched.py +++ b/ray/adaptdl_ray/tune/adaptdl_trial_sched.py @@ -108,11 +108,13 @@ def choose_trial_to_run( self, trial_runner: "trial_runner.TrialRunner") -> Optional[Trial]: for trial in trial_runner.get_trials(): if (trial.status == Trial.PENDING - and trial_runner.has_resources_for_trial(trial)): + and trial_runner.trial_executor. + has_resources_for_trial(trial)): return trial for trial in trial_runner.get_trials(): if (trial.status == Trial.PAUSED and - trial_runner.has_resources_for_trial(trial) and + trial_runner.trial_executor. + has_resources_for_trial(trial) and len(self._allocs) == 0): # Note: this puts the trial back to RUNNING, we allow Trials to # resume when the allocation cache is empty and we reach a sync