Merge branch 'ggerganov:master' into server_branch

This commit is contained in:
pudepiedj 2024-02-19 12:44:51 +00:00 committed by GitHub
commit aaffb2387f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 339 additions and 123 deletions

View file

@ -741,13 +741,6 @@ function(get_flags CCID CCVER)
if (CCVER VERSION_GREATER_EQUAL 8.1.0) if (CCVER VERSION_GREATER_EQUAL 8.1.0)
list(APPEND CXX_FLAGS -Wextra-semi) list(APPEND CXX_FLAGS -Wextra-semi)
endif() endif()
elseif (CCID MATCHES "Intel")
if (NOT LLAMA_SYCL)
# enable max optimization level when using Intel compiler
set(C_FLAGS -ipo -O3 -static -fp-model=fast -flto -fno-stack-protector)
set(CXX_FLAGS -ipo -O3 -static -fp-model=fast -flto -fno-stack-protector)
add_link_options(-fuse-ld=lld -static-intel)
endif()
endif() endif()
set(GF_C_FLAGS ${C_FLAGS} PARENT_SCOPE) set(GF_C_FLAGS ${C_FLAGS} PARENT_SCOPE)
@ -778,10 +771,7 @@ endif()
set(CUDA_CXX_FLAGS "") set(CUDA_CXX_FLAGS "")
if (LLAMA_CUBLAS) if (LLAMA_CUBLAS)
set(CUDA_FLAGS ${CXX_FLAGS} -use_fast_math) set(CUDA_FLAGS -use_fast_math)
if (NOT MSVC)
list(APPEND CUDA_FLAGS -Wno-pedantic)
endif()
if (LLAMA_ALL_WARNINGS AND NOT MSVC) if (LLAMA_ALL_WARNINGS AND NOT MSVC)
set(NVCC_CMD ${CMAKE_CUDA_COMPILER} .c) set(NVCC_CMD ${CMAKE_CUDA_COMPILER} .c)
@ -814,7 +804,11 @@ if (LLAMA_CUBLAS)
message("-- CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}") message("-- CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}")
get_flags(${CUDA_CCID} ${CUDA_CCVER}) get_flags(${CUDA_CCID} ${CUDA_CCVER})
list(APPEND CUDA_CXX_FLAGS ${GF_CXX_FLAGS}) # This is passed to -Xcompiler later list(APPEND CUDA_CXX_FLAGS ${CXX_FLAGS} ${GF_CXX_FLAGS}) # This is passed to -Xcompiler later
endif()
if (NOT MSVC)
list(APPEND CUDA_CXX_FLAGS -Wno-pedantic)
endif() endif()
endif() endif()

View file

