diff --git a/common/common.cpp b/common/common.cpp index 05ee7236c..1538cfcab 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1929,8 +1929,9 @@ minja::chat_template llama_chat_template_from_model( chat_template = _llama_model_meta_val_str(model, "tokenizer.chat_template"); } } - auto bos_token = _common_token_to_piece(model, llama_token_bos(model), true); - auto eos_token = _common_token_to_piece(model, llama_token_eos(model), true); + const auto vocab = llama_model_get_vocab(model); + auto bos_token = common_token_to_piece(vocab, llama_vocab_bos(vocab), true); + auto eos_token = common_token_to_piece(vocab, llama_vocab_eos(vocab), true); return {std::move(chat_template), bos_token, eos_token}; } diff --git a/common/sampling.h b/common/sampling.h index d3a4c3990..e7c0a3dce 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -100,7 +100,7 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx, char common_sampler_type_to_chr(enum common_sampler_type cnstr); std::string common_sampler_type_to_str(enum common_sampler_type cnstr); -bool common_sampler_trigger_grammar(const struct llama_model * model, common_sampler * gsmpl, const std::string & trigger); +bool common_sampler_trigger_grammar(const struct llama_vocab * vocab, common_sampler * gsmpl, const std::string & trigger); std::vector common_sampler_types_from_names(const std::vector & names, bool allow_alt_names); std::vector common_sampler_types_from_chars(const std::string & chars); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a4eaa0e62..a483b9a26 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3729,7 +3729,7 @@ int main(int argc, char ** argv) { const auto handle_props = [&ctx_server, &res_ok, &get_chat_templates](const httplib::Request &, httplib::Response & res) { // this endpoint is publicly available, please only return what is safe to be exposed const auto & templates = get_chat_templates(); - const auto vocab = llama_vocab_from_model(ctx_server.model); + const auto vocab = llama_model_get_vocab(ctx_server.model); json data = { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, @@ -3765,7 +3765,6 @@ int main(int argc, char ** argv) { json & data, httplib::Response & res, oaicompat_type oaicompat, - bool oaicompat_chat = false, llama_tool_call_style tool_call_style = llama_tool_call_style::None) { GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); @@ -3976,7 +3975,8 @@ int main(int argc, char ** argv) { SERVER_TASK_TYPE_COMPLETION, data, res, - OAICOMPAT_TYPE_CHAT); + OAICOMPAT_TYPE_CHAT, + tool_call_style); }; const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 4c40e47d4..4f324c390 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -241,7 +241,7 @@ CODE_INTEPRETER_TOOL = { ]) def test_completion_with_required_tool(template_name: str, n_predict: int, tool: dict, expected_arguments: dict): global server - server.use_jinja = True + server.jinja = True server.n_predict = n_predict server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja' server.start() @@ -278,7 +278,7 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool: ]) def test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): global server - server.use_jinja = True + server.jinja = True server.n_predict = n_predict server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja' server.start() @@ -322,7 +322,7 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools: ]) def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None): global server - server.use_jinja = True + server.jinja = True server.n_ctx = 8192 server.n_predict = 128 server.model_hf_repo = hf_repo diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 6f686dae9..93046b34d 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -157,10 +157,6 @@ class ServerProcess: if self.lora_files: for lora_file in self.lora_files: server_args.extend(["--lora", lora_file]) - if self.chat_template_file: - server_args.extend(["--chat-template-file", self.chat_template_file]) - if self.use_jinja: - server_args.append("--jinja") if self.disable_ctx_shift: server_args.extend(["--no-context-shift"]) if self.api_key: diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 379025045..8f9a7517c 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -595,7 +595,7 @@ static json oaicompat_completion_params_parse( if (has_tools) { if (stream) { throw std::runtime_error("Cannot use tools with stream"); - } + } if (use_jinja) { if (tool_call_style == llama_tool_call_style::UnknownToolCallStyle) { throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template."); diff --git a/include/llama.h b/include/llama.h index e2c548b7b..7a19aac15 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1193,8 +1193,6 @@ extern "C" { const char * grammar_str, const char * grammar_root); - LLAMA_API bool llama_sampler_is_grammar_empty(struct llama_sampler * gsmpl); - /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first. LLAMA_API struct llama_sampler * llama_sampler_init_penalties( int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) @@ -1256,6 +1254,8 @@ extern "C" { // Returns the sampled token LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx); + LLAMA_API bool llama_sampler_is_grammar_empty(struct llama_sampler * smpl); + // TODO: extend in the future //LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 72408faf0..22cf5d76c 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1511,11 +1511,6 @@ static struct llama_sampler_i llama_sampler_grammar_i = { /* .free = */ llama_sampler_grammar_free, }; -bool llama_sampler_is_grammar_empty(struct llama_sampler * gsmpl) { - struct llama_sampler_grammar * ctx = (struct llama_sampler_grammar *) gsmpl->ctx; - return ctx->grammar == nullptr; -} - struct llama_sampler * llama_sampler_init_grammar(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { auto * ctx = new llama_sampler_grammar; diff --git a/src/llama.cpp b/src/llama.cpp index 3779c3979..daf1b7c97 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1130,8 +1130,7 @@ struct llm_build_context { rope_type (hparams.rope_type), cb (cb), buf_compute_meta (lctx.buf_compute_meta) { - // all - ializations should be done in init() + // all initializations should be done in init() } void init() {