Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(search_family): Remove the output of extra fields in the FT.AGGREGATE command #4231

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions src/server/search/aggregator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ namespace dfly::aggregate {
namespace {

struct GroupStep {
PipelineResult operator()(std::vector<DocValues> values) {
PipelineResult operator()(PipelineResult result) {
// Separate items into groups
absl::flat_hash_map<absl::FixedArray<Value>, std::vector<DocValues>> groups;
for (auto& value : values) {
for (auto& value : result.values) {
groups[Extract(value)].push_back(std::move(value));
}

Expand All @@ -28,7 +28,18 @@ struct GroupStep {
}
out.push_back(std::move(doc));
}
return out;

absl::flat_hash_set<std::string> fields_to_print;
fields_to_print.reserve(fields_.size() + reducers_.size());

for (auto& field : fields_) {
fields_to_print.insert(std::move(field));
}
for (auto& reducer : reducers_) {
fields_to_print.insert(std::move(reducer.result_field));
}

return {std::move(out), std::move(fields_to_print)};
}

absl::FixedArray<Value> Extract(const DocValues& dv) {
Expand Down Expand Up @@ -104,34 +115,42 @@ PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
}

PipelineStep MakeSortStep(std::string_view field, bool descending) {
return [field = std::string(field), descending](std::vector<DocValues> values) -> PipelineResult {
return [field = std::string(field), descending](PipelineResult result) -> PipelineResult {
auto& values = result.values;

std::sort(values.begin(), values.end(), [field](const DocValues& l, const DocValues& r) {
auto it1 = l.find(field);
auto it2 = r.find(field);
return it1 == l.end() || (it2 != r.end() && it1->second < it2->second);
});
if (descending)

if (descending) {
std::reverse(values.begin(), values.end());
return values;
}

result.fields_to_print.insert(field);
return result;
};
}

PipelineStep MakeLimitStep(size_t offset, size_t num) {
return [offset, num](std::vector<DocValues> values) -> PipelineResult {
return [offset, num](PipelineResult result) {
auto& values = result.values;
values.erase(values.begin(), values.begin() + std::min(offset, values.size()));
values.resize(std::min(num, values.size()));
return values;
return result;
};
}

PipelineResult Process(std::vector<DocValues> values, absl::Span<const PipelineStep> steps) {
PipelineResult Process(std::vector<DocValues> values,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put your prod-duty hat, and think about all the VLOG statements you would like to have, if we have problems in production with this code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

absl::Span<const std::string_view> fields_to_print,
absl::Span<const PipelineStep> steps) {
PipelineResult result{std::move(values), {fields_to_print.begin(), fields_to_print.end()}};
for (auto& step : steps) {
auto result = step(std::move(values));
if (!result.has_value())
return result;
values = std::move(result.value());
PipelineResult step_result = step(std::move(result));
result = std::move(step_result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to see step method modifying the result instead of all these moves

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean that we need to refactor this logic and have something like Aggregator that stores current result and has methods like DoSort, DoReduce? Or use pointer here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about this question? to refactor steps and modify the result instead of moving it every time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I spoke about it today. I'm finalizing it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you going to create another PR?

Copy link
Contributor Author

@BagritsevichStepan BagritsevichStepan Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that it would be better to create another PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also for logs that Roman mentioned, because we first need to refactor the code before adding the logs in the appropriate places

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}
return values;
return result;
}

} // namespace dfly::aggregate
17 changes: 13 additions & 4 deletions src/server/search/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#pragma once

#include <absl/container/flat_hash_map.h>
#include <absl/container/flat_hash_set.h>
#include <absl/types/span.h>

#include <string>
Expand All @@ -19,10 +20,16 @@ namespace dfly::aggregate {
using Value = ::dfly::search::SortableValue;
using DocValues = absl::flat_hash_map<std::string, Value>; // documents sent through the pipeline

// TODO: Replace DocValues with compact linear search map instead of hash map
struct PipelineResult {
// Values to be passed to the next step
// TODO: Replace DocValues with compact linear search map instead of hash map
std::vector<DocValues> values;

using PipelineResult = io::Result<std::vector<DocValues>, facade::ErrorReply>;
using PipelineStep = std::function<PipelineResult(std::vector<DocValues>)>; // Group, Sort, etc.
// Fields from values to be printed
absl::flat_hash_set<std::string> fields_to_print;
};

using PipelineStep = std::function<PipelineResult(PipelineResult)>; // Group, Sort, etc.

// Iterator over Span<DocValues> that yields doc[field] or monostate if not present.
// Extra clumsy for STL compatibility!
Expand Down Expand Up @@ -82,6 +89,8 @@ PipelineStep MakeSortStep(std::string_view field, bool descending = false);
PipelineStep MakeLimitStep(size_t offset, size_t num);

// Process values with given steps
PipelineResult Process(std::vector<DocValues> values, absl::Span<const PipelineStep> steps);
PipelineResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const PipelineStep> steps);

} // namespace dfly::aggregate
52 changes: 24 additions & 28 deletions src/server/search/aggregator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@ TEST(AggregatorTest, Sort) {
};
PipelineStep steps[] = {MakeSortStep("a", false)};

auto result = Process(values, steps);
auto result = Process(values, {"a"}, steps);

EXPECT_TRUE(result);
EXPECT_EQ(result->at(0)["a"], Value(0.5));
EXPECT_EQ(result->at(1)["a"], Value(1.0));
EXPECT_EQ(result->at(2)["a"], Value(1.5));
EXPECT_EQ(result.values[0]["a"], Value(0.5));
EXPECT_EQ(result.values[1]["a"], Value(1.0));
EXPECT_EQ(result.values[2]["a"], Value(1.5));
}

TEST(AggregatorTest, Limit) {
Expand All @@ -35,12 +34,11 @@ TEST(AggregatorTest, Limit) {
};
PipelineStep steps[] = {MakeLimitStep(1, 2)};

auto result = Process(values, steps);
auto result = Process(values, {"i"}, steps);

EXPECT_TRUE(result);
EXPECT_EQ(result->size(), 2);
EXPECT_EQ(result->at(0)["i"], Value(2.0));
EXPECT_EQ(result->at(1)["i"], Value(3.0));
EXPECT_EQ(result.values.size(), 2);
EXPECT_EQ(result.values[0]["i"], Value(2.0));
EXPECT_EQ(result.values[1]["i"], Value(3.0));
}

TEST(AggregatorTest, SimpleGroup) {
Expand All @@ -54,12 +52,11 @@ TEST(AggregatorTest, SimpleGroup) {
std::string_view fields[] = {"tag"};
PipelineStep steps[] = {MakeGroupStep(fields, {})};

auto result = Process(values, steps);
EXPECT_TRUE(result);
EXPECT_EQ(result->size(), 2);
auto result = Process(values, {"i", "tag"}, steps);
EXPECT_EQ(result.values.size(), 2);

EXPECT_EQ(result->at(0).size(), 1);
std::set<Value> groups{result->at(0)["tag"], result->at(1)["tag"]};
EXPECT_EQ(result.values[0].size(), 1);
std::set<Value> groups{result.values[0]["tag"], result.values[1]["tag"]};
std::set<Value> expected{"even", "odd"};
EXPECT_EQ(groups, expected);
}
Expand All @@ -83,25 +80,24 @@ TEST(AggregatorTest, GroupWithReduce) {
Reducer{"null-field", "distinct-null", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)}};
PipelineStep steps[] = {MakeGroupStep(fields, std::move(reducers))};

auto result = Process(values, steps);
EXPECT_TRUE(result);
EXPECT_EQ(result->size(), 2);
auto result = Process(values, {"i", "half-i", "tag"}, steps);
EXPECT_EQ(result.values.size(), 2);

// Reorder even first
if (result->at(0).at("tag") == Value("odd"))
std::swap(result->at(0), result->at(1));
if (result.values[0].at("tag") == Value("odd"))
std::swap(result.values[0], result.values[1]);

// Even
EXPECT_EQ(result->at(0).at("count"), Value{(double)5});
EXPECT_EQ(result->at(0).at("sum-i"), Value{(double)2 + 4 + 6 + 8});
EXPECT_EQ(result->at(0).at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result->at(0).at("distinct-null"), Value{(double)1});
EXPECT_EQ(result.values[0].at("count"), Value{(double)5});
EXPECT_EQ(result.values[0].at("sum-i"), Value{(double)2 + 4 + 6 + 8});
EXPECT_EQ(result.values[0].at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result.values[0].at("distinct-null"), Value{(double)1});

// Odd
EXPECT_EQ(result->at(1).at("count"), Value{(double)5});
EXPECT_EQ(result->at(1).at("sum-i"), Value{(double)1 + 3 + 5 + 7 + 9});
EXPECT_EQ(result->at(1).at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result->at(1).at("distinct-null"), Value{(double)1});
EXPECT_EQ(result.values[1].at("count"), Value{(double)5});
EXPECT_EQ(result.values[1].at("sum-i"), Value{(double)1 + 3 + 5 + 7 + 9});
EXPECT_EQ(result.values[1].at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result.values[1].at("distinct-null"), Value{(double)1});
}

} // namespace dfly::aggregate
30 changes: 21 additions & 9 deletions src/server/search/search_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -981,22 +981,34 @@ void SearchFamily::FtAggregate(CmdArgList args, const CommandContext& cmd_cntx)
make_move_iterator(sub_results.end()));
}

auto agg_results = aggregate::Process(std::move(values), params->steps);
if (!agg_results.has_value())
return builder->SendError(agg_results.error());
std::vector<std::string_view> load_fields;
if (params->load_fields) {
load_fields.reserve(params->load_fields->size());
for (const auto& field : params->load_fields.value()) {
load_fields.push_back(field.GetShortName());
}
}

auto agg_results = aggregate::Process(std::move(values), load_fields, params->steps);

size_t result_size = agg_results->size();
auto* rb = static_cast<RedisReplyBuilder*>(cmd_cntx.rb);
auto sortable_value_sender = SortableValueSender(rb);

const size_t result_size = agg_results.values.size();
rb->StartArray(result_size + 1);
rb->SendLong(result_size);

for (const auto& result : agg_results.value()) {
rb->StartArray(result.size() * 2);
for (const auto& [k, v] : result) {
rb->SendBulkString(k);
std::visit(sortable_value_sender, v);
const size_t field_count = agg_results.fields_to_print.size();
for (const auto& value : agg_results.values) {
rb->StartArray(field_count * 2);
for (const auto& field : agg_results.fields_to_print) {
rb->SendBulkString(field);

if (auto it = value.find(field); it != value.end()) {
std::visit(sortable_value_sender, it->second);
} else {
rb->SendNull();
}
}
}
}
Expand Down
41 changes: 33 additions & 8 deletions src/server/search/search_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -962,15 +962,12 @@ TEST_F(SearchFamilyTest, AggregateGroupBy) {
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("foo_total", "20", "word", "item2"),
IsMap("foo_total", "50", "word", "item1")));

/*
Temporary not supported

resp = Run({"ft.aggregate", "i1", "*", "LOAD", "2", "foo", "text", "GROUPBY", "2", "@word",
"@text", "REDUCE", "SUM", "1", "@foo", "AS", "foo_total"}); EXPECT_THAT(resp,
IsUnordArrayWithSize(IsMap("foo_total", "20", "word", ArgType(RespExpr::NIL), "text", "\"second
key\""), IsMap("foo_total", "40", "word", ArgType(RespExpr::NIL), "text", "\"third key\""),
IsMap({"foo_total", "10", "word", ArgType(RespExpr::NIL), "text", "\"first key"})));
*/
"@text", "REDUCE", "SUM", "1", "@foo", "AS", "foo_total"});
EXPECT_THAT(resp, IsUnordArrayWithSize(
IsMap("foo_total", "40", "word", "item1", "text", "\"third key\""),
IsMap("foo_total", "20", "word", "item2", "text", "\"second key\""),
IsMap("foo_total", "10", "word", "item1", "text", "\"first key\"")));
}

