Merge branch 'ggerganov:master' into server_branch
This commit is contained in:
commit
aaffb2387f
24 changed files with 339 additions and 123 deletions
|
@ -741,13 +741,6 @@ function(get_flags CCID CCVER)
|
|||
if (CCVER VERSION_GREATER_EQUAL 8.1.0)
|
||||
list(APPEND CXX_FLAGS -Wextra-semi)
|
||||
endif()
|
||||
elseif (CCID MATCHES "Intel")
|
||||
if (NOT LLAMA_SYCL)
|
||||
# enable max optimization level when using Intel compiler
|
||||
set(C_FLAGS -ipo -O3 -static -fp-model=fast -flto -fno-stack-protector)
|
||||
set(CXX_FLAGS -ipo -O3 -static -fp-model=fast -flto -fno-stack-protector)
|
||||
add_link_options(-fuse-ld=lld -static-intel)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(GF_C_FLAGS ${C_FLAGS} PARENT_SCOPE)
|
||||
|
@ -778,10 +771,7 @@ endif()
|
|||
set(CUDA_CXX_FLAGS "")
|
||||
|
||||
if (LLAMA_CUBLAS)
|
||||
set(CUDA_FLAGS ${CXX_FLAGS} -use_fast_math)
|
||||
if (NOT MSVC)
|
||||
list(APPEND CUDA_FLAGS -Wno-pedantic)
|
||||
endif()
|
||||
set(CUDA_FLAGS -use_fast_math)
|
||||
|
||||
if (LLAMA_ALL_WARNINGS AND NOT MSVC)
|
||||
set(NVCC_CMD ${CMAKE_CUDA_COMPILER} .c)
|
||||
|
@ -814,7 +804,11 @@ if (LLAMA_CUBLAS)
|
|||
message("-- CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}")
|
||||
|
||||
get_flags(${CUDA_CCID} ${CUDA_CCVER})
|
||||
list(APPEND CUDA_CXX_FLAGS ${GF_CXX_FLAGS}) # This is passed to -Xcompiler later
|
||||
list(APPEND CUDA_CXX_FLAGS ${CXX_FLAGS} ${GF_CXX_FLAGS}) # This is passed to -Xcompiler later
|
||||
endif()
|
||||
|
||||
if (NOT MSVC)
|
||||
list(APPEND CUDA_CXX_FLAGS -Wno-pedantic)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
|
39
Makefile
39
Makefile
|
@ -97,9 +97,10 @@ endif
|
|||
#
|
||||
|
||||
# keep standard at C11 and C++11
|
||||
MK_CPPFLAGS = -I. -Icommon
|
||||
MK_CFLAGS = -std=c11 -fPIC
|
||||
MK_CXXFLAGS = -std=c++11 -fPIC
|
||||
MK_CPPFLAGS = -I. -Icommon
|
||||
MK_CFLAGS = -std=c11 -fPIC
|
||||
MK_CXXFLAGS = -std=c++11 -fPIC
|
||||
MK_NVCCFLAGS = -std=c++11
|
||||
|
||||
# -Ofast tends to produce faster code, but may not be available for some compilers.
|
||||
ifdef LLAMA_FAST
|
||||
|
@ -220,30 +221,6 @@ ifeq ($(LLAMA_FATAL_WARNINGS),1)
|
|||
MK_CXXFLAGS += -Werror
|
||||
endif
|
||||
|
||||
ifeq ($(CC_IS_CLANG), 1)
|
||||
# clang options
|
||||
MK_CFLAGS += -Wunreachable-code-break -Wunreachable-code-return
|
||||
MK_HOST_CXXFLAGS += -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi
|
||||
|
||||
ifneq '' '$(and $(CC_IS_LLVM_CLANG),$(filter 1,$(shell expr $(CC_VER) \>= 030800)))'
|
||||
MK_CFLAGS += -Wdouble-promotion
|
||||
endif
|
||||
ifneq '' '$(and $(CC_IS_APPLE_CLANG),$(filter 1,$(shell expr $(CC_VER) \>= 070300)))'
|
||||
MK_CFLAGS += -Wdouble-promotion
|
||||
endif
|
||||
else
|
||||
# gcc options
|
||||
MK_CFLAGS += -Wdouble-promotion
|
||||
MK_HOST_CXXFLAGS += -Wno-array-bounds
|
||||
|
||||
ifeq ($(shell expr $(CC_VER) \>= 070100), 1)
|
||||
MK_HOST_CXXFLAGS += -Wno-format-truncation
|
||||
endif
|
||||
ifeq ($(shell expr $(CC_VER) \>= 080100), 1)
|
||||
MK_HOST_CXXFLAGS += -Wextra-semi
|
||||
endif
|
||||
endif
|
||||
|
||||
# this version of Apple ld64 is buggy
|
||||
ifneq '' '$(findstring dyld-1015.7,$(shell $(CC) $(LDFLAGS) -Wl,-v 2>&1))'
|
||||
MK_CPPFLAGS += -DHAVE_BUGGY_APPLE_LINKER
|
||||
|
@ -468,7 +445,7 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
|
|||
ifdef JETSON_EOL_MODULE_DETECT
|
||||
$(NVCC) -I. -Icommon -D_XOPEN_SOURCE=600 -D_GNU_SOURCE -DNDEBUG -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I/usr/local/cuda/targets/aarch64-linux/include -std=c++11 -O3 $(NVCCFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
|
||||
else
|
||||
$(NVCC) $(BASE_CXXFLAGS) $(NVCCFLAGS) -Wno-pedantic -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
|
||||
$(NVCC) $(NVCCFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
|
||||
endif # JETSON_EOL_MODULE_DETECT
|
||||
endif # LLAMA_CUBLAS
|
||||
|
||||
|
@ -579,7 +556,7 @@ override LDFLAGS := $(MK_LDFLAGS) $(LDFLAGS)
|
|||
ifdef LLAMA_CUBLAS
|
||||
GF_CC := $(NVCC) $(NVCCFLAGS) 2>/dev/null .c -Xcompiler
|
||||
include scripts/get-flags.mk
|
||||
CUDA_CXXFLAGS := $(GF_CXXFLAGS)
|
||||
CUDA_CXXFLAGS := $(BASE_CXXFLAGS) $(GF_CXXFLAGS) -Wno-pedantic
|
||||
endif
|
||||
|
||||
#
|
||||
|
@ -891,3 +868,7 @@ tests/test-model-load-cancel: tests/test-model-load-cancel.cpp ggml.o llama.o te
|
|||
tests/test-autorelease: tests/test-autorelease.cpp ggml.o llama.o tests/get-model.cpp $(COMMON_DEPS) $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
||||
tests/test-chat-template: tests/test-chat-template.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
|
|
@ -272,7 +272,7 @@ Please install [Visual Studio](https://visualstudio.microsoft.com/) which impact
|
|||
|
||||
a. Please follow the procedure in [Get the Intel® oneAPI Base Toolkit ](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html).
|
||||
|
||||
Recommend to install to default folder: **/opt/intel/oneapi**.
|
||||
Recommend to install to default folder: **C:\Program Files (x86)\Intel\oneAPI**.
|
||||
|
||||
Following guide uses the default folder as example. If you use other folder, please modify the following guide info with your folder.
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ variety of hardware - locally and in the cloud.
|
|||
- Plain C/C++ implementation without any dependencies
|
||||
- Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
|
||||
- AVX, AVX2 and AVX512 support for x86 architectures
|
||||
- 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use
|
||||
- 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use
|
||||
- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP)
|
||||
- Vulkan, SYCL, and (partial) OpenCL backend support
|
||||
- CPU+GPU hybrid inference to partially accelerate models larger than the total VRAM capacity
|
||||
|
@ -768,7 +768,7 @@ The time per token is measured on a MacBook M1 Pro 32GB RAM using 4 and 8 thread
|
|||
|
||||
#### How to run
|
||||
|
||||
1. Download/extract: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
||||
1. Download/extract: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
|
||||
2. Run `./perplexity -m models/7B/ggml-model-q4_0.gguf -f wiki.test.raw`
|
||||
3. Output:
|
||||
```
|
||||
|
|
|
@ -219,7 +219,7 @@ function gg_run_open_llama_3b_v2 {
|
|||
gg_wget models-mnt/open-llama/3B-v2/ https://huggingface.co/openlm-research/open_llama_3b_v2/resolve/main/pytorch_model.bin
|
||||
gg_wget models-mnt/open-llama/3B-v2/ https://huggingface.co/openlm-research/open_llama_3b_v2/raw/main/generation_config.json
|
||||
|
||||
gg_wget models-mnt/wikitext/ https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip
|
||||
gg_wget models-mnt/wikitext/ https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
|
||||
unzip -o models-mnt/wikitext/wikitext-2-raw-v1.zip -d models-mnt/wikitext/
|
||||
head -n 60 models-mnt/wikitext/wikitext-2-raw/wiki.test.raw > models-mnt/wikitext/wikitext-2-raw/wiki.test-60.raw
|
||||
|
||||
|
@ -401,7 +401,7 @@ function gg_run_open_llama_7b_v2 {
|
|||
gg_wget models-mnt/open-llama/7B-v2/ https://huggingface.co/openlm-research/open_llama_7b_v2/resolve/main/pytorch_model-00002-of-00002.bin
|
||||
gg_wget models-mnt/open-llama/7B-v2/ https://huggingface.co/openlm-research/open_llama_7b_v2/raw/main/generation_config.json
|
||||
|
||||
gg_wget models-mnt/wikitext/ https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip
|
||||
gg_wget models-mnt/wikitext/ https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
|
||||
unzip -o models-mnt/wikitext/wikitext-2-raw-v1.zip -d models-mnt/wikitext/
|
||||
|
||||
path_models="../models-mnt/open-llama/7B-v2"
|
||||
|
|
|
@ -1704,6 +1704,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
|||
}
|
||||
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
|
||||
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
|
||||
fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep);
|
||||
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
|
||||
fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
|
||||
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);
|
||||
|
|
|
@ -248,7 +248,10 @@ static llama_token llama_sampling_sample_impl(
|
|||
llama_sample_temp(ctx_main, &cur_p, temp);
|
||||
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
|
||||
} else {
|
||||
sampler_queue(ctx_main, params, cur_p, 1);
|
||||
// temperature sampling
|
||||
size_t min_keep = std::max(1, params.min_keep);
|
||||
|
||||
sampler_queue(ctx_main, params, cur_p, min_keep);
|
||||
|
||||
id = llama_sample_token(ctx_main, &cur_p);
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ enum class llama_sampler_type : char {
|
|||
typedef struct llama_sampling_params {
|
||||
int32_t n_prev = 64; // number of previous tokens to remember
|
||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
||||
int32_t top_k = 40; // <= 0 to use vocab size
|
||||
float top_p = 0.95f; // 1.0 = disabled
|
||||
float min_p = 0.05f; // 0.0 = disabled
|
||||
|
|
|
@ -1533,16 +1533,17 @@ int main(int argc, char ** argv) {
|
|||
|
||||
int n_past = 0;
|
||||
|
||||
ggml_cgraph gf = {};
|
||||
struct ggml_cgraph * gf = NULL;
|
||||
gf = ggml_new_graph_custom(ctx0, LLAMA_TRAIN_MAX_NODES, true);
|
||||
|
||||
get_example_targets_batch(ctx0, 64*ex+0, tokens_input, targets);
|
||||
|
||||
struct ggml_tensor * logits = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch);
|
||||
struct ggml_tensor * logits = forward_batch(&model, &kv_self, ctx0, gf, tokens_input, n_tokens, n_past, n_batch);
|
||||
// struct ggml_tensor * e = cross_entropy_loss(ctx0, targets, logits);
|
||||
struct ggml_tensor * e = square_error_loss(ctx0, targets, logits);
|
||||
|
||||
ggml_build_forward_expand(&gf, e);
|
||||
ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
|
||||
ggml_build_forward_expand(gf, e);
|
||||
ggml_graph_compute_helper(work_buffer, gf, /*n_threads*/ 1);
|
||||
|
||||
float error_before_opt = ggml_get_f32_1d(e, 0);
|
||||
|
||||
|
@ -1552,8 +1553,8 @@ int main(int argc, char ** argv) {
|
|||
opt_params_lbfgs.lbfgs.n_iter = 16;
|
||||
ggml_opt(ctx0, opt_params_lbfgs, e);
|
||||
//
|
||||
ggml_build_forward_expand(&gf, e);
|
||||
ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
|
||||
ggml_build_forward_expand(gf, e);
|
||||
ggml_graph_compute_helper(work_buffer, gf, /*n_threads*/ 1);
|
||||
|
||||
float error_after_opt = ggml_get_f32_1d(e, 0);
|
||||
|
||||
|
@ -1600,13 +1601,14 @@ int main(int argc, char ** argv) {
|
|||
};
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
ggml_cgraph gf = {};
|
||||
struct ggml_cgraph * gf = NULL;
|
||||
gf = ggml_new_graph_custom(ctx0, LLAMA_TRAIN_MAX_NODES, true);
|
||||
|
||||
int n_past = 0;
|
||||
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past);
|
||||
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, gf, tokens_input, sample_ctx, n_past);
|
||||
|
||||
ggml_build_forward_expand(&gf, logits);
|
||||
ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
|
||||
ggml_build_forward_expand(gf, logits);
|
||||
ggml_graph_compute_helper(work_buffer, gf, /*n_threads*/ 1);
|
||||
|
||||
struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
|
||||
struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);
|
||||
|
|
|
@ -25,9 +25,6 @@ if len(clip_tensors) > 0:
|
|||
clip = {name.replace("vision_tower.vision_tower.", ""): checkpoint[name].float() for name in clip_tensors}
|
||||
torch.save(clip, f"{args.model}/llava.clip")
|
||||
|
||||
# remove these tensors
|
||||
for name in clip_tensors:
|
||||
del checkpoint[name]
|
||||
|
||||
# added tokens should be removed to be able to convert Mistral models
|
||||
if os.path.exists(f"{args.model}/added_tokens.json"):
|
||||
|
@ -35,7 +32,6 @@ if len(clip_tensors) > 0:
|
|||
f.write("{}\n")
|
||||
|
||||
|
||||
torch.save(checkpoint, path)
|
||||
|
||||
print("Done!")
|
||||
print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")
|
||||
|
|
|
@ -309,7 +309,7 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens
|
|||
}
|
||||
|
||||
static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
||||
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
||||
// Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
|
||||
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
||||
// Output: `perplexity: 13.5106 [114/114]`
|
||||
// BOS tokens will be added for each chunk before eval
|
||||
|
@ -447,7 +447,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
return perplexity_v2(ctx, params);
|
||||
}
|
||||
|
||||
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
||||
// Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
|
||||
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
||||
// Output: `perplexity: 13.5106 [114/114]`
|
||||
// BOS tokens will be added for each chunk before eval
|
||||
|
|
|
@ -199,6 +199,8 @@ node index.js
|
|||
|
||||
`n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token (default: 0)
|
||||
|
||||
`min_keep`: If greater than 0, force samplers to return N possible tokens at minimum (default: 0)
|
||||
|
||||
`image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `prompt`. You can determine the place of the image in the prompt as in the following: `USER:[img-12]Describe the image in detail.\nASSISTANT:`. In this case, `[img-12]` will be replaced by the embeddings of the image with id `12` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 12}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
|
||||
|
||||
`slot_id`: Assign the completion task to an specific slot. If is -1 the task will be assigned to a Idle slot (default: -1)
|
||||
|
|
|
@ -234,6 +234,7 @@
|
|||
mirostat_eta: 0.1, // learning rate
|
||||
grammar: '',
|
||||
n_probs: 0, // no completion_probabilities,
|
||||
min_keep: 0, // min probs from each sampler,
|
||||
image_data: [],
|
||||
cache_prompt: true,
|
||||
api_key: ''
|
||||
|
@ -791,6 +792,9 @@
|
|||
<fieldset>
|
||||
${IntField({ label: "Show Probabilities", max: 10, min: 0, name: "n_probs", value: params.value.n_probs })}
|
||||
</fieldset>
|
||||
<fieldset>
|
||||
${IntField({ label: "Min Probabilities from each Sampler", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })}
|
||||
</fieldset>
|
||||
<fieldset>
|
||||
<label for="api_key">API Key</label>
|
||||
<input type="text" name="api_key" value="${params.value.api_key}" placeholder="Enter API key" oninput=${updateParams} />
|
||||
|
|
|
@ -624,6 +624,7 @@ struct llama_server_context
|
|||
slot->params.seed = json_value(data, "seed", default_params.seed);
|
||||
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
||||
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||
|
||||
if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) {
|
||||
// Might be better to reject the request with a 400 ?
|
||||
|
@ -1169,6 +1170,7 @@ struct llama_server_context
|
|||
{"stream", slot.params.stream},
|
||||
{"logit_bias", slot.sparams.logit_bias},
|
||||
{"n_probs", slot.sparams.n_probs},
|
||||
{"min_keep", slot.sparams.min_keep},
|
||||
{"grammar", slot.sparams.grammar},
|
||||
{"samplers", samplers_sequence}
|
||||
};
|
||||
|
|
|
@ -6205,7 +6205,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
|
|||
const int ix = rowx*ncols + col;
|
||||
const int iy = rowy*ncols + col;
|
||||
|
||||
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + slope*pos[col];
|
||||
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
|
||||
|
||||
vals[col] = val;
|
||||
max_val = max(max_val, val);
|
||||
|
@ -9170,17 +9170,17 @@ static void ggml_cuda_op_soft_max(
|
|||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
// positions tensor
|
||||
float * src2_dd = dst_dd; // default to avoid null checks in the kernel
|
||||
float * src2_dd = nullptr;
|
||||
cuda_pool_alloc<float> src2_f;
|
||||
|
||||
ggml_tensor * src2 = dst->src[2];
|
||||
const bool use_src2 = src2 != nullptr;
|
||||
|
||||
if (use_src2) {
|
||||
const bool src2_on_device = use_src2 && src2->backend == GGML_BACKEND_GPU;
|
||||
ggml_tensor_extra_gpu * src2_extra = use_src2 ? (ggml_tensor_extra_gpu *) src2->extra : nullptr;
|
||||
const bool src2_on_device = src2->backend == GGML_BACKEND_GPU;
|
||||
|
||||
if (src2_on_device) {
|
||||
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
|
||||
src2_dd = (float *) src2_extra->data_device[g_main_device];
|
||||
} else {
|
||||
src2_dd = src2_f.alloc(ggml_nelements(src2));
|
||||
|
|
|
@ -392,7 +392,7 @@ kernel void kernel_soft_max(
|
|||
float lmax = -INFINITY;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
|
||||
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
|
||||
}
|
||||
|
||||
// find the max value in the block
|
||||
|
@ -417,7 +417,7 @@ kernel void kernel_soft_max(
|
|||
// parallel sum
|
||||
float lsum = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
|
||||
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
|
||||
lsum += exp_psrc0;
|
||||
pdst[i00] = exp_psrc0;
|
||||
}
|
||||
|
@ -495,7 +495,7 @@ kernel void kernel_soft_max_4(
|
|||
float4 lmax4 = -INFINITY;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
|
||||
}
|
||||
|
||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||
|
@ -521,7 +521,7 @@ kernel void kernel_soft_max_4(
|
|||
// parallel sum
|
||||
float4 lsum4 = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
|
||||
lsum4 += exp_psrc4;
|
||||
pdst4[i00] = exp_psrc4;
|
||||
}
|
||||
|
@ -4027,7 +4027,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|||
y4 += 32 * 32;
|
||||
}
|
||||
#else
|
||||
// TODO
|
||||
(void) x;
|
||||
(void) y;
|
||||
(void) yl;
|
||||
(void) nb32;
|
||||
#endif
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
|
@ -4170,7 +4173,10 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|||
y4 += 32 * 32;
|
||||
}
|
||||
#else
|
||||
// TODO
|
||||
(void) x;
|
||||
(void) y;
|
||||
(void) yl;
|
||||
(void) nb32;
|
||||
#endif
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
|
@ -4306,7 +4312,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|||
y4 += 32 * 32;
|
||||
}
|
||||
#else
|
||||
// TODO
|
||||
(void) x;
|
||||
(void) y;
|
||||
(void) yl;
|
||||
(void) nb32;
|
||||
#endif
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
|
@ -4424,7 +4433,10 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|||
y4 += 16 * 32;
|
||||
}
|
||||
#else
|
||||
// TODO
|
||||
(void) x;
|
||||
(void) y;
|
||||
(void) yl;
|
||||
(void) nb32;
|
||||
#endif
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
|
@ -4659,6 +4671,8 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|||
const float dl = d * sc[0];
|
||||
const float ml = min * sc[1];
|
||||
#else
|
||||
(void) get_scale_min_k4_just2;
|
||||
|
||||
q = q + 16 * (il&1);
|
||||
device const uint8_t * s = xb->scales;
|
||||
device const half2 * dh = (device const half2 *)xb->d;
|
||||
|
|
|
@ -1837,9 +1837,9 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
|
|||
float sigma2 = sumx2/QK_K;
|
||||
for (int j = 0; j < QK_K/16; ++j) {
|
||||
const float * restrict qw = quant_weights + QK_K * i + 16*j;
|
||||
for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
|
||||
for (int l = 0; l < 16; ++l) sw[j] += weight[l];
|
||||
scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
|
||||
for (int l = 0; l < QK_K/16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
|
||||
for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];
|
||||
scales[j] = make_qkx3_quants(QK_K/16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
|
||||
}
|
||||
|
||||
float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
|
||||
|
@ -3855,7 +3855,7 @@ static inline __m128i get_scale_shuffle(int i) {
|
|||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
|
||||
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
||||
|
@ -3866,8 +3866,8 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||
assert(nrc == 1);
|
||||
#endif
|
||||
UNUSED(nrc);
|
||||
UNUSED(bbx);
|
||||
UNUSED(bby);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_q4_0 * restrict x = vx;
|
||||
|
@ -4024,15 +4024,15 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||
|
||||
const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
|
||||
|
||||
__m128i bx = _mm_and_si128(lowMask, tmp);
|
||||
__m128i by = _mm_loadu_si128((const __m128i *)y[i].qs);
|
||||
bx = _mm_sub_epi8(bx, off);
|
||||
const __m128i i32_0 = mul_sum_i8_pairs(bx, by);
|
||||
__m128i bx_0 = _mm_and_si128(lowMask, tmp);
|
||||
__m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
|
||||
bx_0 = _mm_sub_epi8(bx_0, off);
|
||||
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
|
||||
|
||||
bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
|
||||
by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
|
||||
bx = _mm_sub_epi8(bx, off);
|
||||
const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
|
||||
bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
|
||||
by_0 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
|
||||
bx_0 = _mm_sub_epi8(bx_0, off);
|
||||
const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0);
|
||||
|
||||
// Convert int32_t to float
|
||||
__m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
|
||||
|
@ -4222,7 +4222,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
|
||||
void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
const int qk = QK8_1;
|
||||
const int nb = n / qk;
|
||||
|
||||
|
@ -4233,8 +4233,8 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
|||
assert(nrc == 1);
|
||||
#endif
|
||||
UNUSED(nrc);
|
||||
UNUSED(bbx);
|
||||
UNUSED(bby);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_q4_1 * restrict x = vx;
|
||||
|
@ -4440,7 +4440,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
|
||||
void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
||||
|
@ -4448,8 +4448,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||
assert(qk == QK5_0);
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bbx);
|
||||
UNUSED(bby);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_q5_0 * restrict x = vx;
|
||||
|
@ -4618,21 +4618,21 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||
/* Compute combined scale for the block */
|
||||
const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
|
||||
|
||||
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
||||
__m256i bx_0 = bytes_from_nibbles_32(x[i].qs);
|
||||
const __m256i bxhi = bytes_from_bits_32(x[i].qh);
|
||||
__m128i bxhil = _mm256_castsi256_si128(bxhi);
|
||||
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
|
||||
bxhil = _mm_andnot_si128(bxhil, mask);
|
||||
bxhih = _mm_andnot_si128(bxhih, mask);
|
||||
__m128i bxl = _mm256_castsi256_si128(bx);
|
||||
__m128i bxh = _mm256_extractf128_si256(bx, 1);
|
||||
__m128i bxl = _mm256_castsi256_si128(bx_0);
|
||||
__m128i bxh = _mm256_extractf128_si256(bx_0, 1);
|
||||
bxl = _mm_or_si128(bxl, bxhil);
|
||||
bxh = _mm_or_si128(bxh, bxhih);
|
||||
bx = MM256_SET_M128I(bxh, bxl);
|
||||
bx_0 = MM256_SET_M128I(bxh, bxl);
|
||||
|
||||
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
||||
const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
||||
|
||||
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
||||
const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0);
|
||||
|
||||
/* Multiply q with scale and accumulate */
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
|
||||
|
@ -4731,7 +4731,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
|
||||
void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
const int qk = QK8_1;
|
||||
const int nb = n / qk;
|
||||
|
||||
|
@ -4739,8 +4739,8 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
|||
assert(qk == QK5_1);
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bbx);
|
||||
UNUSED(bby);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_q5_1 * restrict x = vx;
|
||||
|
@ -4925,22 +4925,22 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
|||
|
||||
summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
|
||||
|
||||
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
||||
__m256i bx_0 = bytes_from_nibbles_32(x[i].qs);
|
||||
const __m256i bxhi = bytes_from_bits_32(x[i].qh);
|
||||
__m128i bxhil = _mm256_castsi256_si128(bxhi);
|
||||
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
|
||||
bxhil = _mm_and_si128(bxhil, mask);
|
||||
bxhih = _mm_and_si128(bxhih, mask);
|
||||
__m128i bxl = _mm256_castsi256_si128(bx);
|
||||
__m128i bxh = _mm256_extractf128_si256(bx, 1);
|
||||
__m128i bxl = _mm256_castsi256_si128(bx_0);
|
||||
__m128i bxh = _mm256_extractf128_si256(bx_0, 1);
|
||||
bxl = _mm_or_si128(bxl, bxhil);
|
||||
bxh = _mm_or_si128(bxh, bxhih);
|
||||
bx = MM256_SET_M128I(bxh, bxl);
|
||||
bx_0 = MM256_SET_M128I(bxh, bxl);
|
||||
|
||||
const __m256 dy = _mm256_set1_ps(y[i].d);
|
||||
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
||||
const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
||||
|
||||
const __m256 q = mul_sum_us8_pairs_float(bx, by);
|
||||
const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0);
|
||||
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
|
||||
}
|
||||
|
@ -5035,7 +5035,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
|
||||
void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
||||
|
@ -5046,8 +5046,8 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||
assert(nrc == 1);
|
||||
#endif
|
||||
UNUSED(nrc);
|
||||
UNUSED(bbx);
|
||||
UNUSED(bby);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_q8_0 * restrict x = vx;
|
||||
|
@ -5169,10 +5169,10 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
// load elements
|
||||
vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl);
|
||||
vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl);
|
||||
vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[i].qs, vl);
|
||||
vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
|
||||
|
||||
vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl);
|
||||
vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl);
|
||||
|
||||
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
|
||||
vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
|
||||
|
|
19
ggml.c
19
ggml.c
|
@ -23,6 +23,9 @@
|
|||
#include <limits.h>
|
||||
#include <stdarg.h>
|
||||
#include <signal.h>
|
||||
#if defined(__gnu_linux__)
|
||||
#include <syscall.h>
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
#include <unistd.h>
|
||||
|
@ -1971,7 +1974,7 @@ struct ggml_numa_nodes {
|
|||
uint32_t n_nodes;
|
||||
uint32_t total_cpus; // hardware threads on system
|
||||
uint32_t current_node; // node on which main process is execting
|
||||
#ifdef __linux__
|
||||
#if defined(__gnu_linux__)
|
||||
cpu_set_t cpuset; // cpuset from numactl
|
||||
#else
|
||||
uint32_t cpuset; // no NUMA support outside of Linux at this time. Use a portable datatype
|
||||
|
@ -2009,7 +2012,7 @@ inline static void ggml_critical_section_end(void) {
|
|||
atomic_fetch_sub(&g_state_barrier, 1);
|
||||
}
|
||||
|
||||
#ifdef __linux__
|
||||
#if defined(__gnu_linux__)
|
||||
static cpu_set_t ggml_get_numa_affinity(void) {
|
||||
cpu_set_t cpuset;
|
||||
pthread_t thread;
|
||||
|
@ -2031,7 +2034,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
|
|||
return;
|
||||
}
|
||||
|
||||
#ifdef __linux__
|
||||
#if defined(__gnu_linux__)
|
||||
struct stat st;
|
||||
char path[256];
|
||||
int rv;
|
||||
|
@ -2063,7 +2066,13 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
|
|||
|
||||
// figure out which node we're on
|
||||
uint current_cpu;
|
||||
int getcpu_ret = getcpu(¤t_cpu, &g_state.numa.current_node);
|
||||
int getcpu_ret = 0;
|
||||
#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28)
|
||||
getcpu_ret = getcpu(¤t_cpu, &g_state.numa.current_node);
|
||||
#else
|
||||
// old glibc doesn't have a wrapper for this call. Fall back on direct syscall
|
||||
getcpu_ret = syscall(SYS_getcpu,¤t_cpu,&g_state.numa.current_node);
|
||||
#endif
|
||||
|
||||
if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1 || getcpu_ret != 0) {
|
||||
g_state.numa.n_nodes = 0;
|
||||
|
@ -16734,7 +16743,7 @@ typedef pthread_t ggml_thread_t;
|
|||
#endif
|
||||
|
||||
// Android's libc implementation "bionic" does not support setting affinity
|
||||
#if defined(__linux__) && !defined(__BIONIC__)
|
||||
#if defined(__gnu_linux__)
|
||||
static void set_numa_thread_affinity(int thread_n) {
|
||||
if (!ggml_is_numa()) {
|
||||
return;
|
||||
|
|
117
llama.cpp
117
llama.cpp
|
@ -12508,6 +12508,123 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
|
|||
return 0;
|
||||
}
|
||||
|
||||
// trim whitespace from the beginning and end of a string
|
||||
static std::string trim(const std::string & str) {
|
||||
size_t start = 0;
|
||||
size_t end = str.size();
|
||||
while (start < end && isspace(str[start])) {
|
||||
start += 1;
|
||||
}
|
||||
while (end > start && isspace(str[end - 1])) {
|
||||
end -= 1;
|
||||
}
|
||||
return str.substr(start, end - start);
|
||||
}
|
||||
|
||||
// Simple version of "llama_apply_chat_template" that only works with strings
|
||||
// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
|
||||
static int32_t llama_chat_apply_template_internal(
|
||||
const std::string & tmpl,
|
||||
const std::vector<const llama_chat_message *> & chat,
|
||||
std::string & dest, bool add_ass) {
|
||||
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
|
||||
std::stringstream ss;
|
||||
if (tmpl.find("<|im_start|>") != std::string::npos) {
|
||||
// chatml template
|
||||
for (auto message : chat) {
|
||||
ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|im_start|>assistant\n";
|
||||
}
|
||||
} else if (tmpl.find("[INST]") != std::string::npos) {
|
||||
// llama2 template and its variants
|
||||
// [variant] support system message
|
||||
bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos;
|
||||
// [variant] space before + after response
|
||||
bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos;
|
||||
// [variant] add BOS inside history
|
||||
bool add_bos_inside_history = tmpl.find("bos_token + '[INST]") != std::string::npos;
|
||||
// [variant] trim spaces from the input message
|
||||
bool strip_message = tmpl.find("content.strip()") != std::string::npos;
|
||||
// construct the prompt
|
||||
bool is_inside_turn = true; // skip BOS at the beginning
|
||||
ss << "[INST] ";
|
||||
for (auto message : chat) {
|
||||
std::string content = strip_message ? trim(message->content) : message->content;
|
||||
std::string role(message->role);
|
||||
if (!is_inside_turn) {
|
||||
is_inside_turn = true;
|
||||
ss << (add_bos_inside_history ? "<s>[INST] " : "[INST] ");
|
||||
}
|
||||
if (role == "system") {
|
||||
if (support_system_message) {
|
||||
ss << "<<SYS>>\n" << content << "\n<</SYS>>\n\n";
|
||||
} else {
|
||||
// if the model does not support system message, we still include it in the first message, but without <<SYS>>
|
||||
ss << content << "\n";
|
||||
}
|
||||
} else if (role == "user") {
|
||||
ss << content << " [/INST]";
|
||||
} else {
|
||||
ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << "</s>";
|
||||
is_inside_turn = false;
|
||||
}
|
||||
}
|
||||
// llama2 templates seem to not care about "add_generation_prompt"
|
||||
} else if (tmpl.find("<|user|>") != std::string::npos) {
|
||||
// zephyr template
|
||||
for (auto message : chat) {
|
||||
ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|assistant|>\n";
|
||||
}
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
}
|
||||
dest = ss.str();
|
||||
return dest.size();
|
||||
}
|
||||
|
||||
LLAMA_API int32_t llama_chat_apply_template(
|
||||
const struct llama_model * model,
|
||||
const char * tmpl,
|
||||
const struct llama_chat_message * chat,
|
||||
size_t n_msg,
|
||||
bool add_ass,
|
||||
char * buf,
|
||||
int32_t length) {
|
||||
std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
|
||||
if (tmpl == nullptr) {
|
||||
GGML_ASSERT(model != nullptr);
|
||||
// load template from model
|
||||
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
|
||||
std::string template_key = "tokenizer.chat_template";
|
||||
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), curr_tmpl.size());
|
||||
if (res < 0) {
|
||||
// worst case: there is no information about template, we will use chatml by default
|
||||
curr_tmpl = "<|im_start|>"; // see llama_chat_apply_template_internal
|
||||
} else {
|
||||
curr_tmpl = std::string(model_template.data(), model_template.size());
|
||||
}
|
||||
}
|
||||
// format the chat to string
|
||||
std::vector<const llama_chat_message *> chat_vec;
|
||||
chat_vec.resize(n_msg);
|
||||
for (size_t i = 0; i < n_msg; i++) {
|
||||
chat_vec[i] = &chat[i];
|
||||
}
|
||||
std::string formatted_chat;
|
||||
int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
|
||||
if (res < 0) {
|
||||
return res;
|
||||
}
|
||||
strncpy(buf, formatted_chat.c_str(), length);
|
||||
return res;
|
||||
}
|
||||
|
||||
struct llama_timings llama_get_timings(struct llama_context * ctx) {
|
||||
struct llama_timings result = {
|
||||
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us,
|
||||
|
|
25
llama.h
25
llama.h
|
@ -305,6 +305,12 @@ extern "C" {
|
|||
int32_t n_eval;
|
||||
};
|
||||
|
||||
// used in chat template
|
||||
typedef struct llama_chat_message {
|
||||
const char * role;
|
||||
const char * content;
|
||||
} llama_chat_message;
|
||||
|
||||
// Helpers for getting default parameters
|
||||
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
||||
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
||||
|
@ -699,6 +705,25 @@ extern "C" {
|
|||
char * buf,
|
||||
int32_t length);
|
||||
|
||||
/// Apply chat template. Inspired by hf apply_chat_template() on python.
|
||||
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
|
||||
/// NOTE: This function only support some known jinja templates. It is not a jinja parser.
|
||||
/// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead.
|
||||
/// @param chat Pointer to a list of multiple llama_chat_message
|
||||
/// @param n_msg Number of llama_chat_message in this chat
|
||||
/// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message.
|
||||
/// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)
|
||||
/// @param length The size of the allocated buffer
|
||||
/// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.
|
||||
LLAMA_API int32_t llama_chat_apply_template(
|
||||
const struct llama_model * model,
|
||||
const char * tmpl,
|
||||
const struct llama_chat_message * chat,
|
||||
size_t n_msg,
|
||||
bool add_ass,
|
||||
char * buf,
|
||||
int32_t length);
|
||||
|
||||
//
|
||||
// Grammar
|
||||
//
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
ifeq '' '$(findstring clang,$(shell $(GF_CC) --version))'
|
||||
GF_CC_IS_GCC = 1
|
||||
GF_CC_VER := $(shell { $(GF_CC) -dumpfullversion 2>/dev/null || $(GF_CC) -dumpversion; } | awk -F. '{ printf("%02d%02d%02d", $$1, $$2, $$3) }')
|
||||
GF_CC_VER := $(shell { $(GF_CC) -dumpfullversion 2>/dev/null; echo; $(GF_CC) -dumpversion; } | awk -F. '/./ { printf("%02d%02d%02d", $$1, $$2, $$3); exit }')
|
||||
else
|
||||
GF_CC_IS_CLANG = 1
|
||||
ifeq '' '$(findstring Apple,$(shell $(GF_CC) --version))'
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip
|
||||
wget https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
|
||||
|
||||
echo "Usage:"
|
||||
echo ""
|
||||
|
|
|
@ -28,6 +28,7 @@ endfunction()
|
|||
llama_build_and_test_executable(test-quantize-fns.cpp)
|
||||
llama_build_and_test_executable(test-quantize-perf.cpp)
|
||||
llama_build_and_test_executable(test-sampling.cpp)
|
||||
llama_build_and_test_executable(test-chat-template.cpp)
|
||||
|
||||
llama_build_executable(test-tokenizer-0-llama.cpp)
|
||||
llama_test_executable (test-tokenizer-0-llama test-tokenizer-0-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
|
||||
|
|
64
tests/test-chat-template.cpp
Normal file
64
tests/test-chat-template.cpp
Normal file
|
@ -0,0 +1,64 @@
|
|||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
|
||||
#undef NDEBUG
|
||||
#include <cassert>
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
int main(void) {
|
||||
llama_chat_message conversation[] = {
|
||||
{"system", "You are a helpful assistant"},
|
||||
{"user", "Hello"},
|
||||
{"assistant", "Hi there"},
|
||||
{"user", "Who are you"},
|
||||
{"assistant", " I am an assistant "},
|
||||
{"user", "Another question"},
|
||||
};
|
||||
size_t message_count = 6;
|
||||
std::vector<std::string> templates = {
|
||||
// teknium/OpenHermes-2.5-Mistral-7B
|
||||
"{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
|
||||
// mistralai/Mistral-7B-Instruct-v0.2
|
||||
"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
|
||||
// TheBloke/FusionNet_34Bx2_MoE-AWQ
|
||||
"{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <<SYS>>\\\\n' + messages[idx]['content'] + '\\\\n<</SYS>>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}",
|
||||
// bofenghuang/vigogne-2-70b-chat
|
||||
"{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\\\n' + system_message + '\\\\n<</SYS>>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\\\n' + content.strip() + '\\\\n<</SYS>>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
|
||||
};
|
||||
std::vector<std::string> expected_substr = {
|
||||
"<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant",
|
||||
"[/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
|
||||
"</s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
|
||||
"[/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
|
||||
};
|
||||
std::vector<char> formatted_chat(1024);
|
||||
int32_t res;
|
||||
|
||||
// test invalid chat template
|
||||
res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size());
|
||||
assert(res < 0);
|
||||
|
||||
for (size_t i = 0; i < templates.size(); i++) {
|
||||
std::string custom_template = templates[i];
|
||||
std::string substr = expected_substr[i];
|
||||
formatted_chat.resize(1024);
|
||||
res = llama_chat_apply_template(
|
||||
nullptr,
|
||||
custom_template.c_str(),
|
||||
conversation,
|
||||
message_count,
|
||||
true,
|
||||
formatted_chat.data(),
|
||||
formatted_chat.size()
|
||||
);
|
||||
formatted_chat.resize(res);
|
||||
std::string output(formatted_chat.data(), formatted_chat.size());
|
||||
std::cout << output << "\n-------------------------\n";
|
||||
// expect the "formatted_chat" to contain pre-defined strings
|
||||
assert(output.find(substr) != std::string::npos);
|
||||
}
|
||||
return 0;
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue