From 8f0fad42b9589f9b11362c74ea5338c51e4c7e61 Mon Sep 17 00:00:00 2001 From: laik Date: Wed, 10 Jul 2024 19:19:10 +0800 Subject: [PATCH 01/26] py : fix extra space in convert_hf_to_gguf.py (#8407) --- convert_hf_to_gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 6cea73f08..ecb807262 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1356,7 +1356,7 @@ class LlamaModel(Model): def set_vocab(self): try: - self. _set_vocab_sentencepiece() + self._set_vocab_sentencepiece() except FileNotFoundError: try: self._set_vocab_llama_hf() From e4dd31ff892627665d3c53c5d1a03c0d282d9d45 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Wed, 10 Jul 2024 19:26:40 +0800 Subject: [PATCH 02/26] py : fix converter for internlm2 (#8321) * update internlm2 * remove unused file * fix lint --- convert_hf_to_gguf.py | 68 +++++++++++++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 16 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ecb807262..ebb5ca376 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2144,6 +2144,9 @@ class InternLM2Model(Model): toktype = SentencePieceTokenTypes.UNUSED elif tokenizer.IsByte(token_id): toktype = SentencePieceTokenTypes.BYTE + # take care of ununsed raw token + if piece.startswith('[UNUSED'): + toktype = SentencePieceTokenTypes.UNKNOWN tokens.append(text) scores.append(score) @@ -2159,6 +2162,47 @@ class InternLM2Model(Model): scores.append(-1000.0) toktypes.append(SentencePieceTokenTypes.USER_DEFINED) + chat_eos_token = '<|im_end|>' + chat_eos_token_id = None + + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + added_tokens_decoder = tokenizer_config_json.get("added_tokens_decoder", {}) + for token_id, foken_data in added_tokens_decoder.items(): + token_id = int(token_id) + token = foken_data["content"] + if token == chat_eos_token: + chat_eos_token_id = token_id + token = token.encode("utf-8") + if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN: + assert(tokens[token_id] == token) + tokens[token_id] = token + scores[token_id] = -1000.0 + toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + if foken_data.get("special"): + toktypes[token_id] = SentencePieceTokenTypes.CONTROL + + tokenizer_file = self.dir_model / 'tokenizer.json' + if tokenizer_file.is_file(): + with open(tokenizer_file, "r", encoding="utf-8") as f: + tokenizer_json = json.load(f) + added_tokens = tokenizer_json.get("added_tokens", []) + for foken_data in added_tokens: + token_id = int(foken_data["id"]) + token = foken_data["content"] + if token == chat_eos_token: + chat_eos_token_id = token_id + token = token.encode("utf-8") + if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN: + assert(tokens[token_id] == token) + tokens[token_id] = token + scores[token_id] = -1000.0 + toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + if foken_data.get("special"): + toktypes[token_id] = SentencePieceTokenTypes.CONTROL + self.gguf_writer.add_tokenizer_model("llama") self.gguf_writer.add_tokenizer_pre("default") self.gguf_writer.add_token_list(tokens) @@ -2168,28 +2212,16 @@ class InternLM2Model(Model): special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) old_eos = special_vocab.special_token_ids["eos"] - if "chat" in os.path.basename(self.dir_model.absolute()): + if chat_eos_token_id is not None: # For the chat model, we replace the eos with '<|im_end|>'. # TODO: this is a hack, should be fixed # https://github.com/ggerganov/llama.cpp/pull/6745#issuecomment-2067687048 - special_vocab.special_token_ids["eos"] = self._try_get_sft_eos(tokenizer) - logger.warning(f"Replace eos:{old_eos} with a special token:{special_vocab.special_token_ids['eos']} \ -in chat mode so that the conversation can end normally.") + special_vocab.special_token_ids["eos"] = chat_eos_token_id + logger.warning(f"Replace eos:{old_eos} with a special token:{chat_eos_token_id}" + " in chat mode so that the conversation can end normally.") special_vocab.add_to_gguf(self.gguf_writer) - def _try_get_sft_eos(self, tokenizer): - unused_145_list = tokenizer.Encode('[UNUSED_TOKEN_145]') - im_end_list = tokenizer.Encode('<|im_end|>') - eos_token = None - assert (len(unused_145_list) == 1) ^ (len(im_end_list) == 1) - if len(unused_145_list) == 1: - eos_token = unused_145_list[0] - if len(im_end_list) == 1: - eos_token = im_end_list[0] - assert eos_token - return eos_token - def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int): if n_head_kv is not None and n_head != n_head_kv: n_head = n_head_kv @@ -2208,6 +2240,10 @@ in chat mode so that the conversation can end normally.") self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) self.gguf_writer.add_file_type(self.ftype) + if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: + if self.hparams["rope_scaling"].get("type") == "linear": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: num_heads = self.hparams["num_attention_heads"] From a8be1e6f5995bec3b1d8474d48ce3a139553a8e1 Mon Sep 17 00:00:00 2001 From: fairydreaming <166155368+fairydreaming@users.noreply.github.com> Date: Wed, 10 Jul 2024 13:38:58 +0200 Subject: [PATCH 03/26] llama : add assert about missing llama_encode() call (#8400) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Stanisław Szymczyk --- src/llama.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llama.cpp b/src/llama.cpp index 2b9ace285..80cc1da26 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13200,6 +13200,8 @@ struct llm_build_context { LLM_NORM_RMS, cb, -1); cb(cur, "result_norm", -1); } else { + GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first"); + struct ggml_tensor * embd_enc = llm_build_inp_embd_enc(); struct ggml_tensor * pos_bucket_dec = llm_build_pos_bucket(true); From 7a80710d93ee0f4e0d335d65984ae747e62c0337 Mon Sep 17 00:00:00 2001 From: Borislav Stanimirov Date: Wed, 10 Jul 2024 14:40:53 +0300 Subject: [PATCH 04/26] msvc : silence codecvt c++17 deprecation warnings (#8395) --- common/common.cpp | 4 ++++ src/unicode.cpp | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/common/common.cpp b/common/common.cpp index fc0f3b350..1e5fc30dd 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1,3 +1,7 @@ +#if defined(_MSC_VER) +#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING +#endif + #include "common.h" // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT diff --git a/src/unicode.cpp b/src/unicode.cpp index 51daa15af..e05fb9d17 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -1,3 +1,7 @@ +#if defined(_MSC_VER) +#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING +#endif + #include "unicode.h" #include "unicode-data.h" From cc61948b1f97165b46990f4c2d9699bf9bb64443 Mon Sep 17 00:00:00 2001 From: Borislav Stanimirov Date: Wed, 10 Jul 2024 14:45:44 +0300 Subject: [PATCH 05/26] llama : C++20 compatibility for u8 strings (#8408) --- src/llama.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 80cc1da26..98a242081 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -57,6 +57,12 @@ #include #endif +#if __cplusplus >= 202000L + #define LU8(x) (const char*)(u8##x) +#else + #define LU8(x) u8##x +#endif + #include #include #include @@ -21511,12 +21517,12 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>"; } - } else if (tmpl == "minicpm" || tmpl_contains(u8"<用户>")) { + } else if (tmpl == "minicpm" || tmpl_contains(LU8("<用户>"))) { // MiniCPM-3B-OpenHermes-2.5-v2-GGUF for (auto message : chat) { std::string role(message->role); if (role == "user") { - ss << u8"<用户>"; + ss << LU8("<用户>"); ss << trim(message->content); ss << ""; } else { @@ -21532,7 +21538,7 @@ static int32_t llama_chat_apply_template_internal( } else if (role == "user") { ss << "User: " << message->content << "\n\n"; } else if (role == "assistant") { - ss << "Assistant: " << message->content << u8"<|end▁of▁sentence|>"; + ss << "Assistant: " << message->content << LU8("<|end▁of▁sentence|>"); } } if (add_ass) { From 83321c6958acf20747d288031174cc1bf8816e33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Yusuf=20Sar=C4=B1g=C3=B6z?= Date: Wed, 10 Jul 2024 15:12:35 +0300 Subject: [PATCH 06/26] gguf-py rel pipeline (#8410) * Upd gguf-py/readme * Bump patch version for release --- gguf-py/README.md | 1 - gguf-py/pyproject.toml | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/gguf-py/README.md b/gguf-py/README.md index bc46d6e1d..9dd888f31 100644 --- a/gguf-py/README.md +++ b/gguf-py/README.md @@ -79,5 +79,4 @@ python -m twine upload dist/* ``` ## TODO -- [ ] Add tests - [ ] Include conversion scripts as command line entry points in this package. diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index 36e63ee3b..62129126b 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gguf" -version = "0.9.0" +version = "0.9.1" description = "Read and write ML models in GGUF for GGML" authors = ["GGML "] packages = [ From 0f1a39f3439825acf7e3a1663566d410be152170 Mon Sep 17 00:00:00 2001 From: Dibakar Gope Date: Wed, 10 Jul 2024 07:14:51 -0500 Subject: [PATCH 07/26] ggml : add AArch64 optimized GEMV and GEMM Q4 kernels (#5780) * Arm AArch64: optimized GEMV and GEMM kernels for q4_0_q8_0, and q8_0_q8_0 quantization * Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 quantization and refactor code to address llama.cpp pr#5780 suggestions * Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 quantization and refactor code to address llama.cpp pr#5780 suggestions * Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 quantization and refactor code to address llama.cpp pr#5780 suggestions * Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 quantization and refactor code to address llama.cpp pr#5780 suggestions * Arm AArch64: add copyright claim only to ggml-aarch64.cpp and ggml-aarch64.h files * Arm AArch64: minor code refactoring for rebase * Arm AArch64: minor code refactoring for resolving a build issue with cmake * Arm AArch64: minor code refactoring to split the Q4_0_AARC64 type into three separate types: Q4_0_4_4, Q4_0_4_8, and Q4_0_8_8 * Arm AArch64: minor code change for resolving a build issue with server-windows * retrigger checks * Arm AArch64: minor code changes for rebase * Arm AArch64: minor changes to skip the pr#7433 vec_dot code for arm cpus with SVE VL not equal to 256 bits * Arm AArch64: remove stale LLAMA_QKK_64 from CMakeLists.txt and delete build.zig * Arm AArch64: add reference scalar gemm and gemv, and avoid dynamic memory allocations during quantization for Q4_0_4_4, Q4_0_4_8, and Q4_0_8_8 * Arm AArch64: add multithreaded quantization support for the new types: Q4_0_4_4, Q4_0_4_8, and Q4_0_8_8 * Arm AArch64: minor code refactoring * Arm AArch64: simplify logic for calling gemm and gemv functions in ggml_compute_forward_mul_mat * Arm AArch64: minimize changes in ggml_compute_forward_mul_mat * Arm AArch64: minor code refactoring, and add reference scalar code to quantize routines for new quant types * Arm AArch64: minor code refactoring * Arm AArch64: minor code refactoring * Arm AArch64: minor code refactoring * rebase on the latest master commit 3fd62a6 and adapt to the new directory structure * Arm AArch64: remove a redundant comment * Arm AArch64: add pragma in ggml-aarch64.c to turn -Woverlength-strings warning off * Arm AArch64: use __aarch64__ check to guard 64-bit neon kernels * Arm AArch64: update docs/build.md README to include compile time flags for buiilding the Q4_0_4_4 quant type --- Makefile | 10 +- Package.swift | 1 + docs/build.md | 2 + examples/quantize/quantize.cpp | 3 + ggml/include/ggml.h | 17 + ggml/src/CMakeLists.txt | 1 + ggml/src/ggml-aarch64.c | 2187 ++++++++++++++++++++++++++++++++ ggml/src/ggml-aarch64.h | 39 + ggml/src/ggml-common.h | 24 + ggml/src/ggml-impl.h | 4 + ggml/src/ggml-quants.c | 122 +- ggml/src/ggml.c | 150 ++- include/llama.h | 3 + src/llama.cpp | 24 +- 14 files changed, 2534 insertions(+), 53 deletions(-) create mode 100644 ggml/src/ggml-aarch64.c create mode 100644 ggml/src/ggml-aarch64.h diff --git a/Makefile b/Makefile index 104120090..b70ebaed5 100644 --- a/Makefile +++ b/Makefile @@ -835,7 +835,8 @@ OBJ_GGML += \ ggml/src/ggml.o \ ggml/src/ggml-alloc.o \ ggml/src/ggml-backend.o \ - ggml/src/ggml-quants.o + ggml/src/ggml-quants.o \ + ggml/src/ggml-aarch64.o OBJ_LLAMA = \ src/llama.o \ @@ -969,6 +970,13 @@ ggml/src/ggml-quants.o: \ ggml/src/ggml-common.h $(CC) $(CFLAGS) -c $< -o $@ +ggml/src/ggml-aarch64.o: \ + ggml/src/ggml-aarch64.c \ + ggml/include/ggml.h \ + ggml/src/ggml-aarch64.h \ + ggml/src/ggml-common.h + $(CC) $(CFLAGS) -c $< -o $@ + ggml/src/ggml-blas.o: \ ggml/src/ggml-blas.cpp \ ggml/include/ggml-blas.h diff --git a/Package.swift b/Package.swift index 77fed86df..d40a48385 100644 --- a/Package.swift +++ b/Package.swift @@ -10,6 +10,7 @@ var sources = [ "ggml/src/ggml-alloc.c", "ggml/src/ggml-backend.c", "ggml/src/ggml-quants.c", + "ggml/src/ggml-aarch64.c", ] var resources: [Resource] = [] diff --git a/docs/build.md b/docs/build.md index bf41bfdf9..d70f72f4c 100644 --- a/docs/build.md +++ b/docs/build.md @@ -28,6 +28,7 @@ In order to build llama.cpp you have four different options. ``` - Notes: + - For `Q4_0_4_4` quantization type build, add the `GGML_NO_LLAMAFILE=1` flag. For example, use `make GGML_NO_LLAMAFILE=1`. - For faster compilation, add the `-j` argument to run multiple jobs in parallel. For example, `make -j 8` will run 8 jobs in parallel. - For faster repeated compilation, install [ccache](https://ccache.dev/). - For debug builds, run `make LLAMA_DEBUG=1` @@ -41,6 +42,7 @@ In order to build llama.cpp you have four different options. **Notes**: + - For `Q4_0_4_4` quantization type build, add the `-DGGML_LLAMAFILE=OFF` cmake option. For example, use `cmake -B build -DGGML_LLAMAFILE=OFF`. - For faster compilation, add the `-j` argument to run multiple jobs in parallel. For example, `cmake --build build --config Release -j 8` will run 8 jobs in parallel. - For faster repeated compilation, install [ccache](https://ccache.dev/). - For debug builds, there are two cases: diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 76e2052d5..1578c4afb 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -46,6 +46,9 @@ static const std::vector QUANT_OPTIONS = { { "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 5.33G, +0.0569 ppl @ Llama-3-8B", }, { "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 6.14G, +0.0217 ppl @ Llama-3-8B", }, { "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 7.96G, +0.0026 ppl @ Llama-3-8B", }, + { "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, + { "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, + { "Q4_0_8_8", LLAMA_FTYPE_MOSTLY_Q4_0_8_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, { "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, +0.0020 ppl @ Mistral-7B", }, { "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", }, { "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index d895c9acd..1e3677537 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -383,6 +383,9 @@ extern "C" { GGML_TYPE_F64 = 28, GGML_TYPE_IQ1_M = 29, GGML_TYPE_BF16 = 30, + GGML_TYPE_Q4_0_4_4 = 31, + GGML_TYPE_Q4_0_4_8 = 32, + GGML_TYPE_Q4_0_8_8 = 33, GGML_TYPE_COUNT, }; @@ -424,6 +427,9 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors }; // available tensor operations: @@ -2406,6 +2412,12 @@ extern "C" { typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx, const void * GGML_RESTRICT y, size_t by, int nrc); + typedef void (*ggml_from_float_to_mat_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr, + int64_t k, int64_t bx); + typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, + const void * GGML_RESTRICT y, int nr, int nc); + typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, + const void * GGML_RESTRICT y, int nr, int nc); typedef struct { const char * type_name; @@ -2418,6 +2430,11 @@ extern "C" { ggml_vec_dot_t vec_dot; enum ggml_type vec_dot_type; int64_t nrows; // number of rows to process simultaneously; + int64_t ncols; // number of columns to process simultaneously; + int64_t interleave_blcksize; // interleave elements in blocks of interleave_blcksize; + ggml_from_float_to_mat_t from_float_to_mat; + ggml_gemv_t gemv; + ggml_gemm_t gemm; } ggml_type_traits_t; GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type); diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index c6694df67..aae5b8e9f 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -1153,6 +1153,7 @@ add_library(ggml ${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM} ${GGML_SOURCES_BLAS} ${GGML_HEADERS_BLAS} ${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE} + ggml-aarch64.c ggml-aarch64.h ) if (EMSCRIPTEN) diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c new file mode 100644 index 000000000..008718634 --- /dev/null +++ b/ggml/src/ggml-aarch64.c @@ -0,0 +1,2187 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd. +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" + +#include "ggml-quants.h" +#include "ggml-impl.h" + +#include +#include +#include +#include +#include // for qsort +#include // for GGML_ASSERT + +#include "ggml-aarch64.h" + +#pragma GCC diagnostic ignored "-Woverlength-strings" + +#define UNUSED GGML_UNUSED + +// Functions to create the interleaved data layout formats + +// interleave 4 block_q4_0s in blocks of interleave_blcksize +// returns an interleaved block_q4_0x4 +// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks +// first, then interleave quants from 4 block_q4_0s in blocks of interleave_blcksize +// +// - in : an array of block_q4_0 pointers +// - interleave_blcksize : the block_q4_0 quants bytes are interleaved in blocks of +// interleave_blcksize bytes +// - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes +// from bias offset form to pure sign form (this saves subtract +// operations durin unpacking) +// +static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int interleave_blcksize, unsigned int xor_mask) { + block_q4_0x4 out; + + for (int i = 0; i < 4; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < QK4_0 * 2; i++) { + int src_offset = (i / (4 * interleave_blcksize)) * interleave_blcksize; + int src_id = (i % (4 * interleave_blcksize)) / interleave_blcksize; + src_offset += (i % interleave_blcksize); + + out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; + } + + return out; +} + +// interleave 8 block_q4_0s in blocks of interleave_blcksize +// returns an interleaved block_q4_0x8 +// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks +// first, then interleave quants from 8 block_q4_0s in blocks of interleave_blcksize +static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int interleave_blcksize, unsigned int xor_mask) { + block_q4_0x8 out; + + for (int i = 0; i < 8; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < QK4_0 * 4; i++) { + int src_offset = (i / (8 * interleave_blcksize)) * interleave_blcksize; + int src_id = (i % (8 * interleave_blcksize)) / interleave_blcksize; + src_offset += (i % interleave_blcksize); + + out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; + } + + return out; +} + +void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * restrict y = (block_q8_0x4 *) vy; + +#if defined(__ARM_NEON) + float32x4_t srcv[4][8]; + float id[4]; + + for (int i = 0; i < nb; i++) { + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int row_iter = 0; row_iter < 4; row_iter++) { + for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); + + for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); + for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); + for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < 8; j++) { + float32x4_t v = vmulq_n_f32(srcv[0][j], id[0]); + int32x4_t vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[1][j], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[2][j], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[3][j], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3); + } + } +#else + // scalar + const int interleave_blcksize = 4; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * interleave_blcksize)) * interleave_blcksize; + int src_id = (j % (4 * interleave_blcksize)) / interleave_blcksize; + src_offset += (j % interleave_blcksize); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0);; + } + } +#endif +} + +void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * restrict y = (block_q8_0x4 *) vy; + +#if defined(__ARM_NEON) + float32x4_t srcv[4][8]; + float id[4]; + + for (int i = 0; i < nb; i++) { + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int row_iter = 0; row_iter < 4; row_iter++) { + for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); + + for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); + for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); + for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < 4; j++) { + float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]); + int32x4_t vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[1][2 * j], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[2][2 * j], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[3][2 * j], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3); + } + } +#else + // scalar + const int interleave_blcksize = 8; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * interleave_blcksize)) * interleave_blcksize; + int src_id = (j % (4 * interleave_blcksize)) / interleave_blcksize; + src_offset += (j % interleave_blcksize); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0);; + } + } +#endif +} + +void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t interleave_blcksize) { + assert(nrow == 4); + UNUSED(nrow); + if (interleave_blcksize == 4) quantize_q8_0_4x4(x, vy, n_per_row); + else if (interleave_blcksize == 8) quantize_q8_0_4x8(x, vy, n_per_row); + else assert(false); +} + +static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int interleave_blcksize) { + assert(n_per_row % QK4_0 == 0); + const int nb = n_per_row / QK4_0; + + void * out_ptr = NULL; + if (nrows_interleaved == 8) { + out_ptr = (block_q4_0x8 *) dst; + } + else if (nrows_interleaved == 4) { + out_ptr = (block_q4_0x4 *) dst; + } + assert(nrows_interleaved <= 8); + block_q4_0 dst_tmp[8]; + + for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) { + + for (int64_t x = 0; x < nb; x++) { + + for (int i = 0; i < nrows_interleaved; i++ ) { + quantize_row_q4_0_reference(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0); + } + + if (nrows_interleaved == 8) { + *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, interleave_blcksize, 0x88); + out_ptr = (block_q4_0x8 *) out_ptr + 1; + } + else if (nrows_interleaved == 4) { + *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, interleave_blcksize, 0x88); + out_ptr = (block_q4_0x4 *) out_ptr + 1; + } + } + } + + return ((nrow * n_per_row) / QK4_0 * sizeof(block_q4_0)); +} + +size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4); + } + else { + assert(false); + return 0; + } +} + +size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8); + } + else { + assert(false); + return 0; + } +} + +size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8); + } + else { + assert(false); + return 0; + } +} + +void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__ARM_FEATURE_SVE) + if (svcntw() == 8) { + GGML_ASSERT(!(ggml_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + GGML_ASSERT(!(ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) && + "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance"); +#elif defined(__ARM_NEON) && defined(__aarch64__) + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + + __asm__ __volatile__( + "movi v31.16b, #0x4\n" + "movi v30.16b, #0xf0\n" + "add %x[b_ptr], %x[b_ptr], #0x8\n" + "1:" // Column loop + "add x22, %x[a_ptr], #0x2\n" + "movi v29.16b, #0x0\n" + "mov x21, %x[nb]\n" + "2:" // Block loop + "ldr q28, [%x[b_ptr], #0x0]\n" + "ldr q27, [x22, #0x0]\n" + "movi v26.4s, #0x0\n" + "sub x20, x22, #0x2\n" + "ldr q25, [x22, #0x10]\n" + "ldr q24, [%x[b_ptr], #0x10]\n" + "sub x21, x21, #0x1\n" + "add x22, x22, #0x22\n" + "ldr q23, [%x[b_ptr], #0x20]\n" + "ldr q22, [%x[b_ptr], #0x30]\n" + "ld1r { v21.8h }, [x20]\n" + "ldr q20, [%x[b_ptr], #-0x8]\n" + "sshl v16.16b, v28.16b, v31.16b\n" + "and v28.16b, v28.16b, v30.16b\n" + "sshl v19.16b, v24.16b, v31.16b\n" + "and v24.16b, v24.16b, v30.16b\n" + "add %x[b_ptr], %x[b_ptr], #0x48\n" + "sshl v18.16b, v23.16b, v31.16b\n" + "and v23.16b, v23.16b, v30.16b\n" + ".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n" + "sshl v17.16b, v22.16b, v31.16b\n" + "and v22.16b, v22.16b, v30.16b\n" + "fcvtl v21.4s, v21.4h\n" + "fcvtl v16.4s, v20.4h\n" + ".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n" + "fmul v16.4s, v16.4s, v21.4s\n" + ".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n" + ".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n" + ".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n" + ".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n" + ".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n" + ".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "fmla v29.4s, v26.4s, v16.4s\n" + "cbnz x21, 2b\n" + "sub %x[nc], %x[nc], #0x4\n" + "str q29, [%x[res_ptr], #0x0]\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "cbnz %x[nc], 1b\n" + : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) + : [a_ptr] "r" (a_ptr), [nb] "r" (nb) + : "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22" + ); +#else + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +#endif +} + +void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__ARM_FEATURE_SVE) + if (svcntw() == 8) { + GGML_ASSERT(!(ggml_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + + __asm__ __volatile__( + "movi v2.16b, #0x4\n" + "movi v1.16b, #0xf0\n" + "add %x[b_ptr], %x[b_ptr], #0x8\n" + "1:" // Column loop + "add x23, %x[a_ptr], #0x2\n" + "movi v0.16b, #0x0\n" + "mov x22, %x[nb]\n" + "2:" // Block loop + "ldr q31, [%x[b_ptr], #0x0]\n" + "ldr q30, [%x[b_ptr], #0x10]\n" + "mov x21, x23\n" + "movi v29.4s, #0x0\n" + "ldr q28, [%x[b_ptr], #0x20]\n" + "ldr q27, [%x[b_ptr], #0x30]\n" + "movi v26.4s, #0x0\n" + "sub x20, x23, #0x2\n" + "ld1r { v25.8h }, [x20]\n" + "ldr q24, [%x[b_ptr], #-0x8]\n" + "sub x22, x22, #0x1\n" + "add x23, x23, #0x22\n" + "ld1r { v23.2d }, [x21], #0x8\n" + "sshl v22.16b, v31.16b, v2.16b\n" + "sshl v16.16b, v30.16b, v2.16b\n" + "add %x[b_ptr], %x[b_ptr], #0x48\n" + "ld1r { v21.2d }, [x21], #0x8\n" + "sshl v20.16b, v28.16b, v2.16b\n" + "sshl v19.16b, v27.16b, v2.16b\n" + "ld1r { v18.2d }, [x21], #0x8\n" + "ld1r { v17.2d }, [x21], #0x8\n" + "and v31.16b, v31.16b, v1.16b\n" + "and v30.16b, v30.16b, v1.16b\n" + ".inst 0x4e9796dd // sdot v29.4s, v22.16b, v23.16b\n" + ".inst 0x4e97961a // sdot v26.4s, v16.16b, v23.16b\n" + "and v28.16b, v28.16b, v1.16b\n" + "and v27.16b, v27.16b, v1.16b\n" + "fcvtl v25.4s, v25.4h\n" + "fcvtl v16.4s, v24.4h\n" + ".inst 0x4e95969d // sdot v29.4s, v20.16b, v21.16b\n" + ".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n" + "fmul v16.4s, v16.4s, v25.4s\n" + ".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n" + ".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n" + ".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n" + ".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n" + "addp v29.4s, v29.4s, v26.4s\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "fmla v0.4s, v29.4s, v16.4s\n" + "cbnz x22, 2b\n" + "sub %x[nc], %x[nc], #0x4\n" + "str q0, [%x[res_ptr], #0x0]\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "cbnz %x[nc], 1b\n" + : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) + : [a_ptr] "r" (a_ptr), [nb] "r" (nb) + : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23" + ); +#elif defined(__ARM_NEON) && defined(__aarch64__) + GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " + "performance"); +#else + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +#endif +} + +void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__ARM_FEATURE_SVE) + if (svcntw() == 8) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + + __asm__ __volatile__( + "ptrue p0.b\n" + "add %x[b_ptr], %x[b_ptr], #0x10\n" + "1:" // Column loop + "add x22, %x[a_ptr], #0x2\n" + "mov z31.b, #0x0\n" + "mov x21, %x[nb]\n" + "2:" // Block loop + "ld1b { z30.b }, p0/Z, [%x[b_ptr]]\n" + "ld1b { z29.b }, p0/Z, [%x[b_ptr], #1, MUL VL]\n" + "mov z28.s, #0x0\n" + "mov z27.s, #0x0\n" + "ld1rd { z26.d }, p0/Z, [x22]\n" + "ld1b { z25.b }, p0/Z, [%x[b_ptr], #2, MUL VL]\n" + "sub x20, x22, #0x2\n" + "sub x21, x21, #0x1\n" + "ld1b { z24.b }, p0/Z, [%x[b_ptr], #3, MUL VL]\n" + "ld1rd { z23.d }, p0/Z, [x22, #8]\n" + "lsl z22.b, z30.b, #0x4\n" + "lsl z16.b, z29.b, #0x4\n" + "and z30.b, z30.b, #0xf0\n" + "and z29.b, z29.b, #0xf0\n" + "ld1rd { z21.d }, p0/Z, [x22, #16]\n" + "ld1rd { z20.d }, p0/Z, [x22, #24]\n" + "lsl z19.b, z25.b, #0x4\n" + "and z25.b, z25.b, #0xf0\n" + "ld1rh { z17.h }, p0/Z, [x20]\n" + "ld1h { z18.s }, p0/Z, [%x[b_ptr], #-1, MUL VL]\n" + "sdot z28.s, z22.b, z26.b\n" + "sdot z27.s, z16.b, z26.b\n" + "lsl z16.b, z24.b, #0x4\n" + "add x22, x22, #0x22\n" + "and z24.b, z24.b, #0xf0\n" + "add %x[b_ptr], %x[b_ptr], #0x90\n" + "fcvt z17.s, p0/m, z17.h\n" + "fcvt z18.s, p0/m, z18.h\n" + "sdot z28.s, z19.b, z23.b\n" + "sdot z27.s, z16.b, z23.b\n" + "fmul z18.s, z18.s, z17.s\n" + "sdot z28.s, z30.b, z21.b\n" + "sdot z27.s, z29.b, z21.b\n" + "sdot z28.s, z25.b, z20.b\n" + "sdot z27.s, z24.b, z20.b\n" + "uzp1 z17.s, z28.s, z27.s\n" + "uzp2 z16.s, z28.s, z27.s\n" + "add z17.s, z17.s, z16.s\n" + "asr z17.s, z17.s, #0x4\n" + "scvtf z17.s, p0/m, z17.s\n" + "fmla z31.s, p0/M, z17.s, z18.s\n" + "cbnz x21, 2b\n" + "sub %x[nc], %x[nc], #0x8\n" + "st1w { z31.s }, p0, [%x[res_ptr]]\n" + "add %x[res_ptr], %x[res_ptr], #0x20\n" + "cbnz %x[nc], 1b\n" + : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) + : [a_ptr] "r" (a_ptr), [nb] "r" (nb) + : "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); + return; + } + else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + GGML_ASSERT((ggml_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal " + "performance"); + } + else if (ggml_cpu_has_neon()) { + GGML_ASSERT(((ggml_cpu_has_sve() && (svcntw() == 8)) || ggml_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 " + "quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + GGML_ASSERT(ggml_cpu_has_sve() && + "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance"); +#elif defined(__ARM_NEON) && defined(__aarch64__) + GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " + "performance"); +#else + float sumf[8]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +#endif +} + +void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (svcntw() == 8) { + GGML_ASSERT(!(ggml_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + GGML_ASSERT(!(ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) && + "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance"); +#elif defined(__ARM_NEON) && defined(__aarch64__) + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x10, %x[nr]\n" + "mov x9, #0x88\n" + "cmp x10, #0x10\n" + "mul x9, %x[nb], x9\n" + "blt 4f\n" + "1:" // Row loop + "add x28, %x[b_ptr], #0x8\n" + "mov x27, %x[nc]\n" + "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x25, %x[a_ptr], #0x8\n" + "movi v15.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "mov x24, %x[nb]\n" + "add x23, x25, x9\n" + "movi v18.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "add x22, x23, x9\n" + "movi v11.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "add x21, x22, x9\n" + "movi v23.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "movi v0.16b, #0x0\n" + "movi v4.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v8.16b, #0x0\n" + "movi v1.16b, #0x0\n" + "3:" // Block loop + "ldr q3, [x28, #0x0]\n" + "ldr q31, [x25, #0x0]\n" + "movi v28.16b, #0x4\n" + "movi v10.4s, #0x0\n" + "ldr q22, [x28, #0x10]\n" + "ldr q6, [x25, #0x10]\n" + "movi v29.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "ldr q27, [x28, #0x20]\n" + "ldr q30, [x28, #0x30]\n" + "movi v20.4s, #0x0\n" + "movi v24.16b, #0xf0\n" + "ldr d2, [x25, #-0x8]\n" + "ldr d26, [x23, #-0x8]\n" + "sshl v12.16b, v3.16b, v28.16b\n" + "sub x20, x28, #0x8\n" + "ldr d17, [x20, #0x0]\n" + "and v3.16b, v3.16b, v24.16b\n" + "subs x24, x24, #0x1\n" + "add x28, x28, #0x48\n" + ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n" + ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n" + ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n" + ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n" + "sshl v31.16b, v22.16b, v28.16b\n" + "and v22.16b, v22.16b, v24.16b\n" + "fcvtl v17.4s, v17.4h\n" + "fcvtl v2.4s, v2.4h\n" + "fcvtl v26.4s, v26.4h\n" + ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n" + ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n" + ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n" + ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n" + "sshl v6.16b, v27.16b, v28.16b\n" + "sshl v28.16b, v30.16b, v28.16b\n" + "and v27.16b, v27.16b, v24.16b\n" + "and v30.16b, v30.16b, v24.16b\n" + "ldr q24, [x25, #0x20]\n" + ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x30]\n" + ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n" + ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n" + ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n" + ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x40]\n" + ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x50]\n" + ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n" + ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n" + ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n" + ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x60]\n" + ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x70]\n" + "add x25, x25, #0x88\n" + ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n" + ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n" + ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n" + ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n" + "fmul v24.4s, v17.4s, v2.s[0]\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v15.4s, v10.4s, v24.4s\n" + "ldr q24, [x23, #0x0]\n" + "fmul v10.4s, v17.4s, v2.s[1]\n" + "fmla v19.4s, v29.4s, v10.4s\n" + "ldr q10, [x23, #0x10]\n" + "fmul v29.4s, v17.4s, v2.s[2]\n" + "fmul v2.4s, v17.4s, v2.s[3]\n" + "fmla v18.4s, v9.4s, v29.4s\n" + "movi v9.4s, #0x0\n" + "movi v29.4s, #0x0\n" + ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n" + "fmla v14.4s, v20.4s, v2.4s\n" + "movi v20.4s, #0x0\n" + "movi v2.4s, #0x0\n" + ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x20]\n" + ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n" + ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n" + ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n" + ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x30]\n" + ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x40]\n" + ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n" + ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n" + ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n" + ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x50]\n" + ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x60]\n" + ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n" + ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n" + ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n" + ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x70]\n" + "add x23, x23, #0x88\n" + ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x0]\n" + ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n" + ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n" + ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n" + ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n" + "fmul v10.4s, v17.4s, v26.s[0]\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "fmla v11.4s, v9.4s, v10.4s\n" + "ldr q9, [x22, #0x10]\n" + "fmul v10.4s, v17.4s, v26.s[1]\n" + "fmla v13.4s, v29.4s, v10.4s\n" + "ldr d29, [x22, #-0x8]\n" + "fmul v10.4s, v17.4s, v26.s[2]\n" + "fmul v26.4s, v17.4s, v26.s[3]\n" + "fcvtl v29.4s, v29.4h\n" + "fmla v23.4s, v20.4s, v10.4s\n" + "movi v20.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "fmla v16.4s, v2.4s, v26.4s\n" + "movi v26.4s, #0x0\n" + "movi v2.4s, #0x0\n" + ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" + ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x20]\n" + ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" + ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n" + ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x30]\n" + ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x40]\n" + ".inst 0x4f89e394 // sdot v20.4s, v28.16b, v9.4b[0]\n" + ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" + ".inst 0x4f89eb9a // sdot v26.4s, v28.16b, v9.4b[2]\n" + ".inst 0x4fa9eb82 // sdot v2.4s, v28.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x50]\n" + ".inst 0x4f98e074 // sdot v20.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e06a // sdot v10.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e87a // sdot v26.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x60]\n" + ".inst 0x4f89e2d4 // sdot v20.4s, v22.16b, v9.4b[0]\n" + ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" + ".inst 0x4f89eada // sdot v26.4s, v22.16b, v9.4b[2]\n" + ".inst 0x4fa9eac2 // sdot v2.4s, v22.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x70]\n" + "add x22, x22, #0x88\n" + ".inst 0x4f98e374 // sdot v20.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e36a // sdot v10.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb7a // sdot v26.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x21, #0x0]\n" + ".inst 0x4f89e3d4 // sdot v20.4s, v30.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ca // sdot v10.4s, v30.16b, v9.4b[1]\n" + ".inst 0x4f89ebda // sdot v26.4s, v30.16b, v9.4b[2]\n" + ".inst 0x4fa9ebc2 // sdot v2.4s, v30.16b, v9.4b[3]\n" + "fmul v9.4s, v17.4s, v29.s[0]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "fmla v25.4s, v20.4s, v9.4s\n" + "ldr q9, [x21, #0x10]\n" + "fmul v20.4s, v17.4s, v29.s[1]\n" + "fmla v7.4s, v10.4s, v20.4s\n" + "ldr d20, [x21, #-0x8]\n" + "fmul v10.4s, v17.4s, v29.s[2]\n" + "fmul v29.4s, v17.4s, v29.s[3]\n" + "fcvtl v20.4s, v20.4h\n" + "fmla v0.4s, v26.4s, v10.4s\n" + "movi v26.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "fmla v4.4s, v2.4s, v29.4s\n" + "movi v2.4s, #0x0\n" + "movi v29.4s, #0x0\n" + ".inst 0x4f98e19a // sdot v26.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" + ".inst 0x4f98e982 // sdot v2.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e99d // sdot v29.4s, v12.16b, v24.4b[3]\n" + "ldr q12, [x21, #0x20]\n" + "fmul v24.4s, v17.4s, v20.s[0]\n" + ".inst 0x4f89e3fa // sdot v26.4s, v31.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" + ".inst 0x4f89ebe2 // sdot v2.4s, v31.16b, v9.4b[2]\n" + ".inst 0x4fa9ebfd // sdot v29.4s, v31.16b, v9.4b[3]\n" + "ldr q9, [x21, #0x30]\n" + "fmul v31.4s, v17.4s, v20.s[1]\n" + ".inst 0x4f8ce0da // sdot v26.4s, v6.16b, v12.4b[0]\n" + ".inst 0x4face0ca // sdot v10.4s, v6.16b, v12.4b[1]\n" + ".inst 0x4f8ce8c2 // sdot v2.4s, v6.16b, v12.4b[2]\n" + ".inst 0x4face8dd // sdot v29.4s, v6.16b, v12.4b[3]\n" + "ldr q12, [x21, #0x40]\n" + "fmul v6.4s, v17.4s, v20.s[2]\n" + "fmul v20.4s, v17.4s, v20.s[3]\n" + ".inst 0x4f89e39a // sdot v26.4s, v28.16b, v9.4b[0]\n" + ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" + ".inst 0x4f89eb82 // sdot v2.4s, v28.16b, v9.4b[2]\n" + ".inst 0x4fa9eb9d // sdot v29.4s, v28.16b, v9.4b[3]\n" + "ldr q9, [x21, #0x50]\n" + ".inst 0x4f8ce07a // sdot v26.4s, v3.16b, v12.4b[0]\n" + ".inst 0x4face06a // sdot v10.4s, v3.16b, v12.4b[1]\n" + ".inst 0x4f8ce862 // sdot v2.4s, v3.16b, v12.4b[2]\n" + ".inst 0x4face87d // sdot v29.4s, v3.16b, v12.4b[3]\n" + "ldr q12, [x21, #0x60]\n" + ".inst 0x4f89e2da // sdot v26.4s, v22.16b, v9.4b[0]\n" + ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" + ".inst 0x4f89eac2 // sdot v2.4s, v22.16b, v9.4b[2]\n" + ".inst 0x4fa9eadd // sdot v29.4s, v22.16b, v9.4b[3]\n" + "ldr q17, [x21, #0x70]\n" + "add x21, x21, #0x88\n" + ".inst 0x4f8ce37a // sdot v26.4s, v27.16b, v12.4b[0]\n" + ".inst 0x4face36a // sdot v10.4s, v27.16b, v12.4b[1]\n" + ".inst 0x4f8ceb62 // sdot v2.4s, v27.16b, v12.4b[2]\n" + ".inst 0x4faceb7d // sdot v29.4s, v27.16b, v12.4b[3]\n" + ".inst 0x4f91e3da // sdot v26.4s, v30.16b, v17.4b[0]\n" + ".inst 0x4fb1e3ca // sdot v10.4s, v30.16b, v17.4b[1]\n" + ".inst 0x4f91ebc2 // sdot v2.4s, v30.16b, v17.4b[2]\n" + ".inst 0x4fb1ebdd // sdot v29.4s, v30.16b, v17.4b[3]\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "fmla v5.4s, v26.4s, v24.4s\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "fmla v21.4s, v10.4s, v31.4s\n" + "fmla v8.4s, v2.4s, v6.4s\n" + "fmla v1.4s, v29.4s, v20.4s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x27, x27, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "str q15, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q19, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q18, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q14, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q11, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q13, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q16, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q7, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q0, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q4, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q5, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q21, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q8, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q1, [x20, #0x0]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x10, x10, #0x10\n" + "cmp x10, #0x10\n" + "mov %x[res_ptr], x26\n" + "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x10, 9f\n" + "5:" // Row tail: Row loop + "add x24, %x[b_ptr], #0x8\n" + "mov x23, %x[nc]\n" + "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "movi v15.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "add x25, %x[a_ptr], #0x8\n" + "mov x21, %x[nb]\n" + "movi v18.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "7:" // Row tail: Block loop + "ldr q7, [x24, #0x0]\n" + "ldr q5, [x25, #0x0]\n" + "movi v9.16b, #0x4\n" + "movi v4.4s, #0x0\n" + "ldr q3, [x24, #0x10]\n" + "ldr q2, [x25, #0x10]\n" + "movi v1.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "ldr q13, [x24, #0x20]\n" + "ldr q31, [x25, #0x20]\n" + "movi v30.4s, #0x0\n" + "movi v29.16b, #0xf0\n" + "ldr q28, [x24, #0x30]\n" + "ldr q27, [x25, #0x30]\n" + "sshl v20.16b, v7.16b, v9.16b\n" + "sub x20, x24, #0x8\n" + "ldr q26, [x25, #0x40]\n" + "ldr q25, [x25, #0x50]\n" + "sshl v17.16b, v3.16b, v9.16b\n" + "and v7.16b, v7.16b, v29.16b\n" + "ldr q24, [x25, #0x60]\n" + "ldr q16, [x25, #0x70]\n" + "sshl v22.16b, v13.16b, v9.16b\n" + "and v3.16b, v3.16b, v29.16b\n" + "ldr d21, [x20, #0x0]\n" + "ldr d12, [x25, #-0x8]\n" + ".inst 0x4f85e284 // sdot v4.4s, v20.16b, v5.4b[0]\n" + ".inst 0x4fa5e281 // sdot v1.4s, v20.16b, v5.4b[1]\n" + ".inst 0x4f85ea80 // sdot v0.4s, v20.16b, v5.4b[2]\n" + ".inst 0x4fa5ea9e // sdot v30.4s, v20.16b, v5.4b[3]\n" + "sshl v9.16b, v28.16b, v9.16b\n" + "subs x21, x21, #0x1\n" + "and v13.16b, v13.16b, v29.16b\n" + "and v28.16b, v28.16b, v29.16b\n" + "add x25, x25, #0x88\n" + "add x24, x24, #0x48\n" + "fcvtl v21.4s, v21.4h\n" + "fcvtl v12.4s, v12.4h\n" + ".inst 0x4f82e224 // sdot v4.4s, v17.16b, v2.4b[0]\n" + ".inst 0x4fa2e221 // sdot v1.4s, v17.16b, v2.4b[1]\n" + ".inst 0x4f82ea20 // sdot v0.4s, v17.16b, v2.4b[2]\n" + ".inst 0x4fa2ea3e // sdot v30.4s, v17.16b, v2.4b[3]\n" + "fmul v11.4s, v21.4s, v12.s[0]\n" + "fmul v23.4s, v21.4s, v12.s[1]\n" + "fmul v17.4s, v21.4s, v12.s[2]\n" + ".inst 0x4f9fe2c4 // sdot v4.4s, v22.16b, v31.4b[0]\n" + "fmul v6.4s, v21.4s, v12.s[3]\n" + ".inst 0x4fbfe2c1 // sdot v1.4s, v22.16b, v31.4b[1]\n" + ".inst 0x4f9feac0 // sdot v0.4s, v22.16b, v31.4b[2]\n" + ".inst 0x4fbfeade // sdot v30.4s, v22.16b, v31.4b[3]\n" + ".inst 0x4f9be124 // sdot v4.4s, v9.16b, v27.4b[0]\n" + ".inst 0x4fbbe121 // sdot v1.4s, v9.16b, v27.4b[1]\n" + ".inst 0x4f9be920 // sdot v0.4s, v9.16b, v27.4b[2]\n" + ".inst 0x4fbbe93e // sdot v30.4s, v9.16b, v27.4b[3]\n" + ".inst 0x4f9ae0e4 // sdot v4.4s, v7.16b, v26.4b[0]\n" + ".inst 0x4fbae0e1 // sdot v1.4s, v7.16b, v26.4b[1]\n" + ".inst 0x4f9ae8e0 // sdot v0.4s, v7.16b, v26.4b[2]\n" + ".inst 0x4fbae8fe // sdot v30.4s, v7.16b, v26.4b[3]\n" + ".inst 0x4f99e064 // sdot v4.4s, v3.16b, v25.4b[0]\n" + ".inst 0x4fb9e061 // sdot v1.4s, v3.16b, v25.4b[1]\n" + ".inst 0x4f99e860 // sdot v0.4s, v3.16b, v25.4b[2]\n" + ".inst 0x4fb9e87e // sdot v30.4s, v3.16b, v25.4b[3]\n" + ".inst 0x4f98e1a4 // sdot v4.4s, v13.16b, v24.4b[0]\n" + ".inst 0x4fb8e1a1 // sdot v1.4s, v13.16b, v24.4b[1]\n" + ".inst 0x4f98e9a0 // sdot v0.4s, v13.16b, v24.4b[2]\n" + ".inst 0x4fb8e9be // sdot v30.4s, v13.16b, v24.4b[3]\n" + ".inst 0x4f90e384 // sdot v4.4s, v28.16b, v16.4b[0]\n" + ".inst 0x4fb0e381 // sdot v1.4s, v28.16b, v16.4b[1]\n" + ".inst 0x4f90eb80 // sdot v0.4s, v28.16b, v16.4b[2]\n" + ".inst 0x4fb0eb9e // sdot v30.4s, v28.16b, v16.4b[3]\n" + "scvtf v4.4s, v4.4s, #0x4\n" + "scvtf v1.4s, v1.4s, #0x4\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "fmla v15.4s, v4.4s, v11.4s\n" + "scvtf v30.4s, v30.4s, #0x4\n" + "fmla v19.4s, v1.4s, v23.4s\n" + "fmla v18.4s, v0.4s, v17.4s\n" + "fmla v14.4s, v30.4s, v6.4s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x10, #0x1\n" + "str q15, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x2\n" + "str q19, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x3\n" + "str q18, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "str q14, [x20, #0x0]\n" + "8:" // Row tail: Accumulator store skip + "subs x23, x23, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "bne 6b\n" + "subs x10, x10, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x9\n" + "mov %x[res_ptr], x22\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +#else + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +#endif +} + +void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (svcntw() == 8) { + GGML_ASSERT(!(ggml_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x10, %x[nr]\n" + "mov x9, #0x88\n" + "cmp x10, #0x10\n" + "mul x9, %x[nb], x9\n" + "blt 4f\n" + "1:" // Row loop + "add x28, %x[b_ptr], #0x8\n" + "mov x27, %x[nc]\n" + "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x25, %x[a_ptr], #0x8\n" + "movi v2.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "mov x24, %x[nb]\n" + "add x23, x25, x9\n" + "movi v12.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "add x22, x23, x9\n" + "movi v11.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "add x21, x22, x9\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "movi v4.16b, #0x0\n" + "movi v6.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "3:" // Block loop + "ldr q21, [x28, #0x0]\n" + "ldr q16, [x28, #0x10]\n" + "movi v1.16b, #0x4\n" + "movi v19.4s, #0x0\n" + "ldr q27, [x25, #0x0]\n" + "ldr q15, [x25, #0x10]\n" + "movi v26.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "ldr q29, [x28, #0x20]\n" + "ldr q3, [x28, #0x30]\n" + "movi v17.4s, #0x0\n" + "movi v0.16b, #0xf0\n" + "ldr d20, [x25, #-0x8]\n" + "ldr d9, [x23, #-0x8]\n" + "sshl v8.16b, v21.16b, v1.16b\n" + "sshl v31.16b, v16.16b, v1.16b\n" + "and v21.16b, v21.16b, v0.16b\n" + "and v16.16b, v16.16b, v0.16b\n" + "sub x20, x28, #0x8\n" + "subs x24, x24, #0x1\n" + "add x28, x28, #0x48\n" + ".inst 0x4e88a773 // smmla v19.4s, v27.16b, v8.16b\n" + ".inst 0x4e9fa77a // smmla v26.4s, v27.16b, v31.16b\n" + "ldr q27, [x25, #0x20]\n" + ".inst 0x4e88a5f2 // smmla v18.4s, v15.16b, v8.16b\n" + ".inst 0x4e9fa5f1 // smmla v17.4s, v15.16b, v31.16b\n" + "sshl v15.16b, v29.16b, v1.16b\n" + "sshl v1.16b, v3.16b, v1.16b\n" + "and v29.16b, v29.16b, v0.16b\n" + "and v3.16b, v3.16b, v0.16b\n" + "ldr q0, [x25, #0x30]\n" + "fcvtl v20.4s, v20.4h\n" + ".inst 0x4e8fa773 // smmla v19.4s, v27.16b, v15.16b\n" + "fcvtl v9.4s, v9.4h\n" + ".inst 0x4e81a77a // smmla v26.4s, v27.16b, v1.16b\n" + "ldr q27, [x25, #0x40]\n" + ".inst 0x4e8fa412 // smmla v18.4s, v0.16b, v15.16b\n" + ".inst 0x4e81a411 // smmla v17.4s, v0.16b, v1.16b\n" + "ldr q0, [x25, #0x50]\n" + ".inst 0x4e95a773 // smmla v19.4s, v27.16b, v21.16b\n" + ".inst 0x4e90a77a // smmla v26.4s, v27.16b, v16.16b\n" + "ldr q27, [x25, #0x60]\n" + ".inst 0x4e95a412 // smmla v18.4s, v0.16b, v21.16b\n" + ".inst 0x4e90a411 // smmla v17.4s, v0.16b, v16.16b\n" + "ldr q0, [x25, #0x70]\n" + "add x25, x25, #0x88\n" + ".inst 0x4e9da773 // smmla v19.4s, v27.16b, v29.16b\n" + ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n" + "ldr d27, [x20, #0x0]\n" + ".inst 0x4e9da412 // smmla v18.4s, v0.16b, v29.16b\n" + ".inst 0x4e83a411 // smmla v17.4s, v0.16b, v3.16b\n" + "fcvtl v27.4s, v27.4h\n" + "uzp1 v0.2d, v19.2d, v26.2d\n" + "uzp2 v26.2d, v19.2d, v26.2d\n" + "fmul v19.4s, v27.4s, v20.s[0]\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "fmla v2.4s, v0.4s, v19.4s\n" + "ldr q19, [x23, #0x0]\n" + "uzp1 v0.2d, v18.2d, v17.2d\n" + "uzp2 v18.2d, v18.2d, v17.2d\n" + "fmul v17.4s, v27.4s, v20.s[1]\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "fmla v10.4s, v26.4s, v17.4s\n" + "ldr q17, [x23, #0x10]\n" + "fmul v26.4s, v27.4s, v20.s[2]\n" + "fmul v20.4s, v27.4s, v20.s[3]\n" + "fmla v12.4s, v0.4s, v26.4s\n" + "ldr d0, [x22, #-0x8]\n" + "ldr d26, [x21, #-0x8]\n" + "fcvtl v0.4s, v0.4h\n" + "fmla v28.4s, v18.4s, v20.4s\n" + "movi v20.4s, #0x0\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" + ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" + "ldr q19, [x23, #0x20]\n" + "fcvtl v26.4s, v26.4h\n" + ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" + ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" + "ldr q19, [x23, #0x40]\n" + ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" + ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" + "ldr q19, [x23, #0x60]\n" + ".inst 0x4e9da674 // smmla v20.4s, v19.16b, v29.16b\n" + ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n" + "uzp1 v19.2d, v20.2d, v18.2d\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "uzp2 v20.2d, v20.2d, v18.2d\n" + "fmul v18.4s, v27.4s, v9.s[0]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v11.4s, v19.4s, v18.4s\n" + "ldr q18, [x22, #0x0]\n" + "fmul v19.4s, v27.4s, v9.s[1]\n" + "fmla v13.4s, v20.4s, v19.4s\n" + "movi v19.4s, #0x0\n" + "movi v20.4s, #0x0\n" + ".inst 0x4e88a633 // smmla v19.4s, v17.16b, v8.16b\n" + ".inst 0x4e9fa634 // smmla v20.4s, v17.16b, v31.16b\n" + "ldr q17, [x23, #0x30]\n" + ".inst 0x4e8fa633 // smmla v19.4s, v17.16b, v15.16b\n" + ".inst 0x4e81a634 // smmla v20.4s, v17.16b, v1.16b\n" + "ldr q17, [x23, #0x50]\n" + ".inst 0x4e95a633 // smmla v19.4s, v17.16b, v21.16b\n" + ".inst 0x4e90a634 // smmla v20.4s, v17.16b, v16.16b\n" + "ldr q17, [x23, #0x70]\n" + "add x23, x23, #0x88\n" + ".inst 0x4e9da633 // smmla v19.4s, v17.16b, v29.16b\n" + ".inst 0x4e83a634 // smmla v20.4s, v17.16b, v3.16b\n" + "uzp1 v17.2d, v19.2d, v20.2d\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "uzp2 v20.2d, v19.2d, v20.2d\n" + "fmul v19.4s, v27.4s, v9.s[2]\n" + "fmul v9.4s, v27.4s, v9.s[3]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v22.4s, v17.4s, v19.4s\n" + "ldr q17, [x22, #0x10]\n" + "movi v19.4s, #0x0\n" + ".inst 0x4e88a653 // smmla v19.4s, v18.16b, v8.16b\n" + "fmla v23.4s, v20.4s, v9.4s\n" + "movi v20.4s, #0x0\n" + "movi v9.4s, #0x0\n" + ".inst 0x4e9fa654 // smmla v20.4s, v18.16b, v31.16b\n" + "ldr q18, [x22, #0x20]\n" + ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" + ".inst 0x4e8fa653 // smmla v19.4s, v18.16b, v15.16b\n" + ".inst 0x4e81a654 // smmla v20.4s, v18.16b, v1.16b\n" + "ldr q18, [x22, #0x40]\n" + ".inst 0x4e95a653 // smmla v19.4s, v18.16b, v21.16b\n" + ".inst 0x4e90a654 // smmla v20.4s, v18.16b, v16.16b\n" + "ldr q18, [x22, #0x60]\n" + ".inst 0x4e9da653 // smmla v19.4s, v18.16b, v29.16b\n" + ".inst 0x4e83a654 // smmla v20.4s, v18.16b, v3.16b\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e9fa632 // smmla v18.4s, v17.16b, v31.16b\n" + "ldr q17, [x22, #0x30]\n" + ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" + ".inst 0x4e81a632 // smmla v18.4s, v17.16b, v1.16b\n" + "ldr q17, [x22, #0x50]\n" + ".inst 0x4e95a629 // smmla v9.4s, v17.16b, v21.16b\n" + ".inst 0x4e90a632 // smmla v18.4s, v17.16b, v16.16b\n" + "ldr q17, [x22, #0x70]\n" + "add x22, x22, #0x88\n" + ".inst 0x4e9da629 // smmla v9.4s, v17.16b, v29.16b\n" + ".inst 0x4e83a632 // smmla v18.4s, v17.16b, v3.16b\n" + "uzp1 v17.2d, v19.2d, v20.2d\n" + "uzp2 v20.2d, v19.2d, v20.2d\n" + "fmul v19.4s, v27.4s, v0.s[0]\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v25.4s, v17.4s, v19.4s\n" + "ldr q19, [x21, #0x0]\n" + "fmul v17.4s, v27.4s, v0.s[1]\n" + "fmla v5.4s, v20.4s, v17.4s\n" + "ldr q17, [x21, #0x10]\n" + "uzp1 v20.2d, v9.2d, v18.2d\n" + "uzp2 v9.2d, v9.2d, v18.2d\n" + "fmul v18.4s, v27.4s, v0.s[2]\n" + "fmul v0.4s, v27.4s, v0.s[3]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "fmla v7.4s, v20.4s, v18.4s\n" + "movi v20.4s, #0x0\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" + ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" + "ldr q19, [x21, #0x20]\n" + "fmla v4.4s, v9.4s, v0.4s\n" + "movi v9.4s, #0x0\n" + "movi v0.4s, #0x0\n" + ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" + "fmul v8.4s, v27.4s, v26.s[0]\n" + ".inst 0x4e9fa620 // smmla v0.4s, v17.16b, v31.16b\n" + "ldr q17, [x21, #0x30]\n" + ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" + "fmul v31.4s, v27.4s, v26.s[1]\n" + ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" + "ldr q19, [x21, #0x40]\n" + ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" + "fmul v15.4s, v27.4s, v26.s[2]\n" + "fmul v27.4s, v27.4s, v26.s[3]\n" + ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n" + "ldr q1, [x21, #0x50]\n" + ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" + ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" + "ldr q26, [x21, #0x60]\n" + ".inst 0x4e95a429 // smmla v9.4s, v1.16b, v21.16b\n" + ".inst 0x4e90a420 // smmla v0.4s, v1.16b, v16.16b\n" + "ldr q21, [x21, #0x70]\n" + "add x21, x21, #0x88\n" + ".inst 0x4e9da754 // smmla v20.4s, v26.16b, v29.16b\n" + ".inst 0x4e83a752 // smmla v18.4s, v26.16b, v3.16b\n" + ".inst 0x4e9da6a9 // smmla v9.4s, v21.16b, v29.16b\n" + ".inst 0x4e83a6a0 // smmla v0.4s, v21.16b, v3.16b\n" + "uzp1 v29.2d, v20.2d, v18.2d\n" + "uzp2 v21.2d, v20.2d, v18.2d\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "uzp1 v18.2d, v9.2d, v0.2d\n" + "uzp2 v16.2d, v9.2d, v0.2d\n" + "scvtf v21.4s, v21.4s, #0x4\n" + "fmla v6.4s, v29.4s, v8.4s\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "fmla v30.4s, v21.4s, v31.4s\n" + "fmla v24.4s, v18.4s, v15.4s\n" + "fmla v14.4s, v16.4s, v27.4s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x27, x27, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "str q2, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q12, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q28, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q11, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q13, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q22, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q5, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q7, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q4, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q6, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q30, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q24, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q14, [x20, #0x0]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x10, x10, #0x10\n" + "cmp x10, #0x10\n" + "mov %x[res_ptr], x26\n" + "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x10, 9f\n" + "5:" // Row tail: Row loop + "add x24, %x[b_ptr], #0x8\n" + "mov x23, %x[nc]\n" + "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "movi v2.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "add x25, %x[a_ptr], #0x8\n" + "mov x21, %x[nb]\n" + "movi v12.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "7:" // Row tail: Block loop + "ldr q6, [x24, #0x0]\n" + "ldr q5, [x24, #0x10]\n" + "movi v17.16b, #0x4\n" + "movi v8.4s, #0x0\n" + "ldr q4, [x25, #0x0]\n" + "ldr q13, [x25, #0x10]\n" + "movi v27.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "ldr q31, [x24, #0x20]\n" + "ldr q14, [x24, #0x30]\n" + "movi v29.4s, #0x0\n" + "movi v22.16b, #0xf0\n" + "ldr q11, [x25, #0x20]\n" + "ldr q23, [x25, #0x30]\n" + "sshl v21.16b, v6.16b, v17.16b\n" + "sshl v16.16b, v5.16b, v17.16b\n" + "ldr q20, [x25, #0x40]\n" + "ldr q26, [x25, #0x50]\n" + "and v6.16b, v6.16b, v22.16b\n" + "and v5.16b, v5.16b, v22.16b\n" + "ldr q25, [x25, #0x60]\n" + "ldr q3, [x25, #0x70]\n" + "sshl v19.16b, v31.16b, v17.16b\n" + "sshl v18.16b, v14.16b, v17.16b\n" + "ldr d17, [x25, #-0x8]\n" + ".inst 0x4e95a488 // smmla v8.4s, v4.16b, v21.16b\n" + ".inst 0x4e90a49b // smmla v27.4s, v4.16b, v16.16b\n" + "and v31.16b, v31.16b, v22.16b\n" + ".inst 0x4e95a5a0 // smmla v0.4s, v13.16b, v21.16b\n" + ".inst 0x4e90a5bd // smmla v29.4s, v13.16b, v16.16b\n" + "and v14.16b, v14.16b, v22.16b\n" + "sub x20, x24, #0x8\n" + "ldr d16, [x20, #0x0]\n" + "subs x21, x21, #0x1\n" + "add x25, x25, #0x88\n" + "fcvtl v17.4s, v17.4h\n" + "add x24, x24, #0x48\n" + ".inst 0x4e93a568 // smmla v8.4s, v11.16b, v19.16b\n" + ".inst 0x4e92a57b // smmla v27.4s, v11.16b, v18.16b\n" + ".inst 0x4e93a6e0 // smmla v0.4s, v23.16b, v19.16b\n" + ".inst 0x4e92a6fd // smmla v29.4s, v23.16b, v18.16b\n" + "fcvtl v16.4s, v16.4h\n" + ".inst 0x4e86a688 // smmla v8.4s, v20.16b, v6.16b\n" + ".inst 0x4e85a69b // smmla v27.4s, v20.16b, v5.16b\n" + "fmul v23.4s, v16.4s, v17.s[0]\n" + "fmul v21.4s, v16.4s, v17.s[1]\n" + "fmul v1.4s, v16.4s, v17.s[2]\n" + "fmul v20.4s, v16.4s, v17.s[3]\n" + ".inst 0x4e86a740 // smmla v0.4s, v26.16b, v6.16b\n" + ".inst 0x4e85a75d // smmla v29.4s, v26.16b, v5.16b\n" + ".inst 0x4e9fa728 // smmla v8.4s, v25.16b, v31.16b\n" + ".inst 0x4e8ea73b // smmla v27.4s, v25.16b, v14.16b\n" + ".inst 0x4e9fa460 // smmla v0.4s, v3.16b, v31.16b\n" + ".inst 0x4e8ea47d // smmla v29.4s, v3.16b, v14.16b\n" + "uzp1 v19.2d, v8.2d, v27.2d\n" + "uzp2 v18.2d, v8.2d, v27.2d\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "uzp1 v17.2d, v0.2d, v29.2d\n" + "uzp2 v16.2d, v0.2d, v29.2d\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "fmla v2.4s, v19.4s, v23.4s\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "fmla v10.4s, v18.4s, v21.4s\n" + "fmla v12.4s, v17.4s, v1.4s\n" + "fmla v28.4s, v16.4s, v20.4s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x10, #0x1\n" + "str q2, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x2\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x3\n" + "str q12, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "str q28, [x20, #0x0]\n" + "8:" // Row tail: Accumulator store skip + "subs x23, x23, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "bne 6b\n" + "subs x10, x10, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x9\n" + "mov %x[res_ptr], x22\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +#elif defined(__ARM_NEON) && defined(__aarch64__) + GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " + "performance"); +#else + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +#endif +} + +void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (svcntw() == 8) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x20, #0x4\n" + "mov x13, %x[nr]\n" + "mov z28.s, #-0x4\n" + "mov x12, #0x88\n" + "ptrue p1.b\n" + "whilelt p0.s, XZR, x20\n" + "cmp x13, #0x10\n" + "mul x12, %x[nb], x12\n" + "blt 4f\n" + "1:" // Row loop + "add x11, %x[b_ptr], #0x10\n" + "mov x10, %x[nc]\n" + "add x9, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x28, %x[a_ptr], #0x8\n" + "mov z24.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov x27, %x[nb]\n" + "add x26, x28, x12\n" + "mov z12.b, #0x0\n" + "mov z0.b, #0x0\n" + "add x25, x26, x12\n" + "mov z13.b, #0x0\n" + "mov z1.b, #0x0\n" + "add x24, x25, x12\n" + "mov z20.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z8.b, #0x0\n" + "mov z29.b, #0x0\n" + "mov z27.b, #0x0\n" + "mov z10.b, #0x0\n" + "3:" // Block loop + "ld1b { z30.b }, p1/Z, [x11]\n" + "ld1b { z21.b }, p1/Z, [x11, #1, MUL VL]\n" + "mov z18.s, #0x0\n" + "mov z7.s, #0x0\n" + "ld1rqb { z3.b }, p1/Z, [x28]\n" + "ld1rqb { z5.b }, p1/Z, [x28, #16]\n" + "mov z9.s, #0x0\n" + "mov z22.s, #0x0\n" + "ld1b { z4.b }, p1/Z, [x11, #2, MUL VL]\n" + "ld1b { z17.b }, p1/Z, [x11, #3, MUL VL]\n" + "sub x20, x11, #0x10\n" + "sub x23, x28, #0x8\n" + "lsl z31.b, z30.b, #0x4\n" + "lsl z6.b, z21.b, #0x4\n" + "ld1h { z23.s }, p1/Z, [x20]\n" + "sub x22, x26, #0x8\n" + "and z30.b, z30.b, #0xf0\n" + "and z21.b, z21.b, #0xf0\n" + "sub x21, x25, #0x8\n" + "sub x20, x24, #0x8\n" + "lsl z14.b, z4.b, #0x4\n" + "lsl z2.b, z17.b, #0x4\n" + "subs x27, x27, #0x1\n" + "add x11, x11, #0x90\n" + ".inst 0x451f9872 // smmla z18.s, z3.b, z31.b\n" + ".inst 0x45069867 // smmla z7.s, z3.b, z6.b\n" + "ld1rqb { z3.b }, p1/Z, [x28, #32]\n" + "and z4.b, z4.b, #0xf0\n" + ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" + ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" + "ld1rqb { z5.b }, p1/Z, [x28, #48]\n" + "and z17.b, z17.b, #0xf0\n" + "fcvt z23.s, p1/m, z23.h\n" + ".inst 0x450e9872 // smmla z18.s, z3.b, z14.b\n" + ".inst 0x45029867 // smmla z7.s, z3.b, z2.b\n" + "ld1rqb { z3.b }, p1/Z, [x28, #64]\n" + ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" + ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" + "ld1rqb { z5.b }, p1/Z, [x28, #80]\n" + "fscale z23.s, p1/m, z23.s, z28.s\n" + ".inst 0x451e9872 // smmla z18.s, z3.b, z30.b\n" + ".inst 0x45159867 // smmla z7.s, z3.b, z21.b\n" + "ld1rqb { z3.b }, p1/Z, [x28, #96]\n" + ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" + ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" + "ld1rqb { z5.b }, p1/Z, [x28, #112]\n" + "add x28, x28, #0x88\n" + ".inst 0x45049872 // smmla z18.s, z3.b, z4.b\n" + ".inst 0x45119867 // smmla z7.s, z3.b, z17.b\n" + "ld1h { z3.s }, p0/Z, [x23]\n" + ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" + ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" + "fcvt z3.s, p1/m, z3.h\n" + "uzp1 z5.d, z18.d, z7.d\n" + "uzp2 z18.d, z18.d, z7.d\n" + "mov z3.q, z3.q[0]\n" + "uzp1 z7.d, z9.d, z22.d\n" + "uzp2 z22.d, z9.d, z22.d\n" + "fmul z9.s, z23.s, z3.s[0]\n" + "scvtf z5.s, p1/m, z5.s\n" + "scvtf z18.s, p1/m, z18.s\n" + "scvtf z7.s, p1/m, z7.s\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z24.s, p1/M, z5.s, z9.s\n" + "ld1rqb { z5.b }, p1/Z, [x26]\n" + "fmul z9.s, z23.s, z3.s[1]\n" + "fmla z15.s, p1/M, z18.s, z9.s\n" + "ld1rqb { z18.b }, p1/Z, [x26, #16]\n" + "fmul z9.s, z23.s, z3.s[2]\n" + "fmul z3.s, z23.s, z3.s[3]\n" + "fmla z12.s, p1/M, z7.s, z9.s\n" + "mov z9.s, #0x0\n" + "ld1h { z7.s }, p0/Z, [x22]\n" + ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" + "fmla z0.s, p1/M, z22.s, z3.s\n" + "mov z22.s, #0x0\n" + "ld1h { z3.s }, p0/Z, [x21]\n" + ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" + "ld1rqb { z5.b }, p1/Z, [x26, #32]\n" + "fcvt z7.s, p1/m, z7.h\n" + "fcvt z3.s, p1/m, z3.h\n" + ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" + ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" + "ld1rqb { z5.b }, p1/Z, [x26, #64]\n" + "mov z7.q, z7.q[0]\n" + "mov z3.q, z3.q[0]\n" + ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" + ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" + "ld1rqb { z5.b }, p1/Z, [x26, #96]\n" + ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" + ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" + "uzp1 z5.d, z9.d, z22.d\n" + "scvtf z5.s, p1/m, z5.s\n" + "uzp2 z22.d, z9.d, z22.d\n" + "fmul z9.s, z23.s, z7.s[0]\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z13.s, p1/M, z5.s, z9.s\n" + "ld1rqb { z9.b }, p1/Z, [x25]\n" + "fmul z5.s, z23.s, z7.s[1]\n" + "fmla z1.s, p1/M, z22.s, z5.s\n" + "mov z5.s, #0x0\n" + "mov z22.s, #0x0\n" + ".inst 0x451f9a45 // smmla z5.s, z18.b, z31.b\n" + ".inst 0x45069a56 // smmla z22.s, z18.b, z6.b\n" + "ld1rqb { z18.b }, p1/Z, [x26, #48]\n" + ".inst 0x450e9a45 // smmla z5.s, z18.b, z14.b\n" + ".inst 0x45029a56 // smmla z22.s, z18.b, z2.b\n" + "ld1rqb { z18.b }, p1/Z, [x26, #80]\n" + ".inst 0x451e9a45 // smmla z5.s, z18.b, z30.b\n" + ".inst 0x45159a56 // smmla z22.s, z18.b, z21.b\n" + "ld1rqb { z18.b }, p1/Z, [x26, #112]\n" + "add x26, x26, #0x88\n" + ".inst 0x45049a45 // smmla z5.s, z18.b, z4.b\n" + ".inst 0x45119a56 // smmla z22.s, z18.b, z17.b\n" + "uzp1 z18.d, z5.d, z22.d\n" + "scvtf z18.s, p1/m, z18.s\n" + "uzp2 z22.d, z5.d, z22.d\n" + "fmul z5.s, z23.s, z7.s[2]\n" + "fmul z7.s, z23.s, z7.s[3]\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z20.s, p1/M, z18.s, z5.s\n" + "ld1rqb { z18.b }, p1/Z, [x25, #16]\n" + "ld1h { z5.s }, p0/Z, [x20]\n" + "fcvt z5.s, p1/m, z5.h\n" + "fmla z25.s, p1/M, z22.s, z7.s\n" + "mov z22.s, #0x0\n" + "mov z7.s, #0x0\n" + ".inst 0x451f9936 // smmla z22.s, z9.b, z31.b\n" + ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" + "ld1rqb { z9.b }, p1/Z, [x25, #32]\n" + "mov z5.q, z5.q[0]\n" + ".inst 0x450e9936 // smmla z22.s, z9.b, z14.b\n" + ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" + "ld1rqb { z9.b }, p1/Z, [x25, #64]\n" + ".inst 0x451e9936 // smmla z22.s, z9.b, z30.b\n" + ".inst 0x45159927 // smmla z7.s, z9.b, z21.b\n" + "ld1rqb { z9.b }, p1/Z, [x25, #96]\n" + ".inst 0x45049936 // smmla z22.s, z9.b, z4.b\n" + ".inst 0x45119927 // smmla z7.s, z9.b, z17.b\n" + "uzp1 z9.d, z22.d, z7.d\n" + "scvtf z9.s, p1/m, z9.s\n" + "uzp2 z22.d, z22.d, z7.d\n" + "fmul z7.s, z23.s, z3.s[0]\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z11.s, p1/M, z9.s, z7.s\n" + "ld1rqb { z9.b }, p1/Z, [x24]\n" + "fmul z7.s, z23.s, z3.s[1]\n" + "fmla z16.s, p1/M, z22.s, z7.s\n" + "mov z22.s, #0x0\n" + "mov z7.s, #0x0\n" + ".inst 0x451f9a56 // smmla z22.s, z18.b, z31.b\n" + ".inst 0x45069a47 // smmla z7.s, z18.b, z6.b\n" + "ld1rqb { z18.b }, p1/Z, [x25, #48]\n" + ".inst 0x450e9a56 // smmla z22.s, z18.b, z14.b\n" + ".inst 0x45029a47 // smmla z7.s, z18.b, z2.b\n" + "ld1rqb { z18.b }, p1/Z, [x25, #80]\n" + ".inst 0x451e9a56 // smmla z22.s, z18.b, z30.b\n" + ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" + "ld1rqb { z18.b }, p1/Z, [x25, #112]\n" + "add x25, x25, #0x88\n" + ".inst 0x45049a56 // smmla z22.s, z18.b, z4.b\n" + ".inst 0x45119a47 // smmla z7.s, z18.b, z17.b\n" + "uzp1 z18.d, z22.d, z7.d\n" + "scvtf z18.s, p1/m, z18.s\n" + "uzp2 z7.d, z22.d, z7.d\n" + "fmul z22.s, z23.s, z3.s[2]\n" + "fmul z3.s, z23.s, z3.s[3]\n" + "scvtf z7.s, p1/m, z7.s\n" + "fmla z19.s, p1/M, z18.s, z22.s\n" + "ld1rqb { z18.b }, p1/Z, [x24, #16]\n" + "fmul z22.s, z23.s, z5.s[0]\n" + "fmla z26.s, p1/M, z7.s, z3.s\n" + "mov z3.s, #0x0\n" + "mov z7.s, #0x0\n" + ".inst 0x451f9923 // smmla z3.s, z9.b, z31.b\n" + ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" + "ld1rqb { z9.b }, p1/Z, [x24, #32]\n" + ".inst 0x450e9923 // smmla z3.s, z9.b, z14.b\n" + ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" + "mov z9.s, #0x0\n" + ".inst 0x451f9a49 // smmla z9.s, z18.b, z31.b\n" + "mov z31.s, #0x0\n" + ".inst 0x45069a5f // smmla z31.s, z18.b, z6.b\n" + "ld1rqb { z6.b }, p1/Z, [x24, #48]\n" + "ld1rqb { z18.b }, p1/Z, [x24, #64]\n" + ".inst 0x450e98c9 // smmla z9.s, z6.b, z14.b\n" + "fmul z14.s, z23.s, z5.s[1]\n" + ".inst 0x450298df // smmla z31.s, z6.b, z2.b\n" + "ld1rqb { z6.b }, p1/Z, [x24, #80]\n" + "fmul z2.s, z23.s, z5.s[2]\n" + "fmul z23.s, z23.s, z5.s[3]\n" + ".inst 0x451e9a43 // smmla z3.s, z18.b, z30.b\n" + ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" + "ld1rqb { z5.b }, p1/Z, [x24, #96]\n" + ".inst 0x451e98c9 // smmla z9.s, z6.b, z30.b\n" + ".inst 0x451598df // smmla z31.s, z6.b, z21.b\n" + "ld1rqb { z18.b }, p1/Z, [x24, #112]\n" + "add x24, x24, #0x88\n" + ".inst 0x450498a3 // smmla z3.s, z5.b, z4.b\n" + ".inst 0x451198a7 // smmla z7.s, z5.b, z17.b\n" + ".inst 0x45049a49 // smmla z9.s, z18.b, z4.b\n" + ".inst 0x45119a5f // smmla z31.s, z18.b, z17.b\n" + "uzp1 z18.d, z3.d, z7.d\n" + "uzp2 z5.d, z3.d, z7.d\n" + "scvtf z18.s, p1/m, z18.s\n" + "uzp1 z6.d, z9.d, z31.d\n" + "uzp2 z9.d, z9.d, z31.d\n" + "scvtf z5.s, p1/m, z5.s\n" + "fmla z8.s, p1/M, z18.s, z22.s\n" + "scvtf z6.s, p1/m, z6.s\n" + "scvtf z9.s, p1/m, z9.s\n" + "fmla z29.s, p1/M, z5.s, z14.s\n" + "fmla z27.s, p1/M, z6.s, z2.s\n" + "fmla z10.s, p1/M, z9.s, z23.s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x10, x10, #0x8\n" + "add %x[res_ptr], %x[res_ptr], #0x20\n" + "st1w { z24.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z15.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z12.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z0.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z13.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z1.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z20.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z25.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z11.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z16.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z19.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z26.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z8.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z29.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z27.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z10.s }, p1, [x20]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x13, x13, #0x10\n" + "cmp x13, #0x10\n" + "mov %x[res_ptr], x9\n" + "madd %x[a_ptr], x20, x12, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x13, 9f\n" + "5:" // Row tail: Row loop + "add x25, %x[b_ptr], #0x10\n" + "mov x24, %x[nc]\n" + "add x23, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "mov z24.b, #0x0\n" + "mov z15.b, #0x0\n" + "add x28, %x[a_ptr], #0x8\n" + "mov x22, %x[nb]\n" + "mov z12.b, #0x0\n" + "mov z0.b, #0x0\n" + "7:" // Row tail: Block loop + "ld1b { z3.b }, p1/Z, [x25]\n" + "ld1b { z6.b }, p1/Z, [x25, #1, MUL VL]\n" + "mov z2.s, #0x0\n" + "mov z25.s, #0x0\n" + "ld1rqb { z26.b }, p1/Z, [x28]\n" + "ld1rqb { z21.b }, p1/Z, [x28, #16]\n" + "mov z27.s, #0x0\n" + "mov z19.s, #0x0\n" + "ld1b { z29.b }, p1/Z, [x25, #2, MUL VL]\n" + "ld1b { z16.b }, p1/Z, [x25, #3, MUL VL]\n" + "sub x21, x25, #0x10\n" + "sub x20, x28, #0x8\n" + "lsl z20.b, z3.b, #0x4\n" + "lsl z4.b, z6.b, #0x4\n" + "ld1rqb { z10.b }, p1/Z, [x28, #32]\n" + "ld1rqb { z23.b }, p1/Z, [x28, #48]\n" + "and z3.b, z3.b, #0xf0\n" + "and z6.b, z6.b, #0xf0\n" + "ld1rqb { z11.b }, p1/Z, [x28, #64]\n" + "ld1rqb { z7.b }, p1/Z, [x28, #80]\n" + "lsl z8.b, z29.b, #0x4\n" + "lsl z14.b, z16.b, #0x4\n" + "ld1rqb { z18.b }, p1/Z, [x28, #96]\n" + "ld1rqb { z30.b }, p1/Z, [x28, #112]\n" + ".inst 0x45149b42 // smmla z2.s, z26.b, z20.b\n" + ".inst 0x45049b59 // smmla z25.s, z26.b, z4.b\n" + "and z29.b, z29.b, #0xf0\n" + "ld1h { z17.s }, p1/Z, [x21]\n" + ".inst 0x45149abb // smmla z27.s, z21.b, z20.b\n" + ".inst 0x45049ab3 // smmla z19.s, z21.b, z4.b\n" + "and z16.b, z16.b, #0xf0\n" + "ld1h { z4.s }, p0/Z, [x20]\n" + "subs x22, x22, #0x1\n" + "add x28, x28, #0x88\n" + "fcvt z17.s, p1/m, z17.h\n" + "add x25, x25, #0x90\n" + ".inst 0x45089942 // smmla z2.s, z10.b, z8.b\n" + ".inst 0x450e9959 // smmla z25.s, z10.b, z14.b\n" + "fcvt z4.s, p1/m, z4.h\n" + ".inst 0x45089afb // smmla z27.s, z23.b, z8.b\n" + ".inst 0x450e9af3 // smmla z19.s, z23.b, z14.b\n" + "fscale z17.s, p1/m, z17.s, z28.s\n" + "mov z4.q, z4.q[0]\n" + ".inst 0x45039962 // smmla z2.s, z11.b, z3.b\n" + ".inst 0x45069979 // smmla z25.s, z11.b, z6.b\n" + "fmul z23.s, z17.s, z4.s[0]\n" + "fmul z9.s, z17.s, z4.s[1]\n" + "fmul z21.s, z17.s, z4.s[2]\n" + "fmul z4.s, z17.s, z4.s[3]\n" + ".inst 0x450398fb // smmla z27.s, z7.b, z3.b\n" + ".inst 0x450698f3 // smmla z19.s, z7.b, z6.b\n" + ".inst 0x451d9a42 // smmla z2.s, z18.b, z29.b\n" + ".inst 0x45109a59 // smmla z25.s, z18.b, z16.b\n" + ".inst 0x451d9bdb // smmla z27.s, z30.b, z29.b\n" + ".inst 0x45109bd3 // smmla z19.s, z30.b, z16.b\n" + "uzp1 z31.d, z2.d, z25.d\n" + "uzp2 z13.d, z2.d, z25.d\n" + "scvtf z31.s, p1/m, z31.s\n" + "uzp1 z17.d, z27.d, z19.d\n" + "uzp2 z18.d, z27.d, z19.d\n" + "scvtf z13.s, p1/m, z13.s\n" + "fmla z24.s, p1/M, z31.s, z23.s\n" + "scvtf z17.s, p1/m, z17.s\n" + "scvtf z18.s, p1/m, z18.s\n" + "fmla z15.s, p1/M, z13.s, z9.s\n" + "fmla z12.s, p1/M, z17.s, z21.s\n" + "fmla z0.s, p1/M, z18.s, z4.s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x13, #0x1\n" + "st1w { z24.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x13, #0x2\n" + "st1w { z15.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x13, #0x3\n" + "st1w { z12.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "st1w { z0.s }, p1, [x20]\n" + "8:" // Row tail: Accumulator store skip + "subs x24, x24, #0x8\n" + "add %x[res_ptr], %x[res_ptr], #0x20\n" + "bne 6b\n" + "subs x13, x13, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x12\n" + "mov %x[res_ptr], x23\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); + return; + } + else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + GGML_ASSERT((ggml_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal " + "performance"); + } + else if (ggml_cpu_has_neon()) { + GGML_ASSERT(((ggml_cpu_has_sve() && (svcntw() == 8)) || ggml_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 " + "quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + GGML_ASSERT(ggml_cpu_has_sve() && + "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance"); +#elif defined(__ARM_NEON) && defined(__aarch64__) + GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " + "performance"); +#else + float sumf[4][8]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +#endif +} diff --git a/ggml/src/ggml-aarch64.h b/ggml/src/ggml-aarch64.h new file mode 100644 index 000000000..65ead1efe --- /dev/null +++ b/ggml/src/ggml-aarch64.h @@ -0,0 +1,39 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd. +#pragma once + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" + +#include "ggml.h" + +// GGML internal header + +#ifdef __cplusplus +extern "C" { +#endif + +// Quantization +void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); + +void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t interleave_blcksize); + +// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") +size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q4_0_4x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + +// GEMV +void ggml_gemv_q4_0_4x4_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_4x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_8x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); + +// GEMM +void ggml_gemm_q4_0_4x4_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_4x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_8x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); + +#ifdef __cplusplus +} +#endif + diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index c74060cc4..fafd5fa7a 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -199,6 +199,30 @@ typedef struct { } block_q8_1; static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, "wrong q8_1 block size/padding"); +typedef struct { + ggml_half d[4]; // deltas for 4 q4_0 blocks + uint8_t qs[QK4_0 * 2]; // nibbles / quants for 4 q4_0 blocks +} block_q4_0x4; +static_assert(sizeof(block_q4_0x4) == 4 * sizeof(ggml_half) + QK4_0 * 2, "wrong q4_0x4 block size/padding"); + +typedef struct { + ggml_half d[8]; // deltas for 8 q4_0 blocks + uint8_t qs[QK4_0 * 4]; // nibbles / quants for 8 q4_0 blocks +} block_q4_0x8; +static_assert(sizeof(block_q4_0x8) == 8 * sizeof(ggml_half) + QK4_0 * 4, "wrong q4_0x8 block size/padding"); + +typedef struct { + ggml_half d[4]; // deltas for 4 q8_0 blocks + int8_t qs[QK8_0 * 4]; // quants for 4 q8_0 blocks +} block_q8_0x4; +static_assert(sizeof(block_q8_0x4) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong q8_0x4 block size/padding"); + +typedef struct { + ggml_half d[8]; // deltas for 8 q8_0 blocks + int8_t qs[QK8_0 * 8]; // quants for 8 q8_0 blocks +} block_q8_0x8; +static_assert(sizeof(block_q8_0x8) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong q8_0x8 block size/padding"); + // // Super-block quantization structures // diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 1d2336190..a2c8dbec0 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -609,6 +609,10 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { #endif // defined(__ARM_NEON) && (!defined(__MSC_VER) +#ifdef __ARM_FEATURE_SVE +#include +#endif // __ARM_FEATURE_SVE + // precomputed f32 table for f16 (256 KB) // defined in ggml.c, initialized in ggml_init() extern float ggml_table_f32_f16[1 << 16]; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 0eb52e485..cbe377cf5 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -3814,43 +3814,47 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r } #endif #if defined(__ARM_FEATURE_SVE) - const svbool_t ptrueh = svptrue_pat_b8(SV_VL16); - const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh); + if (svcntb() == QK8_0) { + const svbool_t ptrueh = svptrue_pat_b8(SV_VL16); + const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh); - svfloat32_t sumv0 = svdup_n_f32(0.0f); - svfloat32_t sumv1 = svdup_n_f32(0.0f); + svfloat32_t sumv0 = svdup_n_f32(0.0f); + svfloat32_t sumv1 = svdup_n_f32(0.0f); - assert(nb % 2 == 0); // TODO: handle odd nb + assert(nb % 2 == 0); // TODO: handle odd nb - for (int i = 0; i < nb; i += 2) { - const block_q4_0 * restrict x0 = &x[i + 0]; - const block_q4_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; + for (int i = 0; i < nb; i += 2) { + const block_q4_0 * restrict x0 = &x[i + 0]; + const block_q4_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_0 * restrict y1 = &y[i + 1]; - // load x - const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); - const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); + // load x + const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); + const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); - // 4-bit -> 8-bit - const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04)); - const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04)); + // 4-bit -> 8-bit + const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04)); + const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04)); - // sub 8 - const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); - const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); + // sub 8 + const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); + const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); - // load y - const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); - // dot product - sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + // dot product + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + return; } - - *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); -#elif defined(__ARM_NEON) +#endif +#if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); @@ -5422,31 +5426,35 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r } #endif #if defined(__ARM_FEATURE_SVE) - svfloat32_t sumv0 = svdup_n_f32(0.0f); - svfloat32_t sumv1 = svdup_n_f32(0.0f); + if (svcntb() == QK8_0) { + svfloat32_t sumv0 = svdup_n_f32(0.0f); + svfloat32_t sumv1 = svdup_n_f32(0.0f); - assert(nb % 2 == 0); // TODO: handle odd nb + assert(nb % 2 == 0); // TODO: handle odd nb - for (int i = 0; i < nb; i += 2) { - const block_q8_0 * restrict x0 = &x[i + 0]; - const block_q8_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; + for (int i = 0; i < nb; i += 2) { + const block_q8_0 * restrict x0 = &x[i + 0]; + const block_q8_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_0 * restrict y1 = &y[i + 1]; - // load x - const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); - const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); + // load x + const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); + const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); - // load y - const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); - sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + return; } - - *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); -#elif defined(__ARM_NEON) +#endif +#if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); @@ -14760,6 +14768,16 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) { } \ } +#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \ + const type * q = (const type *) (data); \ + for (size_t i = 0; i < (nb); ++i) { \ + for (size_t j = 0; j < (nr); ++j) { \ + if (!validate_fp16(q[i].d[j], i)) { \ + return false; \ + } \ + } \ + } + bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes) { if (type < 0 || type >= GGML_TYPE_COUNT) { fprintf(stderr, "%s: invalid type %d\n", __func__, type); @@ -14977,6 +14995,16 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); } break; + case GGML_TYPE_Q4_0_4_4: + case GGML_TYPE_Q4_0_4_8: + { + VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x4, data, nbytes / sizeof(block_q4_0x4), 4); + } break; + case GGML_TYPE_Q4_0_8_8: + { + VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x8, data, nbytes / sizeof(block_q4_0x8), 8); + } break; + case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index bc91ac3a7..c0aced3d2 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4,6 +4,7 @@ #include "ggml-impl.h" #include "ggml-quants.h" #include "ggml.h" +#include "ggml-aarch64.h" #if defined(_MSC_VER) || defined(__MINGW32__) @@ -37,7 +38,7 @@ #include #endif -#ifdef __ARM_FEATURE_MATMUL_INT8 +#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8) #undef GGML_USE_LLAMAFILE #endif @@ -692,6 +693,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { #else .nrows = 1, #endif + .from_float_to_mat = quantize_mat_q8_0, }, [GGML_TYPE_Q8_1] = { .type_name = "q8_1", @@ -889,6 +891,54 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16, .vec_dot_type = GGML_TYPE_BF16, .nrows = 1, + }, + [GGML_TYPE_Q4_0_4_4] = { + .type_name = "q4_0_4x4", + .blck_size = QK4_0, + .type_size = sizeof(block_q4_0), + .is_quantized = true, + .to_float = NULL, + .from_float = NULL, + .from_float_reference = NULL, + .vec_dot = NULL, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + .ncols = 4, + .interleave_blcksize = 4, + .gemv = ggml_gemv_q4_0_4x4_q8_0, + .gemm = ggml_gemm_q4_0_4x4_q8_0, + }, + [GGML_TYPE_Q4_0_4_8] = { + .type_name = "q4_0_4x8", + .blck_size = QK4_0, + .type_size = sizeof(block_q4_0), + .is_quantized = true, + .to_float = NULL, + .from_float = NULL, + .from_float_reference = NULL, + .vec_dot = NULL, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + .ncols = 4, + .interleave_blcksize = 8, + .gemv = ggml_gemv_q4_0_4x8_q8_0, + .gemm = ggml_gemm_q4_0_4x8_q8_0, + }, + [GGML_TYPE_Q4_0_8_8] = { + .type_name = "q4_0_8x8", + .blck_size = QK4_0, + .type_size = sizeof(block_q4_0), + .is_quantized = true, + .to_float = NULL, + .from_float = NULL, + .from_float_reference = NULL, + .vec_dot = NULL, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + .ncols = 8, + .interleave_blcksize = 8, + .gemv = ggml_gemv_q4_0_8x8_q8_0, + .gemm = ggml_gemm_q4_0_8x8_q8_0, } }; @@ -3188,6 +3238,9 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break; case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break; + case GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = GGML_TYPE_Q4_0_4_4; break; + case GGML_FTYPE_MOSTLY_Q4_0_4_8: wtype = GGML_TYPE_Q4_0_4_8; break; + case GGML_FTYPE_MOSTLY_Q4_0_8_8: wtype = GGML_TYPE_Q4_0_8_8; break; case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; } @@ -9432,6 +9485,9 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: + case GGML_TYPE_Q4_0_4_4: + case GGML_TYPE_Q4_0_4_8: + case GGML_TYPE_Q4_0_8_8: { ggml_compute_forward_add_q_f32(params, dst); } break; @@ -9807,6 +9863,9 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: + case GGML_TYPE_Q4_0_4_4: + case GGML_TYPE_Q4_0_4_8: + case GGML_TYPE_Q4_0_8_8: { ggml_compute_forward_add1_q_f32(params, dst); } break; @@ -9932,6 +9991,9 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: + case GGML_TYPE_Q4_0_4_4: + case GGML_TYPE_Q4_0_4_8: + case GGML_TYPE_Q4_0_8_8: default: { GGML_ASSERT(false); @@ -12134,6 +12196,12 @@ static void ggml_compute_forward_mul_mat( enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; int64_t const vec_dot_num_rows = type_traits[type].nrows; + int64_t const matmul_num_cols = type_traits[type].ncols; + int64_t const interleave_blcksize = type_traits[type].interleave_blcksize; + ggml_from_float_to_mat_t const from_float_to_mat + = type_traits[vec_dot_type].from_float_to_mat; + ggml_gemv_t const gemv = type_traits[type].gemv; + ggml_gemm_t const gemm = type_traits[type].gemm; GGML_ASSERT(ne0 == ne01); GGML_ASSERT(ne1 == ne11); @@ -12192,7 +12260,16 @@ UseGgmlGemm1:; for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = ith; i11 < ne11; i11 += nth) { + int64_t i11_processed = 0; + if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) { + for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { + from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + 4, ne10, interleave_blcksize); + } + i11_processed = ne11 - ne11 % 4; + } + for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), ne10); @@ -12273,6 +12350,28 @@ UseGgmlGemm2:; const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; + if ((ggml_n_dims(src0) == 2) && gemv) { + const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t src1_col_stride = ggml_is_contiguous(src1) || src1->type != vec_dot_type ? ggml_row_size(vec_dot_type, ne10) : nb11; + int64_t src0_start = (ith * ne01) / nth; + int64_t src0_end = ((ith + 1) * ne01) / nth; + src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start; + src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end; + if (src0_start >= src0_end) return; + + // If there are more than three rows in src1, use gemm; otherwise, use gemv. + if (gemm && (ne11 > 3)) { + gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01, + (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); + } + for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) { + gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01, + (const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1, + src0_end - src0_start); + } + return; + } + // The first chunk comes from our thread_id, the rest will get auto-assigned. int current_chunk = ith; @@ -12318,6 +12417,8 @@ static void ggml_compute_forward_mul_mat_id( ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; + int64_t const matmul_num_cols = type_traits[type].ncols; + ggml_gemv_t const gemv = type_traits[type].gemv; // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == ggml_type_size(type)); @@ -12403,6 +12504,34 @@ static void ggml_compute_forward_mul_mat_id( const int64_t nr0 = ne01; // src0 rows const int64_t nr1 = cne1; // src1 rows + if (((ggml_n_dims(src0) - 1) == 2) && gemv) { + int64_t src0_cur_start = (ith * ne01) / nth; + int64_t src0_cur_end = ((ith + 1) * ne01) / nth; + src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start; + src0_cur_end = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end; + if (src0_cur_start >= src0_cur_end) return; + + for (int ir1 = 0; ir1 < nr1; ir1++) { + struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1); + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + const char * src1_col = (const char *) wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12 * ne11) * row_size + : (i11 * nb11 + i12 * nb12)); + + gemv(ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, + (const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start); + } + continue; + } + // distribute the thread work across the inner or outer loop based on which one is larger const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows @@ -12704,6 +12833,9 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: + case GGML_TYPE_Q4_0_4_4: + case GGML_TYPE_Q4_0_4_8: + case GGML_TYPE_Q4_0_8_8: { ggml_compute_forward_out_prod_q_f32(params, dst); } break; @@ -12889,6 +13021,9 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: + case GGML_TYPE_Q4_0_4_4: + case GGML_TYPE_Q4_0_4_8: + case GGML_TYPE_Q4_0_8_8: default: { GGML_ASSERT(false); @@ -13148,6 +13283,9 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: + case GGML_TYPE_Q4_0_4_4: + case GGML_TYPE_Q4_0_4_8: + case GGML_TYPE_Q4_0_8_8: { ggml_compute_forward_get_rows_q(params, dst); } break; @@ -13734,6 +13872,9 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q8_K: + case GGML_TYPE_Q4_0_4_4: + case GGML_TYPE_Q4_0_4_8: + case GGML_TYPE_Q4_0_8_8: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -20457,6 +20598,9 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_0_4_4: result = quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_0_4_8: result = quantize_q4_0_4x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_0_8_8: result = quantize_q4_0_8x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_F16: { size_t elemsize = sizeof(ggml_fp16_t); @@ -21759,8 +21903,6 @@ int ggml_cpu_has_neon(void) { int ggml_cpu_has_sve(void) { #if defined(__ARM_FEATURE_SVE) - // TODO: Currently, SVE 256 bit is only supported. - GGML_ASSERT(svcntb() == QK8_0); return 1; #else return 0; diff --git a/include/llama.h b/include/llama.h index bb4b05ba6..3970c3aeb 100644 --- a/include/llama.h +++ b/include/llama.h @@ -162,6 +162,9 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/src/llama.cpp b/src/llama.cpp index 98a242081..b19d786e2 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3788,6 +3788,9 @@ struct llama_model_loader { case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; + case GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break; + case GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break; + case GGML_TYPE_Q4_0_8_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_8_8; break; default: { LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); @@ -4481,6 +4484,9 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; + case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4"; + case LLAMA_FTYPE_MOSTLY_Q4_0_4_8: return "Q4_0_4_8"; + case LLAMA_FTYPE_MOSTLY_Q4_0_8_8: return "Q4_0_8_8"; default: return "unknown, may not work"; } @@ -17768,6 +17774,10 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { new_type = GGML_TYPE_IQ3_S; } + else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 || + new_type == GGML_TYPE_Q4_0_8_8) { + new_type = GGML_TYPE_Q4_0; + } } } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) { @@ -18080,6 +18090,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break; case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break; + case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: default_type = GGML_TYPE_Q4_0_4_4; break; + case LLAMA_FTYPE_MOSTLY_Q4_0_4_8: default_type = GGML_TYPE_Q4_0_4_8; break; + case LLAMA_FTYPE_MOSTLY_Q4_0_8_8: default_type = GGML_TYPE_Q4_0_8_8; break; default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); } @@ -18390,6 +18403,14 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s f32_data = (float *) f32_conv_buf.data(); } + int chunk_size_multiplier = 1; + if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 || new_type == GGML_TYPE_Q4_0_8_8) { + if ((new_type == GGML_TYPE_Q4_0_8_8) && (tensor->ne[1] % 8 != 0)) new_type = GGML_TYPE_Q4_0; + else if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_0; + if (new_type == GGML_TYPE_Q4_0_8_8) chunk_size_multiplier = 8; + else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8) chunk_size_multiplier = 4; + } + LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type)); fflush(stdout); @@ -18402,7 +18423,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s const int64_t nrows = tensor->ne[1]; static const int64_t min_chunk_size = 32 * 512; - const int64_t chunk_size = n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row); + const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)) * + chunk_size_multiplier; const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1]; const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size; From 6b2a849d1f43d46b82d2f9c08c3275137b528784 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 10 Jul 2024 15:23:29 +0300 Subject: [PATCH 08/26] ggml : move sgemm sources to llamafile subfolder (#8394) ggml-ci --- Makefile | 8 ++++---- ggml/CMakeLists.txt | 2 +- ggml/src/CMakeLists.txt | 6 +++--- ggml/src/ggml.c | 3 +-- ggml/src/{ => llamafile}/sgemm.cpp | 0 ggml/src/{ => llamafile}/sgemm.h | 0 6 files changed, 9 insertions(+), 10 deletions(-) rename ggml/src/{ => llamafile}/sgemm.cpp (100%) rename ggml/src/{ => llamafile}/sgemm.h (100%) diff --git a/Makefile b/Makefile index b70ebaed5..68197fef8 100644 --- a/Makefile +++ b/Makefile @@ -554,7 +554,7 @@ endif # GGML_BLIS ifndef GGML_NO_LLAMAFILE MK_CPPFLAGS += -DGGML_USE_LLAMAFILE - OBJ_GGML += ggml/src/sgemm.o + OBJ_GGML += ggml/src/llamafile/sgemm.o endif ifdef GGML_RPC @@ -983,9 +983,9 @@ ggml/src/ggml-blas.o: \ $(CXX) $(CXXFLAGS) -c $< -o $@ ifndef GGML_NO_LLAMAFILE -ggml/src/sgemm.o: \ - ggml/src/sgemm.cpp \ - ggml/src/sgemm.h \ +ggml/src/llamafile/sgemm.o: \ + ggml/src/llamafile/sgemm.cpp \ + ggml/src/llamafile/sgemm.h \ ggml/include/ggml.h $(CXX) $(CXXFLAGS) -c $< -o $@ endif # GGML_NO_LLAMAFILE diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 0d0d52d57..649ac3dcc 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -104,7 +104,7 @@ option(GGML_ACCELERATE "ggml: enable Accelerate framework" option(GGML_BLAS "ggml: use BLAS" ${GGML_BLAS_DEFAULT}) set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING "ggml: BLAS library vendor") -option(GGML_LLAMAFILE "ggml: use ggml SGEMM" OFF) +option(GGML_LLAMAFILE "ggml: use LLAMAFILE" OFF) option(GGML_CUDA "ggml: use CUDA" OFF) option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index aae5b8e9f..c5ee7e425 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -238,12 +238,12 @@ if (GGML_BLAS) endif() if (GGML_LLAMAFILE) - message(STATUS "Using ggml SGEMM") + message(STATUS "Using llamafile") add_compile_definitions(GGML_USE_LLAMAFILE) - set(GGML_HEADERS_LLAMAFILE sgemm.h) - set(GGML_SOURCES_LLAMAFILE sgemm.cpp) + set(GGML_HEADERS_LLAMAFILE llamafile/sgemm.h) + set(GGML_SOURCES_LLAMAFILE llamafile/sgemm.cpp) endif() if (GGML_CUDA) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index c0aced3d2..1bb731e16 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6,7 +6,6 @@ #include "ggml.h" #include "ggml-aarch64.h" - #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) @@ -43,7 +42,7 @@ #endif #ifdef GGML_USE_LLAMAFILE -#include "sgemm.h" +#include #endif #if defined(_MSC_VER) diff --git a/ggml/src/sgemm.cpp b/ggml/src/llamafile/sgemm.cpp similarity index 100% rename from ggml/src/sgemm.cpp rename to ggml/src/llamafile/sgemm.cpp diff --git a/ggml/src/sgemm.h b/ggml/src/llamafile/sgemm.h similarity index 100% rename from ggml/src/sgemm.h rename to ggml/src/llamafile/sgemm.h From f4444d992c16b6b9442f4770c7c3a10b19a08343 Mon Sep 17 00:00:00 2001 From: AidanBeltonS Date: Wed, 10 Jul 2024 16:10:49 +0100 Subject: [PATCH 09/26] [SYCL] Use multi_ptr to clean up deprecated warnings (#8256) --- ggml/src/ggml-sycl/common.hpp | 6 ++ ggml/src/ggml-sycl/convert.cpp | 2 +- ggml/src/ggml-sycl/mmq.cpp | 184 ++++++++++++++++----------------- ggml/src/ggml-sycl/norm.cpp | 6 +- ggml/src/ggml-sycl/softmax.cpp | 2 +- 5 files changed, 103 insertions(+), 97 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 9a1c161b6..68d41411b 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -346,4 +346,10 @@ inline sycl::vec vec_aligned_load(const Tp* aligned_ptr) { return *reinterpret_cast*>(aligned_ptr); } +// Helper for accessing pointers with no warnings +template +static __dpct_inline__ Tp* get_pointer(sycl::local_accessor acc) { + return acc.template get_multi_ptr().get(); +} + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index a15271b51..39c28753c 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -158,7 +158,7 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k, sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { - dequantize_block_q4_K(vx, y, scale_local_acc.get_pointer(), item_ct1); + dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1); }); }); } diff --git a/ggml/src/ggml-sycl/mmq.cpp b/ggml/src/ggml-sycl/mmq.cpp index b514f0040..3107ba919 100644 --- a/ggml/src/ggml-sycl/mmq.cpp +++ b/ggml/src/ggml-sycl/mmq.cpp @@ -1835,10 +1835,10 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy, mul_mat_q4_0( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, - tile_x_qs_q4_0_acc_ct1.get_pointer(), - tile_x_d_q4_0_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); + get_pointer(tile_x_qs_q4_0_acc_ct1), + get_pointer(tile_x_d_q4_0_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); }); }); } @@ -1870,10 +1870,10 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy, mul_mat_q4_0( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, - tile_x_qs_q4_0_acc_ct1.get_pointer(), - tile_x_d_q4_0_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); + get_pointer(tile_x_qs_q4_0_acc_ct1), + get_pointer(tile_x_d_q4_0_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); }); }); } @@ -1950,10 +1950,10 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy, mul_mat_q4_1( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, - tile_x_qs_q4_1_acc_ct1.get_pointer(), - tile_x_dm_q4_1_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); + get_pointer(tile_x_qs_q4_1_acc_ct1), + get_pointer(tile_x_dm_q4_1_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); }); }); } @@ -1985,10 +1985,10 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy, mul_mat_q4_1( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, - tile_x_qs_q4_1_acc_ct1.get_pointer(), - tile_x_dm_q4_1_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); + get_pointer(tile_x_qs_q4_1_acc_ct1), + get_pointer(tile_x_dm_q4_1_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); }); }); } @@ -2065,10 +2065,10 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy, mul_mat_q5_0( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, - tile_x_ql_q5_0_acc_ct1.get_pointer(), - tile_x_d_q5_0_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); + get_pointer(tile_x_ql_q5_0_acc_ct1), + get_pointer(tile_x_d_q5_0_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); }); }); } @@ -2100,10 +2100,10 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy, mul_mat_q5_0( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, - tile_x_ql_q5_0_acc_ct1.get_pointer(), - tile_x_d_q5_0_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); + get_pointer(tile_x_ql_q5_0_acc_ct1), + get_pointer(tile_x_d_q5_0_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); }); }); } @@ -2180,10 +2180,10 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy, mul_mat_q5_1( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, - tile_x_ql_q5_1_acc_ct1.get_pointer(), - tile_x_dm_q5_1_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); + get_pointer(tile_x_ql_q5_1_acc_ct1), + get_pointer(tile_x_dm_q5_1_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); }); }); } @@ -2215,10 +2215,10 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy, mul_mat_q5_1( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, - tile_x_ql_q5_1_acc_ct1.get_pointer(), - tile_x_dm_q5_1_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); + get_pointer(tile_x_ql_q5_1_acc_ct1), + get_pointer(tile_x_dm_q5_1_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); }); }); } @@ -2295,10 +2295,10 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy, mul_mat_q8_0( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, - tile_x_qs_q8_0_acc_ct1.get_pointer(), - tile_x_d_q8_0_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); + get_pointer(tile_x_qs_q8_0_acc_ct1), + get_pointer(tile_x_d_q8_0_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); }); }); } @@ -2330,10 +2330,10 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy, mul_mat_q8_0( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, - tile_x_qs_q8_0_acc_ct1.get_pointer(), - tile_x_d_q8_0_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); + get_pointer(tile_x_qs_q8_0_acc_ct1), + get_pointer(tile_x_d_q8_0_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); }); }); } @@ -2412,11 +2412,11 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy, mul_mat_q2_K( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, - tile_x_ql_q2_K_acc_ct1.get_pointer(), - tile_x_dm_q2_K_acc_ct1.get_pointer(), - tile_x_sc_q2_K_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); + get_pointer(tile_x_ql_q2_K_acc_ct1), + get_pointer(tile_x_dm_q2_K_acc_ct1), + get_pointer(tile_x_sc_q2_K_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); }); }); } @@ -2450,11 +2450,11 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy, mul_mat_q2_K( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, - tile_x_ql_q2_K_acc_ct1.get_pointer(), - tile_x_dm_q2_K_acc_ct1.get_pointer(), - tile_x_sc_q2_K_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); + get_pointer(tile_x_ql_q2_K_acc_ct1), + get_pointer(tile_x_dm_q2_K_acc_ct1), + get_pointer(tile_x_sc_q2_K_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); }); }); } @@ -2537,12 +2537,12 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy, mul_mat_q3_K( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, - tile_x_ql_q3_K_acc_ct1.get_pointer(), - tile_x_dm_q3_K_acc_ct1.get_pointer(), - tile_x_qh_q3_K_acc_ct1.get_pointer(), - tile_x_sc_q3_K_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); + get_pointer(tile_x_ql_q3_K_acc_ct1), + get_pointer(tile_x_dm_q3_K_acc_ct1), + get_pointer(tile_x_qh_q3_K_acc_ct1), + get_pointer(tile_x_sc_q3_K_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); }); }); } @@ -2578,12 +2578,12 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy, mul_mat_q3_K( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, - tile_x_ql_q3_K_acc_ct1.get_pointer(), - tile_x_dm_q3_K_acc_ct1.get_pointer(), - tile_x_qh_q3_K_acc_ct1.get_pointer(), - tile_x_sc_q3_K_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); + get_pointer(tile_x_ql_q3_K_acc_ct1), + get_pointer(tile_x_dm_q3_K_acc_ct1), + get_pointer(tile_x_qh_q3_K_acc_ct1), + get_pointer(tile_x_sc_q3_K_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); }); }); } @@ -2663,11 +2663,11 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy, mul_mat_q4_K( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, - tile_x_ql_q4_K_acc_ct1.get_pointer(), - tile_x_dm_q4_K_acc_ct1.get_pointer(), - tile_x_sc_q4_K_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); + get_pointer(tile_x_ql_q4_K_acc_ct1), + get_pointer(tile_x_dm_q4_K_acc_ct1), + get_pointer(tile_x_sc_q4_K_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); }); }); } @@ -2701,11 +2701,11 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy, mul_mat_q4_K( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, - tile_x_ql_q4_K_acc_ct1.get_pointer(), - tile_x_dm_q4_K_acc_ct1.get_pointer(), - tile_x_sc_q4_K_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); + get_pointer(tile_x_ql_q4_K_acc_ct1), + get_pointer(tile_x_dm_q4_K_acc_ct1), + get_pointer(tile_x_sc_q4_K_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); }); }); } @@ -2784,11 +2784,11 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy, mul_mat_q5_K( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, - tile_x_ql_q5_K_acc_ct1.get_pointer(), - tile_x_dm_q5_K_acc_ct1.get_pointer(), - tile_x_sc_q5_K_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); + get_pointer(tile_x_ql_q5_K_acc_ct1), + get_pointer(tile_x_dm_q5_K_acc_ct1), + get_pointer(tile_x_sc_q5_K_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); }); }); } @@ -2822,11 +2822,11 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy, mul_mat_q5_K( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, - tile_x_ql_q5_K_acc_ct1.get_pointer(), - tile_x_dm_q5_K_acc_ct1.get_pointer(), - tile_x_sc_q5_K_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); + get_pointer(tile_x_ql_q5_K_acc_ct1), + get_pointer(tile_x_dm_q5_K_acc_ct1), + get_pointer(tile_x_sc_q5_K_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); }); }); } @@ -2905,11 +2905,11 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy, mul_mat_q6_K( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, - tile_x_ql_acc_ct1.get_pointer(), - tile_x_dm_acc_ct1.get_pointer(), - tile_x_sc_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); + get_pointer(tile_x_ql_acc_ct1), + get_pointer(tile_x_dm_acc_ct1), + get_pointer(tile_x_sc_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); }); }); } @@ -2943,11 +2943,11 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy, mul_mat_q6_K( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, - tile_x_ql_acc_ct1.get_pointer(), - tile_x_dm_acc_ct1.get_pointer(), - tile_x_sc_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); + get_pointer(tile_x_ql_acc_ct1), + get_pointer(tile_x_dm_acc_ct1), + get_pointer(tile_x_sc_acc_ct1), + get_pointer(tile_y_qs_acc_ct1), + get_pointer(tile_y_ds_acc_ct1)); }); }); } diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp index e0c5dfeca..cccf87d06 100644 --- a/ggml/src/ggml-sycl/norm.cpp +++ b/ggml/src/ggml-sycl/norm.cpp @@ -218,7 +218,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols, [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { norm_f32(x, dst, ncols, eps, item_ct1, - s_sum_acc_ct1.get_pointer(), work_group_size); + get_pointer(s_sum_acc_ct1), work_group_size); }); }); } @@ -265,7 +265,7 @@ static void group_norm_f32_sycl(const float* x, float* dst, [[intel::reqd_sub_group_size(WARP_SIZE)]] { group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1, - s_sum_acc_ct1.get_pointer(), work_group_size); + get_pointer(s_sum_acc_ct1), work_group_size); }); }); } @@ -306,7 +306,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { rms_norm_f32(x, dst, ncols, eps, item_ct1, - s_sum_acc_ct1.get_pointer(), work_group_size); + get_pointer(s_sum_acc_ct1), work_group_size); }); }); } diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index e624b6ba3..c5d9a837e 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -136,7 +136,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float * soft_max_f32(x, mask, dst, ncols_par, nrows_y, scale, max_bias, m0, m1, n_head_log2, item_ct1, - local_buf_acc.get_pointer()); + get_pointer(local_buf_acc)); }); }); } From dd07a123b79f9bd9e8a4ba0447427b3083e9347a Mon Sep 17 00:00:00 2001 From: Clint Herron Date: Wed, 10 Jul 2024 12:35:18 -0400 Subject: [PATCH 10/26] Name Migration: Build the deprecation-warning 'main' binary every time (#8404) * Modify the deprecation-warning 'main' binary to build every time, instead of only when a legacy binary is present. This is to help users of tutorials and other instruction sets from knowing what to do when the 'main' binary is missing and they are trying to follow instructions. * Adjusting 'server' name-deprecation binary to build all the time, similar to the 'main' legacy name binary. --- Makefile | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/Makefile b/Makefile index 68197fef8..668b38b99 100644 --- a/Makefile +++ b/Makefile @@ -1513,15 +1513,17 @@ llama-q8dot: pocs/vdot/q8dot.cpp ggml/src/ggml.o \ # Mark legacy binary targets as .PHONY so that they are always checked. .PHONY: main quantize perplexity embedding server finetune +# NOTE: We currently will always build the deprecation-warning `main` and `server` binaries to help users migrate. +# Eventually we will want to remove these target from building all the time. main: examples/deprecation-warning/deprecation-warning.cpp -ifneq (,$(wildcard main)) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) - @echo "#########" - @echo "WARNING: The 'main' binary is deprecated. Please use 'llama-cli' instead." - @echo " Remove the 'main' binary to remove this warning." - @echo "#########" -endif + @echo "NOTICE: The 'main' binary is deprecated. Please use 'llama-cli' instead." + +server: examples/deprecation-warning/deprecation-warning.cpp + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + @echo "NOTICE: The 'server' binary is deprecated. Please use 'llama-server' instead." quantize: examples/deprecation-warning/deprecation-warning.cpp ifneq (,$(wildcard quantize)) @@ -1553,16 +1555,6 @@ ifneq (,$(wildcard embedding)) @echo "#########" endif -server: examples/deprecation-warning/deprecation-warning.cpp -ifneq (,$(wildcard server)) - $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) - $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) - @echo "#########" - @echo "WARNING: The 'server' binary is deprecated. Please use 'llama-server' instead." - @echo " Remove the 'server' binary to remove this warning." - @echo "#########" -endif - finetune: examples/deprecation-warning/deprecation-warning.cpp ifneq (,$(wildcard finetune)) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) From 278d0e18469aacf505be18ce790a63c7cc31be26 Mon Sep 17 00:00:00 2001 From: Clint Herron Date: Wed, 10 Jul 2024 20:08:17 -0400 Subject: [PATCH 11/26] Initialize default slot sampling parameters from the global context. (#8418) --- examples/server/server.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8feff6702..efd426289 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -737,6 +737,8 @@ struct server_context { slot.ga_n = ga_n; slot.ga_w = ga_w; + slot.sparams = params.sparams; + slot.reset(); slots.push_back(slot); From 7a221b672e49dfae459b1af27210ba3f2b5419b6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 11 Jul 2024 10:21:30 +0300 Subject: [PATCH 12/26] llama : use F32 precision in Qwen2 attention and no FA (#8412) --- src/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index b19d786e2..ed77ed918 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8134,7 +8134,7 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); cb(kq, "kq", il); - if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) { + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) { // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 ggml_mul_mat_set_prec(kq, GGML_PREC_F32); From 9a55ffe6fbf7f19c865f8f277ca7cd585cc0b094 Mon Sep 17 00:00:00 2001 From: compilade Date: Thu, 11 Jul 2024 03:41:48 -0400 Subject: [PATCH 13/26] tokenize : add --no-parse-special option (#8423) This should allow more easily explaining how parse_special affects tokenization. --- examples/tokenize/tokenize.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/tokenize/tokenize.cpp b/examples/tokenize/tokenize.cpp index 0180c87d8..2afb6024c 100644 --- a/examples/tokenize/tokenize.cpp +++ b/examples/tokenize/tokenize.cpp @@ -29,6 +29,7 @@ static void print_usage_information(const char * argv0, FILE * stream) { fprintf(stream, " -p PROMPT, --prompt PROMPT read prompt from the argument.\n"); fprintf(stream, " --stdin read prompt from standard input.\n"); fprintf(stream, " --no-bos do not ever add a BOS token to the prompt, even if normally the model uses a BOS token.\n"); + fprintf(stream, " --no-parse-special do not parse control tokens.\n"); fprintf(stream, " --log-disable disable logs. Makes stderr quiet when loading the model.\n"); fprintf(stream, " --show-count print the total number of tokens.\n"); } @@ -195,6 +196,7 @@ int main(int raw_argc, char ** raw_argv) { // variables where to put any arguments we see. bool printing_ids = false; bool no_bos = false; + bool no_parse_special = false; bool disable_logging = false; bool show_token_count = false; const char * model_path = NULL; @@ -229,6 +231,9 @@ int main(int raw_argc, char ** raw_argv) { else if (arg == "--no-bos") { no_bos = true; } + else if (arg == "--no-parse-special") { + no_parse_special = true; + } else if (arg == "-p" || arg == "--prompt") { if (prompt_set) { fprintf(stderr, "Error: -p or --prompt specified multiple times.\n"); @@ -359,9 +364,10 @@ int main(int raw_argc, char ** raw_argv) { const bool model_wants_add_bos = llama_should_add_bos_token(model); const bool add_bos = model_wants_add_bos && !no_bos; + const bool parse_special = !no_parse_special; std::vector tokens; - tokens = ::llama_tokenize(model, prompt, add_bos, true); + tokens = ::llama_tokenize(model, prompt, add_bos, parse_special); if (printing_ids) { printf("["); From a977c115448e40856fb9cbe3ceb6d8ce802553b0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 11 Jul 2024 11:20:40 +0300 Subject: [PATCH 14/26] gitignore : deprecated binaries --- .gitignore | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.gitignore b/.gitignore index 4866f6122..7c7dee0c6 100644 --- a/.gitignore +++ b/.gitignore @@ -61,6 +61,11 @@ llama-batched-swift out/ tmp/ +# Deprecated + +/main +/server + # CI !.github/workflows/*.yml From 808aba39161e5d7ca2ff24110b5aa14d2e536988 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 11 Jul 2024 16:47:47 +0200 Subject: [PATCH 15/26] CUDA: optimize and refactor MMQ (#8416) * CUDA: optimize and refactor MMQ * explicit q8_1 memory layouts, add documentation --- ggml/src/ggml-cuda/mma.cuh | 4 + ggml/src/ggml-cuda/mmq.cuh | 1365 ++++++++++++++++--------------- ggml/src/ggml-cuda/quantize.cu | 107 ++- ggml/src/ggml-cuda/quantize.cuh | 6 +- ggml/src/ggml-cuda/vecdotq.cuh | 72 +- 5 files changed, 867 insertions(+), 687 deletions(-) diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 5d87dd8e6..a452a3cc3 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -70,6 +70,10 @@ struct mma_int_A_I16K8 { } #endif // defined(INT8_MMA_AVAILABLE) } + + __device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) { + ((mma_int_A_I16K4 *) x)[0].load(xs0, stride); + } }; struct mma_int_B_J8K4 { diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 118e34d28..51c44d857 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -8,18 +8,70 @@ #include #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. +#define MMQ_ITER_K 256 +#define MMQ_NWARPS 8 typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride); -typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0); +typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00); typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max); +enum mmq_q8_1_ds_layout { + MMQ_Q8_1_DS_LAYOUT_D4, + MMQ_Q8_1_DS_LAYOUT_DS4, + MMQ_Q8_1_DS_LAYOUT_D2S6, +}; + struct block_q8_1_mmq { - half2 ds[4]; - int8_t qs[4*QK8_1]; + // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block. + // The y float data is first grouped as blocks of 128 values. + // These blocks are then treated as individual data values and transposed. + // + // To avoid shared memory bank conflicts each block is padded with 16 bytes. + // This padding is also used to store block scales/partial sums. + // The scales multiplied with the quantized data are equal to the unquantized values. + // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization) + // and are only needed for performance reasons. + // + // The exact data stored depends on the x data type. + union { + float d4[4]; // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3 + half2 ds4[4]; // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3 + half d2s6[8]; // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values, + // stored as d0,d1,s1,s2,s3,s4,s5 + }; + int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each }; static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size"); static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size"); +static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { + switch (type_x) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q5_0: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q5_1: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q8_0: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q2_K: + return MMQ_Q8_1_DS_LAYOUT_D2S6; + case GGML_TYPE_Q3_K: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + return MMQ_Q8_1_DS_LAYOUT_D4; + default: + GGML_ASSERT(false); + break; + } +} + struct tile_x_sizes { int qs; int dm; @@ -79,49 +131,46 @@ static constexpr __device__ int get_mmq_y_device() { #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) } -#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} -#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} -#define MMQ_DP4A_TXS_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0} -#define MMQ_DP4A_TXS_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0} -#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0} -#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0} -#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4} -#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} +#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} +#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0} +#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0} +#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0} +#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 : type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 : - type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 : - type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 : + type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 : + type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 : type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 : type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K : type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K : type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : - type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q5_0 : - type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q5_0 : + type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 : + type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 : tile_x_sizes{0, 0, 0}; } -#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4) -#define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1 + 4) -#define MMQ_MMA_TILE_X_K_Q5_0 (2*WARP_SIZE + WARP_SIZE/QI5_0 + 4) -#define MMQ_MMA_TILE_X_K_Q5_1 (2*WARP_SIZE + WARP_SIZE/QI5_1 + 4) -#define MMQ_MMA_TILE_X_K_Q8_0 (1*WARP_SIZE + WARP_SIZE/QI8_0 + 0) -#define MMQ_MMA_TILE_X_K_Q2_K (1*WARP_SIZE + WARP_SIZE + 4) -#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/QI3_K + WARP_SIZE/4 + 2) -#define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K + WARP_SIZE/8 + 7) -#define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K + WARP_SIZE/8 + 7) -#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7) +#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4) +#define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1 + 4) +#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE + 4) +#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/(2*QI3_K) + WARP_SIZE/8 + 7) +#define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K + WARP_SIZE/8 + 7) +#define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K + WARP_SIZE/8 + 7) +#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7) static_assert(MMQ_MMA_TILE_X_K_Q4_0 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q4_1 % 8 == 4, "Wrong padding."); -static_assert(MMQ_MMA_TILE_X_K_Q5_0 % 8 == 4, "Wrong padding."); -static_assert(MMQ_MMA_TILE_X_K_Q5_1 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding."); @@ -131,21 +180,20 @@ static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 : type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 : - type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 : - type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 : + type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 : type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 : type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K : type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K : type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K : type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K : type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : - type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q5_0 : - type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q5_0 : + type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 : 0; } #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) -#define MMQ_NWARPS 8 static int mmq_get_granularity_host(const int mmq_x, const int cc) { return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8; @@ -218,7 +266,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); const int * x_qs = (const int *) x; @@ -226,34 +274,39 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) { + const int k0 = k00 + k01; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); - - int u[2*VDR_Q4_0_Q8_1_MMQ]; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; #pragma unroll - for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_0) % WARP_SIZE]; + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); + + int u[2*VDR_Q4_0_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)]; + } + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl + (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_0], u, + x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } - - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl - (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0], - y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } template static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; @@ -271,52 +324,60 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; - mma_A A[ntx]; - float dA[ntx][mma_C::ne/2]; + mma_A A[ntx][4]; + float dA[ntx][mma_C::ne/2][4]; const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + n*mma_A::I + mma_A::get_i(l); - const int k = k0 + mma_A::get_k(l) % QI4_0; - const int shift = 4*(mma_A::get_k(l) / QI4_0); - - A[n].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808); - } + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) { + const int k0 = k00 + k01; #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + for (int l = 0; l < mma_A::ne; ++l) { + const int i = i0 + n*mma_A::I + mma_A::get_i(l); + const int k = k0/QR4_0 + mma_A::get_k(l) % QI4_0; + const int shift = 4*(mma_A::get_k(l) / QI4_0); - dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/QI4_0]; + A[n][k01/(QR4_0*QI4_0)].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808); + } + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + + dA[n][l][k01/(QR4_0*QI4_0)] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/(QR4_0*QI4_0)]; + } } } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - mma_B B; - float dB[mma_C::ne/2]; +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) { + mma_B B; + float dB[mma_C::ne/2]; - B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K); + B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); - dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); - } + dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } #pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma_K8(A[n], B); + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n][k01/(QR4_0*QI4_0)], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l]; + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2][k01/(QR4_0*QI4_0)]*dB[l%2]*C.x[l]; + } } } } @@ -381,7 +442,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); const int * x_qs = (const int *) x; @@ -389,34 +450,39 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) { + const int k0 = k00 + k01; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); - - int u[2*VDR_Q4_1_Q8_1_MMQ]; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; #pragma unroll - for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_1) % WARP_SIZE]; + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); + + int u[2*VDR_Q4_1_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)]; + } + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl + (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_1], u, + x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } - - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl - (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1], - y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } template static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; @@ -435,50 +501,58 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; - mma_A A[ntx]; - half2 dmA[ntx][mma_C::ne/2]; + mma_A A[ntx][4]; + half2 dmA[ntx][mma_C::ne/2][4]; const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); #pragma unroll for (int n = 0; n < ntx; ++n) { - ((mma_A_K4 *) &A[n])[0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0, MMQ_MMA_TILE_X_K_Q4_1); - A[n].x[2] = (A[n].x[0] >> 4) & 0x0F0F0F0F; - A[n].x[3] = (A[n].x[1] >> 4) & 0x0F0F0F0F; - A[n].x[0] &= 0x0F0F0F0F; - A[n].x[1] &= 0x0F0F0F0F; +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) { + const int k0 = k00 + k01; + + A[n][k01/(QR4_1*QI4_1)].load_low(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0/QR4_1, MMQ_MMA_TILE_X_K_Q4_1); + A[n][k01/(QR4_1*QI4_1)].x[2] = (A[n][k01/(QR4_1*QI4_1)].x[0] >> 4) & 0x0F0F0F0F; + A[n][k01/(QR4_1*QI4_1)].x[3] = (A[n][k01/(QR4_1*QI4_1)].x[1] >> 4) & 0x0F0F0F0F; + A[n][k01/(QR4_1*QI4_1)].x[0] &= 0x0F0F0F0F; + A[n][k01/(QR4_1*QI4_1)].x[1] &= 0x0F0F0F0F; #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/QI4_1]; + dmA[n][l][k01/(QR4_1*QI4_1)] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/(QR4_1*QI4_1)]; + } } } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - mma_B B; - half2 dsB[mma_C::ne/2]; +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) { + mma_B B; + half2 dsB[mma_C::ne/2]; - B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K); + B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); - dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; - } + dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]; + } #pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma_K8(A[n], B); + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n][k01/(QR4_1*QI4_1)], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2]; - sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); + for (int l = 0; l < mma_C::ne; ++l) { + const half2 dmA_dsB = dmA[n][l/2][k01/(QR4_1*QI4_1)]*dsB[l%2]; + sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); + } } } } @@ -531,8 +605,8 @@ template static __device__ __forceinlin qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; - x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; #else x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; @@ -553,106 +627,13 @@ template static __device__ __forceinlin const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; #ifdef INT8_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + kbxd] = bxi->d; + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d; #endif // INT8_MMA_AVAILABLE } } -template -static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + txs.qs; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl - (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE], - x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); - } - } -} - -template -static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { -#ifdef INT8_MMA_AVAILABLE - - typedef mma_int_A_I16K8 mma_A; - typedef mma_int_B_J8K8 mma_B; - typedef mma_int_C_I16J8 mma_C; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + WARP_SIZE*2; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - mma_A A[ntx]; - float dA[ntx][mma_C::ne/2]; - - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_0 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_0); - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I; - - dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + k0/QI5_0]; - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - mma_B B; - float dB[mma_C::ne/2]; - - B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K); - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); - - dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma_K8(A[n], B); - -#pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l]; - } - } - } -#else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); - NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE -} - template static __device__ __forceinline__ void load_tiles_q5_1( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -694,8 +675,8 @@ template static __device__ __forceinlin qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; - x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; #else x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; @@ -716,113 +697,19 @@ template static __device__ __forceinlin const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; #ifdef INT8_MMA_AVAILABLE - x_dm[i*MMQ_MMA_TILE_X_K_Q5_1 + kbxd] = bxi->dm; + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; #endif // INT8_MMA_AVAILABLE } } -template -static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + txs.qs; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl - (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE], - x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); - } - } -} - -template -static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { -#ifdef INT8_MMA_AVAILABLE - - typedef mma_int_A_I16K8 mma_A; - typedef mma_int_B_J8K8 mma_B; - typedef mma_int_C_I16J8 mma_C; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - - mma_A A[ntx]; - half2 dmA[ntx][mma_C::ne/2]; - - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_1 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_1); - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I; - - dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_1 + k0/QI5_1]; - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - mma_B B; - half2 dsB[mma_C::ne/2]; - - B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K); - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); - - dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma_K8(A[n], B); - -#pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2]; - sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); - } - } - } -#else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); - NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE -} - template static __device__ __forceinline__ void load_tiles_q8_0( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { #ifdef INT8_MMA_AVAILABLE int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_tile + WARP_SIZE); + float * x_df = (float *) (x_tile + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); int * x_qs = (int *) x_tile; @@ -843,18 +730,20 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x] = get_int_b2(bxi->qs, kqsx); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx); #else - x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b2(bxi->qs, kqsx); + x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx); + x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx); #endif // INT8_MMA_AVAILABLE } - const int blocks_per_tile_x_row = WARP_SIZE / QI8_0; + const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) { - int i = i0 + threadIdx.y * QI8_0 + threadIdx.x / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0/2) { + int i = i0 + threadIdx.y * (QI8_0/2) + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); @@ -863,16 +752,16 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; #ifdef INT8_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else - x_df[i*(WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d; + x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; #endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); const int * x_qs = (const int *) x; @@ -880,24 +769,29 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) { + const int k0 = k00 + k01; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl - (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0], - y_df[j*MMQ_TILE_Y_K + k0/QI8_1]); +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl + (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % WARP_SIZE], + x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]); + } } } } template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; @@ -911,49 +805,178 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + WARP_SIZE; + const float * x_df = (const float *) x_qs + 2*WARP_SIZE; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; - mma_A A[ntx]; - float dA[ntx][mma_C::ne/2]; + mma_A A[ntx][WARP_SIZE/QI8_0]; + float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0]; const int i0 = (threadIdx.y/ntx)*rows_per_warp; #pragma unroll for (int n = 0; n < ntx; ++n) { - A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + const int k0 = k00 + k01; + + A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); - dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + const int k0 = k00 + k01; + + dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; + } } } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - mma_B B; - float dB[mma_C::ne/2]; +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + const int k0 = k00 + k01; - B.load(y_qs + j0*MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K); + mma_B B; + float dB[mma_C::ne/2]; + + B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K); #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); - dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1]; + dB[l] = y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]; + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n][k01/QI8_0], B); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2]; + } + } + } + } +#else + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE +} + +template +static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl + (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { +#ifdef INT8_MMA_AVAILABLE + + typedef mma_int_A_I16K8 mma_A; + typedef mma_int_B_J8K8 mma_B; + typedef mma_int_C_I16J8 mma_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE; + const int * y_qs = (const int *) y + 4; + const half2 * y_dm = (const half2 *) y; + + mma_A A[ntx][WARP_SIZE/QI8_1]; + half2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1]; + + const int i0 = (threadIdx.y/ntx)*rows_per_warp; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + const int k0 = k00 + k01; + + A[n][k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); } #pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma_K8(A[n], B); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2]; + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + const int k0 = k00 + k01; + + dmA[n][l][k01/QI8_1] = x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]; + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + const int k0 = k00 + k01; + + mma_B B; + half2 dsB[mma_C::ne/2]; + + B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K); + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + dsB[l] = y_dm[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]; + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n][k01/QI8_1], B); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + const half2 dmA_dsB = dmA[n][l/2][k01/QI8_1]*dsB[l%2]; + sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); + } } } } @@ -968,44 +991,37 @@ template static __device__ __forceinlin #ifdef INT8_MMA_AVAILABLE int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + WARP_SIZE); + half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); #endif // INT8_MMA_AVAILABLE - const int kbx = threadIdx.x / QI2_K; const int kqsx = threadIdx.x % QI2_K; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_K) { + int i = i0 + threadIdx.y*(WARP_SIZE/QI2_K) + threadIdx.x/QI2_K; if (need_check) { i = min(i, i_max); } - const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx; + const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride; const int x_ql_0 = get_int_b2(bxi->qs, kqsx); #pragma unroll for (int l = 0; l < QR2_K; ++l) { - const int k = kbx*QI2_K + (kqsx/8)*8 + l*2 + (kqsx % 8)/4; + const int k = (kqsx/8)*32 + l*8 + kqsx % 8; - int x_qs_k = ((x_ql_0 >> (2*l)) & 0x03030303) << (2*(kqsx % 4)); - x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE); - x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 2, WARP_SIZE); - - if (kqsx % QR2_K != 0) { - continue; - } + const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; #ifdef INT8_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; #else - x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k; + x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k; #endif // INT8_MMA_AVAILABLE } @@ -1018,44 +1034,68 @@ template static __device__ __forceinlin #endif // FAST_FP16_AVAILABLE #ifdef INT8_MMA_AVAILABLE - x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + threadIdx.x] = x_dm_ik; + x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; #else - x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = x_dm_ik; + x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik; #endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + txs.qs; const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; + const half2 * y_ds = (const half2 *) y; + float2 y_df[mmq_x/nwarps]; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]); + } - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( - &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE], - &x_dm[i*(WARP_SIZE + 1) + k0], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]); +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + if (k01 < WARP_SIZE/2) { + constexpr int ns = 2; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } else { + constexpr int ns = 1; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } + } } } } template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K4 mma_A; + typedef mma_int_A_I16K8 mma_A_K8; typedef mma_int_B_J8K4 mma_B; typedef mma_int_C_I16J8 mma_C; @@ -1066,74 +1106,107 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE; + const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2; const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; + const half2 * y_ds = (const half2 *) y; const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - mma_A A[ntx][2]; - float dA[ntx][mma_C::ne/2][2]; - float mA[ntx][mma_C::ne/2][2]; + mma_A A[ntx][8]; + float dA[ntx][mma_C::ne/2][8]; + float mA[ntx][mma_C::ne/2][8]; #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + n*mma_A::I + mma_A::get_i(l); - const int shift = 2*mma_A::get_k(l); + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + const int k0 = k00 + k01; - A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 0] >> shift) & 0x03030303; - A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 1] >> shift) & 0x03030303; + ((mma_A_K8 *) A[n])[k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); } + } +#pragma unroll + for (int n = 0; n < ntx; ++n) { #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); #pragma unroll - for (int kdm = 0; kdm < 2; ++kdm) { - const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + kdm]); + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) { + const int k0 = k00 + k01; - dA[n][l][kdm] = dm.x; - mA[n][l][kdm] = dm.y; + const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]); + + dA[n][l][k01/(QI8_1/2)] = dm.x; + mA[n][l][k01/(QI8_1/2)] = dm.y; } } } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - mma_B B[2]; - float dB[mma_C::ne/2]; - - B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + 0) % WARP_SIZE, MMQ_TILE_Y_K); - B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K); + float2 dB[mma_C::ne/2]; #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); - dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)]; + dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]); } - mma_C Cm[2]; - mma_A A1; - A1.x[0] = 0x01010101; - A1.x[1] = 0x01010101; - Cm[0].mma_K4(A1, B[0]); - Cm[1].mma_K4(A1, B[1]); +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + mma_B B[2]; + + B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); + + mma_C Cm[2]; + if (k01 >= WARP_SIZE * 3/4) { + mma_A A1; + A1.x[0] = 0x01010101; + A1.x[1] = 0x01010101; + Cm[0].mma_K4(A1, B[0]); + Cm[1].mma_K4(A1, B[1]); + } #pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C Cd[2]; + for (int n = 0; n < ntx; ++n) { + mma_C Cd[2]; - Cd[0].mma_K4(A[n][0], B[0]); - Cd[1].mma_K4(A[n][1], B[1]); + Cd[0].mma_K4(A[n][k01/4 + 0], B[0]); + Cd[1].mma_K4(A[n][k01/4 + 1], B[1]); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += ( - Cd[0].x[l]*dA[n][l/2][0] + Cd[1].x[l]*dA[n][l/2][1] - Cm[0].x[l]*mA[n][l/2][0] - Cm[1].x[l]*mA[n][l/2][1])*dB[l%2]; + for (int l = 0; l < mma_C::ne; ++l) { + float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1]; + if (k01 >= WARP_SIZE * 3/4) { + tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1]; + } + sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y); + } + } + } + +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) { + float2 sB[mma_C::ne/2]; + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x; + sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y; + } } } } @@ -1149,7 +1222,7 @@ template static __device__ __forceinlin #ifdef INT8_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); - int * x_sc = (int *) (x_df + WARP_SIZE/QI3_K); + int * x_sc = (int *) (x_df + 1); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); int * x_qs = (int *) x_tile; @@ -1157,75 +1230,66 @@ template static __device__ __forceinlin int * x_sc = (int *) (x_df + txs.dm); #endif // INT8_MMA_AVAILABLE - const int kbx = threadIdx.x / QI3_K; const int kqsx = threadIdx.x % QI3_K; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI3_K) { + int i = i0 + threadIdx.y * (WARP_SIZE/QI3_K) + threadIdx.x / QI3_K; if (need_check) { i = min(i, i_max); } - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx; + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; const int x_ql_0 = get_int_b2(bxi->qs, kqsx); const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2))); #pragma unroll for (int l = 0; l < QR3_K; ++l) { - const int k = kbx*(QR3_K*QI3_K) + (kqsx/8)*32 + l*8 + kqsx % 8; + const int k = (kqsx/8)*32 + l*8 + kqsx % 8; const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303; const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404; - int x_qs_k = (x_ql_k | x_qh_k) << (4*(k%2)); - x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE); - - if (kqsx % 2 != 0) { - continue; - } + const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2] = x_qs_k; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; #else - x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k; + x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k; #endif // INT8_MMA_AVAILABLE } } - const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; - const int kbxd = threadIdx.x % blocks_per_tile_x_row; - #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) { - int i = (i0 + threadIdx.y * QI3_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) { + int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbxd; + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; #ifdef INT8_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + kbxd] = bxi->d; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K] = bxi->d; #else - x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbxd] = bxi->d; + x_df[i] = bxi->d; #endif // INT8_MMA_AVAILABLE } #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) { + int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8); if (need_check) { i = min(i, i_max); } - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI3_K/4); + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; - const int ksc = threadIdx.x % (QI3_K/4); + const int ksc = threadIdx.x % (WARP_SIZE/8); const int ksc_low = ksc % (QI3_K/8); const int shift_low = 4 * (ksc / (QI3_K/8)); @@ -1238,16 +1302,16 @@ template static __device__ __forceinlin const int sc = __vsubss4(sc_low | sc_high, 0x20202020); #ifdef INT8_MMA_AVAILABLE - x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/4)] = sc; + x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/8)] = sc; #else - x_sc[i*(WARP_SIZE/4) + i/4 + threadIdx.x % (WARP_SIZE/4)] = sc; + x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc; #endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); const int * x_qs = (const int *) x; @@ -1256,32 +1320,35 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + const int k0 = k00 + k01; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; - const int kbx = k0 / QI3_K; - const int ky = (k0 % QI3_K) * QR3_K; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; - const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; + const int8_t * scales = ((const int8_t *) (x_sc + i*(WARP_SIZE/8) + i/8)) + k0/4; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq( - &x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales, - x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]); + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq( + &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales, + x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + } } } } template static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K4 mma_A; + typedef mma_int_A_I16K8 mma_A_K8; typedef mma_int_B_J8K4 mma_B; typedef mma_int_C_I16J8 mma_C; @@ -1293,73 +1360,74 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma( const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + WARP_SIZE*2; - const int * x_sc = (const int *) x_df + WARP_SIZE/QI3_K; + const int * x_sc = (const int *) x_df + 1; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - mma_A A[ntx][2]; - int scA[ntx][mma_C::ne/2][2]; + mma_A A[ntx][8]; + int scA[ntx][mma_C::ne/2][8]; float dA[ntx][mma_C::ne/2]; #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + n*mma_A::I + mma_A::get_i(l); - const int k = QR3_K*k0 + mma_A::get_k(l); + for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + const int k0 = k00 + k01; - A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + 0] >> (4*(k%2))) & 0x0F0F0F0F; - A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F; - A[n][0].x[l] = __vsubss4(A[n][0].x[l], 0x04040404); - A[n][1].x[l] = __vsubss4(A[n][1].x[l], 0x04040404); + ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - const int kbx = k0 / QI3_K; - const int ky = (k0 % QI3_K) * QR3_K; - const int8_t * sc = ((const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q3_K + kbx*4)) + ky/4; +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) { + const int k0 = k00 + k01; - scA[n][l][0] = sc[0]; - scA[n][l][1] = sc[1]; - } + const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + k0/16]; + const int8_t * sc = (const int8_t *) &sc_packed; #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + for (int ksc = 0; ksc < sizeof(int); ++ksc) { + scA[n][l][k01/4 + ksc] = sc[ksc]; + } + } - dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/QI3_K]; + dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K]; } } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - mma_B B[2]; - float dB[mma_C::ne/2]; +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + mma_B B[2]; + float dB[mma_C::ne/2]; - B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + 0) % WARP_SIZE, MMQ_TILE_Y_K); - B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K); + B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); - dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)]; - } + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } #pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C[2]; - C[0].mma_K4(A[n][0], B[0]); - C[1].mma_K4(A[n][1], B[1]); + for (int n = 0; n < ntx; ++n) { + mma_C C[2]; + C[0].mma_K4(A[n][k01/4 + 0], B[0]); + C[1].mma_K4(A[n][k01/4 + 1], B[1]); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += (C[0].x[l]*scA[n][l/2][0] + C[1].x[l]*scA[n][l/2][1])*dA[n][l/2]*dB[l%2]; + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]* + (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1]); + } } } } @@ -1451,7 +1519,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); const int * x_qs = (const int *) x; @@ -1460,26 +1528,31 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) { + const int k0 = k00 + k01; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8); +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq( - &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8, - x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]); + const uint8_t * sc = (const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/32] + 2*(k01/16); + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq( + &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, + x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } } } } template static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; @@ -1500,35 +1573,40 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - mma_A A[ntx][2]; - int scA[ntx][mma_C::ne/2][2]; - int mA[ntx][mma_C::ne/2][2]; + mma_A A[ntx][4]; + int scA[ntx][mma_C::ne/2][4]; + int mA[ntx][mma_C::ne/2][4]; half2 dmA[ntx][mma_C::ne/2]; #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 8) { - A[n][kvdr/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0, MMQ_MMA_TILE_X_K_Q4_K); + for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) { + const int k0 = k00 + k01; + + A[n][k01/8 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0/QR4_K, MMQ_MMA_TILE_X_K_Q4_K); #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { - A[n][kvdr/4 + 1].x[l] = (A[n][kvdr/4 + 0].x[l] >> 4) & 0x0F0F0F0F; - A[n][kvdr/4 + 0].x[l] &= 0x0F0F0F0F; + A[n][k01/8 + 1].x[l] = (A[n][k01/8 + 0].x[l] >> 4) & 0x0F0F0F0F; + A[n][k01/8 + 0].x[l] &= 0x0F0F0F0F; } } #pragma unroll - for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) { + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + + const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 0)]; + const int m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 2)]; + + const uint8_t * sc = (const uint8_t *) &sc_packed; + const uint8_t * m = (const uint8_t *) &m_packed; + #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); - - const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + k0/16]) + 2 * ((k0 % 16) / 8); - const uint8_t * m = sc + 8; - - scA[n][l][kvdr/4] = sc[kvdr/4]; - mA[n][l][kvdr/4] = m[kvdr/4]; + for (int ksc = 0; ksc < sizeof(int); ++ksc) { + scA[n][l][ksc] = sc[ksc]; + mA[n][l][ksc] = m[ksc]; } } @@ -1536,7 +1614,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( for (int l = 0; l < mma_C::ne/2; ++l) { const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); - dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K + k0/QI4_K]; + dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K]; } } @@ -1546,28 +1624,28 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( float tmpm[ntx][mma_C::ne] = {{0.0f}}; #pragma unroll - for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) { + for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { mma_B B; half2 dsB[mma_C::ne/2]; - B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K); + B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); - dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)]; + dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]; } #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C C; - C.mma_K8(A[n][kvdr/4], B); + C.mma_K8(A[n][k01/8], B); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { - tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) * __low2float(dsB[l%2]); - tmpm[n][l] += mA[n][l/2][kvdr/4] * __high2float(dsB[l%2]); + tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) * __low2float(dsB[l%2]); + tmpm[n][l] += mA[n][l/2][k01/8] * __high2float(dsB[l%2]); } } } @@ -1682,7 +1760,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); const int * x_qs = (const int *) x; @@ -1691,26 +1769,31 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) { + const int k0 = k00 + k01; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8); +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq( - &x_qs[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8, - x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]); + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16); + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq( + &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, + x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } } } } template static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; @@ -1731,26 +1814,34 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - mma_A A[ntx][2]; - int scA[ntx][mma_C::ne/2][2]; - int mA[ntx][mma_C::ne/2][2]; + mma_A A[ntx][4]; + int scA[ntx][mma_C::ne/2][4]; + int mA[ntx][mma_C::ne/2][4]; half2 dmA[ntx][mma_C::ne/2]; #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) { - A[n][kvdr/4].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + (QR5_K*k0 + QR5_K*kvdr), MMQ_MMA_TILE_X_K_Q5_K); + for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + const int k0 = k00 + k01; + + A[n][k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + k0, MMQ_MMA_TILE_X_K_Q5_K); + } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + k0/16]) + 2 * ((k0 % 16) / 8); - const uint8_t * m = sc + 8; + const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 0)]; + const int m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 2)]; - scA[n][l][kvdr/4] = sc[kvdr/4]; - mA[n][l][kvdr/4] = m[kvdr/4]; + const uint8_t * sc = (const uint8_t *) &sc_packed; + const uint8_t * m = (const uint8_t *) &m_packed; + +#pragma unroll + for (int ksc = 0; ksc < sizeof(int); ++ksc) { + scA[n][l][ksc] = sc[ksc]; + mA[n][l][ksc] = m[ksc]; } } @@ -1758,7 +1849,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( for (int l = 0; l < mma_C::ne/2; ++l) { const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K + k0/QI5_K]; + dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K]; } } @@ -1768,28 +1859,30 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( float tmpm[ntx][mma_C::ne] = {{0.0f}}; #pragma unroll - for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) { + for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + const int k0 = k00 + k01; + mma_B B; half2 dsB[mma_C::ne/2]; - B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K); + B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K); #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); - dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)]; + dsB[l] = y_ds[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]; } #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C C; - C.mma_K8(A[n][kvdr/4], B); + C.mma_K8(A[n][k01/8], B); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { - tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) * __low2float(dsB[l%2]); - tmpm[n][l] += mA[n][l/2][kvdr/4] * __high2float(dsB[l%2]); + tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) * __low2float(dsB[l%2]); + tmpm[n][l] += mA[n][l/2][k01/8] * __high2float(dsB[l%2]); } } } @@ -1896,7 +1989,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); const int * x_qs = (const int *) x; @@ -1905,26 +1998,31 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) { + const int k0 = k00 + k01; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; - const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]); +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq( - &x_qs[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc, - x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]); + const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]); + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq( + &x_qs[i*(QR6_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, + x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + } } } } template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K4 mma_A; @@ -1945,25 +2043,35 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - mma_A A[ntx][4]; - int scA[ntx][mma_C::ne/2][4]; + mma_A A[ntx][8]; + int scA[ntx][mma_C::ne/2][8]; float dA[ntx][mma_C::ne/2]; #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) { - A[n][kvdr/2 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + 0), MMQ_MMA_TILE_X_K_Q6_K); - A[n][kvdr/2 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K); + for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + const int k0 = k00 + k01; + + A[n][k01/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); + A[n][k01/4 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K); + } + +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) { + const int k0 = k00 + k01; #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - const int8_t * sc = ((const int8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/8]); + const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16]; + const int8_t * sc = (const int8_t *) &sc_packed; - scA[n][l][kvdr/2 + 0] = sc[kvdr/2 + 0]; - scA[n][l][kvdr/2 + 1] = sc[kvdr/2 + 1]; +#pragma unroll + for (int ksc = 0; ksc < sizeof(int); ++ksc) { + scA[n][l][k01/4 + ksc] = sc[ksc]; + } } } @@ -1971,7 +2079,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( for (int l = 0; l < mma_C::ne/2; ++l) { const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K + k0/QI6_K]; + dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K]; } } @@ -1980,30 +2088,29 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( float tmp[ntx][mma_C::ne] = {{0.0f}}; #pragma unroll - for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) { + for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { mma_B B[2]; float dB[mma_C::ne/2]; - const int k0B = (2*k0 + 2*kvdr) % WARP_SIZE; - B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0 + k0B, MMQ_TILE_Y_K); - B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k0B, MMQ_TILE_Y_K); + B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K); + B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K); #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); - dB[l] = y_df[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)]; + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; } #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C C[2]; - C[0].mma_K4(A[n][kvdr/2 + 0], B[0]); - C[1].mma_K4(A[n][kvdr/2 + 1], B[1]); + C[0].mma_K4(A[n][k01/4 + 0], B[0]); + C[1].mma_K4(A[n][k01/4 + 1], B[1]); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { - tmp[n][l] += (C[0].x[l]*scA[n][l/2][kvdr/2 + 0] + C[1].x[l]*scA[n][l/2][kvdr/2 + 1])*dB[l%2]; + tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2]; } } } @@ -2051,8 +2158,8 @@ template static __device__ __forceinlin const int2 v = get_int_from_table_16(aux_q4); const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 0] = v.x; - x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 4] = v.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; #else x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; @@ -2073,7 +2180,7 @@ template static __device__ __forceinlin const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; #ifdef INT8_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + kbxd] = __half2float(bxi->d); + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); #else x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d); #endif // INT8_MMA_AVAILABLE @@ -2109,8 +2216,8 @@ template static __device__ __forceinlin const int2 v = get_int_from_table_16(aux_q4); const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 0] = v.x; - x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 4] = v.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; #else x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; @@ -2133,7 +2240,7 @@ template static __device__ __forceinlin | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); #ifdef INT8_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + threadIdx.x % 8] = d * (ls - 32); + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); #else x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); #endif // INT8_MMA_AVAILABLE @@ -2229,16 +2336,16 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_1_q8_1_dp4a; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; }; template @@ -2293,45 +2400,18 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -static bool mmq_need_sum(const ggml_type type_x) { - switch (type_x) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - return true; - case GGML_TYPE_Q5_0: - return false; - case GGML_TYPE_Q5_1: - return true; - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - return false; - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - return true; - case GGML_TYPE_Q6_K: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ4_NL: - return false; - default: - GGML_ASSERT(false); - break; - } - return false; -} - template static __device__ void mul_mat_q_process_tile( const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup, @@ -2339,10 +2419,7 @@ static __device__ void mul_mat_q_process_tile( const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) { constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int qr = ggml_cuda_type_traits::qr; - constexpr int qi = ggml_cuda_type_traits::qi; constexpr int mmq_y = get_mmq_y_device(); - constexpr int vdr = mmq_type_traits::vdr; constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; extern __shared__ char data_mul_mat_q[]; @@ -2357,7 +2434,7 @@ static __device__ void mul_mat_q_process_tile( constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; #endif // INT8_MMA_AVAILABLE - constexpr int blocks_per_warp = WARP_SIZE / qi; + constexpr int blocks_per_iter = MMQ_ITER_K / qk; float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; @@ -2366,29 +2443,40 @@ static __device__ void mul_mat_q_process_tile( const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int)); - for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_warp) { - + for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01); -#pragma unroll - for (int kr = 0; kr < qr; ++kr) { - const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + kr*sizeof(block_q8_1_mmq)/sizeof(int)); + { + const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int)); #pragma unroll for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) { int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x; tile_y[l] = by0[l]; } - - __syncthreads(); - -// #pragma unroll // unrolling this loop causes too much register pressure - for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) { - vec_dot(tile_x, tile_y, sum, k0); - } - - __syncthreads(); } + + __syncthreads(); + + vec_dot(tile_x, tile_y, sum, 0); + + __syncthreads(); + + { + const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int)); +#pragma unroll + for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) { + int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x; + + tile_y[l] = by0[l]; + } + } + + __syncthreads(); + + vec_dot(tile_x, tile_y, sum, WARP_SIZE); + + __syncthreads(); } if (fixup) { @@ -2424,7 +2512,6 @@ static __global__ void mul_mat_q( } constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int qi = ggml_cuda_type_traits::qi; constexpr int mmq_y = get_mmq_y_device(); // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: @@ -2439,7 +2526,7 @@ static __global__ void mul_mat_q( #endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA const int64_t blocks_per_ne00 = ne00 / qk; - constexpr int blocks_per_warp = WARP_SIZE / qi; + constexpr int blocks_per_iter = MMQ_ITER_K / qk; const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y @@ -2448,8 +2535,8 @@ static __global__ void mul_mat_q( int64_t kbc = (int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x; int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x; - kbc -= (kbc % blocks_per_ne00) % blocks_per_warp; - kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_warp; + kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; + kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter; // kb0 == k index when doing the matrix multiplication for an output tile. int kb0_start = kbc % blocks_per_ne00; @@ -2490,8 +2577,7 @@ static __global__ void mul_mat_q_stream_k_fixup( constexpr int mmq_y = get_mmq_y_device(); constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int qi = ggml_cuda_type_traits::qi; - constexpr int blocks_per_warp = WARP_SIZE / qi; + constexpr int blocks_per_iter = MMQ_ITER_K / qk; const int64_t blocks_per_ne00 = ne00 / qk; float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; @@ -2501,15 +2587,18 @@ static __global__ void mul_mat_q_stream_k_fixup( bool any_fixup = false; - const int bidx_start = (blockIdx.y*nty + blockIdx.x) * block_num_mmq / (gridDim.y*gridDim.x); - const int bidx_stop = (blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq / (gridDim.y*gridDim.x) + 1; + const int bidx_start = ((blockIdx.y*nty + blockIdx.x) * block_num_mmq) / (gridDim.y*gridDim.x); + const int bidx_stop = ((blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq + gridDim.y*gridDim.x - 1) / (gridDim.y*gridDim.x); + + int64_t kbc_0; + int64_t kbc_stop_0 = (int64_t) bidx_start*blocks_per_ne00*ntx*nty / block_num_mmq; for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) { - int64_t kbc = (int64_t) bidx *blocks_per_ne00*ntx*nty / block_num_mmq; - int64_t kbc_stop = (int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq; + kbc_0 = kbc_stop_0; + kbc_stop_0 = (int64_t) (bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq; - kbc -= (kbc % blocks_per_ne00) % blocks_per_warp; - kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_warp; + const int64_t kbc = kbc_0 - (kbc_0 % blocks_per_ne00) % blocks_per_iter; + const int64_t kbc_stop = kbc_stop_0 - (kbc_stop_0 % blocks_per_ne00) % blocks_per_iter; // Skip fixup tile if the MMQ CUDA block never wrote anything to it: if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) { diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index b46786822..aa7f1eff0 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -37,47 +37,92 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest reinterpret_cast(y[ib].ds.y) = sum; } -template +template static __global__ void quantize_mmq_q8_1( const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) { - const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; + constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32; + constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32; + + const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4; if (ix0 >= kx0_padded) { return; } + const float4 * x4 = (const float4 *) x; + const int64_t ix1 = kx1*blockIdx.z + blockIdx.y; block_q8_1_mmq * y = (block_q8_1_mmq *) vy; - const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/(4*QK8_1)); // first block of channel - const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel - const int64_t iqs = ix0 % (4*QK8_1); // quant index in block + const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel + const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel + const int64_t iqs = ix0 % (4*QK8_1); // quant index in block - const float xi = ix0 < kx0 ? x[ix1*kx0 + ix0] : 0.0f; - float amax = fabsf(xi); + // Load 4 floats per thread and calculate max. abs. value between them: + const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f); + float amax = fabsf(xi.x); + amax = fmaxf(amax, fabsf(xi.y)); + amax = fmaxf(amax, fabsf(xi.z)); + amax = fmaxf(amax, fabsf(xi.w)); - amax = warp_reduce_max(amax); - - float sum; - if (need_sum) { - sum = warp_reduce_sum(xi); + // Exchange max. abs. value between vals_per_scale/4 threads. +#pragma unroll + for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE)); } - const float d = amax / 127; - const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); + float sum; + if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) { + sum = xi.x + xi.y + xi.z + xi.w; - y[ib].qs[iqs] = q; + // Exchange calculate sum across vals_per_sum/4 threads. +#pragma unroll + for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) { + sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE); + } + } + + const float d_inv = 127.0f / amax; + char4 q; + q.x = roundf(xi.x*d_inv); + q.y = roundf(xi.y*d_inv); + q.z = roundf(xi.z*d_inv); + q.w = roundf(xi.w*d_inv); + + // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth: + char4 * yqs4 = (char4 *) y[ib].qs; + yqs4[iqs/4] = q; + + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) { + if (iqs % 16 != 0 || iqs >= 96) { + return; + } + + y[ib].d2s6[2 + iqs/16] = sum; + + if (iqs % 64 != 0) { + return; + } + + const float d = 1.0f / d_inv; + + y[ib].d2s6[iqs/64] = d; - if (iqs % QK8_1 != 0) { return; } - if (need_sum) { - y[ib].ds[iqs/QK8_1] = make_half2(d, sum); + if (iqs % 32 != 0) { + return; + } + + const float d = 1.0f / d_inv; + + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) { + y[ib].ds4[iqs/32] = make_half2(d, sum); } else { - ((float *) y[ib].ds)[iqs/QK8_1] = d; + y[ib].d4[iqs/32] = d; } } @@ -101,12 +146,24 @@ void quantize_mmq_q8_1_cuda( GGML_ASSERT(kx0_padded % (4*QK8_1) == 0); - const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; + const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); const dim3 num_blocks(block_num_x, kx1, channels); - const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); - if (mmq_need_sum(type_x)) { - quantize_mmq_q8_1<<>>(x, vy, kx0, kx1, kx0_padded); - } else { - quantize_mmq_q8_1<<>>(x, vy, kx0, kx1, kx0_padded); + const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1); + switch (mmq_get_q8_1_ds_layout(type_x)) { + case MMQ_Q8_1_DS_LAYOUT_D4: + quantize_mmq_q8_1 + <<>>(x, vy, kx0, kx1, kx0_padded); + break; + case MMQ_Q8_1_DS_LAYOUT_DS4: + quantize_mmq_q8_1 + <<>>(x, vy, kx0, kx1, kx0_padded); + break; + case MMQ_Q8_1_DS_LAYOUT_D2S6: + quantize_mmq_q8_1 + <<>>(x, vy, kx0, kx1, kx0_padded); + break; + default: + GGML_ASSERT(false); + break; } } diff --git a/ggml/src/ggml-cuda/quantize.cuh b/ggml/src/ggml-cuda/quantize.cuh index 486c9360a..03bf322b9 100644 --- a/ggml/src/ggml-cuda/quantize.cuh +++ b/ggml/src/ggml-cuda/quantize.cuh @@ -5,7 +5,11 @@ #include -#define CUDA_QUANTIZE_BLOCK_SIZE 256 +#define CUDA_QUANTIZE_BLOCK_SIZE 256 +#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128 + +static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk of out-of-bounds access."); +static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access."); typedef void (*quantize_cuda_t)( const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 1d510484a..6a17d0f3e 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -189,7 +189,7 @@ template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp } #define VDR_Q2_K_Q8_1_MMVQ 1 -#define VDR_Q2_K_Q8_1_MMQ 2 +#define VDR_Q2_K_Q8_1_MMQ 4 // contiguous v/x values static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( @@ -219,32 +219,56 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( return dm2f.x*sumf_d - dm2f.y*sumf_m; } -// contiguous u/y values +// contiguous v/x + u/y values +template static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( - const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8) { + const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8, const half2 * s8) { - float sumf_d = 0.0f; - float sumf_m = 0.0f; + float sumf = 0.0f; + float sumf_d8 = 0.0f; #pragma unroll - for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) { - const float2 dm2f = __half22float2(dm2[i0/(QI8_1/2)]); - int sumi_d = 0; - int sumi_m = 0; + for (int i0 = 0; i0 < QR2_K*VDR_Q2_K_Q8_1_MMQ; i0 += QI8_1) { + const float2 dm2f0 = __half22float2(dm2[i0/(QI8_1/2) + 0]); + int sumi_d0 = 0; + + const float2 dm2f1 = __half22float2(dm2[i0/(QI8_1/2) + 1]); + int sumi_d1 = 0; - const int vi0 = v[i0/(QI8_1/2)]; #pragma unroll for (int i = i0; i < i0 + QI8_1/2; ++i) { - const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303; - sumi_d = ggml_cuda_dp4a(vi, u[i], sumi_d); // SIMD dot product - sumi_m = ggml_cuda_dp4a(0x01010101, u[i], sumi_m); + sumi_d0 = ggml_cuda_dp4a(v[i], u[i], sumi_d0); } + sumf_d8 += dm2f0.x * sumi_d0; - sumf_d += dm2f.x * sumi_d; - sumf_m += dm2f.y * sumi_m; +#pragma unroll + for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) { + sumi_d1 = ggml_cuda_dp4a(v[i], u[i], sumi_d1); + } + sumf_d8 += dm2f1.x * sumi_d1; + + if (i0/QI8_1 < ns8) { + const float2 s8f = __half22float2(s8[i0/QI8_1]); + sumf -= dm2f0.y*s8f.x; + sumf -= dm2f1.y*s8f.y; + } else { + int sumi_m0 = 0; +#pragma unroll + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_m0 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m0); + } + sumf_d8 -= dm2f0.y * sumi_m0; + + int sumi_m1 = 0; +#pragma unroll + for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) { + sumi_m1 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m1); + } + sumf_d8 -= dm2f1.y * sumi_m1; + } } - return d8*(sumf_d - sumf_m); + return sumf + d8*sumf_d8; } #define VDR_Q3_K_Q8_1_MMVQ 1 @@ -283,7 +307,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( return d3 * sumf; } -// contiguous u/y values +// contiguous v/x + u/y values static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales, const float & d3, const float & d8) { @@ -296,8 +320,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( #pragma unroll for (int i = i0; i < i0 + QI8_1/2; ++i) { - const int vi = __vsubss4((v[i/2] >> (4*(i%2))) & 0x0F0F0F0F, 0x04040404); - sumi_sc = ggml_cuda_dp4a(vi, u[i], sumi_sc); // SIMD dot product + sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product } sumi += sumi_sc * scales[i0 / (QI8_1/2)]; @@ -334,7 +357,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( return dm4f.x*sumf_d - dm4f.y*sumf_m; } -// contiguous u/y values +// contiguous v/x + u/y values static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { @@ -397,7 +420,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( return dm5f.x*sumf_d - dm5f.y*sumf_m; } -// contiguous u/y values +// contiguous v/x + u/y values static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { @@ -451,13 +474,16 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( return d*sumf; } -// contiguous u/y values +// contiguous v/x + u/y values static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc, const float & d6, const float * __restrict__ d8) { float sumf_d = 0.0f; + const int sc_packed = get_int_b4(sc, 0); + const int8_t * sc_reg = (const int8_t *) &sc_packed; + #pragma unroll for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) { int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale @@ -471,7 +497,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product } - sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y); + sumf_d += d8[i0/4] * (sc_reg[i0/2+0]*sumi_d.x + sc_reg[i0/2+1]*sumi_d.y); } return d6 * sumf_d; From b078c619aa4e97fe726d61b0b5499a2e19a418a4 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 11 Jul 2024 17:53:42 +0200 Subject: [PATCH 16/26] cuda : suppress 'noreturn' warn in no_device_code (#8414) * cuda : suppress 'noreturn' warn in no_device_code This commit adds a while(true) loop to the no_device_code function in common.cuh. This is done to suppress the warning: ```console /ggml/src/ggml-cuda/template-instances/../common.cuh:346:1: warning: function declared 'noreturn' should not return [-Winvalid-noreturn] 346 | } | ^ ``` The motivation for this is to reduce the number of warnings when compilng with GGML_HIPBLAS=ON. Signed-off-by: Daniel Bevenius * squash! cuda : suppress 'noreturn' warn in no_device_code Update __trap macro instead of using a while loop to suppress the warning. Signed-off-by: Daniel Bevenius --------- Signed-off-by: Daniel Bevenius --- ggml/src/ggml-cuda/common.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 4ff06b871..26d9412a2 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -104,7 +104,7 @@ #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) #define cudaStream_t hipStream_t #define cudaSuccess hipSuccess -#define __trap abort +#define __trap() do { abort(); __builtin_unreachable(); } while(0) #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED #define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED From 368645698ab648e390dcd7c00a2bf60efa654f57 Mon Sep 17 00:00:00 2001 From: Nicholai Tukanov Date: Thu, 11 Jul 2024 11:49:15 -0500 Subject: [PATCH 17/26] ggml : add NVPL BLAS support (#8329) (#8425) * ggml : add NVPL BLAS support * ggml : replace `_ENABLE_CBLAS` with `GGML_BLAS_USE_` --------- Co-authored-by: ntukanov --- Makefile | 8 +++++++- ggml/src/ggml-blas.cpp | 13 +++++++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index 668b38b99..4869d2ecf 100644 --- a/Makefile +++ b/Makefile @@ -547,11 +547,17 @@ ifdef GGML_OPENBLAS64 endif # GGML_OPENBLAS64 ifdef GGML_BLIS - MK_CPPFLAGS += -DGGML_USE_BLAS -I/usr/local/include/blis -I/usr/include/blis + MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_BLIS -I/usr/local/include/blis -I/usr/include/blis MK_LDFLAGS += -lblis -L/usr/local/lib OBJ_GGML += ggml/src/ggml-blas.o endif # GGML_BLIS +ifdef GGML_NVPL + MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_NVPL -DNVPL_ILP64 -I/usr/local/include/nvpl_blas -I/usr/include/nvpl_blas + MK_LDFLAGS += -L/usr/local/lib -lnvpl_blas_core -lnvpl_blas_ilp64_gomp + OBJ_GGML += ggml/src/ggml-blas.o +endif # GGML_NVPL + ifndef GGML_NO_LLAMAFILE MK_CPPFLAGS += -DGGML_USE_LLAMAFILE OBJ_GGML += ggml/src/llamafile/sgemm.o diff --git a/ggml/src/ggml-blas.cpp b/ggml/src/ggml-blas.cpp index d709a357b..a37aa4072 100644 --- a/ggml/src/ggml-blas.cpp +++ b/ggml/src/ggml-blas.cpp @@ -8,11 +8,12 @@ # include #elif defined(GGML_BLAS_USE_MKL) # include +#elif defined(GGML_BLAS_USE_BLIS) +# include +#elif defined(GGML_BLAS_USE_NVPL) +# include #else # include -# ifdef BLIS_ENABLE_CBLAS -# include -# endif #endif struct ggml_backend_blas_context { @@ -140,10 +141,14 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg openblas_set_num_threads(ctx->n_threads); #endif -#if defined(BLIS_ENABLE_CBLAS) +#if defined(GGML_BLAS_USE_BLIS) bli_thread_set_num_threads(ctx->n_threads); #endif +#if defined(GGML_BLAS_USE_NVPL) + nvpl_blas_set_num_threads(ctx->n_threads); +#endif + for (int64_t i13 = 0; i13 < ne13; i13++) { for (int64_t i12 = 0; i12 < ne12; i12++) { const int64_t i03 = i13/r3; From b549a1bbefb2f1fbb8b558bac1f2ae7967e60964 Mon Sep 17 00:00:00 2001 From: Chen Xi Date: Fri, 12 Jul 2024 00:52:04 +0000 Subject: [PATCH 18/26] [SYCL] fix the mul_mat_id ut issues (#8427) * fix part of mul_mat_id * skip the bfloat 16 sycl ut Signed-off-by: Chen Xi --------- Signed-off-by: Chen Xi Co-authored-by: Meng, Hengyu Co-authored-by: Chen Xi --- ggml/src/ggml-backend.c | 2 +- ggml/src/ggml-sycl.cpp | 49 +++++++++++------------------------------ src/llama.cpp | 7 ------ 3 files changed, 14 insertions(+), 44 deletions(-) diff --git a/ggml/src/ggml-backend.c b/ggml/src/ggml-backend.c index 13c71c310..dbbaa3941 100644 --- a/ggml/src/ggml-backend.c +++ b/ggml/src/ggml-backend.c @@ -394,7 +394,7 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) // backend registry -#define GGML_REG_MAX_BACKENDS 16 +#define GGML_REG_MAX_BACKENDS 64 struct ggml_backend_reg { char name[128]; diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp index 9c419ba89..5a890237f 100644 --- a/ggml/src/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl.cpp @@ -3768,37 +3768,13 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids)))); SYCL_CHECK(CHECK_TRY_ERROR(stream->wait())); - const ggml_tensor_extra_gpu *src0_extra = - (const ggml_tensor_extra_gpu *)src0->extra; - const ggml_tensor_extra_gpu *src1_extra = - (const ggml_tensor_extra_gpu *)src1->extra; - const ggml_tensor_extra_gpu *dst_extra = - (const ggml_tensor_extra_gpu *)dst->extra; - - ggml_tensor_extra_gpu src0_row_extra; - ggml_tensor_extra_gpu src1_row_extra; - ggml_tensor_extra_gpu dst_row_extra; - ggml_tensor src0_row = *src0; ggml_tensor src1_row = *src1; ggml_tensor dst_row = *dst; - src1_row.backend = GGML_BACKEND_TYPE_GPU; - dst_row.backend = GGML_BACKEND_TYPE_GPU; - - src0_row.extra = &src0_row_extra; - src1_row.extra = &src1_row_extra; - dst_row.extra = &dst_row_extra; - - char *src0_original = src1->backend == GGML_BACKEND_TYPE_CPU - ? (char *)src0->data - : (char *)src0_extra->data_device[ctx.device]; - char *src1_original = src1->backend == GGML_BACKEND_TYPE_CPU - ? (char *)src1->data - : (char *)src1_extra->data_device[ctx.device]; - char *dst_original = dst->backend == GGML_BACKEND_TYPE_CPU - ? (char *)dst->data - : (char *)dst_extra->data_device[ctx.device]; + char *src0_original = (char *)src0->data; + char *src1_original = (char *)src1->data; + char *dst_original = (char *)dst->data; src0_row.ne[2] = 1; src0_row.ne[3] = 1; @@ -3827,12 +3803,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten const int64_t i1 = id; const int64_t i2 = i12; - src0_row_extra.data_device[ctx.device] = - src0_original + i02*nb02; - src1_row_extra.data_device[ctx.device] = - src1_original + + i11*nb11 + i12*nb12; - dst_row_extra.data_device[ctx.device] = - dst_original + i1*nb1 + i2*nb2; + src0_row.data = src0_original + i02*nb02; + src1_row.data = src1_original + + i11*nb11 + i12*nb12; + dst_row.data = dst_original + i1*nb1 + i2*nb2; ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row); } @@ -3841,8 +3814,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten ggml_sycl_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); ggml_sycl_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); - src1_row_extra.data_device[ctx.device] = src1_contiguous.get(); - dst_row_extra.data_device[ctx.device] = dst_contiguous.get(); + src1_row.data = src1_contiguous.get(); + dst_row.data = dst_contiguous.get(); for (int64_t i02 = 0; i02 < n_as; i02++) { int64_t num_src1_rows = 0; @@ -3898,7 +3871,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten }); } - src0_row_extra.data_device[ctx.device] = src0_original + i02*nb02; + src0_row.data = src0_original + i02*nb02; GGML_ASSERT(nb11 == sizeof(float)*ne10); GGML_ASSERT(nb1 == sizeof(float)*ne0); @@ -5221,6 +5194,10 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons return false; } } + ggml_type src0_type = op->src[0]->type; + if (src0_type == GGML_TYPE_BF16) { + return false; + } return true; } break; case GGML_OP_GET_ROWS: diff --git a/src/llama.cpp b/src/llama.cpp index ed77ed918..f91ac7779 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5883,13 +5883,6 @@ static bool llm_load_tensors( auto & hparams = model.hparams; -#ifdef GGML_USE_SYCL - // disable MoE with SYCL until mul_mat_id is updated - if (hparams.n_expert > 0) { - n_gpu_layers = 0; - } -#endif - model.split_mode = split_mode; model.main_gpu = main_gpu; model.n_gpu_layers = n_gpu_layers; From 370b1f7e7a7514d9a63daf15f7b0f00319b7f908 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Jul 2024 10:46:02 +0300 Subject: [PATCH 19/26] ggml : minor naming changes (#8433) * ggml : minor naming changes ggml-ci * ggml : use PRId64 [no ci] * ggml : revert FA K/Q names --- examples/quantize-stats/quantize-stats.cpp | 2 +- ggml/include/ggml.h | 50 +++++----- ggml/src/ggml-aarch64.c | 68 +++++++------ ggml/src/ggml-aarch64.h | 14 +-- ggml/src/ggml-quants.c | 94 ++++++++--------- ggml/src/ggml-quants.h | 34 +++---- ggml/src/ggml.c | 111 ++++++++++----------- tests/test-double-float.cpp | 4 +- tests/test-quantize-fns.cpp | 2 +- tests/test-quantize-perf.cpp | 2 +- 10 files changed, 192 insertions(+), 189 deletions(-) diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 746df8446..68cf8d359 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -154,7 +154,7 @@ static void test_roundtrip_on_chunk( } if (use_reference) { - qfns.from_float_reference(input_scratch, quantized_scratch, chunk_size); + qfns.from_float_ref(input_scratch, quantized_scratch, chunk_size); } else { qfns.from_float(input_scratch, quantized_scratch, chunk_size); } diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 1e3677537..f2145ff35 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -714,9 +714,9 @@ extern "C" { GGML_API GGML_CALL size_t ggml_nbytes (const struct ggml_tensor * tensor); GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN - GGML_API GGML_CALL int ggml_blck_size(enum ggml_type type); - GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block - GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row + GGML_API GGML_CALL int64_t ggml_blck_size(enum ggml_type type); + GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block + GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row GGML_DEPRECATED( GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float @@ -2410,31 +2410,31 @@ extern "C" { #endif typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); - typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx, - const void * GGML_RESTRICT y, size_t by, int nrc); - typedef void (*ggml_from_float_to_mat_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr, - int64_t k, int64_t bx); - typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, - const void * GGML_RESTRICT y, int nr, int nc); - typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, - const void * GGML_RESTRICT y, int nr, int nc); + typedef void (*ggml_from_float_to_mat_t) + (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr, int64_t k, int64_t bs); + typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx, + const void * GGML_RESTRICT y, size_t by, int nrc); + typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, + const void * GGML_RESTRICT y, int nr, int nc); + typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, + const void * GGML_RESTRICT y, int nr, int nc); typedef struct { - const char * type_name; - int blck_size; - size_t type_size; - bool is_quantized; - ggml_to_float_t to_float; - ggml_from_float_t from_float; - ggml_from_float_t from_float_reference; - ggml_vec_dot_t vec_dot; - enum ggml_type vec_dot_type; - int64_t nrows; // number of rows to process simultaneously; - int64_t ncols; // number of columns to process simultaneously; - int64_t interleave_blcksize; // interleave elements in blocks of interleave_blcksize; + const char * type_name; + int64_t blck_size; + int64_t blck_size_interleave; // interleave elements in blocks + size_t type_size; + bool is_quantized; + ggml_to_float_t to_float; + ggml_from_float_t from_float; + ggml_from_float_t from_float_ref; ggml_from_float_to_mat_t from_float_to_mat; - ggml_gemv_t gemv; - ggml_gemm_t gemm; + ggml_vec_dot_t vec_dot; + enum ggml_type vec_dot_type; + int64_t nrows; // number of rows to process simultaneously + int64_t ncols; // number of columns to process simultaneously + ggml_gemv_t gemv; + ggml_gemm_t gemm; } ggml_type_traits_t; GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type); diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c index 008718634..40838cf4f 100644 --- a/ggml/src/ggml-aarch64.c +++ b/ggml/src/ggml-aarch64.c @@ -20,19 +20,19 @@ // Functions to create the interleaved data layout formats -// interleave 4 block_q4_0s in blocks of interleave_blcksize +// interleave 4 block_q4_0s in blocks of blck_size_interleave // returns an interleaved block_q4_0x4 // in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks -// first, then interleave quants from 4 block_q4_0s in blocks of interleave_blcksize +// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave // // - in : an array of block_q4_0 pointers -// - interleave_blcksize : the block_q4_0 quants bytes are interleaved in blocks of -// interleave_blcksize bytes +// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of +// blck_size_interleave bytes // - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes // from bias offset form to pure sign form (this saves subtract // operations durin unpacking) // -static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int interleave_blcksize, unsigned int xor_mask) { +static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { block_q4_0x4 out; for (int i = 0; i < 4; i++) { @@ -40,9 +40,9 @@ static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int interleave_b } for (int i = 0; i < QK4_0 * 2; i++) { - int src_offset = (i / (4 * interleave_blcksize)) * interleave_blcksize; - int src_id = (i % (4 * interleave_blcksize)) / interleave_blcksize; - src_offset += (i % interleave_blcksize); + int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (i % blck_size_interleave); out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; } @@ -50,11 +50,11 @@ static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int interleave_b return out; } -// interleave 8 block_q4_0s in blocks of interleave_blcksize +// interleave 8 block_q4_0s in blocks of blck_size_interleave // returns an interleaved block_q4_0x8 // in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks -// first, then interleave quants from 8 block_q4_0s in blocks of interleave_blcksize -static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int interleave_blcksize, unsigned int xor_mask) { +// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave +static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { block_q4_0x8 out; for (int i = 0; i < 8; i++) { @@ -62,9 +62,9 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int interleave_b } for (int i = 0; i < QK4_0 * 4; i++) { - int src_offset = (i / (8 * interleave_blcksize)) * interleave_blcksize; - int src_id = (i % (8 * interleave_blcksize)) / interleave_blcksize; - src_offset += (i % interleave_blcksize); + int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave; + int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave; + src_offset += (i % blck_size_interleave); out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; } @@ -135,7 +135,7 @@ void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) } #else // scalar - const int interleave_blcksize = 4; + const int blck_size_interleave = 4; float srcv[4][QK8_0]; float id[4]; @@ -155,12 +155,12 @@ void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) } for (int j = 0; j < QK8_0 * 4; j++) { - int src_offset = (j / (4 * interleave_blcksize)) * interleave_blcksize; - int src_id = (j % (4 * interleave_blcksize)) / interleave_blcksize; - src_offset += (j % interleave_blcksize); + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); float x0 = srcv[src_id][src_offset] * id[src_id]; - y[i].qs[j] = roundf(x0);; + y[i].qs[j] = roundf(x0); } } #endif @@ -253,7 +253,7 @@ void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) } #else // scalar - const int interleave_blcksize = 8; + const int blck_size_interleave = 8; float srcv[4][QK8_0]; float id[4]; @@ -273,26 +273,30 @@ void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) } for (int j = 0; j < QK8_0 * 4; j++) { - int src_offset = (j / (4 * interleave_blcksize)) * interleave_blcksize; - int src_id = (j % (4 * interleave_blcksize)) / interleave_blcksize; - src_offset += (j % interleave_blcksize); + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); float x0 = srcv[src_id][src_offset] * id[src_id]; - y[i].qs[j] = roundf(x0);; + y[i].qs[j] = roundf(x0); } } #endif } -void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t interleave_blcksize) { +void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) { assert(nrow == 4); UNUSED(nrow); - if (interleave_blcksize == 4) quantize_q8_0_4x4(x, vy, n_per_row); - else if (interleave_blcksize == 8) quantize_q8_0_4x8(x, vy, n_per_row); - else assert(false); + if (blck_size_interleave == 4) { + quantize_q8_0_4x4(x, vy, n_per_row); + } else if (blck_size_interleave == 8) { + quantize_q8_0_4x8(x, vy, n_per_row); + } else { + assert(false); + } } -static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int interleave_blcksize) { +static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int blck_size_interleave) { assert(n_per_row % QK4_0 == 0); const int nb = n_per_row / QK4_0; @@ -311,15 +315,15 @@ static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict ds for (int64_t x = 0; x < nb; x++) { for (int i = 0; i < nrows_interleaved; i++ ) { - quantize_row_q4_0_reference(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0); + quantize_row_q4_0_ref(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0); } if (nrows_interleaved == 8) { - *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, interleave_blcksize, 0x88); + *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave, 0x88); out_ptr = (block_q4_0x8 *) out_ptr + 1; } else if (nrows_interleaved == 4) { - *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, interleave_blcksize, 0x88); + *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave, 0x88); out_ptr = (block_q4_0x4 *) out_ptr + 1; } } diff --git a/ggml/src/ggml-aarch64.h b/ggml/src/ggml-aarch64.h index 65ead1efe..517babaf1 100644 --- a/ggml/src/ggml-aarch64.h +++ b/ggml/src/ggml-aarch64.h @@ -16,7 +16,7 @@ extern "C" { void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t interleave_blcksize); +void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave); // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); @@ -24,14 +24,14 @@ size_t quantize_q4_0_4x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT d size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); // GEMV -void ggml_gemv_q4_0_4x4_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q4_0_4x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q4_0_8x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); // GEMM -void ggml_gemm_q4_0_4x4_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q4_0_4x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q4_0_8x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #ifdef __cplusplus } diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index cbe377cf5..1839a722e 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -658,7 +658,7 @@ static inline __m128i packNibbles( __m256i bytes ) { #endif //__loongarch_asx // reference implementation for deterministic creation of model files -void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { +void quantize_row_q4_0_ref(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; assert(k % qk == 0); @@ -696,11 +696,11 @@ void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict } void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q4_0_reference(x, y, k); + quantize_row_q4_0_ref(x, y, k); } -void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int64_t k) { +void quantize_row_q4_1_ref(const float * restrict x, block_q4_1 * restrict y, int64_t k) { const int qk = QK4_1; assert(k % qk == 0); @@ -738,10 +738,10 @@ void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict } void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q4_1_reference(x, y, k); + quantize_row_q4_1_ref(x, y, k); } -void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int64_t k) { +void quantize_row_q5_0_ref(const float * restrict x, block_q5_0 * restrict y, int64_t k) { static const int qk = QK5_0; assert(k % qk == 0); @@ -786,10 +786,10 @@ void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict } void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q5_0_reference(x, y, k); + quantize_row_q5_0_ref(x, y, k); } -void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int64_t k) { +void quantize_row_q5_1_ref(const float * restrict x, block_q5_1 * restrict y, int64_t k) { const int qk = QK5_1; assert(k % qk == 0); @@ -834,11 +834,11 @@ void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict } void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q5_1_reference(x, y, k); + quantize_row_q5_1_ref(x, y, k); } // reference implementation for deterministic creation of model files -void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int64_t k) { +void quantize_row_q8_0_ref(const float * restrict x, block_q8_0 * restrict y, int64_t k) { assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -1144,12 +1144,12 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) #else GGML_UNUSED(nb); // scalar - quantize_row_q8_0_reference(x, y, k); + quantize_row_q8_0_ref(x, y, k); #endif } // reference implementation for deterministic creation of model files -void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int64_t k) { +void quantize_row_q8_1_ref(const float * restrict x, block_q8_1 * restrict y, int64_t k) { assert(QK8_1 == 32); assert(k % QK8_1 == 0); const int nb = k / QK8_1; @@ -1508,7 +1508,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) #else GGML_UNUSED(nb); // scalar - quantize_row_q8_1_reference(x, y, k); + quantize_row_q8_1_ref(x, y, k); #endif } @@ -1899,7 +1899,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * //========================- 2-bit (de)-quantization -void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int64_t k) { +void quantize_row_q2_K_ref(const float * restrict x, block_q2_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2002,7 +2002,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int6 } void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) { - quantize_row_q2_K_reference(x, vy, k); + quantize_row_q2_K_ref(x, vy, k); } static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights, @@ -2226,7 +2226,7 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row); if (!quant_weights) { - quantize_row_q2_K_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q2_K_ref(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; @@ -2241,7 +2241,7 @@ size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nr //========================= 3-bit (de)-quantization -void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int64_t k) { +void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2368,7 +2368,7 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int6 } void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) { - quantize_row_q3_K_reference(x, vy, k); + quantize_row_q3_K_ref(x, vy, k); } static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) { @@ -2458,7 +2458,7 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row); if (!quant_weights) { - quantize_row_q3_K_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q3_K_ref(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; @@ -2473,7 +2473,7 @@ size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nr // ====================== 4-bit (de)-quantization -void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int64_t k) { +void quantize_row_q4_K_ref(const float * restrict x, block_q4_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2572,7 +2572,7 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int6 void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q4_K * restrict y = vy; - quantize_row_q4_K_reference(x, y, k); + quantize_row_q4_K_ref(x, y, k); } static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) { @@ -2651,7 +2651,7 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row); if (!quant_weights) { - quantize_row_q4_K_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q4_K_ref(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; @@ -2666,7 +2666,7 @@ size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nr // ====================== 5-bit (de)-quantization -void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int64_t k) { +void quantize_row_q5_K_ref(const float * restrict x, block_q5_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int64_t nb = k / QK_K; @@ -2783,7 +2783,7 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int6 void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q5_K * restrict y = vy; - quantize_row_q5_K_reference(x, y, k); + quantize_row_q5_K_ref(x, y, k); } static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) { @@ -2882,7 +2882,7 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row); if (!quant_weights) { - quantize_row_q5_K_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q5_K_ref(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; @@ -2897,7 +2897,7 @@ size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nr // ====================== 6-bit (de)-quantization -void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int64_t k) { +void quantize_row_q6_K_ref(const float * restrict x, block_q6_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int64_t nb = k / QK_K; @@ -3001,7 +3001,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int6 void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q6_K * restrict y = vy; - quantize_row_q6_K_reference(x, y, k); + quantize_row_q6_K_ref(x, y, k); } static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) { @@ -3091,7 +3091,7 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row); if (!quant_weights) { - quantize_row_q6_K_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q6_K_ref(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; @@ -3108,7 +3108,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri static_assert(QK4_0 == 32, "QK4_0 must be 32"); if (!quant_weights) { - quantize_row_q4_0_reference(x, y, n_per_row); + quantize_row_q4_0_ref(x, y, n_per_row); return; } @@ -3134,7 +3134,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q4_0_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row); } size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row); @@ -3151,7 +3151,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri static_assert(QK4_1 == 32, "QK4_1 must be 32"); if (!quant_weights) { - quantize_row_q4_1_reference(x, y, n_per_row); + quantize_row_q4_1_ref(x, y, n_per_row); return; } @@ -3179,7 +3179,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q4_1_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q4_1_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row); } size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row); @@ -3196,7 +3196,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri static_assert(QK5_0 == 32, "QK5_0 must be 32"); if (!quant_weights) { - quantize_row_q5_0_reference(x, y, n_per_row); + quantize_row_q5_0_ref(x, y, n_per_row); return; } @@ -3233,7 +3233,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q5_0_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q5_0_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row); } size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row); @@ -3250,7 +3250,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri static_assert(QK5_1 == 32, "QK5_1 must be 32"); if (!quant_weights) { - quantize_row_q5_1_reference(x, y, n_per_row); + quantize_row_q5_1_ref(x, y, n_per_row); return; } @@ -3286,7 +3286,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q5_1_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q5_1_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row); } size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row); @@ -3302,7 +3302,7 @@ size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nr size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { (void)quant_weights; // not used const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row); - quantize_row_q8_0_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q8_0_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * row_size; } @@ -3590,7 +3590,7 @@ void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, //===================================== Q8_K ============================================== -void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int64_t k) { +void quantize_row_q8_K_ref(const float * restrict x, block_q8_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int64_t nb = k / QK_K; @@ -3641,7 +3641,7 @@ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int6 } void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q8_K_reference(x, y, k); + quantize_row_q8_K_ref(x, y, k); } //===================================== Dot ptoducts ================================= @@ -13530,10 +13530,10 @@ size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int64_t void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq3_xxs * restrict y = vy; - quantize_row_iq3_xxs_reference(x, y, k); + quantize_row_iq3_xxs_ref(x, y, k); } -void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) { +void quantize_row_iq3_xxs_ref(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_row_iq3_xxs_impl(256, x, y, k, NULL); } @@ -13746,10 +13746,10 @@ size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int64_t n void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq3_s * restrict y = vy; - quantize_row_iq3_s_reference(x, y, k); + quantize_row_iq3_s_ref(x, y, k); } -void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int64_t k) { +void quantize_row_iq3_s_ref(const float * restrict x, block_iq3_s * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq3_s(x, y, 1, k, NULL); } @@ -14487,7 +14487,7 @@ void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int64_t k } } -void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int64_t k) { +void quantize_row_iq4_nl_ref(const float * restrict x, block_iq4_nl * restrict y, int64_t k) { assert(k % QK4_NL == 0); quantize_row_iq4_nl(x, y, k); } @@ -14515,10 +14515,10 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq4_xs * restrict y = vy; - quantize_row_iq4_xs_reference(x, y, k); + quantize_row_iq4_xs_ref(x, y, k); } -void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int64_t k) { +void quantize_row_iq4_xs_ref(const float * restrict x, block_iq4_xs * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq4_xs(x, y, 1, k, NULL); } @@ -14705,7 +14705,7 @@ size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int64_t n return nrow * nblock * sizeof(block_iq2_s); } -void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restrict y, int64_t k) { +void quantize_row_iq2_s_ref(const float * restrict x, block_iq2_s * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq2_s(x, y, 1, k, NULL); } @@ -14713,7 +14713,7 @@ void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restri void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq2_s * restrict y = vy; - quantize_row_iq2_s_reference(x, y, k); + quantize_row_iq2_s_ref(x, y, k); } static bool validate_float(float f, size_t i) { diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 30983b872..88b1f3269 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -12,25 +12,25 @@ extern "C" { #endif // Quantization -void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); -void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); -void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); -void quantize_row_q5_1_reference(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k); -void quantize_row_q8_0_reference(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k); -void quantize_row_q8_1_reference(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); -void quantize_row_q2_K_reference(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); -void quantize_row_q3_K_reference(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); -void quantize_row_q4_K_reference(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k); -void quantize_row_q5_K_reference(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k); -void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k); -void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k); -void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k); -void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k); -void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k); -void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); -void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); +void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k); +void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 1bb731e16..9a5414787 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -592,7 +592,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = false, .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row, .from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row, - .from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row, + .from_float_ref = (ggml_from_float_t) ggml_fp32_to_fp16_row, .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16, .vec_dot_type = GGML_TYPE_F16, .nrows = 1, @@ -604,7 +604,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q4_0, .from_float = quantize_row_q4_0, - .from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q4_0_ref, .vec_dot = ggml_vec_dot_q4_0_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, #if defined (__ARM_FEATURE_MATMUL_INT8) @@ -620,7 +620,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q4_1, .from_float = quantize_row_q4_1, - .from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref, .vec_dot = ggml_vec_dot_q4_1_q8_1, .vec_dot_type = GGML_TYPE_Q8_1, #if defined (__ARM_FEATURE_MATMUL_INT8) @@ -636,7 +636,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = false, .to_float = NULL, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = NULL, .vec_dot_type = GGML_TYPE_COUNT, .nrows = 1, @@ -648,7 +648,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = false, .to_float = NULL, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = NULL, .vec_dot_type = GGML_TYPE_COUNT, .nrows = 1, @@ -660,7 +660,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q5_0, .from_float = quantize_row_q5_0, - .from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q5_0_ref, .vec_dot = ggml_vec_dot_q5_0_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, @@ -672,7 +672,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q5_1, .from_float = quantize_row_q5_1, - .from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q5_1_ref, .vec_dot = ggml_vec_dot_q5_1_q8_1, .vec_dot_type = GGML_TYPE_Q8_1, .nrows = 1, @@ -684,7 +684,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q8_0, .from_float = quantize_row_q8_0, - .from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q8_0_ref, + .from_float_to_mat = quantize_mat_q8_0, .vec_dot = ggml_vec_dot_q8_0_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, #if defined (__ARM_FEATURE_MATMUL_INT8) @@ -692,7 +693,6 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { #else .nrows = 1, #endif - .from_float_to_mat = quantize_mat_q8_0, }, [GGML_TYPE_Q8_1] = { .type_name = "q8_1", @@ -700,7 +700,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_q8_1), .is_quantized = true, .from_float = quantize_row_q8_1, - .from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref, .vec_dot_type = GGML_TYPE_Q8_1, .nrows = 1, }, @@ -711,7 +711,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q2_K, .from_float = quantize_row_q2_K, - .from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q2_K_ref, .vec_dot = ggml_vec_dot_q2_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -723,7 +723,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q3_K, .from_float = quantize_row_q3_K, - .from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q3_K_ref, .vec_dot = ggml_vec_dot_q3_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -735,7 +735,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q4_K, .from_float = quantize_row_q4_K, - .from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q4_K_ref, .vec_dot = ggml_vec_dot_q4_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -747,7 +747,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q5_K, .from_float = quantize_row_q5_K, - .from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q5_K_ref, .vec_dot = ggml_vec_dot_q5_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -759,7 +759,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q6_K, .from_float = quantize_row_q6_K, - .from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q6_K_ref, .vec_dot = ggml_vec_dot_q6_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -771,7 +771,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq2_xxs, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = ggml_vec_dot_iq2_xxs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -783,7 +783,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq2_xs, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = ggml_vec_dot_iq2_xs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -795,7 +795,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq3_xxs, .from_float = quantize_row_iq3_xxs, - .from_float_reference = (ggml_from_float_t)quantize_row_iq3_xxs_reference, + .from_float_ref = (ggml_from_float_t)quantize_row_iq3_xxs_ref, .vec_dot = ggml_vec_dot_iq3_xxs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -807,7 +807,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq3_s, .from_float = quantize_row_iq3_s, - .from_float_reference = (ggml_from_float_t)quantize_row_iq3_s_reference, + .from_float_ref = (ggml_from_float_t)quantize_row_iq3_s_ref, .vec_dot = ggml_vec_dot_iq3_s_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -819,7 +819,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq2_s, .from_float = quantize_row_iq2_s, - .from_float_reference = (ggml_from_float_t)quantize_row_iq2_s_reference, + .from_float_ref = (ggml_from_float_t)quantize_row_iq2_s_ref, .vec_dot = ggml_vec_dot_iq2_s_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -831,7 +831,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq1_s, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = ggml_vec_dot_iq1_s_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -843,7 +843,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq1_m, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = ggml_vec_dot_iq1_m_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -855,7 +855,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq4_nl, .from_float = quantize_row_iq4_nl, - .from_float_reference = (ggml_from_float_t)quantize_row_iq4_nl_reference, + .from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref, .vec_dot = ggml_vec_dot_iq4_nl_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, @@ -867,7 +867,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq4_xs, .from_float = quantize_row_iq4_xs, - .from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference, + .from_float_ref = (ggml_from_float_t)quantize_row_iq4_xs_ref, .vec_dot = ggml_vec_dot_iq4_xs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -886,7 +886,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = false, .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row, .from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row, - .from_float_reference = (ggml_from_float_t) ggml_fp32_to_bf16_row, + .from_float_ref = (ggml_from_float_t) ggml_fp32_to_bf16_row, .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16, .vec_dot_type = GGML_TYPE_BF16, .nrows = 1, @@ -894,48 +894,48 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_0_4_4] = { .type_name = "q4_0_4x4", .blck_size = QK4_0, + .blck_size_interleave = 4, .type_size = sizeof(block_q4_0), .is_quantized = true, .to_float = NULL, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = NULL, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, .ncols = 4, - .interleave_blcksize = 4, .gemv = ggml_gemv_q4_0_4x4_q8_0, .gemm = ggml_gemm_q4_0_4x4_q8_0, }, [GGML_TYPE_Q4_0_4_8] = { .type_name = "q4_0_4x8", .blck_size = QK4_0, + .blck_size_interleave = 8, .type_size = sizeof(block_q4_0), .is_quantized = true, .to_float = NULL, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = NULL, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, .ncols = 4, - .interleave_blcksize = 8, .gemv = ggml_gemv_q4_0_4x8_q8_0, .gemm = ggml_gemm_q4_0_4x8_q8_0, }, [GGML_TYPE_Q4_0_8_8] = { .type_name = "q4_0_8x8", .blck_size = QK4_0, + .blck_size_interleave = 8, .type_size = sizeof(block_q4_0), .is_quantized = true, .to_float = NULL, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = NULL, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, .ncols = 8, - .interleave_blcksize = 8, .gemv = ggml_gemv_q4_0_8x8_q8_0, .gemm = ggml_gemm_q4_0_8x8_q8_0, } @@ -3115,7 +3115,7 @@ size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) { return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN); } -GGML_CALL int ggml_blck_size(enum ggml_type type) { +GGML_CALL int64_t ggml_blck_size(enum ggml_type type) { return type_traits[type].blck_size; } @@ -12192,15 +12192,14 @@ static void ggml_compute_forward_mul_mat( const enum ggml_type type = src0->type; - enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; - ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; - int64_t const vec_dot_num_rows = type_traits[type].nrows; - int64_t const matmul_num_cols = type_traits[type].ncols; - int64_t const interleave_blcksize = type_traits[type].interleave_blcksize; - ggml_from_float_to_mat_t const from_float_to_mat - = type_traits[vec_dot_type].from_float_to_mat; - ggml_gemv_t const gemv = type_traits[type].gemv; - ggml_gemm_t const gemm = type_traits[type].gemm; + enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; + ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float; + ggml_from_float_to_mat_t const from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat; + int64_t const vec_dot_num_rows = type_traits[type].nrows; + int64_t const matmul_num_cols = type_traits[type].ncols; + int64_t const blck_size_interleave = type_traits[type].blck_size_interleave; + ggml_gemv_t const gemv = type_traits[type].gemv; + ggml_gemm_t const gemm = type_traits[type].gemm; GGML_ASSERT(ne0 == ne01); GGML_ASSERT(ne1 == ne11); @@ -12264,14 +12263,14 @@ UseGgmlGemm1:; for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), - 4, ne10, interleave_blcksize); + 4, ne10, blck_size_interleave); } i11_processed = ne11 - ne11 % 4; } for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { - from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), - (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), - ne10); + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + ne10); } } } @@ -12355,7 +12354,7 @@ UseGgmlGemm2:; int64_t src0_start = (ith * ne01) / nth; int64_t src0_end = ((ith + 1) * ne01) / nth; src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start; - src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end; + src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end; if (src0_start >= src0_end) return; // If there are more than three rows in src1, use gemm; otherwise, use gemv. @@ -12413,11 +12412,11 @@ static void ggml_compute_forward_mul_mat_id( const bool src1_cont = ggml_is_contiguous(src1); - ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; - enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; - ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; - int64_t const matmul_num_cols = type_traits[type].ncols; - ggml_gemv_t const gemv = type_traits[type].gemv; + ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; + enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; + ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float; + int64_t const matmul_num_cols = type_traits[type].ncols; + ggml_gemv_t const gemv = type_traits[type].gemv; // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == ggml_type_size(type)); @@ -12458,9 +12457,9 @@ static void ggml_compute_forward_mul_mat_id( for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i11 = ith; i11 < ne11; i11 += nth) { - from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), - (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), - ne10); + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + ne10); } } } @@ -21063,8 +21062,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p (int64_t) info->ne[3]; if (ne % ggml_blck_size(info->type) != 0) { - fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%d)\n", - __func__, info->name.data, (int)info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type)); + fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n", + __func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type)); fclose(file); gguf_free(ctx); return NULL; diff --git a/tests/test-double-float.cpp b/tests/test-double-float.cpp index 753dae911..6aac4737a 100644 --- a/tests/test-double-float.cpp +++ b/tests/test-double-float.cpp @@ -14,7 +14,7 @@ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wdouble-promotion" -// ggml.c::quantize_row_q4_0_reference +// ggml.c::quantize_row_q4_0_ref inline static uint8_t round_orig(float v0) { return ((int8_t) (round(v0))) + 8; } // ggml.c::ggml_silu_f32 @@ -24,7 +24,7 @@ inline static float silu_orig(float x) { #pragma GCC diagnostic pop -// ggml.c::quantize_row_q4_0_reference +// ggml.c::quantize_row_q4_0_ref inline static uint8_t round_float(float v0) { return (int8_t)roundf(v0) + 8; } // ggml.c::ggml_silu_f32 diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index e690ac6c8..c97458d1d 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -60,7 +60,7 @@ static float reference_quantization_error(ggml_type_traits_t & qfns, size_t test qfns.from_float(test_data, tmp_q.data(), test_size); qfns.to_float(tmp_q.data(), tmp_out.data(), test_size); - qfns.from_float_reference(test_data, tmp_q.data(), test_size); + qfns.from_float_ref(test_data, tmp_q.data(), test_size); qfns.to_float(tmp_q.data(), tmp_out_ref.data(), test_size); return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size); diff --git a/tests/test-quantize-perf.cpp b/tests/test-quantize-perf.cpp index 48d9fae3d..24e066053 100644 --- a/tests/test-quantize-perf.cpp +++ b/tests/test-quantize-perf.cpp @@ -285,7 +285,7 @@ int main(int argc, char * argv[]) { for (size_t size : params.test_sizes) { printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024)); auto quantize_fn = [&](void) -> float { - qfns.from_float_reference(test_data1, test_q1, size); + qfns.from_float_ref(test_data1, test_q1, size); return test_q1[0]; }; size_t quantized_size = ggml_row_size(type, size); From 71c1121d11f1437be9421fd0cbaa011b9ef49098 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Jul 2024 10:46:14 +0300 Subject: [PATCH 20/26] examples : sprintf -> snprintf (#8434) * examples : sprintf -> snprintf ggml-ci * examples : use sizeof() instead of hardcoded constants --- examples/eval-callback/eval-callback.cpp | 2 +- examples/gguf-hash/gguf-hash.cpp | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 64cd338c2..c8a3016a4 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -99,7 +99,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { char src1_str[128] = {0}; if (src1) { - sprintf(src1_str, "%s{%s}", src1->name, ggml_ne_string(src1).c_str()); + snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str()); } printf("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, diff --git a/examples/gguf-hash/gguf-hash.cpp b/examples/gguf-hash/gguf-hash.cpp index c34728c3d..e96c75117 100644 --- a/examples/gguf-hash/gguf-hash.cpp +++ b/examples/gguf-hash/gguf-hash.cpp @@ -347,7 +347,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) { char hex_result[17]; for (int offset = 0; offset < 8; offset++) { unsigned int shift_bits_by = (8 * (8 - offset - 1)); - sprintf( ( hex_result + (2*offset)), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff); + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff); } if (hash_params.manifest_is_usable) { @@ -384,7 +384,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) { char hex_result[41] = {0}; for (int offset = 0; offset < 20; offset++) { - sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff); + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff); } if (hash_params.manifest_is_usable) { @@ -421,7 +421,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) { char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0}; for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) { - sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff); + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff); } if (hash_params.manifest_is_usable) { @@ -460,7 +460,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) { char hex_result[17]; for (int offset = 0; offset < 8; offset++) { unsigned int shift_bits_by = (8 * (8 - offset - 1)); - sprintf( ( hex_result + (2*offset)), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff); + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff); } if (hash_params.manifest_is_usable) { @@ -490,7 +490,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) { char hex_result[41]; for (int offset = 0; offset < 20; offset++) { - sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff); + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff); } if (hash_params.manifest_is_usable) { @@ -520,7 +520,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) { char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0}; for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) { - sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff); + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff); } if (hash_params.manifest_is_usable) { @@ -552,7 +552,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) { generate_uuidv5(result, uuid); char string_buffer[37] = {0}; - sprintf(string_buffer, "%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", + snprintf(string_buffer, sizeof(string_buffer), "%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", uuid[0], uuid[1], uuid[2], uuid[3], uuid[4], uuid[5], uuid[6], uuid[7], uuid[8], uuid[9], uuid[10], uuid[11], From 5aefbce27a66473d4b1263ba3f7bdd3d14245975 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ji=C5=99=C3=AD=20Podiv=C3=ADn?= <66251151+jpodivin@users.noreply.github.com> Date: Fri, 12 Jul 2024 10:06:33 +0200 Subject: [PATCH 21/26] convert : remove fsep token from GPTRefactForCausalLM (#8237) The token used by Refact doesn't serve the same purpose as the from CodeGemma. Signed-off-by: Jiri Podivin --- convert_hf_to_gguf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ebb5ca376..cf930be17 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1203,11 +1203,10 @@ class RefactModel(Model): # TODO: how to determine special FIM tokens automatically? special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False, - special_token_types = ['prefix', 'suffix', 'middle', 'fsep', 'eot']) + special_token_types = ['prefix', 'suffix', 'middle', 'eot']) special_vocab._set_special_token("prefix", 1) special_vocab._set_special_token("suffix", 3) special_vocab._set_special_token("middle", 2) - special_vocab._set_special_token("fsep", 4) # is this correct? special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): From 8a4441ea1a2564578134404f31158c318e9c0bf3 Mon Sep 17 00:00:00 2001 From: Armen Kaleshian Date: Fri, 12 Jul 2024 04:08:19 -0400 Subject: [PATCH 22/26] docker : fix filename for convert-hf-to-gguf.py in tools.sh (#8441) Commit b0a4699 changed the name of this script from convert-hf-to-gguf.py to convert_hf_to_gguf.py breaking how convert is called from within a Docker container. --- .devops/tools.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.devops/tools.sh b/.devops/tools.sh index 335382f69..cf0e8f32d 100755 --- a/.devops/tools.sh +++ b/.devops/tools.sh @@ -8,7 +8,7 @@ arg1="$1" shift if [[ "$arg1" == '--convert' || "$arg1" == '-c' ]]; then - python3 ./convert-hf-to-gguf.py "$@" + python3 ./convert_hf_to_gguf.py "$@" elif [[ "$arg1" == '--quantize' || "$arg1" == '-q' ]]; then ./llama-quantize "$@" elif [[ "$arg1" == '--run' || "$arg1" == '-r' ]]; then From c3ebcfa148e867a68e78fd5c4f0c23e8f84c788b Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Fri, 12 Jul 2024 03:14:12 -0500 Subject: [PATCH 23/26] server : ensure batches are either all embed or all completion (#8420) * make sure batches are all embed or all non-embed * non-embedding batch for sampled tokens; fix unused params warning --- examples/server/server.cpp | 38 ++++++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index efd426289..badeb9121 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2005,6 +2005,11 @@ struct server_context { int32_t n_batch = llama_n_batch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx); + // track if this is an embedding or non-embedding batch + // if we've added sampled tokens above, we are in non-embedding mode + // -1: none, 0: non-embedding, 1: embedding + int32_t batch_type = batch.n_tokens > 0 ? 0 : -1; + // next, batch any pending prompts without exceeding n_batch if (params.cont_batching || batch.n_tokens == 0) { for (auto & slot : slots) { @@ -2175,6 +2180,14 @@ struct server_context { } } + // check that we are in the right batch_type, if not defer the slot + bool slot_type = slot.embedding ? 1 : 0; + if (batch_type == -1) { + batch_type = slot_type; + } else if (batch_type != slot_type) { + continue; + } + // keep only the common part int p0 = (int) system_tokens.size() + slot.n_past; if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) { @@ -2276,6 +2289,9 @@ struct server_context { {"n_tokens", batch.n_tokens}, }); + // make sure we're in the right embedding mode + llama_set_embeddings(ctx, batch_type == 1); + // process the created batch of tokens for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); @@ -2990,6 +3006,11 @@ int main(int argc, char ** argv) { }; const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + if (ctx_server.params.embedding) { + res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = json::parse(req.body); @@ -3085,6 +3106,11 @@ int main(int argc, char ** argv) { }; const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { + if (ctx_server.params.embedding) { + res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); @@ -3157,6 +3183,11 @@ int main(int argc, char ** argv) { }; const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + if (ctx_server.params.embedding) { + res_error(res, format_error_response("This server does not support infill. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = json::parse(req.body); @@ -3243,13 +3274,8 @@ int main(int argc, char ** argv) { return res.set_content(data.dump(), "application/json; charset=utf-8"); }; - const auto handle_embeddings = [¶ms, &ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - if (!params.embedding) { - res.status = 501; - res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8"); - return; - } const json body = json::parse(req.body); bool is_openai = false; From f53226245f421bd01b47cce43a47e791de82c636 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 12 Jul 2024 11:05:21 +0200 Subject: [PATCH 24/26] llama : suppress unary minus operator warning (#8448) This commit updates the _try_copy lambda and moves the unary minus operator to after the cast to int32_t. The motivation for this that currently the following warning is generated on windows: ```console llama.cpp\src\llama.cpp(21147,30): warning C4146: unary minus operator applied to unsigned type, result still unsigned ``` --- src/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index f91ac7779..59b76a6d8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -21144,7 +21144,7 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token size--; } if (length < (int32_t)size) { - return (int32_t) -size; + return -(int32_t) size; } memcpy(buf, token, size); return (int32_t) size; From 6af51c0d96e6268769fc05c98d5b1a5e832c0017 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Jul 2024 14:48:04 +0300 Subject: [PATCH 25/26] main : print error on empty input (#8456) --- examples/main/main.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 4ef55c1e6..a0d817b1a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -289,8 +289,13 @@ int main(int argc, char ** argv) { // Should not run without any tokens if (embd_inp.empty()) { - embd_inp.push_back(llama_token_bos(model)); - LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + if (add_bos) { + embd_inp.push_back(llama_token_bos(model)); + LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + } else { + LOG_TEE("error: input is empty\n"); + return -1; + } } // Tokenize negative prompt From 4e24cffd8cccd653634e24ee461c252bd77b1426 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Jul 2024 14:48:15 +0300 Subject: [PATCH 26/26] server : handle content array in chat API (#8449) * server : handle content array in chat API * Update examples/server/utils.hpp Co-authored-by: Xuan Son Nguyen --------- Co-authored-by: Xuan Son Nguyen --- examples/server/utils.hpp | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 7ef2a519a..db6b3b74d 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -122,8 +122,26 @@ inline std::string format_chat(const struct llama_model * model, const std::stri for (size_t i = 0; i < messages.size(); ++i) { const auto & curr_msg = messages[i]; - std::string role = json_value(curr_msg, "role", std::string("")); - std::string content = json_value(curr_msg, "content", std::string("")); + + std::string role = json_value(curr_msg, "role", std::string("")); + + std::string content; + if (curr_msg.contains("content")) { + if (curr_msg["content"].is_string()) { + content = curr_msg["content"].get(); + } else if (curr_msg["content"].is_array()) { + for (const auto & part : curr_msg["content"]) { + if (part.contains("text")) { + content += "\n" + part["text"].get(); + } + } + } else { + throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); + } + } else { + throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); + } + chat.push_back({role, content}); }