From 63e489c025d61c7ca5ec06c5d10f36e2b76aaa1d Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Tue, 18 Feb 2025 18:03:23 +0000 Subject: [PATCH] tool-call: refactor common chat / tool-call api (+ tests / fixes) (#11900) * tool-call refactoring: moved common_chat_* to chat.h, common_chat_templates_init return a unique_ptr to opaque type * addressed clang-tidy lints in [test-]chat.* * rm minja deps from util & common & move it to common/minja/ * add name & tool_call_id to common_chat_msg * add common_chat_tool * added json <-> tools, msgs conversions to chat.h * fix double bos/eos jinja avoidance hack (was preventing inner bos/eos tokens) * fix deepseek r1 slow test (no longer opening w/ new template) * allow empty tools w/ auto + grammar * fix & test server grammar & json_schema params w/ & w/o --jinja --- Makefile | 2 +- common/CMakeLists.txt | 6 +- common/arg.cpp | 1 + common/chat.cpp | 730 ++++++++++++--- common/chat.h | 134 +++ common/chat.hpp | 55 -- common/common.cpp | 170 ---- common/common.h | 56 -- common/{ => minja}/chat-template.hpp | 0 common/{ => minja}/minja.hpp | 0 examples/main/main.cpp | 27 +- examples/run/run.cpp | 70 +- examples/server/server.cpp | 63 +- .../server/tests/unit/test_chat_completion.py | 45 +- examples/server/tests/unit/test_tool_call.py | 10 +- examples/server/utils.hpp | 126 +-- tests/test-chat-template.cpp | 54 +- tests/test-chat.cpp | 835 ++++++++++-------- 18 files changed, 1388 insertions(+), 996 deletions(-) create mode 100644 common/chat.h delete mode 100644 common/chat.hpp rename common/{ => minja}/chat-template.hpp (100%) rename common/{ => minja}/minja.hpp (100%) diff --git a/Makefile b/Makefile index 662194086eaaf..fb9a3b44890a0 100644 --- a/Makefile +++ b/Makefile @@ -1364,7 +1364,7 @@ llama-server: \ examples/server/index.html.hpp \ examples/server/loading.html.hpp \ common/chat.cpp \ - common/chat.hpp \ + common/chat.h \ common/chat-template.hpp \ common/json.hpp \ common/minja.hpp \ diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index c2b4aa7d09f1c..17146fffc1168 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -57,8 +57,7 @@ add_library(${TARGET} STATIC arg.h base64.hpp chat.cpp - chat.hpp - chat-template.hpp + chat.h common.cpp common.h console.cpp @@ -68,7 +67,8 @@ add_library(${TARGET} STATIC llguidance.cpp log.cpp log.h - minja.hpp + minja/chat-template.hpp + minja/minja.hpp ngram-cache.cpp ngram-cache.h sampling.cpp diff --git a/common/arg.cpp b/common/arg.cpp index f06aa1076cca7..eb8beccac2ee7 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2,6 +2,7 @@ #include "log.h" #include "sampling.h" +#include "chat.h" #include #include diff --git a/common/chat.cpp b/common/chat.cpp index f21a9d2a63a4b..9ebe4c5784cbc 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1,8 +1,433 @@ -#include "chat.hpp" -#include "chat-template.hpp" +#include "chat.h" #include "json-schema-to-grammar.h" #include "log.h" -#include "minja.hpp" +#include "minja/chat-template.hpp" +#include "minja/minja.hpp" + +#include + +typedef minja::chat_template common_chat_template; + +struct common_chat_templates { + bool has_explicit_template; // Model had builtin template or template overridde was specified. + std::unique_ptr template_default; // always set (defaults to chatml) + std::unique_ptr template_tool_use; +}; + +struct templates_params { + json messages; + json tools; + common_chat_tool_choice tool_choice; + json json_schema; + bool parallel_tool_calls; + bool stream; + std::string grammar; + bool add_generation_prompt = true; + bool extract_reasoning = true; +}; + +common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) { + if (tool_choice == "auto") { + return COMMON_CHAT_TOOL_CHOICE_AUTO; + } + if (tool_choice == "none") { + return COMMON_CHAT_TOOL_CHOICE_NONE; + } + if (tool_choice == "required") { + return COMMON_CHAT_TOOL_CHOICE_REQUIRED; + } + throw std::runtime_error("Invalid tool_choice: " + tool_choice); +} + +template <> +std::vector common_chat_msgs_parse_oaicompat(const json & messages) { + std::vector msgs; + + try { + + if (!messages.is_array()) { + throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump()); + } + + for (const auto & message : messages) { + if (!message.is_object()) { + throw std::runtime_error("Expected 'message' to be an object, got " + message.dump()); + } + + common_chat_msg msg; + if (!message.contains("role")) { + throw std::runtime_error("Missing 'role' in message: " + message.dump()); + } + msg.role = message.at("role"); + + if (message.contains("content")) { + const auto & content = message.at("content"); + if (content.is_string()) { + msg.content = content; + } else if (content.is_array()) { + for (const auto & part : content) { + if (!part.contains("type")) { + throw std::runtime_error("Missing content part type: " + part.dump()); + } + const auto & type = part.at("type"); + if (type != "text") { + throw std::runtime_error("Unsupported content part type: " + type.dump()); + } + common_chat_msg_content_part msg_part; + msg_part.type = type; + msg_part.text = part.at("text"); + msg.content_parts.push_back(msg_part); + } + } else if (!content.is_null()) { + throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: /~https://github.com/ggml-org/llama.cpp/issues/8367)"); + } + } else { + throw std::runtime_error("Expected 'content' (ref: /~https://github.com/ggml-org/llama.cpp/issues/8367)"); + } + if (message.contains("reasoning_content")) { + msg.reasoning_content = message.at("reasoning_content"); + } + if (message.contains("name")) { + msg.tool_name = message.at("name"); + } + if (message.contains("tool_call_id")) { + msg.tool_call_id = message.at("tool_call_id"); + } + if (message.contains("tool_calls")) { + for (const auto & tool_call : message.at("tool_calls")) { + common_chat_tool_call tc; + if (!tool_call.contains("type")) { + throw std::runtime_error("Missing tool call type: " + tool_call.dump()); + } + const auto & type = tool_call.at("type"); + if (type != "function") { + throw std::runtime_error("Unsupported tool call type: " + tool_call.dump()); + } + if (!tool_call.contains("function")) { + throw std::runtime_error("Missing tool call function: " + tool_call.dump()); + } + const auto & fc = tool_call.at("function"); + if (!fc.contains("name")) { + throw std::runtime_error("Missing tool call name: " + tool_call.dump()); + } + tc.name = fc.at("name"); + tc.arguments = fc.at("arguments"); + if (tool_call.contains("id")) { + tc.id = tool_call.at("id"); + } + msg.tool_calls.push_back(tc); + } + } + + msgs.push_back(msg); + } + } catch (const std::exception & e) { + throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2)); + } + + return msgs; +} + +template <> +json common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text) { + json messages = json::array(); + for (const auto & msg : msgs) { + if (!msg.content.empty() && !msg.content_parts.empty()) { + throw std::runtime_error("Cannot specify both content and content_parts"); + } + json jmsg { + {"role", msg.role}, + }; + if (!msg.content.empty()) { + jmsg["content"] = msg.content; + } else if (!msg.content_parts.empty()) { + if (concat_typed_text) { + std::string text; + for (const auto & part : msg.content_parts) { + if (part.type != "text") { + LOG_WRN("Ignoring content part type: %s\n", part.type.c_str()); + continue; + } + if (!text.empty()) { + text += '\n'; + } + text += part.text; + } + jmsg["content"] = text; + } else { + auto & parts = jmsg["content"] = json::array(); + for (const auto & part : msg.content_parts) { + parts.push_back({ + {"type", part.type}, + {"text", part.text}, + }); + } + } + } else { + jmsg["content"] = json(); // null + } + if (!msg.reasoning_content.empty()) { + jmsg["reasoning_content"] = msg.reasoning_content; + } + if (!msg.tool_name.empty()) { + jmsg["name"] = msg.tool_name; + } + if (!msg.tool_call_id.empty()) { + jmsg["tool_call_id"] = msg.tool_call_id; + } + if (!msg.tool_calls.empty()) { + auto & tool_calls = jmsg["tool_calls"] = json::array(); + for (const auto & tool_call : msg.tool_calls) { + json tc { + {"type", "function"}, + {"function", { + {"name", tool_call.name}, + {"arguments", tool_call.arguments}, + }}, + }; + if (!tool_call.id.empty()) { + tc["id"] = tool_call.id; + } + tool_calls.push_back(tc); + } + } + messages.push_back(jmsg); + } + return messages; +} + +template <> +std::vector common_chat_msgs_parse_oaicompat(const std::string & messages) { + return common_chat_msgs_parse_oaicompat(json::parse(messages)); +} + +template <> +std::vector common_chat_tools_parse_oaicompat(const json & tools) { + std::vector result; + + try { + if (!tools.is_null()) { + if (!tools.is_array()) { + throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump()); + } + for (const auto & tool : tools) { + if (!tool.contains("type")) { + throw std::runtime_error("Missing tool type: " + tool.dump()); + } + const auto & type = tool.at("type"); + if (!type.is_string() || type != "function") { + throw std::runtime_error("Unsupported tool type: " + tool.dump()); + } + if (!tool.contains("function")) { + throw std::runtime_error("Missing tool function: " + tool.dump()); + } + + const auto & function = tool.at("function"); + result.push_back({ + /* .name = */ function.at("name"), + /* .description = */ function.at("description"), + /* .parameters = */ function.at("parameters").dump(), + }); + } + } + } catch (const std::exception & e) { + throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2)); + } + + return result; +} + +template <> +std::vector common_chat_tools_parse_oaicompat(const std::string & tools) { + return common_chat_tools_parse_oaicompat(json::parse(tools)); +} + +template <> +json common_chat_tools_to_json_oaicompat(const std::vector & tools) { + if (tools.empty()) { + return json(); + } + + auto result = json::array(); + for (const auto & tool : tools) { + result.push_back({ + {"type", "function"}, + {"function", { + {"name", tool.name}, + {"description", tool.description}, + {"parameters", json::parse(tool.parameters)}, + }}, + }); + } + return result; +} + +bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { + if (use_jinja) { + try { + common_chat_msg msg; + msg.role = "user"; + msg.content = "test"; + + auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl); + + common_chat_templates_inputs inputs; + inputs.messages = {msg}; + + common_chat_templates_apply(tmpls.get(), inputs); + return true; + } catch (const std::exception & e) { + LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); + return false; + } + } + llama_chat_message chat[] = {{"user", "test"}}; + const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0); + return res >= 0; +} + +std::string common_chat_format_single( + const struct common_chat_templates * tmpls, + const std::vector & past_msg, + const common_chat_msg & new_msg, + bool add_ass, + bool use_jinja) { + + common_chat_templates_inputs inputs; + inputs.use_jinja = use_jinja; + + std::string fmt_past_msg; + if (!past_msg.empty()) { + inputs.messages = past_msg; + inputs.add_generation_prompt = false; + fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt; + } + std::ostringstream ss; + // if the past_msg ends with a newline, we must preserve it in the formatted version + if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { + ss << "\n"; + }; + // format chat with new_msg + inputs.messages.push_back(new_msg); + inputs.add_generation_prompt = add_ass; + auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt; + // get the diff part + ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); + return ss.str(); +} + +std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) { + common_chat_templates_inputs inputs; + inputs.use_jinja = use_jinja; + auto add_simple_msg = [&](auto role, auto content) { + common_chat_msg msg; + msg.role = role; + msg.content = content; + inputs.messages.push_back(msg); + }; + add_simple_msg("system", "You are a helpful assistant"); + add_simple_msg("user", "Hello"); + add_simple_msg("assistant", "Hi there"); + add_simple_msg("user", "How are you?"); + return common_chat_templates_apply(tmpls, inputs).prompt; +} + +#define CHATML_TEMPLATE_SRC \ + "{%- for message in messages -%}\n" \ + " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \ + "{%- endfor -%}\n" \ + "{%- if add_generation_prompt -%}\n" \ + " {{- '<|im_start|>assistant\n' -}}\n" \ + "{%- endif -%}" + +void common_chat_templates_free(struct common_chat_templates * tmpls) { + delete tmpls; +} + +bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls) { + return tmpls->has_explicit_template; +} + +const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) { + if (variant != nullptr) { + if (strcmp(variant, "tool_use") == 0) { + if (tmpls->template_tool_use) { + return tmpls->template_tool_use->source().c_str(); + } + return nullptr; + } else { + LOG_DBG("%s: unknown template variant: %s\n", __func__, variant); + } + } + return tmpls->template_default->source().c_str(); +} + +common_chat_templates_ptr common_chat_templates_init( + const struct llama_model * model, + const std::string & chat_template_override, + const std::string & bos_token_override, + const std::string & eos_token_override) +{ + std::string default_template_src; + std::string template_tool_use_src; + + bool has_explicit_template = !chat_template_override.empty(); + if (chat_template_override.empty()) { + GGML_ASSERT(model != nullptr); + const auto * str = llama_model_chat_template(model, /* name */ nullptr); + if (str) { + default_template_src = str; + has_explicit_template = true; + } + str = llama_model_chat_template(model, /* name */ "tool_use"); + if (str) { + template_tool_use_src = str; + has_explicit_template = true; + } + } else { + default_template_src = chat_template_override; + } + if (default_template_src.empty() || default_template_src == "chatml") { + if (!template_tool_use_src.empty()) { + default_template_src = template_tool_use_src; + } else { + default_template_src = CHATML_TEMPLATE_SRC; + } + } + std::string token_bos = bos_token_override; + std::string token_eos = eos_token_override; + if (model) { + const auto * vocab = llama_model_get_vocab(model); + const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) { + if (token == LLAMA_TOKEN_NULL) { + if (default_template_src.find(jinja_variable_name) != std::string::npos + || template_tool_use_src.find(jinja_variable_name) != std::string::npos) { + LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name); + } + return std::string(); + } + return common_token_to_piece(vocab, token, true); + }; + token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); + token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); + } + common_chat_templates_ptr tmpls(new common_chat_templates()); + tmpls->has_explicit_template = has_explicit_template; + try { + tmpls->template_default = std::make_unique(default_template_src, token_bos, token_eos); + } catch (const std::exception & e) { + LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what()); + tmpls->template_default = std::make_unique(CHATML_TEMPLATE_SRC, token_bos, token_eos); + } + if (!template_tool_use_src.empty()) { + try { + tmpls->template_tool_use = std::make_unique(template_tool_use_src, token_bos, token_eos); + } catch (const std::exception & e) { + LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what()); + } + } + return tmpls; +} std::string common_chat_format_name(common_chat_format format) { switch (format) { @@ -38,22 +463,22 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons json_error_locator() : position(0), found_error(false) {} - bool parse_error(std::size_t position, const std::string &, const json::exception &) override { + bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT this->position = position - 1; this->found_error = true; return false; } - bool null() override { return true; } - bool boolean(bool) override { return true; } - bool number_integer(number_integer_t) override { return true; } - bool number_unsigned(number_unsigned_t) override { return true; } - bool number_float(number_float_t, const string_t &) override { return true; } - bool string(string_t &) override { return true; } - bool binary(binary_t &) override { return true; } - bool start_object(std::size_t) override { return true; } - bool key(string_t &) override { return true; } + bool null() override { return true; } // NOLINT + bool boolean(bool) override { return true; } // NOLINT + bool number_integer(number_integer_t) override { return true; } // NOLINT + bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT + bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT + bool string(string_t &) override { return true; } // NOLINT + bool binary(binary_t &) override { return true; } // NOLINT + bool start_object(std::size_t) override { return true; } // NOLINT + bool key(string_t &) override { return true; } // NOLINT bool end_object() override { return true; } - bool start_array(std::size_t) override { return true; } + bool start_array(std::size_t) override { return true; } // NOLINT bool end_array() override { return true; } }; json_error_locator err_loc; @@ -187,13 +612,20 @@ static std::string apply( // tmpl_inputs.now = std::chrono::system_clock::now(); minja::chat_template_options tmpl_opts; - tmpl_opts.use_bos_token = false; - tmpl_opts.use_eos_token = false; - - return tmpl.apply(tmpl_inputs, tmpl_opts); + // To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens + // instead of using `chat_template_options.use_bos_token = false`, since these tokens + // may be needed inside the template / between messages too. + auto result = tmpl.apply(tmpl_inputs, tmpl_opts); + if (string_starts_with(result, tmpl.bos_token())) { + result = result.substr(tmpl.bos_token().size()); + } + if (string_ends_with(result, tmpl.eos_token())) { + result = result.substr(0, result.size() - tmpl.eos_token().size()); + } + return result; } -static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; auto tool_call_schemas = json::array(); @@ -247,7 +679,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp {"required", json::array({"tool_call"})}, }; const auto schema = - inputs.tool_choice != "required" + inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED ? json { {"anyOf", json::array({ tool_call, @@ -303,9 +735,9 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) { return result; } -static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; - data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); foreach_function(inputs.tools, [&](const json & tool) { @@ -348,9 +780,9 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input) return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); } -static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; - data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); foreach_function(inputs.tools, [&](const json & tool) { @@ -455,10 +887,10 @@ static void expect_tool_parameters(const std::string & name, const json & parame const auto & parameters_required = parameters.at("required"); for (const auto & prop : expected_properties) { if (!parameters_properties.contains(prop)) { - throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); + throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT } if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) { - throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); + throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT } } if (parameters_properties.size() != expected_properties.size()) { @@ -466,18 +898,16 @@ static void expect_tool_parameters(const std::string & name, const json & parame } } -static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) { +static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) { auto builtin_tools = json::array(); common_chat_params data; - data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; auto handle_builtin_tool = [&](const std::string & name, const json & parameters) { - if (name == "wolfram_alpha") { + if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") { // /~https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py - expect_tool_parameters(name, parameters, {"query"}); - } else if (name == "web_search" || name == "brave_search") { // /~https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py expect_tool_parameters(name, parameters, {"query"}); } else if (name == "python" || name == "code_interpreter") { @@ -489,7 +919,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com std::vector kvs; for (const auto & [key, value] : parameters.at("properties").items()) { - kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); + kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT } tool_rules.push_back( @@ -560,34 +990,33 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo auto arg_value_str = raw_args.substr(it_eq + 1); auto arg_value = json::parse(arg_value_str); - return { - /* .role = */ "assistant", - /* .content = */ match.prefix().str(), - /* .tool_calls = */ { - { - /* .name = */ match[1], - /* .arguments = */ (json { - {arg_name, arg_value}, - }).dump(), - /* .id = */ "", - }, - }, - }; + common_chat_msg msg; + msg.role = "assistant"; + msg.content = match.prefix().str(); + msg.tool_calls.push_back({ + /* .name = */ name, + /* .arguments = */ (json { + {arg_name, arg_value}, + }).dump(), + /* .id = */ "", + }); + return msg; } } return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); } -static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != "required" && inputs.json_schema.is_null(); + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null(); data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool.at("function"); std::string name = function.at("name"); auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); auto args_rule = builder.add_schema(name + "-args", parameters); tool_rules.push_back(builder.add_rule(name + "-call", "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n" @@ -666,15 +1095,15 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, return msg; } -static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { - fprintf(stderr, "%s\n", __func__); +static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { + LOG_DBG("%s\n", __func__); common_chat_params data; data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { {"datetime", "Jan 29 2025 13:00:00 GMT"}, {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, }); if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); foreach_function(inputs.tools, [&](const json & tool) { @@ -712,14 +1141,14 @@ static common_chat_msg common_chat_parse_firefunction_v2(const std::string & inp return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1); } -static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) { // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar common_chat_params data; data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector first_tool_rules; std::vector subsequent_tool_rules; @@ -727,6 +1156,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ const auto & function = tool.at("function"); std::string name = function.at("name"); auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); auto args_rule = builder.add_schema(name + "-args", parameters); first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule)); subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule)); @@ -795,14 +1225,14 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in } } -static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) { // /~https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt common_chat_params data; json tools = inputs.tools.is_null() ? inputs.tools : json::array(); std::string python_code_argument_name; auto has_raw_python = false; - data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; foreach_function(inputs.tools, [&](const json & tool) { @@ -814,7 +1244,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con throw std::runtime_error("Missing type in python tool"); } has_raw_python = true; - auto type = parameters.at("type"); + const auto & type = parameters.at("type"); if (type == "object") { auto properties = parameters.at("properties"); for (auto it = properties.begin(); it != properties.end(); ++it) { @@ -854,17 +1284,15 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s std::smatch match; if (std::regex_search(input, match, python_tag_regex)) { auto code = match[1].str(); - return { - /* .role = */ "assistant", - /* .content = */ match.prefix().str(), - /* .tool_calls = */ { - { - /* .name = */ "python", - /* .arguments = */ (json {{"code", code}}).dump(), - /* .id = */ "", - }, - } - }; + common_chat_msg msg; + msg.role = "assistant"; + msg.content = match.prefix().str(); + msg.tool_calls.push_back({ + /* .name = */ "python", + /* .arguments = */ (json {{"code", code}}).dump(), + /* .id = */ "", + }); + return msg; } static std::regex function_regex(R"()"); static std::regex close_regex(R"()"); @@ -872,10 +1300,10 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); } -static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; // (content)?({"name": "foo", "arguments": {"a": 1}})* - data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; foreach_function(inputs.tools, [&](const json & tool) { @@ -908,20 +1336,18 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) std::regex middle_pattern(R"([\n\s]*[\n\s]*)"); std::regex end_pattern(R"([\n\s]*[\n\s]*$)"); + common_chat_msg msg; + msg.role = "assistant"; + auto end = input.end(); std::sregex_iterator rend; std::sregex_iterator rit(input.begin(), end, start_pattern); if (rit == rend) { - return { - /* .role = */ "assistant", - /* .content = */ input, - /* .tool_calls = */ {}, - }; + msg.content = input; + return msg; } - common_chat_msg result; - result.role = "assistant"; - result.content = rit->prefix(); + msg.content = rit->prefix(); auto it = rit->suffix().first; while (it != end) { @@ -930,7 +1356,7 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) throw std::runtime_error("Failed to parse json tool call"); } const auto & arguments = call.at("arguments"); - result.tool_calls.push_back({ + msg.tool_calls.push_back({ call.at("name"), arguments.dump(), // arguments.is_string() ? arguments.get() : arguments.dump(), @@ -947,17 +1373,17 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) break; } } - return result; + return msg; } catch (const std::exception & e) { - return { - /* .role = */ "assistant", - /* .content = */ input, - /* .tool_calls = */ {}, - }; + LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what()); + common_chat_msg msg; + msg.role = "assistant"; + msg.content = input; + return msg; } } -static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; @@ -973,12 +1399,35 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha return data; } -common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_templates_apply_jinja( + const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs) +{ + templates_params params; + params.tools = common_chat_tools_to_json_oaicompat(inputs.tools); + const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use + ? *tmpls->template_tool_use + : *tmpls->template_default; const auto & src = tmpl.source(); const auto & caps = tmpl.original_caps(); + params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content); + params.add_generation_prompt = inputs.add_generation_prompt; + params.extract_reasoning = inputs.extract_reasoning; + params.tool_choice = inputs.tool_choice; + params.grammar = inputs.grammar; + if (!inputs.json_schema.empty()) { + params.json_schema = json::parse(inputs.json_schema); + } - if (inputs.tools.is_array()) { - if (inputs.tool_choice != "none" && !inputs.grammar.empty()) { + if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) { + LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n"); + params.parallel_tool_calls = false; + } else { + params.parallel_tool_calls = inputs.parallel_tool_calls; + } + + if (params.tools.is_array()) { + if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) { throw std::runtime_error("Cannot specify grammar with tools"); } if (caps.supports_tool_calls && !caps.supports_tools) { @@ -987,68 +1436,135 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co } // DeepSeek R1: use handler in all cases except json schema (thinking / tools). - if (src.find("<|tool▁calls▁begin|>") != std::string::npos && inputs.json_schema.is_null()) { - return common_chat_params_init_deepseek_r1(tmpl, inputs); + if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) { + return common_chat_params_init_deepseek_r1(tmpl, params); } // Command R7B: : use handler in all cases except json schema (thinking / tools). - if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) { - return common_chat_params_init_command_r7b(tmpl, inputs); + if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) { + return common_chat_params_init_command_r7b(tmpl, params); } // Use generic handler when mixing tools + JSON schema. // TODO: support that mix in handlers below. - if ((!inputs.tools.is_array() && inputs.json_schema.is_object())) { - return common_chat_params_init_generic(tmpl, inputs); + if ((params.tools.is_array() && params.json_schema.is_object())) { + return common_chat_params_init_generic(tmpl, params); } // Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases. if (src.find(">>>all") != std::string::npos) { - return common_chat_params_init_functionary_v3_2(tmpl, inputs); + return common_chat_params_init_functionary_v3_2(tmpl, params); } // Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases. if (src.find(" functools[") != std::string::npos) { - return common_chat_params_init_firefunction_v2(tmpl, inputs); + return common_chat_params_init_firefunction_v2(tmpl, params); } // Plain handler (no tools) - if (inputs.tools.is_null() || inputs.tool_choice == "none") { - return common_chat_params_init_without_tools(tmpl, inputs); + if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { + return common_chat_params_init_without_tools(tmpl, params); } // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools) if (src.find("") != std::string::npos) { - return common_chat_params_init_hermes_2_pro(tmpl, inputs); + return common_chat_params_init_hermes_2_pro(tmpl, params); } // Functionary v3.1 (w/ tools) if (src.find("<|start_header_id|>") != std::string::npos && src.find("ipython<|end_header_id|>") != std::string::npos) { auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; - return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools); + return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools); } // Mistral Nemo (w/ tools) if (src.find("[TOOL_CALLS]") != std::string::npos) { - return common_chat_params_init_mistral_nemo(tmpl, inputs); + return common_chat_params_init_mistral_nemo(tmpl, params); } // Generic fallback - return common_chat_params_init_generic(tmpl, inputs); + return common_chat_params_init_generic(tmpl, params); +} + +// Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template. +static common_chat_params common_chat_templates_apply_legacy( + const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs) +{ + int alloc_size = 0; + std::vector chat; + std::vector contents; + for (const auto & msg : inputs.messages) { + auto content = msg.content; + for (const auto & part : msg.content_parts) { + if (part.type != "text") { + LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str()); + continue; + } + if (!content.empty()) { + content += "\n";; + } + content += part.text; + } + contents.emplace_back(std::move(content)); + } + for (size_t i = 0; i < contents.size(); ++i) { + const auto & msg = inputs.messages[i]; + const auto & content = contents[i]; + chat.push_back({msg.role.c_str(), content.c_str()}); + alloc_size += (msg.role.size() + content.size()) * 1.25; + } + + std::vector buf(alloc_size); + + // run the first time to get the total output length + const auto & src = tmpls->template_default->source(); + int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size()); + + // error: chat template is not supported + if (res < 0) { + // if the custom "tmpl" is not supported, we throw an error + // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() + throw std::runtime_error("this custom template is not supported"); + } + + // if it turns out that our buffer is too small, we resize it + if ((size_t) res > buf.size()) { + buf.resize(res); + res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size()); + } + + common_chat_params params; + params.prompt = std::string(buf.data(), res); + if (!inputs.json_schema.empty()) { + params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema)); + } else { + params.grammar = inputs.grammar; + } + return params; +} + +common_chat_params common_chat_templates_apply( + const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs) +{ + GGML_ASSERT(tmpls != nullptr); + return inputs.use_jinja + ? common_chat_templates_apply_jinja(tmpls, inputs) + : common_chat_templates_apply_legacy(tmpls, inputs); } static common_chat_msg common_chat_parse_content_only(const std::string & input) { - return { - /* .role = */ "assistant", - /* .content = */ input, - /* .tool_calls = */ {}, - }; + common_chat_msg msg; + msg.role = "assistant"; + msg.content = input; + return msg; } common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) { diff --git a/common/chat.h b/common/chat.h new file mode 100644 index 0000000000000..e77bef82b9edd --- /dev/null +++ b/common/chat.h @@ -0,0 +1,134 @@ +// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers. + +#pragma once + +#include "common.h" +#include +#include + +struct common_chat_templates; + +struct common_chat_tool_call { + std::string name; + std::string arguments; + std::string id; +}; + +struct common_chat_msg_content_part { + std::string type; + std::string text; +}; + +struct common_chat_msg { + std::string role; + std::string content; + std::vector content_parts = {}; + std::vector tool_calls = {}; + std::string reasoning_content; + std::string tool_name; + std::string tool_call_id; +}; + +struct common_chat_tool { + std::string name; + std::string description; + std::string parameters; +}; + +enum common_chat_tool_choice { + COMMON_CHAT_TOOL_CHOICE_AUTO, + COMMON_CHAT_TOOL_CHOICE_REQUIRED, + COMMON_CHAT_TOOL_CHOICE_NONE, +}; + +enum common_chat_format { + COMMON_CHAT_FORMAT_CONTENT_ONLY, + COMMON_CHAT_FORMAT_GENERIC, + COMMON_CHAT_FORMAT_MISTRAL_NEMO, + COMMON_CHAT_FORMAT_LLAMA_3_X, + COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, + COMMON_CHAT_FORMAT_DEEPSEEK_R1, + COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, + COMMON_CHAT_FORMAT_FIREFUNCTION_V2, + COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, + COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, + COMMON_CHAT_FORMAT_HERMES_2_PRO, + COMMON_CHAT_FORMAT_COMMAND_R7B, + COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, + + COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats +}; + +struct common_chat_templates_inputs { + std::vector messages; + std::string grammar; + std::string json_schema; + bool add_generation_prompt = true; + bool use_jinja = true; + // Parameters below only supported when use_jinja is true + std::vector tools; + common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; + bool parallel_tool_calls = false; + bool extract_reasoning = true; +}; + +struct common_chat_params { + common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + std::string prompt; + std::string grammar; + bool grammar_lazy = false; + std::vector grammar_triggers; + std::vector preserved_tokens; + std::vector additional_stops; +}; + +// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid +bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); + +void common_chat_templates_free(struct common_chat_templates * tmpls); + +struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } }; + +typedef std::unique_ptr common_chat_templates_ptr; + +common_chat_templates_ptr common_chat_templates_init( + const struct llama_model * model, + const std::string & chat_template_override, + const std::string & bos_token_override = "", + const std::string & eos_token_override = ""); + +bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls); +const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr); + + +struct common_chat_params common_chat_templates_apply( + const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs); + +// Format single message, while taking into account the position of that message in chat history +std::string common_chat_format_single( + const struct common_chat_templates * tmpls, + const std::vector & past_msg, + const common_chat_msg & new_msg, + bool add_ass, + bool use_jinja); + +// Returns an example of formatted chat +std::string common_chat_format_example( + const struct common_chat_templates * tmpls, + bool use_jinja); + +std::string common_chat_format_name(common_chat_format format); +common_chat_msg common_chat_parse( const std::string & input, common_chat_format format); + +common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice); + +// Parses a JSON array of messages in OpenAI's chat completion API format. +// T can be std::string containing JSON or nlohmann::ordered_json +template std::vector common_chat_msgs_parse_oaicompat(const T & messages); +template T common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text = false); + +// Parses a JSON array of tools in OpenAI's chat completion tool call API format. +// T can be std::string containing JSON or nlohmann::ordered_json +template std::vector common_chat_tools_parse_oaicompat(const T & tools); +template T common_chat_tools_to_json_oaicompat(const std::vector & tools); diff --git a/common/chat.hpp b/common/chat.hpp deleted file mode 100644 index ba1632f669cf7..0000000000000 --- a/common/chat.hpp +++ /dev/null @@ -1,55 +0,0 @@ -// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers. - -#pragma once - -#include "common.h" -#include -#include -#include -#include - -using json = nlohmann::ordered_json; - -struct common_chat_inputs { - json messages; - json tools; - json tool_choice; - json json_schema; - bool parallel_tool_calls; - bool stream; - std::string grammar; - bool add_generation_prompt = true; - bool extract_reasoning = true; -}; - -enum common_chat_format { - COMMON_CHAT_FORMAT_CONTENT_ONLY, - COMMON_CHAT_FORMAT_GENERIC, - COMMON_CHAT_FORMAT_MISTRAL_NEMO, - COMMON_CHAT_FORMAT_LLAMA_3_X, - COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, - COMMON_CHAT_FORMAT_DEEPSEEK_R1, - COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, - COMMON_CHAT_FORMAT_FIREFUNCTION_V2, - COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, - COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, - COMMON_CHAT_FORMAT_HERMES_2_PRO, - COMMON_CHAT_FORMAT_COMMAND_R7B, - COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, - - COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats -}; - -struct common_chat_params { - common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - json prompt; - std::string grammar; - bool grammar_lazy = false; - std::vector grammar_triggers; - std::vector preserved_tokens; - std::vector additional_stops; -}; - -struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params); -std::string common_chat_format_name(common_chat_format format); -common_chat_msg common_chat_parse( const std::string & input, common_chat_format format); diff --git a/common/common.cpp b/common/common.cpp index 8661e164ada6b..d2b0d50e3ee39 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -12,8 +12,6 @@ #include "json.hpp" #include "json-schema-to-grammar.h" #include "llama.h" -#include "chat.hpp" -#include "chat-template.hpp" #include #include @@ -1768,174 +1766,6 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto return text; } -// -// Chat template utils -// - -bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { - if (use_jinja) { - try { - auto chat_template = common_chat_template(tmpl, "", ""); - common_chat_inputs inputs; - inputs.messages = json::array({{ - {"role", "user"}, - {"content", "test"}, - }}); - common_chat_params_init(chat_template, inputs); - return true; - } catch (const std::exception & e) { - LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); - return false; - } - } - llama_chat_message chat[] = {{"user", "test"}}; - const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0); - return res >= 0; -} - -std::string common_chat_apply_template( - const common_chat_template & tmpl, - const std::vector & msgs, - bool add_ass, - bool use_jinja) { - if (use_jinja) { - auto messages = json::array(); - for (const auto & msg : msgs) { - messages.push_back({{"role", msg.role}, {"content", msg.content}}); - } - common_chat_inputs inputs; - inputs.messages = messages; - inputs.add_generation_prompt = add_ass; - return common_chat_params_init(tmpl, inputs).prompt; - } - - int alloc_size = 0; - std::vector chat; - for (const auto & msg : msgs) { - chat.push_back({msg.role.c_str(), msg.content.c_str()}); - alloc_size += (msg.role.size() + msg.content.size()) * 1.25; - } - - std::vector buf(alloc_size); - - // run the first time to get the total output length - int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); - - // error: chat template is not supported - if (res < 0) { - // if the custom "tmpl" is not supported, we throw an error - // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() - throw std::runtime_error("this custom template is not supported"); - } - - // if it turns out that our buffer is too small, we resize it - if ((size_t) res > buf.size()) { - buf.resize(res); - res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); - } - - std::string formatted_chat(buf.data(), res); - return formatted_chat; -} - -std::string common_chat_format_single( - const common_chat_template & tmpl, - const std::vector & past_msg, - const common_chat_msg & new_msg, - bool add_ass, - bool use_jinja) { - std::ostringstream ss; - auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja); - std::vector chat_new(past_msg); - // if the past_msg ends with a newline, we must preserve it in the formatted version - if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { - ss << "\n"; - }; - // format chat with new_msg - chat_new.push_back(new_msg); - auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja); - // get the diff part - ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); - return ss.str(); -} - -std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) { - std::vector msgs = { - {"system", "You are a helpful assistant", {}}, - {"user", "Hello", {}}, - {"assistant", "Hi there", {}}, - {"user", "How are you?", {}}, - }; - return common_chat_apply_template(tmpl, msgs, true, use_jinja); -} - -#define CHATML_TEMPLATE_SRC \ - "{%- for message in messages -%}\n" \ - " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \ - "{%- endfor -%}\n" \ - "{%- if add_generation_prompt -%}\n" \ - " {{- '<|im_start|>assistant\n' -}}\n" \ - "{%- endif -%}" - -common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) -{ - std::string default_template_src; - std::string template_tool_use_src; - - bool has_explicit_template = !chat_template_override.empty(); - if (chat_template_override.empty()) { - auto str = llama_model_chat_template(model, /* name */ nullptr); - if (str) { - default_template_src = str; - has_explicit_template = true; - } - str = llama_model_chat_template(model, /* name */ "tool_use"); - if (str) { - template_tool_use_src = str; - has_explicit_template = true; - } - } else { - default_template_src = chat_template_override; - } - if (default_template_src.empty() || default_template_src == "chatml") { - if (!template_tool_use_src.empty()) { - default_template_src = template_tool_use_src; - } else { - default_template_src = CHATML_TEMPLATE_SRC; - } - } - auto vocab = llama_model_get_vocab(model); - const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) { - if (token == LLAMA_TOKEN_NULL) { - if (default_template_src.find(jinja_variable_name) != std::string::npos - || template_tool_use_src.find(jinja_variable_name) != std::string::npos) { - LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name); - } - return std::string(); - } else { - return common_token_to_piece(vocab, token, true); - } - }; - auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); - auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); - try { - return { - has_explicit_template, - std::make_unique(default_template_src, token_bos, token_eos), - template_tool_use_src.empty() - ? nullptr - : std::make_unique(template_tool_use_src, token_bos, token_eos), - }; - } catch (const std::exception & e) { - LOG_ERR("%s: failed to parse chat template: %s\n", __func__, e.what()); - return { - has_explicit_template, - std::make_unique(CHATML_TEMPLATE_SRC, token_bos, token_eos), - nullptr, - }; - } -} - // // KV cache utils // diff --git a/common/common.h b/common/common.h index 98b9a4464787a..10bcc10d51bb5 100644 --- a/common/common.h +++ b/common/common.h @@ -616,62 +616,6 @@ std::string common_detokenize( const std::vector & tokens, bool special = true); -// -// Chat template utils -// - -struct common_tool_call { - std::string name; - std::string arguments; - std::string id; -}; - -// same with llama_chat_message, but uses std::string -struct common_chat_msg { - std::string role; - std::string content; - std::vector tool_calls; - std::string reasoning_content = ""; -}; - -// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid -bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); - -namespace minja { - class chat_template; -} - -typedef minja::chat_template common_chat_template; - -struct common_chat_templates { - bool has_explicit_template; // Model had builtin template or template overridde was specified. - std::unique_ptr template_default; // always set (defaults to chatml) - std::unique_ptr template_tool_use; -}; - -// CPP wrapper for llama_chat_apply_template -// If the built-in template is not supported, we default to chatml -// If the custom "tmpl" is not supported, we throw an error -std::string common_chat_apply_template( - const common_chat_template & tmpl, - const std::vector & chat, - bool add_ass, - bool use_jinja); - -// Format single message, while taking into account the position of that message in chat history -std::string common_chat_format_single( - const common_chat_template & tmpl, - const std::vector & past_msg, - const common_chat_msg & new_msg, - bool add_ass, - bool use_jinja); - -// Returns an example of formatted chat -std::string common_chat_format_example( - const common_chat_template & tmpl, bool use_jinja); - -common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); - // // KV cache utils // diff --git a/common/chat-template.hpp b/common/minja/chat-template.hpp similarity index 100% rename from common/chat-template.hpp rename to common/minja/chat-template.hpp diff --git a/common/minja.hpp b/common/minja/minja.hpp similarity index 100% rename from common/minja.hpp rename to common/minja/minja.hpp diff --git a/examples/main/main.cpp b/examples/main/main.cpp index e654d3542c6c3..cf8659b037ee3 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -4,7 +4,7 @@ #include "log.h" #include "sampling.h" #include "llama.h" -#include "chat-template.hpp" +#include "chat.h" #include #include @@ -158,7 +158,7 @@ int main(int argc, char ** argv) { } const llama_vocab * vocab = llama_model_get_vocab(model); - auto chat_templates = common_chat_templates_from_model(model, params.chat_template); + auto chat_templates = common_chat_templates_init(model, params.chat_template); LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads); @@ -201,7 +201,7 @@ int main(int argc, char ** argv) { } // auto enable conversation mode if chat template is available - const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.template_default; + const bool has_chat_template = common_chat_templates_was_explicit(chat_templates.get()); if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) { if (has_chat_template) { LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__); @@ -219,7 +219,7 @@ int main(int argc, char ** argv) { // print chat template example in conversation mode if (params.conversation_mode) { if (params.enable_chat_template) { - LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str()); + LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.get(), params.use_jinja).c_str()); } else { LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); } @@ -264,9 +264,11 @@ int main(int argc, char ** argv) { std::vector embd_inp; auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { - common_chat_msg new_msg{role, content, {}}; - auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja); - chat_msgs.push_back({role, content, {}}); + common_chat_msg new_msg; + new_msg.role = role; + new_msg.content = content; + auto formatted = common_chat_format_single(chat_templates.get(), chat_msgs, new_msg, role == "user", g_params->use_jinja); + chat_msgs.push_back(new_msg); LOG_DBG("formatted: '%s'\n", formatted.c_str()); return formatted; }; @@ -755,11 +757,14 @@ int main(int argc, char ** argv) { // check for reverse prompt using special tokens llama_token last_token = common_sampler_last(smpl); - if (std::find(antiprompt_token.begin(), antiprompt_token.end(), last_token) != antiprompt_token.end()) { - if (params.interactive) { - is_interacting = true; + for (auto token : antiprompt_token) { + if (token == last_token) { + if (params.interactive) { + is_interacting = true; + } + is_antiprompt = true; + break; } - is_antiprompt = true; } if (is_antiprompt) { diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 9362da22083d3..ed8644ef78d97 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -24,7 +24,7 @@ #include #include -#include "chat-template.hpp" +#include "chat.h" #include "common.h" #include "json.hpp" #include "linenoise.cpp/linenoise.h" @@ -557,7 +557,7 @@ class LlamaData { llama_model_ptr model; llama_sampler_ptr sampler; llama_context_ptr context; - std::vector messages; + std::vector messages; // TODO: switch to common_chat_msg std::list msg_strs; std::vector fmtted; @@ -834,44 +834,23 @@ static void add_message(const char * role, const std::string & text, LlamaData & } // Function to apply the chat template and resize `formatted` if needed -static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) { - if (use_jinja) { - json messages = json::array(); - for (const auto & msg : llama_data.messages) { - messages.push_back({ - {"role", msg.role}, - {"content", msg.content}, - }); - } - try { - minja::chat_template_inputs tmpl_inputs; - tmpl_inputs.messages = messages; - tmpl_inputs.add_generation_prompt = append; - - minja::chat_template_options tmpl_opts; - tmpl_opts.use_bos_token = false; - tmpl_opts.use_eos_token = false; - - auto result = tmpl.apply(tmpl_inputs, tmpl_opts); - llama_data.fmtted.resize(result.size() + 1); - memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1); - return result.size(); - } catch (const std::exception & e) { - printe("failed to render the chat template: %s\n", e.what()); - return -1; - } - } - int result = llama_chat_apply_template( - tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append, - append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0); - if (append && result > static_cast(llama_data.fmtted.size())) { - llama_data.fmtted.resize(result); - result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(), - llama_data.messages.size(), append, llama_data.fmtted.data(), - llama_data.fmtted.size()); - } - - return result; +static int apply_chat_template(const struct common_chat_templates * tmpls, LlamaData & llama_data, const bool append, bool use_jinja) { + common_chat_templates_inputs inputs; + for (const auto & msg : llama_data.messages) { + common_chat_msg cmsg; + cmsg.role = msg.role; + cmsg.content = msg.content; + inputs.messages.push_back(cmsg); + } + inputs.add_generation_prompt = append; + inputs.use_jinja = use_jinja; + + auto chat_params = common_chat_templates_apply(tmpls, inputs); + // TODO: use other params for tool calls. + auto result = chat_params.prompt; + llama_data.fmtted.resize(result.size() + 1); + memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1); + return result.size(); } // Function to tokenize the prompt @@ -1015,8 +994,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt, } // Helper function to apply the chat template and handle errors -static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { - const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja); +static int apply_chat_template_with_error_handling(const common_chat_templates * tmpls, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { + const int new_len = apply_chat_template(tmpls, llama_data, append, use_jinja); if (new_len < 0) { printe("failed to apply the chat template\n"); return -1; @@ -1078,8 +1057,7 @@ static int get_user_input(std::string & user_input, const std::string & user) { static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) { int prev_len = 0; llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); - auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), ""); - GGML_ASSERT(chat_templates.template_default); + auto chat_templates = common_chat_templates_init(llama_data.model.get(), ""); static const bool stdout_a_terminal = is_stdout_a_terminal(); while (true) { // Get user input @@ -1090,7 +1068,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_ add_message("user", user.empty() ? user_input : user, llama_data); int new_len; - if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) { + if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) { return 1; } @@ -1105,7 +1083,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_ } add_message("assistant", response, llama_data); - if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) { + if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) { return 1; } } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 5707c766d7e05..809bfe0e36cd7 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -329,9 +329,6 @@ struct server_task { } // process "json_schema" and "grammar" - if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { - throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); - } if (data.contains("json_schema") && !data.contains("grammar")) { try { auto schema = json_value(data, "json_schema", json::object()); @@ -1807,7 +1804,7 @@ struct server_context { // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; - common_chat_templates chat_templates; + common_chat_templates_ptr chat_templates; ~server_context() { // Clear any sampling context @@ -1891,45 +1888,17 @@ struct server_context { llama_init_dft.context.reset(); } - if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) { + chat_templates = common_chat_templates_init(model, params_base.chat_template); + try { + common_chat_format_example(chat_templates.get(), params.use_jinja); + } catch (const std::exception & e) { SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); - chat_templates = common_chat_templates_from_model(model, "chatml"); - } else { - chat_templates = common_chat_templates_from_model(model, params_base.chat_template); + chat_templates = common_chat_templates_init(model, "chatml"); } - GGML_ASSERT(chat_templates.template_default.get() != nullptr); return true; } - bool validate_builtin_chat_template(bool use_jinja) const { - llama_chat_message chat[] = {{"user", "test"}}; - - if (use_jinja) { - auto templates = common_chat_templates_from_model(model, ""); - common_chat_inputs inputs; - inputs.messages = json::array({{ - {"role", "user"}, - {"content", "test"}, - }}); - GGML_ASSERT(templates.template_default); - try { - common_chat_params_init(*templates.template_default, inputs); - if (templates.template_tool_use) { - common_chat_params_init(*templates.template_tool_use, inputs); - } - return true; - } catch (const std::exception & e) { - SRV_ERR("failed to apply template: %s\n", e.what()); - return false; - } - } else { - const char * tmpl = llama_model_chat_template(model, /* name */ nullptr); - const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0); - return chat_res > 0; - } - } - void init() { const int32_t n_ctx_slot = n_ctx / params_base.n_parallel; @@ -3822,13 +3791,15 @@ int main(int argc, char ** argv) { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, { "model_path", ctx_server.params_base.model }, - { "chat_template", ctx_server.chat_templates.template_default->source() }, - { "bos_token", ctx_server.chat_templates.template_default->bos_token() }, - { "eos_token", ctx_server.chat_templates.template_default->eos_token() }, + { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) }, + { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)}, + { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)}, { "build_info", build_info }, }; - if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) { - data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source(); + if (ctx_server.params_base.use_jinja) { + if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) { + data["chat_template_tool_use"] = tool_use_src; + } } res_ok(res, data); @@ -4063,7 +4034,7 @@ int main(int argc, char ** argv) { } auto body = json::parse(req.body); - json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates); + json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get()); return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, @@ -4076,7 +4047,7 @@ int main(int argc, char ** argv) { // same with handle_chat_completions, but without inference part const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) { auto body = json::parse(req.body); - json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates); + json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get()); res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); }; @@ -4493,8 +4464,8 @@ int main(int argc, char ** argv) { // print sample chat example to make it clear which template is used LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, - ctx_server.chat_templates.template_default->source().c_str(), - common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str()); + common_chat_templates_source(ctx_server.chat_templates.get()), + common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str()); ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) { ctx_server.process_single_task(task); diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index f23d5cff49abc..af1dcb5b96554 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -21,6 +21,8 @@ def create_server(): (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None), + (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None), + (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None), ] ) def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template): @@ -44,7 +46,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte assert res.body["usage"]["completion_tokens"] == n_predicted choice = res.body["choices"][0] assert "assistant" == choice["message"]["role"] - assert match_regex(re_content, choice["message"]["content"]) + assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}' assert choice["finish_reason"] == finish_reason @@ -169,6 +171,47 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int assert "error" in res.body +@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [ + (False, {"const": "42"}, 6, "\"42\""), + (True, {"const": "42"}, 6, "\"42\""), +]) +def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str): + global server + server.jinja = jinja + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predicted, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Write an example"}, + ], + "json_schema": json_schema, + }) + assert res.status_code == 200, f'Expected 200, got {res.status_code}' + choice = res.body["choices"][0] + assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}' + + +@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [ + (False, 'root ::= "a"{5,5}', 6, "a{5,5}"), + (True, 'root ::= "a"{5,5}', 6, "a{5,5}"), +]) +def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str): + global server + server.jinja = jinja + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predicted, + "messages": [ + {"role": "user", "content": "Does not matter what I say, does it?"}, + ], + "grammar": grammar, + }) + assert res.status_code == 200, res.body + choice = res.body["choices"][0] + assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"] + + @pytest.mark.parametrize("messages", [ None, "string", diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py index ba3367b4f332d..a91a2f3333ca3 100644 --- a/examples/server/tests/unit/test_tool_call.py +++ b/examples/server/tests/unit/test_tool_call.py @@ -356,12 +356,12 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | (None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), - ("^> 0.56$", 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), # TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value) - ("^The y-coordinate [\\s\\S]*?\\*\\*0.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - ("[\\s\\S]*?\\*\\*0\\.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), + ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + # ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), ]) def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None): global server @@ -401,7 +401,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, { "role": "tool", "name": "calculate", - "content": 0.55644242476, + "content": "0.55644242476", "tool_call_id": "call_6789" } ], @@ -444,7 +444,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, (128, None, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), (1024, 'deepseek', "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - (1024, 'none', "\n?I need[\\s\\S]*?\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + (1024, 'none', "^I need[\\s\\S]*?\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), (1024, 'deepseek', "To find the sum of.*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), ]) diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 60cb2673ec2ec..6f8ab2b93aac7 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -12,9 +12,7 @@ // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT #include "json.hpp" -#include "minja.hpp" -#include "chat.hpp" -#include "chat-template.hpp" +#include "chat.h" #include #include @@ -347,41 +345,6 @@ static llama_tokens format_infill( return embd_inp; } -// Format given chat. If tmpl is empty, we take the template from model metadata -inline std::string format_chat(const common_chat_template & tmpl, const std::vector & messages) { - std::vector chat; - - for (size_t i = 0; i < messages.size(); ++i) { - const auto & curr_msg = messages[i]; - - std::string role = json_value(curr_msg, "role", std::string("")); - - std::string content; - if (curr_msg.contains("content")) { - if (curr_msg["content"].is_string()) { - content = curr_msg["content"].get(); - } else if (curr_msg["content"].is_array()) { - for (const auto & part : curr_msg["content"]) { - if (part.contains("text")) { - content += "\n" + part["text"].get(); - } - } - } else { - throw std::runtime_error("Invalid 'content' type (ref: /~https://github.com/ggml-org/llama.cpp/issues/8367)"); - } - } else { - throw std::runtime_error("Missing 'content' (ref: /~https://github.com/ggml-org/llama.cpp/issues/8367)"); - } - - chat.push_back({role, content, /* tool_calls= */ {}}); - } - - const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false); - LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str()); - - return formatted_chat; -} - // // base64 utils (TODO: move to common in the future) // @@ -579,12 +542,9 @@ static json oaicompat_completion_params_parse( const json & body, /* openai api json semantics */ bool use_jinja, common_reasoning_format reasoning_format, - const common_chat_templates & chat_templates) + const struct common_chat_templates * tmpls) { json llama_params; - const auto & tmpl = body.contains("tools") && chat_templates.template_tool_use - ? *chat_templates.template_tool_use - : *chat_templates.template_default; auto tools = json_value(body, "tools", json()); auto stream = json_value(body, "stream", false); @@ -610,62 +570,58 @@ static json oaicompat_completion_params_parse( llama_params["stop"] = json_value(body, "stop", json::array()); } + auto json_schema = json_value(body, "json_schema", json()); + auto grammar = json_value(body, "grammar", std::string()); + if (!json_schema.is_null() && !grammar.empty()) { + throw std::runtime_error("Cannot use both json_schema and grammar"); + } + // Handle "response_format" field if (body.contains("response_format")) { json response_format = json_value(body, "response_format", json::object()); std::string response_type = json_value(response_format, "type", std::string()); if (response_type == "json_object") { - llama_params["json_schema"] = json_value(response_format, "schema", json::object()); + json_schema = json_value(response_format, "schema", json::object()); } else if (response_type == "json_schema") { json json_schema = json_value(response_format, "json_schema", json::object()); - llama_params["json_schema"] = json_value(json_schema, "schema", json::object()); + json_schema = json_value(json_schema, "schema", json::object()); } else if (!response_type.empty() && response_type != "text") { throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); } } + common_chat_templates_inputs inputs; + inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages")); + inputs.tools = common_chat_tools_parse_oaicompat(tools); + inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto"))); + inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); + inputs.grammar = grammar; + inputs.add_generation_prompt = true; + inputs.use_jinja = use_jinja; + inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); + inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; + if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) { + throw std::runtime_error("Cannot use custom grammar constraints with tools."); + } + // Apply chat template to the list of messages - if (use_jinja) { - auto tool_choice = json_value(body, "tool_choice", std::string("auto")); - if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") { - throw std::runtime_error("Invalid tool_choice: " + tool_choice); - } - if (tool_choice != "none" && llama_params.contains("grammar")) { - throw std::runtime_error("Cannot use custom grammar constraints with tools."); - } - common_chat_inputs inputs; - inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; - inputs.messages = body.at("messages"); - inputs.tools = tools; - inputs.tool_choice = tool_choice; - inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); - if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) { - LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n"); - inputs.parallel_tool_calls = false; - } - inputs.stream = stream; - // TODO: support mixing schema w/ tools beyond generic format. - inputs.json_schema = json_value(llama_params, "json_schema", json()); - auto chat_params = common_chat_params_init(tmpl, inputs); - - llama_params["chat_format"] = static_cast(chat_params.format); - llama_params["prompt"] = chat_params.prompt; - llama_params["grammar"] = chat_params.grammar; - llama_params["grammar_lazy"] = chat_params.grammar_lazy; - auto grammar_triggers = json::array(); - for (const auto & trigger : chat_params.grammar_triggers) { - grammar_triggers.push_back({ - {"word", trigger.word}, - {"at_start", trigger.at_start}, - }); - } - llama_params["grammar_triggers"] = grammar_triggers; - llama_params["preserved_tokens"] = chat_params.preserved_tokens; - for (const auto & stop : chat_params.additional_stops) { - llama_params["stop"].push_back(stop); - } - } else { - llama_params["prompt"] = format_chat(tmpl, body.at("messages")); + auto chat_params = common_chat_templates_apply(tmpls, inputs); + + llama_params["chat_format"] = static_cast(chat_params.format); + llama_params["prompt"] = chat_params.prompt; + llama_params["grammar"] = chat_params.grammar; + llama_params["grammar_lazy"] = chat_params.grammar_lazy; + auto grammar_triggers = json::array(); + for (const auto & trigger : chat_params.grammar_triggers) { + grammar_triggers.push_back({ + {"word", trigger.word}, + {"at_start", trigger.at_start}, + }); + } + llama_params["grammar_triggers"] = grammar_triggers; + llama_params["preserved_tokens"] = chat_params.preserved_tokens; + for (const auto & stop : chat_params.additional_stops) { + llama_params["stop"].push_back(stop); } // Handle "n" field diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index e0314ae1d6296..9231c517afb0b 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -1,13 +1,14 @@ #include #include #include +#include #undef NDEBUG #include #include "llama.h" #include "common.h" -#include "chat-template.hpp" +#include "chat.h" static std::string normalize_newlines(const std::string & s) { #ifdef _WIN32 @@ -18,6 +19,13 @@ static std::string normalize_newlines(const std::string & s) { #endif } +static common_chat_msg simple_msg(const std::string & role, const std::string & content) { + common_chat_msg msg; + msg.role = role; + msg.content = content; + return msg; +} + int main(void) { std::vector conversation { {"system", "You are a helpful assistant"}, @@ -50,7 +58,7 @@ int main(void) { /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", /* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", /* .expected_output_jinja= */ "", - /* .bos_token= */ "", + /* .bos_token= */ "", /* .eos_token= */ "", }, { @@ -72,8 +80,8 @@ int main(void) { { /* .name= */ "mlabonne/AlphaMonarch-7B", /* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}", - /* .expected_output= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", - /* .expected_output_jinja= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", + /* .expected_output= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", + /* .expected_output_jinja= */ "", /* .bos_token= */ "", /* .eos_token= */ "", }, @@ -87,7 +95,7 @@ int main(void) { /* .name= */ "OrionStarAI/Orion-14B-Chat", /* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}", /* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", - /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", + /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", /* .bos_token= */ "", /* .eos_token= */ "", }, @@ -304,12 +312,9 @@ int main(void) { } } - json messages = json::array(); + std::vector messages; for (const auto & msg : conversation) { - messages.push_back({ - {"role", msg.role}, - {"content", msg.content}, - }); + messages.push_back(simple_msg(msg.role, msg.content)); } for (const auto & test_case : test_cases) { if (!test_case.supported_with_jinja) { @@ -317,8 +322,13 @@ int main(void) { } printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str()); try { - minja::chat_template tmpl(test_case.template_str, test_case.bos_token, test_case.eos_token); - auto output = normalize_newlines(tmpl.apply(messages, json(), add_generation_prompt)); + auto tmpls = common_chat_templates_init(/* model= */ nullptr, test_case.template_str.c_str(), test_case.bos_token, test_case.eos_token); + common_chat_templates_inputs inputs; + inputs.use_jinja = true; + inputs.messages = messages; + inputs.add_generation_prompt = add_generation_prompt; + auto output = common_chat_templates_apply(tmpls.get(), inputs).prompt; + output = normalize_newlines(output); auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja); if (output != expected_output) { printf("Expected:\n%s\n", expected_output.c_str()); @@ -336,11 +346,11 @@ int main(void) { // test llama_chat_format_single for system message printf("\n\n=== llama_chat_format_single (system message) ===\n\n"); std::vector chat2; - common_chat_msg sys_msg{"system", "You are a helpful assistant", {}}; + auto sys_msg = simple_msg("system", "You are a helpful assistant"); auto fmt_sys = [&](std::string tmpl_str) { - minja::chat_template tmpl(tmpl_str, "", ""); - auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false); + auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str); + auto output = common_chat_format_single(tmpls.get(), chat2, sys_msg, false, /* use_jinja= */ false); printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n"); return output; @@ -360,14 +370,14 @@ int main(void) { // test llama_chat_format_single for user message printf("\n\n=== llama_chat_format_single (user message) ===\n\n"); - chat2.push_back({"system", "You are a helpful assistant", {}}); - chat2.push_back({"user", "Hello", {}}); - chat2.push_back({"assistant", "I am assistant", {}}); - common_chat_msg new_msg{"user", "How are you", {}}; + chat2.push_back(simple_msg("system", "You are a helpful assistant")); + chat2.push_back(simple_msg("user", "Hello")); + chat2.push_back(simple_msg("assistant", "I am assistant")); + auto new_msg = simple_msg("user", "How are you"); - auto fmt_single = [&](std::string tmpl_str) { - minja::chat_template tmpl(tmpl_str, "", ""); - auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false); + auto fmt_single = [&](const std::string & tmpl_str) { + auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str.c_str()); + auto output = common_chat_format_single(tmpls.get(), chat2, new_msg, true, /* use_jinja= */ false); printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n"); return output; diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 2836caf6a71a3..6435923054859 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -10,38 +10,12 @@ #include #include -#include "chat-template.hpp" -#include "chat.hpp" +#include "chat.h" #include "llama-grammar.h" #include "unicode.h" using json = nlohmann::ordered_json; -static common_chat_msg msg_from_json(const json & message) { - common_chat_msg ret; - ret.role = "assistant"; - if (message.contains("content") && !message.at("content").is_null()) { - ret.content = message.at("content"); - } - if (message.contains("tool_plan")) { - ret.reasoning_content = message.at("tool_plan"); - } - if (message.contains("reasoning_content")) { - ret.reasoning_content = message.at("reasoning_content"); - } - auto has_tool_calls = message.contains("tool_calls"); - if (has_tool_calls) { - for (const auto & tc : message.at("tool_calls")) { - const auto & arguments = tc.at("function").at("arguments"); - ret.tool_calls.push_back({ - tc.at("function").at("name").get(), - arguments.is_string() ? arguments.get() : arguments.dump(), - tc.contains("id") ? tc.at("id").get() : "", - }); - } - } - return ret; -} template static void assert_equals(const T & expected, const T & actual) { if (expected != actual) { @@ -53,7 +27,7 @@ template static void assert_equals(const T & expected, const T & actua } static std::string read_file(const std::string & path) { - std::cerr << "# Reading: " << path << std::endl << std::flush; + std::cerr << "# Reading: " << path << '\n' << std::flush; std::ifstream fs(path, std::ios_base::binary); if (!fs.is_open()) { fs = std::ifstream("../" + path, std::ios_base::binary); @@ -66,10 +40,14 @@ static std::string read_file(const std::string & path) { fs.seekg(0); std::string out; out.resize(static_cast(size)); - fs.read(&out[0], static_cast(size)); + fs.read(out.data(), static_cast(size)); return out; } +static common_chat_templates_ptr read_templates(const std::string & path) { + return common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, read_file(path))); +} + static std::unique_ptr build_grammar(const std::string & grammar_str) { return std::unique_ptr( llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0)); @@ -90,110 +68,102 @@ static bool match_string(const std::string & input, llama_grammar * grammar) { } } - for (const auto & stack : stacks_cur) { - if (stack.empty()) { - // An empty stack means that the grammar has been completed - return true; - } + if (std::any_of(stacks_cur.begin(), stacks_cur.end(), [](const auto & stack) { return stack.empty(); })) { + // An empty stack means that the grammar has been completed + return true; } return false; } -// Dumps `{"a": 1}` as `"{\"a\": 1}"`, unlike nlohmann::json::dump which would dump it as `"{\"a\":1}"`. -static std::string dump(const json & j) { - return minja::Value(j).dump(-1, /* to_json= */ true); -} - static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) { assert_equals(expected.role, actual.role); assert_equals(expected.content, actual.content); + assert_equals(expected.content_parts.size(), actual.content_parts.size()); + for (size_t i = 0; i < expected.content_parts.size(); i++) { + const auto & expected_part = expected.content_parts[i]; + const auto & actual_part = actual.content_parts[i]; + assert_equals(expected_part.type, actual_part.type); + assert_equals(expected_part.text, actual_part.text); + } assert_equals(expected.reasoning_content, actual.reasoning_content); assert_equals(expected.tool_calls.size(), actual.tool_calls.size()); for (size_t i = 0; i < expected.tool_calls.size(); i++) { const auto & expected_tool_call = expected.tool_calls[i]; const auto & actual_tool_call = actual.tool_calls[i]; assert_equals(expected_tool_call.name, actual_tool_call.name); - assert_equals(dump(json::parse(expected_tool_call.arguments)), dump(json::parse(actual_tool_call.arguments))); + assert_equals(json::parse(expected_tool_call.arguments).dump(), json::parse(actual_tool_call.arguments).dump()); assert_equals(expected_tool_call.id, actual_tool_call.id); } } -const auto special_function_tool = json::parse(R"({ - "type": "function", - "function": { - "name": "special_function", - "description": "I'm special", - "parameters": { - "type": "object", - "properties": { - "arg1": { - "type": "integer", - "description": "The arg." - } - }, - "required": ["arg1"] - } - } -})"); -const auto python_tool = json::parse(R"({ - "type": "function", - "function": { - "name": "python", - "description": "an ipython interpreter", - "parameters": { - "type": "object", - "properties": { - "code": { - "type": "string", - "description": "Python code to execute." - } - }, - "required": ["code"] - } - } -})"); -const auto code_interpreter_tool = json::parse(R"({ - "type": "function", - "function": { - "name": "code_interpreter", - "description": "an ipython interpreter", - "parameters": { - "type": "object", - "properties": { - "code": { - "type": "string", - "description": "Python code to execute." - } - }, - "required": ["code"] - } - } -})"); -const json tools = { special_function_tool, python_tool }; -const json llama_3_1_tools = { special_function_tool, code_interpreter_tool }; +common_chat_tool special_function_tool { + /* .name = */ "special_function", + /* .description = */ "I'm special", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "arg1": { + "type": "integer", + "description": "The arg." + } + }, + "required": ["arg1"] + })", +}; +common_chat_tool python_tool { + /* .name = */ "python", + /* .description = */ "an ipython interpreter", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute." + } + }, + "required": ["code"] + })", +}; +common_chat_tool code_interpreter_tool { + /* .name = */ "code_interpreter", + /* .description = */ "an ipython interpreter", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute." + } + }, + "required": ["code"] + })", +}; +std::vector tools { special_function_tool, python_tool }; +std::vector llama_3_1_tools { special_function_tool, code_interpreter_tool }; struct delta_data { std::string delta; common_chat_params params; }; -static delta_data init_delta(const common_chat_template & tmpl, const std::vector & end_tokens, - const json & user_message, const json & delta_message, const json & tools, - const json & tool_choice, +static delta_data init_delta(const struct common_chat_templates * tmpls, const std::vector & end_tokens, + const common_chat_msg & user_message, + const common_chat_msg & delta_message, + const std::vector & tools, + const common_chat_tool_choice & tool_choice, bool think = false) { - common_chat_inputs inputs; + common_chat_templates_inputs inputs; inputs.parallel_tool_calls = true; - inputs.messages = json::array(); inputs.messages.push_back(user_message); inputs.tools = tools; inputs.tool_choice = tool_choice; inputs.extract_reasoning = think; - auto params_prefix = common_chat_params_init(tmpl, inputs); + auto params_prefix = common_chat_templates_apply(tmpls, inputs); inputs.messages.push_back(delta_message); inputs.add_generation_prompt = false; - auto params_full = common_chat_params_init(tmpl, inputs); + auto params_full = common_chat_templates_apply(tmpls, inputs); std::string prefix = params_prefix.prompt; std::string full = params_full.prompt; @@ -234,30 +204,29 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto gets the diff, removes any end tokens and parses the result w/ the grammar, checking that the parsed message is the same as the test_message */ -static void test_template(const common_chat_template & tmpl, const std::vector & end_tokens, - const json & test_message, const json & tools = {}, const std::string & expected_delta = "", +static void test_templates(const struct common_chat_templates * tmpls, const std::vector & end_tokens, + const common_chat_msg & test_message, + const std::vector & tools = {}, + const std::string & expected_delta = "", bool expect_grammar_triggered = true, bool test_grammar_if_triggered = true, bool think = false) { - common_chat_msg expected_msg = msg_from_json(test_message); - - auto user_message = json{ - { "role", "user" }, - { "content", "Hello, world!" } - }; + common_chat_msg user_message; + user_message.role = "user"; + user_message.content = "Hello, world!"; - for (const auto & tool_choice : json({ "auto", "required" })) { - auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice, think); + for (const auto & tool_choice : std::vector {COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED}) { + auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice, think); if (!expected_delta.empty()) { assert_equals(expected_delta, data.delta); } if (expect_grammar_triggered) { const auto msg = common_chat_parse(data.delta, data.params.format); - assert_msg_equals(expected_msg, msg); + assert_msg_equals(test_message, msg); } - if (!expected_msg.tool_calls.empty()) { + if (!test_message.tool_calls.empty()) { GGML_ASSERT(!data.params.grammar.empty()); } if (!data.params.grammar.empty()) { @@ -297,246 +266,339 @@ static void test_template(const common_chat_template & tmpl, const std::vectorI'm thinkingHello, world!\nWhat's up?" }, - }; - json message_assist_thoughts_unparsed_r7b { - { "role", "assistant" }, - { "content", "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?" }, - }; - json message_assist_thoughts { - { "role", "assistant" }, - { "content", "Hello, world!\nWhat's up?" }, - { "reasoning_content", "I'm thinking" }, - }; - json tool_calls = json::array({{ - { "type", "function" }, - { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } }, - }}); - - json message_assist_call { - { "role", "assistant"}, - { "content", {}}, - { "tool_calls", { - { - { "type", "function" }, - { "function", { - { "name", "special_function" }, - { "arguments", "{\"arg1\": 1}" }, - }}, - }, - }}, - }; - json message_assist_call_thoughts = { - { "role", "assistant" }, - { "content", nullptr }, - { "reasoning_content", "I'm\nthinking" }, - { "tool_calls", { - { - { "type", "function" }, - { "function", { - { "name", "special_function" }, - { "arguments", "{\"arg1\": 1}" }, - }}, - }, - }}, - }; - json message_assist_call_thoughts_unparsed = { - { "role", "assistant" }, - { "content", "I'm\nthinking" }, - { "tool_calls", { - { - { "type", "function" }, - { "function", { - { "name", "special_function" }, - { "arguments", "{\"arg1\": 1}" }, - }}, - }, - }}, - }; - json message_assist_call_id { - { "role", "assistant"}, - { "content", {}}, - { "tool_calls", { - { - { "type", "function" }, - { "function", { - { "name", "special_function" }, - { "arguments", "{\"arg1\": 1}" }, - }}, - {"id", "123456789"}, - }, - }}, - { "role", "assistant" }, - { "content", {} }, - { "tool_calls", tool_calls } - }; - json message_assist_call_idx { - { "role", "assistant"}, - { "content", {}}, - { "tool_calls", { - { - { "type", "function" }, - { "function", { - { "name", "special_function" }, - { "arguments", "{\"arg1\": 1}" }, - }}, - // Index of the tool call in the tool_calls array - {"id", "0"}, - }, - }}, - { "role", "assistant" }, - { "content", {} }, - { "tool_calls", tool_calls } - }; - json message_assist_call_tool_plan_idx = message_assist_call_idx; - message_assist_call_tool_plan_idx["tool_plan"] = "I'm thinking"; - - auto python_message_assist_call = json{ - { "role", "assistant" }, - { "content", {} }, - { "tool_calls", json{ { - { "type", "function" }, - { "function", - { - { "name", "python" }, - { "arguments", - { - { "code", "print('hey')" }, - } }, - } }, - } } } +const common_chat_msg message_user { + "user", + "Hey there!", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; + +const common_chat_msg message_user_parts { + "user", + /* .content = */ "", + /* .content_parts = */ { + { "text", "Hey" }, + { "text", "there" }, + }, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist { + "assistant", + "Hello, world!\nWhat's up?", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_thoughts_unparsed_think { + "assistant", + "I'm thinkingHello, world!\nWhat's up?", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_thoughts_unparsed_r7b { + "assistant", + "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_thoughts { + "assistant", + "Hello, world!\nWhat's up?", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "I'm thinking", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const std::vector tool_calls { + { "special_function", "{\"arg1\": 1}", /* .id = */ "" }, +}; +const std::vector tool_calls_idx { + { "special_function", "{\"arg1\": 1}", /* .id = */ "0" }, +}; +const std::vector tool_calls_id { + { "special_function", "{\"arg1\": 1}", /* .id = */ "123456789" }, +}; + +const common_chat_msg message_assist_call { + "assistant", + "", + /* .content_parts = */ {}, + tool_calls, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_call_thoughts = { + "assistant", + /* .content = */ "", + /* .content_parts = */ {}, + tool_calls, + /* .reasoning_content = */ "I'm\nthinking", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_call_thoughts_unparsed = { + "assistant", + /* .content = */ "I'm\nthinking", + /* .content_parts = */ {}, + tool_calls, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_call_id { + "assistant", + "", + /* .content_parts = */ {}, + tool_calls_id, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_call_idx { + "assistant", + "", + /* .content_parts = */ {}, + tool_calls_idx, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_call_python { + "assistant", + "", + /* .content_parts = */ {}, + { { "python", "{\"code\": \"print('hey')\"}", /* .id = */ "" } }, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_call_code_interpreter { + "assistant", + "", + /* .content_parts = */ {}, + { { "code_interpreter", "{\"code\": \"print('hey')\"}", /* .id = */ "" } }, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; + +static void test_msgs_oaicompat_json_conversion() { + std::vector msgs{ + message_user, + message_user_parts, + message_assist_call, + message_assist_call_thoughts, + message_assist_call_thoughts_unparsed, + message_assist_call_id, + message_assist_call_idx, + message_assist_call_python, + message_assist_call_code_interpreter, }; - auto code_interpreter_message_assist_call = json{ - { "role", "assistant" }, - { "content", {} }, - { "tool_calls", json{ { - { "type", "function" }, - { "function", - { - { "name", "code_interpreter" }, - { "arguments", - { - { "code", "print('hey')" }, - } }, - } }, - } } } + for (const auto & msg : msgs) { + auto oai_json = common_chat_msgs_to_json_oaicompat({msg}); + auto msgs2 = common_chat_msgs_parse_oaicompat(oai_json); + assert_equals((size_t) 1, msgs2.size()); + auto msg2 = msgs2[0]; + assert_msg_equals(msg, msg2); + } + assert_equals( + std::string( + "[\n" + " {\n" + " \"role\": \"user\",\n" + " \"content\": [\n" + " {\n" + " \"type\": \"text\",\n" + " \"text\": \"Hey\"\n" + " },\n" + " {\n" + " \"type\": \"text\",\n" + " \"text\": \"there\"\n" + " }\n" + " ]\n" + " }\n" + "]" + ), + common_chat_msgs_to_json_oaicompat({message_user_parts}).dump(2)); + + assert_equals( + std::string( + "[\n" + " {\n" + " \"role\": \"assistant\",\n" + " \"content\": null,\n" + " \"tool_calls\": [\n" + " {\n" + " \"type\": \"function\",\n" + " \"function\": {\n" + " \"name\": \"python\",\n" + " \"arguments\": \"{\\\"code\\\": \\\"print('hey')\\\"}\"\n" + " }\n" + " }\n" + " ]\n" + " }\n" + "]" + ), + common_chat_msgs_to_json_oaicompat({message_assist_call_python}).dump(2)); +} + +static void test_tools_oaicompat_json_conversion() { + std::vector tools{ + special_function_tool, + python_tool, + code_interpreter_tool, }; - common_chat_inputs inputs_no_tools; - inputs_no_tools.messages = json::array({message_user}); + for (const auto & tool : tools) { + auto oai_json = common_chat_tools_to_json_oaicompat({tool}); + auto tools2 = common_chat_tools_parse_oaicompat(oai_json); + assert_equals((size_t) 1, tools2.size()); + auto tool2 = tools2[0]; + assert_equals(tool.name, tool2.name); + assert_equals(tool.description, tool2.description); + assert_equals(json::parse(tool.parameters).dump(2), json::parse(tool2.parameters).dump(2)); + } + + assert_equals( + std::string( + "[\n" + " {\n" + " \"type\": \"function\",\n" + " \"function\": {\n" + " \"name\": \"special_function\",\n" + " \"description\": \"I'm special\",\n" + " \"parameters\": {\n" + " \"type\": \"object\",\n" + " \"properties\": {\n" + " \"arg1\": {\n" + " \"type\": \"integer\",\n" + " \"description\": \"The arg.\"\n" + " }\n" + " },\n" + " \"required\": [\n" + " \"arg1\"\n" + " ]\n" + " }\n" + " }\n" + " }\n" + "]" + ), + common_chat_tools_to_json_oaicompat({special_function_tool}).dump(2)); +} + +static void test_template_output_parsers() { + + common_chat_templates_inputs inputs_no_tools; + inputs_no_tools.messages = {message_user}; inputs_no_tools.extract_reasoning = false; - common_chat_inputs inputs_no_tools_think; - inputs_no_tools_think.messages = json::array({message_user}); + common_chat_templates_inputs inputs_no_tools_think; + inputs_no_tools_think.messages = {message_user}; inputs_no_tools_think.extract_reasoning = true; - common_chat_inputs inputs_tools; - inputs_tools.messages = json::array({message_user}); - inputs_tools.tools = json::array({special_function_tool}); + common_chat_templates_inputs inputs_tools; + inputs_tools.messages = {message_user}; + inputs_tools.tools = {special_function_tool}; inputs_tools.extract_reasoning = false; - common_chat_inputs inputs_tools_think; - inputs_tools_think.messages = json::array({message_user}); - inputs_tools_think.tools = json::array({special_function_tool}); + common_chat_templates_inputs inputs_tools_think; + inputs_tools_think.messages = {message_user}; + inputs_tools_think.tools = {special_function_tool}; inputs_tools_think.extract_reasoning = true; - common_chat_inputs inputs_tools_builtin; - inputs_tools_builtin.messages = json::array({message_user}); - inputs_tools_builtin.tools = json::array({python_tool}); + common_chat_templates_inputs inputs_tools_builtin; + inputs_tools_builtin.messages = {message_user}; + inputs_tools_builtin.tools = {python_tool}; inputs_tools_builtin.extract_reasoning = false; { // Not supported yet - const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"), "", ""); - assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format); + auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"); + assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format); } { - const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "", ""); + auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"); std::vector end_tokens{ "<|END_OF_TURN_TOKEN|>" }; - assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format); - assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format); + assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format); - assert_msg_equals(msg_from_json(message_assist), + assert_msg_equals(message_assist, common_chat_parse( "Hello, world!\nWhat's up?", COMMON_CHAT_FORMAT_COMMAND_R7B)); - assert_msg_equals(msg_from_json(message_assist), + assert_msg_equals(message_assist, common_chat_parse( "Hello, world!\nWhat's up?<|END_RESPONSE|>", COMMON_CHAT_FORMAT_COMMAND_R7B)); - assert_msg_equals(msg_from_json(message_assist), + assert_msg_equals(message_assist, common_chat_parse( "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", COMMON_CHAT_FORMAT_COMMAND_R7B)); - assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b), + assert_msg_equals(message_assist_thoughts_unparsed_r7b, common_chat_parse( "<|START_THINKING|>I'm thinking<|END_THINKING|>" "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", COMMON_CHAT_FORMAT_COMMAND_R7B)); - assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b), + assert_msg_equals(message_assist_thoughts_unparsed_r7b, common_chat_parse( "<|START_THINKING|>I'm thinking<|END_THINKING|>" "Hello, world!\nWhat's up?<|END_RESPONSE|>", COMMON_CHAT_FORMAT_COMMAND_R7B)); - assert_msg_equals(msg_from_json(message_assist_thoughts), + assert_msg_equals(message_assist_thoughts, common_chat_parse( "<|START_THINKING|>I'm thinking<|END_THINKING|>" "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING)); - test_template(tmpl, end_tokens, message_assist_call_idx, tools, + test_templates(tmpls.get(), end_tokens, message_assist_call_idx, tools, "<|START_THINKING|><|END_THINKING|>" "<|START_ACTION|>[\n" " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n" "]<|END_ACTION|>"); - test_template(tmpl, end_tokens, message_assist_call_tool_plan_idx, tools, - "<|START_THINKING|>I'm thinking<|END_THINKING|>" - "<|START_ACTION|>[\n" - " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n" - "]<|END_ACTION|>", - /* expect_grammar_triggered= */ true, - /* test_grammar_if_triggered= */ true, - /* think= */ true); - test_template(tmpl, end_tokens, message_assist, tools, + test_templates(tmpls.get(), end_tokens, message_assist, tools, "<|START_RESPONSE|>Hello, world!\n" "What's up?<|END_RESPONSE|>", /* expect_grammar_triggered= */ false); } { - const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "", ""); + auto tmpls = read_templates("models/templates/google-gemma-2-2b-it.jinja"); std::vector end_tokens{ "" }; - assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format); assert_equals(COMMON_CHAT_FORMAT_GENERIC, - common_chat_params_init( - common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), - "", ""), + common_chat_templates_apply( + read_templates("models/templates/microsoft-Phi-3.5-mini-instruct.jinja").get(), inputs_tools) .format); // Generic tool calls doesn't generate / parse content-only messages symmetrically. - assert_msg_equals(msg_from_json(message_assist), + assert_msg_equals(message_assist, common_chat_parse("{\n" " \"response\": \"Hello, world!\\nWhat's up?\"\n" "}", - common_chat_params_init(tmpl, inputs_tools).format)); - test_template(tmpl, end_tokens, message_assist_call_id, tools, + common_chat_templates_apply(tmpls.get(), inputs_tools).format)); + test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools, "{\n" " \"tool_calls\": [\n" " {\n" @@ -550,143 +612,133 @@ static void test_template_output_parsers() { "}"); } { - const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "", - ""); + auto tmpls = read_templates("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"); std::vector end_tokens{ "" }; - assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template( - tmpl, end_tokens, message_assist_call_id, tools, + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates( + tmpls.get(), end_tokens, message_assist_call_id, tools, "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]"); } { - const common_chat_template tmpl( - read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "", ""); + auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"); std::vector end_tokens{ "<|im_end|>" }; - assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format); assert_equals( COMMON_CHAT_FORMAT_HERMES_2_PRO, - common_chat_params_init( - common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), - "", ""), + common_chat_templates_apply( + read_templates("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja").get(), inputs_tools) .format); assert_equals( COMMON_CHAT_FORMAT_HERMES_2_PRO, - common_chat_params_init( - common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "", ""), + common_chat_templates_apply( + read_templates("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja").get(), inputs_tools) .format); - test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "\n" "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" ""); - test_template(tmpl, end_tokens, python_message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools, "\n" "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n" ""); } { - const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "", - ""); + auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"); std::vector end_tokens{ "<|eom_id|>", "<|eot_id|>" }; - assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format); assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, - common_chat_params_init(tmpl, inputs_tools_builtin).format); + common_chat_templates_apply(tmpls.get(), inputs_tools_builtin).format); assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, - common_chat_params_init( - common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), - "", ""), + common_chat_templates_apply( + read_templates("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja").get(), inputs_tools_builtin) .format); - // test_template(tmpl, end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, code_interpreter_message_assist_call, llama_3_1_tools, + // test_templates(tmpls.get(), end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_call_code_interpreter, llama_3_1_tools, "<|python_tag|>code_interpreter.call(code=\"print('hey')\")"); - test_template(tmpl, end_tokens, python_message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools, "<|python_tag|>python.call(code=\"print('hey')\")"); - test_template(tmpl, end_tokens, message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); } { - const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "", - ""); + auto tmpls = read_templates("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"); std::vector end_tokens{ "<|eom_id|>", "<|eot_id|>" }; - assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); } { - const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "", - ""); + auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.1.jinja"); std::vector end_tokens{ "<|eom_id|>", "<|eot_id|>" }; assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, - common_chat_params_init(tmpl, inputs_tools).format); + common_chat_templates_apply(tmpls.get(), inputs_tools).format); - test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "{\"arg1\": 1}"); } { - const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "", - ""); + auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.2.jinja"); std::vector end_tokens{ "<|eom_id|>", "<|eot_id|>" }; - assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - test_template(tmpl, end_tokens, message_assist, {}, + test_templates(tmpls.get(), end_tokens, message_assist, {}, "all\n" "Hello, world!\n" "What's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "special_function\n" "{\"arg1\": 1}"); } { - const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "", - ""); + auto tmpls = read_templates("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"); std::vector end_tokens{ "<|eot_id|>" }; - assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]"); } { // Original DeepSeek R1 template. Leaves <|tool▁calls▁begin|> and others unclosed. Our logic fixes the prompt. - const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), - "", ""); + auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"); std::vector end_tokens{ "<|end▁of▁sentence|>" }; - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format); - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format); - test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think), + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + assert_msg_equals(message_assist_thoughts_unparsed_think, common_chat_parse("I'm thinkingHello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1)); - assert_msg_equals(msg_from_json(message_assist_thoughts), + assert_msg_equals(message_assist_thoughts, common_chat_parse("I'm thinkingHello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); - assert_msg_equals(msg_from_json(message_assist_thoughts), + assert_msg_equals(message_assist_thoughts, // Latest template update (ast of 20250209) adds a trailing \n if add_generation_prompt is true. common_chat_parse("I'm thinkingHello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); - // test_template(tmpl, end_tokens, message_assist_call, tools, + // test_templates(tmpls.get(), end_tokens, message_assist_call, tools, // "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" // "```json\n" // "{\"arg1\": 1}\n" @@ -697,23 +749,22 @@ static void test_template_output_parsers() { } { // Replacement DeepSeek R1 template. Makes the Distill Qwen 7B/32B models happy to call tools and all. - const common_chat_template tmpl(read_file("models/templates/llama-cpp-deepseek-r1.jinja"), - "", ""); + auto tmpls = read_templates("models/templates/llama-cpp-deepseek-r1.jinja"); std::vector end_tokens{ "<|end▁of▁sentence|>" }; - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format); - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format); - test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think), + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + assert_msg_equals(message_assist_thoughts_unparsed_think, common_chat_parse("I'm thinkingHello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1)); - assert_msg_equals(msg_from_json(message_assist_thoughts), + assert_msg_equals(message_assist_thoughts, common_chat_parse("I'm thinkingHello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); - assert_msg_equals(msg_from_json(message_assist_call_thoughts_unparsed), + assert_msg_equals(message_assist_call_thoughts_unparsed, common_chat_parse( "I'm\nthinking\n\n" "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" @@ -721,7 +772,7 @@ static void test_template_output_parsers() { "{\"arg1\": 1}\n" "```<|tool▁call▁end|><|tool▁calls▁end|>", COMMON_CHAT_FORMAT_DEEPSEEK_R1)); - assert_msg_equals(msg_from_json(message_assist_call_thoughts), + assert_msg_equals(message_assist_call_thoughts, common_chat_parse( "I'm\nthinking\n\n" "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" @@ -729,7 +780,7 @@ static void test_template_output_parsers() { "{\"arg1\": 1}\n" "```<|tool▁call▁end|><|tool▁calls▁end|>", COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); - test_template(tmpl, end_tokens, message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" "```json\n" "{\"arg1\": 1}\n" @@ -738,38 +789,46 @@ static void test_template_output_parsers() { } int main(int argc, char ** argv) { + try { #ifndef _WIN32 - if (argc > 1) { - common_chat_inputs inputs; - inputs.messages = { - { { "role", "user" }, { "content", "Hey" } } - }; - inputs.tools = json::array({ special_function_tool }); - - std::cout << "| Template | Format |\n"; - std::cout << "|----------|--------|\n"; - - for (int i = 1; i < argc; i++) { - try { - std::string path = argv[i]; - if (path.rfind(".jinja") != path.size() - 6) { - std::cerr << "Skipping non-jinja file: " << path << std::endl; - continue; + if (argc > 1) { + common_chat_templates_inputs inputs; + common_chat_msg msg; + msg.role = "user"; + msg.content = "Hey"; + inputs.messages = {msg}; + inputs.tools = { special_function_tool }; + + std::cout << "| Template | Format |\n"; + std::cout << "|----------|--------|\n"; + + for (int i = 1; i < argc; i++) { + try { + std::string path = argv[i]; + if (path.rfind(".jinja") != path.size() - 6) { + std::cerr << "Skipping non-jinja file: " << path << '\n'; + continue; + } + auto tmpls = read_templates(path); + auto parts = string_split(path, "/"); + auto name = parts[parts.size() - 1]; + auto format = common_chat_format_name(common_chat_templates_apply(tmpls.get(), inputs).format); + std::cout << "| " << name << " | " << format << " |\n"; + } catch (const std::exception & e) { + std::cerr << "Failed to process " << argv[i] << ": " << e.what() << '\n'; } - common_chat_template tmpl(read_file(path), "", ""); - auto parts = string_split(path, "/"); - auto name = parts[parts.size() - 1]; - auto format = common_chat_format_name(common_chat_params_init(tmpl, inputs).format); - std::cout << "| " << name << " | " << format << " |\n"; - } catch (const std::exception & e) { - std::cerr << "Failed to process " << argv[i] << ": " << e.what() << std::endl; } - } - } else + } else #endif - { - test_template_output_parsers(); - std::cout << "\n[chat] All tests passed!" << std::endl; + { + test_msgs_oaicompat_json_conversion(); + test_tools_oaicompat_json_conversion(); + test_template_output_parsers(); + std::cout << "\n[chat] All tests passed!" << '\n'; + } + return 0; + } catch (const std::exception & e) { + std::cerr << "Error: " << e.what() << '\n'; + return 1; } - return 0; }