Merge branch 'master' into compilade/lazy-convert-hf

This commit is contained in:
Francis Couture-Harpin 2024-05-08 10:56:03 -04:00
commit bffdaf4010
43 changed files with 1646 additions and 179 deletions

View file

@ -103,6 +103,8 @@ set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for
set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
"llama: max. batch size for using peer access") "llama: max. batch size for using peer access")
option(LLAMA_CUDA_NO_PEER_COPY "llama: do not use peer to peer copies" OFF) option(LLAMA_CUDA_NO_PEER_COPY "llama: do not use peer to peer copies" OFF)
option(LLAMA_CUDA_NO_VMM "llama: do not try to use CUDA VMM" OFF)
option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF) option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF) option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
option(LLAMA_HIP_UMA "llama: use HIP unified memory architecture" OFF) option(LLAMA_HIP_UMA "llama: use HIP unified memory architecture" OFF)
@ -409,6 +411,9 @@ if (LLAMA_CUDA)
if (LLAMA_CUDA_FORCE_MMQ) if (LLAMA_CUDA_FORCE_MMQ)
add_compile_definitions(GGML_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ)
endif() endif()
if (LLAMA_CUDA_NO_VMM)
add_compile_definitions(GGML_CUDA_NO_VMM)
endif()
add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
if (DEFINED LLAMA_CUDA_DMMV_Y) if (DEFINED LLAMA_CUDA_DMMV_Y)
@ -434,7 +439,11 @@ if (LLAMA_CUDA)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
endif() endif()
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver) if (LLAMA_CUDA_NO_VMM)
# No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
else()
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
endif()
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
# 52 == lowest CUDA 12 standard # 52 == lowest CUDA 12 standard

View file

@ -20,7 +20,8 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
### Hot topics ### Hot topics
- **BPE pre-tokenization support has been added: https://github.com/ggerganov/llama.cpp/pull/6920** - **Initial Flash-Attention support: https://github.com/ggerganov/llama.cpp/pull/5021**
- BPE pre-tokenization support has been added: https://github.com/ggerganov/llama.cpp/pull/6920
- MoE memory layout has been updated - reconvert models for `mmap` support and regenerate `imatrix` https://github.com/ggerganov/llama.cpp/pull/6387 - MoE memory layout has been updated - reconvert models for `mmap` support and regenerate `imatrix` https://github.com/ggerganov/llama.cpp/pull/6387
- Model sharding instructions using `gguf-split` https://github.com/ggerganov/llama.cpp/discussions/6404 - Model sharding instructions using `gguf-split` https://github.com/ggerganov/llama.cpp/discussions/6404
- Fix major bug in Metal batched inference https://github.com/ggerganov/llama.cpp/pull/6225 - Fix major bug in Metal batched inference https://github.com/ggerganov/llama.cpp/pull/6225
@ -935,17 +936,25 @@ If your issue is with model generation quality, then please at least scan the fo
### Android ### Android
#### Build on Android using Termux
[Termux](https://github.com/termux/termux-app#installation) is a method to execute `llama.cpp` on an Android device (no root required).
```
apt update && apt upgrade -y
apt install git make cmake
```
It's recommended to move your model inside the `~/` directory for best performance:
```
cd storage/downloads
mv model.gguf ~/
```
[Get the code](https://github.com/ggerganov/llama.cpp#get-the-code) & [follow the Linux build instructions](https://github.com/ggerganov/llama.cpp#build) to build `llama.cpp`.
#### Building the Project using Android NDK #### Building the Project using Android NDK
You can easily run `llama.cpp` on Android device with [termux](https://termux.dev/). Obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake.
First, install the essential packages for termux:
```
pkg install clang wget git cmake
```
Second, obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake:
You can execute the following commands on your computer to avoid downloading the NDK to your mobile. Of course, you can also do this in Termux.
Execute the following commands on your computer to avoid downloading the NDK to your mobile. Alternatively, you can also do this in Termux:
``` ```
$ mkdir build-android $ mkdir build-android
$ cd build-android $ cd build-android
@ -953,7 +962,9 @@ $ export NDK=<your_ndk_directory>
$ cmake -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_C_FLAGS=-march=armv8.4a+dotprod .. $ cmake -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_C_FLAGS=-march=armv8.4a+dotprod ..
$ make $ make
``` ```
Install [termux](https://termux.dev/) on your device and run `termux-setup-storage` to get access to your SD card.
Install [termux](https://github.com/termux/termux-app#installation) on your device and run `termux-setup-storage` to get access to your SD card (if Android 11+ then run the command twice).
Finally, copy these built `llama` binaries and the model file to your device storage. Because the file permissions in the Android sdcard cannot be changed, you can copy the executable files to the `/data/data/com.termux/files/home/bin` path, and then execute the following commands in Termux to add executable permission: Finally, copy these built `llama` binaries and the model file to your device storage. Because the file permissions in the Android sdcard cannot be changed, you can copy the executable files to the `/data/data/com.termux/files/home/bin` path, and then execute the following commands in Termux to add executable permission:
(Assumed that you have pushed the built executable files to the /sdcard/llama.cpp/bin path using `adb push`) (Assumed that you have pushed the built executable files to the /sdcard/llama.cpp/bin path using `adb push`)
@ -975,25 +986,10 @@ $cd /data/data/com.termux/files/home/bin
$./main -m ../model/llama-2-7b-chat.Q4_K_M.gguf -n 128 -cml $./main -m ../model/llama-2-7b-chat.Q4_K_M.gguf -n 128 -cml
``` ```
Here is a demo of an interactive session running on Pixel 5 phone: Here's a demo of an interactive session running on Pixel 5 phone:
https://user-images.githubusercontent.com/271616/225014776-1d567049-ad71-4ef2-b050-55b0b3b9274c.mp4 https://user-images.githubusercontent.com/271616/225014776-1d567049-ad71-4ef2-b050-55b0b3b9274c.mp4
#### Build on Android using Termux
[Termux](https://github.com/termux/termux-app#installation) is an alternative to execute `llama.cpp` on an Android device (no root required).
```
apt update && apt upgrade -y
apt install git
```
It's recommended to move your model inside the `~/` directory for best performance:
```
cd storage/downloads
mv model.gguf ~/
```
[Follow the Linux build instructions](https://github.com/ggerganov/llama.cpp#build) to build `llama.cpp`.
### Docker ### Docker
#### Prerequisites #### Prerequisites

View file

@ -160,9 +160,8 @@ function gg_run_test_scripts_debug {
set -e set -e
# TODO: too slow, run on dedicated node (cd ./examples/gguf-split && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log
#(cd ./examples/gguf-split && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log (cd ./examples/quantize && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log
#(cd ./examples/quantize && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log
set +e set +e
} }
@ -695,8 +694,10 @@ test $ret -eq 0 && gg_run ctest_release
if [ -z ${GG_BUILD_LOW_PERF} ]; then if [ -z ${GG_BUILD_LOW_PERF} ]; then
test $ret -eq 0 && gg_run embd_bge_small test $ret -eq 0 && gg_run embd_bge_small
test $ret -eq 0 && gg_run test_scripts_debug if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then
test $ret -eq 0 && gg_run test_scripts_release test $ret -eq 0 && gg_run test_scripts_debug
test $ret -eq 0 && gg_run test_scripts_release
fi
if [ -z ${GG_BUILD_VRAM_GB} ] || [ ${GG_BUILD_VRAM_GB} -ge 8 ]; then if [ -z ${GG_BUILD_VRAM_GB} ] || [ ${GG_BUILD_VRAM_GB} -ge 8 ]; then
if [ -z ${GG_BUILD_CUDA} ]; then if [ -z ${GG_BUILD_CUDA} ]; then

View file

@ -911,6 +911,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.instruct = true; params.instruct = true;
return true; return true;
} }
if (arg == "-cnv" || arg == "--conversation") {
params.conversation = true;
return true;
}
if (arg == "-cml" || arg == "--chatml") { if (arg == "-cml" || arg == "--chatml") {
params.chatml = true; params.chatml = true;
return true; return true;
@ -1417,6 +1421,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --version show version and build info\n"); printf(" --version show version and build info\n");
printf(" -i, --interactive run in interactive mode\n"); printf(" -i, --interactive run in interactive mode\n");
printf(" --interactive-first run in interactive mode and wait for input right away\n"); printf(" --interactive-first run in interactive mode and wait for input right away\n");
printf(" -cnv, --conversation run in conversation mode (does not print special tokens and suffix/prefix)\n");
printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n"); printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n");
printf(" -cml, --chatml run in chatml mode (use with ChatML-compatible models)\n"); printf(" -cml, --chatml run in chatml mode (use with ChatML-compatible models)\n");
printf(" --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); printf(" --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n");

View file

@ -140,6 +140,7 @@ struct gpt_params {
bool random_prompt = false; // do not randomize prompt if none provided bool random_prompt = false; // do not randomize prompt if none provided
bool use_color = false; // use color to distinguish generations and inputs bool use_color = false; // use color to distinguish generations and inputs
bool interactive = false; // interactive mode bool interactive = false; // interactive mode
bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix)
bool chatml = false; // chatml mode (used for models trained on chatml syntax) bool chatml = false; // chatml mode (used for models trained on chatml syntax)
bool prompt_cache_all = false; // save user input and generations to prompt cache bool prompt_cache_all = false; // save user input and generations to prompt cache
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it

View file

@ -35,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
result->prev.resize(params.n_prev); result->prev.resize(params.n_prev);
result->n_considered = 0;
llama_sampling_set_rng_seed(result, params.seed); llama_sampling_set_rng_seed(result, params.seed);
return result; return result;
@ -64,6 +66,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
std::fill(ctx->prev.begin(), ctx->prev.end(), 0); std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->cur.clear(); ctx->cur.clear();
ctx->n_considered = 0;
} }
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
@ -253,6 +256,8 @@ static llama_token llama_sampling_sample_impl(
} }
} }
ctx_sampling->n_considered = cur_p.size;
return id; return id;
} }

