Skip to content

Commit

Permalink
huggingface#3111 Set features correctly when concatenating.
Browse files Browse the repository at this point in the history
  • Loading branch information
fr.branchaud-charron committed Oct 19, 2021
1 parent adc5cec commit 835b1ea
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3643,6 +3643,15 @@ def concatenate_datasets(
format = {}
logger.info("Some of the datasets have disparate format. Resetting the format of the concatenated dataset.")

# Find column types.
if axis == 1:
features_d = {}
for dset in dsets:
features_d.update(dset.features)
features = Features(features_d)
else:
features = dsets[0].features

# Concatenate tables
tables_to_concat = [dset._data for dset in dsets if len(dset._data) > 0]
# There might be no table with data left hence return first empty table
Expand Down Expand Up @@ -3702,6 +3711,7 @@ def apply_offset_to_indices_table(table, offset):
fingerprint=fingerprint,
)
concatenated_dataset.set_format(**format)
concatenated_dataset = concatenated_dataset.cast(features)
return concatenated_dataset


Expand Down
13 changes: 13 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2242,6 +2242,19 @@ def test_concatenate_datasets_duplicate_columns(dataset):
assert "duplicated" in str(excinfo.value)


def test_concatenate_datasets_column_typing(dataset):
data = {"label": [0, 1, 1, 0], "col_1": ["a", "b", "c", "d"]}
data_2 = {"col2": [0, 1, 1, 0]}

features = Features({"label": ClassLabel(2, names=["POS", "NEG"]), "col_1": Value("string")})
features_2 = Features({"col2": ClassLabel(2, names=["POS", "NEG"])})
with Dataset.from_dict(data, features=features, info=DatasetInfo(features=features)) as dset:
with Dataset.from_dict(data_2, features=features_2, info=DatasetInfo(features=features_2)) as dset2:
concatenated = concatenate_datasets([dset, dset2], axis=1)
assert isinstance(concatenated.features["label"], ClassLabel)
assert isinstance(concatenated.features["col2"], ClassLabel)


def test_interleave_datasets():
d1 = Dataset.from_dict({"a": [0, 1, 2]})
d2 = Dataset.from_dict({"a": [10, 11, 12, 13]})
Expand Down