@ -100,6 +100,7 @@ endif
MK_CPPFLAGS = -I. -Icommon MK_CPPFLAGS = -I. -Icommon
MK_CFLAGS = -std=c11 -fPIC MK_CFLAGS = -std=c11 -fPIC
MK_CXXFLAGS = -std=c++11 -fPIC MK_CXXFLAGS = -std=c++11 -fPIC
MK_NVCCFLAGS = -std=c++11
# -Ofast tends to produce faster code, but may not be available for some compilers. # -Ofast tends to produce faster code, but may not be available for some compilers.
ifdef LLAMA_FAST ifdef LLAMA_FAST
@ -220,30 +221,6 @@ ifeq ($(LLAMA_FATAL_WARNINGS),1)
MK_CXXFLAGS += -Werror MK_CXXFLAGS += -Werror
endif endif
ifeq ($(CC_IS_CLANG), 1)
# clang options
MK_CFLAGS += -Wunreachable-code-break -Wunreachable-code-return
MK_HOST_CXXFLAGS += -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi
ifneq '' '$(and $(CC_IS_LLVM_CLANG),$(filter 1,$(shell expr $(CC_VER) \>= 030800)))'
MK_CFLAGS += -Wdouble-promotion
endif
ifneq '' '$(and $(CC_IS_APPLE_CLANG),$(filter 1,$(shell expr $(CC_VER) \>= 070300)))'
MK_CFLAGS += -Wdouble-promotion
endif
else
# gcc options
MK_CFLAGS += -Wdouble-promotion
MK_HOST_CXXFLAGS += -Wno-array-bounds
ifeq ($(shell expr $(CC_VER) \>= 070100), 1)
MK_HOST_CXXFLAGS += -Wno-format-truncation
endif
ifeq ($(shell expr $(CC_VER) \>= 080100), 1)
MK_HOST_CXXFLAGS += -Wextra-semi
endif
endif
# this version of Apple ld64 is buggy # this version of Apple ld64 is buggy
ifneq '' '$(findstring dyld-1015.7,$(shell $(CC) $(LDFLAGS) -Wl,-v 2>&1))' ifneq '' '$(findstring dyld-1015.7,$(shell $(CC) $(LDFLAGS) -Wl,-v 2>&1))'
MK_CPPFLAGS += -DHAVE_BUGGY_APPLE_LINKER MK_CPPFLAGS += -DHAVE_BUGGY_APPLE_LINKER
@ -468,7 +445,7 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
ifdef JETSON_EOL_MODULE_DETECT ifdef JETSON_EOL_MODULE_DETECT
$(NVCC) -I. -Icommon -D_XOPEN_SOURCE=600 -D_GNU_SOURCE -DNDEBUG -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I/usr/local/cuda/targets/aarch64-linux/include -std=c++11 -O3 $(NVCCFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@ $(NVCC) -I. -Icommon -D_XOPEN_SOURCE=600 -D_GNU_SOURCE -DNDEBUG -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I/usr/local/cuda/targets/aarch64-linux/include -std=c++11 -O3 $(NVCCFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
else else
$(NVCC) $(BASE_CXXFLAGS) $(NVCCFLAGS) -Wno-pedantic -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@ $(NVCC) $(NVCCFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
endif # JETSON_EOL_MODULE_DETECT endif # JETSON_EOL_MODULE_DETECT
endif # LLAMA_CUBLAS endif # LLAMA_CUBLAS
@ -579,7 +556,7 @@ override LDFLAGS := $(MK_LDFLAGS) $(LDFLAGS)
ifdef LLAMA_CUBLAS ifdef LLAMA_CUBLAS
GF_CC := $(NVCC) $(NVCCFLAGS) 2>/dev/null .c -Xcompiler GF_CC := $(NVCC) $(NVCCFLAGS) 2>/dev/null .c -Xcompiler
include scripts/get-flags.mk include scripts/get-flags.mk
CUDA_CXXFLAGS := $(GF_CXXFLAGS) CUDA_CXXFLAGS := $(BASE_CXXFLAGS) $(GF_CXXFLAGS) -Wno-pedantic
endif endif
# #
@ -891,3 +868,7 @@ tests/test-model-load-cancel: tests/test-model-load-cancel.cpp ggml.o llama.o te
tests/test-autorelease: tests/test-autorelease.cpp ggml.o llama.o tests/get-model.cpp $(COMMON_DEPS) $(OBJS) tests/test-autorelease: tests/test-autorelease.cpp ggml.o llama.o tests/get-model.cpp $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
tests/test-chat-template: tests/test-chat-template.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

View file

@ -272,7 +272,7 @@ Please install [Visual Studio](https://visualstudio.microsoft.com/) which impact
a. Please follow the procedure in [Get the Intel® oneAPI Base Toolkit ](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html). a. Please follow the procedure in [Get the Intel® oneAPI Base Toolkit ](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html).
Recommend to install to default folder: **/opt/intel/oneapi**. Recommend to install to default folder: **C:\Program Files (x86)\Intel\oneAPI**.
Following guide uses the default folder as example. If you use other folder, please modify the following guide info with your folder. Following guide uses the default folder as example. If you use other folder, please modify the following guide info with your folder.

View file

@ -61,7 +61,7 @@ variety of hardware - locally and in the cloud.
- Plain C/C++ implementation without any dependencies - Plain C/C++ implementation without any dependencies
- Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks - Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
- AVX, AVX2 and AVX512 support for x86 architectures - AVX, AVX2 and AVX512 support for x86 architectures
- 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use - 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use
- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP) - Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP)
- Vulkan, SYCL, and (partial) OpenCL backend support - Vulkan, SYCL, and (partial) OpenCL backend support
- CPU+GPU hybrid inference to partially accelerate models larger than the total VRAM capacity - CPU+GPU hybrid inference to partially accelerate models larger than the total VRAM capacity
@ -768,7 +768,7 @@ The time per token is measured on a MacBook M1 Pro 32GB RAM using 4 and 8 thread
#### How to run #### How to run
1. Download/extract: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research 1. Download/extract: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
2. Run `./perplexity -m models/7B/ggml-model-q4_0.gguf -f wiki.test.raw` 2. Run `./perplexity -m models/7B/ggml-model-q4_0.gguf -f wiki.test.raw`
3. Output: 3. Output:
``` ```

View file

@ -219,7 +219,7 @@ function gg_run_open_llama_3b_v2 {
gg_wget models-mnt/open-llama/3B-v2/ https://huggingface.co/openlm-research/open_llama_3b_v2/resolve/main/pytorch_model.bin gg_wget models-mnt/open-llama/3B-v2/ https://huggingface.co/openlm-research/open_llama_3b_v2/resolve/main/pytorch_model.bin
gg_wget models-mnt/open-llama/3B-v2/ https://huggingface.co/openlm-research/open_llama_3b_v2/raw/main/generation_config.json gg_wget models-mnt/open-llama/3B-v2/ https://huggingface.co/openlm-research/open_llama_3b_v2/raw/main/generation_config.json
gg_wget models-mnt/wikitext/ https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip gg_wget models-mnt/wikitext/ https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
unzip -o models-mnt/wikitext/wikitext-2-raw-v1.zip -d models-mnt/wikitext/ unzip -o models-mnt/wikitext/wikitext-2-raw-v1.zip -d models-mnt/wikitext/
head -n 60 models-mnt/wikitext/wikitext-2-raw/wiki.test.raw > models-mnt/wikitext/wikitext-2-raw/wiki.test-60.raw head -n 60 models-mnt/wikitext/wikitext-2-raw/wiki.test.raw > models-mnt/wikitext/wikitext-2-raw/wiki.test-60.raw
@ -401,7 +401,7 @@ function gg_run_open_llama_7b_v2 {
gg_wget models-mnt/open-llama/7B-v2/ https://huggingface.co/openlm-research/open_llama_7b_v2/resolve/main/pytorch_model-00002-of-00002.bin gg_wget models-mnt/open-llama/7B-v2/ https://huggingface.co/openlm-research/open_llama_7b_v2/resolve/main/pytorch_model-00002-of-00002.bin
gg_wget models-mnt/open-llama/7B-v2/ https://huggingface.co/openlm-research/open_llama_7b_v2/raw/main/generation_config.json gg_wget models-mnt/open-llama/7B-v2/ https://huggingface.co/openlm-research/open_llama_7b_v2/raw/main/generation_config.json
gg_wget models-mnt/wikitext/ https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip gg_wget models-mnt/wikitext/ https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
unzip -o models-mnt/wikitext/wikitext-2-raw-v1.zip -d models-mnt/wikitext/ unzip -o models-mnt/wikitext/wikitext-2-raw-v1.zip -d models-mnt/wikitext/
path_models="../models-mnt/open-llama/7B-v2" path_models="../models-mnt/open-llama/7B-v2"

View file

@ -1704,6 +1704,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
} }
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str()); fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep);
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat); fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau); fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta); fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);