View file

@ -81,6 +81,7 @@ struct llama_sampling_context {
// TODO: replace with ring-buffer // TODO: replace with ring-buffer
std::vector<llama_token> prev; std::vector<llama_token> prev;
std::vector<llama_token_data> cur; std::vector<llama_token_data> cur;
size_t n_considered;
std::mt19937 rng; std::mt19937 rng;
}; };

View file

@ -67,6 +67,9 @@ models = [
{"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", }, {"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", },
{"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", }, {"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", },
{"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", }, {"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", },
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
] ]
# make directory "models/tokenizers" if it doesn't exist # make directory "models/tokenizers" if it doesn't exist
@ -150,6 +153,8 @@ for model in models:
# print the "pre_tokenizer" content from the tokenizer.json # print the "pre_tokenizer" content from the tokenizer.json
with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f: with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f:
cfg = json.load(f) cfg = json.load(f)
normalizer = cfg["normalizer"]
logger.info("normalizer: " + json.dumps(normalizer, indent=4))
pre_tokenizer = cfg["pre_tokenizer"] pre_tokenizer = cfg["pre_tokenizer"]
logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4)) logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4))

View file

@ -397,6 +397,15 @@ class Model:
if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8": if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8":
# ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01 # ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01
res = "command-r" res = "command-r"
if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea":
# ref: https://huggingface.co/Qwen/Qwen1.5-7B
res = "qwen2"
if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166":
# ref: https://huggingface.co/allenai/OLMo-1.7-7B-hf
res = "olmo"
if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e":
# ref: https://huggingface.co/databricks/dbrx-instruct
res = "dbrx"
if res is None: if res is None:
logger.warning("\n") logger.warning("\n")
@ -2248,8 +2257,9 @@ class OlmoModel(Model):
def set_gguf_parameters(self): def set_gguf_parameters(self):
super().set_gguf_parameters() super().set_gguf_parameters()
self.gguf_writer.add_layer_norm_eps(1e-5) self.gguf_writer.add_layer_norm_eps(1e-5)
if "clip_qkv" in self.hparams is not None: clip_qkv = self.hparams.get("clip_qkv")
self.gguf_writer.add_clamp_kqv(self.hparams["clip_qkv"]) if clip_qkv is not None:
self.gguf_writer.add_clamp_kqv(clip_qkv)
# Same as super class, but permuting q_proj, k_proj # Same as super class, but permuting q_proj, k_proj
# Copied from: LlamaModel # Copied from: LlamaModel

View file

@ -1512,25 +1512,27 @@ def main(args_in: list[str] | None = None) -> None:
if args.big_endian: if args.big_endian:
endianess = gguf.GGUFEndian.BIG endianess = gguf.GGUFEndian.BIG
params = Params.load(model_plus) params = None
if params.n_ctx == -1: if args.pad_vocab or not args.vocab_only:
if args.ctx is None: params = Params.load(model_plus)
msg = """\ if params.n_ctx == -1:
The model doesn't have a context size, and you didn't specify one with --ctx if args.ctx is None:
Please specify one with --ctx: msg = """\
- LLaMA v1: --ctx 2048 The model doesn't have a context size, and you didn't specify one with --ctx
- LLaMA v2: --ctx 4096""" Please specify one with --ctx:
parser.error(textwrap.dedent(msg)) - LLaMA v1: --ctx 2048
params.n_ctx = args.ctx - LLaMA v2: --ctx 4096"""
parser.error(textwrap.dedent(msg))
params.n_ctx = args.ctx
if args.outtype: if args.outtype:
params.ftype = { params.ftype = {
"f32": GGMLFileType.AllF32, "f32": GGMLFileType.AllF32,
"f16": GGMLFileType.MostlyF16, "f16": GGMLFileType.MostlyF16,
"q8_0": GGMLFileType.MostlyQ8_0, "q8_0": GGMLFileType.MostlyQ8_0,
}[args.outtype] }[args.outtype]
logger.info(f"params = {params}") logger.info(f"params = {params}")
model_parent_path = model_plus.paths[0].parent model_parent_path = model_plus.paths[0].parent
vocab_path = Path(args.vocab_dir or args.model or model_parent_path) vocab_path = Path(args.vocab_dir or args.model or model_parent_path)
@ -1543,6 +1545,17 @@ def main(args_in: list[str] | None = None) -> None:
if not args.outfile: if not args.outfile:
raise ValueError("need --outfile if using --vocab-only") raise ValueError("need --outfile if using --vocab-only")
outfile = args.outfile outfile = args.outfile
if params is None:
params = Params(
n_vocab = vocab.vocab_size,
n_embd = 1,
n_layer = 1,
n_ctx = 1,
n_ff = 1,
n_head = 1,
n_head_kv = 1,
f_norm_eps = 1e-5,
)
OutputFile.write_vocab_only(outfile, params, vocab, special_vocab, OutputFile.write_vocab_only(outfile, params, vocab, special_vocab,
endianess=endianess, pad_vocab=args.pad_vocab) endianess=endianess, pad_vocab=args.pad_vocab)
logger.info(f"Wrote {outfile}") logger.info(f"Wrote {outfile}")

View file

@ -23,7 +23,7 @@ Install BLIS:
sudo make install sudo make install
``` ```
We recommend using openmp since it's easier to modify the cores been used. We recommend using openmp since it's easier to modify the cores being used.
### llama.cpp compilation ### llama.cpp compilation

View file

@ -96,9 +96,9 @@ NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorc
This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`. This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`.
Have a look to existing implementation like `build_llama`, `build_dbrx` or `build_bert`. Have a look at existing implementation like `build_llama`, `build_dbrx` or `build_bert`.
When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support of missing backend operations can be added in another PR. When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support for missing backend operations can be added in another PR.
Note: to debug the inference graph: you can use [eval-callback](../examples/eval-callback). Note: to debug the inference graph: you can use [eval-callback](../examples/eval-callback).

View file

