Merge branch 'master' into xsn/vision_2

This commit is contained in:
Xuan Son Nguyen 2025-01-22 13:28:31 +01:00
commit 32daa38333
65 changed files with 7551 additions and 952 deletions

View file

@ -345,8 +345,18 @@ struct lora_merge_ctx {
gf = ggml_new_graph(ctx0);
struct ggml_tensor * cur = inp_base;
for (size_t i = 0; i < adapters.size(); ++i) {
struct ggml_tensor * a_T = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32)));
struct ggml_tensor * delta = ggml_mul_mat(ctx0, a_T, ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32));
struct ggml_tensor * delta;
bool is_tok_embd = string_starts_with(name_base, "token_embd");
if (is_tok_embd) {
printf("%s : detected token embeddings tensor\n", __func__);
delta = ggml_mul_mat(ctx0,
ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32),
ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32));
} else {
delta = ggml_mul_mat(ctx0,
ggml_cont(ctx0, ggml_transpose(ctx0, ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32))),
ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32));
}
// scale
const float alpha = adapters[i]->alpha;
const float rank = (float) inp_b[i]->ne[0];

View file

@ -0,0 +1,46 @@
## MiniCPM-o 2.6
Currently, this readme only supports minicpm-omni's image capabilities, and we will update the full-mode support as soon as possible.
### Prepare models and code
Download [MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6) PyTorch model from huggingface to "MiniCPM-o-2_6" folder.
Clone llama.cpp:
```bash
git clone git@github.com:OpenBMB/llama.cpp.git
cd llama.cpp
git checkout minicpm-omni
```
### Usage of MiniCPM-o 2.6
Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-o-2_6-gguf) by us)
```bash
python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-o-2_6
python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-o-2_6 --minicpmv-projector ../MiniCPM-o-2_6/minicpmv.projector --output-dir ../MiniCPM-o-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 4
python ./convert_hf_to_gguf.py ../MiniCPM-o-2_6/model
# quantize int4 version
./llama-quantize ../MiniCPM-o-2_6/model/ggml-model-f16.gguf ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M
```
Build llama.cpp using `CMake`:
https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md
```bash
cmake -B build
cmake --build build --config Release
```
Inference on Linux or Mac
```
# run f16 version
./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
# run quantized int4 version
./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
# or run in interactive mode
./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i
```

View file

@ -718,6 +718,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
else if (ctx->minicpmv_version == 3) {
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
}
else if (ctx->minicpmv_version == 4) {
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
}
ggml_set_name(pos_embed, "pos_embed");
ggml_set_input(pos_embed);
}
@ -1053,6 +1056,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
n_head = hidden_size/d_head;
num_query = 64;
}
else if (ctx->minicpmv_version == 4) {
hidden_size = 3584;
n_head = hidden_size/d_head;
num_query = 64;
}
struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
@ -2041,6 +2049,7 @@ static std::vector<std::vector<clip_image_u8 *>> uhd_slice_image(const clip_imag
images[images.size()-1].push_back(patch);
}
}
clip_image_u8_free(refine_image);
}
return images;
}
@ -2079,6 +2088,13 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
clip_image_f32_free(res);
}
}
for (size_t i = 0; i < imgs.size(); ++i) {
for (size_t j = 0; j < imgs[i].size(); ++j) {
if (imgs[i][j] != nullptr) {
clip_image_u8_free(imgs[i][j]);
}
}
}
return true;
}
else if (ctx->has_qwen2vl_merger) {
@ -2335,6 +2351,9 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
else if (ctx->minicpmv_version == 3) {
n_patches = 64;
}
else if (ctx->minicpmv_version == 4) {
n_patches = 64;
}
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
int patch_size = params.patch_size * 2;
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
@ -2514,8 +2533,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
int* positions_data = (int*)malloc(ggml_nbytes(positions));
int bucket_coords_h[70];
int bucket_coords_w[70];
int bucket_coords_h[1024];
int bucket_coords_w[1024];
for (int i = 0; i < pos_h; i++){
bucket_coords_h[i] = std::floor(70.0*i/pos_h);
}
@ -2543,6 +2562,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
else if (ctx->minicpmv_version == 3) {
embed_dim = 3584;
}
else if (ctx->minicpmv_version == 4) {
embed_dim = 3584;
}
auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
@ -2786,6 +2808,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
else if (ctx->minicpmv_version == 3) {
return 3584;
}
else if (ctx->minicpmv_version == 4) {
return 3584;
}
}
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
return ctx->vision_model.mm_1_b->ne[0];

View file

