diff --git a/CMakeLists.txt b/CMakeLists.txt index 0c29b5d09..40a098d01 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -741,13 +741,6 @@ function(get_flags CCID CCVER) if (CCVER VERSION_GREATER_EQUAL 8.1.0) list(APPEND CXX_FLAGS -Wextra-semi) endif() - elseif (CCID MATCHES "Intel") - if (NOT LLAMA_SYCL) - # enable max optimization level when using Intel compiler - set(C_FLAGS -ipo -O3 -static -fp-model=fast -flto -fno-stack-protector) - set(CXX_FLAGS -ipo -O3 -static -fp-model=fast -flto -fno-stack-protector) - add_link_options(-fuse-ld=lld -static-intel) - endif() endif() set(GF_C_FLAGS ${C_FLAGS} PARENT_SCOPE) @@ -778,10 +771,7 @@ endif() set(CUDA_CXX_FLAGS "") if (LLAMA_CUBLAS) - set(CUDA_FLAGS ${CXX_FLAGS} -use_fast_math) - if (NOT MSVC) - list(APPEND CUDA_FLAGS -Wno-pedantic) - endif() + set(CUDA_FLAGS -use_fast_math) if (LLAMA_ALL_WARNINGS AND NOT MSVC) set(NVCC_CMD ${CMAKE_CUDA_COMPILER} .c) @@ -814,7 +804,11 @@ if (LLAMA_CUBLAS) message("-- CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}") get_flags(${CUDA_CCID} ${CUDA_CCVER}) - list(APPEND CUDA_CXX_FLAGS ${GF_CXX_FLAGS}) # This is passed to -Xcompiler later + list(APPEND CUDA_CXX_FLAGS ${CXX_FLAGS} ${GF_CXX_FLAGS}) # This is passed to -Xcompiler later + endif() + + if (NOT MSVC) + list(APPEND CUDA_CXX_FLAGS -Wno-pedantic) endif() endif() diff --git a/Makefile b/Makefile index 901798606..29fd2ca9c 100644 --- a/Makefile +++ b/Makefile @@ -97,9 +97,10 @@ endif # # keep standard at C11 and C++11 -MK_CPPFLAGS = -I. -Icommon -MK_CFLAGS = -std=c11 -fPIC -MK_CXXFLAGS = -std=c++11 -fPIC +MK_CPPFLAGS = -I. -Icommon +MK_CFLAGS = -std=c11 -fPIC +MK_CXXFLAGS = -std=c++11 -fPIC +MK_NVCCFLAGS = -std=c++11 # -Ofast tends to produce faster code, but may not be available for some compilers. ifdef LLAMA_FAST @@ -220,30 +221,6 @@ ifeq ($(LLAMA_FATAL_WARNINGS),1) MK_CXXFLAGS += -Werror endif -ifeq ($(CC_IS_CLANG), 1) - # clang options - MK_CFLAGS += -Wunreachable-code-break -Wunreachable-code-return - MK_HOST_CXXFLAGS += -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi - - ifneq '' '$(and $(CC_IS_LLVM_CLANG),$(filter 1,$(shell expr $(CC_VER) \>= 030800)))' - MK_CFLAGS += -Wdouble-promotion - endif - ifneq '' '$(and $(CC_IS_APPLE_CLANG),$(filter 1,$(shell expr $(CC_VER) \>= 070300)))' - MK_CFLAGS += -Wdouble-promotion - endif -else - # gcc options - MK_CFLAGS += -Wdouble-promotion - MK_HOST_CXXFLAGS += -Wno-array-bounds - - ifeq ($(shell expr $(CC_VER) \>= 070100), 1) - MK_HOST_CXXFLAGS += -Wno-format-truncation - endif - ifeq ($(shell expr $(CC_VER) \>= 080100), 1) - MK_HOST_CXXFLAGS += -Wextra-semi - endif -endif - # this version of Apple ld64 is buggy ifneq '' '$(findstring dyld-1015.7,$(shell $(CC) $(LDFLAGS) -Wl,-v 2>&1))' MK_CPPFLAGS += -DHAVE_BUGGY_APPLE_LINKER @@ -468,7 +445,7 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ifdef JETSON_EOL_MODULE_DETECT $(NVCC) -I. -Icommon -D_XOPEN_SOURCE=600 -D_GNU_SOURCE -DNDEBUG -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I/usr/local/cuda/targets/aarch64-linux/include -std=c++11 -O3 $(NVCCFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@ else - $(NVCC) $(BASE_CXXFLAGS) $(NVCCFLAGS) -Wno-pedantic -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@ + $(NVCC) $(NVCCFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@ endif # JETSON_EOL_MODULE_DETECT endif # LLAMA_CUBLAS @@ -579,7 +556,7 @@ override LDFLAGS := $(MK_LDFLAGS) $(LDFLAGS) ifdef LLAMA_CUBLAS GF_CC := $(NVCC) $(NVCCFLAGS) 2>/dev/null .c -Xcompiler include scripts/get-flags.mk -CUDA_CXXFLAGS := $(GF_CXXFLAGS) +CUDA_CXXFLAGS := $(BASE_CXXFLAGS) $(GF_CXXFLAGS) -Wno-pedantic endif # @@ -891,3 +868,7 @@ tests/test-model-load-cancel: tests/test-model-load-cancel.cpp ggml.o llama.o te tests/test-autorelease: tests/test-autorelease.cpp ggml.o llama.o tests/get-model.cpp $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-chat-template: tests/test-chat-template.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) diff --git a/README-sycl.md b/README-sycl.md index e3a8e726e..dd5bf9dea 100644 --- a/README-sycl.md +++ b/README-sycl.md @@ -272,7 +272,7 @@ Please install [Visual Studio](https://visualstudio.microsoft.com/) which impact a. Please follow the procedure in [Get the Intel® oneAPI Base Toolkit ](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html). -Recommend to install to default folder: **/opt/intel/oneapi**. +Recommend to install to default folder: **C:\Program Files (x86)\Intel\oneAPI**. Following guide uses the default folder as example. If you use other folder, please modify the following guide info with your folder. diff --git a/README.md b/README.md index 0c4ee5a27..70866e249 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ variety of hardware - locally and in the cloud. - Plain C/C++ implementation without any dependencies - Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks - AVX, AVX2 and AVX512 support for x86 architectures -- 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use +- 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use - Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP) - Vulkan, SYCL, and (partial) OpenCL backend support - CPU+GPU hybrid inference to partially accelerate models larger than the total VRAM capacity @@ -768,7 +768,7 @@ The time per token is measured on a MacBook M1 Pro 32GB RAM using 4 and 8 thread #### How to run -1. Download/extract: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research +1. Download/extract: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip 2. Run `./perplexity -m models/7B/ggml-model-q4_0.gguf -f wiki.test.raw` 3. Output: ``` diff --git a/ci/run.sh b/ci/run.sh index b94658c96..f3a29c2e9 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -219,7 +219,7 @@ function gg_run_open_llama_3b_v2 { gg_wget models-mnt/open-llama/3B-v2/ https://huggingface.co/openlm-research/open_llama_3b_v2/resolve/main/pytorch_model.bin gg_wget models-mnt/open-llama/3B-v2/ https://huggingface.co/openlm-research/open_llama_3b_v2/raw/main/generation_config.json - gg_wget models-mnt/wikitext/ https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip + gg_wget models-mnt/wikitext/ https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip unzip -o models-mnt/wikitext/wikitext-2-raw-v1.zip -d models-mnt/wikitext/ head -n 60 models-mnt/wikitext/wikitext-2-raw/wiki.test.raw > models-mnt/wikitext/wikitext-2-raw/wiki.test-60.raw @@ -401,7 +401,7 @@ function gg_run_open_llama_7b_v2 { gg_wget models-mnt/open-llama/7B-v2/ https://huggingface.co/openlm-research/open_llama_7b_v2/resolve/main/pytorch_model-00002-of-00002.bin gg_wget models-mnt/open-llama/7B-v2/ https://huggingface.co/openlm-research/open_llama_7b_v2/raw/main/generation_config.json - gg_wget models-mnt/wikitext/ https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip + gg_wget models-mnt/wikitext/ https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip unzip -o models-mnt/wikitext/wikitext-2-raw-v1.zip -d models-mnt/wikitext/ path_models="../models-mnt/open-llama/7B-v2" diff --git a/common/common.cpp b/common/common.cpp index 489462b5a..10ef11829 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1704,6 +1704,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l } fprintf(stream, "lora_base: %s\n", params.lora_base.c_str()); fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); + fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep); fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat); fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau); fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta); diff --git a/common/sampling.cpp b/common/sampling.cpp index 611c327bb..de4331a11 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -248,7 +248,10 @@ static llama_token llama_sampling_sample_impl( llama_sample_temp(ctx_main, &cur_p, temp); id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); } else { - sampler_queue(ctx_main, params, cur_p, 1); + // temperature sampling + size_t min_keep = std::max(1, params.min_keep); + + sampler_queue(ctx_main, params, cur_p, min_keep); id = llama_sample_token(ctx_main, &cur_p); diff --git a/common/sampling.h b/common/sampling.h index e1279a894..95d875394 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -22,6 +22,7 @@ enum class llama_sampler_type : char { typedef struct llama_sampling_params { int32_t n_prev = 64; // number of previous tokens to remember int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled float min_p = 0.05f; // 0.0 = disabled diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index e7d2ad592..65bb238a0 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -1533,16 +1533,17 @@ int main(int argc, char ** argv) { int n_past = 0; - ggml_cgraph gf = {}; + struct ggml_cgraph * gf = NULL; + gf = ggml_new_graph_custom(ctx0, LLAMA_TRAIN_MAX_NODES, true); get_example_targets_batch(ctx0, 64*ex+0, tokens_input, targets); - struct ggml_tensor * logits = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch); + struct ggml_tensor * logits = forward_batch(&model, &kv_self, ctx0, gf, tokens_input, n_tokens, n_past, n_batch); // struct ggml_tensor * e = cross_entropy_loss(ctx0, targets, logits); struct ggml_tensor * e = square_error_loss(ctx0, targets, logits); - ggml_build_forward_expand(&gf, e); - ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1); + ggml_build_forward_expand(gf, e); + ggml_graph_compute_helper(work_buffer, gf, /*n_threads*/ 1); float error_before_opt = ggml_get_f32_1d(e, 0); @@ -1552,8 +1553,8 @@ int main(int argc, char ** argv) { opt_params_lbfgs.lbfgs.n_iter = 16; ggml_opt(ctx0, opt_params_lbfgs, e); // - ggml_build_forward_expand(&gf, e); - ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1); + ggml_build_forward_expand(gf, e); + ggml_graph_compute_helper(work_buffer, gf, /*n_threads*/ 1); float error_after_opt = ggml_get_f32_1d(e, 0); @@ -1600,13 +1601,14 @@ int main(int argc, char ** argv) { }; struct ggml_context * ctx0 = ggml_init(params); - ggml_cgraph gf = {}; + struct ggml_cgraph * gf = NULL; + gf = ggml_new_graph_custom(ctx0, LLAMA_TRAIN_MAX_NODES, true); int n_past = 0; - struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past); + struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, gf, tokens_input, sample_ctx, n_past); - ggml_build_forward_expand(&gf, logits); - ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1); + ggml_build_forward_expand(gf, logits); + ggml_graph_compute_helper(work_buffer, gf, /*n_threads*/ 1); struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx); struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx); diff --git a/examples/llava/llava-surgery.py b/examples/llava/llava-surgery.py index 8b7a62fba..4f2da3bee 100644 --- a/examples/llava/llava-surgery.py +++ b/examples/llava/llava-surgery.py @@ -25,9 +25,6 @@ if len(clip_tensors) > 0: clip = {name.replace("vision_tower.vision_tower.", ""): checkpoint[name].float() for name in clip_tensors} torch.save(clip, f"{args.model}/llava.clip") - # remove these tensors - for name in clip_tensors: - del checkpoint[name] # added tokens should be removed to be able to convert Mistral models if os.path.exists(f"{args.model}/added_tokens.json"): @@ -35,7 +32,6 @@ if len(clip_tensors) > 0: f.write("{}\n") - torch.save(checkpoint, path) print("Done!") print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.") diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 74dcc642a..9ec989389 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -309,7 +309,7 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens } static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) { - // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research + // Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` // Output: `perplexity: 13.5106 [114/114]` // BOS tokens will be added for each chunk before eval @@ -447,7 +447,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par return perplexity_v2(ctx, params); } - // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research + // Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` // Output: `perplexity: 13.5106 [114/114]` // BOS tokens will be added for each chunk before eval diff --git a/examples/server/README.md b/examples/server/README.md index ac5133d24..809e2d37c 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -199,6 +199,8 @@ node index.js `n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token (default: 0) + `min_keep`: If greater than 0, force samplers to return N possible tokens at minimum (default: 0) + `image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `prompt`. You can determine the place of the image in the prompt as in the following: `USER:[img-12]Describe the image in detail.\nASSISTANT:`. In this case, `[img-12]` will be replaced by the embeddings of the image with id `12` in the following `image_data` array: `{..., "image_data": [{"data": "", "id": 12}]}`. Use `image_data` only with multimodal models, e.g., LLaVA. `slot_id`: Assign the completion task to an specific slot. If is -1 the task will be assigned to a Idle slot (default: -1) diff --git a/examples/server/public/index.html b/examples/server/public/index.html index b059c75f2..84038ddce 100644 --- a/examples/server/public/index.html +++ b/examples/server/public/index.html @@ -234,6 +234,7 @@ mirostat_eta: 0.1, // learning rate grammar: '', n_probs: 0, // no completion_probabilities, + min_keep: 0, // min probs from each sampler, image_data: [], cache_prompt: true, api_key: '' @@ -791,6 +792,9 @@
${IntField({ label: "Show Probabilities", max: 10, min: 0, name: "n_probs", value: params.value.n_probs })}
+
+ ${IntField({ label: "Min Probabilities from each Sampler", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })} +
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 7e0e571b7..53620f45e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -624,6 +624,7 @@ struct llama_server_context slot->params.seed = json_value(data, "seed", default_params.seed); slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar); slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); + slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) { // Might be better to reject the request with a 400 ? @@ -1169,6 +1170,7 @@ struct llama_server_context {"stream", slot.params.stream}, {"logit_bias", slot.sparams.logit_bias}, {"n_probs", slot.sparams.n_probs}, + {"min_keep", slot.sparams.min_keep}, {"grammar", slot.sparams.grammar}, {"samplers", samplers_sequence} }; diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 933ebbc4e..eef213509 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6205,7 +6205,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + slope*pos[col]; + const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -9170,17 +9170,17 @@ static void ggml_cuda_op_soft_max( memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); // positions tensor - float * src2_dd = dst_dd; // default to avoid null checks in the kernel + float * src2_dd = nullptr; cuda_pool_alloc src2_f; ggml_tensor * src2 = dst->src[2]; const bool use_src2 = src2 != nullptr; if (use_src2) { - const bool src2_on_device = use_src2 && src2->backend == GGML_BACKEND_GPU; - ggml_tensor_extra_gpu * src2_extra = use_src2 ? (ggml_tensor_extra_gpu *) src2->extra : nullptr; + const bool src2_on_device = src2->backend == GGML_BACKEND_GPU; if (src2_on_device) { + ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra; src2_dd = (float *) src2_extra->data_device[g_main_device]; } else { src2_dd = src2_f.alloc(ggml_nelements(src2)); diff --git a/ggml-metal.metal b/ggml-metal.metal index a00962111..f0d77d446 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -392,7 +392,7 @@ kernel void kernel_soft_max( float lmax = -INFINITY; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]); + lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)); } // find the max value in the block @@ -417,7 +417,7 @@ kernel void kernel_soft_max( // parallel sum float lsum = 0.0f; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val); + const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val); lsum += exp_psrc0; pdst[i00] = exp_psrc0; } @@ -495,7 +495,7 @@ kernel void kernel_soft_max_4( float4 lmax4 = -INFINITY; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]); + lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); @@ -521,7 +521,7 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val); + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; } @@ -4027,7 +4027,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl( y4 += 32 * 32; } #else - // TODO + (void) x; + (void) y; + (void) yl; + (void) nb32; #endif for (int row = 0; row < N_DST; ++row) { @@ -4170,7 +4173,10 @@ void kernel_mul_mv_iq2_xs_f32_impl( y4 += 32 * 32; } #else - // TODO + (void) x; + (void) y; + (void) yl; + (void) nb32; #endif for (int row = 0; row < N_DST; ++row) { @@ -4306,7 +4312,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl( y4 += 32 * 32; } #else - // TODO + (void) x; + (void) y; + (void) yl; + (void) nb32; #endif for (int row = 0; row < N_DST; ++row) { @@ -4424,7 +4433,10 @@ void kernel_mul_mv_iq1_s_f32_impl( y4 += 16 * 32; } #else - // TODO + (void) x; + (void) y; + (void) yl; + (void) nb32; #endif for (int row = 0; row < N_DST; ++row) { @@ -4659,6 +4671,8 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg const float dl = d * sc[0]; const float ml = min * sc[1]; #else + (void) get_scale_min_k4_just2; + q = q + 16 * (il&1); device const uint8_t * s = xb->scales; device const half2 * dh = (device const half2 *)xb->d; diff --git a/ggml-quants.c b/ggml-quants.c index 48f5294e1..3319d2ccf 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -1837,9 +1837,9 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri float sigma2 = sumx2/QK_K; for (int j = 0; j < QK_K/16; ++j) { const float * restrict qw = quant_weights + QK_K * i + 16*j; - for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]); - for (int l = 0; l < 16; ++l) sw[j] += weight[l]; - scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false); + for (int l = 0; l < QK_K/16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]); + for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l]; + scales[j] = make_qkx3_quants(QK_K/16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false); } float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw); @@ -3855,7 +3855,7 @@ static inline __m128i get_scale_shuffle(int i) { } #endif -void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) { +void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; @@ -3866,8 +3866,8 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r assert(nrc == 1); #endif UNUSED(nrc); - UNUSED(bbx); - UNUSED(bby); + UNUSED(bx); + UNUSED(by); UNUSED(bs); const block_q4_0 * restrict x = vx; @@ -4024,15 +4024,15 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs); - __m128i bx = _mm_and_si128(lowMask, tmp); - __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs); - bx = _mm_sub_epi8(bx, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx, by); + __m128i bx_0 = _mm_and_si128(lowMask, tmp); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4)); - by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); - bx = _mm_sub_epi8(bx, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx, by); + bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4)); + by_0 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0); // Convert int32_t to float __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1)); @@ -4222,7 +4222,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r #endif } -void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) { +void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { const int qk = QK8_1; const int nb = n / qk; @@ -4233,8 +4233,8 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r assert(nrc == 1); #endif UNUSED(nrc); - UNUSED(bbx); - UNUSED(bby); + UNUSED(bx); + UNUSED(by); UNUSED(bs); const block_q4_1 * restrict x = vx; @@ -4440,7 +4440,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r #endif } -void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) { +void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; @@ -4448,8 +4448,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r assert(qk == QK5_0); assert(nrc == 1); UNUSED(nrc); - UNUSED(bbx); - UNUSED(bby); + UNUSED(bx); + UNUSED(by); UNUSED(bs); const block_q5_0 * restrict x = vx; @@ -4618,21 +4618,21 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r /* Compute combined scale for the block */ const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); - __m256i bx = bytes_from_nibbles_32(x[i].qs); + __m256i bx_0 = bytes_from_nibbles_32(x[i].qs); const __m256i bxhi = bytes_from_bits_32(x[i].qh); __m128i bxhil = _mm256_castsi256_si128(bxhi); __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); bxhil = _mm_andnot_si128(bxhil, mask); bxhih = _mm_andnot_si128(bxhih, mask); - __m128i bxl = _mm256_castsi256_si128(bx); - __m128i bxh = _mm256_extractf128_si256(bx, 1); + __m128i bxl = _mm256_castsi256_si128(bx_0); + __m128i bxh = _mm256_extractf128_si256(bx_0, 1); bxl = _mm_or_si128(bxl, bxhil); bxh = _mm_or_si128(bxh, bxhih); - bx = MM256_SET_M128I(bxh, bxl); + bx_0 = MM256_SET_M128I(bxh, bxl); - const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs); - const __m256 q = mul_sum_i8_pairs_float(bx, by); + const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0); /* Multiply q with scale and accumulate */ acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); @@ -4731,7 +4731,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r #endif } -void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) { +void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { const int qk = QK8_1; const int nb = n / qk; @@ -4739,8 +4739,8 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r assert(qk == QK5_1); assert(nrc == 1); UNUSED(nrc); - UNUSED(bbx); - UNUSED(bby); + UNUSED(bx); + UNUSED(by); UNUSED(bs); const block_q5_1 * restrict x = vx; @@ -4925,22 +4925,22 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; - __m256i bx = bytes_from_nibbles_32(x[i].qs); + __m256i bx_0 = bytes_from_nibbles_32(x[i].qs); const __m256i bxhi = bytes_from_bits_32(x[i].qh); __m128i bxhil = _mm256_castsi256_si128(bxhi); __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); bxhil = _mm_and_si128(bxhil, mask); bxhih = _mm_and_si128(bxhih, mask); - __m128i bxl = _mm256_castsi256_si128(bx); - __m128i bxh = _mm256_extractf128_si256(bx, 1); + __m128i bxl = _mm256_castsi256_si128(bx_0); + __m128i bxh = _mm256_extractf128_si256(bx_0, 1); bxl = _mm_or_si128(bxl, bxhil); bxh = _mm_or_si128(bxh, bxhih); - bx = MM256_SET_M128I(bxh, bxl); + bx_0 = MM256_SET_M128I(bxh, bxl); const __m256 dy = _mm256_set1_ps(y[i].d); - const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs); - const __m256 q = mul_sum_us8_pairs_float(bx, by); + const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0); acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); } @@ -5035,7 +5035,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r #endif } -void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) { +void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; @@ -5046,8 +5046,8 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r assert(nrc == 1); #endif UNUSED(nrc); - UNUSED(bbx); - UNUSED(bby); + UNUSED(bx); + UNUSED(by); UNUSED(bs); const block_q8_0 * restrict x = vx; @@ -5169,10 +5169,10 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r for (int i = 0; i < nb; i++) { // load elements - vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl); - vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl); + vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[i].qs, vl); + vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[i].qs, vl); - vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl); + vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl); vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl); diff --git a/ggml.c b/ggml.c index 8224652a9..4ee2c5e11 100644 --- a/ggml.c +++ b/ggml.c @@ -23,6 +23,9 @@ #include #include #include +#if defined(__gnu_linux__) +#include +#endif #ifdef GGML_USE_METAL #include @@ -1971,7 +1974,7 @@ struct ggml_numa_nodes { uint32_t n_nodes; uint32_t total_cpus; // hardware threads on system uint32_t current_node; // node on which main process is execting -#ifdef __linux__ +#if defined(__gnu_linux__) cpu_set_t cpuset; // cpuset from numactl #else uint32_t cpuset; // no NUMA support outside of Linux at this time. Use a portable datatype @@ -2009,7 +2012,7 @@ inline static void ggml_critical_section_end(void) { atomic_fetch_sub(&g_state_barrier, 1); } -#ifdef __linux__ +#if defined(__gnu_linux__) static cpu_set_t ggml_get_numa_affinity(void) { cpu_set_t cpuset; pthread_t thread; @@ -2031,7 +2034,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) { return; } -#ifdef __linux__ +#if defined(__gnu_linux__) struct stat st; char path[256]; int rv; @@ -2063,7 +2066,13 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) { // figure out which node we're on uint current_cpu; - int getcpu_ret = getcpu(¤t_cpu, &g_state.numa.current_node); + int getcpu_ret = 0; +#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) + getcpu_ret = getcpu(¤t_cpu, &g_state.numa.current_node); +#else + // old glibc doesn't have a wrapper for this call. Fall back on direct syscall + getcpu_ret = syscall(SYS_getcpu,¤t_cpu,&g_state.numa.current_node); +#endif if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1 || getcpu_ret != 0) { g_state.numa.n_nodes = 0; @@ -16734,7 +16743,7 @@ typedef pthread_t ggml_thread_t; #endif // Android's libc implementation "bionic" does not support setting affinity -#if defined(__linux__) && !defined(__BIONIC__) +#if defined(__gnu_linux__) static void set_numa_thread_affinity(int thread_n) { if (!ggml_is_numa()) { return; diff --git a/llama.cpp b/llama.cpp index 319eba1f3..9d615237c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12508,6 +12508,123 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token return 0; } +// trim whitespace from the beginning and end of a string +static std::string trim(const std::string & str) { + size_t start = 0; + size_t end = str.size(); + while (start < end && isspace(str[start])) { + start += 1; + } + while (end > start && isspace(str[end - 1])) { + end -= 1; + } + return str.substr(start, end - start); +} + +// Simple version of "llama_apply_chat_template" that only works with strings +// This function uses heuristic checks to determine commonly used template. It is not a jinja parser. +static int32_t llama_chat_apply_template_internal( + const std::string & tmpl, + const std::vector & chat, + std::string & dest, bool add_ass) { + // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 + std::stringstream ss; + if (tmpl.find("<|im_start|>") != std::string::npos) { + // chatml template + for (auto message : chat) { + ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n"; + } + if (add_ass) { + ss << "<|im_start|>assistant\n"; + } + } else if (tmpl.find("[INST]") != std::string::npos) { + // llama2 template and its variants + // [variant] support system message + bool support_system_message = tmpl.find("<>") != std::string::npos; + // [variant] space before + after response + bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos; + // [variant] add BOS inside history + bool add_bos_inside_history = tmpl.find("bos_token + '[INST]") != std::string::npos; + // [variant] trim spaces from the input message + bool strip_message = tmpl.find("content.strip()") != std::string::npos; + // construct the prompt + bool is_inside_turn = true; // skip BOS at the beginning + ss << "[INST] "; + for (auto message : chat) { + std::string content = strip_message ? trim(message->content) : message->content; + std::string role(message->role); + if (!is_inside_turn) { + is_inside_turn = true; + ss << (add_bos_inside_history ? "[INST] " : "[INST] "); + } + if (role == "system") { + if (support_system_message) { + ss << "<>\n" << content << "\n<>\n\n"; + } else { + // if the model does not support system message, we still include it in the first message, but without <> + ss << content << "\n"; + } + } else if (role == "user") { + ss << content << " [/INST]"; + } else { + ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << ""; + is_inside_turn = false; + } + } + // llama2 templates seem to not care about "add_generation_prompt" + } else if (tmpl.find("<|user|>") != std::string::npos) { + // zephyr template + for (auto message : chat) { + ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n"; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } + } else { + // template not supported + return -1; + } + dest = ss.str(); + return dest.size(); +} + +LLAMA_API int32_t llama_chat_apply_template( + const struct llama_model * model, + const char * tmpl, + const struct llama_chat_message * chat, + size_t n_msg, + bool add_ass, + char * buf, + int32_t length) { + std::string curr_tmpl(tmpl == nullptr ? "" : tmpl); + if (tmpl == nullptr) { + GGML_ASSERT(model != nullptr); + // load template from model + std::vector model_template(2048, 0); // longest known template is about 1200 bytes + std::string template_key = "tokenizer.chat_template"; + int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), curr_tmpl.size()); + if (res < 0) { + // worst case: there is no information about template, we will use chatml by default + curr_tmpl = "<|im_start|>"; // see llama_chat_apply_template_internal + } else { + curr_tmpl = std::string(model_template.data(), model_template.size()); + } + } + // format the chat to string + std::vector chat_vec; + chat_vec.resize(n_msg); + for (size_t i = 0; i < n_msg; i++) { + chat_vec[i] = &chat[i]; + } + std::string formatted_chat; + int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass); + if (res < 0) { + return res; + } + strncpy(buf, formatted_chat.c_str(), length); + return res; +} + struct llama_timings llama_get_timings(struct llama_context * ctx) { struct llama_timings result = { /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, diff --git a/llama.h b/llama.h index 5a97abcc9..77a84c18a 100644 --- a/llama.h +++ b/llama.h @@ -305,6 +305,12 @@ extern "C" { int32_t n_eval; }; + // used in chat template + typedef struct llama_chat_message { + const char * role; + const char * content; + } llama_chat_message; + // Helpers for getting default parameters LLAMA_API struct llama_model_params llama_model_default_params(void); LLAMA_API struct llama_context_params llama_context_default_params(void); @@ -699,6 +705,25 @@ extern "C" { char * buf, int32_t length); + /// Apply chat template. Inspired by hf apply_chat_template() on python. + /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" + /// NOTE: This function only support some known jinja templates. It is not a jinja parser. + /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead. + /// @param chat Pointer to a list of multiple llama_chat_message + /// @param n_msg Number of llama_chat_message in this chat + /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message. + /// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages) + /// @param length The size of the allocated buffer + /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. + LLAMA_API int32_t llama_chat_apply_template( + const struct llama_model * model, + const char * tmpl, + const struct llama_chat_message * chat, + size_t n_msg, + bool add_ass, + char * buf, + int32_t length); + // // Grammar // diff --git a/scripts/get-flags.mk b/scripts/get-flags.mk index 596d7ead1..a742766d1 100644 --- a/scripts/get-flags.mk +++ b/scripts/get-flags.mk @@ -1,6 +1,6 @@ ifeq '' '$(findstring clang,$(shell $(GF_CC) --version))' GF_CC_IS_GCC = 1 - GF_CC_VER := $(shell { $(GF_CC) -dumpfullversion 2>/dev/null || $(GF_CC) -dumpversion; } | awk -F. '{ printf("%02d%02d%02d", $$1, $$2, $$3) }') + GF_CC_VER := $(shell { $(GF_CC) -dumpfullversion 2>/dev/null; echo; $(GF_CC) -dumpversion; } | awk -F. '/./ { printf("%02d%02d%02d", $$1, $$2, $$3); exit }') else GF_CC_IS_CLANG = 1 ifeq '' '$(findstring Apple,$(shell $(GF_CC) --version))' diff --git a/scripts/get-wikitext-2.sh b/scripts/get-wikitext-2.sh index ff96f331e..7ca760fa6 100755 --- a/scripts/get-wikitext-2.sh +++ b/scripts/get-wikitext-2.sh @@ -1,6 +1,6 @@ #!/bin/bash -wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip +wget https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip echo "Usage:" echo "" diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 3e40a78cd..10326d531 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -28,6 +28,7 @@ endfunction() llama_build_and_test_executable(test-quantize-fns.cpp) llama_build_and_test_executable(test-quantize-perf.cpp) llama_build_and_test_executable(test-sampling.cpp) +llama_build_and_test_executable(test-chat-template.cpp) llama_build_executable(test-tokenizer-0-llama.cpp) llama_test_executable (test-tokenizer-0-llama test-tokenizer-0-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf) diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp new file mode 100644 index 000000000..9830650d4 --- /dev/null +++ b/tests/test-chat-template.cpp @@ -0,0 +1,64 @@ +#include +#include +#include +#include + +#undef NDEBUG +#include + +#include "llama.h" + +int main(void) { + llama_chat_message conversation[] = { + {"system", "You are a helpful assistant"}, + {"user", "Hello"}, + {"assistant", "Hi there"}, + {"user", "Who are you"}, + {"assistant", " I am an assistant "}, + {"user", "Another question"}, + }; + size_t message_count = 6; + std::vector templates = { + // teknium/OpenHermes-2.5-Mistral-7B + "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", + // mistralai/Mistral-7B-Instruct-v0.2 + "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", + // TheBloke/FusionNet_34Bx2_MoE-AWQ + "{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <>\\\\n' + messages[idx]['content'] + '\\\\n<>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}", + // bofenghuang/vigogne-2-70b-chat + "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\\\n' + system_message + '\\\\n<>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\\\n' + content.strip() + '\\\\n<>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", + }; + std::vector expected_substr = { + "<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant", + "[/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + "[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + "[/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + }; + std::vector formatted_chat(1024); + int32_t res; + + // test invalid chat template + res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size()); + assert(res < 0); + + for (size_t i = 0; i < templates.size(); i++) { + std::string custom_template = templates[i]; + std::string substr = expected_substr[i]; + formatted_chat.resize(1024); + res = llama_chat_apply_template( + nullptr, + custom_template.c_str(), + conversation, + message_count, + true, + formatted_chat.data(), + formatted_chat.size() + ); + formatted_chat.resize(res); + std::string output(formatted_chat.data(), formatted_chat.size()); + std::cout << output << "\n-------------------------\n"; + // expect the "formatted_chat" to contain pre-defined strings + assert(output.find(substr) != std::string::npos); + } + return 0; +}