@ -575,7 +575,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32); GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
auto add_to_f32 = [] (struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { auto add_to_f32 = [] (struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16) { if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16 || a->type == GGML_TYPE_BF16) {
return ggml_add_cast(ctx, a, b, GGML_TYPE_F32); return ggml_add_cast(ctx, a, b, GGML_TYPE_F32);
} else if (a->type == GGML_TYPE_F32) { } else if (a->type == GGML_TYPE_F32) {
return ggml_add(ctx, a, b); return ggml_add(ctx, a, b);

View file

@ -19,6 +19,7 @@
struct Stats { struct Stats {
std::vector<float> values; std::vector<float> values;
std::vector<int> counts;
int ncall = 0; int ncall = 0;
}; };
@ -121,12 +122,10 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
auto & e = m_stats[wname]; auto & e = m_stats[wname];
++e.ncall; ++e.ncall;
// NOTE: since we select top-k experts, the number of calls for the expert tensors will be k times larger
// using the following line, we can correct for that if needed by replacing the line above with:
//if (idx == t->src[0]->ne[0] - 1) ++e.ncall;
if (e.values.empty()) { if (e.values.empty()) {
e.values.resize(src1->ne[0]*n_as, 0); e.values.resize(src1->ne[0]*n_as, 0);
e.counts.resize(src1->ne[0]*n_as, 0);
} }
else if (e.values.size() != (size_t)src1->ne[0]*n_as) { else if (e.values.size() != (size_t)src1->ne[0]*n_as) {
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as); fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as);
@ -153,6 +152,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
for (int j = 0; j < (int)src1->ne[0]; ++j) { for (int j = 0; j < (int)src1->ne[0]; ++j) {
e.values[e_start + j] += x[j]*x[j]; e.values[e_start + j] += x[j]*x[j];
e.counts[e_start + j]++;
} }
} }
} }
@ -170,6 +170,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
auto& e = m_stats[wname]; auto& e = m_stats[wname];
if (e.values.empty()) { if (e.values.empty()) {
e.values.resize(src1->ne[0], 0); e.values.resize(src1->ne[0], 0);
e.counts.resize(src1->ne[0], 0);
} }
else if (e.values.size() != (size_t)src1->ne[0]) { else if (e.values.size() != (size_t)src1->ne[0]) {
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]); fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]);
@ -183,6 +184,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
const float * x = data + row * src1->ne[0]; const float * x = data + row * src1->ne[0];
for (int j = 0; j < (int)src1->ne[0]; ++j) { for (int j = 0; j < (int)src1->ne[0]; ++j) {
e.values[j] += x[j]*x[j]; e.values[j] += x[j]*x[j];
e.counts[j]++;
} }
} }
if (e.ncall > m_last_call) { if (e.ncall > m_last_call) {
@ -222,7 +224,13 @@ void IMatrixCollector::save_imatrix(const char * fname, const char * dataset) co
out.write((const char *) &p.second.ncall, sizeof(p.second.ncall)); out.write((const char *) &p.second.ncall, sizeof(p.second.ncall));
int nval = p.second.values.size(); int nval = p.second.values.size();
out.write((const char *) &nval, sizeof(nval)); out.write((const char *) &nval, sizeof(nval));
if (nval > 0) out.write((const char *) p.second.values.data(), nval * sizeof(float)); if (nval > 0) {
std::vector<float> tmp(nval);
for (int i = 0; i < nval; i++) {
tmp[i] = (p.second.values[i] / static_cast<float>(p.second.counts[i])) * static_cast<float>(p.second.ncall);
}
out.write((const char*)tmp.data(), nval*sizeof(float));
}
} }
// Write the number of call the matrix was computed with // Write the number of call the matrix was computed with
@ -270,14 +278,28 @@ bool IMatrixCollector::load_imatrix(const char * imatrix_file, std::unordered_ma
imatrix_data = {}; imatrix_data = {};
return false; return false;
} }
e.values.resize(nval);
in.read((char*)e.values.data(), nval*sizeof(float)); // When re-called from load_imatrix() with add set, this will already be created.
if (e.values.empty()) {
e.values.resize(nval, 0);
e.counts.resize(nval, 0);
}
std::vector<float> tmp(nval);
in.read((char*)tmp.data(), nval*sizeof(float));
if (in.fail()) { if (in.fail()) {
printf("%s: failed reading data for entry %d\n",__func__,i); printf("%s: failed reading data for entry %d\n",__func__,i);
imatrix_data = {}; imatrix_data = {};
return false; return false;
} }
e.ncall = ncall;
// Recreate the state as expected by save_imatrix(), and corerct for weighted sum.
for (int i = 0; i < nval; i++) {
e.values[i] += tmp[i];
e.counts[i] += ncall;
}
e.ncall += ncall;
} }
return true; return true;
} }

View file

@ -56,7 +56,7 @@ python ./examples/llava/convert-image-encoder-to-gguf.py -m ../clip-vit-large-pa
python ./convert.py ../llava-v1.5-7b --skip-unknown python ./convert.py ../llava-v1.5-7b --skip-unknown
``` ```
Now both the LLaMA part and the image encoder is in the `llava-v1.5-7b` directory. Now both the LLaMA part and the image encoder are in the `llava-v1.5-7b` directory.
## LLaVA 1.6 gguf conversion ## LLaVA 1.6 gguf conversion
1) First clone a LLaVA 1.6 model: 1) First clone a LLaVA 1.6 model:

View file

@ -143,7 +143,7 @@ The `--ctx-size` option allows you to set the size of the prompt context used by
### Extended Context Size ### Extended Context Size
Some fine-tuned models have extended the context length by scaling RoPE. For example, if the original pre-trained model have a context length (max sequence length) of 4096 (4k) and the fine-tuned model have 32k. That is a scaling factor of 8, and should work by setting the above `--ctx-size` to 32768 (32k) and `--rope-scale` to 8. Some fine-tuned models have extended the context length by scaling RoPE. For example, if the original pre-trained model has a context length (max sequence length) of 4096 (4k) and the fine-tuned model has 32k. That is a scaling factor of 8, and should work by setting the above `--ctx-size` to 32768 (32k) and `--rope-scale` to 8.
- `--rope-scale N`: Where N is the linear scaling factor used by the fine-tuned model. - `--rope-scale N`: Where N is the linear scaling factor used by the fine-tuned model.
@ -286,7 +286,7 @@ These options help improve the performance and memory usage of the LLaMA models.
- `--numa distribute`: Pin an equal proportion of the threads to the cores on each NUMA node. This will spread the load amongst all cores on the system, utilitizing all memory channels at the expense of potentially requiring memory to travel over the slow links between nodes. - `--numa distribute`: Pin an equal proportion of the threads to the cores on each NUMA node. This will spread the load amongst all cores on the system, utilitizing all memory channels at the expense of potentially requiring memory to travel over the slow links between nodes.
- `--numa isolate`: Pin all threads to the NUMA node that the program starts on. This limits the number of cores and amount of memory that can be used, but guarantees all memory access remains local to the NUMA node. - `--numa isolate`: Pin all threads to the NUMA node that the program starts on. This limits the number of cores and amount of memory that can be used, but guarantees all memory access remains local to the NUMA node.
- `--numa numactl`: Pin threads to the CPUMAP that is passed to the program by starting it with the numactl utility. This is the most flexible mode, and allow arbitraty core usage patterns, for example a map that uses all the cores on one NUMA nodes, and just enough cores on a second node to saturate the inter-node memory bus. - `--numa numactl`: Pin threads to the CPUMAP that is passed to the program by starting it with the numactl utility. This is the most flexible mode, and allow arbitrary core usage patterns, for example a map that uses all the cores on one NUMA nodes, and just enough cores on a second node to saturate the inter-node memory bus.
These flags attempt optimizations that help on some systems with non-uniform memory access. This currently consists of one of the above strategies, and disabling prefetch and readahead for mmap. The latter causes mapped pages to be faulted in on first access instead of all at once, and in combination with pinning threads to NUMA nodes, more of the pages end up on the NUMA node where they are used. Note that if the model is already in the system page cache, for example because of a previous run without this option, this will have little effect unless you drop the page cache first. This can be done by rebooting the system or on Linux by writing '3' to '/proc/sys/vm/drop_caches' as root. These flags attempt optimizations that help on some systems with non-uniform memory access. This currently consists of one of the above strategies, and disabling prefetch and readahead for mmap. The latter causes mapped pages to be faulted in on first access instead of all at once, and in combination with pinning threads to NUMA nodes, more of the pages end up on the NUMA node where they are used. Note that if the model is already in the system page cache, for example because of a previous run without this option, this will have little effect unless you drop the page cache first. This can be done by rebooting the system or on Linux by writing '3' to '/proc/sys/vm/drop_caches' as root.

View file

