Fix merge

This commit is contained in:
ochafik 2025-01-14 01:14:35 +00:00
parent e7ff6ecd93
commit 7a7d6f6a22
9 changed files with 14 additions and 23 deletions

View file

@ -1929,8 +1929,9 @@ minja::chat_template llama_chat_template_from_model(
chat_template = _llama_model_meta_val_str(model, "tokenizer.chat_template");
}
}
auto bos_token = _common_token_to_piece(model, llama_token_bos(model), true);
auto eos_token = _common_token_to_piece(model, llama_token_eos(model), true);
const auto vocab = llama_model_get_vocab(model);
auto bos_token = common_token_to_piece(vocab, llama_vocab_bos(vocab), true);
auto eos_token = common_token_to_piece(vocab, llama_vocab_eos(vocab), true);
return {std::move(chat_template), bos_token, eos_token};
}

View file

@ -100,7 +100,7 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx,
char common_sampler_type_to_chr(enum common_sampler_type cnstr);
std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
bool common_sampler_trigger_grammar(const struct llama_model * model, common_sampler * gsmpl, const std::string & trigger);
bool common_sampler_trigger_grammar(const struct llama_vocab * vocab, common_sampler * gsmpl, const std::string & trigger);
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);

View file

@ -3729,7 +3729,7 @@ int main(int argc, char ** argv) {
const auto handle_props = [&ctx_server, &res_ok, &get_chat_templates](const httplib::Request &, httplib::Response & res) {
// this endpoint is publicly available, please only return what is safe to be exposed
const auto & templates = get_chat_templates();
const auto vocab = llama_vocab_from_model(ctx_server.model);
const auto vocab = llama_model_get_vocab(ctx_server.model);
json data = {
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel },
@ -3765,7 +3765,6 @@ int main(int argc, char ** argv) {
json & data,
httplib::Response & res,
oaicompat_type oaicompat,
bool oaicompat_chat = false,
llama_tool_call_style tool_call_style = llama_tool_call_style::None) {
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
@ -3976,7 +3975,8 @@ int main(int argc, char ** argv) {
SERVER_TASK_TYPE_COMPLETION,
data,
res,
OAICOMPAT_TYPE_CHAT);
OAICOMPAT_TYPE_CHAT,
tool_call_style);
};
const auto handle_models = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {

View file

@ -241,7 +241,7 @@ CODE_INTEPRETER_TOOL = {
])
def test_completion_with_required_tool(template_name: str, n_predict: int, tool: dict, expected_arguments: dict):
global server
server.use_jinja = True
server.jinja = True
server.n_predict = n_predict
server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja'
server.start()
@ -278,7 +278,7 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool:
])
def test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
global server
server.use_jinja = True
server.jinja = True
server.n_predict = n_predict
server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja'
server.start()
@ -322,7 +322,7 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
])
def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
global server
server.use_jinja = True
server.jinja = True
server.n_ctx = 8192
server.n_predict = 128
server.model_hf_repo = hf_repo

View file

@ -157,10 +157,6 @@ class ServerProcess:
if self.lora_files:
for lora_file in self.lora_files:
server_args.extend(["--lora", lora_file])
if self.chat_template_file:
server_args.extend(["--chat-template-file", self.chat_template_file])
if self.use_jinja:
server_args.append("--jinja")
if self.disable_ctx_shift:
server_args.extend(["--no-context-shift"])
if self.api_key:

View file

@ -595,7 +595,7 @@ static json oaicompat_completion_params_parse(
if (has_tools) {
if (stream) {
throw std::runtime_error("Cannot use tools with stream");
}
}
if (use_jinja) {
if (tool_call_style == llama_tool_call_style::UnknownToolCallStyle) {
throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template.");

View file

@ -1193,8 +1193,6 @@ extern "C" {
const char * grammar_str,
const char * grammar_root);
LLAMA_API bool llama_sampler_is_grammar_empty(struct llama_sampler * gsmpl);
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size)
@ -1256,6 +1254,8 @@ extern "C" {
// Returns the sampled token
LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx);
LLAMA_API bool llama_sampler_is_grammar_empty(struct llama_sampler * smpl);
// TODO: extend in the future
//LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...);

View file

@ -1511,11 +1511,6 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
/* .free = */ llama_sampler_grammar_free,
};
bool llama_sampler_is_grammar_empty(struct llama_sampler * gsmpl) {
struct llama_sampler_grammar * ctx = (struct llama_sampler_grammar *) gsmpl->ctx;
return ctx->grammar == nullptr;
}
struct llama_sampler * llama_sampler_init_grammar(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
auto * ctx = new llama_sampler_grammar;

View file

@ -1130,8 +1130,7 @@ struct llm_build_context {
rope_type (hparams.rope_type),
cb (cb),
buf_compute_meta (lctx.buf_compute_meta) {
// all
ializations should be done in init()
// all initializations should be done in init()
}
void init() {