From ecb217db4fcfa3880300ad08531a5fb6bb142d45 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 4 Jun 2023 23:34:30 +0300 Subject: [PATCH 01/15] llama : Metal inference (#1642) * mtl : export the LLaMA computation graph * ci : disable temporary * mtl : adapt the MNIST example as starter * mtl : no need for mtl-export tool, add cli arg for main instead * mtl : export just a small part of the graph for now to make it easier * mtl : move MSL code into separate file for easy editing * mtl : initial get_rows_q4_0 kernel * mtl : confirmed get_rows_q4_0 is working correctly * mtl : add rms_norm kernel + confirm working * mtl : add mul kernel + confirm working * mtl : initial mul_mat Q4 kernel (wrong results) * mtl : mul_mat fixes (still wrong) * mtl : another mul_mat Q4 (still does not work) * mtl : working mul_mat q4 * ggml : fix handling of "view" ops in ggml_graph_import() * mtl : add rope kernel * mtl : add reshape and transpose handling * ggml : store offset as opt arg for ggml_view_xd() operators * mtl : add cpy kernel + handle view ops * mtl : confirm f16 x f32 attention mul mat * mtl : add scale kernel * mtl : add diag_mask_inf kernel * mtl : fix soft_max kernel * ggml : update ggml_nbytes() to handle non-contiguous tensors * mtl : verify V tensor contents * mtl : add f32 -> f32 cpy kernel * mtl : add silu kernel * mtl : add non-broadcast mul kernel * mtl : full GPU inference of the computation graph * mtl : optimize rms_norm and soft_max kernels * mtl : add f16 mat x f32 vec multiplication kernel * mtl : fix bug in f16 x f32 mul mat + speed-up computation * mtl : faster mul_mat_q4_0_f32 kernel * mtl : fix kernel signature + roll inner loop * mtl : more threads for rms_norm + better timing * mtl : remove printfs from inner loop * mtl : simplify implementation * mtl : add save/load vocab to ggml file * mtl : plug Metal inference into llama.cpp (very quick-n-dirty) * mtl : make it work with main example Lots of hacks but at least now it generates text * mtl : preparing for merge * mtl : clean-up ggml mtl interface + suport scratch / inplace * mtl : remove temp / debug code * metal : final refactoring and simplification * Revert "ci : disable temporary" This reverts commit 98c267fc77fe811082f672538fc91bcfc9072d63. * metal : add comments * metal : clean-up stuff, fix typos * readme : add Metal instructions * readme : add example for main --- .gitignore | 1 + CMakeLists.txt | 62 +++- Makefile | 33 +- README.md | 31 +- examples/CMakeLists.txt | 5 +- examples/common.cpp | 3 + examples/common.h | 1 + examples/main/main.cpp | 7 + examples/metal/CMakeLists.txt | 3 + examples/metal/metal.cpp | 102 ++++++ ggml-metal.h | 63 ++++ ggml-metal.m | 672 ++++++++++++++++++++++++++++++++++ ggml-metal.metal | 489 +++++++++++++++++++++++++ ggml.c | 153 ++++++-- ggml.h | 6 +- llama.cpp | 132 +++++-- llama.h | 8 +- 17 files changed, 1677 insertions(+), 94 deletions(-) create mode 100644 examples/metal/CMakeLists.txt create mode 100644 examples/metal/metal.cpp create mode 100644 ggml-metal.h create mode 100644 ggml-metal.m create mode 100644 ggml-metal.metal diff --git a/.gitignore b/.gitignore index d231f3ff8..e4561ad73 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,7 @@ build-release/ build-static/ build-cublas/ build-opencl/ +build-metal/ build-no-accel/ build-sanitize-addr/ build-sanitize-thread/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 21f4ec9dd..1f2e78c0f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -64,13 +64,14 @@ if (NOT MSVC) endif() # 3rd party libs -option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) -option(LLAMA_BLAS "llama: use BLAS" OFF) +option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) +option(LLAMA_BLAS "llama: use BLAS" OFF) set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") -option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) -set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") -set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") -option(LLAMA_CLBLAST "llama: use CLBlast" OFF) +option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) +set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") +set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") +option(LLAMA_CLBLAST "llama: use CLBlast" OFF) +option(LLAMA_METAL "llama: use Metal" OFF) option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) @@ -183,7 +184,7 @@ if (LLAMA_CUBLAS) enable_language(CUDA) - set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h) + set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h) add_compile_definitions(GGML_USE_CUBLAS) add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) @@ -200,12 +201,37 @@ if (LLAMA_CUBLAS) endif() endif() +if (LLAMA_METAL) + find_library(FOUNDATION_LIBRARY Foundation REQUIRED) + find_library(METAL_FRAMEWORK Metal REQUIRED) + find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) + find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED) + + set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h) + + add_compile_definitions(GGML_USE_METAL) + add_compile_definitions(GGML_METAL_NDEBUG) + + # get full path to the file + #add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/") + + # copy ggml-metal.metal to bin directory + configure_file(ggml-metal.metal bin/ggml-metal.metal COPYONLY) + + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} + ${FOUNDATION_LIBRARY} + ${METAL_FRAMEWORK} + ${METALKIT_FRAMEWORK} + ${METALPERFORMANCE_FRAMEWORK} + ) +endif() + if (LLAMA_CLBLAST) find_package(CLBlast) if (CLBlast_FOUND) message(STATUS "CLBlast found") - set(GGML_OPENCL_SOURCES ggml-opencl.cpp ggml-opencl.h) + set(GGML_SOURCES_OPENCL ggml-opencl.cpp ggml-opencl.h) add_compile_definitions(GGML_USE_CLBLAST) @@ -370,8 +396,10 @@ endif() add_library(ggml OBJECT ggml.c ggml.h - ${GGML_CUDA_SOURCES} - ${GGML_OPENCL_SOURCES}) + ${GGML_SOURCES_CUDA} + ${GGML_SOURCES_OPENCL} + ${GGML_SOURCES_METAL} + ) target_include_directories(ggml PUBLIC .) target_compile_features(ggml PUBLIC c_std_11) # don't bump @@ -384,21 +412,25 @@ endif() add_library(llama llama.cpp llama.h - llama-util.h) + llama-util.h + ) target_include_directories(llama PUBLIC .) target_compile_features(llama PUBLIC cxx_std_11) # don't bump -target_link_libraries(llama PRIVATE ggml ${LLAMA_EXTRA_LIBS}) +target_link_libraries(llama PRIVATE + ggml + ${LLAMA_EXTRA_LIBS} + ) if (BUILD_SHARED_LIBS) set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON) target_compile_definitions(llama PRIVATE LLAMA_SHARED LLAMA_BUILD) endif() -if (GGML_CUDA_SOURCES) +if (GGML_SOURCES_CUDA) message(STATUS "GGML CUDA sources found, configuring CUDA architecture") - set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES OFF) - set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto") + set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES OFF) + set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto") set_property(TARGET llama PROPERTY CUDA_ARCHITECTURES OFF) endif() diff --git a/Makefile b/Makefile index 8e8d426c5..1f910c3ec 100644 --- a/Makefile +++ b/Makefile @@ -105,6 +105,7 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686)) #CFLAGS += -mfma -mf16c -mavx #CXXFLAGS += -mfma -mf16c -mavx endif + ifneq ($(filter ppc64%,$(UNAME_M)),) POWER9_M := $(shell grep "POWER9" /proc/cpuinfo) ifneq (,$(findstring POWER9,$(POWER9_M))) @@ -116,6 +117,7 @@ ifneq ($(filter ppc64%,$(UNAME_M)),) CXXFLAGS += -std=c++23 -DGGML_BIG_ENDIAN endif endif + ifndef LLAMA_NO_ACCELERATE # Mac M1 - include Accelerate framework. # `-framework Accelerate` works on Mac Intel as well, with negliable performance boost (as of the predict time). @@ -123,7 +125,8 @@ ifndef LLAMA_NO_ACCELERATE CFLAGS += -DGGML_USE_ACCELERATE LDFLAGS += -framework Accelerate endif -endif +endif # LLAMA_NO_ACCELERATE + ifdef LLAMA_OPENBLAS CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas -I/usr/include/openblas ifneq ($(shell grep -e "Arch Linux" -e "ID_LIKE=arch" /etc/os-release 2>/dev/null),) @@ -131,11 +134,13 @@ ifdef LLAMA_OPENBLAS else LDFLAGS += -lopenblas endif -endif +endif # LLAMA_OPENBLAS + ifdef LLAMA_BLIS CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/blis -I/usr/include/blis LDFLAGS += -lblis -L/usr/local/lib -endif +endif # LLAMA_BLIS + ifdef LLAMA_CUBLAS CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include @@ -156,9 +161,10 @@ endif # LLAMA_CUDA_DMMV_Y ggml-cuda.o: ggml-cuda.cu ggml-cuda.h $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ endif # LLAMA_CUBLAS + ifdef LLAMA_CLBLAST - CFLAGS += -DGGML_USE_CLBLAST - CXXFLAGS += -DGGML_USE_CLBLAST + CFLAGS += -DGGML_USE_CLBLAST + CXXFLAGS += -DGGML_USE_CLBLAST # Mac provides OpenCL as a framework ifeq ($(UNAME_S),Darwin) LDFLAGS += -lclblast -framework OpenCL @@ -166,23 +172,38 @@ ifdef LLAMA_CLBLAST LDFLAGS += -lclblast -lOpenCL endif OBJS += ggml-opencl.o + ggml-opencl.o: ggml-opencl.cpp ggml-opencl.h $(CXX) $(CXXFLAGS) -c $< -o $@ -endif +endif # LLAMA_CLBLAST + +ifdef LLAMA_METAL + CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG + CXXFLAGS += -DGGML_USE_METAL + LDFLAGS += -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders + OBJS += ggml-metal.o + +ggml-metal.o: ggml-metal.m ggml-metal.h + $(CC) $(CFLAGS) -c $< -o $@ +endif # LLAMA_METAL + ifneq ($(filter aarch64%,$(UNAME_M)),) # Apple M1, M2, etc. # Raspberry Pi 3, 4, Zero 2 (64-bit) CFLAGS += -mcpu=native CXXFLAGS += -mcpu=native endif + ifneq ($(filter armv6%,$(UNAME_M)),) # Raspberry Pi 1, Zero CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access endif + ifneq ($(filter armv7%,$(UNAME_M)),) # Raspberry Pi 2 CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations endif + ifneq ($(filter armv8%,$(UNAME_M)),) # Raspberry Pi 3, 4, Zero 2 (32-bit) CFLAGS += -mfp16-format=ieee -mno-unaligned-access diff --git a/README.md b/README.md index aba22b9a0..4fc877c64 100644 --- a/README.md +++ b/README.md @@ -51,11 +51,10 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++ The main goal of `llama.cpp` is to run the LLaMA model using 4-bit integer quantization on a MacBook - Plain C/C++ implementation without dependencies -- Apple silicon first-class citizen - optimized via ARM NEON and Accelerate framework +- Apple silicon first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks - AVX, AVX2 and AVX512 support for x86 architectures - Mixed F16 / F32 precision - 4-bit, 5-bit and 8-bit integer quantization support -- Runs on the CPU - Supports OpenBLAS/Apple BLAS/ARM Performance Lib/ATLAS/BLIS/Intel MKL/NVHPC/ACML/SCSL/SGIMATH and [more](https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors) in BLAS - cuBLAS and CLBlast support @@ -236,6 +235,32 @@ In order to build llama.cpp you have three different options. zig build -Drelease-fast ``` +### Metal Build + +Using Metal allows the computation to be executed on the GPU for Apple devices: + +- Using `make`: + + ```bash + LLAMA_METAL=1 make + ``` + +- Using `CMake`: + + ```bash + mkdir build-metal + cd build-metal + cmake -DLLAMA_METAL=ON .. + cmake --build . --config Release + ``` + +When built with Metal support, you can enable GPU inference with the `--gpu-layers|-ngl` command-line argument. +Any value larger than 0 will offload the computation to the GPU. For example: + +```bash +./main -m ./models/7B/ggml-model-q4_0.bin -n 128 -ngl 1 +``` + ### BLAS Build Building the program with BLAS support may lead to some performance improvements in prompt processing using batch sizes higher than 32 (the default is 512). BLAS doesn't affect the normal generation performance. There are currently three different implementations of it: @@ -369,7 +394,7 @@ Building the program with BLAS support may lead to some performance improvements Running: - The CLBlast build supports `--gpu-layers|-ngl` like the CUDA version does. + The CLBlast build supports `--gpu-layers|-ngl` like the CUDA version does. To select the correct platform (driver) and device (GPU), you can use the environment variables `GGML_OPENCL_PLATFORM` and `GGML_OPENCL_DEVICE`. The selection can be a number (starting from 0) or a text string to search: diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e4ce5aca7..3deff4077 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -37,7 +37,10 @@ else() add_subdirectory(save-load-state) add_subdirectory(benchmark) add_subdirectory(baby-llama) - if(LLAMA_BUILD_SERVER) + if (LLAMA_METAL) + add_subdirectory(metal) + endif() + if (LLAMA_BUILD_SERVER) add_subdirectory(server) endif() endif() diff --git a/examples/common.cpp b/examples/common.cpp index 32247cef7..b5810f28f 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -299,6 +299,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.use_mmap = false; } else if (arg == "--mtest") { params.mem_test = true; + } else if (arg == "--export") { + params.export_cgraph = true; } else if (arg == "--verbose-prompt") { params.verbose_prompt = true; } else if (arg == "-r" || arg == "--reverse-prompt") { @@ -438,6 +440,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " number of layers to store in VRAM\n"); #endif fprintf(stderr, " --mtest compute maximum memory usage\n"); + fprintf(stderr, " --export export the computation graph to 'llama.ggml'\n"); fprintf(stderr, " --verbose-prompt print prompt before generation\n"); fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); diff --git a/examples/common.h b/examples/common.h index fea9aa81a..66bdeb5e9 100644 --- a/examples/common.h +++ b/examples/common.h @@ -71,6 +71,7 @@ struct gpt_params { bool use_mmap = true; // use mmap for faster loads bool use_mlock = false; // use mlock to keep model in memory bool mem_test = false; // compute maximum memory usage + bool export_cgraph = false; // export the computation graph bool verbose_prompt = false; // print prompt tokens before generation }; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 57cc1e486..b4d129393 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -134,6 +134,13 @@ int main(int argc, char ** argv) { return 0; } + // export the cgraph and exit + if (params.export_cgraph) { + llama_eval_export(ctx, "llama.ggml"); + llama_free(ctx); + + return 0; + } std::string path_session = params.path_prompt_cache; std::vector session_tokens; diff --git a/examples/metal/CMakeLists.txt b/examples/metal/CMakeLists.txt new file mode 100644 index 000000000..a8c4284a5 --- /dev/null +++ b/examples/metal/CMakeLists.txt @@ -0,0 +1,3 @@ +set(TEST_TARGET metal) +add_executable(${TEST_TARGET} metal.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml) diff --git a/examples/metal/metal.cpp b/examples/metal/metal.cpp new file mode 100644 index 000000000..77aca94a3 --- /dev/null +++ b/examples/metal/metal.cpp @@ -0,0 +1,102 @@ +// Evaluate a statically exported ggml computation graph with Metal +// +// - First, export a LLaMA graph: +// +// $ ./bin/main -m ../models/7B/ggml-model-q4_0.bin --export +// +// - Run this tool to evaluate the exported graph: +// +// $ ./bin/metal llama.ggml +// +// The purpose of this tool is mostly for debugging and demonstration purposes. +// The main limitation of exporting computation graphs is that their sizes are static which often +// can be a problem for real-world applications. +// + +#include "ggml.h" +#include "ggml-metal.h" + +#include +#include +#include + +int main(int argc, char ** argv) { + ggml_time_init(); + + if (argc != 2) { + fprintf(stderr, "Usage: %s llama.ggml\n", argv[0]); + return -1; + } + + const char * fname_cgraph = argv[1]; + + // load the compute graph + struct ggml_context * ctx_data = NULL; + struct ggml_context * ctx_eval = NULL; + + struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval); + gf.n_threads = 1; + + // this allocates all Metal resources and memory buffers + auto * ctx_metal = ggml_metal_init(); + + ggml_metal_add_buffer(ctx_metal, "data", ggml_get_mem_buffer(ctx_data), ggml_get_mem_size(ctx_data)); + ggml_metal_add_buffer(ctx_metal, "eval", ggml_get_mem_buffer(ctx_eval), ggml_get_mem_size(ctx_eval)); + + // main + { + struct ggml_tensor * input = ggml_graph_get_tensor(&gf, "embd"); + *(int32_t *) input->data = 1; // BOS + + ggml_metal_set_tensor(ctx_metal, input); + + // warmup + ggml_metal_graph_compute(ctx_metal, &gf); + + const int n_iter = 16; + + const int64_t t0 = ggml_time_us(); + + // the actual inference happens here + for (int i = 0; i < n_iter; ++i) { + ggml_metal_graph_compute(ctx_metal, &gf); + } + + const int64_t t1 = ggml_time_us(); + + printf("time: %.2f ms, %.2f ms/tok\n", (t1 - t0) / 1000.0, (t1 - t0) / 1000.0 / n_iter); + } + + // debug output + { + struct ggml_tensor * logits = gf.nodes[gf.n_nodes - 1]; + ggml_metal_get_tensor(ctx_metal, logits); + + float * ptr = (float *) ggml_get_data(logits); + + printf("logits: "); + for (int i = 0; i < 10; i++) { + printf("%8.4f ", ptr[i]); + } + printf("\n"); + int imax = 0; + double sum = 0.0; + double vmax = -1e9; + for (int i = 0; i < 32000; i++) { + sum += (double) ptr[i]; + if (ptr[i] > vmax) { + vmax = ptr[i]; + imax = i; + } + } + printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax); + } + + ggml_metal_free(ctx_metal); + + ggml_free(ctx_data); + ggml_free(ctx_eval); + + return 0; +} + diff --git a/ggml-metal.h b/ggml-metal.h new file mode 100644 index 000000000..a9441a9d4 --- /dev/null +++ b/ggml-metal.h @@ -0,0 +1,63 @@ +// An interface allowing to compute ggml_cgraph with Metal +// +// This is a fully functional interface that extends ggml with GPU support for Apple devices. +// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, OpenCL, etc.) +// +// How it works? +// +// As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this +// interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you +// use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.) +// +// You only need to make sure that all memory buffers that you used during the graph creation +// are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is +// used during the graph evaluation to determine the arguments of the compute kernels. +// +// Synchronization between device and host memory (for example for input and output tensors) +// is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions. +// + +#pragma once + +#include +#include + +// max memory buffers that can be mapped to the device +#define GGML_METAL_MAX_BUFFERS 16 + +struct ggml_tensor; +struct ggml_cgraph; + +#ifdef __cplusplus +extern "C" { +#endif + +struct ggml_metal_context; + +struct ggml_metal_context * ggml_metal_init(void); +void ggml_metal_free(struct ggml_metal_context * ctx); + +// creates a mapping between a host memory buffer and a device memory buffer +// - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute +// - the mapping is used during computation to determine the arguments of the compute kernels +// - you don't need to keep the host memory buffer allocated as it is never accessed by Metal +// +bool ggml_metal_add_buffer( + struct ggml_metal_context * ctx, + const char * name, + void * data, + size_t size); + +// set data from host memory into the device +void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t); + +// get data from the device into host memory +void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t); + +// same as ggml_graph_compute but uses Metal +void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf); + +#ifdef __cplusplus +} +#endif + diff --git a/ggml-metal.m b/ggml-metal.m new file mode 100644 index 000000000..3cb423a01 --- /dev/null +++ b/ggml-metal.m @@ -0,0 +1,672 @@ +#import "ggml-metal.h" + +#import "ggml.h" + +#import + +#import +#import + +#ifdef GGML_METAL_NDEBUG +#define metal_printf(...) +#else +#define metal_printf(...) fprintf(stderr, __VA_ARGS__) +#endif + +#define UNUSED(x) (void)(x) + +struct ggml_metal_buffer { + const char * name; + + void * data; + size_t size; + + id metal; +}; + +struct ggml_metal_context { + float * logits; + + id device; + id queue; + id library; + + int n_buffers; + struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; + + // custom kernels +#define GGML_METAL_DECL_KERNEL(name) \ + id function_##name; \ + id pipeline_##name + + GGML_METAL_DECL_KERNEL(add); + GGML_METAL_DECL_KERNEL(mul); + GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast + GGML_METAL_DECL_KERNEL(scale); + GGML_METAL_DECL_KERNEL(silu); + GGML_METAL_DECL_KERNEL(relu); + GGML_METAL_DECL_KERNEL(soft_max); + GGML_METAL_DECL_KERNEL(diag_mask_inf); + GGML_METAL_DECL_KERNEL(get_rows_q4_0); + GGML_METAL_DECL_KERNEL(rms_norm); + GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); + GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); + GGML_METAL_DECL_KERNEL(rope); + GGML_METAL_DECL_KERNEL(cpy_f32_f16); + GGML_METAL_DECL_KERNEL(cpy_f32_f32); + +#undef GGML_METAL_DECL_KERNEL +}; + +// MSL code +// TODO: move the contents here when ready +// for now it is easier to work in a separate file +static NSString * const msl_library_source = @"see metal.metal"; + +struct ggml_metal_context * ggml_metal_init(void) { + fprintf(stderr, "%s: allocating\n", __func__); + + struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context)); + + ctx->device = MTLCreateSystemDefaultDevice(); + ctx->queue = [ctx->device newCommandQueue]; + + // determine if we can use MPS + if (MPSSupportsMTLDevice(ctx->device)) { + fprintf(stderr, "%s: using MPS\n", __func__); + } else { + fprintf(stderr, "%s: not using MPS\n", __func__); + GGML_ASSERT(false && "MPS not supported"); + } + +#if 0 + // compile from source string and show compile log + { + NSError * error = nil; + + ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error]; + if (error) { + fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]); + exit(1); + } + } +#else + UNUSED(msl_library_source); + + // read the source from "ggml-metal.metal" into a string and use newLibraryWithSource + { + NSError * error = nil; + + //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"]; + NSString * path = [[NSBundle mainBundle] pathForResource:@"ggml-metal" ofType:@"metal"]; + fprintf(stderr, "%s: loading '%s'\n", __func__, [path UTF8String]); + + NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error]; + if (error) { + fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]); + exit(1); + } + + ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error]; + if (error) { + fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]); + exit(1); + } + } +#endif + + // load kernels + { +#define GGML_METAL_ADD_KERNEL(name) \ + ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \ + ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:nil]; \ + fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); + + GGML_METAL_ADD_KERNEL(add); + GGML_METAL_ADD_KERNEL(mul); + GGML_METAL_ADD_KERNEL(mul_row); + GGML_METAL_ADD_KERNEL(scale); + GGML_METAL_ADD_KERNEL(silu); + GGML_METAL_ADD_KERNEL(relu); + GGML_METAL_ADD_KERNEL(soft_max); + GGML_METAL_ADD_KERNEL(diag_mask_inf); + GGML_METAL_ADD_KERNEL(get_rows_q4_0); + GGML_METAL_ADD_KERNEL(rms_norm); + GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); + GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); + GGML_METAL_ADD_KERNEL(rope); + GGML_METAL_ADD_KERNEL(cpy_f32_f16); + GGML_METAL_ADD_KERNEL(cpy_f32_f32); + +#undef GGML_METAL_ADD_KERNEL + } + + return ctx; +} + +void ggml_metal_free(struct ggml_metal_context * ctx) { + fprintf(stderr, "%s: deallocating\n", __func__); + + free(ctx); +} + +// finds the Metal buffer that contains the tensor data on the GPU device +// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the +// Metal buffer based on the host memory pointer +// +static id ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) { + //fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach); + + for (int i = 0; i < ctx->n_buffers; ++i) { + const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data; + + if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) { + *offs = (size_t) ioffs; + + //fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs); + + return ctx->buffers[i].metal; + } + } + + fprintf(stderr, "%s: error: buffer is nil\n", __func__); + + return nil; +} + +bool ggml_metal_add_buffer( + struct ggml_metal_context * ctx, + const char * name, + void * data, + size_t size) { + if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) { + fprintf(stderr, "%s: too many buffers\n", __func__); + return false; + } + + if (data) { + // verify that the buffer does not overlap with any of the existing buffers + for (int i = 0; i < ctx->n_buffers; ++i) { + const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data; + + if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) { + fprintf(stderr, "%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name); + return false; + } + } + + ctx->buffers[ctx->n_buffers].name = name; + ctx->buffers[ctx->n_buffers].data = data; + ctx->buffers[ctx->n_buffers].size = size; + ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytes:data length:size options:MTLResourceStorageModeShared]; + + ++ctx->n_buffers; + + fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB\n", __func__, name, size / 1024.0 / 1024.0); + } + + return true; +} + +void ggml_metal_set_tensor( + struct ggml_metal_context * ctx, + struct ggml_tensor * t) { + metal_printf("%s: set input for tensor '%s'\n", __func__, t->name); + + size_t offs; + id id_dst = ggml_metal_get_buffer(ctx, t, &offs); + + memcpy((void *) ((uint8_t *) id_dst.contents + offs), t->data, ggml_nbytes(t)); +} + +void ggml_metal_get_tensor( + struct ggml_metal_context * ctx, + struct ggml_tensor * t) { + metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name); + + size_t offs; + id id_src = ggml_metal_get_buffer(ctx, t, &offs); + + memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t)); +} + +void ggml_metal_graph_compute( + struct ggml_metal_context * ctx, + struct ggml_cgraph * gf) { + metal_printf("%s: evaluating graph\n", __func__); + + size_t offs_src0 = 0; + size_t offs_src1 = 0; + size_t offs_dst = 0; + + id command_buffer = [ctx->queue commandBuffer]; + id encoder = nil; + + for (int i = 0; i < gf->n_nodes; ++i) { + //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); + + struct ggml_tensor * src0 = gf->nodes[i]->src0; + struct ggml_tensor * src1 = gf->nodes[i]->src1; + struct ggml_tensor * dst = gf->nodes[i]; + + const int64_t ne00 = src0 ? src0->ne[0] : 0; + const int64_t ne01 = src0 ? src0->ne[1] : 0; + const int64_t ne02 = src0 ? src0->ne[2] : 0; + const int64_t ne03 = src0 ? src0->ne[3] : 0; + + const uint64_t nb00 = src0 ? src0->nb[0] : 0; + const uint64_t nb01 = src0 ? src0->nb[1] : 0; + const uint64_t nb02 = src0 ? src0->nb[2] : 0; + const uint64_t nb03 = src0 ? src0->nb[3] : 0; + + const int64_t ne10 = src1 ? src1->ne[0] : 0; + const int64_t ne11 = src1 ? src1->ne[1] : 0; + const int64_t ne12 = src1 ? src1->ne[2] : 0; + const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13); + + const uint64_t nb10 = src1 ? src1->nb[0] : 0; + const uint64_t nb11 = src1 ? src1->nb[1] : 0; + const uint64_t nb12 = src1 ? src1->nb[2] : 0; + const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13); + + const int64_t ne0 = dst ? dst->ne[0] : 0; + const int64_t ne1 = dst ? dst->ne[1] : 0; + const int64_t ne2 = dst ? dst->ne[2] : 0; + const int64_t ne3 = dst ? dst->ne[3] : 0; + + const uint64_t nb0 = dst ? dst->nb[0] : 0; + const uint64_t nb1 = dst ? dst->nb[1] : 0; + const uint64_t nb2 = dst ? dst->nb[2] : 0; + const uint64_t nb3 = dst ? dst->nb[3] : 0; + + const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; + const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; + + id id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil; + id id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil; + id id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil; + + //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op)); + //if (src0) { + // metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, + // ggml_is_contiguous(src0), src0->name); + //} + //if (src1) { + // metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, + // ggml_is_contiguous(src1), src1->name); + //} + //if (dst) { + // metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, + // dst->name); + //} + + switch (dst->op) { + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_TRANSPOSE: + case GGML_OP_PERMUTE: + { + // noop + } break; + case GGML_OP_ADD: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + [encoder setComputePipelineState:ctx->pipeline_add]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_MUL: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + if (ggml_nelements(src1) == ne10) { + // src1 is a row + [encoder setComputePipelineState:ctx->pipeline_mul_row]; + } else { + [encoder setComputePipelineState:ctx->pipeline_mul]; + } + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SCALE: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + const float scale = *(const float *) src1->data; + + [encoder setComputePipelineState:ctx->pipeline_scale]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SILU: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + [encoder setComputePipelineState:ctx->pipeline_silu]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_RELU: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + [encoder setComputePipelineState:ctx->pipeline_relu]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SOFT_MAX: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + const int nth = 32; + + [encoder setComputePipelineState:ctx->pipeline_soft_max]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_DIAG_MASK_INF: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + const int n_past = ((int32_t *)(src1->data))[0]; + + [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_MUL_MAT: + { + // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224 + + GGML_ASSERT(ne00 == ne10); + GGML_ASSERT(ne02 == ne12); + + if (ggml_is_contiguous(src0) && + ggml_is_contiguous(src1) && + (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) { + + if (encoder != nil) { + [encoder endEncoding]; + encoder = nil; + } + + MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16; + MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16; + + // for F32 x F32 we use MPS + MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor + matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt]; + + MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor + matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt]; + + MPSMatrixDescriptor * desc = [MPSMatrixDescriptor + matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32]; + + MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc] + initWithDevice:ctx->device transposeLeft:false transposeRight:true + resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0]; + + // we need to do ne02 multiplications + // TODO: is there a way to do this in parallel - currently very slow .. + // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS + for (int64_t i02 = 0; i02 < ne02; ++i02) { + size_t offs_src0_cur = offs_src0 + i02*nb02; + size_t offs_src1_cur = offs_src1 + i02*nb12; + size_t offs_dst_cur = offs_dst + i02*nb2; + + MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0]; + MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1]; + MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ]; + + [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst]; + } + } else { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + int nth0 = 32; + int nth1 = 1; + + // use custom matrix x vector kernel + switch (src0t) { + case GGML_TYPE_Q4_0: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 8; + nth1 = 4; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(ne02 == ne12); + + nth0 = 32; + nth1 = 1; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; + } break; + default: GGML_ASSERT(false && "not implemented"); + }; + + + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; + + if (src0t == GGML_TYPE_Q4_0) { + [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else { + [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + } + } break; + case GGML_OP_GET_ROWS: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + switch (src0->type) { + case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; + default: GGML_ASSERT(false && "not implemented"); + } + + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4]; + [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5]; + + const int64_t n = ggml_nelements(src1); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_RMS_NORM: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + const float eps = 1e-6f; + + const int nth = 256; + + [encoder setComputePipelineState:ctx->pipeline_rms_norm]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ROPE: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + const int n_dims = ((int32_t *) src1->data)[1]; + const int mode = ((int32_t *) src1->data)[2]; + + const int n_past = ((int32_t *)(src1->data))[0]; + + [encoder setComputePipelineState:ctx->pipeline_rope]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&n_past length:sizeof( int) atIndex:18]; + [encoder setBytes:&n_dims length:sizeof( int) atIndex:19]; + [encoder setBytes:&mode length:sizeof( int) atIndex:20]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_CPY: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + const int nth = 32; + + switch (src0t) { + case GGML_TYPE_F32: + { + switch (dstt) { + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break; + case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break; + default: GGML_ASSERT(false && "not implemented"); + }; + } break; + default: GGML_ASSERT(false && "not implemented"); + } + + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + default: + fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); + GGML_ASSERT(false); + } + } + + if (encoder != nil) { + [encoder endEncoding]; + encoder = nil; + } + + [command_buffer commit]; + [command_buffer waitUntilCompleted]; + + { + const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime]; + UNUSED(time_elapsed); + + metal_printf("%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0); + } +} diff --git a/ggml-metal.metal b/ggml-metal.metal new file mode 100644 index 000000000..4bedc8ea4 --- /dev/null +++ b/ggml-metal.metal @@ -0,0 +1,489 @@ +#include + +using namespace metal; + +#define MAX(x, y) ((x) > (y) ? (x) : (y)) + +#define QK4_0 32 +#define QR4_0 2 +typedef struct { + half d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; + +static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) { + const int qk = QK4_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const half d = x[i].d; + + for (int j = 0; j < qk/2; ++j) { + const int x0 = (x[i].qs[j] & 0x0F) - 8; + const int x1 = (x[i].qs[j] >> 4) - 8; + + y[i*qk + j + 0 ] = x0*d; + y[i*qk + j + qk/2] = x1*d; + } + } +} + +kernel void kernel_add( + device const float * src0, + device const float * src1, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] + src1[tpig]; +} + +kernel void kernel_mul( + device const float * src0, + device const float * src1, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src1[tpig]; +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_mul_row( + device const float * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src1[tpig % ne00]; +} + +kernel void kernel_scale( + device const float * src0, + device float * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_silu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + float x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); +} + +kernel void kernel_relu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = max(0.0f, src0[tpig]); +} + +kernel void kernel_soft_max( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + threadgroup float * buf [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + // parallel max + buf[tpitg[0]] = -INFINITY; + for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { + buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]); + } + + // reduce + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = ntg[0]/2; i > 0; i /= 2) { + if (tpitg[0] < i) { + buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // broadcast + if (tpitg[0] == 0) { + buf[0] = buf[0]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float max = buf[0]; + + // parallel sum + buf[tpitg[0]] = 0.0f; + for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { + buf[tpitg[0]] += exp(psrc0[i00] - max); + } + + // reduce + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = ntg[0]/2; i > 0; i /= 2) { + if (tpitg[0] < i) { + buf[tpitg[0]] += buf[tpitg[0] + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // broadcast + if (tpitg[0] == 0) { + buf[0] = buf[0]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float sum = buf[0]; + + for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { + pdst[i00] = exp(psrc0[i00] - max) / sum; + } +} + +kernel void kernel_diag_mask_inf( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int & n_past, + uint3 tpig[[thread_position_in_grid]]) { + const int64_t i02 = tpig[2]; + const int64_t i01 = tpig[1]; + const int64_t i00 = tpig[0]; + + if (i00 > n_past + i01) { + dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; + } else { + dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; + } +} + +kernel void kernel_get_rows_q4_0( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb1, + uint tpig[[thread_position_in_grid]]) { + const int i = tpig; + const int r = ((device int32_t *) src1)[i]; + + dequantize_row_q4_0( + (device const block_q4_0 *) ((device char *) src0 + r*nb01), + (device float *) ((device char *) dst + i*nb1), ne00); +} + +kernel void kernel_rms_norm( + device const void * src0, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant float & eps, + threadgroup float * sum [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); + + // parallel sum + sum[tpitg] = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + sum[tpitg] += x[i00] * x[i00]; + } + + // reduce + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = ntg/2; i > 0; i /= 2) { + if (tpitg < i) { + sum[tpitg] += sum[tpitg + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // broadcast + if (tpitg == 0) { + sum[0] /= ne00; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float mean = sum[0]; + const float scale = 1.0f/sqrt(mean + eps); + + device float * y = dst + tgpig*ne00; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + y[i00] = x[i00] * scale; + } +} + +kernel void kernel_mul_mat_q4_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + threadgroup float * sum [[threadgroup(0)]], + uint2 tgpig[[threadgroup_position_in_grid]], + uint2 tpig[[thread_position_in_grid]], + uint2 tpitg[[thread_position_in_threadgroup]], + uint2 tptg[[threads_per_threadgroup]]) { + const int nb = ne00/QK4_0; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + + device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb; + device const float * y = (device const float *) src1 + r1*ne10; + + const uint nth = tptg.x*tptg.y; + const uint ith = tptg.y*tpitg.x + tpitg.y; + + sum[ith] = 0.0f; + + for (int i = tpitg.x; i < nb; i += tptg.x) { + device const uchar4 * x0p = (device const uchar4 *) (x + i)->qs; + device const float4 * y0p = (device const float4 *) (y + i*QK4_0); + + const float d = (float)((x + i)->d); + + const uchar4 x0v = *(x0p + tpitg.y); + const float4 y0v = *(y0p + tpitg.y + 0); + const float4 y1v = *(y0p + tpitg.y + 4); + + float acc = 0.0f; + + for (int j = 0; j < 4; ++j) { + const int x0 = x0v[j] & 0x0F; + const int x1 = x0v[j] >> 4; + + const float y0 = y0v[j]; + const float y1 = y1v[j]; + + acc += (x0 - 8)*y0 + (x1 - 8)*y1; + } + + sum[ith] += acc*d; + } + + // accumulate the sum from all threads in the threadgroup + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = nth/2; i > 0; i /= 2) { + if (ith < i) { + sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + if (ith == 0) { + dst[r1*ne0 + r0] = sum[0]; + } +} + +kernel void kernel_mul_mat_f16_f32( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + threadgroup float * sum [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpig[[thread_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]]) { + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int64_t im = tgpig.z; + + device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02); + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + sum[tpitg.x] = 0.0f; + + for (int i = tpitg.x; i < ne00; i += tptg.x) { + sum[tpitg.x] += (float) x[i] * (float) y[i]; + } + + // accumulate the sum from all threads in the threadgroup + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = tptg.x/2; i > 0; i /= 2) { + if (tpitg.x < i) { + sum[tpitg.x] += sum[tpitg.x + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + if (tpitg.x == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0]; + } +} + +kernel void kernel_rope( + device const void * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & mode, + uint3 tpig[[thread_position_in_grid]]) { + const int64_t i3 = tpig[2]; + const int64_t i2 = tpig[1]; + const int64_t i1 = tpig[0]; + + const bool is_neox = mode & 2; + const float theta_scale = pow(10000.0, -2.0f/n_dims); + + const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + + float theta = (float)p; + + if (!is_neox) { + for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float cos_theta = cos(theta); + const float sin_theta = sin(theta); + + theta *= theta_scale; + + device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[1]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; + } + } else { + // TODO: implement + } +} + +kernel void kernel_cpy_f32_f16( + device const float * src0, + device half * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f32( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} diff --git a/ggml.c b/ggml.c index 91552c94c..00bbee503 100644 --- a/ggml.c +++ b/ggml.c @@ -3723,7 +3723,7 @@ int64_t ggml_nelements(const struct ggml_tensor * tensor) { return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; } -int ggml_nrows(const struct ggml_tensor * tensor) { +int64_t ggml_nrows(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); return tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; @@ -3732,7 +3732,14 @@ int ggml_nrows(const struct ggml_tensor * tensor) { size_t ggml_nbytes(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - return (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type]; + // this should handle cases where the tensor is not contiguous in memory + // probaby just: + // + // return tensor->ne[3]*tensor->nb[3] + // + // is enough, but just in case, adding the second part + + return MAX(tensor->ne[3]*tensor->nb[3], (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type]); } int ggml_blck_size(enum ggml_type type) { @@ -3814,11 +3821,11 @@ size_t ggml_tensor_overhead(void) { return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE + 16; } -static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) { +bool ggml_is_transposed(const struct ggml_tensor * tensor) { return tensor->nb[0] > tensor->nb[1]; } -static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) { +bool ggml_is_contiguous(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); return @@ -5802,10 +5809,18 @@ struct ggml_tensor * ggml_view_1d( struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset); + ggml_scratch_save(ctx); + + struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2); + memcpy(offs->data, &offset, 2*sizeof(int32_t)); + + ggml_scratch_load(ctx); + result->op = GGML_OP_VIEW; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; + result->opt[0] = offs; if (is_node) { memcpy(result->padding, &offset, sizeof(offset)); @@ -5834,6 +5849,13 @@ struct ggml_tensor * ggml_view_2d( struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset); + ggml_scratch_save(ctx); + + struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2); + memcpy(offs->data, &offset, 2*sizeof(int32_t)); + + ggml_scratch_load(ctx); + result->nb[1] = nb1; result->nb[2] = result->nb[1]*ne1; result->nb[3] = result->nb[2]; @@ -5842,6 +5864,7 @@ struct ggml_tensor * ggml_view_2d( result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; + result->opt[0] = offs; if (is_node) { memcpy(result->padding, &offset, sizeof(offset)); @@ -5872,6 +5895,13 @@ struct ggml_tensor * ggml_view_3d( struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset); + ggml_scratch_save(ctx); + + struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2); + memcpy(offs->data, &offset, 2*sizeof(int32_t)); + + ggml_scratch_load(ctx); + result->nb[1] = nb1; result->nb[2] = nb2; result->nb[3] = result->nb[2]*ne2; @@ -5880,6 +5910,7 @@ struct ggml_tensor * ggml_view_3d( result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; + result->opt[0] = offs; if (is_node) { memcpy(result->padding, &offset, sizeof(offset)); @@ -5912,6 +5943,13 @@ struct ggml_tensor * ggml_view_4d( struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, (char *) a->data + offset); + ggml_scratch_save(ctx); + + struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2); + memcpy(offs->data, &offset, 2*sizeof(int32_t)); + + ggml_scratch_load(ctx); + result->nb[1] = nb1; result->nb[2] = nb2; result->nb[3] = nb3; @@ -5920,6 +5958,7 @@ struct ggml_tensor * ggml_view_4d( result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; + result->opt[0] = offs; if (is_node) { memcpy(result->padding, &offset, sizeof(offset)); @@ -9252,7 +9291,7 @@ static void ggml_compute_forward_rms_norm_f32( sum += (ggml_float)(x[i00] * x[i00]); } - float mean = sum/ne00; + const float mean = sum/ne00; float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); @@ -11163,7 +11202,7 @@ static void ggml_compute_forward_rope_f32( theta *= theta_scale; const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); const float x0 = src[0]; const float x1 = src[1]; @@ -11184,7 +11223,7 @@ static void ggml_compute_forward_rope_f32( const int64_t i0 = ib*n_dims + ic/2; const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); const float x0 = src[0]; const float x1 = src[n_dims/2]; @@ -14588,7 +14627,7 @@ static void ggml_graph_export_leaf(const struct ggml_tensor * tensor, FILE * fou const int64_t * ne = tensor->ne; const size_t * nb = tensor->nb; - fprintf(fout, "%-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %16p %16s\n", + fprintf(fout, "%-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %16p %32s\n", ggml_type_name(tensor->type), ggml_op_name (tensor->op), tensor->n_dims, @@ -14602,7 +14641,7 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char const int64_t * ne = tensor->ne; const size_t * nb = tensor->nb; - fprintf(fout, "%-6s %-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %8d %16p %16s\n", + fprintf(fout, "%-6s %-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %8d %16p %32s\n", arg, ggml_type_name(tensor->type), ggml_op_name (tensor->op), @@ -14615,8 +14654,8 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char } void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { - assert(cgraph->work == NULL); - assert(cgraph->work_size == 0); + //assert(cgraph->work == NULL); + //assert(cgraph->work_size == 0); uint64_t size_eval = 0; @@ -14837,7 +14876,6 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** // read file into data { FILE * fin = fopen(fname, "rb"); - if (!fin) { fprintf(stderr, "%s: failed to open %s\n", __func__, fname); return result; @@ -14977,6 +15015,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** op = *(const uint32_t *) ptr; ptr += sizeof(op); n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims); + enum ggml_op eop = (enum ggml_op) op; + int64_t ne[GGML_MAX_DIMS]; size_t nb[GGML_MAX_DIMS]; @@ -14991,42 +15031,77 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** nb[j] = nb_cur; } - struct ggml_tensor * tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne); + uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur); // TODO: not yet used - tensor->op = (enum ggml_op) op; + const char * ptr_name = ptr; ptr += GGML_MAX_NAME; - uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur); + const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += (2 + GGML_MAX_OPT)*sizeof(int32_t); - memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME; + struct ggml_tensor * args[2 + GGML_MAX_OPT] = { NULL }; + + // parse args + for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) { + const int32_t arg_idx = ptr_arg_idx[j]; + + if (arg_idx == -1) { + continue; + } + + if (arg_idx < GGML_MAX_NODES) { + args[j] = result.leafs[arg_idx]; + } else { + args[j] = result.nodes[arg_idx - GGML_MAX_NODES]; + } + } + + // create the tensor + // "view" operations are handled differently + // TODO: handle inplace ops - currently a copy is always made + + struct ggml_tensor * tensor = NULL; + + switch (eop) { + // TODO: implement other view ops + case GGML_OP_RESHAPE: + { + tensor = ggml_reshape_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3]); + } break; + case GGML_OP_VIEW: + { + tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0); + + uint64_t offs; + memcpy(&offs, args[2]->data, sizeof(offs)); + + tensor->data = ((char *) tensor->data) + offs; + } break; + case GGML_OP_TRANSPOSE: + { + tensor = ggml_transpose(*ctx_eval, args[0]); + } break; + case GGML_OP_PERMUTE: + { + tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0); + } break; + default: + { + tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne); + + tensor->op = eop; + } break; + } + + memcpy(tensor->name, ptr_name, GGML_MAX_NAME); for (int j = 0; j < GGML_MAX_DIMS; ++j) { tensor->nb[j] = nb[j]; } - // parse args - { - struct ggml_tensor ** args[2 + GGML_MAX_OPT] = { - &tensor->src0, - &tensor->src1, - }; + tensor->src0 = args[0]; + tensor->src1 = args[1]; - for (int j = 0; j < GGML_MAX_OPT; ++j) { - args[2 + j] = &tensor->opt[j]; - } - - for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) { - const int32_t arg_idx = *(const int32_t *) ptr; ptr += sizeof(arg_idx); - - if (arg_idx == -1) { - continue; - } - - if (arg_idx < GGML_MAX_NODES) { - *args[j] = result.leafs[arg_idx]; - } else { - *args[j] = result.nodes[arg_idx - GGML_MAX_NODES]; - } - } + for (int j = 0; j < GGML_MAX_OPT; ++j) { + tensor->opt[j] = args[2 + j]; } result.nodes[i] = tensor; diff --git a/ggml.h b/ggml.h index 60c0ad8bf..2ea87ce9a 100644 --- a/ggml.h +++ b/ggml.h @@ -425,6 +425,7 @@ extern "C" { GGML_API void ggml_print_objects(const struct ggml_context * ctx); GGML_API int64_t ggml_nelements(const struct ggml_tensor * tensor); + GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor); GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor); GGML_API int ggml_blck_size (enum ggml_type type); @@ -441,13 +442,16 @@ extern "C" { // TODO: temporary until model loading of ggml examples is refactored GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype); + GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor); + GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor); + // use this to compute the memory overhead of a tensor GGML_API size_t ggml_tensor_overhead(void); // main GGML_API struct ggml_context * ggml_init(struct ggml_init_params params); - GGML_API void ggml_free(struct ggml_context * ctx); + GGML_API void ggml_free(struct ggml_context * ctx); GGML_API size_t ggml_used_mem(const struct ggml_context * ctx); diff --git a/llama.cpp b/llama.cpp index f70b26c0f..bc58ad960 100644 --- a/llama.cpp +++ b/llama.cpp @@ -16,6 +16,10 @@ #include "ggml-opencl.h" #endif +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + #include #include #include @@ -243,6 +247,10 @@ struct llama_context { llama_ctx_buffer buf_compute; llama_ctx_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS]; +#ifdef GGML_USE_METAL + ggml_metal_context * ctx_metal = NULL; +#endif + int buf_last = 0; size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 }; @@ -1088,7 +1096,7 @@ static void llama_model_load_internal( mmapped_size - vram_total + // weights in VRAM not in memory MEM_REQ_SCRATCH0().at(model.type) + MEM_REQ_SCRATCH1().at(model.type) + - MEM_REQ_EVAL().at(model.type); + MEM_REQ_EVAL().at (model.type); // this is the memory required by one llama_state const size_t mem_required_state = @@ -1195,17 +1203,19 @@ static bool llama_model_load( // evaluate the transformer // -// - lctx: llama context -// - tokens: new batch of tokens to process -// - n_past: the context size so far -// - n_threads: number of threads to use +// - lctx: llama context +// - tokens: new batch of tokens to process +// - n_past: the context size so far +// - n_threads: number of threads to use +// - cgraph_fname: filename of the exported computation graph // static bool llama_eval_internal( - llama_context & lctx, - const llama_token * tokens, - const int n_tokens, - const int n_past, - const int n_threads) { + llama_context & lctx, + const llama_token * tokens, + const int n_tokens, + const int n_past, + const int n_threads, + const char * cgraph_fname) { // enforce that the first token is BOS if (n_past == 0 && tokens[0] != llama_token_bos()) { @@ -1251,13 +1261,18 @@ static bool llama_eval_internal( ggml_set_name(embd, "embd"); memcpy(embd->data, tokens, N*ggml_element_size(embd)); +#ifdef GGML_USE_METAL + if (lctx.ctx_metal && N == 1) { + ggml_metal_set_tensor(lctx.ctx_metal, embd); + } +#endif + + struct ggml_tensor * cur; struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; - struct ggml_tensor * cur; - lctx.use_buf(ctx0, 0); // norm @@ -1271,6 +1286,7 @@ static bool llama_eval_internal( // self-attention { // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0); struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0); ggml_set_name(Qcur, "Qcur"); @@ -1280,6 +1296,7 @@ static bool llama_eval_internal( { // compute the transposed [N, n_embd] V matrix struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N)); + ggml_set_name(Vcur, "Vcur"); struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, @@ -1325,7 +1342,6 @@ static bool llama_eval_internal( struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); ggml_set_name(KQ_soft_max, "KQ_soft_max"); - // split cached V into n_head heads struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, @@ -1407,26 +1423,53 @@ static bool llama_eval_internal( // norm { + cur = ggml_rms_norm(ctx0, inpL); - inpL = ggml_rms_norm(ctx0, inpL); + // cur = cur*norm(broadcasted) + cur = ggml_mul(ctx0, cur, model.norm); - // inpL = inpL*norm(broadcasted) - inpL = ggml_mul(ctx0, inpL, model.norm); - - embeddings = inpL; + embeddings = cur; } // lm_head - inpL = ggml_mul_mat(ctx0, model.output, inpL); + cur = ggml_mul_mat(ctx0, model.output, cur); lctx.use_buf(ctx0, -1); // logits -> probs - //inpL = ggml_soft_max_inplace(ctx0, inpL); + //cur = ggml_soft_max_inplace(ctx0, cur); // run the computation - ggml_build_forward_expand(&gf, inpL); - ggml_graph_compute (ctx0, &gf); + ggml_build_forward_expand(&gf, cur); + +#ifdef GGML_USE_METAL + if (lctx.ctx_metal && N == 1) { + ggml_metal_graph_compute(lctx.ctx_metal, &gf); + ggml_metal_get_tensor (lctx.ctx_metal, cur); + } else { + // IMPORTANT: + // Since we don't have efficient Matrix x Matrix Metal multiplication yet, we fallback to vanilla + // ggml_graph_compute(). It uses Apple's Accelerate CBLAS API which takes advantage of the ANE or the AMX + // coprocessor. + // + // When we implement Matrix x Matrix Metal multiplication, we can avoid this branch. + // But for now, we have focused only on Matrix x Vector Metal multiplication. + // + ggml_graph_compute(ctx0, &gf); + + if (lctx.ctx_metal) { + // We need to sync the CPU KV cache with the GPU KV cache + ggml_metal_set_tensor(lctx.ctx_metal, kv_self.k); + ggml_metal_set_tensor(lctx.ctx_metal, kv_self.v); + } + } +#else + ggml_graph_compute(ctx0, &gf); +#endif + + if (cgraph_fname) { + ggml_graph_export(&gf, cgraph_fname); + } #ifdef GGML_PERF // print timing information per ggml operation (for debugging purposes) @@ -1440,7 +1483,7 @@ static bool llama_eval_internal( //} //embd_w.resize(n_vocab*N); - //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); + //memcpy(embd_w.data(), ggml_get_data(cur), sizeof(float)*n_vocab*N); // update kv token count lctx.model.kv_self.n = n_past + N; @@ -1451,11 +1494,11 @@ static bool llama_eval_internal( if (lctx.logits_all) { logits_out.resize(n_vocab * N); - memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N); + memcpy(logits_out.data(), (float *) ggml_get_data(cur), sizeof(float)*n_vocab*N); } else { // return result for just the last token logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + memcpy(logits_out.data(), (float *) ggml_get_data(cur) + (n_vocab*(N-1)), sizeof(float)*n_vocab); } } @@ -2251,8 +2294,8 @@ struct llama_context * llama_init_from_file( ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_gpu_layers, memory_type, - params.use_mmap, params.use_mlock, params.vocab_only, - params.progress_callback, params.progress_callback_user_data)) { + params.use_mmap, params.use_mlock, params.vocab_only, + params.progress_callback, params.progress_callback_user_data)) { fprintf(stderr, "%s: failed to load model\n", __func__); llama_free(ctx); return nullptr; @@ -2290,6 +2333,25 @@ struct llama_context * llama_init_from_file( ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type)); } +#ifdef GGML_USE_METAL + if (params.n_gpu_layers > 0) { + // this allocates all Metal resources and memory buffers + ctx->ctx_metal = ggml_metal_init(); + + if (params.use_mmap) { + ggml_metal_add_buffer(ctx->ctx_metal, "data", ctx->model.mapping->addr, ctx->model.mapping->size); + ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size); + } else { + ggml_metal_add_buffer(ctx->ctx_metal, "data", ggml_get_mem_buffer(ctx->model.ctx), ggml_get_mem_size(ctx->model.ctx)); + ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size); + } + + ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->model.kv_self.buf.addr, ctx->model.kv_self.buf.size); + ggml_metal_add_buffer(ctx->ctx_metal, "scr0", ctx->buf_scratch[0].addr, ctx->buf_scratch[0].size); + ggml_metal_add_buffer(ctx->ctx_metal, "scr1", ctx->buf_scratch[1].addr, ctx->buf_scratch[1].size); + } +#endif + return ctx; } @@ -2905,7 +2967,7 @@ int llama_eval( int n_tokens, int n_past, int n_threads) { - if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads)) { + if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, nullptr)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } @@ -2920,6 +2982,20 @@ int llama_eval( return 0; } +int llama_eval_export(struct llama_context * ctx, const char * fname) { + const int n_batch = 1; + const int n_ctx = 512 - n_batch; + + const std::vector tmp(n_batch, llama_token_bos()); + + if (!llama_eval_internal(*ctx, tmp.data(), tmp.size(), n_ctx, 1, fname)) { + fprintf(stderr, "%s: failed to eval\n", __func__); + return 1; + } + + return 0; +} + int llama_tokenize( struct llama_context * ctx, const char * text, diff --git a/llama.h b/llama.h index c6b0a2889..87fa97367 100644 --- a/llama.h +++ b/llama.h @@ -31,7 +31,7 @@ #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN #define LLAMA_SESSION_VERSION 1 -#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) +#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) // Defined when llama.cpp is compiled with support for offloading model layers to GPU. #define LLAMA_SUPPORTS_GPU_OFFLOAD #endif @@ -173,6 +173,12 @@ extern "C" { int n_past, int n_threads); + // Export a static computation graph for context of 511 and batch size of 1 + // NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these + // parameters here to keep things simple + // IMPORTANT: do not use for anything else other than debugging and testing! + LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname); + // Convert the provided text into tokens. // The tokens pointer must be large enough to hold the resulting tokens. // Returns the number of tokens on success, no more than n_max_tokens From 827f5eda91e5b7299848ee2c7179d873bdee0f7b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 4 Jun 2023 23:38:19 +0300 Subject: [PATCH 02/15] readme : update hot topics --- README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 4fc877c64..01f138127 100644 --- a/README.md +++ b/README.md @@ -9,9 +9,11 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++ **Hot topics:** -- Quantization formats `Q4` and `Q8` have changed again (19 May) - [(info)](https://github.com/ggerganov/llama.cpp/pull/1508) -- Quantization formats `Q4` and `Q5` have changed - requantize any old models [(info)](https://github.com/ggerganov/llama.cpp/pull/1405) -- [Roadmap May 2023](https://github.com/ggerganov/llama.cpp/discussions/1220) +- GPU support with Metal (Apple Silicon): https://github.com/ggerganov/llama.cpp/pull/1642 +- High-quality 2,3,4,5,6-bit quantization: https://github.com/ggerganov/llama.cpp/pull/1684 +- Multi-GPU support: https://github.com/ggerganov/llama.cpp/pull/1607 +- Training LLaMA models from scratch: https://github.com/ggerganov/llama.cpp/pull/1652 +- CPU threading improvements: https://github.com/ggerganov/llama.cpp/pull/1632
Table of Contents From d1f563a743a83dabc11e125d4a7d64189c16498c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 5 Jun 2023 10:19:03 +0300 Subject: [PATCH 03/15] llama : fix Metal KV cache sync (close #1695) --- llama.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/llama.cpp b/llama.cpp index bc58ad960..69bfdc1a1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1455,6 +1455,14 @@ static bool llama_eval_internal( // When we implement Matrix x Matrix Metal multiplication, we can avoid this branch. // But for now, we have focused only on Matrix x Vector Metal multiplication. // + // TODO: avoid these syncs via shared memory (ref #1696) + // + if (lctx.ctx_metal) { + // We need to sync the GPU KV cache with the CPU KV cache + ggml_metal_get_tensor(lctx.ctx_metal, kv_self.k); + ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v); + } + ggml_graph_compute(ctx0, &gf); if (lctx.ctx_metal) { From 5220a991a5e92bddad9542267ab445a2c033681c Mon Sep 17 00:00:00 2001 From: Henri Vasserman Date: Mon, 5 Jun 2023 13:43:08 +0300 Subject: [PATCH 04/15] Increase 3B scratch buffers. (#1698) The 128 MB was too optimistic. Too bad it is not dynamically computed. --- llama.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index 69bfdc1a1..a16450173 100644 --- a/llama.cpp +++ b/llama.cpp @@ -63,7 +63,7 @@ static const size_t MB = 1024*1024; static const std::map & MEM_REQ_SCRATCH0() { static std::map k_sizes = { - { MODEL_3B, 128ull * MB }, + { MODEL_3B, 256ull * MB }, { MODEL_7B, 512ull * MB }, { MODEL_13B, 512ull * MB }, { MODEL_30B, 512ull * MB }, @@ -75,7 +75,7 @@ static const std::map & MEM_REQ_SCRATCH0() static const std::map & MEM_REQ_SCRATCH1() { static std::map k_sizes = { - { MODEL_3B, 128ull * MB }, + { MODEL_3B, 256ull * MB }, { MODEL_7B, 512ull * MB }, { MODEL_13B, 512ull * MB }, { MODEL_30B, 512ull * MB }, From 99009e72f8072fa552eb02efee436be596c71cdd Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Mon, 5 Jun 2023 22:56:18 +0300 Subject: [PATCH 05/15] ggml : add SOTA 2,3,4,5,6 bit k-quantizations (#1684) * Starting to add k-quantization to ggml I think it is better to have quantization separate from ggml. For now just adding the k-quants there, but it would be better to also factor out the existing ggml quantizations. * Adding Q3_K and Q8_K (de)-quantization * Q3_K now working on CUDA and AVX2/scalar CUDA is not ideal - ~50% slower than Q4_0 for single token prediction, about the same in batch mode (perplexity). CPU single token is ~55 ms (on Ryzen 7950X). * Some improvement for Q3_K on CUDA It is now ~22.5 ms/token on my GPU, so ~30% slower than Q4_0. * Some more CUDA optimizations for Q3_K Single token is now 20.5 ms/token (~20% slower than Q4_0). Perplexity is on par with Q4_0. * Adding Q4_K - scalar, AVX2, CUDA Performance is the same or perhaps very slightly better than Q4_0 on the CPU. On the GPU, single token prediction is ~10% better than Q4_0, batch mode (perplexity is about the same). * Adding Q6_K - scalar, AVX2, CUDA Performance is ~40% lower compared to Q4_K on the CPU. This is to be expected, considering that we are memory bound on the CPU and the 6-bit model is ~44% larger than the 4-bit. On the GPU, single token prediction is ~6% lower than Q4_0, batch mode (perplexity) is even closer (but still slower). * Adding Q5_K - scalar, AVX2, CUDA Performance is ~20% lower compared to Q4_K on the CPU. This is to be expected, considering that we are memory bound on the CPU and the 5-bit model is ~22% larger than the 4-bit. On the GPU, single token prediction is about the same as Q4_0 for both, single token and batch prediction. * Per convention, all QX_K quantizations use Q5_K for output.weight * Adding quantization mixes * Quantization mixes: didn't quite get what I wanted in the last commit * Q4_K dot product for ARM_NEON * Q6_K dot product for ARM_NEON * Q5_K dot product for ARM_NEON * Adding Q3_K dot for ARM_NEON It is 22% slower than Q4_K, despite the smaller model size. On x86_64, where we are memory bound, the Q3_K model is quite a bit faster than Q4_K. * A very slightly faster ARM_NEON Q3_K dot * Adding Q2_K - just CUDA for now Token prediction is pretty good - about 15.5 ms on a RTX 4080. Perplexity is about the same as Q4_K. * Adding scalar and AVX2 Q2_K dot * Adding ARM_NEON Q2_K dot About the same performance as Q4_K. * A slightly faster ARM_NEON Q2_K dot Single token prediction is now ~36 ms on M2 Max. The code is much simpler too. * Fixed bug in Q2_K CUDA dot product kernel Stranegly enough, for the few prompts I tried with the 7B model the responses looked perfectly reasonable. Only realized something is not quite right when I tried the larger models and started getting nonse back. In any case, Q2_K single token evaluation time on an RTX 4080 in a Ryzen7950X box iusing CUDA and model fully loaded on the GPU are ~15.5 ms for 7B, ~25.4 ms for 13B, and ~55.8 ms for 30B. The max number of layers that fit in VRAM for The 65B is 32. With that, we get ~330 ms per token, which is not that much faster than just running on the CPU (~470 ms per token). * Don't print zeros/NaNs when no count histogram has been collected * A 10% faster CUDA vector dot kernel for Q3_K Q3_K is now running at ~18.5 ms / token on CUDA, so the gap to Q4_0 is only 10%. It seems memory acccess pattern is more important for performance than the amount of computation the kernel does. * A slightly daster Q4_K AVX2 dot product For perplexity, where we are less memory bound, time per pass drops by ~5%. Barely measurable difference for single token prediction. * A slightly faster ARM_NEON A4_K dot product * Minor * Fix quantization error test We cannot possibly be expecting rmse < 0.002 for 2- and 3-bit quantization variants. * Fix docker build I have been sloppy with vector reinterpret casts on ARM_NEON. It seems clang is very forgiving in that regard. * Added forgotten ggml.o dependence on k_quants.h to the Makefile * Had unintentionally committed the Makefile with -Ofast enabled * ggml : rename k_quants -> ggml-quants-k, use lowercase in code --------- Co-authored-by: Iwan Kawrakow Co-authored-by: Georgi Gerganov --- CMakeLists.txt | 2 + Makefile | 26 +- examples/quantize-stats/quantize-stats.cpp | 5 +- examples/quantize/quantize.cpp | 22 +- ggml-cuda.cu | 491 +++++ ggml-quants-k.c | 2246 ++++++++++++++++++++ ggml-quants-k.h | 122 ++ ggml.c | 150 +- ggml.h | 12 + llama.cpp | 85 +- llama.h | 9 + tests/test-quantize-fns.cpp | 7 +- 12 files changed, 3148 insertions(+), 29 deletions(-) create mode 100644 ggml-quants-k.c create mode 100644 ggml-quants-k.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 1f2e78c0f..da5913d69 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -396,6 +396,8 @@ endif() add_library(ggml OBJECT ggml.c ggml.h + ggml-quants-k.h + ggml-quants-k.c ${GGML_SOURCES_CUDA} ${GGML_SOURCES_OPENCL} ${GGML_SOURCES_METAL} diff --git a/Makefile b/Makefile index 1f910c3ec..7c9e7f739 100644 --- a/Makefile +++ b/Makefile @@ -40,8 +40,11 @@ endif # # keep standard at C11 and C++11 -CFLAGS = -I. -O3 -std=c11 -fPIC -CXXFLAGS = -I. -I./examples -O3 -std=c++11 -fPIC +# -Ofast tends to produce faster code, but may not be available for some compilers. +#OPT = -Ofast +OPT = -O3 +CFLAGS = -I. $(OPT) -std=c11 -fPIC +CXXFLAGS = -I. -I./examples $(OPT) -std=c++11 -fPIC LDFLAGS = ifdef LLAMA_DEBUG @@ -228,7 +231,10 @@ $(info ) # Build library # -ggml.o: ggml.c ggml.h ggml-cuda.h +ggml.o: ggml.c ggml.h ggml-cuda.h ggml-quants-k.h + $(CC) $(CFLAGS) -c $< -o $@ + +ggml-quants-k.o: ggml-quants-k.c ggml-quants-k.h ggml.h ggml-cuda.h $(CC) $(CFLAGS) -c $< -o $@ llama.o: llama.cpp ggml.h ggml-cuda.h llama.h llama-util.h @@ -247,25 +253,25 @@ clean: # Examples # -main: examples/main/main.cpp build-info.h ggml.o llama.o common.o $(OBJS) +main: examples/main/main.cpp build-info.h ggml.o ggml-quants-k.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) @echo @echo '==== Run ./main -h for help. ====' @echo -quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o $(OBJS) +quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o ggml-quants-k.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -quantize-stats: examples/quantize-stats/quantize-stats.cpp build-info.h ggml.o llama.o $(OBJS) +quantize-stats: examples/quantize-stats/quantize-stats.cpp build-info.h ggml.o llama.o ggml-quants-k.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -perplexity: examples/perplexity/perplexity.cpp build-info.h ggml.o llama.o common.o $(OBJS) +perplexity: examples/perplexity/perplexity.cpp build-info.h ggml.o llama.o common.o ggml-quants-k.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -embedding: examples/embedding/embedding.cpp build-info.h ggml.o llama.o common.o $(OBJS) +embedding: examples/embedding/embedding.cpp build-info.h ggml.o llama.o common.o ggml-quants-k.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o llama.o common.o $(OBJS) +save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o llama.o common.o ggml-quants-k.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp build-info.h ggml.o llama.o common.o $(OBJS) @@ -287,7 +293,7 @@ benchmark-matmult: examples/benchmark/benchmark-matmult.cpp build-info.h ggml.o $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) ./$@ -vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS) +vdot: pocs/vdot/vdot.cpp ggml.o ggml-quants-k.o $(OBJS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) .PHONY: tests clean diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 085fdde3c..6e4f7e1e0 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -282,8 +282,9 @@ int main(int argc, char ** argv) { break; } int j; - for (j = 0; j < GGML_TYPE_COUNT && strcmp(argv[i], ggml_type_name((ggml_type) j)) != 0; j++) { - // find match + for (j = 0; j < GGML_TYPE_COUNT; ++j) { + const auto * name = ggml_type_name((ggml_type) j); + if (name && strcmp(argv[i], name) == 0) break; } if (j < GGML_TYPE_COUNT) { params.include_types.push_back((ggml_type) j); diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 769dd36a4..947b40202 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -7,11 +7,23 @@ #include static const std::map LLAMA_FTYPE_MAP = { - {"q4_0", LLAMA_FTYPE_MOSTLY_Q4_0}, - {"q4_1", LLAMA_FTYPE_MOSTLY_Q4_1}, - {"q5_0", LLAMA_FTYPE_MOSTLY_Q5_0}, - {"q5_1", LLAMA_FTYPE_MOSTLY_Q5_1}, - {"q8_0", LLAMA_FTYPE_MOSTLY_Q8_0}, + {"q4_0", LLAMA_FTYPE_MOSTLY_Q4_0}, + {"q4_1", LLAMA_FTYPE_MOSTLY_Q4_1}, + {"q5_0", LLAMA_FTYPE_MOSTLY_Q5_0}, + {"q5_1", LLAMA_FTYPE_MOSTLY_Q5_1}, + {"q8_0", LLAMA_FTYPE_MOSTLY_Q8_0}, + {"q2_K", LLAMA_FTYPE_MOSTLY_Q2_K}, + {"q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M}, + {"q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S}, + {"q3_K_M", LLAMA_FTYPE_MOSTLY_Q3_K_M}, + {"q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L}, + {"q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M}, + {"q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S}, + {"q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M}, + {"q5_K", LLAMA_FTYPE_MOSTLY_Q5_K_M}, + {"q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S}, + {"q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M}, + {"q6_K", LLAMA_FTYPE_MOSTLY_Q6_K}, }; bool try_parse_ftype(const std::string & ftype_str, llama_ftype & ftype, std::string & ftype_str_out) { diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 98170a3ae..5385e0120 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -35,6 +36,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1); typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream); +typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v); // QK = number of values after dequantization // QR = QK / number of values before dequantization @@ -83,6 +85,51 @@ typedef struct { } block_q8_0; static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); +//================================= k-quants + +#define QK_K 256 + +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins +} block_q2_k; +static_assert(sizeof(block_q2_k) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_k block size/padding"); + +typedef struct { + uint8_t hmask[QK_K/8]; + uint8_t qs[QK_K/4]; // nibbles / quants + uint8_t scales[3*QK_K/64]; + half d; +} block_q3_k; +static_assert(sizeof(block_q3_k) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_k block size/padding"); + +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_k; +static_assert(sizeof(block_q4_k) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_k block size/padding"); + +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_k; +static_assert(sizeof(block_q5_k) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_k block size/padding"); + +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales + half d; // delta +} block_q6_k; +static_assert(sizeof(block_q6_k) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_k block size/padding"); + #define WARP_SIZE 32 #define CUDA_MUL_BLOCK_SIZE 256 @@ -184,6 +231,337 @@ static __device__ void dequantize_q8_0(const void * vx, const int ib, const int v1 = vi1*d; } +//================================== k-quants + +static __global__ void dequantize_block_q2_k(const void * vx, float * yy) { + + const int i = blockIdx.x; + const int tid = threadIdx.x; + const int n = tid/32; + const int l = tid - 32*n; + const int is = 8*n + l/16; + + const block_q2_k * x = (const block_q2_k *) vx; + + const uint8_t q = x[i].qs[32*n + l]; + float * y = yy + i*QK_K + 128*n; + + float dall = x[i].d; + float dmin = x[i].dmin; + y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); + y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4); + y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); + y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); + +} + +static __device__ void vec_dot_q2_k(const void * vx, const int ib, const int iqs, const float * yy, float & result) { + + const block_q2_k * x = (const block_q2_k *) vx; + + // if n is 0, we want to do the lower 128, else the upper 128, + // covering y[l+0], y[l+32], y[l+64], y[l+96] and + // y[l+16], y[l+48], y[l+80], y[l+112] + int n = iqs/128; // 0 or 1 + int r = iqs - 128*n; // 0...120 in steps of 8 + int l = r/8; // 0...15 in steps of 1 + + const float * y = yy + 128*n + l; + const uint8_t * q = x[ib].qs + 32*n + l; + const uint8_t * s = x[ib].scales + 8*n; + + const float dall = x[ib].d; + const float dmin = x[ib].dmin; + + float sum = y[ 0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4)) + + y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4)) + + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4)) + + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4)) + + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4)) + + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4)) + + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4)) + + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4)); + + result = sum; + +} + +static __global__ void dequantize_block_q3_k(const void * vx, float * yy) { + + int r = threadIdx.x/4; + int i = blockIdx.x; + int tid = r/2; + int is0 = r%2; + int l0 = 16*is0 + 4*(threadIdx.x%4); + int n = tid / 4; + int j = tid - 4*n; + + const block_q3_k * x = (const block_q3_k *) vx; + + uint8_t m = 1 << (4*n + j); + int is = 8*n + 2*j + is0; + int shift = 2*j; + + int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) : + (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4); + float d_all = x[i].d; + float dl = d_all * (us - 32); + + float * y = yy + i*QK_K + 128*n + 32*j; + const uint8_t * q = x[i].qs + 32*n; + const uint8_t * hm = x[i].hmask; + + for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); + +} + +static __device__ void vec_dot_q3_k(const void * vx, const int ib, const int iqs, const float * yy, float & result) { + + const block_q3_k * x = (const block_q3_k *) vx; + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + uint32_t aux[3]; + uint32_t utmp[4]; + + // if n is 0, we want to do the lower 128, else the upper 128, + // covering y[l+0], y[l+32], y[l+64], y[l+96] and + // y[l+16], y[l+48], y[l+80], y[l+112] + int n = iqs/128; // 0 or 1 + int r = iqs - 128*n; // 0...120 in steps of 8 + int l = r/8; // 0...15 in steps of 1 + + const float * y = yy + 128*n + l; + const uint8_t * q = x[ib].qs + 32*n + l; + const uint8_t * hm = x[ib].hmask + l; + const int8_t * s = (const int8_t *)utmp + 8*n; + + memcpy(aux, x[ib].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + const float dall = x[ib].d; + + const uint8_t m = 1 << (4*n); + + float sum = y[ 0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4)) + + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4)) + + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4)) + + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4)) + + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4)) + + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4)) + + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4)) + + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4)); + + result = sum * dall; + +} + +static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { + if (j < 4) { + d = q[j] & 63; m = q[j + 4] & 63; + } else { + d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} + +static __global__ void dequantize_block_q4_k(const void * vx, float * yy) { + const block_q4_k * x = (const block_q4_k *) vx; + + const int i = blockIdx.x; + + //// assume 64 threads - this is very slightly better than the one below + //const int tid = threadIdx.x; + //const int il = tid/16; + //const int ir = tid%16; + //const int is = 2*il; + //const int n = 2; + + // assume 32 threads + const int tid = threadIdx.x; + const int il = tid/8; + const int ir = tid%8; + const int is = 2*il; + const int n = 4; + + float * y = yy + i*QK_K + 64*il + n*ir; + + const float dall = x[i].d; + const float dmin = x[i].dmin; + + const uint8_t * q = x[i].qs + 32*il + n*ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, sc, m); + const float d1 = dall * sc; const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[i].scales, sc, m); + const float d2 = dall * sc; const float m2 = dmin * m; + for (int l = 0; l < n; ++l) { + y[l + 0] = d1 * (q[l] & 0xF) - m1; + y[l +32] = d2 * (q[l] >> 4) - m2; + } +} + +static __device__ void vec_dot_q4_k(const void * vx, const int ib, const int iqs, const float * yy, float & result) { + + const block_q4_k * x = (const block_q4_k *) vx; + + // iqs is in 0...248 in steps of 8 => + const int j = iqs / 64; // j is in 0...3 + const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4 + const int is = 2*j; // is is in 0...6 in steps of 2 + + const float * y = yy + 64*j + ir; + const uint8_t * q = x[ib].qs + 32*j + ir; + + const float dall = x[ib].d; + const float dmin = x[ib].dmin; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[ib].scales, sc, m); + const float d1 = dall * sc; + const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[ib].scales, sc, m); + const float d2 = dall * sc; + const float m2 = dmin * m; + + float sum = 0; + for (int k = 0; k < 4; ++k) { + sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1); + sum += y[k + 32] * (d2 * (q[k] >> 4) - m2); + } + result = sum; + +} + +static __global__ void dequantize_block_q5_k(const void * vx, float * yy) { + const block_q5_k * x = (const block_q5_k *) vx; + + const int i = blockIdx.x; + + // assume 64 threads - this is very slightly better than the one below + const int tid = threadIdx.x; + const int il = tid/16; // il is in 0...3 + const int ir = tid%16; // ir is in 0...15 + const int is = 2*il; // is is in 0...6 + + float * y = yy + i*QK_K + 64*il + 2*ir; + + const float dall = x[i].d; + const float dmin = x[i].dmin; + + const uint8_t * ql = x[i].qs + 32*il + 2*ir; + const uint8_t * qh = x[i].qh + 2*ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, sc, m); + const float d1 = dall * sc; const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[i].scales, sc, m); + const float d2 = dall * sc; const float m2 = dmin * m; + + uint8_t hm = 1 << (2*il); + y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1; + y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1; + hm <<= 1; + y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; + y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; +} + +static __device__ void vec_dot_q5_k(const void * vx, const int ib, const int iqs, const float * yy, float & result) { + + const block_q5_k * x = (const block_q5_k *) vx; + + // iqs is in 0...248 in steps of 8 => + const int j = iqs / 64; // j is in 0...3 + const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4 + const int is = 2*j; // is is in 0...6 in steps of 2 + + const float * y = yy + 64*j + ir; + const uint8_t * ql = x[ib].qs + 32*j + ir; + const uint8_t * qh = x[ib].qh + ir; + + const float dall = x[ib].d; + const float dmin = x[ib].dmin; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[ib].scales, sc, m); + const float d1 = dall * sc; + const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[ib].scales, sc, m); + const float d2 = dall * sc; + const float m2 = dmin * m; + + uint8_t hm = 1 << is; + float sum = 0; + for (int k = 0; k < 4; ++k) { + sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1); + } + hm <<= 1; + for (int k = 0; k < 4; ++k) { + sum += y[k + 32] * (d2 * ((ql[k] >> 4) + (qh[k] & hm ? 16 : 0)) - m2); + } + result = sum; + +} + +static __global__ void dequantize_block_q6_k(const void * vx, float * yy) { + const block_q6_k * x = (const block_q6_k *) vx; + + const int i = blockIdx.x; + + // assume 64 threads - this is very slightly better than the one below + const int tid = threadIdx.x; + const int ip = tid/32; // ip is 0 or 1 + const int il = tid - 32*ip; // 0...32 + const int is = 8*ip + il/16; + + float * y = yy + i*QK_K + 128*ip + il; + + const float d = x[i].d; + + const uint8_t * ql = x[i].ql + 64*ip + il; + const uint8_t qh = x[i].qh[32*ip + il]; + const int8_t * sc = x[i].scales + is; + + y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32); + y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32); + y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); +} + +static __device__ void vec_dot_q6_k(const void * vx, const int ib, const int iqs, const float * yy, float & result) { + + const block_q6_k * x = (const block_q6_k *) vx; + + const int ip = iqs / 128; // 0 or 1 + const int il = (iqs - 128*ip)/8; // 0...15 + const int is = 8*ip; + + const float * y = yy + 128*ip + il; + + const float d = x[ib].d; + + const uint8_t * ql = x[ib].ql + 64*ip + il; + const uint8_t * qh = x[ib].qh + 32*ip + il; + const int8_t * sc = x[ib].scales + is; + + result = y[ 0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32) + + y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32) + + y[ 64] * d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32) + + y[ 96] * d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32) + + y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32) + + y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32) + + y[ 80] * d * sc[5] * ((int8_t)((ql[16] >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32) + + y[112] * d * sc[7] * ((int8_t)((ql[48] >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32); + +} + static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){ const half * x = (const half *) vx; @@ -258,6 +636,41 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, } } +template +static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + const int iter_stride = QK_K; + const int vals_per_iter = iter_stride / n_thread; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + float tmp = 0; // partial sum for thread in warp + + for (int i = 0; i < ncols; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = ib0 + col/QK_K; // x block index + const int iqs = col%QK_K; // x quant index + const int iybs = col - col%QK_K; // y block start index + + float v; + dot_kernel(vx, ib, iqs, y + iybs, v); + tmp += v; + } + + // sum up partial sums and write back result + __syncthreads(); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE; mul_f32<<>>(x, y, dst, kx, ky); @@ -288,6 +701,31 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu dequantize_block<<>>(vx, y, k); } +static void dequantize_row_q2_k_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_q2_k<<>>(vx, y); +} + +static void dequantize_row_q3_k_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_q3_k<<>>(vx, y); +} + +static void dequantize_row_q4_k_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_q4_k<<>>(vx, y); +} + +static void dequantize_row_q5_k_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_q5_k<<>>(vx, y); +} + +static void dequantize_row_q6_k_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_q6_k<<>>(vx, y); +} + static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); @@ -328,6 +766,37 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f <<>>(vx, y, dst, ncols); } +static void dequantize_mul_mat_vec_q2_k_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2; + const dim3 block_dims(32, ny, 1); + dequantize_mul_mat_vec_k<32, vec_dot_q2_k><<<(nrows + ny - 1)/ny, block_dims, 0, stream>>>(vx, y, dst, ncols); +} + +static void dequantize_mul_mat_vec_q3_k_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const dim3 block_dims(32, 2, 1); + dequantize_mul_mat_vec_k<32, vec_dot_q3_k><<>>(vx, y, dst, ncols); +} + +static void dequantize_mul_mat_vec_q4_k_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const dim3 block_dims(32, 2, 1); + dequantize_mul_mat_vec_k<32, vec_dot_q4_k><<>>(vx, y, dst, ncols); +} + +static void dequantize_mul_mat_vec_q5_k_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const dim3 block_dims(32, 2, 1); + dequantize_mul_mat_vec_k<32, vec_dot_q5_k><<>>(vx, y, dst, ncols); +} + +static void dequantize_mul_mat_vec_q6_k_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const dim3 block_dims(32, 2, 1); + dequantize_mul_mat_vec_k<32, vec_dot_q6_k><<>>(vx, y, dst, ncols); +} + static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block<32, 1, convert_f16><<>>(vx, y, k); @@ -353,6 +822,16 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_q5_1_cuda; case GGML_TYPE_Q8_0: return dequantize_row_q8_0_cuda; + case GGML_TYPE_Q2_K: + return dequantize_row_q2_k_cuda; + case GGML_TYPE_Q3_K: + return dequantize_row_q3_k_cuda; + case GGML_TYPE_Q4_K: + return dequantize_row_q4_k_cuda; + case GGML_TYPE_Q5_K: + return dequantize_row_q5_k_cuda; + case GGML_TYPE_Q6_K: + return dequantize_row_q6_k_cuda; case GGML_TYPE_F16: return convert_fp16_to_fp32_cuda; default: @@ -372,6 +851,16 @@ static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_t return dequantize_mul_mat_vec_q5_1_cuda; case GGML_TYPE_Q8_0: return dequantize_mul_mat_vec_q8_0_cuda; + case GGML_TYPE_Q2_K: + return dequantize_mul_mat_vec_q2_k_cuda; + case GGML_TYPE_Q3_K: + return dequantize_mul_mat_vec_q3_k_cuda; + case GGML_TYPE_Q4_K: + return dequantize_mul_mat_vec_q4_k_cuda; + case GGML_TYPE_Q5_K: + return dequantize_mul_mat_vec_q5_k_cuda; + case GGML_TYPE_Q6_K: + return dequantize_mul_mat_vec_q6_k_cuda; case GGML_TYPE_F16: return convert_mul_mat_vec_f16_cuda; default: @@ -790,12 +1279,14 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); // compute + //printf("Calling dmmv\n"); dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream); CUDA_CHECK(cudaGetLastError()); } else { // general dequantization kernel + cuBLAS matrix matrix multiplication float * c_X = d_X + i * x_ne; +//typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); // convert src0 to fp32 on device to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2); CUDA_CHECK(cudaGetLastError()); diff --git a/ggml-quants-k.c b/ggml-quants-k.c new file mode 100644 index 000000000..dec00d371 --- /dev/null +++ b/ggml-quants-k.c @@ -0,0 +1,2246 @@ +#include "ggml-quants-k.h" +#include "ggml.h" + +#include +#include +#include + +#ifdef __ARM_NEON + +// if YCM cannot find , make a symbolic link to it, for example: +// +// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ +// +#include + +#else + +#ifdef __wasm_simd128__ +#include +#else +#ifdef __POWER9_VECTOR__ +#include +#undef bool +#define bool _Bool +#else +#if defined(_MSC_VER) || defined(__MINGW32__) +#include +#else +#if !defined(__riscv) +#include +#endif +#endif +#endif +#endif +#endif + +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +// +// 2-6 bit quantization in super-blocks +// + + +// +// ===================== Helper functions +// +static inline int nearest_int(float fval) { + assert(fval <= 4194303.f); + float val = fval + 12582912.f; + int i; memcpy(&i, &val, sizeof(int)); + return (i & 0x007fffff) - 0x00400000; +} + +static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) { + float max = 0; + float amax = 0; + for (int i = 0; i < n; ++i) { + float ax = fabsf(x[i]); + if (ax > amax) { amax = ax; max = x[i]; } + } + if (!amax) { // all zero + for (int i = 0; i < n; ++i) { + L[i] = 0; + } + return 0.f; + } + float iscale = -nmax / max; + if (rmse_type == 0) { + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); + } + return 1/iscale; + } + int weight_type = rmse_type%2; + float sumlx = 0; + float suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + L[i] = l + nmax; + float w = weight_type == 1 ? x[i] * x[i] : 1; + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + float scale = sumlx/suml2; + float best = scale * sumlx; + for (int itry = 0; itry < 3; ++itry) { + iscale = 1/scale; + float slx = 0; + float sl2 = 0; + bool changed = false; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + if (l + nmax != L[i]) { changed = true; } + float w = weight_type == 1 ? x[i] * x[i] : 1.f; + slx += w*x[i]*l; + sl2 += w*l*l; + } + if (!changed || sl2 == 0 || slx*slx <= best*sl2) { break; } + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); + } + sumlx = slx; suml2 = sl2; + scale = sumlx/suml2; + best = scale * sumlx; + } + for (int itry = 0; itry < 5; ++itry) { + int n_changed = 0; + for (int i = 0; i < n; ++i) { + float w = weight_type == 1 ? x[i]*x[i] : 1; + int l = L[i] - nmax; + float slx = sumlx - w*x[i]*l; + if (slx > 0) { + float sl2 = suml2 - w*l*l; + int new_l = nearest_int(x[i] * sl2 / slx); + new_l = MAX(-nmax, MIN(nmax-1, new_l)); + if (new_l != l) { + slx += w*x[i]*new_l; + sl2 += w*new_l*new_l; + if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) { + L[i] = nmax + new_l; sumlx = slx; suml2 = sl2; + scale = sumlx / suml2; best = scale * sumlx; + ++n_changed; + } + } + } + } + if (!n_changed) { break; } + } + if (rmse_type < 3) { + return scale; + } + for (int is = -4; is <= 4; ++is) { + if (is == 0) { + continue; + } + iscale = -(nmax + 0.1f*is) / max; + sumlx = suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + float w = weight_type == 1 ? x[i] * x[i] : 1; + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + if (suml2 > 0 && sumlx*sumlx > best*suml2) { + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); + } + scale = sumlx/suml2; best = scale*sumlx; + } + } + return scale; +} + +static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) { + float max = 0; + float amax = 0; + for (int i = 0; i < n; ++i) { + float ax = fabsf(x[i]); + if (ax > amax) { amax = ax; max = x[i]; } + } + if (!amax) { // all zero + for (int i = 0; i < n; ++i) { L[i] = 0; } + return 0.f; + } + float iscale = -nmax / max; + if (do_rmse) { + float sumlx = 0; + float suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + L[i] = l; + float w = x[i]*x[i]; + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + for (int itry = 0; itry < 5; ++itry) { + int n_changed = 0; + for (int i = 0; i < n; ++i) { + float w = x[i]*x[i]; + float slx = sumlx - w*x[i]*L[i]; + if (slx > 0) { + float sl2 = suml2 - w*L[i]*L[i]; + int new_l = nearest_int(x[i] * sl2 / slx); + new_l = MAX(-nmax, MIN(nmax-1, new_l)); + if (new_l != L[i]) { + slx += w*x[i]*new_l; + sl2 += w*new_l*new_l; + if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) { + L[i] = new_l; sumlx = slx; suml2 = sl2; + ++n_changed; + } + } + } + } + if (!n_changed) { + break; + } + } + for (int i = 0; i < n; ++i) { + L[i] += nmax; + } + return sumlx / suml2; + } + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + L[i] = l + nmax; + } + return 1/iscale; +} + +static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min, int ntry) { + float min = x[0]; + float max = x[0]; + for (int i = 1; i < n; ++i) { + if (x[i] < min) min = x[i]; + if (x[i] > max) max = x[i]; + } + if (max == min) { + for (int i = 0; i < n; ++i) L[i] = 0; + *the_min = 0; + return 0.f; + } + if (min > 0) min = 0; + float iscale = nmax/(max - min); + float scale = 1/iscale; + for (int itry = 0; itry < ntry; ++itry) { + float sumlx = 0; int suml2 = 0; + bool did_change = false; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + l = MAX(0, MIN(nmax, l)); + if (l != L[i]) { + L[i] = l; + did_change = true; + } + sumlx += (x[i] - min)*l; + suml2 += l*l; + } + scale = sumlx/suml2; + float sum = 0; + for (int i = 0; i < n; ++i) { + sum += x[i] - scale*L[i]; + } + min = sum/n; + if (min > 0) min = 0; + iscale = 1/scale; + if (!did_change) break; + } + *the_min = -min; + return scale; +} + +static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) { + if (j < 4) { + *d = q[j] & 63; *m = q[j + 4] & 63; + } else { + *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} + +//========================- 2-bit (de)-quantization + +void quantize_row_q2_k_reference(const float * restrict x, block_q2_k * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + uint8_t L[QK_K]; + float mins[QK_K/16]; + float scales[QK_K/16]; + + const float q4scale = 15.f; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; // as we are deducting the min, scales are always positive + float max_min = 0; + for (int j = 0; j < QK_K/16; ++j) { + scales[j] = make_qkx1_quants(16, 3, x + 16*j, L + 16*j, &mins[j], 5); + float scale = scales[j]; + if (scale > max_scale) { + max_scale = scale; + } + float min = mins[j]; + if (min > max_min) { + max_min = min; + } + } + + if (max_scale > 0) { + float iscale = q4scale/max_scale; + for (int j = 0; j < QK_K/16; ++j) { + int l = nearest_int(iscale*scales[j]); + y[i].scales[j] = l; + } + y[i].d = ggml_fp32_to_fp16(max_scale/q4scale); + } else { + for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0; + y[i].d = ggml_fp32_to_fp16(0.f); + } + if (max_min > 0) { + float iscale = q4scale/max_min; + for (int j = 0; j < QK_K/16; ++j) { + int l = nearest_int(iscale*mins[j]); + y[i].scales[j] |= (l << 4); + } + y[i].dmin = ggml_fp32_to_fp16(max_min/q4scale); + } else { + y[i].dmin = ggml_fp32_to_fp16(0.f); + } + for (int j = 0; j < QK_K/16; ++j) { + const float d = ggml_fp16_to_fp32(y[i].d) * (y[i].scales[j] & 0xF); + if (!d) continue; + const float dm = ggml_fp16_to_fp32(y[i].dmin) * (y[i].scales[j] >> 4); + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int((x[16*j + ii] + dm)/d); + l = MAX(0, MIN(3, l)); + L[16*j + ii] = l; + } + } + + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); + } + } + + x += QK_K; + + } +} + +void dequantize_row_q2_k(const block_q2_k * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = ggml_fp16_to_fp32(x[i].d); + const float min = ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * q = x[i].qs; + + int is = 0; + float dl, ml; + for (int n = 0; n < QK_K; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + + uint8_t sc = x[i].scales[is++]; + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml; + + sc = x[i].scales[is++]; + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml; + + shift += 2; + } + q += 32; + } + + } +} + +void quantize_row_q2_k(const float * restrict x, void * restrict vy, int k) { + quantize_row_q2_k_reference(x, vy, k); +} + +size_t ggml_quantize_q2_k(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { + const int nb = k / QK_K; + + // TODO - collect histograms - although, at a second thought, I don't really care about them + (void)hist; + + for (int j = 0; j < nb; j += k) { + block_q2_k * restrict y = (block_q2_k *)dst + j/QK_K; + quantize_row_q2_k_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q2_k)); +} + +//========================= 3-bit (de)-quantization + +void quantize_row_q3_k_reference(const float * restrict x, block_q3_k * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + int8_t L[QK_K]; + float scales[QK_K / 16]; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; + float amax = 0; + for (int j = 0; j < QK_K/16; ++j) { + scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true); + float scale = fabsf(scales[j]); + if (scale > amax) { + amax = scale; max_scale = scales[j]; + } + } + + memset(y[i].scales, 0, 12); + if (max_scale) { + float iscale = -32.f/max_scale; + for (int j = 0; j < QK_K/16; ++j) { + int8_t l = nearest_int(iscale*scales[j]); + l = MAX(-32, MIN(31, l)) + 32; + if (j < 8) { + y[i].scales[j] = l & 0xF; + } else { + y[i].scales[j-8] |= ((l & 0xF) << 4); + } + l >>= 4; + y[i].scales[j%4 + 8] |= (l << (2*(j/4))); + } + y[i].d = ggml_fp32_to_fp16(1/iscale); + } else { + y[i].d = ggml_fp32_to_fp16(0.f); + } + + int8_t sc; + for (int j = 0; j < QK_K/16; ++j) { + sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4; + sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32; + float d = ggml_fp16_to_fp32(y[i].d) * sc; + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-4, MIN(3, l)); + L[16*j + ii] = l + 4; + } + } + + memset(y[i].hmask, 0, QK_K/8); + // We put the high-bit for the 1st 32 quants into bit 0, the next 32 into bit 1, etc. + int m = 0; + uint8_t hm = 1; + for (int j = 0; j < QK_K; ++j) { + if (L[j] > 3) { + y[i].hmask[m] |= hm; + L[j] -= 4; + } + if (++m == QK_K/8) { + m = 0; hm <<= 1; + } + } + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); + } + } + + x += QK_K; + } +} + +void dequantize_row_q3_k(const block_q3_k * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + assert(QK_K == 256); + const int nb = k / QK_K; + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + uint32_t aux[4]; + const int8_t * scales = (const int8_t*)aux; + + for (int i = 0; i < nb; i++) { + + const float d_all = ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + uint8_t m = 1; + + memcpy(aux, x[i].scales, 12); + uint32_t tmp = aux[2]; + aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + int is = 0; + float dl; + for (int n = 0; n < QK_K; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4)); + } + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4)); + } + + shift += 2; + m <<= 1; + } + q += 32; + } + + } +} + +void quantize_row_q3_k(const float * restrict x, void * restrict vy, int k) { + quantize_row_q3_k_reference(x, vy, k); +} + +size_t ggml_quantize_q3_k(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { + const int nb = k / QK_K; + + // TODO - collect histograms - although, at a second thought, I don't really care about them + (void)hist; + + for (int j = 0; j < nb; j += k) { + block_q3_k * restrict y = (block_q3_k *)dst + j/QK_K; + quantize_row_q3_k_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q3_k)); +} + +// ====================== 4-bit (de)-quantization + +void quantize_row_q4_k_reference(const float * restrict x, block_q4_k * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + uint8_t L[QK_K]; + float mins[QK_K/32]; + float scales[QK_K/32]; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; // as we are deducting the min, scales are always positive + float max_min = 0; + for (int j = 0; j < QK_K/32; ++j) { + scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 5); + float scale = scales[j]; + if (scale > max_scale) { + max_scale = scale; + } + float min = mins[j]; + if (min > max_min) { + max_min = min; + } + } + + float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f; + float inv_min = max_min > 0 ? 63.f/max_min : 0.f; + for (int j = 0; j < QK_K/32; ++j) { + uint8_t ls = nearest_int(inv_scale*scales[j]); + uint8_t lm = nearest_int(inv_min*mins[j]); + ls = MIN(63, ls); + lm = MIN(63, lm); + if (j < 4) { + y[i].scales[j] = ls; + y[i].scales[j+4] = lm; + } else { + y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4); + y[i].scales[j-4] |= ((ls >> 4) << 6); + y[i].scales[j-0] |= ((lm >> 4) << 6); + } + } + y[i].d = ggml_fp32_to_fp16(max_scale/63.f); + y[i].dmin = ggml_fp32_to_fp16(max_min/63.f); + + uint8_t sc, m; + for (int j = 0; j < QK_K/32; ++j) { + get_scale_min_k4(j, y[i].scales, &sc, &m); + const float d = ggml_fp16_to_fp32(y[i].d) * sc; + if (!d) continue; + const float dm = ggml_fp16_to_fp32(y[i].dmin) * m; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + dm)/d); + l = MAX(0, MIN(15, l)); + L[32*j + ii] = l; + } + } + uint8_t * q = y[i].qs; + for (int j = 0; j < QK_K; j += 64) { + for (int l = 0; l < 32; ++l) *q++ = L[j + l] | (L[j + l + 32] << 4); + } + + x += QK_K; + + } +} + +void dequantize_row_q4_k(const block_q4_k * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = ggml_fp16_to_fp32(x[i].d); + const float min = ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * q = x[i].qs; + + int is = 0; + uint8_t sc, m; + for (int j = 0; j < QK_K; j += 64) { + get_scale_min_k4(is + 0, x[i].scales, &sc, &m); + const float d1 = d * sc; const float m1 = min * m; + get_scale_min_k4(is + 1, x[i].scales, &sc, &m); + const float d2 = d * sc; const float m2 = min * m; + for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1; + for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; + q += 32; is += 2; + } + + } +} + +void quantize_row_q4_k(const float * restrict x, void * restrict vy, int k) { + assert(k % QK_K == 0); + block_q4_k * restrict y = vy; + quantize_row_q4_k_reference(x, y, k); +} + +size_t ggml_quantize_q4_k(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + (void)hist; // TODO: collect histograms + for (int j = 0; j < nb; j += k) { + block_q4_k * restrict y = (block_q4_k *)dst + j/QK_K; + quantize_row_q4_k_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q4_k)); +} + +// ====================== 5-bit (de)-quantization + +void quantize_row_q5_k_reference(const float * restrict x, block_q5_k * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + uint8_t L[QK_K]; + float mins[QK_K/32]; + float scales[QK_K/32]; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; // as we are deducting the min, scales are always positive + float max_min = 0; + for (int j = 0; j < QK_K/32; ++j) { + scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 5); + float scale = scales[j]; + if (scale > max_scale) { + max_scale = scale; + } + float min = mins[j]; + if (min > max_min) { + max_min = min; + } + } + + float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f; + float inv_min = max_min > 0 ? 63.f/max_min : 0.f; + for (int j = 0; j < QK_K/32; ++j) { + uint8_t ls = nearest_int(inv_scale*scales[j]); + uint8_t lm = nearest_int(inv_min*mins[j]); + ls = MIN(63, ls); + lm = MIN(63, lm); + if (j < 4) { + y[i].scales[j] = ls; + y[i].scales[j+4] = lm; + } else { + y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4); + y[i].scales[j-4] |= ((ls >> 4) << 6); + y[i].scales[j-0] |= ((lm >> 4) << 6); + } + } + y[i].d = ggml_fp32_to_fp16(max_scale/63.f); + y[i].dmin = ggml_fp32_to_fp16(max_min/63.f); + + uint8_t sc, m; + for (int j = 0; j < QK_K/32; ++j) { + get_scale_min_k4(j, y[i].scales, &sc, &m); + const float d = ggml_fp16_to_fp32(y[i].d) * sc; + if (!d) continue; + const float dm = ggml_fp16_to_fp32(y[i].dmin) * m; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + dm)/d); + l = MAX(0, MIN(31, l)); + L[32*j + ii] = l; + } + } + + uint8_t * restrict qh = y[i].qh; + uint8_t * restrict ql = y[i].qs; + memset(qh, 0, QK_K/8); + + uint8_t m1 = 1, m2 = 2; + for (int n = 0; n < QK_K; n += 64) { + for (int j = 0; j < 32; ++j) { + int l1 = L[n + j]; + if (l1 > 15) { + l1 -= 16; qh[j] |= m1; + } + int l2 = L[n + j + 32]; + if (l2 > 15) { + l2 -= 16; qh[j] |= m2; + } + ql[j] = l1 | (l2 << 4); + } + m1 <<= 2; m2 <<= 2; + ql += 32; + } + + x += QK_K; + + } +} + +void dequantize_row_q5_k(const block_q5_k * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = ggml_fp16_to_fp32(x[i].d); + const float min = ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * ql = x[i].qs; + const uint8_t * qh = x[i].qh; + + int is = 0; + uint8_t sc, m; + uint8_t u1 = 1, u2 = 2; + for (int j = 0; j < QK_K; j += 64) { + get_scale_min_k4(is + 0, x[i].scales, &sc, &m); + const float d1 = d * sc; const float m1 = min * m; + get_scale_min_k4(is + 1, x[i].scales, &sc, &m); + const float d2 = d * sc; const float m2 = min * m; + for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1; + for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2; + ql += 32; is += 2; + u1 <<= 2; u2 <<= 2; + } + } +} + +void quantize_row_q5_k(const float * restrict x, void * restrict vy, int k) { + assert(k % QK_K == 0); + block_q5_k * restrict y = vy; + quantize_row_q5_k_reference(x, y, k); +} + +size_t ggml_quantize_q5_k(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + (void)hist; + for (int j = 0; j < nb; j += k) { + block_q5_k * restrict y = (block_q5_k *)dst + j/QK_K; + quantize_row_q5_k_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q5_k)); +} + +// ====================== 6-bit (de)-quantization + +void quantize_row_q6_k_reference(const float * restrict x, block_q6_k * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + int8_t L[QK_K]; + float scales[QK_K/16]; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; + float max_abs_scale = 0; + + for (int ib = 0; ib < QK_K/16; ++ib) { + + const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1); + scales[ib] = scale; + + const float abs_scale = fabsf(scale); + if (abs_scale > max_abs_scale) { + max_abs_scale = abs_scale; + max_scale = scale; + } + + } + + float iscale = -128.f/max_scale; + y[i].d = ggml_fp32_to_fp16(1/iscale); + for (int ib = 0; ib < QK_K/16; ++ib) { + y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib])); + } + + for (int j = 0; j < QK_K/16; ++j) { + float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j]; + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-32, MIN(31, l)); + L[16*j + ii] = l + 32; + } + } + + uint8_t * restrict ql = y[i].ql; + uint8_t * restrict qh = y[i].qh; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + const uint8_t q1 = L[j + l + 0] & 0xF; + const uint8_t q2 = L[j + l + 32] & 0xF; + const uint8_t q3 = L[j + l + 64] & 0xF; + const uint8_t q4 = L[j + l + 96] & 0xF; + ql[l+ 0] = q1 | (q3 << 4); + ql[l+32] = q2 | (q4 << 4); + qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6); + } + ql += 64; + qh += 32; + } + + x += QK_K; + + } +} + +void dequantize_row_q6_k(const block_q6_k * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict ql = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict sc = x[i].scales; + + for (int n = 0; n < QK_K; n += 128) { + for (int l = 0; l < 32; ++l) { + int is = l/16; + const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + y[l + 0] = d * sc[is + 0] * q1; + y[l + 32] = d * sc[is + 2] * q2; + y[l + 64] = d * sc[is + 4] * q3; + y[l + 96] = d * sc[is + 6] * q4; + } + y += 128; + ql += 64; + qh += 32; + sc += 8; + } + + } +} + +void quantize_row_q6_k(const float * restrict x, void * restrict vy, int k) { + assert(k % QK_K == 0); + block_q6_k * restrict y = vy; + quantize_row_q6_k_reference(x, y, k); +} + +size_t ggml_quantize_q6_k(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + (void)hist; // TODO + + for (int j = 0; j < nb; j += k) { + block_q6_k * restrict y = (block_q6_k *)dst + j/QK_K; + quantize_row_q6_k_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q6_k)); +} + +//===================================== Q8_K ============================================== + +void quantize_row_q8_k_reference(const float * restrict x, block_q8_k * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + float max = 0; + float amax = 0; + for (int j = 0; j < QK_K; ++j) { + float ax = fabsf(x[j]); + if (ax > amax) { + amax = ax; max = x[j]; + } + } + if (!amax) { + y[i].d = 0; + memset(y[i].qs, 0, QK_K); + x += QK_K; + continue; + } + const float iscale = -128.f/max; + for (int j = 0; j < QK_K; ++j) { + int v = nearest_int(iscale*x[j]); + y[i].qs[j] = MIN(127, v); + } + for (int j = 0; j < QK_K/16; ++j) { + int sum = 0; + for (int ii = 0; ii < 16; ++ii) { + sum += y[i].qs[j*16 + ii]; + } + y[i].bsums[j] = sum; + } + y[i].d = 1/iscale; + x += QK_K; + } +} + +void dequantize_row_q8_k(const block_q8_k * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < QK_K; ++j) { + *y++ = x[i].d * x[i].qs[j]; + } + } +} + +void quantize_row_q8_k(const float * restrict x, void * restrict y, int k) { + quantize_row_q8_k_reference(x, y, k); +} + +//===================================== Dot ptoducts ================================= + +// +// Helper functions +// +#if __AVX__ || __AVX2__ || __AVX512F__ + +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = _mm256_extractf128_ps(x, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} + +// shuffles to pick the required scales in dot products +static inline __m256i get_scale_shuffle_q3k(int i) { + static const uint8_t k_shuffle[128] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} +static inline __m256i get_scale_shuffle_k4(int i) { + static const uint8_t k_shuffle[256] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, + 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, + 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, + 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} +static inline __m128i get_scale_shuffle(int i) { + static const uint8_t k_shuffle[128] = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, + 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, + 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, + 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 + }; + return _mm_loadu_si128((const __m128i*)k_shuffle + i); +} +#endif + +void ggml_vec_dot_q2_k_q8_k(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + + const block_q2_k * restrict x = vx; + const block_q8_k * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + const uint8x16_t m3 = vdupq_n_u8(0x3); + const uint8x16_t m4 = vdupq_n_u8(0xF); + const int32x4_t vzero = vdupq_n_s32(0); + + int8x16x2_t q2bytes; + uint8_t aux[16]; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const uint8_t * restrict sc = x[i].scales; + + const uint8x16_t mins_and_scales = vld1q_u8(sc); + const uint8x16_t scales = vandq_u8(mins_and_scales, m4); + vst1q_u8(aux, scales); + + const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); + const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); + const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}; + const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), + vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); + const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), + vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1]))); + sum += dmin * vaddvq_s32(vaddq_s32(s0, s1)); + + int isum = 0; + int is = 0; + +// We use this macro instead of a function call because for some reason +// the code runs 2-3% slower, even if the function is declared inline +#if defined(__ARM_FEATURE_DOTPROD) +#define MULTIPLY_ACCUM_WITH_SCALE(index)\ + isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ + isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; +#else +#define MULTIPLY_ACCUM_WITH_SCALE(index)\ + {\ + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\ + vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\ + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\ + vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\ + isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\ + } +#endif + +#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ + q8bytes = vld1q_s8_x2(q8); q8 += 32;\ + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ + MULTIPLY_ACCUM_WITH_SCALE((index)); + + + for (int j = 0; j < QK_K/128; ++j) { + + const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32; + + int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32; + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); + MULTIPLY_ACCUM_WITH_SCALE(0); + + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2); + + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4); + + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6); + + is += 8; + } + sum += d * isum; + + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m128i m4 = _mm_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales8 = _mm_and_si128(mins_and_scales, m4); + const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); + const __m256i mins = _mm256_cvtepi8_epi16(mins8); + const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums)); + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc); + + const __m256i all_scales = _mm256_cvtepi8_epi16(scales8); + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + const __m256i scales[2] = {_mm256_set_m128i(l_scales, l_scales), _mm256_set_m128i(h_scales, h_scales)}; + + __m256i sumi = _mm256_setzero_si256(); + + for (int j = 0; j < QK_K/128; ++j) { + + const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32; + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + const __m256i q2_0 = _mm256_and_si256(q2bits, m3); + const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3); + const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3); + const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3); + + __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); + __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); + __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2); + __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3); + + p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0); + p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1); + p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2); + p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3); + + p0 = _mm256_add_epi32(p0, p1); + p2 = _mm256_add_epi32(p2, p3); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2)); + } + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc); + +#else + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + int summs = 0; + for (int j = 0; j < 16; ++j) { + summs += y[i].bsums[j] * (sc[j] >> 4); + } + + const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + int isum = 0; + int is = 0; + int d; + for (int k = 0; k < QK_K/128; ++k) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + d = sc[is++] & 0xF; + int isuml = 0; + for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + d = sc[is++] & 0xF; + isuml = 0; + for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + shift += 2; + q8 += 32; + } + q2 += 32; + } + sumf += dall * isum - dmin * summs; + } + *s = sumf; +#endif +} + +void ggml_vec_dot_q3_k_q8_k(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_k * restrict x = vx; + const block_q8_k * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + uint32_t aux[3]; + uint32_t utmp[4]; + + const uint8x16_t m3b = vdupq_n_u8(0x3); +#ifdef __ARM_FEATURE_DOTPROD + const int32x4_t vzero = vdupq_n_s32(0); +#endif + + const uint8x16_t m0 = vdupq_n_u8(1); + const uint8x16_t m1 = vshlq_n_u8(m0, 1); + const uint8x16_t m2 = vshlq_n_u8(m0, 2); + const uint8x16_t m3 = vshlq_n_u8(m0, 3); + const int8_t m32 = 32; + + int8x16x4_t q3bytes; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + uint8x16x2_t qhbits = vld1q_u8_x2(qh); + + uint8x16x4_t q3h; + + int32_t isum = 0; + + // Set up scales + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= m32; + + for (int j = 0; j < QK_K/128; ++j) { + + const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32; + const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64; + const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64; + + q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); + q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); + q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); + q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); + + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3]; +#else + int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])), + vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0]))); + int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])), + vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1]))); + int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])), + vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2]))); + int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])), + vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; +#endif + scale += 4; + + q3h.val[0] = vbicq_u8(m2, qhbits.val[0]); + q3h.val[1] = vbicq_u8(m2, qhbits.val[1]); + q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); + q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); + + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3]; +#else + p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])), + vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0]))); + p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])), + vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1]))); + p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])), + vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2]))); + p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])), + vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; +#endif + scale += 4; + + if (j == 0) { + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4); + } + + } + sum += d * isum; + + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i mone = _mm256_set1_epi8(1); + const __m128i m32 = _mm_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // Set up scales + memcpy(aux, x[i].scales, 12); + __m128i scales128 = _mm_set_epi32( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = _mm_sub_epi8(scales128, m32); + const __m256i all_scales = _mm256_cvtepi8_epi16(scales128); + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + const __m256i scales[2] = {_mm256_set_m128i(l_scales, l_scales), _mm256_set_m128i(h_scales, h_scales)}; + + // high bit + const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask); + + // integer accumulator + __m256i sumi = _mm256_setzero_si256(); + + int bit = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits + const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32; + + // prepare low and high bits + const __m256i q3l_0 = _mm256_and_si256(q3bits, m3); + const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3); + const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3); + const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3); + const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); + __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2); + __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3); + + __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); + __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); + __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + // multiply with scales + p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); + p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); + p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); + + // accumulate + p16_0 = _mm256_add_epi32(p16_0, p16_1); + p16_2 = _mm256_add_epi32(p16_2, p16_3); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2)); + + } + + // multiply with block scale and accumulate + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc); + +#else + // scalar version + // This function is written like this so the compiler can manage to vectorize most of it + // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the + // manually vectorized version above. Every other version I tried would run at least 4 times slower. + // The ideal situation would be if we could just write the code once, and the compiler would + // automatically produce the best possible set of machine instructions, instead of us having to manually + // write vectorized versions for AVX, ARM_NEON, etc. + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + uint32_t auxs[4]; + const int8_t * scales = (const int8_t*)auxs; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + q3 += 32; + } + a = aux8; + + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + } + const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} + +void ggml_vec_dot_q4_k_q8_k(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q4_k * restrict x = vx; + const block_q8_k * restrict y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#ifdef __ARM_NEON + + const uint8x16_t m4b = vdupq_n_u8(0xf); +#ifdef __ARM_FEATURE_DOTPROD + const uint32x4_t mzero = vdupq_n_s32(0); +#endif + + int8x16x2_t q4bytes; + int8x16x2_t q8bytes; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, 12); + + const uint32x2_t mins8 = {utmp[1] & kmask1, ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4)}; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[0] &= kmask1; + + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + sumf -= dmin * vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + //int32x4_t isum = mzero; + + int32_t sumi1 = 0; + int32_t sumi2 = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32; + +#ifdef __ARM_FEATURE_DOTPROD + q8bytes = vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + + const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + sumi1 += vaddvq_s32(p1) * scales[2*j+0]; + + q8bytes = vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + + const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + + sumi2 += vaddvq_s32(p2) * scales[2*j+1]; +#else + q8bytes = vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0]; + + q8bytes = vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1]; + +#endif + } + + sumf += d * (sumi1 + sumi2); + + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + __m128 acc_m = _mm_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m); + + const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + const __m256i scales = _mm256_set_m128i(sc128, sc128); + + __m256i sumi = _mm256_setzero_si256(); + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4l = _mm256_and_si256(q4bits, m4); + const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + + const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); + p16l = _mm256_madd_epi16(scale_l, p16l); + sumi = _mm256_add_epi32(sumi, p16l); + + const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); + p16h = _mm256_madd_epi16(scale_h, p16h); + sumi = _mm256_add_epi32(sumi, p16h); + + } + + __m256 vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + + } + + acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); + acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); + + *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); + +#else + + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + a += 32; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + a += 32; q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +void ggml_vec_dot_q5_k_q8_k(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q5_k * restrict x = vx; + const block_q8_k * restrict y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + + +#ifdef __ARM_NEON + + const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint32x4_t mzero = vdupq_n_u32(0); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); + + int8x16x4_t q5bytes; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8); + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + int32_t sumi_mins = vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + uint8x16x2_t qhbits = vld1q_u8_x2(qh); + + uint8x16x4_t q5h; + + int32_t sumi = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32; + const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; + + q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); + q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); + q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3); + q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3); + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2); + + q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); + q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); + q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); + q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + + sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++; + sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++; +#else + + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++; + + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++; +#endif + } + + sumf += d * sumi - dmin * sumi_mins; + + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m128i mzero = _mm_setzero_si128(); + const __m256i mone = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); + summs += dmin * _mm_extract_epi32(hsum, 0); + + const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + const __m256i scales = _mm256_set_m128i(sc128, sc128); + + const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh); + __m256i hmask = mone; + + __m256i sumi = _mm256_setzero_si256(); + + int bit = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32; + + const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); + const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); + const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1); + + p16_0 = _mm256_madd_epi16(scale_0, p16_0); + p16_1 = _mm256_madd_epi16(scale_1, p16_1); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + + } + + __m256 vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + + + +void ggml_vec_dot_q6_k_q8_k(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q6_k * restrict x = vx; + const block_q8_k * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + float sum = 0; + + const uint8x16_t m4b = vdupq_n_u8(0xF); + const int32x4_t vzero = vdupq_n_s32(0); + //const int8x16_t m32s = vdupq_n_s8(32); + + const uint8x16_t mone = vdupq_n_u8(3); + + int8x16x4_t q6bytes; + uint8x16x4_t q6h; + + for (int i = 0; i < nb; ++i) { + + const float d_all = ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); + const int8x16_t scales = vld1q_s8(scale); + const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}; + + const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), + vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])), + vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1])))); + int32_t isum_mins = vaddvq_s32(prod); + + int32_t isum = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32; + uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64; + int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; + + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); + uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 2); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); + //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); + //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s); + //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s); + q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])); + q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])); + q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])); + q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + + isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; + scale += 4; + +#else + + int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; + scale += 2; + + int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1]; + scale += 2; +#endif + + q8bytes = vld1q_s8_x4(q8); q8 += 64; + + shifted = vshrq_n_u8(qhbits.val[0], 4); + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 4); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[0], 6); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 6); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s); + //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s); + //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s); + //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s); + q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])); + q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])); + q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])); + q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + + isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; + scale += 4; + + //for (int l = 0; l < 4; ++l) { + // const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]); + // isum += vaddvq_s32(p) * *scale++; + //} +#else + p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; + scale += 2; + + p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1]; + scale += 2; +#endif + + } + //sum += isum * d_all * y[i].d; + sum += d_all * y[i].d * (isum - 32 * isum_mins); + + } + *s = sum; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m2 = _mm256_set1_epi8(3); + const __m256i m32s = _mm256_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + + __m256i sumi = _mm256_setzero_si256(); + + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); + const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); + const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); + const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); + is += 4; + + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32; + + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4); + const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4); + const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4); + + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); + const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); + const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); + __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); + __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); + + __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); + __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); + __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); + p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2); + p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3)); + + } + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + } + + *s = hsum_float_8(acc); + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; + } + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + + diff --git a/ggml-quants-k.h b/ggml-quants-k.h new file mode 100644 index 000000000..d6f06013b --- /dev/null +++ b/ggml-quants-k.h @@ -0,0 +1,122 @@ +#pragma once + +#include "ggml.h" + +#include +#include +#include + +// Super-block size +#define QK_K 256 + +// +// Super-block quantization structures +// + +// 2-bit quantization +// weight is represented as x = a * q + b +// 16 blocks of 16 elemenets each +// Effectively 2.5625 bits per weight +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + ggml_fp16_t d; // super-block scale for quantized scales + ggml_fp16_t dmin; // super-block scale for quantized mins +} block_q2_k; +static_assert(sizeof(block_q2_k) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_k block size/padding"); + +// 3-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elemenets each +// Effectively 3.4375 bits per weight +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits + uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits + ggml_fp16_t d; // super-block scale +} block_q3_k; +static_assert(sizeof(block_q3_k) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_k block size/padding"); + +// 4-bit quantization +// 16 blocks of 32 elements each +// weight is represented as x = a * q + b +// Effectively 4.5 bits per weight +typedef struct { + ggml_fp16_t d; // super-block scale for quantized scales + ggml_fp16_t dmin; // super-block scale for quantized mins + uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_k; +static_assert(sizeof(block_q4_k) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_k block size/padding"); + +// 5-bit quantization +// 16 blocks of 32 elements each +// weight is represented as x = a * q + b +// Effectively 5.5 bits per weight +typedef struct { + ggml_fp16_t d; // super-block scale for quantized scales + ggml_fp16_t dmin; // super-block scale for quantized mins + uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_k; +static_assert(sizeof(block_q5_k) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_k block size/padding"); + +// 6-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elemenets each +// Effectively 6.5625 bits per weight +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + ggml_fp16_t d; // super-block scale +} block_q6_k; +static_assert(sizeof(block_q6_k) == sizeof(ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_k block size/padding"); + +// This is only used for intermediate quantization and dot products +typedef struct { + float d; // delta + int8_t qs[QK_K]; // quants + int16_t bsums[QK_K/16]; // sum of quants in groups of 16 +} block_q8_k; +static_assert(sizeof(block_q8_k) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_k block size/padding"); + + +// Quantization +void quantize_row_q2_k_reference(const float * restrict x, block_q2_k * restrict y, int k); +void quantize_row_q3_k_reference(const float * restrict x, block_q3_k * restrict y, int k); +void quantize_row_q4_k_reference(const float * restrict x, block_q4_k * restrict y, int k); +void quantize_row_q5_k_reference(const float * restrict x, block_q5_k * restrict y, int k); +void quantize_row_q6_k_reference(const float * restrict x, block_q6_k * restrict y, int k); +void quantize_row_q8_k_reference(const float * restrict x, block_q8_k * restrict y, int k); + +void quantize_row_q2_k(const float * restrict x, void * restrict y, int k); +void quantize_row_q3_k(const float * restrict x, void * restrict y, int k); +void quantize_row_q4_k(const float * restrict x, void * restrict y, int k); +void quantize_row_q5_k(const float * restrict x, void * restrict y, int k); +void quantize_row_q6_k(const float * restrict x, void * restrict y, int k); +void quantize_row_q8_k(const float * restrict x, void * restrict y, int k); + +// Dequantization +void dequantize_row_q2_k(const block_q2_k * restrict x, float * restrict y, int k); +void dequantize_row_q3_k(const block_q3_k * restrict x, float * restrict y, int k); +void dequantize_row_q4_k(const block_q4_k * restrict x, float * restrict y, int k); +void dequantize_row_q5_k(const block_q5_k * restrict x, float * restrict y, int k); +void dequantize_row_q6_k(const block_q6_k * restrict x, float * restrict y, int k); +void dequantize_row_q8_k(const block_q8_k * restrict x, float * restrict y, int k); + +// Dot product +void ggml_vec_dot_q2_k_q8_k(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void ggml_vec_dot_q3_k_q8_k(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void ggml_vec_dot_q4_k_q8_k(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void ggml_vec_dot_q5_k_q8_k(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void ggml_vec_dot_q6_k_q8_k(int n, float * restrict s, const void * restrict vx, const void * restrict vy); + +// Quantization with histogram collection +size_t ggml_quantize_q2_k(const float * src, void * dst, int n, int k, int64_t * hist); +size_t ggml_quantize_q3_k(const float * src, void * dst, int n, int k, int64_t * hist); +size_t ggml_quantize_q4_k(const float * src, void * dst, int n, int k, int64_t * hist); +size_t ggml_quantize_q5_k(const float * src, void * dst, int n, int k, int64_t * hist); +size_t ggml_quantize_q6_k(const float * src, void * dst, int n, int k, int64_t * hist); + diff --git a/ggml.c b/ggml.c index 00bbee503..7f9bff995 100644 --- a/ggml.c +++ b/ggml.c @@ -2,6 +2,7 @@ #define _GNU_SOURCE #include "ggml.h" +#include "ggml-quants-k.h" #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW @@ -1565,6 +1566,46 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .vec_dot_q = NULL, // TODO .vec_dot_type = GGML_TYPE_Q8_1, }, + [GGML_TYPE_Q2_K] = { + .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q2_k, + .quantize_row_q = quantize_row_q2_k, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q2_k_reference, + .quantize_row_q_dot = quantize_row_q8_k, + .vec_dot_q = ggml_vec_dot_q2_k_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + }, + [GGML_TYPE_Q3_K] = { + .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q3_k, + .quantize_row_q = quantize_row_q3_k, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q3_k_reference, + .quantize_row_q_dot = quantize_row_q8_k, + .vec_dot_q = ggml_vec_dot_q3_k_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + }, + [GGML_TYPE_Q4_K] = { + .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q4_k, + .quantize_row_q = quantize_row_q4_k, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_k_reference, + .quantize_row_q_dot = quantize_row_q8_k, + .vec_dot_q = ggml_vec_dot_q4_k_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + }, + [GGML_TYPE_Q5_K] = { + .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q5_k, + .quantize_row_q = quantize_row_q5_k, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_k_reference, + .quantize_row_q_dot = quantize_row_q8_k, + .vec_dot_q = ggml_vec_dot_q5_k_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + }, + [GGML_TYPE_Q6_K] = { + .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q6_k, + .quantize_row_q = quantize_row_q6_k, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q6_k_reference, + .quantize_row_q_dot = quantize_row_q8_k, + .vec_dot_q = ggml_vec_dot_q6_k_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + }, }; // For internal test use @@ -3444,11 +3485,17 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q5_1] = QK5_1, [GGML_TYPE_Q8_0] = QK8_0, [GGML_TYPE_Q8_1] = QK8_1, + [GGML_TYPE_Q2_K] = QK_K, + [GGML_TYPE_Q3_K] = QK_K, + [GGML_TYPE_Q4_K] = QK_K, + [GGML_TYPE_Q5_K] = QK_K, + [GGML_TYPE_Q6_K] = QK_K, + [GGML_TYPE_Q8_K] = QK_K, [GGML_TYPE_I8] = 1, [GGML_TYPE_I16] = 1, [GGML_TYPE_I32] = 1, }; -static_assert(GGML_TYPE_COUNT == 13, "GGML_BLCK_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 19, "GGML_BLCK_SIZE is outdated"); static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = sizeof(float), @@ -3459,11 +3506,17 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q5_1] = sizeof(block_q5_1), [GGML_TYPE_Q8_0] = sizeof(block_q8_0), [GGML_TYPE_Q8_1] = sizeof(block_q8_1), + [GGML_TYPE_Q2_K] = sizeof(block_q2_k), + [GGML_TYPE_Q3_K] = sizeof(block_q3_k), + [GGML_TYPE_Q4_K] = sizeof(block_q4_k), + [GGML_TYPE_Q5_K] = sizeof(block_q5_k), + [GGML_TYPE_Q6_K] = sizeof(block_q6_k), + [GGML_TYPE_Q8_K] = sizeof(block_q8_k), [GGML_TYPE_I8] = sizeof(int8_t), [GGML_TYPE_I16] = sizeof(int16_t), [GGML_TYPE_I32] = sizeof(int32_t), }; -static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_SIZE is outdated"); static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { @@ -3475,11 +3528,17 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { [GGML_TYPE_Q5_1] = "q5_1", [GGML_TYPE_Q8_0] = "q8_0", [GGML_TYPE_Q8_1] = "q8_1", + [GGML_TYPE_Q2_K] = "q2_k", + [GGML_TYPE_Q3_K] = "q3_k", + [GGML_TYPE_Q4_K] = "q4_k", + [GGML_TYPE_Q5_K] = "q5_k", + [GGML_TYPE_Q6_K] = "q6_k", + [GGML_TYPE_Q8_K] = "q8_k", [GGML_TYPE_I8] = "i8", [GGML_TYPE_I16] = "i16", [GGML_TYPE_I32] = "i32", }; -static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_NAME is outdated"); +static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_NAME is outdated"); static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = false, @@ -3490,11 +3549,17 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_Q5_1] = true, [GGML_TYPE_Q8_0] = true, [GGML_TYPE_Q8_1] = true, + [GGML_TYPE_Q2_K] = true, + [GGML_TYPE_Q3_K] = true, + [GGML_TYPE_Q4_K] = true, + [GGML_TYPE_Q5_K] = true, + [GGML_TYPE_Q6_K] = true, + [GGML_TYPE_Q8_K] = true, [GGML_TYPE_I8] = false, [GGML_TYPE_I16] = false, [GGML_TYPE_I32] = false, }; -static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated"); +static_assert(GGML_TYPE_COUNT == 19, "GGML_IS_QUANTIZED is outdated"); static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "NONE", @@ -3808,6 +3873,11 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; + case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break; + case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break; + case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break; + case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break; + case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break; case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; } @@ -7623,6 +7693,11 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: { ggml_compute_forward_add_q_f32(params, src0, src1, dst); } break; @@ -7926,6 +8001,11 @@ static void ggml_compute_forward_add1( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: { ggml_compute_forward_add1_q_f32(params, src0, src1, dst); } break; @@ -8048,6 +8128,11 @@ static void ggml_compute_forward_acc( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: default: { GGML_ASSERT(false); @@ -10148,6 +10233,11 @@ static void ggml_compute_forward_mul_mat( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: { ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst); } break; @@ -10331,6 +10421,11 @@ static void ggml_compute_forward_set( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: default: { GGML_ASSERT(false); @@ -10496,6 +10591,11 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: { ggml_compute_forward_get_rows_q(params, src0, src1, dst); } break; @@ -11042,6 +11142,12 @@ static void ggml_compute_forward_alibi( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q8_K: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -11113,6 +11219,12 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q8_K: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -16152,6 +16264,36 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i block_q8_0 * block = (block_q8_0*)dst + start / QK8_0; result = ggml_quantize_q8_0(src + start, block, n, n, hist); } break; + case GGML_TYPE_Q2_K: + { + GGML_ASSERT(start % QK_K == 0); + block_q2_k * block = (block_q2_k*)dst + start / QK_K; + result = ggml_quantize_q2_k(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q3_K: + { + GGML_ASSERT(start % QK_K == 0); + block_q3_k * block = (block_q3_k*)dst + start / QK_K; + result = ggml_quantize_q3_k(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q4_K: + { + GGML_ASSERT(start % QK_K == 0); + block_q4_k * block = (block_q4_k*)dst + start / QK_K; + result = ggml_quantize_q4_k(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q5_K: + { + GGML_ASSERT(start % QK_K == 0); + block_q5_k * block = (block_q5_k*)dst + start / QK_K; + result = ggml_quantize_q5_k(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q6_K: + { + GGML_ASSERT(start % QK_K == 0); + block_q6_k * block = (block_q6_k*)dst + start / QK_K; + result = ggml_quantize_q6_k(src + start, block, n, n, hist); + } break; default: assert(false); } diff --git a/ggml.h b/ggml.h index 2ea87ce9a..d1ba15f6a 100644 --- a/ggml.h +++ b/ggml.h @@ -241,6 +241,13 @@ extern "C" { GGML_TYPE_Q5_1 = 7, GGML_TYPE_Q8_0 = 8, GGML_TYPE_Q8_1 = 9, + // k-quantizations + GGML_TYPE_Q2_K = 10, + GGML_TYPE_Q3_K = 11, + GGML_TYPE_Q4_K = 12, + GGML_TYPE_Q5_K = 13, + GGML_TYPE_Q6_K = 14, + GGML_TYPE_Q8_K = 15, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -264,6 +271,11 @@ extern "C" { GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors + GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors + GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors + GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors + GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors }; // available tensor operations: diff --git a/llama.cpp b/llama.cpp index a16450173..e2511e533 100644 --- a/llama.cpp +++ b/llama.cpp @@ -515,6 +515,11 @@ struct llama_file_loader { case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: break; default: { throw format("unrecognized tensor type %u\n", shard.type); @@ -590,6 +595,11 @@ struct llama_file_saver { case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: break; default: LLAMA_ASSERT(false); } @@ -906,6 +916,16 @@ static const char *llama_ftype_name(enum llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q5_0: return "mostly Q5_0"; case LLAMA_FTYPE_MOSTLY_Q5_1: return "mostly Q5_1"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "mostly Q8_0"; + // K-quants + case LLAMA_FTYPE_MOSTLY_Q2_K: return "mostly Q2_K"; + case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "mostly Q3_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "mostly Q3_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "mostly Q3_K - Large"; + case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "mostly Q4_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "mostly Q4_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "mostly Q5_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "mostly Q5_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q6_K: return "mostly Q6_K"; default: return "unknown, may not work"; } } @@ -2113,8 +2133,18 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break; case LLAMA_FTYPE_MOSTLY_Q5_1: quantized_type = GGML_TYPE_Q5_1; break; case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break; + // K-quants + case LLAMA_FTYPE_MOSTLY_Q2_K: quantized_type = GGML_TYPE_Q2_K; break; + case LLAMA_FTYPE_MOSTLY_Q3_K_S: + case LLAMA_FTYPE_MOSTLY_Q3_K_M: + case LLAMA_FTYPE_MOSTLY_Q3_K_L: quantized_type = GGML_TYPE_Q3_K; break; + case LLAMA_FTYPE_MOSTLY_Q4_K_S: + case LLAMA_FTYPE_MOSTLY_Q4_K_M: quantized_type = GGML_TYPE_Q4_K; break; + case LLAMA_FTYPE_MOSTLY_Q5_K_S: + case LLAMA_FTYPE_MOSTLY_Q5_K_M: quantized_type = GGML_TYPE_Q5_K; break; + case LLAMA_FTYPE_MOSTLY_Q6_K: quantized_type = GGML_TYPE_Q6_K; break; default: throw format("invalid output file type %d\n", ftype); - }; + } if (nthread <= 0) { nthread = std::thread::hardware_concurrency(); @@ -2124,6 +2154,20 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s /*vocab_only*/ false)); llama_file_saver file_saver(fname_out.c_str(), model_loader->file_loaders.at(0).get(), ftype); + int n_attention_wv = 0; + int n_feed_forward_w2 = 0; + for (auto& tensor : model_loader->tensors_map.tensors) { + if (tensor.name.find("attention.wv.weight") != std::string::npos) { + ++n_attention_wv; + } + else if (tensor.name.find("feed_forward.w2.weight") != std::string::npos) { + ++n_feed_forward_w2; + } + } + + int i_attention_wv = 0; + int i_feed_forward_w2 = 0; + size_t total_size_org = 0; size_t total_size_new = 0; std::vector hist_all(1 << 4, 0); @@ -2166,6 +2210,27 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s printf("size = %8.3f MB\n", tensor.size/1024.0/1024.0); } else { new_type = quantized_type; + if (tensor.name == "output.weight") new_type = GGML_TYPE_Q6_K; + else if (tensor.name.find("attention.wv.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; + else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && + (i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8 || + (i_attention_wv - n_attention_wv/8)%3 == 2)) new_type = GGML_TYPE_Q6_K; + ++i_attention_wv; + } + else if (tensor.name.find("feed_forward.w2.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; + else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && + (i_feed_forward_w2 < n_feed_forward_w2/8 || i_feed_forward_w2 >= 7*n_feed_forward_w2/8 || + (i_feed_forward_w2 - n_feed_forward_w2/8)%3 == 2)) new_type = GGML_TYPE_Q6_K; + ++i_feed_forward_w2; + } + else if (tensor.name.find("attention.wo.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; + } float * f32_data; size_t nelements = tensor.ne.at(0) * tensor.ne.at(1); llama_buffer f32_conv_buf; @@ -2233,12 +2298,16 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } printf("size = %8.2f MB -> %8.2f MB | hist: ", tensor.size/1024.0/1024.0, new_size/1024.0/1024.0); + int64_t tot_count = 0; for (size_t i = 0; i < hist_cur.size(); i++) { hist_all[i] += hist_cur[i]; + tot_count += hist_cur[i]; } - for (size_t i = 0; i < hist_cur.size(); i++) { - printf("%5.3f ", hist_cur[i] / float(nelements)); + if (tot_count > 0) { + for (size_t i = 0; i < hist_cur.size(); i++) { + printf("%5.3f ", hist_cur[i] / float(nelements)); + } } printf("\n"); } @@ -2256,11 +2325,13 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s sum_all += hist_all[i]; } - printf("%s: hist: ", __func__); - for (size_t i = 0; i < hist_all.size(); i++) { - printf("%5.3f ", hist_all[i] / float(sum_all)); + if (sum_all > 0) { + printf("%s: hist: ", __func__); + for (size_t i = 0; i < hist_all.size(); i++) { + printf("%5.3f ", hist_all[i] / float(sum_all)); + } + printf("\n"); } - printf("\n"); } } diff --git a/llama.h b/llama.h index 87fa97367..7a6419738 100644 --- a/llama.h +++ b/llama.h @@ -94,6 +94,15 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors }; LLAMA_API struct llama_context_params llama_context_default_params(); diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index a31a18827..728460b5e 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -12,6 +12,8 @@ const float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001; const float MAX_QUANTIZATION_TOTAL_ERROR = 0.002; +const float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075; +const float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040; const float MAX_DOT_PRODUCT_ERROR = 0.02; const char* RESULT_STR[] = {"ok", "FAILED"}; @@ -122,7 +124,10 @@ int main(int argc, char * argv[]) { if (qfns.quantize_row_q && qfns.dequantize_row_q) { const float total_error = total_quantization_error(qfns, test_size, test_data.data()); - failed = !(total_error < MAX_QUANTIZATION_TOTAL_ERROR); + const float max_quantization_error = + type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS : + type == GGML_TYPE_Q3_K ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS : MAX_QUANTIZATION_TOTAL_ERROR; + failed = !(total_error < max_quantization_error); num_failed += failed; if (failed || verbose) { printf("%5s absolute quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error); From e7fe66e670537990ccc075cce9286df88bba052a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 5 Jun 2023 23:05:05 +0300 Subject: [PATCH 06/15] ci : disable auto tidy (#1705) --- .github/workflows/tidy-post.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tidy-post.yml b/.github/workflows/tidy-post.yml index a58da0cd6..03652760c 100644 --- a/.github/workflows/tidy-post.yml +++ b/.github/workflows/tidy-post.yml @@ -1,7 +1,7 @@ name: clang-tidy review post comments on: - workflow_run: + workflow_dispatch: workflows: ["clang-tidy-review"] types: - completed From efe05076323f5c6bafece109e21cce046f5e4b07 Mon Sep 17 00:00:00 2001 From: grahameth <96447521+grahameth@users.noreply.github.com> Date: Mon, 5 Jun 2023 22:11:49 +0200 Subject: [PATCH 07/15] ggml : fix internal overflow in ggml_time_us on Windows (#1702) Co-authored-by: grahameth <-> --- ggml.c | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/ggml.c b/ggml.c index 7f9bff995..24f0d2fad 100644 --- a/ggml.c +++ b/ggml.c @@ -404,21 +404,27 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) { // #if defined(_MSC_VER) || defined(__MINGW32__) -static int64_t timer_freq; +static int64_t timer_freq, timer_start; void ggml_time_init(void) { - LARGE_INTEGER frequency; - QueryPerformanceFrequency(&frequency); - timer_freq = frequency.QuadPart; + LARGE_INTEGER t; + QueryPerformanceFrequency(&t); + timer_freq = t.QuadPart; + + // The multiplication by 1000 or 1000000 below can cause an overflow if timer_freq + // and the uptime is high enough. + // We subtract the program start time to reduce the likelihood of that happening. + QueryPerformanceCounter(&t); + timer_start = t.QuadPart; } int64_t ggml_time_ms(void) { LARGE_INTEGER t; QueryPerformanceCounter(&t); - return (t.QuadPart * 1000) / timer_freq; + return ((t.QuadPart-timer_start) * 1000) / timer_freq; } int64_t ggml_time_us(void) { LARGE_INTEGER t; QueryPerformanceCounter(&t); - return (t.QuadPart * 1000000) / timer_freq; + return ((t.QuadPart-timer_start) * 1000000) / timer_freq; } #else void ggml_time_init(void) {} From 9d0693bce38013364b1042568d9083353bfff48f Mon Sep 17 00:00:00 2001 From: kiltyj Date: Mon, 5 Jun 2023 13:24:04 -0700 Subject: [PATCH 08/15] metal : use shared buffers between CPU and GPU (#1696) * Use MTLDevice.newBufferWithBytesNoCopy to share buffers between CPU and GPU * Page-align buffers used by Metal * Remove trailing whitespace * Only import unistd.h for Metal builds * metal : remove unnecessary copies --------- Co-authored-by: Georgi Gerganov --- ggml-metal.m | 17 ++++++++++++++--- ggml.c | 8 ++++++++ llama-util.h | 16 ++++++++++++++++ llama.cpp | 13 ------------- 4 files changed, 38 insertions(+), 16 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 3cb423a01..82c65963b 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -195,14 +195,25 @@ bool ggml_metal_add_buffer( } } + size_t page_size = getpagesize(); + size_t aligned_size = size; + if ((aligned_size % page_size) != 0) { + aligned_size += (page_size - (aligned_size % page_size)); + } + ctx->buffers[ctx->n_buffers].name = name; ctx->buffers[ctx->n_buffers].data = data; ctx->buffers[ctx->n_buffers].size = size; - ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytes:data length:size options:MTLResourceStorageModeShared]; + ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:aligned_size options:MTLResourceStorageModeShared deallocator:nil]; + + if (ctx->buffers[ctx->n_buffers].metal == nil) { + fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0); + return false; + } else { + fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0); + } ++ctx->n_buffers; - - fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB\n", __func__, name, size / 1024.0 / 1024.0); } return true; diff --git a/ggml.c b/ggml.c index 24f0d2fad..4e3e7edb9 100644 --- a/ggml.c +++ b/ggml.c @@ -22,6 +22,10 @@ #include #include +#ifdef GGML_USE_METAL +#include +#endif + // if C99 - static_assert is noop // ref: https://stackoverflow.com/a/53923785/4039976 #ifndef static_assert @@ -122,7 +126,11 @@ typedef void* thread_ret_t; #else inline static void* ggml_aligned_malloc(size_t size) { void* aligned_memory = NULL; +#ifdef GGML_USE_METAL + int result = posix_memalign(&aligned_memory, getpagesize(), size); +#else int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size); +#endif if (result != 0) { // Handle allocation failure return NULL; diff --git a/llama-util.h b/llama-util.h index 3cac9f681..4f8a4296a 100644 --- a/llama-util.h +++ b/llama-util.h @@ -405,13 +405,29 @@ struct llama_buffer { llama_buffer() = default; void resize(size_t len) { +#ifdef GGML_USE_METAL + free(addr); + int result = posix_memalign((void **) &addr, getpagesize(), len); + if (result == 0) { + memset(addr, 0, len); + } + else { + addr = NULL; + } +#else delete[] addr; addr = new uint8_t[len]; +#endif size = len; } ~llama_buffer() { +#ifdef GGML_USE_METAL + free(addr); +#else delete[] addr; +#endif + addr = NULL; } // disable copy and move diff --git a/llama.cpp b/llama.cpp index e2511e533..d0e7151f4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -53,7 +53,6 @@ enum e_model { MODEL_65B, }; - static const size_t MB = 1024*1024; // computed for n_ctx == 2048 @@ -1281,12 +1280,6 @@ static bool llama_eval_internal( ggml_set_name(embd, "embd"); memcpy(embd->data, tokens, N*ggml_element_size(embd)); -#ifdef GGML_USE_METAL - if (lctx.ctx_metal && N == 1) { - ggml_metal_set_tensor(lctx.ctx_metal, embd); - } -#endif - struct ggml_tensor * cur; struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd); @@ -1484,12 +1477,6 @@ static bool llama_eval_internal( } ggml_graph_compute(ctx0, &gf); - - if (lctx.ctx_metal) { - // We need to sync the CPU KV cache with the GPU KV cache - ggml_metal_set_tensor(lctx.ctx_metal, kv_self.k); - ggml_metal_set_tensor(lctx.ctx_metal, kv_self.v); - } } #else ggml_graph_compute(ctx0, &gf); From c2df36d60dc0ff1576541b965d751eadbacbeada Mon Sep 17 00:00:00 2001 From: mgroeber9110 <45620825+mgroeber9110@users.noreply.github.com> Date: Mon, 5 Jun 2023 22:24:29 +0200 Subject: [PATCH 09/15] llama : consistently catch and throw only exceptions deriving from std::exception (#1599) Co-authored-by: Georgi Gerganov --- llama.cpp | 59 ++++++++++++++++++++++++++++--------------------------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/llama.cpp b/llama.cpp index d0e7151f4..54545f01d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -289,15 +289,15 @@ template static T checked_mul(T a, T b) { T ret = a * b; if (a != 0 && ret / a != b) { - throw format("overflow multiplying %llu * %llu", - (unsigned long long) a, (unsigned long long) b); + throw std::runtime_error(format("overflow multiplying %llu * %llu", + (unsigned long long) a, (unsigned long long) b)); } return ret; } static size_t checked_div(size_t a, size_t b) { if (b == 0 || a % b != 0) { - throw format("error dividing %zu / %zu", a, b); + throw std::runtime_error(format("error dividing %zu / %zu", a, b)); } return a / b; } @@ -361,7 +361,7 @@ struct llama_load_tensor { const auto & first_shard = shards.at(0); for (const auto & shard : shards) { if (shard.type != first_shard.type) { - throw format("inconsistent tensor shard type in '%s'", name.c_str()); + throw std::runtime_error(format("inconsistent tensor shard type in '%s'", name.c_str())); } } type = first_shard.type; @@ -384,8 +384,8 @@ struct llama_load_tensor { const auto & first_shard = shards.at(0); for (const auto & shard : shards) { if (shard.ne != first_shard.ne) { - throw format("inconsistent tensor shard shape in '%s': first was %s, other was %s", - name.c_str(), llama_format_tensor_shape(first_shard.ne).c_str(), llama_format_tensor_shape(shard.ne).c_str()); + throw std::runtime_error(format("inconsistent tensor shard shape in '%s': first was %s, other was %s", + name.c_str(), llama_format_tensor_shape(first_shard.ne).c_str(), llama_format_tensor_shape(shard.ne).c_str())); } } ne = first_shard.ne; @@ -463,8 +463,8 @@ struct llama_file_loader { } } - throw format("unknown (magic, version) combination: %08x, %08x; is this really a GGML file?", - magic, version); + throw std::runtime_error(format("unknown (magic, version) combination: %08x, %08x; is this really a GGML file?", + magic, version)); } void read_hparams() { hparams.n_vocab = file.read_u32(); @@ -504,7 +504,7 @@ struct llama_file_loader { file.read_raw(shard.ne.data(), sizeof(shard.ne[0]) * n_dims); std::string name = file.read_string(name_len); if (n_dims < 1 || n_dims > 2) { - throw format("llama.cpp: tensor '%s' should not be %u-dimensional", name.c_str(), n_dims); + throw std::runtime_error(format("llama.cpp: tensor '%s' should not be %u-dimensional", name.c_str(), n_dims)); } switch (shard.type) { case GGML_TYPE_F32: @@ -521,7 +521,7 @@ struct llama_file_loader { case GGML_TYPE_Q6_K: break; default: { - throw format("unrecognized tensor type %u\n", shard.type); + throw std::runtime_error(format("unrecognized tensor type %u\n", shard.type)); } } @@ -630,7 +630,7 @@ struct llama_model_loader { auto * ith_file = new llama_file_loader(fname.c_str(), i, tensors_map); file_loaders.emplace_back(ith_file); if (ith_file->hparams != first_file->hparams) { - throw format("llama.cpp: hparams inconsistent between files"); + throw std::runtime_error(format("llama.cpp: hparams inconsistent between files")); } } if (!llama_mmap::SUPPORTED) { @@ -660,7 +660,7 @@ struct llama_model_loader { uint32_t guess_n_parts() const { auto it = tensors_map.name_to_idx.find("tok_embeddings.weight"); if (it == tensors_map.name_to_idx.end()) { - throw std::string("missing tok_embeddings.weight"); + throw std::runtime_error(std::string("missing tok_embeddings.weight")); } const llama_load_tensor & lt = tensors_map.tensors.at(it->second); return file_loaders.at(0)->hparams.n_embd / lt.shards.at(0).ne.at(0); @@ -677,12 +677,12 @@ struct llama_model_loader { struct ggml_tensor * get_tensor(const std::string & name, const std::vector & ne, ggml_backend backend) { auto it = tensors_map.name_to_idx.find(name); if (it == tensors_map.name_to_idx.end()) { - throw format("llama.cpp: tensor '%s' is missing from model", name.c_str()); + throw std::runtime_error(std::runtime_error(format("llama.cpp: tensor '%s' is missing from model", name.c_str()))); } llama_load_tensor & lt = tensors_map.tensors.at(it->second); if (lt.ne != ne) { - throw format("llama.cpp: tensor '%s' has wrong shape; expected %s, got %s", - name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(lt.ne).c_str()); + throw std::runtime_error(format("llama.cpp: tensor '%s' has wrong shape; expected %s, got %s", + name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(lt.ne).c_str())); } return get_tensor_for(lt, backend); @@ -706,7 +706,7 @@ struct llama_model_loader { void done_getting_tensors() const { if (num_ggml_tensors_created != tensors_map.tensors.size()) { - throw std::string("llama.cpp: file contained more tensors than expected"); + throw std::runtime_error(std::string("llama.cpp: file contained more tensors than expected")); } } @@ -994,7 +994,7 @@ static void llama_model_load_internal( if (hparams.ftype != LLAMA_FTYPE_ALL_F32 && hparams.ftype != LLAMA_FTYPE_MOSTLY_F16 && hparams.ftype != LLAMA_FTYPE_MOSTLY_Q8_0) { - throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1405)"); + throw std::runtime_error(format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1405)")); } } @@ -1002,7 +1002,7 @@ static void llama_model_load_internal( if (hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_1 || hparams.ftype == LLAMA_FTYPE_MOSTLY_Q8_0) { - throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1508)"); + throw std::runtime_error(format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1508)")); } } @@ -1033,7 +1033,7 @@ static void llama_model_load_internal( model.ctx = ggml_init(params); if (!model.ctx) { - throw format("ggml_init() failed"); + throw std::runtime_error(format("ggml_init() failed")); } } @@ -1214,8 +1214,8 @@ static bool llama_model_load( llama_model_load_internal(fname, lctx, n_ctx, n_gpu_layers, memory_type, use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data); return true; - } catch (const std::string & err) { - fprintf(stderr, "error loading model: %s\n", err.c_str()); + } catch (const std::exception & err) { + fprintf(stderr, "error loading model: %s\n", err.what()); return false; } } @@ -2120,8 +2120,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break; case LLAMA_FTYPE_MOSTLY_Q5_1: quantized_type = GGML_TYPE_Q5_1; break; case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break; + // K-quants - case LLAMA_FTYPE_MOSTLY_Q2_K: quantized_type = GGML_TYPE_Q2_K; break; + case LLAMA_FTYPE_MOSTLY_Q2_K: quantized_type = GGML_TYPE_Q2_K; break; case LLAMA_FTYPE_MOSTLY_Q3_K_S: case LLAMA_FTYPE_MOSTLY_Q3_K_M: case LLAMA_FTYPE_MOSTLY_Q3_K_L: quantized_type = GGML_TYPE_Q3_K; break; @@ -2129,8 +2130,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q4_K_M: quantized_type = GGML_TYPE_Q4_K; break; case LLAMA_FTYPE_MOSTLY_Q5_K_S: case LLAMA_FTYPE_MOSTLY_Q5_K_M: quantized_type = GGML_TYPE_Q5_K; break; - case LLAMA_FTYPE_MOSTLY_Q6_K: quantized_type = GGML_TYPE_Q6_K; break; - default: throw format("invalid output file type %d\n", ftype); + case LLAMA_FTYPE_MOSTLY_Q6_K: quantized_type = GGML_TYPE_Q6_K; break; + default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); } if (nthread <= 0) { @@ -2231,7 +2232,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s f32_data[i] = ggml_fp16_to_fp32(f16_data[i]); } } else { - throw format("type %s unsupported for integer quantization", ggml_type_name(tensor.type)); + throw std::runtime_error(format("type %s unsupported for integer quantization", ggml_type_name(tensor.type))); } printf("quantizing .. "); @@ -2433,8 +2434,8 @@ int llama_model_quantize( try { llama_model_quantize_internal(fname_inp, fname_out, ftype, nthread); return 0; - } catch (const std::string & err) { - fprintf(stderr, "%s: failed to quantize: %s\n", __func__, err.c_str()); + } catch (const std::exception & err) { + fprintf(stderr, "%s: failed to quantize: %s\n", __func__, err.what()); return 1; } } @@ -2687,8 +2688,8 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) { try { return llama_apply_lora_from_file_internal(ctx, path_lora, path_base_model, n_threads); - } catch (const std::string & err) { - fprintf(stderr, "%s: failed to apply lora adapter: %s\n", __func__, err.c_str()); + } catch (const std::exception & err) { + fprintf(stderr, "%s: failed to apply lora adapter: %s\n", __func__, err.what()); return 1; } } From f1465624c2cbc8ee65b566e3d87f2af27796d4c4 Mon Sep 17 00:00:00 2001 From: Foul-Tarnished <107711110+Foul-Tarnished@users.noreply.github.com> Date: Mon, 5 Jun 2023 22:28:37 +0200 Subject: [PATCH 10/15] readme : fix typo (#1700) Fix a typo in a command in README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 01f138127..77332ca0e 100644 --- a/README.md +++ b/README.md @@ -317,7 +317,7 @@ Building the program with BLAS support may lead to some performance improvements mkdir build cd build cmake .. -DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=Intel10_64lp -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx - cmake --build . -config Release + cmake --build . --config Release ``` - **cuBLAS** From f4c55d3bd7e124b101bc974cbbf0e0dbbc32d5a3 Mon Sep 17 00:00:00 2001 From: Yuval Peled <31162840+Yuval-Peled@users.noreply.github.com> Date: Mon, 5 Jun 2023 23:32:36 +0300 Subject: [PATCH 11/15] docs : add performance troubleshoot + example benchmark documentation (#1674) * test anchor link * test table * add benchmarks * Add performance troubleshoot & benchmark * add benchmarks * remove unneeded line --------- Co-authored-by: Georgi Gerganov --- README.md | 13 ++++---- BLIS.md => docs/BLIS.md | 0 docs/token_generation_performance_tips.md | 40 +++++++++++++++++++++++ 3 files changed, 47 insertions(+), 6 deletions(-) rename BLIS.md => docs/BLIS.md (100%) create mode 100644 docs/token_generation_performance_tips.md diff --git a/README.md b/README.md index 77332ca0e..28842e968 100644 --- a/README.md +++ b/README.md @@ -267,11 +267,11 @@ Any value larger than 0 will offload the computation to the GPU. For example: Building the program with BLAS support may lead to some performance improvements in prompt processing using batch sizes higher than 32 (the default is 512). BLAS doesn't affect the normal generation performance. There are currently three different implementations of it: -- **Accelerate Framework**: +- #### Accelerate Framework: This is only available on Mac PCs and it's enabled by default. You can just build using the normal instructions. -- **OpenBLAS**: +- #### OpenBLAS: This provides BLAS acceleration using only the CPU. Make sure to have OpenBLAS installed on your machine. @@ -305,11 +305,11 @@ Building the program with BLAS support may lead to some performance improvements cmake --build . --config Release ``` -- **BLIS** +- #### BLIS Check [BLIS.md](BLIS.md) for more information. -- **Intel MKL** +- #### Intel MKL By default, `LLAMA_BLAS_VENDOR` is set to `Generic`, so if you already sourced intel environment script and assign `-DLLAMA_BLAS=ON` in cmake, the mkl version of Blas will automatically been selected. You may also specify it by: @@ -320,7 +320,7 @@ Building the program with BLAS support may lead to some performance improvements cmake --build . --config Release ``` -- **cuBLAS** +- #### cuBLAS This provides BLAS acceleration using the CUDA cores of your Nvidia GPU. Make sure to have the CUDA toolkit installed. You can download it from your Linux distro's package manager or from here: [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads). - Using `make`: @@ -339,7 +339,7 @@ Building the program with BLAS support may lead to some performance improvements The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) can be used to specify which GPU(s) will be used. -- **CLBlast** +- #### CLBlast OpenCL acceleration is provided by the matrix multiplication kernels from the [CLBlast](https://github.com/CNugteren/CLBlast) project and custom kernels for ggml that can generate tokens on the GPU. @@ -684,3 +684,4 @@ docker run -v /path/to/models:/models ghcr.io/ggerganov/llama.cpp:light -m /mode ### Docs - [GGML tips & tricks](https://github.com/ggerganov/llama.cpp/wiki/GGML-Tips-&-Tricks) +- [Performance troubleshooting](./docs/token_generation_performance_tips.md) diff --git a/BLIS.md b/docs/BLIS.md similarity index 100% rename from BLIS.md rename to docs/BLIS.md diff --git a/docs/token_generation_performance_tips.md b/docs/token_generation_performance_tips.md new file mode 100644 index 000000000..69ba6173c --- /dev/null +++ b/docs/token_generation_performance_tips.md @@ -0,0 +1,40 @@ +# Token generation performance troubleshooting + +## Verifying that the model is running on the GPU with cuBLAS +Make sure you compiled llama with the correct env variables according to [this guide](../README.md#cublas), so that llama accepts the `-ngl N` (or `--n-gpu-layers N`) flag. When running llama, you may configure `N` to be very large, and llama will offload the maximum possible number of layers to the GPU, even if it's less than the number you configured. For example: +```shell +./main -m "path/to/model.bin" -ngl 200000 -p "Please sir, may I have some " +``` + +When running llama, before it starts the inference work, it will output diagnostic information that shows whether cuBLAS is offloading work to the GPU. Look for these lines: +```shell +llama_model_load_internal: [cublas] offloading 60 layers to GPU +llama_model_load_internal: [cublas] offloading output layer to GPU +llama_model_load_internal: [cublas] total VRAM used: 17223 MB +... rest of inference +``` + +If you see these lines, then the GPU is being used. + +## Verifying that the CPU is not oversaturated +llama accepts a `-t N` (or `--threads N`) parameter. It's extremely important that this parameter is not too large. If your token generation is extremely slow, try setting this number to 1. If this significantly improves your token generation speed, then your CPU is being oversaturated and you need to explicitly set this parameter to the number of the physicial CPU cores on your machine (even if you utilize a GPU). If in doubt, start with 1 and double the amount until you hit a performance bottleneck, then scale the number down. + +# Example of runtime flags effect on inference speed benchmark +These runs were tested on the following machine: +GPU: A6000 (48GB VRAM) +CPU: 7 physical cores +RAM: 32GB + +Model: `TheBloke_Wizard-Vicuna-30B-Uncensored-GGML/Wizard-Vicuna-30B-Uncensored.ggmlv3.q4_0.bin` (30B parameters, 4bit quantization, GGML) + +Run command: `./main -m "path/to/model.bin" -p "-p "An extremely detailed description of the 10 best ethnic dishes will follow, with recipes: " -n 1000 [additional benchmark flags]` + +Result: + +| command | tokens/second (higher is better) | +| - | - | +| -ngl 2000000 | N/A (less than 0.1) | +| -t 7 | 1.7 | +| -t 1 -ngl 2000000 | 5.5 | +| -t 7 -ngl 2000000 | 8.7 | +| -t 4 -ngl 2000000 | 9.1 | From 590250f7a9847bc9c83aa063dbaac8fa0fea27c8 Mon Sep 17 00:00:00 2001 From: Spencer Sutton Date: Mon, 5 Jun 2023 23:28:17 -0400 Subject: [PATCH 12/15] metal : add checks for buffer size (#1706) Co-authored-by: Spencer Sutton --- ggml-metal.m | 5 +++++ llama.cpp | 27 ++++++++++++++++++++------- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 82c65963b..d721ac688 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -204,6 +204,11 @@ bool ggml_metal_add_buffer( ctx->buffers[ctx->n_buffers].name = name; ctx->buffers[ctx->n_buffers].data = data; ctx->buffers[ctx->n_buffers].size = size; + + if (ctx->device.maxBufferLength < aligned_size) { + fprintf(stderr, "%s: buffer '%s' size %zu is larger than buffer maximum of %zu\n", __func__, name, aligned_size, ctx->device.maxBufferLength); + return false; + } ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:aligned_size options:MTLResourceStorageModeShared deallocator:nil]; if (ctx->buffers[ctx->n_buffers].metal == nil) { diff --git a/llama.cpp b/llama.cpp index 54545f01d..568ce6aca 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2405,17 +2405,30 @@ struct llama_context * llama_init_from_file( // this allocates all Metal resources and memory buffers ctx->ctx_metal = ggml_metal_init(); + void *data_ptr = NULL; + size_t data_size = 0; if (params.use_mmap) { - ggml_metal_add_buffer(ctx->ctx_metal, "data", ctx->model.mapping->addr, ctx->model.mapping->size); - ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size); + data_ptr = ctx->model.mapping->addr; + data_size= ctx->model.mapping->size; } else { - ggml_metal_add_buffer(ctx->ctx_metal, "data", ggml_get_mem_buffer(ctx->model.ctx), ggml_get_mem_size(ctx->model.ctx)); - ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size); + data_ptr = ggml_get_mem_buffer(ctx->model.ctx); + data_size= ggml_get_mem_size(ctx->model.ctx); } - ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->model.kv_self.buf.addr, ctx->model.kv_self.buf.size); - ggml_metal_add_buffer(ctx->ctx_metal, "scr0", ctx->buf_scratch[0].addr, ctx->buf_scratch[0].size); - ggml_metal_add_buffer(ctx->ctx_metal, "scr1", ctx->buf_scratch[1].addr, ctx->buf_scratch[1].size); +#define LLAMA_METAL_CHECK_BUF(result) \ + if (!(result)) { \ + fprintf(stderr, "%s: failed to add buffer\n", __func__); \ + llama_free(ctx); \ + return NULL; \ + } + + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size)); + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size)); + + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->model.kv_self.buf.addr, ctx->model.kv_self.buf.size)); + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr0", ctx->buf_scratch[0].addr, ctx->buf_scratch[0].size)); + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr1", ctx->buf_scratch[1].addr, ctx->buf_scratch[1].size)); +#undef LLAMA_METAL_CHECK_BUF } #endif From 7a74dee6b4e0e80862191141c0037abe28967d5c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 6 Jun 2023 09:39:38 +0300 Subject: [PATCH 13/15] llama : temporary disable Q6_K output quantization (#1711) --- llama.cpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/llama.cpp b/llama.cpp index 568ce6aca..70341d04f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2198,8 +2198,12 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s printf("size = %8.3f MB\n", tensor.size/1024.0/1024.0); } else { new_type = quantized_type; - if (tensor.name == "output.weight") new_type = GGML_TYPE_Q6_K; - else if (tensor.name.find("attention.wv.weight") != std::string::npos) { + // TODO: temporary disabled until Metal / OpenCL support is available + // ref: https://github.com/ggerganov/llama.cpp/issues/1711 + //if (tensor.name == "output.weight") { + // new_type = GGML_TYPE_Q6_K; + //} + if (tensor.name.find("attention.wv.weight") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && @@ -2207,7 +2211,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s (i_attention_wv - n_attention_wv/8)%3 == 2)) new_type = GGML_TYPE_Q6_K; ++i_attention_wv; } - else if (tensor.name.find("feed_forward.w2.weight") != std::string::npos) { + if (tensor.name.find("feed_forward.w2.weight") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && @@ -2215,10 +2219,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s (i_feed_forward_w2 - n_feed_forward_w2/8)%3 == 2)) new_type = GGML_TYPE_Q6_K; ++i_feed_forward_w2; } - else if (tensor.name.find("attention.wo.weight") != std::string::npos) { + if (tensor.name.find("attention.wo.weight") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; } + float * f32_data; size_t nelements = tensor.ne.at(0) * tensor.ne.at(1); llama_buffer f32_conv_buf; From 7ad7750c5c9f6bcea73a1895ffc57e7d21f2ab95 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 6 Jun 2023 09:55:10 +0300 Subject: [PATCH 14/15] gitignore : add .clang-tidy --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index e4561ad73..6cf5c45a7 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ .envrc .swiftpm .venv +.clang-tidy .vs/ .vscode/ From 2d43387dafe9c60f15f57aa23ee0b37864b98b32 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 6 Jun 2023 10:18:03 +0300 Subject: [PATCH 15/15] ggml : fix builds, add ggml-quants-k.o (close #1712, close #1710) --- .gitignore | 1 + Makefile | 18 +++++++++--------- ggml.c | 20 ++++++++++++-------- 3 files changed, 22 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index 6cf5c45a7..9b6905ed4 100644 --- a/.gitignore +++ b/.gitignore @@ -35,6 +35,7 @@ models/* /benchmark-matmult /vdot /Pipfile +/libllama.so build-info.h arm_neon.h diff --git a/Makefile b/Makefile index 7c9e7f739..0205f1959 100644 --- a/Makefile +++ b/Makefile @@ -243,7 +243,7 @@ llama.o: llama.cpp ggml.h ggml-cuda.h llama.h llama-util.h common.o: examples/common.cpp examples/common.h $(CXX) $(CXXFLAGS) -c $< -o $@ -libllama.so: llama.o ggml.o $(OBJS) +libllama.so: llama.o ggml.o ggml-quants-k.o $(OBJS) $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) clean: @@ -253,28 +253,28 @@ clean: # Examples # -main: examples/main/main.cpp build-info.h ggml.o ggml-quants-k.o llama.o common.o $(OBJS) +main: examples/main/main.cpp build-info.h ggml.o ggml-quants-k.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) @echo @echo '==== Run ./main -h for help. ====' @echo -quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o ggml-quants-k.o $(OBJS) +quantize: examples/quantize/quantize.cpp build-info.h ggml.o ggml-quants-k.o llama.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -quantize-stats: examples/quantize-stats/quantize-stats.cpp build-info.h ggml.o llama.o ggml-quants-k.o $(OBJS) +quantize-stats: examples/quantize-stats/quantize-stats.cpp build-info.h ggml.o ggml-quants-k.o llama.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -perplexity: examples/perplexity/perplexity.cpp build-info.h ggml.o llama.o common.o ggml-quants-k.o $(OBJS) +perplexity: examples/perplexity/perplexity.cpp build-info.h ggml.o ggml-quants-k.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -embedding: examples/embedding/embedding.cpp build-info.h ggml.o llama.o common.o ggml-quants-k.o $(OBJS) +embedding: examples/embedding/embedding.cpp build-info.h ggml.o ggml-quants-k.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o llama.o common.o ggml-quants-k.o $(OBJS) +save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o ggml-quants-k.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp build-info.h ggml.o llama.o common.o $(OBJS) +server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp build-info.h ggml.o ggml-quants-k.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) build-info.h: $(wildcard .git/index) scripts/build-info.sh @@ -289,7 +289,7 @@ build-info.h: $(wildcard .git/index) scripts/build-info.sh # Tests # -benchmark-matmult: examples/benchmark/benchmark-matmult.cpp build-info.h ggml.o $(OBJS) +benchmark-matmult: examples/benchmark/benchmark-matmult.cpp build-info.h ggml.o ggml-quants-k.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) ./$@ diff --git a/ggml.c b/ggml.c index 4e3e7edb9..8308dd991 100644 --- a/ggml.c +++ b/ggml.c @@ -14753,7 +14753,7 @@ static void ggml_graph_export_leaf(const struct ggml_tensor * tensor, FILE * fou const int64_t * ne = tensor->ne; const size_t * nb = tensor->nb; - fprintf(fout, "%-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %16p %32s\n", + fprintf(fout, "%-6s %-12s %8d %8jd %jd %jd %jd %16zu %16zu %16zu %16zu %16p %32s\n", ggml_type_name(tensor->type), ggml_op_name (tensor->op), tensor->n_dims, @@ -14767,7 +14767,7 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char const int64_t * ne = tensor->ne; const size_t * nb = tensor->nb; - fprintf(fout, "%-6s %-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %8d %16p %32s\n", + fprintf(fout, "%-6s %-6s %-12s %8d %jd %jd %jd %jd %16zu %16zu %16zu %16zu %8d %16p %32s\n", arg, ggml_type_name(tensor->type), ggml_op_name (tensor->op), @@ -14796,11 +14796,11 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { FILE * fout = stdout; fprintf(fout, "\n"); - fprintf(fout, "%-16s %8x\n", "magic", GGML_FILE_MAGIC); - fprintf(fout, "%-16s %8d\n", "version", GGML_FILE_VERSION); - fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs); - fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes); - fprintf(fout, "%-16s %8llu\n", "eval", size_eval); + fprintf(fout, "%-16s %8x\n", "magic", GGML_FILE_MAGIC); + fprintf(fout, "%-16s %8d\n", "version", GGML_FILE_VERSION); + fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs); + fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes); + fprintf(fout, "%-16s %8ju\n", "eval", size_eval); // header fprintf(fout, "\n"); @@ -15033,7 +15033,11 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** data = ggml_new_tensor_1d(*ctx_data, GGML_TYPE_I8, fsize); - fread(data->data, sizeof(char), fsize, fin); + const size_t ret = fread(data->data, sizeof(char), fsize, fin); + if (ret != fsize) { + fprintf(stderr, "%s: failed to read %s\n", __func__, fname); + return result; + } fclose(fin); }