View file

@ -248,7 +248,10 @@ static llama_token llama_sampling_sample_impl(
llama_sample_temp(ctx_main, &cur_p, temp); llama_sample_temp(ctx_main, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
} else { } else {
sampler_queue(ctx_main, params, cur_p, 1); // temperature sampling
size_t min_keep = std::max(1, params.min_keep);
sampler_queue(ctx_main, params, cur_p, min_keep);
id = llama_sample_token(ctx_main, &cur_p); id = llama_sample_token(ctx_main, &cur_p);

View file

@ -22,6 +22,7 @@ enum class llama_sampler_type : char {
typedef struct llama_sampling_params { typedef struct llama_sampling_params {
int32_t n_prev = 64; // number of previous tokens to remember int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled float min_p = 0.05f; // 0.0 = disabled

View file

@ -1533,16 +1533,17 @@ int main(int argc, char ** argv) {
int n_past = 0; int n_past = 0;
ggml_cgraph gf = {}; struct ggml_cgraph * gf = NULL;
gf = ggml_new_graph_custom(ctx0, LLAMA_TRAIN_MAX_NODES, true);
get_example_targets_batch(ctx0, 64*ex+0, tokens_input, targets); get_example_targets_batch(ctx0, 64*ex+0, tokens_input, targets);
struct ggml_tensor * logits = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch); struct ggml_tensor * logits = forward_batch(&model, &kv_self, ctx0, gf, tokens_input, n_tokens, n_past, n_batch);
// struct ggml_tensor * e = cross_entropy_loss(ctx0, targets, logits); // struct ggml_tensor * e = cross_entropy_loss(ctx0, targets, logits);
struct ggml_tensor * e = square_error_loss(ctx0, targets, logits); struct ggml_tensor * e = square_error_loss(ctx0, targets, logits);
ggml_build_forward_expand(&gf, e); ggml_build_forward_expand(gf, e);
ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1); ggml_graph_compute_helper(work_buffer, gf, /*n_threads*/ 1);
float error_before_opt = ggml_get_f32_1d(e, 0); float error_before_opt = ggml_get_f32_1d(e, 0);
@ -1552,8 +1553,8 @@ int main(int argc, char ** argv) {
opt_params_lbfgs.lbfgs.n_iter = 16; opt_params_lbfgs.lbfgs.n_iter = 16;
ggml_opt(ctx0, opt_params_lbfgs, e); ggml_opt(ctx0, opt_params_lbfgs, e);
// //
ggml_build_forward_expand(&gf, e); ggml_build_forward_expand(gf, e);
ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1); ggml_graph_compute_helper(work_buffer, gf, /*n_threads*/ 1);
float error_after_opt = ggml_get_f32_1d(e, 0); float error_after_opt = ggml_get_f32_1d(e, 0);
@ -1600,13 +1601,14 @@ int main(int argc, char ** argv) {
}; };
struct ggml_context * ctx0 = ggml_init(params); struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph gf = {}; struct ggml_cgraph * gf = NULL;
gf = ggml_new_graph_custom(ctx0, LLAMA_TRAIN_MAX_NODES, true);
int n_past = 0; int n_past = 0;
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past); struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, gf, tokens_input, sample_ctx, n_past);
ggml_build_forward_expand(&gf, logits); ggml_build_forward_expand(gf, logits);
ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1); ggml_graph_compute_helper(work_buffer, gf, /*n_threads*/ 1);
struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx); struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx); struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);

View file

@ -25,9 +25,6 @@ if len(clip_tensors) > 0:
clip = {name.replace("vision_tower.vision_tower.", ""): checkpoint[name].float() for name in clip_tensors} clip = {name.replace("vision_tower.vision_tower.", ""): checkpoint[name].float() for name in clip_tensors}
torch.save(clip, f"{args.model}/llava.clip") torch.save(clip, f"{args.model}/llava.clip")
# remove these tensors
for name in clip_tensors:
del checkpoint[name]
# added tokens should be removed to be able to convert Mistral models # added tokens should be removed to be able to convert Mistral models
if os.path.exists(f"{args.model}/added_tokens.json"): if os.path.exists(f"{args.model}/added_tokens.json"):
@ -35,7 +32,6 @@ if len(clip_tensors) > 0:
f.write("{}\n") f.write("{}\n")
torch.save(checkpoint, path)
print("Done!") print("Done!")
print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.") print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")

View file

