Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine savable #8758

Merged
merged 2 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions paddlenlp/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@
return ret

@classmethod
def _get_nonsavable_keys(cls):
def _get_unsavable_keys(cls):
ret = set()
for attrs in [
cls.op_fusion_attributes,
Expand Down Expand Up @@ -516,7 +516,7 @@
_auto_class: Optional[str] = None

# Fix me, it is global for all config
_nonsavable_keys = set()
_unsavable_keys = set()

def __setattr__(self, key, value):
if key in super().__getattribute__("attribute_map"):
Expand All @@ -542,7 +542,8 @@
kwargs = attribute_map(self, kwargs=kwargs)
kwargs.pop("transformers_version", None)
llm_meta = LlmMetaConfig._get_defaults()
self._nonsavable_keys.update(LlmMetaConfig._get_nonsavable_keys())
self._unsavable_keys.update(LlmMetaConfig._get_unsavable_keys())
self._unsavable_keys.remove("tensor_parallel_degree")

kwargs = set_expected_keys(self, llm_meta, kwargs)
if self.sequence_parallel:
Expand Down Expand Up @@ -1011,14 +1012,14 @@

return serializable_config_dict

def register_nonsaveable_keys(self, keys):
def register_unsavable_keys(self, keys):
# Save: not save it in any case
# Print: show it if non defalut value
if type(keys) == list or type(keys) == tuple:
for key in keys:
self._nonsavable_keys.add(key)
self._unsavable_keys.add(key)
else:
self._nonsavable_keys.add(keys)
self._unsavable_keys.add(keys)

Check warning on line 1022 in paddlenlp/transformers/configuration_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/configuration_utils.py#L1022

Added line #L1022 was not covered by tests

def to_dict(self, saving_file=False) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -1047,9 +1048,9 @@
output[key] = value

# Fix for rewrited from_pretrained method, hasattr
if saving_file and hasattr(self, "_nonsavable_keys"):
if saving_file and hasattr(self, "_unsavable_keys"):
for key in list(output.keys()):
if key in self._nonsavable_keys:
if key in self._unsavable_keys:
output.pop(key)

if hasattr(self, "quantization_config"):
Expand Down
4 changes: 2 additions & 2 deletions tests/transformers/test_configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_model_config_save(self):

config.test_nonsave = "test"
config.test_nonsave_2 = "test"
config.register_nonsaveable_keys(["test_nonsave"])
config.register_unsavable_keys(["test_nonsave"])

with tempfile.TemporaryDirectory() as tp:
config.save_pretrained(tp)
Expand All @@ -105,7 +105,7 @@ def test_model_config_save(self):
loaded_config = json.load(open(os.path.join(tp, "config.json"), "r"))
assert "fuse_attention_qkv" in loaded_config, "fuse qkv is need to save"
assert "use_fused_rms_norm" not in loaded_config, "use_fused_rms_norm don't need to save"
assert "tensor_parallel_degree" not in loaded_config, "tensor_parallel_degree don't need to save"
assert "tensor_parallel_degree" in loaded_config, "tensor_parallel_degree need to save"
assert "paddlenlp_version" in loaded_config, "always save paddlenlp_version"
assert (
"quantization_config" in loaded_config and "quant_type" in loaded_config["quantization_config"]
Expand Down
Loading