@ -362,6 +362,9 @@ int main(int argc, char ** argv) {
params.interactive_first = true; params.interactive_first = true;
params.antiprompt.emplace_back("<|im_start|>user\n"); params.antiprompt.emplace_back("<|im_start|>user\n");
} }
else if (params.conversation) {
params.interactive_first = true;
}
// enable interactive mode if interactive start is specified // enable interactive mode if interactive start is specified
if (params.interactive_first) { if (params.interactive_first) {
@ -733,7 +736,7 @@ int main(int argc, char ** argv) {
// display text // display text
if (input_echo && display) { if (input_echo && display) {
for (auto id : embd) { for (auto id : embd) {
const std::string token_str = llama_token_to_piece(ctx, id); const std::string token_str = llama_token_to_piece(ctx, id, !params.conversation);
printf("%s", token_str.c_str()); printf("%s", token_str.c_str());
if (embd.size() > 1) { if (embd.size() > 1) {
@ -796,7 +799,7 @@ int main(int argc, char ** argv) {
// deal with end of generation tokens in interactive mode // deal with end of generation tokens in interactive mode
if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) { if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
LOG("found EOS token\n"); LOG("found an EOG token\n");
if (params.interactive) { if (params.interactive) {
if (!params.antiprompt.empty()) { if (!params.antiprompt.empty()) {
@ -816,7 +819,7 @@ int main(int argc, char ** argv) {
if (n_past > 0 && is_interacting) { if (n_past > 0 && is_interacting) {
LOG("waiting for user input\n"); LOG("waiting for user input\n");
if (params.instruct || params.chatml) { if (params.conversation || params.instruct || params.chatml) {
printf("\n> "); printf("\n> ");
} }
@ -826,7 +829,7 @@ int main(int argc, char ** argv) {
} }
std::string buffer; std::string buffer;
if (!params.input_prefix.empty()) { if (!params.input_prefix.empty() && !params.conversation) {
LOG("appending input prefix: '%s'\n", params.input_prefix.c_str()); LOG("appending input prefix: '%s'\n", params.input_prefix.c_str());
printf("%s", params.input_prefix.c_str()); printf("%s", params.input_prefix.c_str());
} }
@ -850,7 +853,7 @@ int main(int argc, char ** argv) {
// Entering a empty line lets the user pass control back // Entering a empty line lets the user pass control back
if (buffer.length() > 1) { if (buffer.length() > 1) {
// append input suffix if any // append input suffix if any
if (!params.input_suffix.empty()) { if (!params.input_suffix.empty() && !params.conversation) {
LOG("appending input suffix: '%s'\n", params.input_suffix.c_str()); LOG("appending input suffix: '%s'\n", params.input_suffix.c_str());
printf("%s", params.input_suffix.c_str()); printf("%s", params.input_suffix.c_str());
} }

View file

@ -46,7 +46,8 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 4.45G, +0.0122 ppl @ LLaMA-v1-7B", }, { "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 4.45G, +0.0122 ppl @ LLaMA-v1-7B", },
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", }, { "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", },
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", }, { "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", },
{ "F16", LLAMA_FTYPE_MOSTLY_F16, "13.00G @ 7B", }, { "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, -0.0020 ppl @ Mistral-7B", },
{ "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", },
{ "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", }, { "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", },
// Note: Ensure COPY comes after F32 to avoid ftype 0 from matching. // Note: Ensure COPY comes after F32 to avoid ftype 0 from matching.
{ "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", }, { "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", },

View file

@ -62,6 +62,18 @@ page cache before using this. See https://github.com/ggerganov/llama.cpp/issues/
- `--chat-template JINJA_TEMPLATE`: Set custom jinja chat template. This parameter accepts a string, not a file name. Default: template taken from model's metadata. We only support [some pre-defined templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) - `--chat-template JINJA_TEMPLATE`: Set custom jinja chat template. This parameter accepts a string, not a file name. Default: template taken from model's metadata. We only support [some pre-defined templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template)
- `--log-disable`: Output logs to stdout only, not to `llama.log`. Default: enabled - `--log-disable`: Output logs to stdout only, not to `llama.log`. Default: enabled
- `--log-format FORMAT`: Define the log output to FORMAT: json or text Default: `json` - `--log-format FORMAT`: Define the log output to FORMAT: json or text Default: `json`
- `--rope-scaling` : RoPE scaling method. Defaults to linear unless otherwise specified by the model. Options are `none`, `linear`, `yarn`
- `--rope-freq-base N` : RoPE frequency base (default: loaded from model)
- `--rope-freq-scale N`: RoPE frequency scaling factor, expands context by a factor of 1/N (e.g. 0.25)
- `--yarn-ext-factor N` : YaRN: extrapolation mix factor (Default: 1.0, 0.0 = full interpolation)
- `--yarn-attn-factor N` : YaRN: scale sqrt(t) or attention magnitude (default: 1.0)
- `--yarn-beta-slow N`: YaRN: High correction dim or alpha (default: 1.0)
- `--yarn-beta-fast N`: YaRN: low correction dim or beta (default: 32.0)
- `--pooling` : Pooling type for embeddings, use model default if unspecified. Options are `none`, `mean`, `cls`
- `-dt N`, `--defrag-thold N`: KV cache defragmentation threshold (default: -1.0, < 0 = disabled)
- `-fa`, `--flash-attn` : enable flash attention (default: disabled).
- `-ctk TYPE`, `--cache-type-k TYPE` : KV cache data type for K (default: `f16`, options `f32`, `f16`, `q8_0`, `q4_0`, `q4_1`, `iq4_nl`, `q5_0`, or `q5_1`)
- `-ctv TYPE`, `--cache-type-v TYPE` : KV cache type for V (default `f16`, see `-ctk` for options)
**If compiled with `LLAMA_SERVER_SSL=ON`** **If compiled with `LLAMA_SERVER_SSL=ON`**
- `--ssl-key-file FNAME`: path to file a PEM-encoded SSL private key - `--ssl-key-file FNAME`: path to file a PEM-encoded SSL private key
@ -260,7 +272,7 @@ node index.js
`logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings, e.g. `[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does. Default: `[]` `logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings, e.g. `[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does. Default: `[]`
`n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token. Default: `0` `n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token given the sampling settings. Note that for temperature < 0 the tokens are sampled greedily but token probabilities are still being calculated via a simple softmax of the logits without considering any other sampler settings. Default: `0`
`min_keep`: If greater than 0, force samplers to return N possible tokens at minimum. Default: `0` `min_keep`: If greater than 0, force samplers to return N possible tokens at minimum. Default: `0`
@ -319,7 +331,7 @@ Notice that each `probs` is an array of length `n_probs`.
`content`: Set the text to tokenize. `content`: Set the text to tokenize.
Note that a special `BOS` token is never inserted. `add_special`: Boolean indicating if special tokens, i.e. `BOS`, should be inserted. Default: `false`
- **POST** `/detokenize`: Convert tokens to text. - **POST** `/detokenize`: Convert tokens to text.

View file

@ -2266,17 +2266,31 @@ struct server_context {
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false }; llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
result.tok = id; result.tok = id;
const int32_t n_probs = slot.sparams.n_probs; const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
if (slot.sparams.temp <= 0 && n_probs > 0) { if (n_probs > 0) {
// for llama_sample_token_greedy we need to sort candidates const size_t n_considered = slot.ctx_sampling->n_considered;
llama_sample_softmax(ctx, &cur_p);
}
for (size_t i = 0; i < std::min(cur_p.size, (size_t) n_probs); ++i) { // Make sure at least n_probs top tokens are at the front of the vector:
result.probs.push_back({ if (slot.sparams.temp == 0.0f && n_probs > n_considered) {
cur_p.data[i].id, llama_sample_top_k(ctx, &cur_p, n_probs, 0);
cur_p.data[i].p }
});
if (slot.sparams.temp == 0.0f) {
// With greedy sampling the probabilities have possibly not been calculated.
for (size_t i = 0; i < n_probs; ++i) {
result.probs.push_back({
cur_p.data[i].id,
i == 0 ? 1.0f : 0.0f
});
}
} else {
for (size_t i = 0; i < n_probs; ++i) {
result.probs.push_back({
cur_p.data[i].id,
i >= n_considered ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
});
}
}
} }
if (!process_token(result, slot)) { if (!process_token(result, slot)) {
@ -3633,7 +3647,8 @@ int main(int argc, char ** argv) {
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
if (body.count("content") != 0) { if (body.count("content") != 0) {
tokens = ctx_server.tokenize(body["content"], false); const bool add_special = json_value(body, "add_special", false);
tokens = ctx_server.tokenize(body["content"], add_special);
} }
const json data = format_tokenizer_response(tokens); const json data = format_tokenizer_response(tokens);
return res.set_content(data.dump(), "application/json; charset=utf-8"); return res.set_content(data.dump(), "application/json; charset=utf-8");

View file

@ -7,6 +7,7 @@ Feature: llama.cpp server
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And a model file test-model.gguf And a model file test-model.gguf
And a model alias tinyllama-2 And a model alias tinyllama-2
And BOS token is 1
And 42 as server seed And 42 as server seed
# KV Cache corresponds to the total amount of tokens # KV Cache corresponds to the total amount of tokens
# that can be stored across all independent sequences: #4130 # that can be stored across all independent sequences: #4130
@ -91,7 +92,18 @@ Feature: llama.cpp server
""" """
What is the capital of France ? What is the capital of France ?
""" """
Then tokens can be detokenize Then tokens can be detokenized
And tokens do not begin with BOS
Scenario: Tokenize w/ BOS
Given adding special tokens
When tokenizing:
"""
What is the capital of Germany?
"""
Then tokens begin with BOS
Given first token is removed
Then tokens can be detokenized
Scenario: Models available Scenario: Models available
Given available models Given available models

View file

@ -376,6 +376,11 @@ def step_seed(context, seed):
context.seed.append(seed) context.seed.append(seed)
@step('BOS token is {bos:d}')
def step_bos_token(context, bos):
context.bos = bos
@step('a prefix prompt') @step('a prefix prompt')
def step_prompt_prefix(context): def step_prompt_prefix(context):
context.prompt_prefix = context_text(context) context.prompt_prefix = context_text(context)
@ -656,21 +661,29 @@ async def all_embeddings_are_generated(context):
assert_embeddings(context.tasks_result.pop().pop()) assert_embeddings(context.tasks_result.pop().pop())
@step('adding special tokens')
def step_tokenize_set_add_special(context):
context.tokenize_add_special = True
@step('tokenizing') @step('tokenizing')
@async_run_until_complete @async_run_until_complete
async def step_tokenize(context): async def step_tokenize(context):
context.tokenized_text = context_text(context) context.tokenized_text = context_text(context)
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
tokenize_args = {
"content": context.tokenized_text,
}
if getattr(context, 'tokenize_add_special', None) is not None:
tokenize_args['add_special'] = context.tokenize_add_special
async with session.post(f'{context.base_url}/tokenize', async with session.post(f'{context.base_url}/tokenize',
json={ json=tokenize_args) as response:
"content": context.tokenized_text,
}) as response:
assert response.status == 200 assert response.status == 200
tokenize_json = await response.json() tokenize_json = await response.json()
context.tokens = tokenize_json['tokens'] context.tokens = tokenize_json['tokens']
@step('tokens can be detokenize') @step('tokens can be detokenized')
@async_run_until_complete @async_run_until_complete
async def step_detokenize(context): async def step_detokenize(context):
assert len(context.tokens) > 0 assert len(context.tokens) > 0
@ -685,6 +698,21 @@ async def step_detokenize(context):
assert context.tokenized_text == detokenize_json['content'].strip() assert context.tokenized_text == detokenize_json['content'].strip()
@step('tokens begin with BOS')
def step_strings_for_tokenization(context):
assert context.tokens[0] == context.bos
@step('tokens do not begin with BOS')
def step_strings_for_tokenization(context):
assert context.tokens[0] != context.bos
@step('first token is removed')
def step_strings_for_tokenization(context):
context.tokens = context.tokens[1:]
@step('an OPTIONS request is sent from {origin}') @step('an OPTIONS request is sent from {origin}')
@async_run_until_complete @async_run_until_complete
async def step_options_request(context, origin): async def step_options_request(context, origin):

View file

@ -49,18 +49,18 @@ extern bool server_log_json;
#define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__) #define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__)
#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) #define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
static inline void server_log(const char *level, const char *function, int line, const char *message, const nlohmann::ordered_json &extra); static inline void server_log(const char * level, const char * function, int line, const char * message, const json & extra);
template <typename T> template <typename T>
static T json_value(const json &body, const std::string &key, const T &default_value) { static T json_value(const json & body, const std::string & key, const T & default_value) {
// Fallback null to default value // Fallback null to default value
if (body.contains(key) && !body.at(key).is_null()){ if (body.contains(key) && !body.at(key).is_null()) {
try { try {
return body.value(key, default_value); return body.at(key);
} } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) {
catch (nlohmann::json_abi_v3_11_3::detail::type_error const&){ std::stringstream ss;
std::string message = "Wrong type supplied for parameter '" + key + "'. Expected '" + typeid(default_value).name() + "', using default value."; ss << "Wrong type supplied for parameter '" << key << "'. Expected '" << json(default_value).type_name() << "', using default value.";
server_log("WARN", __func__, __LINE__, message.c_str(), body); LOG_WARNING(ss.str().c_str(), body);
return default_value; return default_value;
} }
} else { } else {
@ -68,16 +68,16 @@ static T json_value(const json &body, const std::string &key, const T &default_v
} }
} }
static inline void server_log(const char *level, const char *function, int line, const char *message, const nlohmann::ordered_json &extra) { static inline void server_log(const char * level, const char * function, int line, const char * message, const json & extra) {
std::stringstream ss_tid; std::stringstream ss_tid;
ss_tid << std::this_thread::get_id(); ss_tid << std::this_thread::get_id();
json log = nlohmann::ordered_json{ json log = json{
{"tid", ss_tid.str()}, {"tid", ss_tid.str()},
{"timestamp", time(nullptr)}, {"timestamp", time(nullptr)},
}; };
if (server_log_json) { if (server_log_json) {
log.merge_patch( { log.merge_patch({
{"level", level}, {"level", level},
{"function", function}, {"function", function},
{"line", line}, {"line", line},
@ -98,7 +98,7 @@ static inline void server_log(const char *level, const char *function, int line,
} }
std::stringstream ss; std::stringstream ss;
ss << buf << " |"; ss << buf << " |";
for (const auto& el : log.items()) for (const auto & el : log.items())
{ {
const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace); const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace);
ss << " " << el.key() << "=" << value; ss << " " << el.key() << "=" << value;

View file

@ -1,6 +1,6 @@
# llama.cpp/example/sycl # llama.cpp/example/sycl
This example program provide the tools for llama.cpp for SYCL on Intel GPU. This example program provides the tools for llama.cpp for SYCL on Intel GPU.
## Tool ## Tool

30
flake.lock generated
View file

@ -5,11 +5,11 @@
"nixpkgs-lib": "nixpkgs-lib" "nixpkgs-lib": "nixpkgs-lib"
}, },
"locked": { "locked": {
"lastModified": 1712014858, "lastModified": 1714641030,
"narHash": "sha256-sB4SWl2lX95bExY2gMFG5HIzvva5AVMJd4Igm+GpZNw=", "narHash": "sha256-yzcRNDoyVP7+SCNX0wmuDju1NUCt8Dz9+lyUXEI0dbI=",
"owner": "hercules-ci", "owner": "hercules-ci",
"repo": "flake-parts", "repo": "flake-parts",
"rev": "9126214d0a59633752a136528f5f3b9aa8565b7d", "rev": "e5d10a24b66c3ea8f150e47dfdb0416ab7c3390e",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -20,11 +20,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1714076141, "lastModified": 1714635257,
"narHash": "sha256-Drmja/f5MRHZCskS6mvzFqxEaZMeciScCTFxWVLqWEY=", "narHash": "sha256-4cPymbty65RvF1DWQfc+Bc8B233A1BWxJnNULJKQ1EY=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "7bb2ccd8cdc44c91edba16c48d2c8f331fb3d856", "rev": "63c3a29ca82437c87573e4c6919b09a24ea61b0f",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -36,20 +36,14 @@
}, },
"nixpkgs-lib": { "nixpkgs-lib": {
"locked": { "locked": {
"dir": "lib", "lastModified": 1714640452,
"lastModified": 1711703276, "narHash": "sha256-QBx10+k6JWz6u7VsohfSw8g8hjdBZEf8CFzXH1/1Z94=",
"narHash": "sha256-iMUFArF0WCatKK6RzfUJknjem0H9m4KgorO/p3Dopkk=", "type": "tarball",
"owner": "NixOS", "url": "https://github.com/NixOS/nixpkgs/archive/50eb7ecf4cd0a5756d7275c8ba36790e5bd53e33.tar.gz"
"repo": "nixpkgs",
"rev": "d8fe5e6c92d0d190646fb9f1056741a229980089",
"type": "github"
}, },
"original": { "original": {
"dir": "lib", "type": "tarball",
"owner": "NixOS", "url": "https://github.com/NixOS/nixpkgs/archive/50eb7ecf4cd0a5756d7275c8ba36790e5bd53e33.tar.gz"
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
} }
}, },
"root": { "root": {

View file

@ -113,7 +113,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
for (int id = 0; id < info.device_count; ++id) { for (int id = 0; id < info.device_count; ++id) {
int device_vmm = 0; int device_vmm = 0;
#if !defined(GGML_USE_HIPBLAS) #if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
CUdevice device; CUdevice device;
CU_CHECK(cuDeviceGet(&device, id)); CU_CHECK(cuDeviceGet(&device, id));
CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device)); CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
@ -259,7 +259,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
}; };
// pool with virtual memory // pool with virtual memory
#if !defined(GGML_USE_HIPBLAS) #if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
struct ggml_cuda_pool_vmm : public ggml_cuda_pool { struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
@ -356,7 +356,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
#endif // !defined(GGML_USE_HIPBLAS) #endif // !defined(GGML_USE_HIPBLAS)
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) { std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
#if !defined(GGML_USE_HIPBLAS) #if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
if (ggml_cuda_info().devices[device].vmm) { if (ggml_cuda_info().devices[device].vmm) {
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device)); return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
} }

View file

@ -17,6 +17,83 @@
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
/**
* Converts brain16 to float32.
*
* The bfloat16 floating point format has the following structure:
*
* sign
*
* exponent
*
* mantissa
*
*
* 0b0000000000000000 brain16
*
* Since bf16 has the same number of exponent bits as a 32bit float,
* encoding and decoding numbers becomes relatively straightforward.
*
* sign
*
* exponent
*
* mantissa
*
*
* 0b00000000000000000000000000000000 IEEE binary32
*
* For comparison, the standard fp16 format has fewer exponent bits.
*
* sign
*
* exponent
*
* mantissa
*
*
* 0b0000000000000000 IEEE binary16
*
* @see IEEE 754-2008
*/
static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
union {
float f;
uint32_t i;
} u;
u.i = (uint32_t)h.bits << 16;
return u.f;
}
/**
* Converts float32 to brain16.
*
* This function is binary identical to AMD Zen4 VCVTNEPS2BF16.
* Subnormals shall be flushed to zero, and NANs will be quiet.
* This code should vectorize nicely if using modern compilers.
*/
static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
ggml_bf16_t h;
union {
float f;
uint32_t i;
} u;
u.f = s;
if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
h.bits = (u.i >> 16) | 64; /* force to quiet */
return h;
}
if (!(u.i & 0x7f800000)) { /* subnormal */
h.bits = (u.i & 0x80000000) >> 16; /* flush to zero */
return h;
}
h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
return h;
}
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif

View file

@ -803,7 +803,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:
{ {
return op->ne[3] == 1; return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1;
} }
default: default:
return false; return false;

View file

@ -2175,7 +2175,7 @@ kernel void kernel_flash_attn_ext_f16(
const short D4 = D/4; const short D4 = D/4;
const short D8 = D/8; const short D8 = D/8;
const short Q8 = Q/8; //const short Q8 = Q/8;
const short NW = N_SIMDWIDTH; const short NW = N_SIMDWIDTH;
const short SH = (C + Q); // shared memory per simdgroup in (half) const short SH = (C + Q); // shared memory per simdgroup in (half)

View file

@ -12450,6 +12450,24 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
const size_t nb = nbytes/ggml_type_size(type); const size_t nb = nbytes/ggml_type_size(type);
switch (type) { switch (type) {
case GGML_TYPE_BF16:
{
int nans = 0;
int infs = 0;
const unsigned short * f = (const unsigned short *) data;
for (size_t i = 0; i < nb; ++i) {
nans += (f[i] & 0x7fff) > 0x7f80;
infs += (f[i] & 0x7fff) == 0x7f80;
}
if (nans) {
fprintf(stderr, "%s: found %d NaNs in row of %zu BF16 values\n", __func__, nans, nb);
return false;
}
if (infs) {
fprintf(stderr, "%s: found %d infinities in row of %zu BF16 values\n", __func__, infs, nb);
return false;
}
} break;
case GGML_TYPE_F16: case GGML_TYPE_F16:
{ {
const ggml_fp16_t * f = (const ggml_fp16_t *) data; const ggml_fp16_t * f = (const ggml_fp16_t *) data;

1031
ggml.c

File diff suppressed because it is too large Load diff

20
ggml.h
View file

@ -326,14 +326,20 @@ extern "C" {
// get ggml_status name string // get ggml_status name string
GGML_API GGML_CALL const char * ggml_status_to_string(enum ggml_status status); GGML_API GGML_CALL const char * ggml_status_to_string(enum ggml_status status);
// ieee 754-2008 half-precision float16
// todo: make this not an integral type
typedef uint16_t ggml_fp16_t; typedef uint16_t ggml_fp16_t;
GGML_API float ggml_fp16_to_fp32(ggml_fp16_t);
GGML_API ggml_fp16_t ggml_fp32_to_fp16(float);
GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t *, float *, int64_t);
GGML_API void ggml_fp32_to_fp16_row(const float *, ggml_fp16_t *, int64_t);
// convert FP16 <-> FP32 // google brain half-precision bfloat16
GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x); typedef struct { uint16_t bits; } ggml_bf16_t;
GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x); GGML_API ggml_bf16_t ggml_fp32_to_bf16(float);
GGML_API float ggml_bf16_to_fp32(ggml_bf16_t); // consider just doing << 16
GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n); GGML_API void ggml_bf16_to_fp32_row(const ggml_bf16_t *, float *, int64_t);
GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n); GGML_API void ggml_fp32_to_bf16_row(const float *, ggml_bf16_t *, int64_t);
struct ggml_object; struct ggml_object;
struct ggml_context; struct ggml_context;
@ -370,6 +376,7 @@ extern "C" {
GGML_TYPE_I64 = 27, GGML_TYPE_I64 = 27,
GGML_TYPE_F64 = 28, GGML_TYPE_F64 = 28,
GGML_TYPE_IQ1_M = 29, GGML_TYPE_IQ1_M = 29,
GGML_TYPE_BF16 = 30,
GGML_TYPE_COUNT, GGML_TYPE_COUNT,
}; };
@ -410,6 +417,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
}; };
// available tensor operations: // available tensor operations:

View file

@ -817,6 +817,7 @@ class GGMLQuantizationType(IntEnum):
I64 = 27 I64 = 27
F64 = 28 F64 = 28
IQ1_M = 29 IQ1_M = 29
BF16 = 30
class GGUFEndian(IntEnum): class GGUFEndian(IntEnum):
@ -888,6 +889,7 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
GGMLQuantizationType.I64: (1, 8), GGMLQuantizationType.I64: (1, 8),
GGMLQuantizationType.F64: (1, 8), GGMLQuantizationType.F64: (1, 8),
GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32), GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32),
GGMLQuantizationType.BF16: (1, 2),
} }

View file

@ -51,7 +51,7 @@ single-line ::= [^\n]+ "\n"`
## Sequences and Alternatives ## Sequences and Alternatives
The order of symbols in a sequence matter. For example, in `"1. " move " " move "\n"`, the `"1. "` must come before the first `move`, etc. The order of symbols in a sequence matters. For example, in `"1. " move " " move "\n"`, the `"1. "` must come before the first `move`, etc.
Alternatives, denoted by `|`, give different sequences that are acceptable. For example, in `move ::= pawn | nonpawn | castle`, `move` can be a `pawn` move, a `nonpawn` move, or a `castle`. Alternatives, denoted by `|`, give different sequences that are acceptable. For example, in `move ::= pawn | nonpawn | castle`, `move` can be a `pawn` move, a `nonpawn` move, or a `castle`.

View file

@ -3175,6 +3175,7 @@ struct llama_model_loader {
switch (type_max) { switch (type_max) {
case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break; case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break;
case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break; case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break;
case GGML_TYPE_BF16: ftype = LLAMA_FTYPE_MOSTLY_BF16; break;
case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break; case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break;
case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break; case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break;
case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break; case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break;
@ -3666,6 +3667,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
switch (ftype) { switch (ftype) {
case LLAMA_FTYPE_ALL_F32: return "all F32"; case LLAMA_FTYPE_ALL_F32: return "all F32";
case LLAMA_FTYPE_MOSTLY_F16: return "F16"; case LLAMA_FTYPE_MOSTLY_F16: return "F16";
case LLAMA_FTYPE_MOSTLY_BF16: return "BF16";
case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0";
case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1";
case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16:
@ -4389,6 +4391,15 @@ static void llm_load_vocab(
} else if ( } else if (
tokenizer_pre == "command-r") { tokenizer_pre == "command-r") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R; vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R;
} else if (
tokenizer_pre == "qwen2") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2;
} else if (
tokenizer_pre == "olmo") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_OLMO;
} else if (
tokenizer_pre == "dbrx") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DBRX;
} else { } else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
} }
@ -6126,6 +6137,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
|| !( || !(
model.ftype == LLAMA_FTYPE_ALL_F32 || model.ftype == LLAMA_FTYPE_ALL_F32 ||
model.ftype == LLAMA_FTYPE_MOSTLY_F16 || model.ftype == LLAMA_FTYPE_MOSTLY_F16 ||
model.ftype == LLAMA_FTYPE_MOSTLY_BF16 ||
model.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || model.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ||
model.ftype == LLAMA_FTYPE_MOSTLY_Q4_1 model.ftype == LLAMA_FTYPE_MOSTLY_Q4_1
) )
@ -12194,6 +12206,7 @@ struct llm_tokenizer_bpe {
case LLAMA_VOCAB_TYPE_BPE: case LLAMA_VOCAB_TYPE_BPE:
switch (vocab.type_pre) { switch (vocab.type_pre) {
case LLAMA_VOCAB_PRE_TYPE_LLAMA3: case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
case LLAMA_VOCAB_PRE_TYPE_DBRX:
word_collection = unicode_regex_split(text, { word_collection = unicode_regex_split(text, {
// original regex from tokenizer.json // original regex from tokenizer.json
//"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
@ -12248,10 +12261,18 @@ struct llm_tokenizer_bpe {
}); });
break; break;
case LLAMA_VOCAB_PRE_TYPE_GPT2: case LLAMA_VOCAB_PRE_TYPE_GPT2:
case LLAMA_VOCAB_PRE_TYPE_OLMO:
word_collection = unicode_regex_split(text, { word_collection = unicode_regex_split(text, {
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
}); });
break; break;
case LLAMA_VOCAB_PRE_TYPE_QWEN2:
word_collection = unicode_regex_split(text, {
// original regex from tokenizer.json
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
});
break;
default: default:
// default regex for BPE tokenization pre-processing // default regex for BPE tokenization pre-processing
word_collection = unicode_regex_split(text, { word_collection = unicode_regex_split(text, {
@ -14154,13 +14175,16 @@ static void llama_tensor_dequantize_internal(
if (qtype.to_float == NULL) { if (qtype.to_float == NULL) {
throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type))); throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type)));
} }
} else if (tensor->type != GGML_TYPE_F16) { } else if (tensor->type != GGML_TYPE_F16 &&
tensor->type != GGML_TYPE_BF16) {
throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type))); throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type)));
} }
if (nthread < 2) { if (nthread < 2) {
if (tensor->type == GGML_TYPE_F16) { if (tensor->type == GGML_TYPE_F16) {
ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements); ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements);
} else if (tensor->type == GGML_TYPE_BF16) {
ggml_bf16_to_fp32_row((ggml_bf16_t *)tensor->data, f32_output, nelements);
} else if (ggml_is_quantized(tensor->type)) { } else if (ggml_is_quantized(tensor->type)) {
qtype.to_float(tensor->data, f32_output, nelements); qtype.to_float(tensor->data, f32_output, nelements);
} else { } else {
@ -14169,7 +14193,14 @@ static void llama_tensor_dequantize_internal(
return; return;
} }
size_t block_size = tensor->type == GGML_TYPE_F16 ? 1 : (size_t)ggml_blck_size(tensor->type); size_t block_size;
if (tensor->type == GGML_TYPE_F16 ||
tensor->type == GGML_TYPE_BF16) {
block_size = 1;
} else {
block_size = (size_t)ggml_blck_size(tensor->type);
}
size_t block_size_bytes = ggml_type_size(tensor->type); size_t block_size_bytes = ggml_type_size(tensor->type);
GGML_ASSERT(nelements % block_size == 0); GGML_ASSERT(nelements % block_size == 0);
@ -14188,6 +14219,8 @@ static void llama_tensor_dequantize_internal(
auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) { auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) {
if (typ == GGML_TYPE_F16) { if (typ == GGML_TYPE_F16) {
ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels); ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels);
} else if (typ == GGML_TYPE_BF16) {
ggml_bf16_to_fp32_row((ggml_bf16_t *)inbuf, outbuf, nels);
} else { } else {
qtype.to_float(inbuf, outbuf, nels); qtype.to_float(inbuf, outbuf, nels);
} }
@ -14548,6 +14581,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break; case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break;
case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break; case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
// K-quants // K-quants

View file

@ -81,6 +81,9 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
LLAMA_VOCAB_PRE_TYPE_REFACT = 8, LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 10,
LLAMA_VOCAB_PRE_TYPE_OLMO = 11,
LLAMA_VOCAB_PRE_TYPE_DBRX = 12,
}; };
// note: these values should be synchronized with ggml_rope // note: these values should be synchronized with ggml_rope
@ -136,6 +139,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
}; };

Binary file not shown.

View file

@ -0,0 +1,106 @@
ied 4 ½ months
__ggml_vocab_test__
Führer
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
Hello world
__ggml_vocab_test__
Hello world
__ggml_vocab_test__
Hello World
__ggml_vocab_test__
Hello World
__ggml_vocab_test__
Hello World!
__ggml_vocab_test__
Hello, world!
__ggml_vocab_test__
Hello, world!
__ggml_vocab_test__
this is 🦙.cpp
__ggml_vocab_test__
w048 7tuijk dsdfhu
__ggml_vocab_test__
нещо на Български
__ggml_vocab_test__
កាន់តែពិសេសអាចខលចេញ
__ggml_vocab_test__
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
Hello
__ggml_vocab_test__
(
__ggml_vocab_test__
=
__ggml_vocab_test__
' era
__ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天
__ggml_vocab_test__
3
__ggml_vocab_test__
33
__ggml_vocab_test__
333
__ggml_vocab_test__
3333
__ggml_vocab_test__
33333
__ggml_vocab_test__
333333
__ggml_vocab_test__
3333333
__ggml_vocab_test__
33333333
__ggml_vocab_test__
333333333
__ggml_vocab_test__
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天 ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
__ggml_vocab_test__

View file

@ -0,0 +1,43 @@
1122 220 19 220 26062 3951
37 50753 261
220
256
262
197
198
271
1406
1572
9707 1879
21927 1879
9707 4337
21927 4337
21927 4337 0
9707 11 1879 0
21927 11 1879 0
419 374 11162 99 247 13 10821
86 15 19 23 220 22 83 1963 41808 11472 2940 16739
78762 14144 1456 13073 63471 33594 3038 133178 79012
146394 97529 241 44258 233 146568 44258 224 147603 20879 115 146280 44258 223 146280 147272 97529 227 147805 148301 147270 44258 223 146848
145836 320 8252 8 26525 114 378 235 149921 30543 320 35673 99066 97534 8 25521 227 320 3243 42365 429 702 1181 1828 3950 8
9707
21927
220 21927
256 21927
262 21927
262 21927 198 262 21927
320
198 284
6 11385
9707 11 379 64848 0 2585 525 498 26525 223 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216
18
18 18
18 18 18
18 18 18 18
18 18 18 18 18
18 18 18 18 18 18
18 18 18 18 18 18 18
18 18 18 18 18 18 18 18
18 18 18 18 18 18 18 18 18
198 4710 14731 65497 7847 1572 2303 78672 10947 145836 320 8252 8 26525 114 378 235 149921 30543 320 35673 99066 97534 8 25521 227 11162 99 247 149955 220 18 220 18 18 220 18 18 18 220 18 18 18 18 220 18 18 18 18 18 220 18 18 18 18 18 18 220 18 18 18 18 18 18 18 220 18 18 18 18 18 18 18 18 220 18 13 18 220 18 496 18 220 18 1112 18 220 146394 97529 241 44258 233 146568 44258 224 147603 20879 115 146280 44258 223 146280 147272 97529 227 144534 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216 55460 53237 18658 14144 1456 13073 63471 33594 3038 133178 79012 3355 4605 4605 13874 13874 73594 3014 3014 28149 17085 2928 26610 7646 358 3003 1012 364 83 813 566 594 1052 11 364 787 498 2704 30 364 44 537 2704 358 3278 1281 432 11 364 35 498 1075 1045 15243 30 1205 6 42612 264 63866 43

View file

@ -93,11 +93,14 @@ help_s = (
"specified values are averaged WITHOUT weighing by the --repetitions parameter of llama-bench." "specified values are averaged WITHOUT weighing by the --repetitions parameter of llama-bench."
) )
parser.add_argument("-s", "--show", help=help_s) parser.add_argument("-s", "--show", help=help_s)
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
known_args, unknown_args = parser.parse_known_args() known_args, unknown_args = parser.parse_known_args()
logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO)
if unknown_args: if unknown_args:
logger.error(f"Received unknown args: {unknown_args}.") logger.error(f"Received unknown args: {unknown_args}.\n")
parser.print_help() parser.print_help()
sys.exit(1) sys.exit(1)
@ -110,7 +113,7 @@ if input_file is None:
input_file = sqlite_files[0] input_file = sqlite_files[0]
if input_file is None: if input_file is None:
logger.error("Cannot find a suitable input file, please provide one.") logger.error("Cannot find a suitable input file, please provide one.\n")
parser.print_help() parser.print_help()
sys.exit(1) sys.exit(1)
@ -202,12 +205,12 @@ elif repo is not None:
hexsha8_baseline = find_parent_in_data(repo.heads.master.commit) hexsha8_baseline = find_parent_in_data(repo.heads.master.commit)
if hexsha8_baseline is None: if hexsha8_baseline is None:
logger.error("No baseline was provided and did not find data for any master branch commits.") logger.error("No baseline was provided and did not find data for any master branch commits.\n")
parser.print_help() parser.print_help()
sys.exit(1) sys.exit(1)
else: else:
logger.error("No baseline was provided and the current working directory " logger.error("No baseline was provided and the current working directory "
"is not part of a git repository from which a baseline could be inferred.") "is not part of a git repository from which a baseline could be inferred.\n")
parser.print_help() parser.print_help()
sys.exit(1) sys.exit(1)
@ -238,7 +241,7 @@ elif repo is not None:
break break
if hexsha8_compare is None: if hexsha8_compare is None:
logger.error("No compare target was provided and did not find data for any non-master commits.") logger.error("No compare target was provided and did not find data for any non-master commits.\n")
parser.print_help() parser.print_help()
sys.exit(1) sys.exit(1)
else: else:
@ -361,7 +364,7 @@ if "gpu_info" in show:
headers = [PRETTY_NAMES[p] for p in show] headers = [PRETTY_NAMES[p] for p in show]
headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"] headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"]
logger.info(tabulate( print(tabulate( # noqa: NP100
table, table,
headers=headers, headers=headers,
floatfmt=".2f", floatfmt=".2f",

View file

@ -1,6 +1,3 @@
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation // Copyright 2024 Mozilla Foundation
// //
// Permission is hereby granted, free of charge, to any person obtaining // Permission is hereby granted, free of charge, to any person obtaining
@ -585,15 +582,15 @@ class tinyBLAS_Q0_ARM {
}; };
#endif // __ARM_FEATURE_DOTPROD #endif // __ARM_FEATURE_DOTPROD
#if defined(__AVX2__) || defined(__AVX512F__) #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
template <typename TA, typename TB, typename TC> template <typename TA, typename TB, typename TC>
class tinyBLAS_Q0_AVX2 { class tinyBLAS_Q0_AVX {
public: public:
tinyBLAS_Q0_AVX2(int64_t k, tinyBLAS_Q0_AVX(int64_t k,
const TA *A, int64_t lda, const TA *A, int64_t lda,
const TB *B, int64_t ldb, const TB *B, int64_t ldb,
TC *C, int64_t ldc, TC *C, int64_t ldc,
int ith, int nth) int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
} }
@ -728,14 +725,34 @@ class tinyBLAS_Q0_AVX2 {
__m256 Cv[RN][RM] = {}; __m256 Cv[RN][RM] = {};
for (int64_t l = 0; l < k; ++l) for (int64_t l = 0; l < k; ++l)
for (int64_t j = 0; j < RN; ++j) for (int64_t j = 0; j < RN; ++j)
for (int64_t i = 0; i < RM; ++i) for (int64_t i = 0; i < RM; ++i) {
#if defined(__AVX2__)
__m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
load(A + lda * (ii + i) + l)),
_mm256_sign_epi8(load(B + ldb * (jj + j) + l),
load(A + lda * (ii + i) + l)));
#else
__m128i ali0 = load0(A + lda * (ii + i) + l);
__m128i ali1 = load1(A + lda * (ii + i) + l);
__m128i blj0 = load0(B + ldb * (jj + j) + l);
__m128i blj1 = load1(B + ldb * (jj + j) + l);
__m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
__m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
__m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
__m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
// updot
const __m128i oneFill = _mm_set1_epi16(1);
__m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
__m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
__m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
#endif
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
unhalf(B[ldb * (jj + j) + l].d)), unhalf(B[ldb * (jj + j) + l].d)),
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), udTmp,
load(A + lda * (ii + i) + l)), Cv[j][i]);
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), }
load(A + lda * (ii + i) + l))),
Cv[j][i]);
for (int64_t j = 0; j < RN; ++j) for (int64_t j = 0; j < RN; ++j)
for (int64_t i = 0; i < RM; ++i) for (int64_t i = 0; i < RM; ++i)
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
@ -746,10 +763,28 @@ class tinyBLAS_Q0_AVX2 {
return _mm256_loadu_si256((const __m256i *)b->qs); return _mm256_loadu_si256((const __m256i *)b->qs);
} }
inline __m128i load0(const block_q8_0 *b) {
return _mm_loadu_si128((const __m128i *)b->qs);
}
inline __m128i load1(const block_q8_0 *b) {
return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
}
inline __m256i load(const block_q4_0 *b) { inline __m256i load(const block_q4_0 *b) {
return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
} }
inline __m128i load0(const block_q4_0 *b) {
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
}
inline __m128i load1(const block_q4_0 *b) {
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
}
inline __m256 updot(__m256i u, __m256i s) { inline __m256 updot(__m256i u, __m256i s) {
__m256i res; __m256i res;
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
@ -777,7 +812,7 @@ class tinyBLAS_Q0_AVX2 {
const int ith; const int ith;
const int nth; const int nth;
}; };
#endif // __AVX2__ #endif // __AVX__
} // namespace } // namespace
@ -928,8 +963,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
case GGML_TYPE_Q8_0: { case GGML_TYPE_Q8_0: {
if (Btype != GGML_TYPE_Q8_0) if (Btype != GGML_TYPE_Q8_0)
return false; return false;
#if defined(__AVX2__) || defined(__AVX512F__) #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
tinyBLAS_Q0_AVX2<block_q8_0, block_q8_0, float> tb{ tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
k, (const block_q8_0 *)A, lda, k, (const block_q8_0 *)A, lda,
(const block_q8_0 *)B, ldb, (const block_q8_0 *)B, ldb,
(float *)C, ldc, (float *)C, ldc,
@ -952,8 +987,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
case GGML_TYPE_Q4_0: { case GGML_TYPE_Q4_0: {
if (Btype != GGML_TYPE_Q8_0) if (Btype != GGML_TYPE_Q8_0)
return false; return false;
#if defined(__AVX2__) || defined(__AVX512F__) #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
tinyBLAS_Q0_AVX2<block_q4_0, block_q8_0, float> tb{ tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
k, (const block_q4_0 *)A, lda, k, (const block_q4_0 *)A, lda,
(const block_q8_0 *)B, ldb, (const block_q8_0 *)B, ldb,
(float *)C, ldc, (float *)C, ldc,

View file

@ -84,6 +84,7 @@ llama_test(test-tokenizer-0 NAME test-tokenizer-0-starcoder ARGS ${CMAKE
llama_test(test-tokenizer-0 NAME test-tokenizer-0-gpt-2 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-2.gguf) llama_test(test-tokenizer-0 NAME test-tokenizer-0-gpt-2 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-2.gguf)
llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf) llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
llama_test(test-tokenizer-0 NAME test-tokenizer-0-command-r ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-command-r.gguf) llama_test(test-tokenizer-0 NAME test-tokenizer-0-command-r ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-command-r.gguf)
llama_test(test-tokenizer-0 NAME test-tokenizer-0-qwen2 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-qwen2.gguf)
# build test-tokenizer-1-bpe target once and add many tests # build test-tokenizer-1-bpe target once and add many tests
add_executable(test-tokenizer-1-bpe test-tokenizer-1-bpe.cpp) add_executable(test-tokenizer-1-bpe test-tokenizer-1-bpe.cpp)

View file

@ -50,7 +50,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) { if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float)); ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
} else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16) { } else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16) {
GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0); GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0);
std::vector<uint8_t> dataq(ggml_row_size(tensor->type, size)); std::vector<uint8_t> dataq(ggml_row_size(tensor->type, size));
std::vector<float> imatrix(tensor->ne[0], 1.0f); // dummy importance matrix std::vector<float> imatrix(tensor->ne[0], 1.0f); // dummy importance matrix
@ -92,6 +92,8 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0]; size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0];
if (t->type == GGML_TYPE_F16) { if (t->type == GGML_TYPE_F16) {
tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i])); tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]));
} else if (t->type == GGML_TYPE_BF16) {
tv.push_back(ggml_bf16_to_fp32(*(ggml_bf16_t*)&buf[i]));
} else if (t->type == GGML_TYPE_F32) { } else if (t->type == GGML_TYPE_F32) {
tv.push_back(*(float *) &buf[i]); tv.push_back(*(float *) &buf[i]);
} else if (t->type == GGML_TYPE_I32) { } else if (t->type == GGML_TYPE_I32) {
@ -1898,7 +1900,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
std::default_random_engine rng(0); std::default_random_engine rng(0);
const ggml_type all_types[] = { const ggml_type all_types[] = {
GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16,
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0, GGML_TYPE_Q8_0,