From 1bef571f6a23c36a26dabacba631763f9a893b83 Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Tue, 4 Feb 2025 18:16:20 +0200 Subject: [PATCH 1/5] arg : list RPC devices first when using --list-devices (#11655) List devices in the same order as they appear when evaluating the model and splitting tensors across devices, i.e. RPC devices come first in the list. ref #11435 --- common/arg.cpp | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index f5e9b294f..76b898881 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1465,15 +1465,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--list-devices"}, "print list of available devices and exit", [](common_params &) { - printf("Available devices:\n"); + std::vector rpc_devices; + std::vector all_devices; for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { auto * dev = ggml_backend_dev_get(i); if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) { - size_t free, total; - ggml_backend_dev_memory(dev, &free, &total); - printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024); + ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); + if (ggml_backend_reg_name(reg) == std::string("RPC")) { + rpc_devices.push_back(dev); + } else { + all_devices.push_back(dev); + } } } + // insert RPC devices in front + all_devices.insert(all_devices.begin(), rpc_devices.begin(), rpc_devices.end()); + printf("Available devices:\n"); + for (size_t i = 0; i < all_devices.size(); ++i) { + auto * dev = all_devices[i]; + size_t free, total; + ggml_backend_dev_memory(dev, &free, &total); + printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024); + } exit(0); } )); From 3962fc1a7956dd0afacfbce10fd1a3ffd3ad857e Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Tue, 4 Feb 2025 18:25:42 +0100 Subject: [PATCH 2/5] server : add try..catch to places not covered by set_exception_handler (#11620) * server : add try..catch to places not covered by set_exception_handler * log_server_request: rm try catch, add reminder --- examples/server/server.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e0acc4705..9cdf2058f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3353,6 +3353,8 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp return; } + // reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch + LOG_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status); LOG_DBG("request: %s\n", req.body.c_str()); @@ -3439,9 +3441,13 @@ int main(int argc, char ** argv) { message = "Unknown Exception"; } - json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); - LOG_WRN("got exception: %s\n", formatted_error.dump().c_str()); - res_error(res, formatted_error); + try { + json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); + LOG_WRN("got exception: %s\n", formatted_error.dump().c_str()); + res_error(res, formatted_error); + } catch (const std::exception & e) { + LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str()); + } }); svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) { From 3ec9fd4b77b6aca03a3c2bf678eae3f9517d6904 Mon Sep 17 00:00:00 2001 From: fxzjshm <11426482+fxzjshm@users.noreply.github.com> Date: Wed, 5 Feb 2025 02:18:38 +0800 Subject: [PATCH 3/5] HIP: force max threads per block to be 1024 (#11621) Some old/vendor forked version of llvm still use 256. Explicitly set it to 1024 to align with upstream llvm. Signed-off-by: fxzjshm --- ggml/src/ggml-hip/CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index eb03e10fa..f4a468363 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -46,6 +46,9 @@ endif() message(STATUS "HIP and hipBLAS found") +# Workaround old compilers +set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} --gpu-max-threads-per-block=1024") + file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh") list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h") From fd08255d0dea6625596c0367ee0a11d195f36762 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 4 Feb 2025 22:21:42 +0100 Subject: [PATCH 4/5] CUDA: non-contiguous (RMS) norm support (#11659) * CUDA: non-contiguous (RMS) norm support --------- Co-authored-by: Georgi Gerganov --- ggml/src/ggml-cuda/ggml-cuda.cu | 4 ++ ggml/src/ggml-cuda/norm.cu | 89 ++++++++++++++++++---------- ggml/src/ggml-metal/ggml-metal.m | 5 +- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 + src/llama.cpp | 6 +- tests/test-backend-ops.cpp | 38 ++++++++---- 6 files changed, 97 insertions(+), 47 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index bda10aec1..70a598099 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -38,6 +38,7 @@ #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv6.cuh" #include "ggml-cuda/gla.cuh" +#include "ggml.h" #include #include @@ -3139,6 +3140,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g break; case GGML_OP_NORM: case GGML_OP_RMS_NORM: + return true; case GGML_OP_RMS_NORM_BACK: return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0; break; @@ -3181,7 +3183,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: + return true; case GGML_OP_GROUP_NORM: + return ggml_is_contiguous(op->src[0]); case GGML_OP_UPSCALE: case GGML_OP_PAD: case GGML_OP_ARANGE: diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index d991ec972..f127616ed 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -1,12 +1,20 @@ #include "norm.cuh" +#include template -static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) { - const int row = blockIdx.x*blockDim.y + threadIdx.y; - const int tid = threadIdx.x; +static __global__ void norm_f32( + const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps) { + const int nrows = gridDim.x; + const int nchannels = gridDim.y; - x += int64_t(row)*ncols; - dst += int64_t(row)*ncols; + const int row = blockIdx.x; + const int channel = blockIdx.y; + const int sample = blockIdx.z; + const int tid = threadIdx.x; + + x += sample*stride_sample + channel*stride_channel + row*stride_row; + dst += ((sample*nchannels + channel)*nrows + row)*ncols; float2 mean_var = make_float2(0.0f, 0.0f); @@ -97,12 +105,19 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr } template -static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { - const int row = blockIdx.x*blockDim.y + threadIdx.y; - const int tid = threadIdx.x; +static __global__ void rms_norm_f32( + const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps) { + const int nrows = gridDim.x; + const int nchannels = gridDim.y; - x += int64_t(row)*ncols; - dst += int64_t(row)*ncols; + const int row = blockIdx.x; + const int channel = blockIdx.y; + const int sample = blockIdx.z; + const int tid = threadIdx.x; + + x += sample*stride_sample + channel*stride_channel + row*stride_row; + dst += ((sample*nchannels + channel)*nrows + row)*ncols; float tmp = 0.0f; // partial sum for thread in warp @@ -186,13 +201,16 @@ static __global__ void rms_norm_back_f32( } } -static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { +static void norm_f32_cuda( + const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { + const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); - norm_f32<<>>(x, dst, ncols, eps); + norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - norm_f32<1024><<>>(x, dst, ncols, eps); + norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } @@ -207,13 +225,16 @@ static void group_norm_f32_cuda( } } -static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { +static void rms_norm_f32_cuda( + const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { + const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols, eps); + rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024><<>>(x, dst, ncols, eps); + rms_norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } @@ -229,23 +250,26 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - const int64_t ne00 = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); + GGML_TENSOR_UNARY_OP_LOCALS; float eps; memcpy(&eps, dst->op_params, sizeof(float)); GGML_ASSERT(eps >= 0.0f); - norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); + const size_t ts0 = ggml_type_size(src0->type); + GGML_ASSERT(nb00 == ts0); + const int64_t s01 = nb01 / ts0; + const int64_t s02 = nb02 / ts0; + const int64_t s03 = nb03 / ts0; + + norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream); } void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -254,8 +278,6 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -271,23 +293,26 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - const int64_t ne00 = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); + GGML_TENSOR_UNARY_OP_LOCALS; float eps; memcpy(&eps, dst->op_params, sizeof(float)); GGML_ASSERT(eps >= 0.0f); - rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); + const size_t ts0 = ggml_type_size(src0->type); + GGML_ASSERT(nb00 == ts0); + const int64_t s01 = nb01 / ts0; + const int64_t s02 = nb02 / ts0; + const int64_t s03 = nb03 / ts0; + + rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream); } void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 9605914ff..0a264be37 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1206,10 +1206,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_GROUP_NORM: return has_simdgroup_reduction; case GGML_OP_RMS_NORM: - return has_simdgroup_reduction && (op->ne[0] % 4 == 0); + return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); case GGML_OP_ARGMAX: - case GGML_OP_NORM: return true; + case GGML_OP_NORM: + return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_ROPE: { const int mode = ((const int32_t *) op->op_params)[2]; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 9ca3959ab..48ac489a6 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -8182,9 +8182,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: + return true; case GGML_OP_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: + return ggml_is_contiguous(op->src[0]); case GGML_OP_ADD: case GGML_OP_ACC: case GGML_OP_MUL: diff --git a/src/llama.cpp b/src/llama.cpp index 5760017e0..aae3c69b5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4610,7 +4610,8 @@ struct llm_build_context { ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); cb(k_pe, "k_pe", il); - kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm + // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont + kv_compressed = ggml_cont(ctx0, kv_compressed); kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, model.layers[il].attn_kv_a_norm, NULL, LLM_NORM_RMS, cb, il); @@ -6464,7 +6465,8 @@ struct llm_build_context { ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); cb(k_pe, "k_pe", il); - kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm + // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont + kv_compressed = ggml_cont(ctx0, kv_compressed); kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, model.layers[il].attn_kv_a_norm, NULL, LLM_NORM_RMS, cb, il); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 4c5c4dd9c..1bfd41254 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1674,21 +1674,28 @@ struct test_silu_back : public test_case { struct test_norm : public test_case { const ggml_type type; const std::array ne; - float eps; + const bool v; // whether a is a non-contiguous view + const float eps; std::string vars() override { - return VARS_TO_STR3(type, ne, eps); + return VARS_TO_STR4(type, ne, v, eps); } test_norm(ggml_type type = GGML_TYPE_F32, std::array ne = {64, 5, 4, 3}, + bool v = false, float eps = 1e-6f) - : type(type), ne(ne), eps(eps) {} + : type(type), ne(ne), v(v), eps(eps) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_name(a, "a"); + if (v) { + a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0); + ggml_set_name(a, "view of a"); + } + ggml_tensor * out = ggml_norm(ctx, a, eps); ggml_set_name(out, "out"); @@ -1700,22 +1707,29 @@ struct test_norm : public test_case { struct test_rms_norm : public test_case { const ggml_type type; const std::array ne; - float eps; + const bool v; // whether a is a non-contiguous view + const float eps; std::string vars() override { - return VARS_TO_STR3(type, ne, eps); + return VARS_TO_STR4(type, ne, v, eps); } test_rms_norm(ggml_type type = GGML_TYPE_F32, std::array ne = {64, 5, 4, 3}, + bool v = false, float eps = 1e-6f) - : type(type), ne(ne), eps(eps) {} + : type(type), ne(ne), v(v), eps(eps) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_param(ctx, a); ggml_set_name(a, "a"); + if (v) { + a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0); + ggml_set_name(a, "view of a"); + } + ggml_tensor * out = ggml_rms_norm(ctx, a, eps); ggml_set_name(out, "out"); @@ -1741,7 +1755,7 @@ struct test_rms_norm : public test_case { struct test_rms_norm_back : public test_case { const ggml_type type; const std::array ne; - float eps; + const float eps; std::string vars() override { return VARS_TO_STR3(type, ne, eps); @@ -2919,7 +2933,7 @@ struct test_group_norm : public test_case { const float eps; std::string vars() override { - return VARS_TO_STR3(type, ne, num_groups); + return VARS_TO_STR4(type, ne, num_groups, eps); } test_group_norm(ggml_type type = GGML_TYPE_F32, @@ -3964,9 +3978,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_scale()); test_cases.emplace_back(new test_silu_back()); - for (float eps : {0.0f, 1e-7f, 1e-4f, 1e-1f}) { - test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps)); - test_cases.emplace_back(new test_rms_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps)); + for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) { + for (bool v : {false, true}) { + test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, v, eps)); + test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps)); + } test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); } From 9f4cc8f8d310b13ab3fc93a9be81b1d1376a7fa6 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Wed, 5 Feb 2025 01:00:12 +0000 Subject: [PATCH 5/5] `sync`: minja (#11641) * `sync`: minja https://github.com/google/minja/commit/182de30cdaee78ba86179122f8047b3bdbab7f7f https://github.com/google/minja/pull/46 https://github.com/google/minja/pull/45 --- common/chat-template.hpp | 211 +++++++++++++++++++++++++++++++++------ common/chat.cpp | 45 +++++++-- common/common.cpp | 9 +- common/minja.hpp | 8 +- examples/run/run.cpp | 10 +- tests/test-chat.cpp | 8 +- 6 files changed, 233 insertions(+), 58 deletions(-) diff --git a/common/chat-template.hpp b/common/chat-template.hpp index 58e119a3b..0e88fb361 100644 --- a/common/chat-template.hpp +++ b/common/chat-template.hpp @@ -33,6 +33,29 @@ struct chat_template_caps { bool requires_typed_content = false; }; +struct chat_template_inputs { + nlohmann::ordered_json messages; + nlohmann::ordered_json tools; + bool add_generation_prompt = true; + nlohmann::ordered_json extra_context; + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); +}; + +struct chat_template_options { + bool apply_polyfills = true; + bool use_bos_token = true; + bool use_eos_token = true; + bool define_strftime_now = true; + + bool polyfill_tools = true; + bool polyfill_tool_call_examples = true; + bool polyfill_tool_calls = true; + bool polyfill_tool_responses = true; + bool polyfill_system_role = true; + bool polyfill_object_arguments = true; + bool polyfill_typed_content = true; +}; + class chat_template { private: @@ -41,6 +64,7 @@ class chat_template { std::string bos_token_; std::string eos_token_; std::shared_ptr template_root_; + std::string tool_call_example_; std::string try_raw_render( const nlohmann::ordered_json & messages, @@ -49,7 +73,18 @@ class chat_template { const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const { try { - auto prompt = apply(messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false); + chat_template_inputs inputs; + inputs.messages = messages; + inputs.tools = tools; + inputs.add_generation_prompt = add_generation_prompt; + inputs.extra_context = extra_context; + // Use fixed date for tests + inputs.now = std::chrono::system_clock::from_time_t(0); + + chat_template_options opts; + opts.apply_polyfills = false; + + auto prompt = apply(inputs, opts); // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str()); return prompt; } catch (const std::exception & e) { @@ -176,6 +211,58 @@ class chat_template { caps_.supports_tool_responses = contains(out, "Some response!"); caps_.supports_tool_call_id = contains(out, "call_911_"); } + + try { + if (!caps_.supports_tools) { + const json user_msg { + {"role", "user"}, + {"content", "Hey"}, + }; + const json args { + {"arg1", "some_value"}, + }; + const json tool_call_msg { + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + { + // TODO: detect if requires numerical id or fixed length == 6 like Nemo + {"id", "call_1___"}, + {"type", "function"}, + {"function", { + {"name", "tool_name"}, + {"arguments", (caps_.requires_object_arguments ? args : json(minja::Value(args).dump(-1, /* to_json= */ true)))}, + }}, + }, + })}, + }; + std::string prefix, full; + { + chat_template_inputs inputs; + inputs.messages = json::array({user_msg}); + inputs.add_generation_prompt = true; + prefix = apply(inputs); + } + { + chat_template_inputs inputs; + inputs.messages = json::array({user_msg, tool_call_msg}); + inputs.add_generation_prompt = false; + full = apply(inputs); + } + + if (full.find(prefix) != 0) { + if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) { + prefix = prefix.substr(0, prefix.size() - eos_token_.size()); + } + } + if (full.find(prefix) != 0) { + fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n"); + } + tool_call_example_ = full.substr(prefix.size()); + } + } catch (const std::exception & e) { + fprintf(stderr, "Failed to generate tool call example: %s\n", e.what()); + } } const std::string & source() const { return source_; } @@ -183,28 +270,72 @@ class chat_template { const std::string & eos_token() const { return eos_token_; } const chat_template_caps & original_caps() const { return caps_; } + // Deprecated, please use the form with chat_template_inputs and chat_template_options std::string apply( const nlohmann::ordered_json & messages, const nlohmann::ordered_json & tools, bool add_generation_prompt, const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(), - bool adjust_inputs = true) const + bool apply_polyfills = true) + { + fprintf(stderr, "[%s] Deprecated!\n", __func__); + chat_template_inputs inputs; + inputs.messages = messages; + inputs.tools = tools; + inputs.add_generation_prompt = add_generation_prompt; + inputs.extra_context = extra_context; + inputs.now = std::chrono::system_clock::now(); + + chat_template_options opts; + opts.apply_polyfills = apply_polyfills; + + return apply(inputs, opts); + } + + std::string apply( + const chat_template_inputs & inputs, + const chat_template_options & opts = chat_template_options()) const { json actual_messages; - auto needs_adjustments = adjust_inputs && (false - || !caps_.supports_system_role - || !caps_.supports_tools - || !caps_.supports_tool_responses - || !caps_.supports_tool_calls - || caps_.requires_object_arguments - || caps_.requires_typed_content + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto has_tool_calls = false; + auto has_tool_responses = false; + auto has_string_content = false; + for (const auto & message : inputs.messages) { + if (message.contains("tool_calls") && !message["tool_calls"].is_null()) { + has_tool_calls = true; + } + if (message.contains("role") && message["role"] == "tool") { + has_tool_responses = true; + } + if (message.contains("content") && message["content"].is_string()) { + has_string_content = true; + } + } + + auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role; + auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools; + auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples; + auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls; + auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses; + auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments; + auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content; + + auto needs_polyfills = opts.apply_polyfills && (false + || polyfill_system_role + || polyfill_tools + || polyfill_tool_calls + || polyfill_tool_responses + || polyfill_object_arguments + || polyfill_typed_content ); - if (needs_adjustments) { + + if (needs_polyfills) { actual_messages = json::array(); auto add_message = [&](const json & msg) { - if (caps_.requires_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) { + if (polyfill_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) { actual_messages.push_back({ {"role", msg.at("role")}, {"content", {{ @@ -227,9 +358,17 @@ class chat_template { pending_system.clear(); } }; - auto needs_tools_in_system = !tools.is_null() && tools.size() > 0 && !caps_.supports_tools; - for (const auto & message_ : needs_tools_in_system ? add_system(messages, "Available tools: " + tools.dump(2)) : messages) { + json adjusted_messages; + if (polyfill_tools) { + adjusted_messages = add_system(inputs.messages, + "You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) + + (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_)); + } else { + adjusted_messages = inputs.messages; + } + + for (const auto & message_ : adjusted_messages) { auto message = message_; if (!message.contains("role") || !message.contains("content")) { throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); @@ -237,7 +376,7 @@ class chat_template { std::string role = message.at("role"); if (message.contains("tool_calls")) { - if (caps_.requires_object_arguments || !caps_.supports_tool_calls) { + if (polyfill_object_arguments || polyfill_tool_calls) { for (auto & tool_call : message.at("tool_calls")) { if (tool_call["type"] == "function") { auto & function = tool_call.at("function"); @@ -252,7 +391,7 @@ class chat_template { } } } - if (!caps_.supports_tool_calls) { + if (polyfill_tool_calls) { auto content = message.at("content"); auto tool_calls = json::array(); for (const auto & tool_call : message.at("tool_calls")) { @@ -279,7 +418,7 @@ class chat_template { message.erase("tool_calls"); } } - if (!caps_.supports_tool_responses && role == "tool") { + if (polyfill_tool_responses && role == "tool") { message["role"] = "user"; auto obj = json { {"tool_response", { @@ -296,7 +435,7 @@ class chat_template { message.erase("name"); } - if (!message["content"].is_null() && !caps_.supports_system_role) { + if (!message["content"].is_null() && polyfill_system_role) { std::string content = message.at("content"); if (role == "system") { if (!pending_system.empty()) pending_system += "\n"; @@ -315,28 +454,36 @@ class chat_template { } add_message(message); } - if (!caps_.supports_system_role) { - flush_sys(); - } + flush_sys(); } else { - actual_messages = messages; + actual_messages = inputs.messages; } auto context = minja::Context::make(json({ {"messages", actual_messages}, - {"add_generation_prompt", add_generation_prompt}, - {"bos_token", bos_token_}, - {"eos_token", eos_token_}, + {"add_generation_prompt", inputs.add_generation_prompt}, })); + context->set("bos_token", opts.use_bos_token ? bos_token_ : ""); + context->set("eos_token", opts.use_eos_token ? eos_token_ : ""); + if (opts.define_strftime_now) { + auto now = inputs.now; + context->set("strftime_now", Value::callable([now](const std::shared_ptr &, minja::ArgumentsValue & args) { + args.expectArgs("strftime_now", {1, 1}, {0, 0}); + auto format = args.args[0].get(); - if (!tools.is_null()) { - auto tools_val = minja::Value(tools); - context->set("tools", tools_val); + auto time = std::chrono::system_clock::to_time_t(now); + auto local_time = *std::localtime(&time); + std::ostringstream ss; + ss << std::put_time(&local_time, format.c_str()); + return ss.str(); + })); } - if (!extra_context.is_null()) { - for (auto & kv : extra_context.items()) { - minja::Value val(kv.value()); - context->set(kv.key(), val); + if (!inputs.tools.is_null()) { + context->set("tools", minja::Value(inputs.tools)); + } + if (!inputs.extra_context.is_null()) { + for (auto & kv : inputs.extra_context.items()) { + context->set(kv.key(), minja::Value(kv.value())); } } @@ -353,7 +500,7 @@ class chat_template { std::string existing_system = messages_with_system.at(0).at("content"); messages_with_system[0] = json { {"role", "system"}, - {"content", existing_system + "\n" + system_prompt}, + {"content", existing_system + "\n\n" + system_prompt}, }; } else { messages_with_system.insert(messages_with_system.begin(), json { diff --git a/common/chat.cpp b/common/chat.cpp index 4a113c0ca..ef1c6fb3d 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -163,6 +163,28 @@ static void foreach_function(const json & tools, const std::function", "<|END_ACTION|>", }; - data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_COMMAND_R7B; return data; } @@ -477,7 +499,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com builder.add_rule("root", string_join(tool_rules, " | ")); }, grammar_options); data.additional_stops.push_back("<|eom_id|>"); - data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, { + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, { {"tools_in_user_message", false}, {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools}, }); @@ -542,7 +564,8 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ }; builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space"); }, grammar_options); - data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = prompt; data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1; return data; } @@ -556,10 +579,10 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { fprintf(stderr, "%s\n", __func__); common_chat_params data; - data.prompt = tmpl.apply(inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { + data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { {"datetime", "Jan 29 2025 13:00:00 GMT"}, {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, - }, /* adjust_inputs= */ false); + }); if (!inputs.tools.is_null() && !inputs.tools.empty()) { data.grammar_lazy = inputs.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { @@ -603,7 +626,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar common_chat_params data; - data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; if (!inputs.tools.is_null() && !inputs.tools.empty()) { data.grammar_lazy = inputs.tool_choice != "required"; @@ -730,7 +753,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con data.grammar_triggers.push_back({"" }; }, grammar_options); - data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO; return data; } @@ -846,7 +869,7 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { common_chat_params data; - data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; data.grammar_lazy = false; if (!inputs.json_schema.is_null()) { diff --git a/common/common.cpp b/common/common.cpp index edba6fb4b..8661e164a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1904,10 +1904,6 @@ common_chat_templates common_chat_templates_from_model(const struct llama_model default_template_src = CHATML_TEMPLATE_SRC; } } - std::string token_bos; - std::string token_eos; - // TODO: update logic that adds BOS and EOS tokens to the tokenized prompt, in favour of the template. -#if 0 auto vocab = llama_model_get_vocab(model); const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) { if (token == LLAMA_TOKEN_NULL) { @@ -1920,9 +1916,8 @@ common_chat_templates common_chat_templates_from_model(const struct llama_model return common_token_to_piece(vocab, token, true); } }; - token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); - token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); -#endif + auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); + auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); try { return { has_explicit_template, diff --git a/common/minja.hpp b/common/minja.hpp index e77eb69d5..c304b5c66 100644 --- a/common/minja.hpp +++ b/common/minja.hpp @@ -2194,7 +2194,7 @@ private: } TemplateTokenVector tokenize() { - static std::regex comment_tok(R"(\{#([-~]?)(.*?)([-~]?)#\})"); + static std::regex comment_tok(R"(\{#([-~]?)([\s\S\r\n]*?)([-~]?)#\})"); static std::regex expr_open_regex(R"(\{\{([-~])?)"); static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)"); static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)"); @@ -2615,6 +2615,7 @@ inline std::shared_ptr Context::builtins() { })); globals.set("join", simple_function("join", { "items", "d" }, [](const std::shared_ptr &, Value & args) { auto do_join = [](Value & items, const std::string & sep) { + if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); std::ostringstream oss; auto first = true; for (size_t i = 0, n = items.size(); i < n; ++i) { @@ -2695,6 +2696,10 @@ inline std::shared_ptr Context::builtins() { return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits::max)()}, {0, 0}); auto & items = args.args[0]; + if (items.is_null()) + return Value::array(); + if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); + auto filter_fn = context->get(args.args[1]); if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); @@ -2772,6 +2777,7 @@ inline std::shared_ptr Context::builtins() { auto & items = args.args[0]; if (items.is_null()) return Value::array(); + if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); auto attr_name = args.args[1].get(); bool has_test = false; diff --git a/examples/run/run.cpp b/examples/run/run.cpp index ca9273155..39353ba30 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -848,7 +848,15 @@ static int apply_chat_template(const common_chat_template & tmpl, LlamaData & ll }); } try { - auto result = tmpl.apply(messages, /* tools= */ json(), append); + minja::chat_template_inputs tmpl_inputs; + tmpl_inputs.messages = messages; + tmpl_inputs.add_generation_prompt = append; + + minja::chat_template_options tmpl_opts; + tmpl_opts.use_bos_token = false; + tmpl_opts.use_eos_token = false; + + auto result = tmpl.apply(tmpl_inputs, tmpl_opts); llama_data.fmtted.resize(result.size() + 1); memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1); return result.size(); diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 50bd40738..b78da2cdb 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -18,12 +18,8 @@ using json = nlohmann::ordered_json; static common_chat_msg msg_from_json(const json & message) { - common_chat_msg ret{ - "assistant", - "", - {}, - /* .tool_plan = */ "", - }; + common_chat_msg ret; + ret.role = "assistant"; if (message.contains("content") && !message.at("content").is_null()) { ret.content = message.at("content"); }