diff --git a/CMakeLists.txt b/CMakeLists.txt index 477c5b57c..0e22ee230 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -103,6 +103,8 @@ set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING "llama: max. batch size for using peer access") option(LLAMA_CUDA_NO_PEER_COPY "llama: do not use peer to peer copies" OFF) +option(LLAMA_CUDA_NO_VMM "llama: do not try to use CUDA VMM" OFF) + option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF) option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF) option(LLAMA_HIP_UMA "llama: use HIP unified memory architecture" OFF) @@ -409,6 +411,9 @@ if (LLAMA_CUDA) if (LLAMA_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() + if (LLAMA_CUDA_NO_VMM) + add_compile_definitions(GGML_CUDA_NO_VMM) + endif() add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) if (DEFINED LLAMA_CUDA_DMMV_Y) @@ -434,7 +439,11 @@ if (LLAMA_CUDA) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) endif() - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver) + if (LLAMA_CUDA_NO_VMM) + # No VMM requested, no need to link directly with the cuda driver lib (libcuda.so) + else() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ... + endif() if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) # 52 == lowest CUDA 12 standard diff --git a/README.md b/README.md index 885322e68..1c960b8c1 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,8 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) ### Hot topics -- **BPE pre-tokenization support has been added: https://github.com/ggerganov/llama.cpp/pull/6920** +- **Initial Flash-Attention support: https://github.com/ggerganov/llama.cpp/pull/5021** +- BPE pre-tokenization support has been added: https://github.com/ggerganov/llama.cpp/pull/6920 - MoE memory layout has been updated - reconvert models for `mmap` support and regenerate `imatrix` https://github.com/ggerganov/llama.cpp/pull/6387 - Model sharding instructions using `gguf-split` https://github.com/ggerganov/llama.cpp/discussions/6404 - Fix major bug in Metal batched inference https://github.com/ggerganov/llama.cpp/pull/6225 @@ -935,17 +936,25 @@ If your issue is with model generation quality, then please at least scan the fo ### Android +#### Build on Android using Termux +[Termux](https://github.com/termux/termux-app#installation) is a method to execute `llama.cpp` on an Android device (no root required). +``` +apt update && apt upgrade -y +apt install git make cmake +``` + +It's recommended to move your model inside the `~/` directory for best performance: +``` +cd storage/downloads +mv model.gguf ~/ +``` + +[Get the code](https://github.com/ggerganov/llama.cpp#get-the-code) & [follow the Linux build instructions](https://github.com/ggerganov/llama.cpp#build) to build `llama.cpp`. + #### Building the Project using Android NDK -You can easily run `llama.cpp` on Android device with [termux](https://termux.dev/). - -First, install the essential packages for termux: -``` -pkg install clang wget git cmake -``` -Second, obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake: - -You can execute the following commands on your computer to avoid downloading the NDK to your mobile. Of course, you can also do this in Termux. +Obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake. +Execute the following commands on your computer to avoid downloading the NDK to your mobile. Alternatively, you can also do this in Termux: ``` $ mkdir build-android $ cd build-android @@ -953,7 +962,9 @@ $ export NDK= $ cmake -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_C_FLAGS=-march=armv8.4a+dotprod .. $ make ``` -Install [termux](https://termux.dev/) on your device and run `termux-setup-storage` to get access to your SD card. + +Install [termux](https://github.com/termux/termux-app#installation) on your device and run `termux-setup-storage` to get access to your SD card (if Android 11+ then run the command twice). + Finally, copy these built `llama` binaries and the model file to your device storage. Because the file permissions in the Android sdcard cannot be changed, you can copy the executable files to the `/data/data/com.termux/files/home/bin` path, and then execute the following commands in Termux to add executable permission: (Assumed that you have pushed the built executable files to the /sdcard/llama.cpp/bin path using `adb push`) @@ -975,25 +986,10 @@ $cd /data/data/com.termux/files/home/bin $./main -m ../model/llama-2-7b-chat.Q4_K_M.gguf -n 128 -cml ``` -Here is a demo of an interactive session running on Pixel 5 phone: +Here's a demo of an interactive session running on Pixel 5 phone: https://user-images.githubusercontent.com/271616/225014776-1d567049-ad71-4ef2-b050-55b0b3b9274c.mp4 -#### Build on Android using Termux -[Termux](https://github.com/termux/termux-app#installation) is an alternative to execute `llama.cpp` on an Android device (no root required). -``` -apt update && apt upgrade -y -apt install git -``` - -It's recommended to move your model inside the `~/` directory for best performance: -``` -cd storage/downloads -mv model.gguf ~/ -``` - -[Follow the Linux build instructions](https://github.com/ggerganov/llama.cpp#build) to build `llama.cpp`. - ### Docker #### Prerequisites diff --git a/ci/run.sh b/ci/run.sh index bf21b6b31..e67c1a5ff 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -160,9 +160,8 @@ function gg_run_test_scripts_debug { set -e - # TODO: too slow, run on dedicated node - #(cd ./examples/gguf-split && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log - #(cd ./examples/quantize && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log + (cd ./examples/gguf-split && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log + (cd ./examples/quantize && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log set +e } @@ -695,8 +694,10 @@ test $ret -eq 0 && gg_run ctest_release if [ -z ${GG_BUILD_LOW_PERF} ]; then test $ret -eq 0 && gg_run embd_bge_small - test $ret -eq 0 && gg_run test_scripts_debug - test $ret -eq 0 && gg_run test_scripts_release + if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then + test $ret -eq 0 && gg_run test_scripts_debug + test $ret -eq 0 && gg_run test_scripts_release + fi if [ -z ${GG_BUILD_VRAM_GB} ] || [ ${GG_BUILD_VRAM_GB} -ge 8 ]; then if [ -z ${GG_BUILD_CUDA} ]; then diff --git a/common/common.cpp b/common/common.cpp index 467fb014e..4a9da284e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -911,6 +911,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.instruct = true; return true; } + if (arg == "-cnv" || arg == "--conversation") { + params.conversation = true; + return true; + } if (arg == "-cml" || arg == "--chatml") { params.chatml = true; return true; @@ -1417,6 +1421,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --version show version and build info\n"); printf(" -i, --interactive run in interactive mode\n"); printf(" --interactive-first run in interactive mode and wait for input right away\n"); + printf(" -cnv, --conversation run in conversation mode (does not print special tokens and suffix/prefix)\n"); printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n"); printf(" -cml, --chatml run in chatml mode (use with ChatML-compatible models)\n"); printf(" --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); diff --git a/common/common.h b/common/common.h index 9252a4b63..6f00a2cca 100644 --- a/common/common.h +++ b/common/common.h @@ -140,6 +140,7 @@ struct gpt_params { bool random_prompt = false; // do not randomize prompt if none provided bool use_color = false; // use color to distinguish generations and inputs bool interactive = false; // interactive mode + bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix) bool chatml = false; // chatml mode (used for models trained on chatml syntax) bool prompt_cache_all = false; // save user input and generations to prompt cache bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it diff --git a/common/sampling.cpp b/common/sampling.cpp index cc83600d9..3715a7985 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -35,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ result->prev.resize(params.n_prev); + result->n_considered = 0; + llama_sampling_set_rng_seed(result, params.seed); return result; @@ -64,6 +66,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) { std::fill(ctx->prev.begin(), ctx->prev.end(), 0); ctx->cur.clear(); + ctx->n_considered = 0; } void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { @@ -253,6 +256,8 @@ static llama_token llama_sampling_sample_impl( } } + ctx_sampling->n_considered = cur_p.size; + return id; } diff --git a/common/sampling.h b/common/sampling.h index cf7081e36..5b73ecdcd 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -81,6 +81,7 @@ struct llama_sampling_context { // TODO: replace with ring-buffer std::vector prev; std::vector cur; + size_t n_considered; std::mt19937 rng; }; diff --git a/convert-hf-to-gguf-update.py b/convert-hf-to-gguf-update.py index 46a225462..a26f45a5f 100755 --- a/convert-hf-to-gguf-update.py +++ b/convert-hf-to-gguf-update.py @@ -67,6 +67,9 @@ models = [ {"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", }, {"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", }, {"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", }, + {"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", }, + {"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", }, + {"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", }, ] # make directory "models/tokenizers" if it doesn't exist @@ -150,6 +153,8 @@ for model in models: # print the "pre_tokenizer" content from the tokenizer.json with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f: cfg = json.load(f) + normalizer = cfg["normalizer"] + logger.info("normalizer: " + json.dumps(normalizer, indent=4)) pre_tokenizer = cfg["pre_tokenizer"] logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4)) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 04e93b719..1dc18b2a5 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -397,6 +397,15 @@ class Model: if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8": # ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01 res = "command-r" + if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea": + # ref: https://huggingface.co/Qwen/Qwen1.5-7B + res = "qwen2" + if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166": + # ref: https://huggingface.co/allenai/OLMo-1.7-7B-hf + res = "olmo" + if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e": + # ref: https://huggingface.co/databricks/dbrx-instruct + res = "dbrx" if res is None: logger.warning("\n") @@ -2248,8 +2257,9 @@ class OlmoModel(Model): def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_layer_norm_eps(1e-5) - if "clip_qkv" in self.hparams is not None: - self.gguf_writer.add_clamp_kqv(self.hparams["clip_qkv"]) + clip_qkv = self.hparams.get("clip_qkv") + if clip_qkv is not None: + self.gguf_writer.add_clamp_kqv(clip_qkv) # Same as super class, but permuting q_proj, k_proj # Copied from: LlamaModel diff --git a/convert.py b/convert.py index 65ac5a720..148bfd66a 100755 --- a/convert.py +++ b/convert.py @@ -1512,25 +1512,27 @@ def main(args_in: list[str] | None = None) -> None: if args.big_endian: endianess = gguf.GGUFEndian.BIG - params = Params.load(model_plus) - if params.n_ctx == -1: - if args.ctx is None: - msg = """\ - The model doesn't have a context size, and you didn't specify one with --ctx - Please specify one with --ctx: - - LLaMA v1: --ctx 2048 - - LLaMA v2: --ctx 4096""" - parser.error(textwrap.dedent(msg)) - params.n_ctx = args.ctx + params = None + if args.pad_vocab or not args.vocab_only: + params = Params.load(model_plus) + if params.n_ctx == -1: + if args.ctx is None: + msg = """\ + The model doesn't have a context size, and you didn't specify one with --ctx + Please specify one with --ctx: + - LLaMA v1: --ctx 2048 + - LLaMA v2: --ctx 4096""" + parser.error(textwrap.dedent(msg)) + params.n_ctx = args.ctx - if args.outtype: - params.ftype = { - "f32": GGMLFileType.AllF32, - "f16": GGMLFileType.MostlyF16, - "q8_0": GGMLFileType.MostlyQ8_0, - }[args.outtype] + if args.outtype: + params.ftype = { + "f32": GGMLFileType.AllF32, + "f16": GGMLFileType.MostlyF16, + "q8_0": GGMLFileType.MostlyQ8_0, + }[args.outtype] - logger.info(f"params = {params}") + logger.info(f"params = {params}") model_parent_path = model_plus.paths[0].parent vocab_path = Path(args.vocab_dir or args.model or model_parent_path) @@ -1543,6 +1545,17 @@ def main(args_in: list[str] | None = None) -> None: if not args.outfile: raise ValueError("need --outfile if using --vocab-only") outfile = args.outfile + if params is None: + params = Params( + n_vocab = vocab.vocab_size, + n_embd = 1, + n_layer = 1, + n_ctx = 1, + n_ff = 1, + n_head = 1, + n_head_kv = 1, + f_norm_eps = 1e-5, + ) OutputFile.write_vocab_only(outfile, params, vocab, special_vocab, endianess=endianess, pad_vocab=args.pad_vocab) logger.info(f"Wrote {outfile}") diff --git a/docs/BLIS.md b/docs/BLIS.md index 0bcd6eeef..c933766b7 100644 --- a/docs/BLIS.md +++ b/docs/BLIS.md @@ -23,7 +23,7 @@ Install BLIS: sudo make install ``` -We recommend using openmp since it's easier to modify the cores been used. +We recommend using openmp since it's easier to modify the cores being used. ### llama.cpp compilation diff --git a/docs/HOWTO-add-model.md b/docs/HOWTO-add-model.md index a56b78344..48769cdf6 100644 --- a/docs/HOWTO-add-model.md +++ b/docs/HOWTO-add-model.md @@ -96,9 +96,9 @@ NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorc This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`. -Have a look to existing implementation like `build_llama`, `build_dbrx` or `build_bert`. +Have a look at existing implementation like `build_llama`, `build_dbrx` or `build_bert`. -When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support of missing backend operations can be added in another PR. +When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support for missing backend operations can be added in another PR. Note: to debug the inference graph: you can use [eval-callback](../examples/eval-callback). diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 3da5317b3..22743b1bf 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -575,7 +575,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs( GGML_ASSERT(tokens_input->type == GGML_TYPE_I32); auto add_to_f32 = [] (struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16) { + if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16 || a->type == GGML_TYPE_BF16) { return ggml_add_cast(ctx, a, b, GGML_TYPE_F32); } else if (a->type == GGML_TYPE_F32) { return ggml_add(ctx, a, b); diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 71e7a727f..82b19fc4f 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -19,6 +19,7 @@ struct Stats { std::vector values; + std::vector counts; int ncall = 0; }; @@ -121,12 +122,10 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * auto & e = m_stats[wname]; ++e.ncall; - // NOTE: since we select top-k experts, the number of calls for the expert tensors will be k times larger - // using the following line, we can correct for that if needed by replacing the line above with: - //if (idx == t->src[0]->ne[0] - 1) ++e.ncall; if (e.values.empty()) { e.values.resize(src1->ne[0]*n_as, 0); + e.counts.resize(src1->ne[0]*n_as, 0); } else if (e.values.size() != (size_t)src1->ne[0]*n_as) { fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as); @@ -153,6 +152,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * for (int j = 0; j < (int)src1->ne[0]; ++j) { e.values[e_start + j] += x[j]*x[j]; + e.counts[e_start + j]++; } } } @@ -170,6 +170,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * auto& e = m_stats[wname]; if (e.values.empty()) { e.values.resize(src1->ne[0], 0); + e.counts.resize(src1->ne[0], 0); } else if (e.values.size() != (size_t)src1->ne[0]) { fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]); @@ -183,6 +184,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * const float * x = data + row * src1->ne[0]; for (int j = 0; j < (int)src1->ne[0]; ++j) { e.values[j] += x[j]*x[j]; + e.counts[j]++; } } if (e.ncall > m_last_call) { @@ -222,7 +224,13 @@ void IMatrixCollector::save_imatrix(const char * fname, const char * dataset) co out.write((const char *) &p.second.ncall, sizeof(p.second.ncall)); int nval = p.second.values.size(); out.write((const char *) &nval, sizeof(nval)); - if (nval > 0) out.write((const char *) p.second.values.data(), nval * sizeof(float)); + if (nval > 0) { + std::vector tmp(nval); + for (int i = 0; i < nval; i++) { + tmp[i] = (p.second.values[i] / static_cast(p.second.counts[i])) * static_cast(p.second.ncall); + } + out.write((const char*)tmp.data(), nval*sizeof(float)); + } } // Write the number of call the matrix was computed with @@ -270,14 +278,28 @@ bool IMatrixCollector::load_imatrix(const char * imatrix_file, std::unordered_ma imatrix_data = {}; return false; } - e.values.resize(nval); - in.read((char*)e.values.data(), nval*sizeof(float)); + + // When re-called from load_imatrix() with add set, this will already be created. + if (e.values.empty()) { + e.values.resize(nval, 0); + e.counts.resize(nval, 0); + } + + std::vector tmp(nval); + in.read((char*)tmp.data(), nval*sizeof(float)); if (in.fail()) { printf("%s: failed reading data for entry %d\n",__func__,i); imatrix_data = {}; return false; } - e.ncall = ncall; + + // Recreate the state as expected by save_imatrix(), and corerct for weighted sum. + for (int i = 0; i < nval; i++) { + e.values[i] += tmp[i]; + e.counts[i] += ncall; + } + e.ncall += ncall; + } return true; } diff --git a/examples/llava/README.md b/examples/llava/README.md index d4810d42e..4fb0cf381 100644 --- a/examples/llava/README.md +++ b/examples/llava/README.md @@ -56,7 +56,7 @@ python ./examples/llava/convert-image-encoder-to-gguf.py -m ../clip-vit-large-pa python ./convert.py ../llava-v1.5-7b --skip-unknown ``` -Now both the LLaMA part and the image encoder is in the `llava-v1.5-7b` directory. +Now both the LLaMA part and the image encoder are in the `llava-v1.5-7b` directory. ## LLaVA 1.6 gguf conversion 1) First clone a LLaVA 1.6 model: diff --git a/examples/main/README.md b/examples/main/README.md index e7a38743c..97e2ae4c2 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -143,7 +143,7 @@ The `--ctx-size` option allows you to set the size of the prompt context used by ### Extended Context Size -Some fine-tuned models have extended the context length by scaling RoPE. For example, if the original pre-trained model have a context length (max sequence length) of 4096 (4k) and the fine-tuned model have 32k. That is a scaling factor of 8, and should work by setting the above `--ctx-size` to 32768 (32k) and `--rope-scale` to 8. +Some fine-tuned models have extended the context length by scaling RoPE. For example, if the original pre-trained model has a context length (max sequence length) of 4096 (4k) and the fine-tuned model has 32k. That is a scaling factor of 8, and should work by setting the above `--ctx-size` to 32768 (32k) and `--rope-scale` to 8. - `--rope-scale N`: Where N is the linear scaling factor used by the fine-tuned model. @@ -286,7 +286,7 @@ These options help improve the performance and memory usage of the LLaMA models. - `--numa distribute`: Pin an equal proportion of the threads to the cores on each NUMA node. This will spread the load amongst all cores on the system, utilitizing all memory channels at the expense of potentially requiring memory to travel over the slow links between nodes. - `--numa isolate`: Pin all threads to the NUMA node that the program starts on. This limits the number of cores and amount of memory that can be used, but guarantees all memory access remains local to the NUMA node. -- `--numa numactl`: Pin threads to the CPUMAP that is passed to the program by starting it with the numactl utility. This is the most flexible mode, and allow arbitraty core usage patterns, for example a map that uses all the cores on one NUMA nodes, and just enough cores on a second node to saturate the inter-node memory bus. +- `--numa numactl`: Pin threads to the CPUMAP that is passed to the program by starting it with the numactl utility. This is the most flexible mode, and allow arbitrary core usage patterns, for example a map that uses all the cores on one NUMA nodes, and just enough cores on a second node to saturate the inter-node memory bus. These flags attempt optimizations that help on some systems with non-uniform memory access. This currently consists of one of the above strategies, and disabling prefetch and readahead for mmap. The latter causes mapped pages to be faulted in on first access instead of all at once, and in combination with pinning threads to NUMA nodes, more of the pages end up on the NUMA node where they are used. Note that if the model is already in the system page cache, for example because of a previous run without this option, this will have little effect unless you drop the page cache first. This can be done by rebooting the system or on Linux by writing '3' to '/proc/sys/vm/drop_caches' as root. diff --git a/examples/main/main.cpp b/examples/main/main.cpp index eabbc2db3..49acd6bab 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -362,6 +362,9 @@ int main(int argc, char ** argv) { params.interactive_first = true; params.antiprompt.emplace_back("<|im_start|>user\n"); } + else if (params.conversation) { + params.interactive_first = true; + } // enable interactive mode if interactive start is specified if (params.interactive_first) { @@ -733,7 +736,7 @@ int main(int argc, char ** argv) { // display text if (input_echo && display) { for (auto id : embd) { - const std::string token_str = llama_token_to_piece(ctx, id); + const std::string token_str = llama_token_to_piece(ctx, id, !params.conversation); printf("%s", token_str.c_str()); if (embd.size() > 1) { @@ -796,7 +799,7 @@ int main(int argc, char ** argv) { // deal with end of generation tokens in interactive mode if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) { - LOG("found EOS token\n"); + LOG("found an EOG token\n"); if (params.interactive) { if (!params.antiprompt.empty()) { @@ -816,7 +819,7 @@ int main(int argc, char ** argv) { if (n_past > 0 && is_interacting) { LOG("waiting for user input\n"); - if (params.instruct || params.chatml) { + if (params.conversation || params.instruct || params.chatml) { printf("\n> "); } @@ -826,7 +829,7 @@ int main(int argc, char ** argv) { } std::string buffer; - if (!params.input_prefix.empty()) { + if (!params.input_prefix.empty() && !params.conversation) { LOG("appending input prefix: '%s'\n", params.input_prefix.c_str()); printf("%s", params.input_prefix.c_str()); } @@ -850,7 +853,7 @@ int main(int argc, char ** argv) { // Entering a empty line lets the user pass control back if (buffer.length() > 1) { // append input suffix if any - if (!params.input_suffix.empty()) { + if (!params.input_suffix.empty() && !params.conversation) { LOG("appending input suffix: '%s'\n", params.input_suffix.c_str()); printf("%s", params.input_suffix.c_str()); } diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 432cc2b4f..909eab283 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -46,7 +46,8 @@ static const std::vector QUANT_OPTIONS = { { "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 4.45G, +0.0122 ppl @ LLaMA-v1-7B", }, { "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", }, { "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", }, - { "F16", LLAMA_FTYPE_MOSTLY_F16, "13.00G @ 7B", }, + { "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, -0.0020 ppl @ Mistral-7B", }, + { "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", }, { "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", }, // Note: Ensure COPY comes after F32 to avoid ftype 0 from matching. { "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", }, diff --git a/examples/server/README.md b/examples/server/README.md index b96a4444a..650317991 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -62,6 +62,18 @@ page cache before using this. See https://github.com/ggerganov/llama.cpp/issues/ - `--chat-template JINJA_TEMPLATE`: Set custom jinja chat template. This parameter accepts a string, not a file name. Default: template taken from model's metadata. We only support [some pre-defined templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) - `--log-disable`: Output logs to stdout only, not to `llama.log`. Default: enabled - `--log-format FORMAT`: Define the log output to FORMAT: json or text Default: `json` +- `--rope-scaling` : RoPE scaling method. Defaults to linear unless otherwise specified by the model. Options are `none`, `linear`, `yarn` +- `--rope-freq-base N` : RoPE frequency base (default: loaded from model) +- `--rope-freq-scale N`: RoPE frequency scaling factor, expands context by a factor of 1/N (e.g. 0.25) +- `--yarn-ext-factor N` : YaRN: extrapolation mix factor (Default: 1.0, 0.0 = full interpolation) +- `--yarn-attn-factor N` : YaRN: scale sqrt(t) or attention magnitude (default: 1.0) +- `--yarn-beta-slow N`: YaRN: High correction dim or alpha (default: 1.0) +- `--yarn-beta-fast N`: YaRN: low correction dim or beta (default: 32.0) +- `--pooling` : Pooling type for embeddings, use model default if unspecified. Options are `none`, `mean`, `cls` +- `-dt N`, `--defrag-thold N`: KV cache defragmentation threshold (default: -1.0, < 0 = disabled) +- `-fa`, `--flash-attn` : enable flash attention (default: disabled). +- `-ctk TYPE`, `--cache-type-k TYPE` : KV cache data type for K (default: `f16`, options `f32`, `f16`, `q8_0`, `q4_0`, `q4_1`, `iq4_nl`, `q5_0`, or `q5_1`) +- `-ctv TYPE`, `--cache-type-v TYPE` : KV cache type for V (default `f16`, see `-ctk` for options) **If compiled with `LLAMA_SERVER_SSL=ON`** - `--ssl-key-file FNAME`: path to file a PEM-encoded SSL private key @@ -260,7 +272,7 @@ node index.js `logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings, e.g. `[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does. Default: `[]` - `n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token. Default: `0` + `n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token given the sampling settings. Note that for temperature < 0 the tokens are sampled greedily but token probabilities are still being calculated via a simple softmax of the logits without considering any other sampler settings. Default: `0` `min_keep`: If greater than 0, force samplers to return N possible tokens at minimum. Default: `0` @@ -319,7 +331,7 @@ Notice that each `probs` is an array of length `n_probs`. `content`: Set the text to tokenize. - Note that a special `BOS` token is never inserted. + `add_special`: Boolean indicating if special tokens, i.e. `BOS`, should be inserted. Default: `false` - **POST** `/detokenize`: Convert tokens to text. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ff0814b2f..06c0be567 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2266,17 +2266,31 @@ struct server_context { llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false }; result.tok = id; - const int32_t n_probs = slot.sparams.n_probs; - if (slot.sparams.temp <= 0 && n_probs > 0) { - // for llama_sample_token_greedy we need to sort candidates - llama_sample_softmax(ctx, &cur_p); - } + const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs); + if (n_probs > 0) { + const size_t n_considered = slot.ctx_sampling->n_considered; - for (size_t i = 0; i < std::min(cur_p.size, (size_t) n_probs); ++i) { - result.probs.push_back({ - cur_p.data[i].id, - cur_p.data[i].p - }); + // Make sure at least n_probs top tokens are at the front of the vector: + if (slot.sparams.temp == 0.0f && n_probs > n_considered) { + llama_sample_top_k(ctx, &cur_p, n_probs, 0); + } + + if (slot.sparams.temp == 0.0f) { + // With greedy sampling the probabilities have possibly not been calculated. + for (size_t i = 0; i < n_probs; ++i) { + result.probs.push_back({ + cur_p.data[i].id, + i == 0 ? 1.0f : 0.0f + }); + } + } else { + for (size_t i = 0; i < n_probs; ++i) { + result.probs.push_back({ + cur_p.data[i].id, + i >= n_considered ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability. + }); + } + } } if (!process_token(result, slot)) { @@ -3633,7 +3647,8 @@ int main(int argc, char ** argv) { std::vector tokens; if (body.count("content") != 0) { - tokens = ctx_server.tokenize(body["content"], false); + const bool add_special = json_value(body, "add_special", false); + tokens = ctx_server.tokenize(body["content"], add_special); } const json data = format_tokenizer_response(tokens); return res.set_content(data.dump(), "application/json; charset=utf-8"); diff --git a/examples/server/tests/features/server.feature b/examples/server/tests/features/server.feature index 646a4e49d..d21c09135 100644 --- a/examples/server/tests/features/server.feature +++ b/examples/server/tests/features/server.feature @@ -7,6 +7,7 @@ Feature: llama.cpp server And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models And a model file test-model.gguf And a model alias tinyllama-2 + And BOS token is 1 And 42 as server seed # KV Cache corresponds to the total amount of tokens # that can be stored across all independent sequences: #4130 @@ -91,7 +92,18 @@ Feature: llama.cpp server """ What is the capital of France ? """ - Then tokens can be detokenize + Then tokens can be detokenized + And tokens do not begin with BOS + + Scenario: Tokenize w/ BOS + Given adding special tokens + When tokenizing: + """ + What is the capital of Germany? + """ + Then tokens begin with BOS + Given first token is removed + Then tokens can be detokenized Scenario: Models available Given available models diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index ca2f0b20e..f4b1ac1d7 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -376,6 +376,11 @@ def step_seed(context, seed): context.seed.append(seed) +@step('BOS token is {bos:d}') +def step_bos_token(context, bos): + context.bos = bos + + @step('a prefix prompt') def step_prompt_prefix(context): context.prompt_prefix = context_text(context) @@ -656,21 +661,29 @@ async def all_embeddings_are_generated(context): assert_embeddings(context.tasks_result.pop().pop()) +@step('adding special tokens') +def step_tokenize_set_add_special(context): + context.tokenize_add_special = True + + @step('tokenizing') @async_run_until_complete async def step_tokenize(context): context.tokenized_text = context_text(context) async with aiohttp.ClientSession() as session: + tokenize_args = { + "content": context.tokenized_text, + } + if getattr(context, 'tokenize_add_special', None) is not None: + tokenize_args['add_special'] = context.tokenize_add_special async with session.post(f'{context.base_url}/tokenize', - json={ - "content": context.tokenized_text, - }) as response: + json=tokenize_args) as response: assert response.status == 200 tokenize_json = await response.json() context.tokens = tokenize_json['tokens'] -@step('tokens can be detokenize') +@step('tokens can be detokenized') @async_run_until_complete async def step_detokenize(context): assert len(context.tokens) > 0 @@ -685,6 +698,21 @@ async def step_detokenize(context): assert context.tokenized_text == detokenize_json['content'].strip() +@step('tokens begin with BOS') +def step_strings_for_tokenization(context): + assert context.tokens[0] == context.bos + + +@step('tokens do not begin with BOS') +def step_strings_for_tokenization(context): + assert context.tokens[0] != context.bos + + +@step('first token is removed') +def step_strings_for_tokenization(context): + context.tokens = context.tokens[1:] + + @step('an OPTIONS request is sent from {origin}') @async_run_until_complete async def step_options_request(context, origin): diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 1a2212502..af12f497d 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -49,18 +49,18 @@ extern bool server_log_json; #define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__) #define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) -static inline void server_log(const char *level, const char *function, int line, const char *message, const nlohmann::ordered_json &extra); +static inline void server_log(const char * level, const char * function, int line, const char * message, const json & extra); template -static T json_value(const json &body, const std::string &key, const T &default_value) { +static T json_value(const json & body, const std::string & key, const T & default_value) { // Fallback null to default value - if (body.contains(key) && !body.at(key).is_null()){ + if (body.contains(key) && !body.at(key).is_null()) { try { - return body.value(key, default_value); - } - catch (nlohmann::json_abi_v3_11_3::detail::type_error const&){ - std::string message = "Wrong type supplied for parameter '" + key + "'. Expected '" + typeid(default_value).name() + "', using default value."; - server_log("WARN", __func__, __LINE__, message.c_str(), body); + return body.at(key); + } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) { + std::stringstream ss; + ss << "Wrong type supplied for parameter '" << key << "'. Expected '" << json(default_value).type_name() << "', using default value."; + LOG_WARNING(ss.str().c_str(), body); return default_value; } } else { @@ -68,16 +68,16 @@ static T json_value(const json &body, const std::string &key, const T &default_v } } -static inline void server_log(const char *level, const char *function, int line, const char *message, const nlohmann::ordered_json &extra) { +static inline void server_log(const char * level, const char * function, int line, const char * message, const json & extra) { std::stringstream ss_tid; ss_tid << std::this_thread::get_id(); - json log = nlohmann::ordered_json{ + json log = json{ {"tid", ss_tid.str()}, {"timestamp", time(nullptr)}, }; if (server_log_json) { - log.merge_patch( { + log.merge_patch({ {"level", level}, {"function", function}, {"line", line}, @@ -98,7 +98,7 @@ static inline void server_log(const char *level, const char *function, int line, } std::stringstream ss; ss << buf << " |"; - for (const auto& el : log.items()) + for (const auto & el : log.items()) { const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace); ss << " " << el.key() << "=" << value; diff --git a/examples/sycl/README.md b/examples/sycl/README.md index b46f17f39..c589c2d3a 100644 --- a/examples/sycl/README.md +++ b/examples/sycl/README.md @@ -1,6 +1,6 @@ # llama.cpp/example/sycl -This example program provide the tools for llama.cpp for SYCL on Intel GPU. +This example program provides the tools for llama.cpp for SYCL on Intel GPU. ## Tool diff --git a/flake.lock b/flake.lock index b738da7c6..c9ead0bf7 100644 --- a/flake.lock +++ b/flake.lock @@ -5,11 +5,11 @@ "nixpkgs-lib": "nixpkgs-lib" }, "locked": { - "lastModified": 1712014858, - "narHash": "sha256-sB4SWl2lX95bExY2gMFG5HIzvva5AVMJd4Igm+GpZNw=", + "lastModified": 1714641030, + "narHash": "sha256-yzcRNDoyVP7+SCNX0wmuDju1NUCt8Dz9+lyUXEI0dbI=", "owner": "hercules-ci", "repo": "flake-parts", - "rev": "9126214d0a59633752a136528f5f3b9aa8565b7d", + "rev": "e5d10a24b66c3ea8f150e47dfdb0416ab7c3390e", "type": "github" }, "original": { @@ -20,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1714076141, - "narHash": "sha256-Drmja/f5MRHZCskS6mvzFqxEaZMeciScCTFxWVLqWEY=", + "lastModified": 1714635257, + "narHash": "sha256-4cPymbty65RvF1DWQfc+Bc8B233A1BWxJnNULJKQ1EY=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "7bb2ccd8cdc44c91edba16c48d2c8f331fb3d856", + "rev": "63c3a29ca82437c87573e4c6919b09a24ea61b0f", "type": "github" }, "original": { @@ -36,20 +36,14 @@ }, "nixpkgs-lib": { "locked": { - "dir": "lib", - "lastModified": 1711703276, - "narHash": "sha256-iMUFArF0WCatKK6RzfUJknjem0H9m4KgorO/p3Dopkk=", - "owner": "NixOS", - "repo": "nixpkgs", - "rev": "d8fe5e6c92d0d190646fb9f1056741a229980089", - "type": "github" + "lastModified": 1714640452, + "narHash": "sha256-QBx10+k6JWz6u7VsohfSw8g8hjdBZEf8CFzXH1/1Z94=", + "type": "tarball", + "url": "https://github.com/NixOS/nixpkgs/archive/50eb7ecf4cd0a5756d7275c8ba36790e5bd53e33.tar.gz" }, "original": { - "dir": "lib", - "owner": "NixOS", - "ref": "nixos-unstable", - "repo": "nixpkgs", - "type": "github" + "type": "tarball", + "url": "https://github.com/NixOS/nixpkgs/archive/50eb7ecf4cd0a5756d7275c8ba36790e5bd53e33.tar.gz" } }, "root": { diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c30554f0c..2d1742c82 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -113,7 +113,7 @@ static ggml_cuda_device_info ggml_cuda_init() { for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; -#if !defined(GGML_USE_HIPBLAS) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) CUdevice device; CU_CHECK(cuDeviceGet(&device, id)); CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device)); @@ -259,7 +259,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { }; // pool with virtual memory -#if !defined(GGML_USE_HIPBLAS) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) struct ggml_cuda_pool_vmm : public ggml_cuda_pool { static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB @@ -356,7 +356,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { #endif // !defined(GGML_USE_HIPBLAS) std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) { -#if !defined(GGML_USE_HIPBLAS) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) if (ggml_cuda_info().devices[device].vmm) { return std::unique_ptr(new ggml_cuda_pool_vmm(device)); } diff --git a/ggml-impl.h b/ggml-impl.h index 94a1cc668..d85b152bf 100644 --- a/ggml-impl.h +++ b/ggml-impl.h @@ -17,6 +17,83 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b)) +/** + * Converts brain16 to float32. + * + * The bfloat16 floating point format has the following structure: + * + * ┌sign + * │ + * │ ┌exponent + * │ │ + * │ │ ┌mantissa + * │ │ │ + * │┌──┴───┐┌─┴───┐ + * 0b0000000000000000 brain16 + * + * Since bf16 has the same number of exponent bits as a 32bit float, + * encoding and decoding numbers becomes relatively straightforward. + * + * ┌sign + * │ + * │ ┌exponent + * │ │ + * │ │ ┌mantissa + * │ │ │ + * │┌──┴───┐┌─┴───────────────────┐ + * 0b00000000000000000000000000000000 IEEE binary32 + * + * For comparison, the standard fp16 format has fewer exponent bits. + * + * ┌sign + * │ + * │ ┌exponent + * │ │ + * │ │ ┌mantissa + * │ │ │ + * │┌─┴─┐┌─┴──────┐ + * 0b0000000000000000 IEEE binary16 + * + * @see IEEE 754-2008 + */ +static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) { + union { + float f; + uint32_t i; + } u; + u.i = (uint32_t)h.bits << 16; + return u.f; +} + +/** + * Converts float32 to brain16. + * + * This function is binary identical to AMD Zen4 VCVTNEPS2BF16. + * Subnormals shall be flushed to zero, and NANs will be quiet. + * This code should vectorize nicely if using modern compilers. + */ +static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) { + ggml_bf16_t h; + union { + float f; + uint32_t i; + } u; + u.f = s; + if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */ + h.bits = (u.i >> 16) | 64; /* force to quiet */ + return h; + } + if (!(u.i & 0x7f800000)) { /* subnormal */ + h.bits = (u.i & 0x80000000) >> 16; /* flush to zero */ + return h; + } + h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16; + return h; +} + +#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x) +#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x) + #ifdef __cplusplus extern "C" { #endif diff --git a/ggml-metal.m b/ggml-metal.m index 017b72ce9..78cac5041 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -803,7 +803,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_DIAG_MASK_INF: case GGML_OP_GET_ROWS: { - return op->ne[3] == 1; + return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1; } default: return false; diff --git a/ggml-metal.metal b/ggml-metal.metal index 3d4276ae0..46c7d5039 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2175,7 +2175,7 @@ kernel void kernel_flash_attn_ext_f16( const short D4 = D/4; const short D8 = D/8; - const short Q8 = Q/8; + //const short Q8 = Q/8; const short NW = N_SIMDWIDTH; const short SH = (C + Q); // shared memory per simdgroup in (half) diff --git a/ggml-quants.c b/ggml-quants.c index 444d1e55e..9883b6f8c 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -12450,6 +12450,24 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte const size_t nb = nbytes/ggml_type_size(type); switch (type) { + case GGML_TYPE_BF16: + { + int nans = 0; + int infs = 0; + const unsigned short * f = (const unsigned short *) data; + for (size_t i = 0; i < nb; ++i) { + nans += (f[i] & 0x7fff) > 0x7f80; + infs += (f[i] & 0x7fff) == 0x7f80; + } + if (nans) { + fprintf(stderr, "%s: found %d NaNs in row of %zu BF16 values\n", __func__, nans, nb); + return false; + } + if (infs) { + fprintf(stderr, "%s: found %d infinities in row of %zu BF16 values\n", __func__, infs, nb); + return false; + } + } break; case GGML_TYPE_F16: { const ggml_fp16_t * f = (const ggml_fp16_t *) data; diff --git a/ggml.c b/ggml.c index 82179a125..093d38d00 100644 --- a/ggml.c +++ b/ggml.c @@ -322,7 +322,7 @@ static ggml_fp16_t ggml_table_exp_f16[1 << 16]; // precomputed f32 table for f16 (256 KB) (ggml-impl.h) float ggml_table_f32_f16[1 << 16]; -const char * ggml_status_to_string(enum ggml_status status) { +GGML_CALL const char * ggml_status_to_string(enum ggml_status status) { switch (status) { case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)"; case GGML_STATUS_FAILED: return "GGML status: error (operation failed)"; @@ -333,16 +333,26 @@ const char * ggml_status_to_string(enum ggml_status status) { return "GGML status: unknown"; } -// note: do not use these inside ggml.c -// these are meant to be used via the ggml.h API float ggml_fp16_to_fp32(ggml_fp16_t x) { +#define ggml_fp16_to_fp32 do_not_use__ggml_fp16_to_fp32__in_ggml return GGML_FP16_TO_FP32(x); } ggml_fp16_t ggml_fp32_to_fp16(float x) { +#define ggml_fp32_to_fp16 do_not_use__ggml_fp32_to_fp16__in_ggml return GGML_FP32_TO_FP16(x); } +float ggml_bf16_to_fp32(ggml_bf16_t x) { +#define ggml_bf16_to_fp32 do_not_use__ggml_bf16_to_fp32__in_ggml + return GGML_BF16_TO_FP32(x); // it just left shifts +} + +ggml_bf16_t ggml_fp32_to_bf16(float x) { +#define ggml_fp32_to_bf16 do_not_use__ggml_fp32_to_bf16__in_ggml + return GGML_FP32_TO_BF16(x); +} + void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) { for (int64_t i = 0; i < n; i++) { y[i] = GGML_FP16_TO_FP32(x[i]); @@ -368,6 +378,49 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) { } } +void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) { + int64_t i = 0; +#if defined(__AVX512F__) + for (; i + 16 <= n; i += 16) { + _mm512_storeu_ps(y + i, + _mm512_castsi512_ps( + _mm512_slli_epi32( + _mm512_cvtepu16_epi32( + _mm256_loadu_si256( + (const __m256i *)(x + i))), + 16))); + } +#elif defined(__AVX2__) + for (; i + 8 <= n; i += 8) { + _mm256_storeu_ps(y + i, + _mm256_castsi256_ps( + _mm256_slli_epi32( + _mm256_cvtepu16_epi32( + _mm_loadu_si128( + (const __m128i *)(x + i))), + 16))); + } +#endif + for (; i < n; i++) { + y[i] = GGML_BF16_TO_FP32(x[i]); + } +} + +void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) { + int i = 0; +#if defined(__AVX512BF16__) + for (; i + 32 <= n; i += 32) { + _mm512_storeu_ps( + (__m512 *)(y + i), + (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16), + _mm512_loadu_ps(x + i))); + } +#endif + for (; i < n; i++) { + y[i] = GGML_FP32_TO_BF16(x[i]); + } +} + bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) { return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0; } @@ -503,6 +556,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc); static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc); +static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc); static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { [GGML_TYPE_I8] = { @@ -845,6 +899,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_q8_K), .is_quantized = true, .from_float = quantize_row_q8_K, + }, + [GGML_TYPE_BF16] = { + .type_name = "bf16", + .blck_size = 1, + .type_size = sizeof(ggml_bf16_t), + .is_quantized = false, + .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row, + .from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row, + .from_float_reference = (ggml_from_float_t) ggml_fp32_to_bf16_row, + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16, + .vec_dot_type = GGML_TYPE_BF16, + .nrows = 1, } }; @@ -1480,6 +1546,8 @@ inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } +inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; } inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } @@ -1498,7 +1566,7 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * UNUSED(by); UNUSED(bs); -#ifdef GGML_SIMD +#if defined(GGML_SIMD) float sumf = 0.0f; const int np = (n & ~(GGML_F32_STEP - 1)); @@ -1534,6 +1602,70 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * *s = sumf; } +static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + int i = 0; + ggml_float sumf = 0; + +#if defined(__AVX512BF16__) + __m512 c1 = _mm512_setzero_ps(); + __m512 c2 = _mm512_setzero_ps(); + for (; i + 64 <= n; i += 64) { + c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)), + (__m512bh)_mm512_loadu_ps((const float *)(y + i))); + c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)), + (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32))); + } + sumf += (ggml_float)_mm512_reduce_add_ps(c1); + sumf += (ggml_float)_mm512_reduce_add_ps(c2); + +#elif defined(__AVX512F__) +#define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16)) + __m512 c1 = _mm512_setzero_ps(); + __m512 c2 = _mm512_setzero_ps(); + for (; i + 32 <= n; i += 32) { + c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1); + c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2); + } + sumf += (ggml_float)_mm512_reduce_add_ps(c1); + sumf += (ggml_float)_mm512_reduce_add_ps(c2); + +#undef LOAD +#elif defined(__AVX2__) +#define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)) + __m256 c1 = _mm256_setzero_ps(); + __m256 c2 = _mm256_setzero_ps(); + __m256 c3 = _mm256_setzero_ps(); + __m256 c4 = _mm256_setzero_ps(); + for (; i + 32 <= n; i += 32) { + c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1); + c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2); + c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3); + c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4); + } + __m128 g; + c1 = _mm256_add_ps(_mm256_add_ps(c1, c3), + _mm256_add_ps(c2, c4)); + g = _mm_add_ps(_mm256_extractf128_ps(c1, 1), + _mm256_castps256_ps128(c1)); + g = _mm_add_ps(g, _mm_movehl_ps(g, g)); + g = _mm_add_ss(g, _mm_movehdup_ps(g)); + sumf += (ggml_float)_mm_cvtss_f32(g); + +#undef LOAD +#endif + + for (; i < n; ++i) { + sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) * + GGML_BF16_TO_FP32(y[i])); + } + *s = sumf; +} + static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -1967,6 +2099,14 @@ inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_ *s = sum; } +inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) { + float sum = 0.0f; + for (int i = 0; i < n; ++i) { + sum += GGML_BF16_TO_FP32(x[i]); + } + *s = sum; +} + inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE float max = -INFINITY; @@ -2377,7 +2517,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) { // figure out which node we're on uint current_cpu; int getcpu_ret = 0; -#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) +#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__) getcpu_ret = getcpu(¤t_cpu, &g_state.numa.current_node); #else // old glibc doesn't have a wrapper for this call. Fall back on direct syscall @@ -2588,6 +2728,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { switch (ftype) { case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break; case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break; + case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break; case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break; case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break; case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; @@ -2729,15 +2870,16 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { { const uint64_t t_start = ggml_time_us(); UNUSED(t_start); - ggml_fp16_t ii; for (int i = 0; i < (1 << 16); ++i) { - uint16_t ui = i; - memcpy(&ii, &ui, sizeof(ii)); - const float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii); + union { + uint16_t u16; + ggml_fp16_t fp16; + } u = {i}; + float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16); ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f)); ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f)); - ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f)); + ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f)); } const uint64_t t_end = ggml_time_us(); UNUSED(t_end); @@ -3201,6 +3343,13 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value)); } } break; + case GGML_TYPE_BF16: + { + assert(tensor->nb[0] == sizeof(ggml_fp16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value)); + } + } break; case GGML_TYPE_F32: { assert(tensor->nb[0] == sizeof(float)); @@ -3253,6 +3402,13 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value)); } } break; + case GGML_TYPE_BF16: + { + assert(tensor->nb[0] == sizeof(ggml_bf16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value)); + } + } break; case GGML_TYPE_F32: { assert(tensor->nb[0] == sizeof(float)); @@ -3320,6 +3476,11 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); } + case GGML_TYPE_BF16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t)); + return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]); + } case GGML_TYPE_F32: { GGML_ASSERT(tensor->nb[0] == sizeof(float)); @@ -3362,6 +3523,11 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); } break; + case GGML_TYPE_BF16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t)); + ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value); + } break; case GGML_TYPE_F32: { GGML_ASSERT(tensor->nb[0] == sizeof(float)); @@ -3385,6 +3551,8 @@ int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i return ((int32_t *) data)[0]; case GGML_TYPE_F16: return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]); + case GGML_TYPE_BF16: + return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]); case GGML_TYPE_F32: return ((float *) data)[0]; default: @@ -3413,6 +3581,10 @@ void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, { ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value); } break; + case GGML_TYPE_BF16: + { + ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value); + } break; case GGML_TYPE_F32: { ((float *)(data))[0] = value; @@ -3451,6 +3623,11 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); } + case GGML_TYPE_BF16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t)); + return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]); + } case GGML_TYPE_F32: { GGML_ASSERT(tensor->nb[0] == sizeof(float)); @@ -3493,6 +3670,11 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); } break; + case GGML_TYPE_BF16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t)); + ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value); + } break; case GGML_TYPE_F32: { GGML_ASSERT(tensor->nb[0] == sizeof(float)); @@ -3516,6 +3698,8 @@ float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, return ((int32_t *) data)[0]; case GGML_TYPE_F16: return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]); + case GGML_TYPE_BF16: + return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]); case GGML_TYPE_F32: return ((float *) data)[0]; default: @@ -3544,6 +3728,10 @@ void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, { ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value); } break; + case GGML_TYPE_BF16: + { + ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value); + } break; case GGML_TYPE_F32: { ((float *)(data))[0] = value; @@ -3738,7 +3926,11 @@ static struct ggml_tensor * ggml_add_cast_impl( // TODO: support less-strict constraint // GGML_ASSERT(ggml_can_repeat(b, a)); GGML_ASSERT(ggml_can_repeat_rows(b, a)); - GGML_ASSERT(ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16); // currently only supported for quantized input and f16 + + // currently only supported for quantized input and f16 + GGML_ASSERT(ggml_is_quantized(a->type) || + a->type == GGML_TYPE_F16 || + a->type == GGML_TYPE_BF16); bool is_node = false; @@ -7215,8 +7407,8 @@ static void ggml_compute_forward_dup_same_cont( ((char *) src0->data + ie0*nb00), (ie1 - ie0) * ggml_type_size(src0->type)); } - } + static void ggml_compute_forward_dup_f16( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -7490,6 +7682,366 @@ static void ggml_compute_forward_dup_f16( } } +static void ggml_compute_forward_dup_bf16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + GGML_TENSOR_UNARY_OP_LOCALS + + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) { + ggml_compute_forward_dup_same_cont(params, dst); + return; + } + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (src0->type == dst->type && + ne00 == ne0 && + nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) { + // copy by rows + const size_t rs = ne00*nb00; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ir0; i01 < ir1; i01++) { + memcpy( + ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + rs); + } + } + } + return; + } + + // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy + + if (ggml_is_contiguous(dst)) { + if (nb00 == sizeof(ggml_bf16_t)) { + if (dst->type == GGML_TYPE_BF16) { + size_t id = 0; + const size_t rs = ne00 * nb00; + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, rs); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + for (int i00 = 0; i00 < ne00; i00++) { + dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00])); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + for (int i00 = 0; i00 < ne00; i00++) { + dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (type_traits[dst->type].from_float) { + ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float; + float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; + + size_t id = 0; + size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + for (int i00 = 0; i00 < ne00; i00++) { + src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]); + } + + quantize_row_q(src0_f32, dst_ptr + id, ne00); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_BF16) { + size_t id = 0; + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr)); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } + return; + } + + // dst counters + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + if (dst->type == GGML_TYPE_BF16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t)); + + if (++i10 == ne00) { + i10 = 0; + if (++i11 == ne01) { + i11 = 0; + if (++i12 == ne02) { + i12 = 0; + if (++i13 == ne03) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr)); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else if (dst->type == GGML_TYPE_F32) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } +} + static void ggml_compute_forward_dup_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -7617,6 +8169,24 @@ static void ggml_compute_forward_dup_f32( id += ne00 * (ne01 - ir1); } } + } else if (dst->type == GGML_TYPE_BF16) { + size_t id = 0; + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } } else { GGML_ASSERT(false); // TODO: implement } @@ -7736,6 +8306,58 @@ static void ggml_compute_forward_dup_f32( } } } + } else if (dst->type == GGML_TYPE_BF16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } } else { GGML_ASSERT(false); // TODO: implement } @@ -7909,6 +8531,10 @@ static void ggml_compute_forward_dup( { ggml_compute_forward_dup_f16(params, dst); } break; + case GGML_TYPE_BF16: + { + ggml_compute_forward_dup_bf16(params, dst); + } break; case GGML_TYPE_F32: { ggml_compute_forward_dup_f32(params, dst); @@ -8091,6 +8717,85 @@ static void ggml_compute_forward_add_f16_f32( } } +static void ggml_compute_forward_add_bf16_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_BF16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + if (dst->type == GGML_TYPE_F32) { + GGML_ASSERT( nb0 == sizeof(float)); + } + else { + GGML_ASSERT(dst->type == GGML_TYPE_BF16); + GGML_ASSERT( nb0 == sizeof(ggml_bf16_t)); + } + + GGML_ASSERT(nb00 == sizeof(ggml_bf16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(float)) { + if (dst->type == GGML_TYPE_BF16) { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); + } + } + } else { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]; + } + } + } + } + else { + // src1 is not contiguous + GGML_ASSERT(false); + } +} + static void ggml_compute_forward_add_f16_f16( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -8147,6 +8852,62 @@ static void ggml_compute_forward_add_f16_f16( } } +static void ggml_compute_forward_add_bf16_bf16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_BF16); + GGML_ASSERT(src1->type == GGML_TYPE_BF16); + GGML_ASSERT(dst->type == GGML_TYPE_BF16); + + GGML_ASSERT( nb0 == sizeof(ggml_bf16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_bf16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(ggml_bf16_t)) { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + ggml_bf16_t * src1_ptr = (ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + GGML_BF16_TO_FP32(src1_ptr[i])); + } + } + } + else { + // src1 is not contiguous + GGML_ASSERT(false); + } +} + static void ggml_compute_forward_add_q_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -8256,6 +9017,18 @@ static void ggml_compute_forward_add( GGML_ASSERT(false); } } break; + case GGML_TYPE_BF16: + { + if (src1->type == GGML_TYPE_BF16) { + ggml_compute_forward_add_bf16_bf16(params, dst); + } + else if (src1->type == GGML_TYPE_F32) { + ggml_compute_forward_add_bf16_f32(params, dst); + } + else { + GGML_ASSERT(false); + } + } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -8514,6 +9287,110 @@ static void ggml_compute_forward_add1_q_f32( } } +static void ggml_compute_forward_add1_bf16_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + // scalar to add + const float v = *(float *) src1->data; + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_BF16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_BF16); + + GGML_ASSERT( nb0 == sizeof(ggml_bf16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_bf16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v); + } + } +} + +static void ggml_compute_forward_add1_bf16_bf16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + // scalar to add + const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_BF16); + GGML_ASSERT(src1->type == GGML_TYPE_BF16); + GGML_ASSERT(dst->type == GGML_TYPE_BF16); + + GGML_ASSERT( nb0 == sizeof(ggml_bf16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_bf16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v); + } + } +} + static void ggml_compute_forward_add1( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -8538,6 +9415,18 @@ static void ggml_compute_forward_add1( GGML_ASSERT(false); } } break; + case GGML_TYPE_BF16: + { + if (src1->type == GGML_TYPE_BF16) { + ggml_compute_forward_add1_bf16_bf16(params, dst); + } + else if (src1->type == GGML_TYPE_F32) { + ggml_compute_forward_add1_bf16_f32(params, dst); + } + else { + GGML_ASSERT(false); + } + } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -8666,6 +9555,7 @@ static void ggml_compute_forward_acc( ggml_compute_forward_acc_f32(params, dst); } break; case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -9187,6 +10077,40 @@ static void ggml_compute_forward_sum_f16( ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum); } +static void ggml_compute_forward_sum_bf16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + assert(params->ith == 0); + assert(ggml_is_scalar(dst)); + + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + assert(src0->nb[0] == sizeof(ggml_bf16_t)); + + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + + float sum = 0; + float row_sum = 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_bf16_ggf(ne00, + &row_sum, + (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03)); + sum += row_sum; + } + } + } + ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum); +} + static void ggml_compute_forward_sum( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -9202,6 +10126,10 @@ static void ggml_compute_forward_sum( { ggml_compute_forward_sum_f16(params, dst); } break; + case GGML_TYPE_BF16: + { + ggml_compute_forward_sum_bf16(params, dst); + } break; default: { GGML_ASSERT(false); @@ -9476,6 +10404,7 @@ static void ggml_compute_forward_repeat( switch (src0->type) { case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_I16: { ggml_compute_forward_repeat_f16(params, dst); @@ -11793,6 +12722,7 @@ static void ggml_compute_forward_set( ggml_compute_forward_set_f32(params, dst); } break; case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -11967,6 +12897,49 @@ static void ggml_compute_forward_get_rows_f16( } } +static void ggml_compute_forward_get_rows_bf16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); + + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(ggml_bf16_t)); + assert(ggml_nrows(dst) == nr); + + const int ith = params->ith; + const int nth = params->nth; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i/(ne11*ne10); + const int64_t i11 = (i - i12*ne11*ne10)/ne10; + const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + ggml_bf16_to_fp32_row( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } +} + static void ggml_compute_forward_get_rows_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -12044,6 +13017,10 @@ static void ggml_compute_forward_get_rows( { ggml_compute_forward_get_rows_f16(params, dst); } break; + case GGML_TYPE_BF16: + { + ggml_compute_forward_get_rows_bf16(params, dst); + } break; case GGML_TYPE_F32: case GGML_TYPE_I32: { @@ -12739,6 +13716,7 @@ static void ggml_compute_forward_alibi( { ggml_compute_forward_alibi_f32(params, dst); } break; + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -12828,6 +13806,7 @@ static void ggml_compute_forward_clamp( ggml_compute_forward_clamp_f32(params, dst); } break; case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -15921,6 +16900,7 @@ static void ggml_compute_forward_get_rel_pos( switch (src0->type) { case GGML_TYPE_F16: + case GGML_TYPE_BF16: { ggml_compute_forward_get_rel_pos_f16(params, dst); } break; @@ -18785,7 +19765,10 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa case GGML_OP_CPY: case GGML_OP_DUP: { - if (ggml_is_quantized(node->type)) { + if (ggml_is_quantized(node->type) || + // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32 + (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) || + (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) { cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; } } break; @@ -18864,7 +19847,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa const int64_t ne10 = node->src[1]->ne[0]; // L const int64_t ne11 = node->src[1]->ne[1]; // Cin - if (node->src[0]->type == GGML_TYPE_F16 && + if ((node->src[0]->type == GGML_TYPE_F16 || + node->src[0]->type == GGML_TYPE_BF16) && node->src[1]->type == GGML_TYPE_F32) { cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02; cur += sizeof(ggml_fp16_t)*ne10*ne11; @@ -18900,6 +19884,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa } else if (node->src[1]->type == GGML_TYPE_F16) { cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 + } else if (node->src[1]->type == GGML_TYPE_BF16) { + cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 } } break; case GGML_OP_FLASH_ATTN_EXT: @@ -18916,6 +19903,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa } else if (node->src[1]->type == GGML_TYPE_F16) { cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2 + } else if (node->src[1]->type == GGML_TYPE_BF16) { + cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2 } } break; case GGML_OP_FLASH_ATTN_BACK: @@ -18929,6 +19919,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa } else if (node->src[1]->type == GGML_TYPE_F16) { cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 + } else if (node->src[1]->type == GGML_TYPE_BF16) { + cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 } } break; @@ -19705,7 +20698,9 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) { fprintf(fp, "%d", ggml_get_i32_1d(node, j)); } - else if (node->type == GGML_TYPE_F32 || node->type == GGML_TYPE_F16) { + else if (node->type == GGML_TYPE_F32 || + node->type == GGML_TYPE_F16 || + node->type == GGML_TYPE_BF16) { fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j)); } else { @@ -20763,6 +21758,12 @@ size_t ggml_quantize_chunk( ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n); result = n * elemsize; } break; + case GGML_TYPE_BF16: + { + size_t elemsize = sizeof(ggml_bf16_t); + ggml_fp32_to_bf16_row(src + start, (ggml_bf16_t *)dst + start, n); + result = n * elemsize; + } break; case GGML_TYPE_F32: { size_t elemsize = sizeof(float); diff --git a/ggml.h b/ggml.h index a11795973..fe6053822 100644 --- a/ggml.h +++ b/ggml.h @@ -326,14 +326,20 @@ extern "C" { // get ggml_status name string GGML_API GGML_CALL const char * ggml_status_to_string(enum ggml_status status); + // ieee 754-2008 half-precision float16 + // todo: make this not an integral type typedef uint16_t ggml_fp16_t; + GGML_API float ggml_fp16_to_fp32(ggml_fp16_t); + GGML_API ggml_fp16_t ggml_fp32_to_fp16(float); + GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t *, float *, int64_t); + GGML_API void ggml_fp32_to_fp16_row(const float *, ggml_fp16_t *, int64_t); - // convert FP16 <-> FP32 - GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x); - GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x); - - GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n); - GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n); + // google brain half-precision bfloat16 + typedef struct { uint16_t bits; } ggml_bf16_t; + GGML_API ggml_bf16_t ggml_fp32_to_bf16(float); + GGML_API float ggml_bf16_to_fp32(ggml_bf16_t); // consider just doing << 16 + GGML_API void ggml_bf16_to_fp32_row(const ggml_bf16_t *, float *, int64_t); + GGML_API void ggml_fp32_to_bf16_row(const float *, ggml_bf16_t *, int64_t); struct ggml_object; struct ggml_context; @@ -370,6 +376,7 @@ extern "C" { GGML_TYPE_I64 = 27, GGML_TYPE_F64 = 28, GGML_TYPE_IQ1_M = 29, + GGML_TYPE_BF16 = 30, GGML_TYPE_COUNT, }; @@ -410,6 +417,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors + GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors }; // available tensor operations: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1ebdee962..5951c0bb0 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -817,6 +817,7 @@ class GGMLQuantizationType(IntEnum): I64 = 27 F64 = 28 IQ1_M = 29 + BF16 = 30 class GGUFEndian(IntEnum): @@ -888,6 +889,7 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { GGMLQuantizationType.I64: (1, 8), GGMLQuantizationType.F64: (1, 8), GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32), + GGMLQuantizationType.BF16: (1, 2), } diff --git a/grammars/README.md b/grammars/README.md index c924e8d46..2b8384d9d 100644 --- a/grammars/README.md +++ b/grammars/README.md @@ -51,7 +51,7 @@ single-line ::= [^\n]+ "\n"` ## Sequences and Alternatives -The order of symbols in a sequence matter. For example, in `"1. " move " " move "\n"`, the `"1. "` must come before the first `move`, etc. +The order of symbols in a sequence matters. For example, in `"1. " move " " move "\n"`, the `"1. "` must come before the first `move`, etc. Alternatives, denoted by `|`, give different sequences that are acceptable. For example, in `move ::= pawn | nonpawn | castle`, `move` can be a `pawn` move, a `nonpawn` move, or a `castle`. diff --git a/llama.cpp b/llama.cpp index aeb5c08df..9c72d118f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3175,6 +3175,7 @@ struct llama_model_loader { switch (type_max) { case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break; case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break; + case GGML_TYPE_BF16: ftype = LLAMA_FTYPE_MOSTLY_BF16; break; case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break; case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break; case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break; @@ -3666,6 +3667,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { switch (ftype) { case LLAMA_FTYPE_ALL_F32: return "all F32"; case LLAMA_FTYPE_MOSTLY_F16: return "F16"; + case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: @@ -4389,6 +4391,15 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "command-r") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R; + } else if ( + tokenizer_pre == "qwen2") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2; + } else if ( + tokenizer_pre == "olmo") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_OLMO; + } else if ( + tokenizer_pre == "dbrx") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DBRX; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -6126,6 +6137,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam || !( model.ftype == LLAMA_FTYPE_ALL_F32 || model.ftype == LLAMA_FTYPE_MOSTLY_F16 || + model.ftype == LLAMA_FTYPE_MOSTLY_BF16 || model.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || model.ftype == LLAMA_FTYPE_MOSTLY_Q4_1 ) @@ -12194,6 +12206,7 @@ struct llm_tokenizer_bpe { case LLAMA_VOCAB_TYPE_BPE: switch (vocab.type_pre) { case LLAMA_VOCAB_PRE_TYPE_LLAMA3: + case LLAMA_VOCAB_PRE_TYPE_DBRX: word_collection = unicode_regex_split(text, { // original regex from tokenizer.json //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", @@ -12248,10 +12261,18 @@ struct llm_tokenizer_bpe { }); break; case LLAMA_VOCAB_PRE_TYPE_GPT2: + case LLAMA_VOCAB_PRE_TYPE_OLMO: word_collection = unicode_regex_split(text, { "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", }); break; + case LLAMA_VOCAB_PRE_TYPE_QWEN2: + word_collection = unicode_regex_split(text, { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }); + break; default: // default regex for BPE tokenization pre-processing word_collection = unicode_regex_split(text, { @@ -14154,13 +14175,16 @@ static void llama_tensor_dequantize_internal( if (qtype.to_float == NULL) { throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type))); } - } else if (tensor->type != GGML_TYPE_F16) { + } else if (tensor->type != GGML_TYPE_F16 && + tensor->type != GGML_TYPE_BF16) { throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type))); } if (nthread < 2) { if (tensor->type == GGML_TYPE_F16) { ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements); + } else if (tensor->type == GGML_TYPE_BF16) { + ggml_bf16_to_fp32_row((ggml_bf16_t *)tensor->data, f32_output, nelements); } else if (ggml_is_quantized(tensor->type)) { qtype.to_float(tensor->data, f32_output, nelements); } else { @@ -14169,7 +14193,14 @@ static void llama_tensor_dequantize_internal( return; } - size_t block_size = tensor->type == GGML_TYPE_F16 ? 1 : (size_t)ggml_blck_size(tensor->type); + size_t block_size; + if (tensor->type == GGML_TYPE_F16 || + tensor->type == GGML_TYPE_BF16) { + block_size = 1; + } else { + block_size = (size_t)ggml_blck_size(tensor->type); + } + size_t block_size_bytes = ggml_type_size(tensor->type); GGML_ASSERT(nelements % block_size == 0); @@ -14188,6 +14219,8 @@ static void llama_tensor_dequantize_internal( auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) { if (typ == GGML_TYPE_F16) { ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels); + } else if (typ == GGML_TYPE_BF16) { + ggml_bf16_to_fp32_row((ggml_bf16_t *)inbuf, outbuf, nels); } else { qtype.to_float(inbuf, outbuf, nels); } @@ -14548,6 +14581,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break; case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break; case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; + case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; // K-quants diff --git a/llama.h b/llama.h index e2fd53ab7..0b2e708d0 100644 --- a/llama.h +++ b/llama.h @@ -81,6 +81,9 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, LLAMA_VOCAB_PRE_TYPE_REFACT = 8, LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, + LLAMA_VOCAB_PRE_TYPE_QWEN2 = 10, + LLAMA_VOCAB_PRE_TYPE_OLMO = 11, + LLAMA_VOCAB_PRE_TYPE_DBRX = 12, }; // note: these values should be synchronized with ggml_rope @@ -136,6 +139,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors + LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/models/ggml-vocab-qwen2.gguf b/models/ggml-vocab-qwen2.gguf new file mode 100644 index 000000000..541e475bc Binary files /dev/null and b/models/ggml-vocab-qwen2.gguf differ diff --git a/models/ggml-vocab-qwen2.gguf.inp b/models/ggml-vocab-qwen2.gguf.inp new file mode 100644 index 000000000..0a89107c6 --- /dev/null +++ b/models/ggml-vocab-qwen2.gguf.inp @@ -0,0 +1,106 @@ +ied 4 ½ months +__ggml_vocab_test__ +Führer +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-qwen2.gguf.out b/models/ggml-vocab-qwen2.gguf.out new file mode 100644 index 000000000..401a510e8 --- /dev/null +++ b/models/ggml-vocab-qwen2.gguf.out @@ -0,0 +1,43 @@ + 1122 220 19 220 26062 3951 + 37 50753 261 + + 220 + 256 + 262 + 197 + 198 + 271 + 1406 + 1572 + 9707 1879 + 21927 1879 + 9707 4337 + 21927 4337 + 21927 4337 0 + 9707 11 1879 0 + 21927 11 1879 0 + 419 374 11162 99 247 13 10821 + 86 15 19 23 220 22 83 1963 41808 11472 2940 16739 + 78762 14144 1456 13073 63471 33594 3038 133178 79012 + 146394 97529 241 44258 233 146568 44258 224 147603 20879 115 146280 44258 223 146280 147272 97529 227 147805 148301 147270 44258 223 146848 + 145836 320 8252 8 26525 114 378 235 149921 30543 320 35673 99066 97534 8 25521 227 320 3243 42365 429 702 1181 1828 3950 8 + 9707 + 21927 + 220 21927 + 256 21927 + 262 21927 + 262 21927 198 262 21927 + 320 + 198 284 + 6 11385 + 9707 11 379 64848 0 2585 525 498 26525 223 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216 + 18 + 18 18 + 18 18 18 + 18 18 18 18 + 18 18 18 18 18 + 18 18 18 18 18 18 + 18 18 18 18 18 18 18 + 18 18 18 18 18 18 18 18 + 18 18 18 18 18 18 18 18 18 + 198 4710 14731 65497 7847 1572 2303 78672 10947 145836 320 8252 8 26525 114 378 235 149921 30543 320 35673 99066 97534 8 25521 227 11162 99 247 149955 220 18 220 18 18 220 18 18 18 220 18 18 18 18 220 18 18 18 18 18 220 18 18 18 18 18 18 220 18 18 18 18 18 18 18 220 18 18 18 18 18 18 18 18 220 18 13 18 220 18 496 18 220 18 1112 18 220 146394 97529 241 44258 233 146568 44258 224 147603 20879 115 146280 44258 223 146280 147272 97529 227 144534 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216 55460 53237 18658 14144 1456 13073 63471 33594 3038 133178 79012 3355 4605 4605 13874 13874 73594 3014 3014 28149 17085 2928 26610 7646 358 3003 1012 364 83 813 566 594 1052 11 364 787 498 2704 30 364 44 537 2704 358 3278 1281 432 11 364 35 498 1075 1045 15243 30 1205 6 42612 264 63866 43 diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index 3892fd25c..fed3c1ee3 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -93,11 +93,14 @@ help_s = ( "specified values are averaged WITHOUT weighing by the --repetitions parameter of llama-bench." ) parser.add_argument("-s", "--show", help=help_s) +parser.add_argument("--verbose", action="store_true", help="increase output verbosity") known_args, unknown_args = parser.parse_known_args() +logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO) + if unknown_args: - logger.error(f"Received unknown args: {unknown_args}.") + logger.error(f"Received unknown args: {unknown_args}.\n") parser.print_help() sys.exit(1) @@ -110,7 +113,7 @@ if input_file is None: input_file = sqlite_files[0] if input_file is None: - logger.error("Cannot find a suitable input file, please provide one.") + logger.error("Cannot find a suitable input file, please provide one.\n") parser.print_help() sys.exit(1) @@ -202,12 +205,12 @@ elif repo is not None: hexsha8_baseline = find_parent_in_data(repo.heads.master.commit) if hexsha8_baseline is None: - logger.error("No baseline was provided and did not find data for any master branch commits.") + logger.error("No baseline was provided and did not find data for any master branch commits.\n") parser.print_help() sys.exit(1) else: logger.error("No baseline was provided and the current working directory " - "is not part of a git repository from which a baseline could be inferred.") + "is not part of a git repository from which a baseline could be inferred.\n") parser.print_help() sys.exit(1) @@ -238,7 +241,7 @@ elif repo is not None: break if hexsha8_compare is None: - logger.error("No compare target was provided and did not find data for any non-master commits.") + logger.error("No compare target was provided and did not find data for any non-master commits.\n") parser.print_help() sys.exit(1) else: @@ -361,7 +364,7 @@ if "gpu_info" in show: headers = [PRETTY_NAMES[p] for p in show] headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"] -logger.info(tabulate( +print(tabulate( # noqa: NP100 table, headers=headers, floatfmt=".2f", diff --git a/sgemm.cpp b/sgemm.cpp index 4e0159804..40ba9d7e9 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -1,6 +1,3 @@ -// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- -// vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi -// // Copyright 2024 Mozilla Foundation // // Permission is hereby granted, free of charge, to any person obtaining @@ -585,15 +582,15 @@ class tinyBLAS_Q0_ARM { }; #endif // __ARM_FEATURE_DOTPROD -#if defined(__AVX2__) || defined(__AVX512F__) +#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) template -class tinyBLAS_Q0_AVX2 { +class tinyBLAS_Q0_AVX { public: - tinyBLAS_Q0_AVX2(int64_t k, - const TA *A, int64_t lda, - const TB *B, int64_t ldb, - TC *C, int64_t ldc, - int ith, int nth) + tinyBLAS_Q0_AVX(int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc, + int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } @@ -728,14 +725,34 @@ class tinyBLAS_Q0_AVX2 { __m256 Cv[RN][RM] = {}; for (int64_t l = 0; l < k; ++l) for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) + for (int64_t i = 0; i < RM; ++i) { +#if defined(__AVX2__) + __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + load(A + lda * (ii + i) + l)), + _mm256_sign_epi8(load(B + ldb * (jj + j) + l), + load(A + lda * (ii + i) + l))); +#else + __m128i ali0 = load0(A + lda * (ii + i) + l); + __m128i ali1 = load1(A + lda * (ii + i) + l); + __m128i blj0 = load0(B + ldb * (jj + j) + l); + __m128i blj1 = load1(B + ldb * (jj + j) + l); + + __m128i sepAA0 = _mm_sign_epi8(ali0, ali0); + __m128i sepAA1 = _mm_sign_epi8(ali1, ali1); + __m128i sepBA0 = _mm_sign_epi8(blj0, ali0); + __m128i sepBA1 = _mm_sign_epi8(blj1, ali1); + + // updot + const __m128i oneFill = _mm_set1_epi16(1); + __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0); + __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1); + __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0))); +#endif Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * unhalf(B[ldb * (jj + j) + l].d)), - updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), - load(A + lda * (ii + i) + l)), - _mm256_sign_epi8(load(B + ldb * (jj + j) + l), - load(A + lda * (ii + i) + l))), - Cv[j][i]); + udTmp, + Cv[j][i]); + } for (int64_t j = 0; j < RN; ++j) for (int64_t i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); @@ -746,10 +763,28 @@ class tinyBLAS_Q0_AVX2 { return _mm256_loadu_si256((const __m256i *)b->qs); } + inline __m128i load0(const block_q8_0 *b) { + return _mm_loadu_si128((const __m128i *)b->qs); + } + + inline __m128i load1(const block_q8_0 *b) { + return _mm_loadu_si128(((const __m128i *)b->qs) + 1); + } + inline __m256i load(const block_q4_0 *b) { return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); } + inline __m128i load0(const block_q4_0 *b) { + const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8)); + } + + inline __m128i load1(const block_q4_0 *b) { + const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8)); + } + inline __m256 updot(__m256i u, __m256i s) { __m256i res; #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) @@ -777,7 +812,7 @@ class tinyBLAS_Q0_AVX2 { const int ith; const int nth; }; -#endif // __AVX2__ +#endif // __AVX__ } // namespace @@ -928,8 +963,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda case GGML_TYPE_Q8_0: { if (Btype != GGML_TYPE_Q8_0) return false; -#if defined(__AVX2__) || defined(__AVX512F__) - tinyBLAS_Q0_AVX2 tb{ +#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) + tinyBLAS_Q0_AVX tb{ k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, @@ -952,8 +987,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda case GGML_TYPE_Q4_0: { if (Btype != GGML_TYPE_Q8_0) return false; -#if defined(__AVX2__) || defined(__AVX512F__) - tinyBLAS_Q0_AVX2 tb{ +#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) + tinyBLAS_Q0_AVX tb{ k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 208f0cf7b..d409a1d6b 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -84,6 +84,7 @@ llama_test(test-tokenizer-0 NAME test-tokenizer-0-starcoder ARGS ${CMAKE llama_test(test-tokenizer-0 NAME test-tokenizer-0-gpt-2 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-2.gguf) llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf) llama_test(test-tokenizer-0 NAME test-tokenizer-0-command-r ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-command-r.gguf) +llama_test(test-tokenizer-0 NAME test-tokenizer-0-qwen2 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-qwen2.gguf) # build test-tokenizer-1-bpe target once and add many tests add_executable(test-tokenizer-1-bpe test-tokenizer-1-bpe.cpp) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b27c1291e..41718e001 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -50,7 +50,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) { ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float)); - } else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16) { + } else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16) { GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0); std::vector dataq(ggml_row_size(tensor->type, size)); std::vector imatrix(tensor->ne[0], 1.0f); // dummy importance matrix @@ -92,6 +92,8 @@ static std::vector tensor_to_float(const ggml_tensor * t) { size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0]; if (t->type == GGML_TYPE_F16) { tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i])); + } else if (t->type == GGML_TYPE_BF16) { + tv.push_back(ggml_bf16_to_fp32(*(ggml_bf16_t*)&buf[i])); } else if (t->type == GGML_TYPE_F32) { tv.push_back(*(float *) &buf[i]); } else if (t->type == GGML_TYPE_I32) { @@ -1898,7 +1900,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op std::default_random_engine rng(0); const ggml_type all_types[] = { - GGML_TYPE_F32, GGML_TYPE_F16, + GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0,