From 8affbf7fe09b71617d5613fa2d60887ed7c7bad7 Mon Sep 17 00:00:00 2001 From: Stepan Bagritsevich Date: Fri, 29 Nov 2024 19:02:16 +0400 Subject: [PATCH] fix(search_family): Remove the output of extra fields in the FT.AGGREGATE command fixes dragonflydb#4230 Signed-off-by: Stepan Bagritsevich --- src/server/search/aggregator.cc | 47 +++++++++++++++++-------- src/server/search/aggregator.h | 17 ++++++--- src/server/search/search_family.cc | 31 +++++++++++----- src/server/search/search_family_test.cc | 41 ++++++++++++++++----- 4 files changed, 101 insertions(+), 35 deletions(-) diff --git a/src/server/search/aggregator.cc b/src/server/search/aggregator.cc index 255d82e10857..4b6b4a5620cf 100644 --- a/src/server/search/aggregator.cc +++ b/src/server/search/aggregator.cc @@ -11,10 +11,10 @@ namespace dfly::aggregate { namespace { struct GroupStep { - PipelineResult operator()(std::vector values) { + PipelineResult operator()(PipelineResult result) { // Separate items into groups absl::flat_hash_map, std::vector> groups; - for (auto& value : values) { + for (auto& value : result.values) { groups[Extract(value)].push_back(std::move(value)); } @@ -28,7 +28,18 @@ struct GroupStep { } out.push_back(std::move(doc)); } - return out; + + absl::flat_hash_set fields_to_print; + fields_to_print.reserve(fields_.size() + reducers_.size()); + + for (auto& field : fields_) { + fields_to_print.insert(std::move(field)); + } + for (auto& reducer : reducers_) { + fields_to_print.insert(std::move(reducer.result_field)); + } + + return {std::move(out), std::move(fields_to_print)}; } absl::FixedArray Extract(const DocValues& dv) { @@ -104,34 +115,42 @@ PipelineStep MakeGroupStep(absl::Span fields, } PipelineStep MakeSortStep(std::string_view field, bool descending) { - return [field = std::string(field), descending](std::vector values) -> PipelineResult { + return [field = std::string(field), descending](PipelineResult result) -> PipelineResult { + auto& values = result.values; + std::sort(values.begin(), values.end(), [field](const DocValues& l, const DocValues& r) { auto it1 = l.find(field); auto it2 = r.find(field); return it1 == l.end() || (it2 != r.end() && it1->second < it2->second); }); - if (descending) + + if (descending) { std::reverse(values.begin(), values.end()); - return values; + } + + result.fields_to_print.insert(field); + return result; }; } PipelineStep MakeLimitStep(size_t offset, size_t num) { - return [offset, num](std::vector values) -> PipelineResult { + return [offset, num](PipelineResult result) { + auto& values = result.values; values.erase(values.begin(), values.begin() + std::min(offset, values.size())); values.resize(std::min(num, values.size())); - return values; + return result; }; } -PipelineResult Process(std::vector values, absl::Span steps) { +PipelineResult Process(std::vector values, + absl::Span fields_to_print, + absl::Span steps) { + PipelineResult result{std::move(values), {fields_to_print.begin(), fields_to_print.end()}}; for (auto& step : steps) { - auto result = step(std::move(values)); - if (!result.has_value()) - return result; - values = std::move(result.value()); + PipelineResult step_result = step(std::move(result)); + result = std::move(step_result); } - return values; + return result; } } // namespace dfly::aggregate diff --git a/src/server/search/aggregator.h b/src/server/search/aggregator.h index 727c0ba96ed0..4f4008bce238 100644 --- a/src/server/search/aggregator.h +++ b/src/server/search/aggregator.h @@ -5,6 +5,7 @@ #pragma once #include +#include #include #include @@ -19,10 +20,16 @@ namespace dfly::aggregate { using Value = ::dfly::search::SortableValue; using DocValues = absl::flat_hash_map; // documents sent through the pipeline -// TODO: Replace DocValues with compact linear search map instead of hash map +struct PipelineResult { + // Values to be passed to the next step + // TODO: Replace DocValues with compact linear search map instead of hash map + std::vector values; -using PipelineResult = io::Result, facade::ErrorReply>; -using PipelineStep = std::function)>; // Group, Sort, etc. + // Fields from values to be printed + absl::flat_hash_set fields_to_print; +}; + +using PipelineStep = std::function; // Group, Sort, etc. // Iterator over Span that yields doc[field] or monostate if not present. // Extra clumsy for STL compatibility! @@ -82,6 +89,8 @@ PipelineStep MakeSortStep(std::string_view field, bool descending = false); PipelineStep MakeLimitStep(size_t offset, size_t num); // Process values with given steps -PipelineResult Process(std::vector values, absl::Span steps); +PipelineResult Process(std::vector values, + absl::Span fields_to_print, + absl::Span steps); } // namespace dfly::aggregate diff --git a/src/server/search/search_family.cc b/src/server/search/search_family.cc index f1151dc60a67..1b37d1943de2 100644 --- a/src/server/search/search_family.cc +++ b/src/server/search/search_family.cc @@ -980,22 +980,35 @@ void SearchFamily::FtAggregate(CmdArgList args, Transaction* tx, SinkReplyBuilde make_move_iterator(sub_results.end())); } - auto agg_results = aggregate::Process(std::move(values), params->steps); - if (!agg_results.has_value()) - return builder->SendError(agg_results.error()); + std::vector load_fields; + if (params->load_fields) { + load_fields.reserve(params->load_fields->size()); + for (const auto& field : params->load_fields.value()) { + load_fields.push_back(field.GetShortName()); + } + } + + auto agg_results = aggregate::Process(std::move(values), load_fields, params->steps); - size_t result_size = agg_results->size(); auto* rb = static_cast(builder); auto sortable_value_sender = SortableValueSender(rb); + const size_t result_size = agg_results.values.size(); rb->StartArray(result_size + 1); rb->SendLong(result_size); - for (const auto& result : agg_results.value()) { - rb->StartArray(result.size() * 2); - for (const auto& [k, v] : result) { - rb->SendBulkString(k); - std::visit(sortable_value_sender, v); + const size_t field_count = agg_results.fields_to_print.size(); + for (const auto& value : agg_results.values) { + rb->StartArray(field_count * 2); + for (const auto& field : agg_results.fields_to_print) { + rb->SendBulkString(field); + + auto it = value.find(field); + if (it != value.end()) { + std::visit(sortable_value_sender, it->second); + } else { + rb->SendNull(); + } } } } diff --git a/src/server/search/search_family_test.cc b/src/server/search/search_family_test.cc index 9fa68bd66757..fe89c412f27a 100644 --- a/src/server/search/search_family_test.cc +++ b/src/server/search/search_family_test.cc @@ -962,15 +962,12 @@ TEST_F(SearchFamilyTest, AggregateGroupBy) { EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("foo_total", "20", "word", "item2"), IsMap("foo_total", "50", "word", "item1"))); - /* - Temporary not supported - resp = Run({"ft.aggregate", "i1", "*", "LOAD", "2", "foo", "text", "GROUPBY", "2", "@word", - "@text", "REDUCE", "SUM", "1", "@foo", "AS", "foo_total"}); EXPECT_THAT(resp, - IsUnordArrayWithSize(IsMap("foo_total", "20", "word", ArgType(RespExpr::NIL), "text", "\"second - key\""), IsMap("foo_total", "40", "word", ArgType(RespExpr::NIL), "text", "\"third key\""), - IsMap({"foo_total", "10", "word", ArgType(RespExpr::NIL), "text", "\"first key"}))); - */ + "@text", "REDUCE", "SUM", "1", "@foo", "AS", "foo_total"}); + EXPECT_THAT(resp, IsUnordArrayWithSize( + IsMap("foo_total", "40", "word", "item1", "text", "\"third key\""), + IsMap("foo_total", "20", "word", "item2", "text", "\"second key\""), + IsMap("foo_total", "10", "word", "item1", "text", "\"first key\""))); } TEST_F(SearchFamilyTest, JsonAggregateGroupBy) { @@ -1632,4 +1629,32 @@ TEST_F(SearchFamilyTest, SearchLoadReturnHash) { EXPECT_THAT(resp, IsMapWithSize("h2", IsMap("a", "two"), "h1", IsMap("a", "one"))); } +// Test that FT.AGGREGATE prints only needed fields +TEST_F(SearchFamilyTest, AggregateResultFields) { + Run({"JSON.SET", "j1", ".", R"({"a":"1","b":"2","c":"3"})"}); + Run({"JSON.SET", "j2", ".", R"({"a":"4","b":"5","c":"6"})"}); + Run({"JSON.SET", "j3", ".", R"({"a":"7","b":"8","c":"9"})"}); + + auto resp = Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.a", "AS", "a", "TEXT", + "SORTABLE", "$.b", "AS", "b", "TEXT", "$.c", "AS", "c", "TEXT"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"FT.AGGREGATE", "index", "*"}); + EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap(), IsMap(), IsMap())); + + resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "a"}); + EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("a", "1"), IsMap("a", "4"), IsMap("a", "7"))); + + resp = Run({"FT.AGGREGATE", "index", "*", "LOAD", "1", "@b", "SORTBY", "1", "a"}); + EXPECT_THAT(resp, + IsUnordArrayWithSize(IsMap("b", "\"2\"", "a", "1"), IsMap("b", "\"5\"", "a", "4"), + IsMap("b", "\"8\"", "a", "7"))); + + resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "a", "GROUPBY", "2", "@b", "@a", + "REDUCE", "COUNT", "0", "AS", "count"}); + EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("b", "\"8\"", "a", "7", "count", "1"), + IsMap("b", "\"2\"", "a", "1", "count", "1"), + IsMap("b", "\"5\"", "a", "4", "count", "1"))); +} + } // namespace dfly