TEST_F(SearchFamilyTest, JsonAggregateGroupBy) {
Expand Down Expand Up @@ -1632,4 +1629,32 @@ TEST_F(SearchFamilyTest, SearchLoadReturnHash) {
EXPECT_THAT(resp, IsMapWithSize("h2", IsMap("a", "two"), "h1", IsMap("a", "one")));
}

// Test that FT.AGGREGATE prints only needed fields
TEST_F(SearchFamilyTest, AggregateResultFields) {
Run({"JSON.SET", "j1", ".", R"({"a":"1","b":"2","c":"3"})"});
Run({"JSON.SET", "j2", ".", R"({"a":"4","b":"5","c":"6"})"});
Run({"JSON.SET", "j3", ".", R"({"a":"7","b":"8","c":"9"})"});

auto resp = Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.a", "AS", "a", "TEXT",
"SORTABLE", "$.b", "AS", "b", "TEXT", "$.c", "AS", "c", "TEXT"});
EXPECT_EQ(resp, "OK");

resp = Run({"FT.AGGREGATE", "index", "*"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap(), IsMap(), IsMap()));

resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "a"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("a", "1"), IsMap("a", "4"), IsMap("a", "7")));

resp = Run({"FT.AGGREGATE", "index", "*", "LOAD", "1", "@b", "SORTBY", "1", "a"});
EXPECT_THAT(resp,
IsUnordArrayWithSize(IsMap("b", "\"2\"", "a", "1"), IsMap("b", "\"5\"", "a", "4"),
IsMap("b", "\"8\"", "a", "7")));

resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "a", "GROUPBY", "2", "@b", "@a",
"REDUCE", "COUNT", "0", "AS", "count"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("b", "\"8\"", "a", "7", "count", "1"),
IsMap("b", "\"2\"", "a", "1", "count", "1"),
IsMap("b", "\"5\"", "a", "4", "count", "1")));
}

} // namespace dfly
Loading