From 724e767d594e81dd401017e15d6798fbe236329e Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Tue, 30 Jul 2024 12:51:01 +0200 Subject: [PATCH] Fix push_to_hub by not calling create_branch if branch exists (#7069) * Fix push_to_hub by not calling create_branch if branch exists * Fix push_to_hub by not calling create_branch if branch exists * Reword comment * Fix push_to_hub by not calling create_branch if PR ref * Update test --- src/datasets/arrow_dataset.py | 3 ++- src/datasets/dataset_dict.py | 3 ++- tests/test_hub.py | 5 ++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 7ba052d3fde5..25b27d091a3c 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -5677,7 +5677,8 @@ def push_to_hub( ) repo_id = repo_url.repo_id - if revision is not None: + if revision is not None and not revision.startswith("refs/pr/"): + # We do not call create_branch for a PR reference: 400 Bad Request api.create_branch(repo_id, branch=revision, token=token, repo_type="dataset", exist_ok=True) if not data_dir: diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index 5d2d9dcd9ffe..cf4a6cc98f8b 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -1708,7 +1708,8 @@ def push_to_hub( ) repo_id = repo_url.repo_id - if revision is not None: + if revision is not None and not revision.startswith("refs/pr/"): + # We do not call create_branch for a PR reference: 400 Bad Request api.create_branch(repo_id, branch=revision, token=token, repo_type="dataset", exist_ok=True) if not data_dir: diff --git a/tests/test_hub.py b/tests/test_hub.py index ab766d017794..9485fe83a71c 100644 --- a/tests/test_hub.py +++ b/tests/test_hub.py @@ -66,9 +66,8 @@ def test_convert_to_parquet(temporary_repo, hf_api, hf_token, ci_hub_config, ci_ _ = convert_to_parquet(repo_id, token=hf_token, trust_remote_code=True) # mock_create_branch assert mock_create_branch.called - assert mock_create_branch.call_count == 2 - for call_args, expected_branch in zip(mock_create_branch.call_args_list, ["refs/pr/1", "script"]): - assert call_args.kwargs.get("branch") == expected_branch + assert mock_create_branch.call_count == 1 + assert mock_create_branch.call_args.kwargs.get("branch") == "script" # mock_create_commit assert mock_create_commit.called assert mock_create_commit.call_count == 2