@ -309,7 +309,7 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens
} }
static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) { static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research // Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
// Output: `perplexity: 13.5106 [114/114]` // Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval // BOS tokens will be added for each chunk before eval
@ -447,7 +447,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
return perplexity_v2(ctx, params); return perplexity_v2(ctx, params);
} }
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research // Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
// Output: `perplexity: 13.5106 [114/114]` // Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval // BOS tokens will be added for each chunk before eval

View file

@ -199,6 +199,8 @@ node index.js
`n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token (default: 0) `n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token (default: 0)
`min_keep`: If greater than 0, force samplers to return N possible tokens at minimum (default: 0)
`image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `prompt`. You can determine the place of the image in the prompt as in the following: `USER:[img-12]Describe the image in detail.\nASSISTANT:`. In this case, `[img-12]` will be replaced by the embeddings of the image with id `12` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 12}]}`. Use `image_data` only with multimodal models, e.g., LLaVA. `image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `prompt`. You can determine the place of the image in the prompt as in the following: `USER:[img-12]Describe the image in detail.\nASSISTANT:`. In this case, `[img-12]` will be replaced by the embeddings of the image with id `12` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 12}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
`slot_id`: Assign the completion task to an specific slot. If is -1 the task will be assigned to a Idle slot (default: -1) `slot_id`: Assign the completion task to an specific slot. If is -1 the task will be assigned to a Idle slot (default: -1)

View file

@ -234,6 +234,7 @@
mirostat_eta: 0.1, // learning rate mirostat_eta: 0.1, // learning rate
grammar: '', grammar: '',
n_probs: 0, // no completion_probabilities, n_probs: 0, // no completion_probabilities,
min_keep: 0, // min probs from each sampler,
image_data: [], image_data: [],
cache_prompt: true, cache_prompt: true,
api_key: '' api_key: ''
@ -791,6 +792,9 @@
<fieldset> <fieldset>
${IntField({ label: "Show Probabilities", max: 10, min: 0, name: "n_probs", value: params.value.n_probs })} ${IntField({ label: "Show Probabilities", max: 10, min: 0, name: "n_probs", value: params.value.n_probs })}
</fieldset> </fieldset>
<fieldset>
${IntField({ label: "Min Probabilities from each Sampler", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })}
</fieldset>
<fieldset> <fieldset>
<label for="api_key">API Key</label> <label for="api_key">API Key</label>
<input type="text" name="api_key" value="${params.value.api_key}" placeholder="Enter API key" oninput=${updateParams} /> <input type="text" name="api_key" value="${params.value.api_key}" placeholder="Enter API key" oninput=${updateParams} />

View file

@ -624,6 +624,7 @@ struct llama_server_context
slot->params.seed = json_value(data, "seed", default_params.seed); slot->params.seed = json_value(data, "seed", default_params.seed);
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar); slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) { if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) {
// Might be better to reject the request with a 400 ? // Might be better to reject the request with a 400 ?
@ -1169,6 +1170,7 @@ struct llama_server_context
{"stream", slot.params.stream}, {"stream", slot.params.stream},
{"logit_bias", slot.sparams.logit_bias}, {"logit_bias", slot.sparams.logit_bias},
{"n_probs", slot.sparams.n_probs}, {"n_probs", slot.sparams.n_probs},
{"min_keep", slot.sparams.min_keep},
{"grammar", slot.sparams.grammar}, {"grammar", slot.sparams.grammar},
{"samplers", samplers_sequence} {"samplers", samplers_sequence}
}; };

View file

@ -6205,7 +6205,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
const int ix = rowx*ncols + col; const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col; const int iy = rowy*ncols + col;
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + slope*pos[col]; const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
vals[col] = val; vals[col] = val;
max_val = max(max_val, val); max_val = max(max_val, val);
@ -9170,17 +9170,17 @@ static void ggml_cuda_op_soft_max(
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
// positions tensor // positions tensor
float * src2_dd = dst_dd; // default to avoid null checks in the kernel float * src2_dd = nullptr;
cuda_pool_alloc<float> src2_f; cuda_pool_alloc<float> src2_f;
ggml_tensor * src2 = dst->src[2]; ggml_tensor * src2 = dst->src[2];
const bool use_src2 = src2 != nullptr; const bool use_src2 = src2 != nullptr;
if (use_src2) { if (use_src2) {
const bool src2_on_device = use_src2 && src2->backend == GGML_BACKEND_GPU; const bool src2_on_device = src2->backend == GGML_BACKEND_GPU;
ggml_tensor_extra_gpu * src2_extra = use_src2 ? (ggml_tensor_extra_gpu *) src2->extra : nullptr;
if (src2_on_device) { if (src2_on_device) {
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
src2_dd = (float *) src2_extra->data_device[g_main_device]; src2_dd = (float *) src2_extra->data_device[g_main_device];
} else { } else {
src2_dd = src2_f.alloc(ggml_nelements(src2)); src2_dd = src2_f.alloc(ggml_nelements(src2));

View file

@ -392,7 +392,7 @@ kernel void kernel_soft_max(
float lmax = -INFINITY; float lmax = -INFINITY;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) { for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]); lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
} }
// find the max value in the block // find the max value in the block
@ -417,7 +417,7 @@ kernel void kernel_soft_max(
// parallel sum // parallel sum
float lsum = 0.0f; float lsum = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) { for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val); const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
lsum += exp_psrc0; lsum += exp_psrc0;
pdst[i00] = exp_psrc0; pdst[i00] = exp_psrc0;
} }
@ -495,7 +495,7 @@ kernel void kernel_soft_max_4(
float4 lmax4 = -INFINITY; float4 lmax4 = -INFINITY;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]); lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
} }
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@ -521,7 +521,7 @@ kernel void kernel_soft_max_4(
// parallel sum // parallel sum
float4 lsum4 = 0.0f; float4 lsum4 = 0.0f;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val); const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
lsum4 += exp_psrc4; lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4; pdst4[i00] = exp_psrc4;
} }
@ -4027,7 +4027,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
y4 += 32 * 32; y4 += 32 * 32;
} }
#else #else
// TODO (void) x;
(void) y;
(void) yl;
(void) nb32;
#endif #endif
for (int row = 0; row < N_DST; ++row) { for (int row = 0; row < N_DST; ++row) {
@ -4170,7 +4173,10 @@ void kernel_mul_mv_iq2_xs_f32_impl(
y4 += 32 * 32; y4 += 32 * 32;
} }
#else #else
// TODO (void) x;
(void) y;
(void) yl;
(void) nb32;
#endif #endif
for (int row = 0; row < N_DST; ++row) { for (int row = 0; row < N_DST; ++row) {
@ -4306,7 +4312,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
y4 += 32 * 32; y4 += 32 * 32;
} }
#else #else
// TODO (void) x;
(void) y;
(void) yl;
(void) nb32;
#endif #endif
for (int row = 0; row < N_DST; ++row) { for (int row = 0; row < N_DST; ++row) {
@ -4424,7 +4433,10 @@ void kernel_mul_mv_iq1_s_f32_impl(
y4 += 16 * 32; y4 += 16 * 32;
} }
#else #else
// TODO (void) x;
(void) y;
(void) yl;
(void) nb32;
#endif #endif
for (int row = 0; row < N_DST; ++row) { for (int row = 0; row < N_DST; ++row) {
@ -4659,6 +4671,8 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
const float dl = d * sc[0]; const float dl = d * sc[0];
const float ml = min * sc[1]; const float ml = min * sc[1];
#else #else
(void) get_scale_min_k4_just2;
q = q + 16 * (il&1); q = q + 16 * (il&1);
device const uint8_t * s = xb->scales; device const uint8_t * s = xb->scales;
device const half2 * dh = (device const half2 *)xb->d; device const half2 * dh = (device const half2 *)xb->d;

View file

@ -1837,9 +1837,9 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
float sigma2 = sumx2/QK_K; float sigma2 = sumx2/QK_K;
for (int j = 0; j < QK_K/16; ++j) { for (int j = 0; j < QK_K/16; ++j) {
const float * restrict qw = quant_weights + QK_K * i + 16*j; const float * restrict qw = quant_weights + QK_K * i + 16*j;
for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]); for (int l = 0; l < QK_K/16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
for (int l = 0; l < 16; ++l) sw[j] += weight[l]; for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];
scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false); scales[j] = make_qkx3_quants(QK_K/16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
} }
float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw); float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
@ -3855,7 +3855,7 @@ static inline __m128i get_scale_shuffle(int i) {
} }
#endif #endif
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) { void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
@ -3866,8 +3866,8 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
assert(nrc == 1); assert(nrc == 1);
#endif #endif
UNUSED(nrc); UNUSED(nrc);
UNUSED(bbx); UNUSED(bx);
UNUSED(bby); UNUSED(by);
UNUSED(bs); UNUSED(bs);
const block_q4_0 * restrict x = vx; const block_q4_0 * restrict x = vx;
@ -4024,15 +4024,15 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs); const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
__m128i bx = _mm_and_si128(lowMask, tmp); __m128i bx_0 = _mm_and_si128(lowMask, tmp);
__m128i by = _mm_loadu_si128((const __m128i *)y[i].qs); __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
bx = _mm_sub_epi8(bx, off); bx_0 = _mm_sub_epi8(bx_0, off);
const __m128i i32_0 = mul_sum_i8_pairs(bx, by); const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4)); bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); by_0 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
bx = _mm_sub_epi8(bx, off); bx_0 = _mm_sub_epi8(bx_0, off);
const __m128i i32_1 = mul_sum_i8_pairs(bx, by); const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0);
// Convert int32_t to float // Convert int32_t to float
__m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1)); __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
@ -4222,7 +4222,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
#endif #endif
} }
void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) { void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_1; const int qk = QK8_1;
const int nb = n / qk; const int nb = n / qk;
@ -4233,8 +4233,8 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
assert(nrc == 1); assert(nrc == 1);
#endif #endif
UNUSED(nrc); UNUSED(nrc);
UNUSED(bbx); UNUSED(bx);
UNUSED(bby); UNUSED(by);
UNUSED(bs); UNUSED(bs);
const block_q4_1 * restrict x = vx; const block_q4_1 * restrict x = vx;
@ -4440,7 +4440,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
#endif #endif
} }
void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) { void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
@ -4448,8 +4448,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
assert(qk == QK5_0); assert(qk == QK5_0);
assert(nrc == 1); assert(nrc == 1);
UNUSED(nrc); UNUSED(nrc);
UNUSED(bbx); UNUSED(bx);
UNUSED(bby); UNUSED(by);
UNUSED(bs); UNUSED(bs);
const block_q5_0 * restrict x = vx; const block_q5_0 * restrict x = vx;
@ -4618,21 +4618,21 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
/* Compute combined scale for the block */ /* Compute combined scale for the block */
const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
__m256i bx = bytes_from_nibbles_32(x[i].qs); __m256i bx_0 = bytes_from_nibbles_32(x[i].qs);
const __m256i bxhi = bytes_from_bits_32(x[i].qh); const __m256i bxhi = bytes_from_bits_32(x[i].qh);
__m128i bxhil = _mm256_castsi256_si128(bxhi); __m128i bxhil = _mm256_castsi256_si128(bxhi);
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1); __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
bxhil = _mm_andnot_si128(bxhil, mask); bxhil = _mm_andnot_si128(bxhil, mask);
bxhih = _mm_andnot_si128(bxhih, mask); bxhih = _mm_andnot_si128(bxhih, mask);
__m128i bxl = _mm256_castsi256_si128(bx); __m128i bxl = _mm256_castsi256_si128(bx_0);
__m128i bxh = _mm256_extractf128_si256(bx, 1); __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
bxl = _mm_or_si128(bxl, bxhil); bxl = _mm_or_si128(bxl, bxhil);
bxh = _mm_or_si128(bxh, bxhih); bxh = _mm_or_si128(bxh, bxhih);
bx = MM256_SET_M128I(bxh, bxl); bx_0 = MM256_SET_M128I(bxh, bxl);
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256 q = mul_sum_i8_pairs_float(bx, by); const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0);
/* Multiply q with scale and accumulate */ /* Multiply q with scale and accumulate */
acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
@ -4731,7 +4731,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
#endif #endif
} }
void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) { void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_1; const int qk = QK8_1;
const int nb = n / qk; const int nb = n / qk;
@ -4739,8 +4739,8 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
assert(qk == QK5_1); assert(qk == QK5_1);
assert(nrc == 1); assert(nrc == 1);
UNUSED(nrc); UNUSED(nrc);
UNUSED(bbx); UNUSED(bx);
UNUSED(bby); UNUSED(by);
UNUSED(bs); UNUSED(bs);
const block_q5_1 * restrict x = vx; const block_q5_1 * restrict x = vx;
@ -4925,22 +4925,22 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
__m256i bx = bytes_from_nibbles_32(x[i].qs); __m256i bx_0 = bytes_from_nibbles_32(x[i].qs);
const __m256i bxhi = bytes_from_bits_32(x[i].qh); const __m256i bxhi = bytes_from_bits_32(x[i].qh);
__m128i bxhil = _mm256_castsi256_si128(bxhi); __m128i bxhil = _mm256_castsi256_si128(bxhi);
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1); __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
bxhil = _mm_and_si128(bxhil, mask); bxhil = _mm_and_si128(bxhil, mask);
bxhih = _mm_and_si128(bxhih, mask); bxhih = _mm_and_si128(bxhih, mask);
__m128i bxl = _mm256_castsi256_si128(bx); __m128i bxl = _mm256_castsi256_si128(bx_0);
__m128i bxh = _mm256_extractf128_si256(bx, 1); __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
bxl = _mm_or_si128(bxl, bxhil); bxl = _mm_or_si128(bxl, bxhil);
bxh = _mm_or_si128(bxh, bxhih); bxh = _mm_or_si128(bxh, bxhih);
bx = MM256_SET_M128I(bxh, bxl); bx_0 = MM256_SET_M128I(bxh, bxl);
const __m256 dy = _mm256_set1_ps(y[i].d); const __m256 dy = _mm256_set1_ps(y[i].d);
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256 q = mul_sum_us8_pairs_float(bx, by); const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0);
acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
} }
@ -5035,7 +5035,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
#endif #endif
} }
void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) { void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
@ -5046,8 +5046,8 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
assert(nrc == 1); assert(nrc == 1);
#endif #endif
UNUSED(nrc); UNUSED(nrc);
UNUSED(bbx); UNUSED(bx);
UNUSED(bby); UNUSED(by);
UNUSED(bs); UNUSED(bs);
const block_q8_0 * restrict x = vx; const block_q8_0 * restrict x = vx;
@ -5169,10 +5169,10 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
// load elements // load elements
vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl); vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[i].qs, vl);
vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl); vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl); vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl);
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl); vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);

19
ggml.c
View file

@ -23,6 +23,9 @@
#include <limits.h> #include <limits.h>
#include <stdarg.h> #include <stdarg.h>
#include <signal.h> #include <signal.h>
#if defined(__gnu_linux__)
#include <syscall.h>
#endif
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
#include <unistd.h> #include <unistd.h>
@ -1971,7 +1974,7 @@ struct ggml_numa_nodes {
uint32_t n_nodes; uint32_t n_nodes;
uint32_t total_cpus; // hardware threads on system uint32_t total_cpus; // hardware threads on system
uint32_t current_node; // node on which main process is execting uint32_t current_node; // node on which main process is execting
#ifdef __linux__ #if defined(__gnu_linux__)
cpu_set_t cpuset; // cpuset from numactl cpu_set_t cpuset; // cpuset from numactl
#else #else
uint32_t cpuset; // no NUMA support outside of Linux at this time. Use a portable datatype uint32_t cpuset; // no NUMA support outside of Linux at this time. Use a portable datatype
@ -2009,7 +2012,7 @@ inline static void ggml_critical_section_end(void) {
atomic_fetch_sub(&g_state_barrier, 1); atomic_fetch_sub(&g_state_barrier, 1);
} }
#ifdef __linux__ #if defined(__gnu_linux__)
static cpu_set_t ggml_get_numa_affinity(void) { static cpu_set_t ggml_get_numa_affinity(void) {
cpu_set_t cpuset; cpu_set_t cpuset;
pthread_t thread; pthread_t thread;
@ -2031,7 +2034,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
return; return;
} }
#ifdef __linux__ #if defined(__gnu_linux__)
struct stat st; struct stat st;
char path[256]; char path[256];
int rv; int rv;
@ -2063,7 +2066,13 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
// figure out which node we're on // figure out which node we're on
uint current_cpu; uint current_cpu;
int getcpu_ret = getcpu(&current_cpu, &g_state.numa.current_node); int getcpu_ret = 0;
#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28)
getcpu_ret = getcpu(&current_cpu, &g_state.numa.current_node);
#else
// old glibc doesn't have a wrapper for this call. Fall back on direct syscall
getcpu_ret = syscall(SYS_getcpu,&current_cpu,&g_state.numa.current_node);
#endif
if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1 || getcpu_ret != 0) { if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1 || getcpu_ret != 0) {
g_state.numa.n_nodes = 0; g_state.numa.n_nodes = 0;
@ -16734,7 +16743,7 @@ typedef pthread_t ggml_thread_t;
#endif #endif
// Android's libc implementation "bionic" does not support setting affinity // Android's libc implementation "bionic" does not support setting affinity
#if defined(__linux__) && !defined(__BIONIC__) #if defined(__gnu_linux__)
static void set_numa_thread_affinity(int thread_n) { static void set_numa_thread_affinity(int thread_n) {
if (!ggml_is_numa()) { if (!ggml_is_numa()) {
return; return;

117
llama.cpp
View file

@ -12508,6 +12508,123 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
return 0; return 0;
} }
// trim whitespace from the beginning and end of a string
static std::string trim(const std::string & str) {
size_t start = 0;
size_t end = str.size();
while (start < end && isspace(str[start])) {
start += 1;
}
while (end > start && isspace(str[end - 1])) {
end -= 1;
}
return str.substr(start, end - start);
}
// Simple version of "llama_apply_chat_template" that only works with strings
// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
static int32_t llama_chat_apply_template_internal(
const std::string & tmpl,
const std::vector<const llama_chat_message *> & chat,
std::string & dest, bool add_ass) {
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
std::stringstream ss;
if (tmpl.find("<|im_start|>") != std::string::npos) {
// chatml template
for (auto message : chat) {
ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
}
if (add_ass) {
ss << "<|im_start|>assistant\n";
}
} else if (tmpl.find("[INST]") != std::string::npos) {
// llama2 template and its variants
// [variant] support system message
bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos;
// [variant] space before + after response
bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos;
// [variant] add BOS inside history
bool add_bos_inside_history = tmpl.find("bos_token + '[INST]") != std::string::npos;
// [variant] trim spaces from the input message
bool strip_message = tmpl.find("content.strip()") != std::string::npos;
// construct the prompt
bool is_inside_turn = true; // skip BOS at the beginning
ss << "[INST] ";
for (auto message : chat) {
std::string content = strip_message ? trim(message->content) : message->content;
std::string role(message->role);
if (!is_inside_turn) {
is_inside_turn = true;
ss << (add_bos_inside_history ? "<s>[INST] " : "[INST] ");
}
if (role == "system") {
if (support_system_message) {
ss << "<<SYS>>\n" << content << "\n<</SYS>>\n\n";
} else {
// if the model does not support system message, we still include it in the first message, but without <<SYS>>
ss << content << "\n";
}
} else if (role == "user") {
ss << content << " [/INST]";
} else {
ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << "</s>";
is_inside_turn = false;
}
}
// llama2 templates seem to not care about "add_generation_prompt"
} else if (tmpl.find("<|user|>") != std::string::npos) {
// zephyr template
for (auto message : chat) {
ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
}
if (add_ass) {
ss << "<|assistant|>\n";
}
} else {
// template not supported
return -1;
}
dest = ss.str();
return dest.size();
}
LLAMA_API int32_t llama_chat_apply_template(
const struct llama_model * model,
const char * tmpl,
const struct llama_chat_message * chat,
size_t n_msg,
bool add_ass,
char * buf,
int32_t length) {
std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
if (tmpl == nullptr) {
GGML_ASSERT(model != nullptr);
// load template from model
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
std::string template_key = "tokenizer.chat_template";
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), curr_tmpl.size());
if (res < 0) {
// worst case: there is no information about template, we will use chatml by default
curr_tmpl = "<|im_start|>"; // see llama_chat_apply_template_internal
} else {
curr_tmpl = std::string(model_template.data(), model_template.size());
}
}
// format the chat to string
std::vector<const llama_chat_message *> chat_vec;
chat_vec.resize(n_msg);
for (size_t i = 0; i < n_msg; i++) {
chat_vec[i] = &chat[i];
}
std::string formatted_chat;
int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
if (res < 0) {
return res;
}
strncpy(buf, formatted_chat.c_str(), length);
return res;
}
struct llama_timings llama_get_timings(struct llama_context * ctx) { struct llama_timings llama_get_timings(struct llama_context * ctx) {
struct llama_timings result = { struct llama_timings result = {
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us, /*.t_start_ms =*/ 1e-3 * ctx->t_start_us,

25
llama.h
View file

@ -305,6 +305,12 @@ extern "C" {
int32_t n_eval; int32_t n_eval;
}; };
// used in chat template
typedef struct llama_chat_message {
const char * role;
const char * content;
} llama_chat_message;
// Helpers for getting default parameters // Helpers for getting default parameters
LLAMA_API struct llama_model_params llama_model_default_params(void); LLAMA_API struct llama_model_params llama_model_default_params(void);
LLAMA_API struct llama_context_params llama_context_default_params(void); LLAMA_API struct llama_context_params llama_context_default_params(void);
@ -699,6 +705,25 @@ extern "C" {
char * buf, char * buf,
int32_t length); int32_t length);
/// Apply chat template. Inspired by hf apply_chat_template() on python.
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
/// NOTE: This function only support some known jinja templates. It is not a jinja parser.
/// @param tmpl A Jinja template to use for this chat. If this is nullptr, the models default chat template will be used instead.
/// @param chat Pointer to a list of multiple llama_chat_message
/// @param n_msg Number of llama_chat_message in this chat
/// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message.
/// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)
/// @param length The size of the allocated buffer
/// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.
LLAMA_API int32_t llama_chat_apply_template(
const struct llama_model * model,
const char * tmpl,
const struct llama_chat_message * chat,
size_t n_msg,
bool add_ass,
char * buf,
int32_t length);
// //
// Grammar // Grammar
// //

View file

@ -1,6 +1,6 @@
ifeq '' '$(findstring clang,$(shell $(GF_CC) --version))' ifeq '' '$(findstring clang,$(shell $(GF_CC) --version))'
GF_CC_IS_GCC = 1 GF_CC_IS_GCC = 1
GF_CC_VER := $(shell { $(GF_CC) -dumpfullversion 2>/dev/null || $(GF_CC) -dumpversion; } | awk -F. '{ printf("%02d%02d%02d", $$1, $$2, $$3) }') GF_CC_VER := $(shell { $(GF_CC) -dumpfullversion 2>/dev/null; echo; $(GF_CC) -dumpversion; } | awk -F. '/./ { printf("%02d%02d%02d", $$1, $$2, $$3); exit }')
else else
GF_CC_IS_CLANG = 1 GF_CC_IS_CLANG = 1
ifeq '' '$(findstring Apple,$(shell $(GF_CC) --version))' ifeq '' '$(findstring Apple,$(shell $(GF_CC) --version))'

View file

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip wget https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
echo "Usage:" echo "Usage:"
echo "" echo ""

View file

@ -28,6 +28,7 @@ endfunction()
llama_build_and_test_executable(test-quantize-fns.cpp) llama_build_and_test_executable(test-quantize-fns.cpp)
llama_build_and_test_executable(test-quantize-perf.cpp) llama_build_and_test_executable(test-quantize-perf.cpp)
llama_build_and_test_executable(test-sampling.cpp) llama_build_and_test_executable(test-sampling.cpp)
llama_build_and_test_executable(test-chat-template.cpp)
llama_build_executable(test-tokenizer-0-llama.cpp) llama_build_executable(test-tokenizer-0-llama.cpp)
llama_test_executable (test-tokenizer-0-llama test-tokenizer-0-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf) llama_test_executable (test-tokenizer-0-llama test-tokenizer-0-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)

View file

@ -0,0 +1,64 @@
#include <iostream>
#include <string>
#include <vector>
#include <sstream>
#undef NDEBUG
#include <cassert>
#include "llama.h"
int main(void) {
llama_chat_message conversation[] = {
{"system", "You are a helpful assistant"},
{"user", "Hello"},
{"assistant", "Hi there"},
{"user", "Who are you"},
{"assistant", " I am an assistant "},
{"user", "Another question"},
};
size_t message_count = 6;
std::vector<std::string> templates = {
// teknium/OpenHermes-2.5-Mistral-7B
"{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
// mistralai/Mistral-7B-Instruct-v0.2
"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
// TheBloke/FusionNet_34Bx2_MoE-AWQ
"{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <<SYS>>\\\\n' + messages[idx]['content'] + '\\\\n<</SYS>>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}",
// bofenghuang/vigogne-2-70b-chat
"{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\\\n' + system_message + '\\\\n<</SYS>>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\\\n' + content.strip() + '\\\\n<</SYS>>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
};
std::vector<std::string> expected_substr = {
"<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant",
"[/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
"</s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
"[/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
};
std::vector<char> formatted_chat(1024);
int32_t res;
// test invalid chat template
res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size());
assert(res < 0);
for (size_t i = 0; i < templates.size(); i++) {
std::string custom_template = templates[i];
std::string substr = expected_substr[i];
formatted_chat.resize(1024);
res = llama_chat_apply_template(
nullptr,
custom_template.c_str(),
conversation,
message_count,
true,
formatted_chat.data(),
formatted_chat.size()
);
formatted_chat.resize(res);
std::string output(formatted_chat.data(), formatted_chat.size());
std::cout << output << "\n-------------------------\n";
// expect the "formatted_chat" to contain pre-defined strings
assert(output.find(substr) != std::string::npos);
}
return 0;
}