Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New api about checkpoint and models #10878

Merged
merged 39 commits into from
Jun 10, 2018

Conversation

seiriosPlus
Copy link
Collaborator

@seiriosPlus seiriosPlus commented May 23, 2018

  1. add Checkpoint handle and config in Train.py
  2. update load_model/save_model API
    Related update fluid Train API param_path to checkpoint_config #10828
    Related Incremental Learning Support for Fluid with Distribution #10870

@seiriosPlus seiriosPlus changed the title [WIP] New api about cpkt/params New api about checkpoint and models May 29, 2018
@seiriosPlus seiriosPlus requested a review from Yancey1989 May 30, 2018 06:26
checkpoint_dir = os.getcwd()
raise ValueError("The values of 'checkpoint_dir' should not be None")

if trainer_args and not isinstance(trainer_args, dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what're the details about trainer_args? Shall we need a class instead of a parameter with dict type, it's confusing for users.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trainer_args not for the user, just for the developers, currently, trainer_args only contains step_id, epoch_id, there maybe have more arguments need to be saved in the checkpoint.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the clear parameters, it would be confused with other developers, use step_id and epoch_id or a Class as the configuration parameter.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I write two functions to make it more clearly.

save_trainer_args(cur_dir, trainer_id, trainer_args)

if is_chief:
save_persist_vars_without_grad(executor, cur_dir, main_program)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks all gradient vars are all not persistent, so maybe the function name would shorter for save_persistent_vars? BTW, persist is a verb, we need the adjective one: persistent .

Copy link
Collaborator Author

@seiriosPlus seiriosPlus Jun 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First, I find arguments named "X@GRAD" are persistent .
Second, save_persistent_vars do not filter RAW arguments.

if "@GRAD" in var.name:
return False

if ".trainer_" in var.name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some comments to explain what's the meaning of the hard code .blcok and .trainer_ ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

_lru_delete(checkpoint_dir, max_num_checkpoints)


def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
def need_load_checkpoint(checkpoint_dir):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confusing about this function, for need to do ... means the function should return a boolean variable, so that we can use the function as:

if need_load_checkpoint(xxx):
    # resume from the checkpoint
else:
    # train from the beginning

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed need_load_checkpoint to get_latest_checkpoint_serial, make it more meaningful.

"""
if checkpoint_dir is None:
checkpoint_dir = os.getcwd()
raise ValueError("The values of 'checkpoint_dir' should not be None")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checkpoint_dir should not be None.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checkpoint_dir should be checked in Trainer.py, A property directory will be given by Trainer.py.

_lru_delete(checkpoint_dir, max_num_checkpoints)


def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
def get_latest_checkpoint_serial(checkpoint_dir):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems we don't need get_latest_checkpoint_serial , do the same thing with _get_latest_checkpoint_dir, or we can rename _get_latest_checkpoint_dir to get_latest_checkpoint_serial.

:param main_program
"""

if checkpoint_dir is None:
checkpoint_dir = os.getcwd()
raise ValueError(
"The values of 'checkpoint_dir' or 'serial' should not be None")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"The values of 'checkpoint_dir' or 'serial' should not be None")

Seems here only check checkpoint_dir ?

if serial < 0:
return
if main_program is None:
raise ValueError("The values of 'main_program'should not be None")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The values of 'main_program'should not be None

main_program should not be None.

load_persist_vars_without_grad will load variables from a directory by an executor,
the variable named end with "@GRAD" will not be loaded.

:param executor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the details comments about the parameters.

@@ -193,14 +253,18 @@ def _dist_transpile_if_necessary(self, optimize_ops, params_grads):
current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port
# the unique trainer id, starting from 0, needed by trainer
# only
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
self.chief = self.trainer_id == 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why hard code trainer_id and chief ere equal to 0 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will default initialize trainer_id =0 and chief = True as default.
If run PaddlePaddle as local, there is only one trainer, its trainer_id is 0, and it is the chief obviously.
If run PaddlePaddle as distribution, we will get PADDLE_TRAINER_ID from env, there will only have one trainer as the chief.

checkpoint_dir = os.getcwd()
raise ValueError("The values of 'checkpoint_dir' should not be None")

if trainer_args and not isinstance(trainer_args, dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the clear parameters, it would be confused with other developers, use step_id and epoch_id or a Class as the configuration parameter.

Copy link
Contributor

@Yancey1989 Yancey1989 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a unit test for the checkpoint feature.

Copy link
Contributor

@Yancey1989 Yancey1989 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some comments above, please follow them thanks.

if not os.path.isdir(trainer_dir):
os.makedirs(trainer_dir)

return trainer_dir


def _lru_delete(dirname, max_num_checkpoints=3):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems this function does not do implement a real LRU algorithms, scroll_delete would be better.

int(serial)
except ValueError:
serial = _get_dir_serial(cur_dir)
if serial == -1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please merge the two condition statement.

@@ -348,6 +423,41 @@ def _get_or_create_parallel_executor(self):
loss_name=self.train_func_outputs[0].name)
return self._get_parallel_executor()

def _clean_checkpoint(self):
if not self.checkpoint:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use assert instead of return directly, otherwise we don't know whether this function success.

return trainer_args

def _save_checkpoint(self, epoch_id, step_id):
if not self.checkpoint:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same reason, use assert please.

@@ -473,79 +478,143 @@ def save_checkpoint(executor,

:param executor
:param checkpoint_dir
:param max_num_checkpoints
:param save_interval_secs
:param trainer_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need more details comments.


class TestCheckpoint(unittest.TestCase):
def setUp(self):
self.dirname = "/tmp/ckpt"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to use tempfile instead of the hard code path.

Copy link
Contributor

@Yancey1989 Yancey1989 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, this feature is related to the API for user, please @typhoonzero double check.

:param executor executor for save the value
:param checkpoint_dir the checkpoint directory
:param trainer_id currect trainer id
:param is_chief if the trainer id equals 0, the is_chief will be true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If have is_chief why still need to pass trainer_id?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

each trainer need to save its arguments practicality.
Only chief need to save variables.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have deleted code about chief

def _get_serial_dir(serial, checkpoint_dir):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
return os.path.join(checkpoint_dir, serial_folder)
def load_persist_vars_without_grad(executor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

write load_persist_vars_without_grad just because of the filter.
I need to write a new filter to filter variables.

@typhoonzero
Copy link
Contributor

The test seems didn't pass?

@seiriosPlus
Copy link
Collaborator Author

I will check it.

@seiriosPlus seiriosPlus merged commit d896134 into PaddlePaddle:develop Jun 10, 2018
@seiriosPlus seiriosPlus deleted the new_api_about_cpkt branch June 10, 2018 09:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants