diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6acd0c50cddd57..b207d9dae024ba 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -681,25 +681,26 @@ struct server_context { add_bos_token = llama_add_bos_token(model); has_eos_token = !llama_add_eos_token(model); - if (!params.model_draft.empty()) { - SRV_INF("loading draft model '%s'\n", params_.model_draft.c_str()); + if (!params.speculative.model.empty()) { + SRV_INF("loading draft model '%s'\n", params.speculative.model.c_str()); auto params_dft = params; - params_dft.model = params.model_draft; - params_dft.n_gpu_layers = params.n_gpu_layers_draft; + params_dft.model = params.speculative.model; + params_dft.n_ctx = params.speculative.n_ctx; + params_dft.n_gpu_layers = params.speculative.n_gpu_layers; common_init_result llama_init_dft = common_init_from_params(params_dft); model_dft = llama_init_dft.model; if (model_dft == nullptr) { - SRV_ERR("failed to load draft model, '%s'\n", params.model_draft.c_str()); + SRV_ERR("failed to load draft model, '%s'\n", params.speculative.model.c_str()); return false; } if (!common_speculative_are_compatible(ctx, llama_init_dft.context)) { - SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params.model_draft.c_str(), params.model.c_str()); + SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params.speculative.model.c_str(), params.model.c_str()); llama_free (llama_init_dft.context); llama_free_model(llama_init_dft.model); @@ -755,7 +756,7 @@ struct server_context { return; } - slot.batch_spec = llama_batch_init(params.n_draft + 1, 0, 1); + slot.batch_spec = llama_batch_init(params.speculative.n_max + 1, 0, 1); } SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); @@ -2287,13 +2288,14 @@ struct server_context { // TODO: configurable through requests struct common_speculative_params params_spec; - params_spec.n_draft = params.n_draft; + params_spec.n_draft = params.speculative.n_max; params_spec.n_reuse = 256; - params_spec.p_min = 0.9f; + params_spec.p_min = params.speculative.p_min; llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id); - if (params.n_draft_min > (int) draft.size()) { + // ignore small drafts + if (params.speculative.n_min > (int) draft.size()) { continue; } @@ -2321,9 +2323,7 @@ struct server_context { for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; - id = ids[i]; - - result.tok = id; + result.tok = ids[i]; if (!process_token(result, slot)) { // release slot because of stop condition