Skip to content

Commit

Permalink
add support for the deprecated field. Now it is correctly working for…
Browse files Browse the repository at this point in the history
… Dify.
  • Loading branch information
tybalex committed Aug 9, 2024
1 parent de003e5 commit a98f4bc
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 26 deletions.
39 changes: 26 additions & 13 deletions examples/server/function-call.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ std::string construct_json_tool_call_str(const json& tool_calls, nlohmann::order
}





const std::vector<json> expand_messages(const json & body, json &tool_name_map) {
std::string function_str = "";
if (body.contains("tools") && !body["tools"].empty()) {
Expand All @@ -243,13 +246,11 @@ const std::vector<json> expand_messages(const json & body, json &tool_name_map)
for (size_t i = 0; i < body["messages"].size(); ++i) {
if (body["messages"][i]["role"] != "tool" and func_observation_map.size() > 0) {
// insert the observation from the tool call before the next message
std::string observation_str = "";
std::vector<std::string> func_observation_array;
json func_json_array = json::array();
for (const auto& [key, value] : func_observation_map) {
func_observation_array.push_back(value);
func_json_array.push_back(value);
}
json func_json_array = func_observation_array;
observation_str = std::string("start observation ") + func_json_array.dump() + std::string(" end observation");
std::string observation_str = "start observation " + func_json_array.dump() + " end observation";
json observation_call;
observation_call["role"] = "user";
observation_call["content"] = observation_str;
Expand All @@ -274,10 +275,15 @@ const std::vector<json> expand_messages(const json & body, json &tool_name_map)
}
}
// else if (body["messages"][i]["role"] == "assistant" and (body["messages"][i]["content"].is_null() or body["messages"][i]["content"]=="") and !body["messages"][i]["tool_calls"].is_null() and !body["messages"][i]["tool_calls"].empty()){
else if (body["messages"][i]["role"] == "assistant" and body["messages"][i].contains("tool_calls")){
else if (body["messages"][i]["role"] == "assistant" and (body["messages"][i].contains("tool_calls") or body["messages"][i].contains("function_call"))){
// convert OpenAI function call format to Rubra format
// string tool_call_str = construct_python_tool_call_str(body["messages"][i]["tool_calls"], func_observation_map);
std::string tool_call_str = construct_json_tool_call_str(body["messages"][i]["tool_calls"], func_observation_map);
std::string tool_call_str;
if (body["messages"][i].contains("tool_calls")) {
tool_call_str = construct_json_tool_call_str(body["messages"][i]["tool_calls"], func_observation_map);
}
else {
tool_call_str = std::string("starttoolcall") + body["messages"][i]["function_call"].dump() + std::string("endtoolcall");
}
json function_call;
function_call["role"] = "assistant";
function_call["content"] = tool_call_str;
Expand All @@ -293,20 +299,27 @@ const std::vector<json> expand_messages(const json & body, json &tool_name_map)
}

}
else if (body["messages"][i]["role"] == "function") {
json func_json_array = json::array();
func_json_array.push_back(body["messages"][i]["content"]);
std::string observation_str = std::string("start observation ") + func_json_array.dump() + std::string(" end observation");
json observation_call;
observation_call["role"] = "user";
observation_call["content"] = observation_str;
temp_vec.push_back(observation_call);
}
else {
temp_vec.push_back(body["messages"][i]);
}

}
if (func_observation_map.size() > 0) {
// insert the observation from the tool call before the next message
std::string observation_str = "";
std::vector<std::string> func_observation_array;
json func_json_array = json::array();
for (const auto& [key, value] : func_observation_map) {
func_observation_array.push_back(value);
func_json_array.push_back(value);
}
json func_json_array = func_observation_array;
observation_str = std::string("start observation ") + func_json_array.dump() + std::string(" end observation");
std::string observation_str = "start observation " + func_json_array.dump() + " end observation";
json observation_call;
observation_call["role"] = "user";
observation_call["content"] = observation_str;
Expand Down
13 changes: 7 additions & 6 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3073,6 +3073,7 @@ int main(int argc, char ** argv) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);


const int id_task = ctx_server.queue_tasks.get_new_id();

ctx_server.queue_results.add_waiting_task_id(id_task);
Expand All @@ -3091,14 +3092,13 @@ int main(int argc, char ** argv) {
}
ctx_server.queue_results.remove_waiting_task_id(id_task);
} else {
const auto chunked_content_provider = [id_task, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
const auto chunked_content_provider = [id_task, &ctx_server, completion_id, data](size_t, httplib::DataSink & sink) {
std::string last_str = "";
bool is_function_call = false;
bool checked_function_call = false;
json last_result_data;

auto process_and_send_data = [&](const json& data) {
std::vector<json> result_array = format_partial_response_oaicompat(data, completion_id);
auto process_and_send_data = [&](const json& res_data) {
std::vector<json> result_array = format_partial_response_oaicompat(res_data, completion_id);

for (const auto& item : result_array) {
if (!item.empty()) {
Expand All @@ -3116,7 +3116,9 @@ int main(int argc, char ** argv) {

while (true) {
server_task_result result = ctx_server.queue_results.recv(id_task);

if (data.contains("tool_field")) {
result.data["tool_field"] = data["tool_field"];
}
if (result.error) {
const std::string error_str = "error: " + result.data.dump(-1, ' ', false, json::error_handler_t::replace) + "\n\n";
LOG_VERBOSE("data stream", {{"to_send", error_str}});
Expand All @@ -3132,7 +3134,6 @@ int main(int argc, char ** argv) {
std::string str_to_check = last_str + content;
is_function_call = (str_to_check.find("starttool") != std::string::npos);
}

if (!is_function_call && !last_str.empty()) {
std::string temp_str = content;
result.data["content"] = last_str;
Expand Down
46 changes: 39 additions & 7 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,13 @@ static json oaicompat_completion_params_parse(
llama_params["__oaicompat"] = true;
json tool_name_map;
const std::vector<json> expanded_messages = expand_messages(body, tool_name_map);
llama_params["tool_field"] = "tool_calls";
if (body.contains("tools") && !body["tools"].empty()) {
llama_params["tool_field"] = "tool_calls";
}
else if (body.contains("functions") && !body["functions"].empty()) {
llama_params["tool_field"] = "function_call";
}
llama_params["prompt"] = format_chat(model, chat_template, expanded_messages);
llama_params["tool_name_map"] = tool_name_map;

Expand Down Expand Up @@ -518,7 +525,6 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
return std::vector<json>({result});
}

bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));

Expand All @@ -527,6 +533,7 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
bool stopped_limit = json_value(result, "stopped_limit", false);
std::string content = json_value(result, "content", std::string(""));
std::vector<json> parsed_content = rubra_fc_json_tool_extractor(content);
std::string tool_field = json_value(result, "tool_field", std::string("tool_calls"));

std::string finish_reason;
if (stopped_word || stopped_eos) {
Expand All @@ -535,7 +542,6 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
if (stopped_limit) {
finish_reason = "length";
}

std::time_t t = std::time(0);

json choices;
Expand All @@ -544,6 +550,7 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
choices = json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}});

} else {
if (first) {
if (content.empty()) {
Expand Down Expand Up @@ -592,10 +599,27 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
};
oai_format_tool_calls.push_back(tool_call);
}
choices = json::array({json{{"finish_reason", nullptr},
if (tool_field == "tool_calls") {
choices = json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"tool_calls", oai_format_tool_calls},
{"delta", json{{tool_field, oai_format_tool_calls},
{"role", "assistant"}}}}});
}
else {
choices = json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{tool_field, oai_format_tool_calls[0]["function"]},
{"role", "assistant"}}}}});
}

json second_ret = json{
{"choices", choices},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}};

return std::vector<json>({initial_ret, second_ret});
}

}
Expand Down Expand Up @@ -632,10 +656,18 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
};
oai_format_tool_calls.push_back(tool_call);
}
choices = json::array({json{{"finish_reason", nullptr},
if (tool_field == "tool_calls") {
choices = json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{tool_field, oai_format_tool_calls},
{"role", "assistant"}}}}});
}
else {
choices = json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"tool_calls", oai_format_tool_calls},
{"delta", json{{tool_field, oai_format_tool_calls[0]["function"]},
{"role", "assistant"}}}}});
}
}

}
Expand All @@ -657,7 +689,7 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
}});
}

return std::vector<json>({ret});
}

Expand Down

0 comments on commit a98f4bc

Please sign in to comment.