@ -216,7 +216,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
return true;
}
static clip_image_f32 * only_v2_5_reshape_by_patch(clip_image_f32 * image, int patch_size) {
static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size) {
int width = image->nx;
int height = image->ny;
int num_patches = (height / patch_size) * (width / patch_size);
@ -277,13 +277,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
}
else {
int has_minicpmv_projector = clip_is_minicpmv(ctx_clip);
if (has_minicpmv_projector == 2) {
encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
}
else if (has_minicpmv_projector == 3) {
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
}
encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
}
if (!encoded) {
@ -313,6 +307,9 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
load_image_size->height = img->ny;
clip_add_load_image_size(ctx_clip, load_image_size);
LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height);
delete[] img_res_v.data;
img_res_v.size = 0;
img_res_v.data = nullptr;
}
else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) {
// flat / default llava-1.5 type embedding

View file

@ -140,6 +140,9 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
else if (has_minicpmv_projector == 3) {
system_prompt = "<|im_start|>user\n";
}
else if (has_minicpmv_projector == 4) {
system_prompt = "<|im_start|>user\n";
}
LOG_INF("%s: image token past: %d\n", __func__, n_past);
eval_string(ctx_llava->ctx_llama, (system_prompt+"<image>").c_str(), params->n_batch, &n_past, false);
process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
@ -227,6 +230,9 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm
else if (has_minicpmv_projector == 3) {
user_prompt = "<|im_start|>user\n" + prompt;
}
else if (has_minicpmv_projector == 4) {
user_prompt = "<|im_start|>user\n" + prompt;
}
}
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);
@ -236,6 +242,9 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm
else if (has_minicpmv_projector == 3) {
eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
}
else if (has_minicpmv_projector == 4) {
eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
}
// generate the response
@ -308,7 +317,6 @@ int main(int argc, char ** argv) {
const auto * tmp = llama_loop(ctx_llava, smpl, n_past);
response += tmp;
if (strcmp(tmp, "</s>") == 0) break;
if (strstr(tmp, "###")) break; // Yi-VL behavior
printf("%s", tmp);// mistral llava-1.6
if (strstr(response.c_str(), "<user>")) break; // minicpm-v
fflush(stdout);

View file

@ -501,7 +501,7 @@ default_image_mean = [0.48145466, 0.4578275, 0.40821073]
default_image_std = [0.26862954, 0.26130258, 0.27577711]
ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3', default=2)
ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3; MiniCPM-o-2.6 use 4', default=2)
# with proper
args = ap.parse_args()
@ -545,12 +545,19 @@ if args.use_f32:
minicpmv_version = args.minicpmv_version
emb_dim = 4096
block_count = 26
if minicpmv_version == 1:
emb_dim = 2304
block_count = 26
elif minicpmv_version == 2:
emb_dim = 4096
block_count = 27
elif minicpmv_version == 3:
emb_dim = 3584
block_count = 27
elif minicpmv_version == 4:
emb_dim = 3584
block_count = 27
default_vision_config = {
"hidden_size": 1152,
@ -567,6 +574,9 @@ model = Idefics2VisionTransformer(vision_config)
if minicpmv_version == 3:
vision_config = SiglipVisionConfig(**default_vision_config)
model = SiglipVisionTransformer(vision_config)
elif minicpmv_version == 4:
vision_config = SiglipVisionConfig(**default_vision_config)
model = SiglipVisionTransformer(vision_config)
processor = None
# if model.attn_pool is not None:
@ -587,7 +597,7 @@ elif args.minicpmv_projector is not None:
fname_middle = "mmproj-"
has_text_encoder = False
has_minicpmv_projector = True
minicpmv_version = 3
minicpmv_version = 4
elif args.vision_only:
fname_middle = "vision-"
has_text_encoder = False
@ -625,7 +635,6 @@ if has_vision_encoder:
fout.add_uint32("clip.vision.projection_dim", 0)
fout.add_uint32(add_key_str(KEY_ATTENTION_HEAD_COUNT, VISION), 16)
fout.add_float32(add_key_str(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
block_count = 26
fout.add_uint32(add_key_str(KEY_BLOCK_COUNT, VISION), block_count)
if processor is not None:

View file

@ -8,7 +8,7 @@ ap.add_argument("-m", "--model", help="Path to MiniCPM-V model")
args = ap.parse_args()
# find the model part that includes the the multimodal projector weights
model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True)
model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True, torch_dtype=torch.bfloat16)
checkpoint = model.state_dict()
# get a list of mm tensor names

View file

@ -4,6 +4,7 @@
#include "log.h"
#include "sampling.h"
#include "llama.h"
#include "chat-template.hpp"
#include <cstdio>
#include <cstring>
@ -84,14 +85,6 @@ static void sigint_handler(int signo) {
}
#endif
static std::string chat_add_and_format(struct llama_model * model, std::vector<common_chat_msg> & chat_msgs, const std::string & role, const std::string & content) {
common_chat_msg new_msg{role, content};
auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
chat_msgs.push_back({role, content});
LOG_DBG("formatted: '%s'\n", formatted.c_str());
return formatted;
}
int main(int argc, char ** argv) {
common_params params;
g_params = &params;
@ -165,6 +158,7 @@ int main(int argc, char ** argv) {
}
const llama_vocab * vocab = llama_model_get_vocab(model);
auto chat_templates = common_chat_templates_from_model(model, params.chat_template);
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
@ -207,7 +201,7 @@ int main(int argc, char ** argv) {
}
// auto enable conversation mode if chat template is available
const bool has_chat_template = !common_get_builtin_chat_template(model).empty() || !params.chat_template.empty();
const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.template_default;
if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) {
if (has_chat_template) {
LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
@ -225,7 +219,7 @@ int main(int argc, char ** argv) {
// print chat template example in conversation mode
if (params.conversation_mode) {
if (params.enable_chat_template) {
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str());
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str());
} else {
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
}
@ -269,10 +263,18 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd_inp;
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
common_chat_msg new_msg{role, content};
auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja);
chat_msgs.push_back({role, content});
LOG_DBG("formatted: '%s'\n", formatted.c_str());
return formatted;
};
{
auto prompt = (params.conversation_mode && params.enable_chat_template)
// format the system prompt in conversation mode (fallback to default if empty)
? chat_add_and_format(model, chat_msgs, "system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt)
? chat_add_and_format("system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt)
// otherwise use the prompt as is
: params.prompt;
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
@ -779,7 +781,7 @@ int main(int argc, char ** argv) {
}
if (params.enable_chat_template) {
chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str());
chat_add_and_format("assistant", assistant_ss.str());
}
is_interacting = true;
LOG("\n");
@ -844,7 +846,7 @@ int main(int argc, char ** argv) {
bool format_chat = params.conversation_mode && params.enable_chat_template;
std::string user_inp = format_chat
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
? chat_add_and_format("user", std::move(buffer))
: std::move(buffer);
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true);

View file

@ -1,5 +1,5 @@
set(TARGET llama-run)
add_executable(${TARGET} run.cpp)
add_executable(${TARGET} run.cpp linenoise.cpp/linenoise.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

View file

@ -0,0 +1,26 @@
Copyright (c) 2010-2014, Salvatore Sanfilippo <antirez at gmail dot com>
Copyright (c) 2010-2013, Pieter Noordhuis <pcnoordhuis at gmail dot com>
Copyright (c) 2025, Eric Curtin <ericcurtin17 at gmail dot com>
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,128 @@
/* linenoise.h -- VERSION 1.0
*
* Guerrilla line editing library against the idea that a line editing lib
* needs to be 20,000 lines of C++ code.
*
* See linenoise.cpp for more information.
*
* ------------------------------------------------------------------------
*
* Copyright (c) 2010-2023, Salvatore Sanfilippo <antirez at gmail dot com>
* Copyright (c) 2010-2013, Pieter Noordhuis <pcnoordhuis at gmail dot com>
* Copyright (c) 2025, Eric Curtin <ericcurtin17 at gmail dot com>
*
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
*
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef __LINENOISE_H
#define __LINENOISE_H
#ifdef __cplusplus
extern "C" {
#endif
#include <stddef.h> /* For size_t. */
#include <stdlib.h>
extern const char *linenoiseEditMore;
/* The linenoiseState structure represents the state during line editing.
* We pass this state to functions implementing specific editing
* functionalities. */
struct linenoiseState {
int in_completion; /* The user pressed TAB and we are now in completion
* mode, so input is handled by completeLine(). */
size_t completion_idx; /* Index of next completion to propose. */
int ifd; /* Terminal stdin file descriptor. */
int ofd; /* Terminal stdout file descriptor. */
char *buf; /* Edited line buffer. */
size_t buflen; /* Edited line buffer size. */
const char *prompt; /* Prompt to display. */
size_t plen; /* Prompt length. */
size_t pos; /* Current cursor position. */
size_t oldpos; /* Previous refresh cursor position. */
size_t len; /* Current edited line length. */
size_t cols; /* Number of columns in terminal. */
size_t oldrows; /* Rows used by last refrehsed line (multiline mode) */
int history_index; /* The history index we are currently editing. */
};
struct linenoiseCompletions {
size_t len = 0;
char ** cvec = nullptr;
bool to_free = true;
~linenoiseCompletions() {
if (!to_free) {
return;
}
for (size_t i = 0; i < len; ++i) {
free(cvec[i]);
}
free(cvec);
}
};
/* Non blocking API. */
int linenoiseEditStart(struct linenoiseState *l, int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt);
const char *linenoiseEditFeed(struct linenoiseState *l);
void linenoiseEditStop(struct linenoiseState *l);
void linenoiseHide(struct linenoiseState *l);
void linenoiseShow(struct linenoiseState *l);
/* Blocking API. */
const char *linenoise(const char *prompt);
void linenoiseFree(void *ptr);
/* Completion API. */
typedef void(linenoiseCompletionCallback)(const char *, linenoiseCompletions *);
typedef const char*(linenoiseHintsCallback)(const char *, int *color, int *bold);
typedef void(linenoiseFreeHintsCallback)(const char *);
void linenoiseSetCompletionCallback(linenoiseCompletionCallback *);
void linenoiseSetHintsCallback(linenoiseHintsCallback *);
void linenoiseSetFreeHintsCallback(linenoiseFreeHintsCallback *);
void linenoiseAddCompletion(linenoiseCompletions *, const char *);
/* History API. */
int linenoiseHistoryAdd(const char *line);
int linenoiseHistorySetMaxLen(int len);
int linenoiseHistorySave(const char *filename);
int linenoiseHistoryLoad(const char *filename);
/* Other utilities. */
void linenoiseClearScreen(void);
void linenoiseSetMultiLine(int ml);
void linenoisePrintKeyCodes(void);
void linenoiseMaskModeEnable(void);
void linenoiseMaskModeDisable(void);
#ifdef __cplusplus
}
#endif
#endif /* __LINENOISE_H */

View file

@ -19,13 +19,16 @@
#include <cstring>
#include <filesystem>
#include <iostream>
#include <list>
#include <sstream>
#include <string>
#include <vector>
#include "common.h"
#include "json.hpp"
#include "linenoise.cpp/linenoise.h"
#include "llama-cpp.h"
#include "chat-template.hpp"
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32)
[[noreturn]] static void sigint_handler(int) {
@ -103,6 +106,7 @@ class Opt {
llama_model_params model_params;
std::string model_;
std::string user;
bool use_jinja = false;
int context_size = -1, ngl = -1;
float temperature = -1;
bool verbose = false;
@ -154,6 +158,8 @@ class Opt {
} else if (options_parsing &&
(parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) {
verbose = true;
} else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
use_jinja = true;
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
help = true;
return 0;
@ -536,7 +542,7 @@ class LlamaData {
llama_sampler_ptr sampler;
llama_context_ptr context;
std::vector<llama_chat_message> messages;
std::vector<std::string> msg_strs;
std::list<std::string> msg_strs;
std::vector<char> fmtted;
int init(Opt & opt) {
@ -711,13 +717,31 @@ static void add_message(const char * role, const std::string & text, LlamaData &
}
// Function to apply the chat template and resize `formatted` if needed
static int apply_chat_template(LlamaData & llama_data, const bool append) {
static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
if (use_jinja) {
json messages = json::array();
for (const auto & msg : llama_data.messages) {
messages.push_back({
{"role", msg.role},
{"content", msg.content},
});
}
try {
auto result = tmpl.apply(messages, /* tools= */ json(), append);
llama_data.fmtted.resize(result.size() + 1);
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
return result.size();
} catch (const std::exception & e) {
printe("failed to render the chat template: %s\n", e.what());
return -1;
}
}
int result = llama_chat_apply_template(
llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append,
tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
llama_data.fmtted.resize(result);
result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(),
result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
llama_data.messages.size(), append, llama_data.fmtted.data(),
llama_data.fmtted.size());
}
@ -727,10 +751,12 @@ static int apply_chat_template(LlamaData & llama_data, const bool append) {
// Function to tokenize the prompt
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
std::vector<llama_token> & prompt_tokens) {
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
const bool is_first = llama_get_kv_cache_used_cells(llama_data.context.get()) == 0;
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
prompt_tokens.resize(n_prompt_tokens);
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first,
true) < 0) {
printe("failed to tokenize the prompt\n");
return -1;
@ -776,7 +802,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get());
std::vector<llama_token> tokens;
if (tokenize_prompt(vocab, prompt, tokens) < 0) {
if (tokenize_prompt(vocab, prompt, tokens, llama_data) < 0) {
return 1;
}
@ -807,24 +833,44 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
batch = llama_batch_get_one(&new_token_id, 1);
}
printf("\033[0m");
return 0;
}
static int read_user_input(std::string & user) {
std::getline(std::cin, user);
static int read_user_input(std::string & user_input) {
static const char * prompt_prefix = "> ";
#ifdef WIN32
printf(
"\r%*s"
"\r\033[0m%s",
get_terminal_width(), " ", prompt_prefix);
std::getline(std::cin, user_input);
if (std::cin.eof()) {
printf("\n");
return 1;
}
if (user == "/bye") {
#else
std::unique_ptr<char, decltype(&std::free)> line(const_cast<char *>(linenoise(prompt_prefix)), free);
if (!line) {
return 1;
}
if (user.empty()) {
user_input = line.get();
#endif
if (user_input == "/bye") {
return 1;
}
if (user_input.empty()) {
return 2;
}
#ifndef WIN32
linenoiseHistoryAdd(line.get());
#endif
return 0; // Should have data in happy path
}
@ -847,8 +893,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
}
// Helper function to apply the chat template and handle errors
static int apply_chat_template_with_error_handling(LlamaData & llama_data, const bool append, int & output_length) {
const int new_len = apply_chat_template(llama_data, append);
static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
if (new_len < 0) {
printe("failed to apply the chat template\n");
return -1;
@ -865,10 +911,6 @@ static int handle_user_input(std::string & user_input, const std::string & user)
return 0; // No need for interactive input
}
printf(
"\r%*s"
"\r\033[32m> \033[0m",
get_terminal_width(), " ");
return read_user_input(user_input); // Returns true if input ends the loop
}
@ -911,9 +953,11 @@ static int get_user_input(std::string & user_input, const std::string & user) {
}
// Main chat loop function
static int chat_loop(LlamaData & llama_data, const std::string & user) {
static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
int prev_len = 0;
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), "");
GGML_ASSERT(chat_templates.template_default);
static const bool stdout_a_terminal = is_stdout_a_terminal();
while (true) {
// Get user input
@ -924,7 +968,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
add_message("user", user.empty() ? user_input : user, llama_data);
int new_len;
if (apply_chat_template_with_error_handling(llama_data, true, new_len) < 0) {
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) {
return 1;
}
@ -939,7 +983,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
}
add_message("assistant", response, llama_data);
if (apply_chat_template_with_error_handling(llama_data, false, prev_len) < 0) {
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) {
return 1;
}
}
@ -999,7 +1043,7 @@ int main(int argc, const char ** argv) {
return 1;
}
if (chat_loop(llama_data, opt.user)) {
if (chat_loop(llama_data, opt.user, opt.use_jinja)) {
return 1;
}

View file

@ -126,7 +126,7 @@ The project is under active development, and we are [looking for feedback and co
| `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') |
| `--grammar-file FNAME` | file to read grammar from |
| `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object<br/>For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead |
| `--jinja` | Enable experimental Jinja templating engine (needed for tool use) |
**Example-specific params**

File diff suppressed because it is too large Load diff

View file

@ -19,6 +19,7 @@
#include "loading.html.hpp"
#include <atomic>
#include <chrono>
#include <condition_variable>
#include <cstddef>
#include <cinttypes>
@ -32,6 +33,8 @@
using json = nlohmann::ordered_json;
constexpr int HTTP_POLLING_SECONDS = 1;
enum stop_type {
STOP_TYPE_NONE,
STOP_TYPE_EOS,
@ -264,6 +267,11 @@ struct server_task {
params.speculative.n_min = std::max(params.speculative.n_min, 2);
params.speculative.n_max = std::max(params.speculative.n_max, 0);
// Use OpenAI API logprobs only if n_probs wasn't provided
if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);
}
if (data.contains("lora")) {
if (data.at("lora").is_array()) {
params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora"));
@ -1602,6 +1610,30 @@ struct server_response {
// should never reach here
}
// same as recv(), but have timeout in seconds
// if timeout is reached, nullptr is returned
server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) {
while (true) {
std::unique_lock<std::mutex> lock(mutex_results);
bool cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout), [&]{
return !queue_results.empty();
});
if (!cr_res) {
return nullptr;
}
for (int i = 0; i < (int) queue_results.size(); i++) {
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
server_task_result_ptr res = std::move(queue_results[i]);
queue_results.erase(queue_results.begin() + i);
return res;
}
}
}
// should never reach here
}
// single-task version of recv()
server_task_result_ptr recv(int id_task) {
std::unordered_set<int> id_tasks = {id_task};
@ -1661,6 +1693,8 @@ struct server_context {
// Necessary similarity of prompt for slot selection
float slot_prompt_similarity = 0.0f;
common_chat_templates chat_templates;
~server_context() {
// Clear any sampling context
for (server_slot & slot : slots) {
@ -1701,13 +1735,16 @@ struct server_context {
add_bos_token = llama_vocab_get_add_bos(vocab);
has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
if (!params_base.speculative.model.empty()) {
if (!params_base.speculative.model.empty() || !params_base.speculative.hf_repo.empty()) {
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
auto params_dft = params_base;
params_dft.devices = params_base.speculative.devices;
params_dft.hf_file = params_base.speculative.hf_file;
params_dft.hf_repo = params_base.speculative.hf_repo;
params_dft.model = params_base.speculative.model;
params_dft.model_url = params_base.speculative.model_url;
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
params_dft.n_parallel = 1;
@ -1737,14 +1774,39 @@ struct server_context {
cparams_dft.type_v = GGML_TYPE_F16;
}
chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
GGML_ASSERT(chat_templates.template_default.get() != nullptr);
return true;
}
bool validate_builtin_chat_template() const {
bool validate_builtin_chat_template(bool use_jinja) const {
llama_chat_message chat[] = {{"user", "test"}};
const char * tmpl = llama_model_chat_template(model);
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
return chat_res > 0;
if (use_jinja) {
auto templates = common_chat_templates_from_model(model, "");
GGML_ASSERT(templates.template_default);
try {
templates.template_default->apply({{
{"role", "user"},
{"content", "test"},
}}, json(), true);
if (templates.template_tool_use) {
templates.template_tool_use->apply({{
{"role", "user"},
{"content", "test"},
}}, json(), true);
}
return true;
} catch (const std::exception & e) {
SRV_ERR("failed to apply template: %s\n", e.what());
return false;
}
} else {
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
return chat_res > 0;
}
}
void init() {
@ -2322,10 +2384,21 @@ struct server_context {
void receive_multi_results(
const std::unordered_set<int> & id_tasks,
const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
const std::function<void(json)> & error_handler) {
const std::function<void(json)> & error_handler,
const std::function<bool()> & is_connection_closed) {
std::vector<server_task_result_ptr> results(id_tasks.size());
for (size_t i = 0; i < id_tasks.size(); i++) {
server_task_result_ptr result = queue_results.recv(id_tasks);
for (int i = 0; i < (int)id_tasks.size(); i++) {
server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
if (is_connection_closed()) {
cancel_tasks(id_tasks);
return;
}
if (result == nullptr) {
i--; // retry
continue;
}
if (result->is_error()) {
error_handler(result->to_json());
@ -2349,10 +2422,20 @@ struct server_context {
void receive_cmpl_results_stream(
const std::unordered_set<int> & id_tasks,
const std::function<bool(server_task_result_ptr&)> & result_handler,
const std::function<void(json)> & error_handler) {
const std::function<void(json)> & error_handler,
const std::function<bool()> & is_connection_closed) {
size_t n_finished = 0;
while (true) {
server_task_result_ptr result = queue_results.recv(id_tasks);
server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
if (is_connection_closed()) {
cancel_tasks(id_tasks);
return;
}
if (result == nullptr) {
continue; // retry
}
if (result->is_error()) {
error_handler(result->to_json());
@ -3609,9 +3692,12 @@ int main(int argc, char ** argv) {
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel },
{ "model_path", ctx_server.params_base.model },
{ "chat_template", common_get_builtin_chat_template(ctx_server.model) },
{ "chat_template", ctx_server.chat_templates.template_default->source() },
{ "build_info", build_info },
};
if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) {
data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source();
}
res_ok(res, data);
};
@ -3634,6 +3720,7 @@ int main(int argc, char ** argv) {
const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
server_task_type type,
json & data,
std::function<bool()> is_connection_closed,
httplib::Response & res,
oaicompat_type oaicompat) {
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
@ -3695,7 +3782,7 @@ int main(int argc, char ** argv) {
}
}, [&](const json & error_data) {
res_error(res, error_data);
});
}, is_connection_closed);
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else {
@ -3705,6 +3792,7 @@ int main(int argc, char ** argv) {
if (res_json.is_array()) {
for (const auto & res : res_json) {
if (!server_sent_event(sink, "data", res)) {
// sending failed (HTTP connection closed), cancel the generation
return false;
}
}
@ -3714,6 +3802,9 @@ int main(int argc, char ** argv) {
}
}, [&](const json & error_data) {
server_sent_event(sink, "error", error_data);
}, [&sink]() {
// note: do not use req.is_connection_closed here because req is already destroyed
return !sink.is_writable();
});
if (oaicompat != OAICOMPAT_TYPE_NONE) {
static const std::string ev_done = "data: [DONE]\n\n";
@ -3736,6 +3827,7 @@ int main(int argc, char ** argv) {
return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
data,
req.is_connection_closed,
res,
OAICOMPAT_TYPE_NONE);
};
@ -3745,6 +3837,7 @@ int main(int argc, char ** argv) {
return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
data,
req.is_connection_closed,
res,
OAICOMPAT_TYPE_COMPLETION);
};
@ -3821,6 +3914,7 @@ int main(int argc, char ** argv) {
return handle_completions_impl(
SERVER_TASK_TYPE_INFILL,
data,
req.is_connection_closed,
res,
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
};
@ -3831,10 +3925,14 @@ int main(int argc, char ** argv) {
return;
}
json data = oaicompat_chat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
auto body = json::parse(req.body);
const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default;
json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja);
return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
data,
req.is_connection_closed,
res,
OAICOMPAT_TYPE_CHAT);
};
@ -3981,7 +4079,7 @@ int main(int argc, char ** argv) {
}, [&](const json & error_data) {
res_error(res, error_data);
error = true;
});
}, req.is_connection_closed);
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
}
@ -4071,7 +4169,7 @@ int main(int argc, char ** argv) {
}, [&](const json & error_data) {
res_error(res, error_data);
error = true;
});
}, req.is_connection_closed);
}
if (error) {
@ -4240,7 +4338,7 @@ int main(int argc, char ** argv) {
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
if (params.chat_template.empty()) {
if (!ctx_server.validate_builtin_chat_template()) {
if (!ctx_server.validate_builtin_chat_template(params.use_jinja)) {
LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
params.chat_template = "chatml";
}
@ -4248,8 +4346,8 @@ int main(int argc, char ** argv) {
// print sample chat example to make it clear which template is used
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
params.chat_template.empty() ? "(built-in)" : params.chat_template.c_str(),
common_chat_format_example(ctx_server.model, params.chat_template).c_str());
ctx_server.chat_templates.template_default->source().c_str(),
common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str());
ctx_server.queue_tasks.on_new_task(std::bind(
&server_context::process_single_task, &ctx_server, std::placeholders::_1));

View file

@ -4,22 +4,26 @@ from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
@pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
[
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False, None),
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True, None),
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
]
)
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
global server
server.jinja = jinja
server.chat_template = chat_template
server.start()
res = server.make_request("POST", "/chat/completions", data={
"model": model,

View file

@ -1,4 +1,5 @@
import pytest
import requests
import time
from openai import OpenAI
from utils import *
@ -405,3 +406,23 @@ def test_n_probs_post_sampling():
assert "bytes" in prob and type(prob["bytes"]) == list
# because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
def test_cancel_request():
global server
server.n_ctx = 4096
server.n_predict = -1
server.n_slots = 1
server.server_slots = True
server.start()
# send a request that will take a long time, but cancel it before it finishes
try:
server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
}, timeout=0.1)
except requests.exceptions.ReadTimeout:
pass # expected
# make sure the slot is free
time.sleep(1) # wait for HTTP_POLLING_SECONDS
res = server.make_request("GET", "/slots")
assert res.body[0]["is_processing"] == False

View file

@ -26,6 +26,9 @@ from re import RegexFlag
import wget
DEFAULT_HTTP_TIMEOUT = 10 if "LLAMA_SANITIZE" not in os.environ else 30
class ServerResponse:
headers: dict
status_code: int
@ -69,13 +72,14 @@ class ServerProcess:
pooling: str | None = None
draft: int | None = None
api_key: str | None = None
response_format: str | None = None
lora_files: List[str] | None = None
disable_ctx_shift: int | None = False
draft_min: int | None = None
draft_max: int | None = None
no_webui: bool | None = None
jinja: bool | None = None
chat_template: str | None = None
chat_template_file: str | None = None
# session variables
process: subprocess.Popen | None = None
@ -88,7 +92,7 @@ class ServerProcess:
if "PORT" in os.environ:
self.server_port = int(os.environ["PORT"])
def start(self, timeout_seconds: int = 10) -> None:
def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
if "LLAMA_SERVER_BIN_PATH" in os.environ:
server_path = os.environ["LLAMA_SERVER_BIN_PATH"]
elif os.name == "nt":
@ -166,8 +170,12 @@ class ServerProcess:
server_args.extend(["--draft-min", self.draft_min])
if self.no_webui:
server_args.append("--no-webui")
if self.jinja:
server_args.append("--jinja")
if self.chat_template:
server_args.extend(["--chat-template", self.chat_template])
if self.chat_template_file:
server_args.extend(["--chat-template-file", self.chat_template_file])
args = [str(arg) for arg in [server_path, *server_args]]
print(f"bench: starting server with: {' '.join(args)}")
@ -219,17 +227,18 @@ class ServerProcess:
path: str,
data: dict | Any | None = None,
headers: dict | None = None,
timeout: float | None = None,
) -> ServerResponse:
url = f"http://{self.server_host}:{self.server_port}{path}"
parse_body = False
if method == "GET":
response = requests.get(url, headers=headers)
response = requests.get(url, headers=headers, timeout=timeout)
parse_body = True
elif method == "POST":
response = requests.post(url, headers=headers, json=data)
response = requests.post(url, headers=headers, json=data, timeout=timeout)
parse_body = True
elif method == "OPTIONS":
response = requests.options(url, headers=headers)
response = requests.options(url, headers=headers, timeout=timeout)
else:
raise ValueError(f"Unimplemented method: {method}")
result = ServerResponse()

View file

@ -16,6 +16,8 @@
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
#include "minja.hpp"
#include "chat-template.hpp"
#include <random>
#include <sstream>
@ -349,7 +351,7 @@ static llama_tokens format_infill(
}
// Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
inline std::string format_chat(const common_chat_template & tmpl, const std::vector<json> & messages) {
std::vector<common_chat_msg> chat;
for (size_t i = 0; i < messages.size(); ++i) {
@ -377,7 +379,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
chat.push_back({role, content});
}
const auto formatted_chat = common_chat_apply_template(model, tmpl, chat, true);
const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
return formatted_chat;
@ -576,14 +578,23 @@ static json oaicompat_completion_params_parse(const json & body) {
return llama_params;
}
static json oaicompat_chat_completion_params_parse(
const struct llama_model * model,
const json & body, /* openai api json semantics */
const std::string & chat_template) {
static json oaicompat_completion_params_parse(
const json & body, /* openai api json semantics */
const common_chat_template & tmpl,
bool use_jinja)
{
json llama_params;
// Apply chat template to the list of messages
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
auto tools = json_value(body, "tools", json());
auto has_tools = tools.is_array() && !tools.empty();
if (has_tools) {
if (use_jinja) {
LOG_WRN("tools param is not fully supported yet\n");
} else {
throw std::runtime_error("tools param requires --jinja flag");
}
}
// Handle "stop" field
if (body.contains("stop") && body.at("stop").is_string()) {
@ -606,6 +617,13 @@ static json oaicompat_chat_completion_params_parse(
}
}
// Apply chat template to the list of messages
if (use_jinja) {
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
} else {
llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
}
// Handle "n" field
int n_choices = json_value(body, "n", 1);
if (n_choices != 1) {
@ -621,7 +639,7 @@ static json oaicompat_chat_completion_params_parse(
}
// Params supported by OAI but unsupported by llama.cpp
static const std::vector<std::string> unsupported_params { "tools", "tool_choice" };
static const std::vector<std::string> unsupported_params { "tool_choice" };
for (const auto & param : unsupported_params) {
if (body.contains(param)) {
throw std::runtime_error("Unsupported param: " + param);

View file

@ -98,10 +98,12 @@ int main(int argc, char ** argv) {
auto generate = [&](const std::string & prompt) {
std::string response;
const bool is_first = llama_get_kv_cache_used_cells(ctx) == 0;
// tokenize the prompt
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
std::vector<llama_token> prompt_tokens(n_prompt_tokens);
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) {
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, true) < 0) {
GGML_ABORT("failed to tokenize the prompt\n");
}
@ -161,7 +163,7 @@ int main(int argc, char ** argv) {
break;
}
const char * tmpl = llama_model_chat_template(model);
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
// add the user input to the message list and format it
messages.push_back({"user", strdup(user.c_str())});

View file

@ -425,6 +425,33 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
}
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) {
const std::string& delimiter = "<|text_sep|>";
std::vector<llama_token> result;
size_t start = 0;
size_t end = str.find(delimiter);
//first token is always a newline, as it was not previously added
result.push_back(common_tokenize(vocab, "\n", false, true)[0]);
while (end != std::string::npos) {
std::string current_word = str.substr(start, end - start);
auto tmp = common_tokenize(vocab, current_word, false, true);
result.push_back(tmp[0]);
start = end + delimiter.length();
end = str.find(delimiter, start);
}
// Add the last part
std::string current_word = str.substr(start);
auto tmp = common_tokenize(vocab, current_word, false, true);
if (tmp.size() > 0) {
result.push_back(tmp[0]);
}
return result;
}
int main(int argc, char ** argv) {
common_params params;
@ -494,6 +521,7 @@ int main(int argc, char ** argv) {
const auto t_main_start = ggml_time_us();
std::vector<llama_token> codes;
std::vector<llama_token> guide_tokens;
// process prompt and generate voice codes
{
@ -508,6 +536,9 @@ int main(int argc, char ** argv) {
// convert the input text into the necessary format expected by OuteTTS
{
std::string prompt_clean = process_text(params.prompt);
if (params.vocoder.use_guide_tokens) {
guide_tokens = prepare_guide_tokens(vocab, prompt_clean);
}
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
@ -717,6 +748,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
int n_past = batch.n_tokens;
int n_decode = 0;
bool next_token_uses_guide_token = true;
while (n_decode <= n_predict) {
// prepare the next batch
common_batch_clear(batch);
@ -728,7 +761,17 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
continue;
}
const llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
if (!guide_tokens.empty() && next_token_uses_guide_token && !llama_vocab_is_control(vocab, new_token_id) && !llama_vocab_is_eog(vocab, new_token_id)) {
llama_token guide_token = guide_tokens[0];
guide_tokens.erase(guide_tokens.begin());
new_token_id = guide_token; //ensure correct word fragment is used
}
//this is the token id that always precedes a new word
next_token_uses_guide_token = (new_token_id == 198);
common_sampler_accept(smpl[i], new_token_id, true);