-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fleetrun launch in legacy mode #40568
Changes from all commits
1d5b3ae
c7da5b5
41c3854
591c5f8
127114f
be9e1da
66086d8
c7db6dd
6b71984
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from paddle.distributed.launch import plugins | ||
|
||
from .node import Node | ||
from .status import Status | ||
from .args_envs import parse_args, fetch_envs, env_args_mapping | ||
|
||
import logging | ||
|
||
|
||
class Context(object): | ||
def __init__(self, enable_plugin=True): | ||
self.args, self.unknown_args = parse_args() | ||
self.envs = fetch_envs() | ||
self.logger = self.get_logger() | ||
|
||
self.node = Node() | ||
self.status = Status() | ||
|
||
self.set_env_in_args() | ||
|
||
# design for event queue, later | ||
self.events = [] | ||
|
||
if enable_plugin: | ||
self._enable_plugin() | ||
|
||
def is_legacy_mode(self): | ||
if self.args.legacy: | ||
return True | ||
|
||
if len(self.unknown_args) > 0: | ||
self.logger.warning("Compatible mode enable with args {}".format( | ||
self.unknown_args)) | ||
return True | ||
|
||
legacy_env_list = [ | ||
'DISTRIBUTED_TRAINER_ENDPOINTS', | ||
'PADDLE_ELASTIC_JOB_ID', | ||
'PADDLE_DISTRI_BACKEND', | ||
'FLAGS_START_PORT', | ||
] | ||
|
||
for env in legacy_env_list: | ||
if env in self.envs: | ||
self.logger.warning( | ||
"ENV {} is deprecated, legacy launch enable".format(env)) | ||
return True | ||
|
||
if self.args.master: | ||
return False | ||
|
||
return False | ||
|
||
def get_envs(self): | ||
return self.envs.copy() | ||
|
||
def _enable_plugin(self): | ||
for pl in plugins.enabled_plugins: | ||
pl(self) | ||
|
||
def get_logger(self, level=logging.INFO): | ||
logger = logging.getLogger("LAUNCH") | ||
logger.setLevel(self.args.log_level.upper() or level) | ||
formatter = logging.Formatter( | ||
fmt='%(name)s %(levelname)s %(asctime)s %(message)s') | ||
ch = logging.StreamHandler() | ||
ch.setFormatter(formatter) | ||
logger.addHandler(ch) | ||
return logger | ||
|
||
def set_env_in_args(self): | ||
for k, v in env_args_mapping.items(): | ||
if k in self.envs: | ||
setattr(self.args, v, self.envs[k]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
from argparse import ArgumentParser, REMAINDER | ||
|
||
env_args_mapping = { | ||
'POD_IP': 'host', | ||
'PADDLE_MASTER': 'master', | ||
'PADDLE_DEVICES': 'devices', | ||
'PADDLE_NNODES': 'nnodes', | ||
'PADDLE_MODE': 'mode', | ||
'PADDLE_LOG_LEVEL': 'log_level', | ||
'PADDLE_NPROC_PER_NODE': 'nproc_per_node', | ||
'PADDLE_JOB_ID': 'job_id', | ||
'PADDLE_RANK': 'rank', | ||
'PADDLE_LOG_DIR': 'log_dir', | ||
'PADDLE_MAX_RESTART': 'max_restart', | ||
'PADDLE_ELASTIC_LEVEL': 'elastic_level', | ||
'PADDLE_ELASTIC_TIMEOUT': 'elastic_timeout', | ||
'PADDLE_SERVER_NUM': 'server_num', | ||
'PADDLE_TRAINER_NUM': 'trainer_num', | ||
'PADDLE_SERVERS_ENDPOINTS': 'servers', | ||
'PADDLE_TRAINERS_ENDPOINTS': 'trainers', | ||
'PADDLE_GLOO_PORT': 'gloo_port', | ||
'PADDLE_WITH_GLOO': 'with_gloo', | ||
} | ||
|
||
|
||
def fetch_envs(): | ||
os.environ.pop('http_proxy', None) | ||
os.environ.pop('https_proxy', None) | ||
|
||
return os.environ.copy() | ||
|
||
|
||
def parse_args(): | ||
parser = ArgumentParser() | ||
|
||
base_group = parser.add_argument_group("Base Parameters") | ||
|
||
base_group.add_argument( | ||
"--master", | ||
type=str, | ||
default=None, | ||
help="the master/rendezvous server, ip:port") | ||
|
||
base_group.add_argument( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 新增的参数legacy建议在增加一些解释 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里可以说明一下legacy只用于内部调试,外部开发者不需要关心这两种模式的差别。 |
||
"--legacy", type=bool, default=False, help="use legacy launch") | ||
|
||
base_group.add_argument( | ||
"--rank", type=int, default=-1, help="the peer rank") | ||
|
||
base_group.add_argument( | ||
"--log_level", type=str, default="INFO", help="log level. Default INFO") | ||
|
||
base_group.add_argument( | ||
"--nnodes", | ||
type=str, | ||
default="1", | ||
help="the number of peers, i.e. pod/node number") | ||
|
||
base_group.add_argument( | ||
"--nproc_per_node", | ||
type=int, | ||
default=None, | ||
help="the number of processes in a pod") | ||
|
||
base_group.add_argument( | ||
"--log_dir", | ||
type=str, | ||
default="log", | ||
help="the path for each process's log. Default ./log") | ||
base_group.add_argument( | ||
"--mode", | ||
type=str, | ||
default="collective", | ||
help="run mode of the job, collective/ps/ps-heter") | ||
|
||
base_group.add_argument( | ||
"--job_id", | ||
type=str, | ||
default="default", | ||
help="unique id of the job. Default default") | ||
|
||
base_group.add_argument( | ||
"--devices", | ||
type=str, | ||
default=None, | ||
help="accelerate devices. as --gpus,npus,xps") | ||
|
||
base_group.add_argument("--host", type=str, default=None, help="host ip") | ||
|
||
base_group.add_argument( | ||
"training_script", | ||
type=str, | ||
help="the full path of py script," | ||
"followed by arguments for the " | ||
"training script") | ||
|
||
base_group.add_argument('training_script_args', nargs=REMAINDER) | ||
|
||
ps_group = parser.add_argument_group("Parameter-Server Parameters") | ||
# for parameter server | ||
ps_group.add_argument( | ||
"--servers", type=str, default='', help="servers endpoints full list") | ||
ps_group.add_argument( | ||
"--trainers", type=str, default='', help="trainers endpoints full list") | ||
|
||
ps_group.add_argument( | ||
"--trainer_num", type=int, default=None, help="number of trainers") | ||
ps_group.add_argument( | ||
"--server_num", type=int, default=None, help="number of servers") | ||
ps_group.add_argument( | ||
"--gloo_port", type=int, default=6767, help="gloo http port") | ||
ps_group.add_argument( | ||
"--with_gloo", type=str, default="0", help="use gloo or not") | ||
|
||
# parameter elastic mode | ||
elastic_group = parser.add_argument_group("Elastic Parameters") | ||
elastic_group.add_argument( | ||
"--max_restart", | ||
type=int, | ||
default=3, | ||
help="the times can restart. Default 3") | ||
|
||
elastic_group.add_argument( | ||
"--elastic_level", | ||
type=int, | ||
default=-1, | ||
help="elastic level: -1 disable, 0 failed exit, peers hold, 1 internal restart" | ||
) | ||
|
||
elastic_group.add_argument( | ||
"--elastic_timeout", | ||
type=int, | ||
default=30, | ||
help="seconds to wait before elastic perform training") | ||
|
||
return parser.parse_known_args() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果是在平台上通过环境变量配置的,是对应新版本的逻辑还是对应老版本的逻辑,可以在check一下是否有兼容的问题
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
新版本统一 --devices, 否则使用老版本