1 comment on commit 835b1ea

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==3.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.008538 / 0.011353 (-0.002815) 0.003405 / 0.011008 (-0.007603) 0.031259 / 0.038508 (-0.007249) 0.033922 / 0.023109 (0.010813) 0.297158 / 0.275898 (0.021260) 0.325307 / 0.323480 (0.001827) 0.007333 / 0.007986 (-0.000652) 0.004498 / 0.004328 (0.000170) 0.008984 / 0.004250 (0.004734) 0.040935 / 0.037052 (0.003883) 0.295503 / 0.258489 (0.037014) 0.332709 / 0.293841 (0.038868) 0.023378 / 0.128546 (-0.105168) 0.008234 / 0.075646 (-0.067413) 0.257455 / 0.419271 (-0.161816) 0.046830 / 0.043533 (0.003297) 0.311200 / 0.255139 (0.056061) 0.320581 / 0.283200 (0.037381) 0.078157 / 0.141683 (-0.063526) 1.679493 / 1.452155 (0.227338) 1.814394 / 1.492716 (0.321678)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.200297 / 0.018006 (0.182290) 0.433885 / 0.000490 (0.433395) 0.005603 / 0.000200 (0.005403) 0.000101 / 0.000054 (0.000047)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.036394 / 0.037411 (-0.001018) 0.022118 / 0.014526 (0.007592) 0.026620 / 0.176557 (-0.149937) 0.123798 / 0.737135 (-0.613337) 0.027241 / 0.296338 (-0.269097)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.416804 / 0.215209 (0.201595) 4.171330 / 2.077655 (2.093675) 1.824146 / 1.504120 (0.320026) 1.614394 / 1.541195 (0.073199) 1.623346 / 1.468490 (0.154856) 0.377181 / 4.584777 (-4.207596) 4.727638 / 3.745712 (0.981926) 0.911893 / 5.269862 (-4.357969) 0.841377 / 4.565676 (-3.724300) 0.041171 / 0.424275 (-0.383104) 0.004647 / 0.007607 (-0.002960) 0.520655 / 0.226044 (0.294611) 5.191782 / 2.268929 (2.922853) 2.225355 / 55.444624 (-53.219269) 1.888148 / 6.876477 (-4.988329) 1.877861 / 2.142072 (-0.264211) 0.485676 / 4.805227 (-4.319551) 0.101251 / 6.500664 (-6.399413) 0.050171 / 0.075469 (-0.025299)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.555605 / 1.841788 (-0.286183) 12.429013 / 8.074308 (4.354705) 26.804268 / 10.191392 (16.612876) 0.782427 / 0.680424 (0.102003) 0.537764 / 0.534201 (0.003563) 0.228385 / 0.579283 (-0.350898) 0.501420 / 0.434364 (0.067056) 0.177762 / 0.540337 (-0.362575) 0.191300 / 1.386936 (-1.195636)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.008713 / 0.011353 (-0.002640) 0.003233 / 0.011008 (-0.007775) 0.031199 / 0.038508 (-0.007309) 0.033859 / 0.023109 (0.010749) 0.282841 / 0.275898 (0.006943) 0.335397 / 0.323480 (0.011917) 0.007576 / 0.007986 (-0.000409) 0.004595 / 0.004328 (0.000266) 0.009037 / 0.004250 (0.004787) 0.041553 / 0.037052 (0.004501) 0.290293 / 0.258489 (0.031804) 0.342348 / 0.293841 (0.048507) 0.023842 / 0.128546 (-0.104704) 0.008149 / 0.075646 (-0.067498) 0.254479 / 0.419271 (-0.164792) 0.046591 / 0.043533 (0.003058) 0.287893 / 0.255139 (0.032754) 0.333378 / 0.283200 (0.050178) 0.082445 / 0.141683 (-0.059238) 1.714000 / 1.452155 (0.261845) 1.858271 / 1.492716 (0.365555)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.281968 / 0.018006 (0.263961) 0.434529 / 0.000490 (0.434039) 0.038604 / 0.000200 (0.038404) 0.000375 / 0.000054 (0.000320)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.035672 / 0.037411 (-0.001740) 0.021083 / 0.014526 (0.006557) 0.025326 / 0.176557 (-0.151231) 0.122373 / 0.737135 (-0.614763) 0.027349 / 0.296338 (-0.268990)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.410682 / 0.215209 (0.195473) 4.100667 / 2.077655 (2.023012) 1.739748 / 1.504120 (0.235628) 1.554251 / 1.541195 (0.013056) 1.555051 / 1.468490 (0.086561) 0.376320 / 4.584777 (-4.208457) 4.758211 / 3.745712 (1.012499) 0.934728 / 5.269862 (-4.335133) 0.863952 / 4.565676 (-3.701725) 0.041466 / 0.424275 (-0.382810) 0.004715 / 0.007607 (-0.002892) 0.511994 / 0.226044 (0.285949) 5.129285 / 2.268929 (2.860356) 2.210843 / 55.444624 (-53.233782) 1.828930 / 6.876477 (-5.047546) 1.819538 / 2.142072 (-0.322534) 0.484991 / 4.805227 (-4.320236) 0.101811 / 6.500664 (-6.398853) 0.051916 / 0.075469 (-0.023553)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.507347 / 1.841788 (-0.334441) 12.498933 / 8.074308 (4.424625) 27.288297 / 10.191392 (17.096905) 0.677019 / 0.680424 (-0.003405) 0.510929 / 0.534201 (-0.023272) 0.225680 / 0.579283 (-0.353603) 0.503333 / 0.434364 (0.068969) 0.183350 / 0.540337 (-0.356987) 0.194135 / 1.386936 (-1.192801)

CML watermark

Please sign in to comment.