diff --git a/.devops/llama-server-cuda.Dockerfile b/.devops/llama-server-cuda.Dockerfile index 0010ffd4c..7bef07a05 100644 --- a/.devops/llama-server-cuda.Dockerfile +++ b/.devops/llama-server-cuda.Dockerfile @@ -30,8 +30,10 @@ RUN make -j$(nproc) llama-server FROM ${BASE_CUDA_RUN_CONTAINER} as runtime RUN apt-get update && \ - apt-get install -y libcurl4-openssl-dev libgomp1 + apt-get install -y libcurl4-openssl-dev libgomp1 curl COPY --from=build /app/llama-server /llama-server +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + ENTRYPOINT [ "/llama-server" ] diff --git a/.devops/llama-server-intel.Dockerfile b/.devops/llama-server-intel.Dockerfile index cec436452..3bf1670ec 100644 --- a/.devops/llama-server-intel.Dockerfile +++ b/.devops/llama-server-intel.Dockerfile @@ -20,10 +20,12 @@ RUN if [ "${LLAMA_SYCL_F16}" = "ON" ]; then \ FROM intel/oneapi-basekit:$ONEAPI_VERSION as runtime RUN apt-get update && \ - apt-get install -y libcurl4-openssl-dev + apt-get install -y libcurl4-openssl-dev curl COPY --from=build /app/build/bin/llama-server /llama-server ENV LC_ALL=C.utf8 +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + ENTRYPOINT [ "/llama-server" ] diff --git a/.devops/llama-server-rocm.Dockerfile b/.devops/llama-server-rocm.Dockerfile index f88cf20e5..4b1cdc320 100644 --- a/.devops/llama-server-rocm.Dockerfile +++ b/.devops/llama-server-rocm.Dockerfile @@ -43,8 +43,10 @@ ENV CXX=/opt/rocm/llvm/bin/clang++ # Enable cURL ENV LLAMA_CURL=1 RUN apt-get update && \ - apt-get install -y libcurl4-openssl-dev + apt-get install -y libcurl4-openssl-dev curl RUN make -j$(nproc) llama-server +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + ENTRYPOINT [ "/app/llama-server" ] diff --git a/.devops/llama-server-vulkan.Dockerfile b/.devops/llama-server-vulkan.Dockerfile index b0fa0b8e6..2bc2e45d3 100644 --- a/.devops/llama-server-vulkan.Dockerfile +++ b/.devops/llama-server-vulkan.Dockerfile @@ -5,15 +5,11 @@ FROM ubuntu:$UBUNTU_VERSION as build # Install build tools RUN apt update && apt install -y git build-essential cmake wget -# Install Vulkan SDK +# Install Vulkan SDK and cURL RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - && \ wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list && \ apt update -y && \ - apt-get install -y vulkan-sdk - -# Install cURL -RUN apt-get update && \ - apt-get install -y libcurl4-openssl-dev + apt-get install -y vulkan-sdk libcurl4-openssl-dev curl # Build it WORKDIR /app @@ -28,4 +24,6 @@ RUN cp /app/build/bin/llama-server /llama-server && \ ENV LC_ALL=C.utf8 +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + ENTRYPOINT [ "/llama-server" ] diff --git a/.devops/llama-server.Dockerfile b/.devops/llama-server.Dockerfile index aa93369be..a53a5c999 100644 --- a/.devops/llama-server.Dockerfile +++ b/.devops/llama-server.Dockerfile @@ -3,7 +3,7 @@ ARG UBUNTU_VERSION=22.04 FROM ubuntu:$UBUNTU_VERSION as build RUN apt-get update && \ - apt-get install -y build-essential git libcurl4-openssl-dev + apt-get install -y build-essential git libcurl4-openssl-dev curl WORKDIR /app @@ -22,4 +22,6 @@ COPY --from=build /app/llama-server /llama-server ENV LC_ALL=C.utf8 +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + ENTRYPOINT [ "/llama-server" ] diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 6244b4812..01f1a4522 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -10,7 +10,7 @@ name: Publish Docker image on: - pull_request: + #pull_request: push: branches: - master @@ -22,7 +22,7 @@ concurrency: jobs: push_to_registry: name: Push Docker image to Docker Hub - if: github.event.pull_request.draft == false + #if: github.event.pull_request.draft == false runs-on: ubuntu-latest env: @@ -33,15 +33,13 @@ jobs: - { tag: "light", dockerfile: ".devops/llama-cli.Dockerfile", platforms: "linux/amd64,linux/arm64" } - { tag: "server", dockerfile: ".devops/llama-server.Dockerfile", platforms: "linux/amd64,linux/arm64" } - { tag: "full", dockerfile: ".devops/full.Dockerfile", platforms: "linux/amd64,linux/arm64" } - # NOTE(canardletter): The CUDA builds on arm64 are very slow, so I - # have disabled them for now until the reason why - # is understood. - { tag: "light-cuda", dockerfile: ".devops/llama-cli-cuda.Dockerfile", platforms: "linux/amd64" } - { tag: "server-cuda", dockerfile: ".devops/llama-server-cuda.Dockerfile", platforms: "linux/amd64" } - { tag: "full-cuda", dockerfile: ".devops/full-cuda.Dockerfile", platforms: "linux/amd64" } - { tag: "light-rocm", dockerfile: ".devops/llama-cli-rocm.Dockerfile", platforms: "linux/amd64,linux/arm64" } - { tag: "server-rocm", dockerfile: ".devops/llama-server-rocm.Dockerfile", platforms: "linux/amd64,linux/arm64" } - - { tag: "full-rocm", dockerfile: ".devops/full-rocm.Dockerfile", platforms: "linux/amd64,linux/arm64" } + # Note: the full-rocm image is failing due to a "no space left on device" error. It is disabled for now to allow the workflow to complete. + #- { tag: "full-rocm", dockerfile: ".devops/full-rocm.Dockerfile", platforms: "linux/amd64,linux/arm64" } - { tag: "light-intel", dockerfile: ".devops/llama-cli-intel.Dockerfile", platforms: "linux/amd64" } - { tag: "server-intel", dockerfile: ".devops/llama-server-intel.Dockerfile", platforms: "linux/amd64" } steps: diff --git a/CMakeLists.txt b/CMakeLists.txt index 49ba45356..1acf4bb08 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -102,7 +102,8 @@ option(LLAMA_LLAMAFILE "llama: use llamafile SGEMM" option(LLAMA_CUDA "llama: use CUDA" OFF) option(LLAMA_CUBLAS "llama: use CUDA (deprecated, use LLAMA_CUDA)" OFF) option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF) -option(LLAMA_CUDA_FORCE_MMQ "llama: use mmq kernels instead of cuBLAS" OFF) +option(LLAMA_CUDA_FORCE_MMQ "llama: always use mmq kernels instead of cuBLAS" OFF) +option(LLAMA_CUDA_FORCE_CUBLAS "llama: always use cuBLAS instead of mmq kernels" OFF) set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels") option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF) @@ -416,13 +417,14 @@ if (LLAMA_CUDA) if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) # 52 == lowest CUDA 12 standard - # 60 == f16 CUDA intrinsics + # 60 == FP16 CUDA intrinsics # 61 == integer CUDA intrinsics - # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster + # 70 == FP16 tensor cores + # 75 == int8 tensor cores if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16) - set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics + set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75") else() - set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics + set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75") #set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work endif() endif() @@ -447,6 +449,9 @@ if (LLAMA_CUDA) if (LLAMA_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() + if (LLAMA_CUDA_FORCE_CUBLAS) + add_compile_definitions(GGML_CUDA_FORCE_CUBLAS) + endif() if (LLAMA_CUDA_NO_VMM) add_compile_definitions(GGML_CUDA_NO_VMM) endif() diff --git a/Makefile b/Makefile index 3aad77394..f6e8eb73e 100644 --- a/Makefile +++ b/Makefile @@ -537,6 +537,9 @@ endif # LLAMA_CUDA_FORCE_DMMV ifdef LLAMA_CUDA_FORCE_MMQ MK_NVCCFLAGS += -DGGML_CUDA_FORCE_MMQ endif # LLAMA_CUDA_FORCE_MMQ +ifdef LLAMA_CUDA_FORCE_CUBLAS + MK_NVCCFLAGS += -DGGML_CUDA_FORCE_CUBLAS +endif # LLAMA_CUDA_FORCE_CUBLAS ifdef LLAMA_CUDA_DMMV_X MK_NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) else diff --git a/README.md b/README.md index 40793c8ea..95d970d83 100644 --- a/README.md +++ b/README.md @@ -510,8 +510,9 @@ Building the program with BLAS support may lead to some performance improvements |--------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. | | LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | - | LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. | - | LLAMA_CUDA_FORCE_MMQ | Boolean | false | Force the use of dequantization + matrix multiplication kernels instead of leveraging Math libraries. | | + | LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. | + | LLAMA_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, RDNA3). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. | + | LLAMA_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models | | LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | | LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | | LLAMA_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. | diff --git a/common/common.cpp b/common/common.cpp index cfdedcbae..c76d0e2c3 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -273,26 +273,22 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { return true; } +#define CHECK_ARG if (++i >= argc) { invalid_param = true; return true; } + bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) { const char split_delim = ','; llama_sampling_params & sparams = params.sparams; if (arg == "-s" || arg == "--seed") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG // TODO: this is temporary, in the future the sampling state will be moved fully to llama_sampling_context. params.seed = std::stoul(argv[i]); sparams.seed = std::stoul(argv[i]); return true; } if (arg == "-t" || arg == "--threads") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_threads = std::stoi(argv[i]); if (params.n_threads <= 0) { params.n_threads = std::thread::hardware_concurrency(); @@ -300,10 +296,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-tb" || arg == "--threads-batch") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_threads_batch = std::stoi(argv[i]); if (params.n_threads_batch <= 0) { params.n_threads_batch = std::thread::hardware_concurrency(); @@ -311,10 +304,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-td" || arg == "--threads-draft") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_threads_draft = std::stoi(argv[i]); if (params.n_threads_draft <= 0) { params.n_threads_draft = std::thread::hardware_concurrency(); @@ -322,10 +312,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-tbd" || arg == "--threads-batch-draft") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_threads_batch_draft = std::stoi(argv[i]); if (params.n_threads_batch_draft <= 0) { params.n_threads_batch_draft = std::thread::hardware_concurrency(); @@ -333,10 +320,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-p" || arg == "--prompt") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.prompt = argv[i]; return true; } @@ -349,10 +333,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--prompt-cache") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.path_prompt_cache = argv[i]; return true; } @@ -365,10 +346,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-bf" || arg == "--binary-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::ifstream file(argv[i], std::ios::binary); if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); @@ -384,10 +362,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-f" || arg == "--file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::ifstream file(argv[i]); if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); @@ -403,10 +378,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--in-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::ifstream file(argv[i]); if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); @@ -417,66 +389,42 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-n" || arg == "--predict" || arg == "--n-predict") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_predict = std::stoi(argv[i]); return true; } if (arg == "--top-k") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.top_k = std::stoi(argv[i]); return true; } if (arg == "-c" || arg == "--ctx-size") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_ctx = std::stoi(argv[i]); return true; } if (arg == "--grp-attn-n" || arg == "-gan") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.grp_attn_n = std::stoi(argv[i]); return true; } if (arg == "--grp-attn-w" || arg == "-gaw") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.grp_attn_w = std::stoi(argv[i]); return true; } if (arg == "--rope-freq-base") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.rope_freq_base = std::stof(argv[i]); return true; } if (arg == "--rope-freq-scale") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.rope_freq_scale = std::stof(argv[i]); return true; } if (arg == "--rope-scaling") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::string value(argv[i]); /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; } else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; } @@ -485,58 +433,37 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--rope-scale") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.rope_freq_scale = 1.0f / std::stof(argv[i]); return true; } if (arg == "--yarn-orig-ctx") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.yarn_orig_ctx = std::stoi(argv[i]); return true; } if (arg == "--yarn-ext-factor") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.yarn_ext_factor = std::stof(argv[i]); return true; } if (arg == "--yarn-attn-factor") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.yarn_attn_factor = std::stof(argv[i]); return true; } if (arg == "--yarn-beta-fast") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.yarn_beta_fast = std::stof(argv[i]); return true; } if (arg == "--yarn-beta-slow") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.yarn_beta_slow = std::stof(argv[i]); return true; } if (arg == "--pooling") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::string value(argv[i]); /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } @@ -546,157 +473,100 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--defrag-thold" || arg == "-dt") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.defrag_thold = std::stof(argv[i]); return true; } if (arg == "--samplers") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG const auto sampler_names = string_split(argv[i], ';'); sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, true); return true; } if (arg == "--sampling-seq") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.samplers_sequence = llama_sampling_types_from_chars(argv[i]); return true; } if (arg == "--top-p") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.top_p = std::stof(argv[i]); return true; } if (arg == "--min-p") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.min_p = std::stof(argv[i]); return true; } if (arg == "--temp") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.temp = std::stof(argv[i]); sparams.temp = std::max(sparams.temp, 0.0f); return true; } if (arg == "--tfs") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.tfs_z = std::stof(argv[i]); return true; } if (arg == "--typical") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.typical_p = std::stof(argv[i]); return true; } if (arg == "--repeat-last-n") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.penalty_last_n = std::stoi(argv[i]); sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n); return true; } if (arg == "--repeat-penalty") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.penalty_repeat = std::stof(argv[i]); return true; } if (arg == "--frequency-penalty") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.penalty_freq = std::stof(argv[i]); return true; } if (arg == "--presence-penalty") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.penalty_present = std::stof(argv[i]); return true; } if (arg == "--dynatemp-range") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.dynatemp_range = std::stof(argv[i]); return true; } if (arg == "--dynatemp-exp") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.dynatemp_exponent = std::stof(argv[i]); return true; } if (arg == "--mirostat") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.mirostat = std::stoi(argv[i]); return true; } if (arg == "--mirostat-lr") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.mirostat_eta = std::stof(argv[i]); return true; } if (arg == "--mirostat-ent") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.mirostat_tau = std::stof(argv[i]); return true; } if (arg == "--cfg-negative-prompt") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.cfg_negative_prompt = argv[i]; return true; } if (arg == "--cfg-negative-prompt-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::ifstream file(argv[i]); if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); @@ -710,203 +580,125 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--cfg-scale") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.cfg_scale = std::stof(argv[i]); return true; } if (arg == "-b" || arg == "--batch-size") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_batch = std::stoi(argv[i]); return true; } if (arg == "-ub" || arg == "--ubatch-size") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_ubatch = std::stoi(argv[i]); return true; } if (arg == "--keep") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_keep = std::stoi(argv[i]); return true; } if (arg == "--draft") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_draft = std::stoi(argv[i]); return true; } if (arg == "--chunks") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_chunks = std::stoi(argv[i]); return true; } if (arg == "-np" || arg == "--parallel") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_parallel = std::stoi(argv[i]); return true; } if (arg == "-ns" || arg == "--sequences") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_sequences = std::stoi(argv[i]); return true; } if (arg == "--p-split" || arg == "-ps") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.p_split = std::stof(argv[i]); return true; } if (arg == "-m" || arg == "--model") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.model = argv[i]; return true; } if (arg == "-md" || arg == "--model-draft") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.model_draft = argv[i]; return true; } if (arg == "-a" || arg == "--alias") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.model_alias = argv[i]; return true; } if (arg == "-mu" || arg == "--model-url") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.model_url = argv[i]; return true; } if (arg == "-hfr" || arg == "--hf-repo") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.hf_repo = argv[i]; return true; } if (arg == "-hff" || arg == "--hf-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.hf_file = argv[i]; return true; } if (arg == "--lora") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.lora_adapter.emplace_back(argv[i], 1.0f); params.use_mmap = false; return true; } if (arg == "--lora-scaled") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG const char* lora_adapter = argv[i]; - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i])); params.use_mmap = false; return true; } if (arg == "--lora-base") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.lora_base = argv[i]; return true; } if (arg == "--control-vector") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.control_vectors.push_back({ 1.0f, argv[i], }); return true; } if (arg == "--control-vector-scaled") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG const char* fname = argv[i]; - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.control_vectors.push_back({ std::stof(argv[i]), fname, }); return true; } if (arg == "--control-vector-layer-range") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.control_vector_layer_start = std::stoi(argv[i]); - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.control_vector_layer_end = std::stoi(argv[i]); return true; } if (arg == "--mmproj") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.mmproj = argv[i]; return true; } if (arg == "--image") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.image.emplace_back(argv[i]); return true; } @@ -922,6 +714,21 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.embedding = true; return true; } + if (arg == "--embd-normalize") { + CHECK_ARG + params.embd_normalize = std::stoi(argv[i]); + return true; + } + if (arg == "--embd-output-format") { + CHECK_ARG + params.embd_out = argv[i]; + return true; + } + if (arg == "--embd-separator") { + CHECK_ARG + params.embd_sep = argv[i]; + return true; + } if (arg == "-if" || arg == "--interactive-first") { params.interactive_first = true; return true; @@ -975,10 +782,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_gpu_layers = std::stoi(argv[i]); if (!llama_supports_gpu_offload()) { fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers option will be ignored\n"); @@ -987,10 +791,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-ngld" || arg == "--gpu-layers-draft" || arg == "--gpu-layers-draft") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_gpu_layers_draft = std::stoi(argv[i]); if (!llama_supports_gpu_offload()) { fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers-draft option will be ignored\n"); @@ -999,10 +800,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--main-gpu" || arg == "-mg") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.main_gpu = std::stoi(argv[i]); #ifndef GGML_USE_CUDA_SYCL_VULKAN fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL/Vulkan. Setting the main GPU has no effect.\n"); @@ -1010,10 +808,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--split-mode" || arg == "-sm") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::string arg_next = argv[i]; if (arg_next == "none") { params.split_mode = LLAMA_SPLIT_MODE_NONE; @@ -1038,10 +833,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--tensor-split" || arg == "-ts") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::string arg_next = argv[i]; // split string by , and / @@ -1066,10 +858,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--rpc") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.rpc_servers = argv[i]; return true; } @@ -1078,10 +867,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--numa") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::string value(argv[i]); /**/ if (value == "distribute" || value == "") { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; } else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; } @@ -1094,10 +880,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--verbosity") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.verbosity = std::stoi(argv[i]); return true; } @@ -1110,18 +893,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-r" || arg == "--reverse-prompt") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.antiprompt.emplace_back(argv[i]); return true; } if (arg == "-ld" || arg == "--logdir") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.logdir = argv[i]; if (params.logdir.back() != DIRECTORY_SEPARATOR) { @@ -1130,26 +907,17 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-lcs" || arg == "--lookup-cache-static") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.lookup_cache_static = argv[i]; return true; } if (arg == "-lcd" || arg == "--lookup-cache-dynamic") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.lookup_cache_dynamic = argv[i]; return true; } if (arg == "--save-all-logits" || arg == "--kl-divergence-base") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.logits_file = argv[i]; return true; } @@ -1158,26 +926,17 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--ppl-stride") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.ppl_stride = std::stoi(argv[i]); return true; } if (arg == "--ppl-output-type") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.ppl_output_type = std::stoi(argv[i]); return true; } if (arg == "-ptc" || arg == "--print-token-count") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_print = std::stoi(argv[i]); return true; } @@ -1190,10 +949,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--hellaswag-tasks") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.hellaswag_tasks = std::stoi(argv[i]); return true; } @@ -1202,10 +958,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--winogrande-tasks") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.winogrande_tasks = std::stoi(argv[i]); return true; } @@ -1214,10 +967,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--multiple-choice-tasks") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.multiple_choice_tasks = std::stoi(argv[i]); return true; } @@ -1234,10 +984,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-l" || arg == "--logit-bias") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::stringstream ss(argv[i]); llama_token key; char sign; @@ -1270,34 +1017,22 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--in-prefix") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.input_prefix = argv[i]; return true; } if (arg == "--in-suffix") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.input_suffix = argv[i]; return true; } if (arg == "--grammar") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.grammar = argv[i]; return true; } if (arg == "--grammar-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::ifstream file(argv[i]); if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); @@ -1312,18 +1047,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-j" || arg == "--json-schema") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.grammar = json_schema_to_grammar(json::parse(argv[i])); return true; } if (arg == "--override-kv") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG if (!string_parse_kv_override(argv[i], params.kv_overrides)) { fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]); invalid_param = true; @@ -1332,42 +1061,27 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--host") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.hostname = argv[i]; return true; } if (arg == "--port") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.port = std::stoi(argv[i]); return true; } if (arg == "--path") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.public_path = argv[i]; return true; } if (arg == "--api-key") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.api_keys.push_back(argv[i]); return true; } if (arg == "--api-key-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::ifstream key_file(argv[i]); if (!key_file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); @@ -1384,43 +1098,28 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--ssl-key-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.ssl_file_key = argv[i]; return true; } if (arg == "--ssl-cert-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.ssl_file_cert = argv[i]; return true; } if (arg == "--timeout" || arg == "-to") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.timeout_read = std::stoi(argv[i]); params.timeout_write = std::stoi(argv[i]); return true; } if (arg == "--threads-http") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_threads_http = std::stoi(argv[i]); return true; } if (arg == "-spf" || arg == "--system-prompt-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::ifstream file(argv[i]); if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); @@ -1437,10 +1136,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--log-format") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG if (std::strcmp(argv[i], "json") == 0) { params.log_json = true; } else if (std::strcmp(argv[i], "text") == 0) { @@ -1460,10 +1156,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--slot-save-path") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.slot_save_path = argv[i]; // if doesn't end with DIRECTORY_SEPARATOR, add it if (!params.slot_save_path.empty() && params.slot_save_path[params.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) { @@ -1472,10 +1165,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--chat-template") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG if (!llama_chat_verify_template(argv[i])) { fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]); fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n"); @@ -1486,10 +1176,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--slot-prompt-similarity" || arg == "-sps") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.slot_prompt_similarity = std::stof(argv[i]); return true; } @@ -1498,37 +1185,25 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-npp") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG auto p = string_split(argv[i], split_delim); params.n_pp.insert(params.n_pp.end(), p.begin(), p.end()); return true; } if (arg == "-ntg") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG auto p = string_split(argv[i], split_delim); params.n_tg.insert(params.n_tg.end(), p.begin(), p.end()); return true; } if (arg == "-npl") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG auto p = string_split(argv[i], split_delim); params.n_pl.insert(params.n_pl.end(), p.begin(), p.end()); return true; } if (arg == "--context-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::ifstream file(argv[i], std::ios::binary); if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); @@ -1539,59 +1214,38 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--chunk-size") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.chunk_size = std::stoi(argv[i]); return true; } if (arg == "--chunk-separator") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.chunk_separator = argv[i]; return true; } if (arg == "--junk") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_junk = std::stoi(argv[i]); return true; } if (arg == "--pos") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.i_pos = std::stoi(argv[i]); return true; } if (arg == "-o" || arg == "--output" || arg == "--output-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.out_file = argv[i]; params.cvector_outfile = argv[i]; return true; } if (arg == "-ofreq" || arg == "--output-frequency") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_out_freq = std::stoi(argv[i]); return true; } if (arg == "--save-frequency") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_save_freq = std::stoi(argv[i]); return true; } @@ -1604,62 +1258,39 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--chunk" || arg == "--from-chunk") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.i_chunk = std::stoi(argv[i]); return true; } // cvector params - if (arg == "--completions-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } - params.cvector_completions_file = argv[i]; - return true; - } if (arg == "--positive-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.cvector_positive_file = argv[i]; return true; } if (arg == "--negative-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.cvector_negative_file = argv[i]; return true; } - if (arg == "--completions") { - if (++i >= argc) { - invalid_param = true; - return true; - } - params.n_completions = std::stoi(argv[i]); - return true; - } if (arg == "--pca-batch") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_pca_batch = std::stoi(argv[i]); return true; } if (arg == "--pca-iter") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_pca_iterations = std::stoi(argv[i]); return true; } + if (arg == "--method") { + CHECK_ARG + std::string value(argv[i]); + /**/ if (value == "pca") { params.cvector_dimre_method = DIMRE_METHOD_PCA; } + else if (value == "mean") { params.cvector_dimre_method = DIMRE_METHOD_MEAN; } + else { invalid_param = true; } + return true; + } #ifndef LOG_DISABLE_LOGS // Parse args for logging parameters if (log_param_single_parse(argv[i])) { @@ -1671,10 +1302,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa // We have a matching known parameter requiring an argument, // now we need to check if there is anything after this argv // and flag invalid_param or parse it. - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG if (!log_param_pair_parse( /*check_but_dont_parse*/ false, argv[i - 1], argv[i])) { invalid_param = true; return true; @@ -1814,7 +1442,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "main", " --cfg-negative-prompt-file FNAME", "negative prompt file to use for guidance" }); options.push_back({ "main", " --cfg-scale N", "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale }); - + options.push_back({ "main", " --chat-template JINJA_TEMPLATE", + "set custom jinja chat template (default: template taken from model's metadata)\n" + "only commonly used templates are accepted:\n" + "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" }); options.push_back({ "grammar" }); options.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() }); options.push_back({ "*", " --grammar-file FNAME", "file to read grammar from" }); @@ -1908,9 +1539,11 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --lora FNAME", "apply LoRA adapter (implies --no-mmap)" }); options.push_back({ "*", " --lora-scaled FNAME S", "apply LoRA adapter with user defined scaling S (implies --no-mmap)" }); options.push_back({ "*", " --lora-base FNAME", "optional model to use as a base for the layers modified by the LoRA adapter" }); - options.push_back({ "*", " --control-vector FNAME", "add a control vector" }); + options.push_back({ "*", " --control-vector FNAME", "add a control vector\n" + "note: this argument can be repeated to add multiple control vectors" }); options.push_back({ "*", " --control-vector-scaled FNAME SCALE", - "add a control vector with user defined scaling SCALE" }); + "add a control vector with user defined scaling SCALE\n" + "note: this argument can be repeated to add multiple scaled control vectors" }); options.push_back({ "*", " --control-vector-layer-range START END", "layer range to apply the control vector(s) to, start and end inclusive" }); options.push_back({ "*", "-m, --model FNAME", "model path (default: models/$filename with filename from --hf-file\n" @@ -1944,6 +1577,11 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "bench", "-ntg n0,n1,...", "number of text generation tokens" }); options.push_back({ "bench", "-npl n0,n1,...", "number of parallel prompts" }); + options.push_back({ "embedding" }); + options.push_back({ "embedding", " --embd-normalize", "normalisation for embendings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", params.embd_normalize }); + options.push_back({ "embedding", " --embd-output-format", "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix" }); + options.push_back({ "embedding", " --embd-separator", "separator of embendings (default \\n) for example \"<#sep#>\"" }); + options.push_back({ "server" }); options.push_back({ "server", " --host HOST", "ip address to listen (default: %s)", params.hostname.c_str() }); options.push_back({ "server", " --port PORT", "port to listen (default: %d)", params.port }); @@ -1986,11 +1624,9 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "cvector", "-o, --output FNAME", "output file (default: '%s')", params.cvector_outfile.c_str() }); options.push_back({ "cvector", " --positive-file FNAME", "positive prompts file, one prompt per line (default: '%s')", params.cvector_positive_file.c_str() }); options.push_back({ "cvector", " --negative-file FNAME", "negative prompts file, one prompt per line (default: '%s')", params.cvector_negative_file.c_str() }); - options.push_back({ "cvector", " --completions-file FNAME", - "completions file (default: '%s')", params.cvector_completions_file.c_str() }); - options.push_back({ "cvector", " --completions N", "number of lines of completions file to use (default: %d)", params.n_completions }); options.push_back({ "cvector", " --pca-batch N", "batch size used for PCA. Larger batch runs faster, but uses more memory (default: %d)", params.n_pca_batch }); options.push_back({ "cvector", " --pca-iter N", "number of iterations used for PCA (default: %d)", params.n_pca_iterations }); + options.push_back({ "cvector", " --method {pca,mean}", "dimensionality reduction method to be used (default: pca)" }); printf("usage: %s [options]\n", argv[0]); @@ -2967,12 +2603,67 @@ bool llama_should_add_bos_token(const llama_model * model) { return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM); } +// +// Chat template utils +// + bool llama_chat_verify_template(const std::string & tmpl) { llama_chat_message chat[] = {{"user", "test"}}; int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); return res >= 0; } +std::string llama_chat_apply_template(const struct llama_model * model, + const std::string & tmpl, + const std::vector & msgs, + bool add_ass) { + int alloc_size = 0; + std::vector chat; + for (auto & msg : msgs) { + chat.push_back({msg.role.c_str(), msg.content.c_str()}); + alloc_size += (msg.role.size() + msg.content.size()) * 1.25; + } + + const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); + std::vector buf(alloc_size); + + // run the first time to get the total output length + int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + + // if it turns out that our buffer is too small, we resize it + if ((size_t) res > buf.size()) { + buf.resize(res); + res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + } + + std::string formatted_chat(buf.data(), res); + return formatted_chat; +} + +std::string llama_chat_format_single(const struct llama_model * model, + const std::string & tmpl, + const std::vector & past_msg, + const llama_chat_msg & new_msg, + bool add_ass) { + auto fmt_past_msg = llama_chat_apply_template(model, tmpl, past_msg, false); + std::vector chat_new(past_msg); + chat_new.push_back(new_msg); + auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass); + auto formatted = fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); + return formatted; +} + +std::string llama_chat_format_example(const struct llama_model * model, + const std::string & tmpl) { + std::vector msgs = { + {"system", "You are a helpful assistant"}, + {"user", "Hello"}, + {"assistant", "Hi there"}, + {"user", "How are you?"}, + }; + return llama_chat_apply_template(model, tmpl, msgs, true); +} + // // KV cache utils // @@ -3052,14 +2743,34 @@ void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_siz // Embedding utils // -void llama_embd_normalize(const float * inp, float * out, int n) { +void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm) { double sum = 0.0; - for (int i = 0; i < n; i++) { - sum += inp[i] * inp[i]; - } - sum = sqrt(sum); - const float norm = sum > 0.0 ? 1.0f / sum : 0.0f; + switch (embd_norm) { + case -1: // no normalisation + sum = 1.0; + break; + case 0: // max absolute + for (int i = 0; i < n; i++) { + if (sum < std::abs(inp[i])) sum = std::abs(inp[i]); + } + sum /= 32760.0; // make an int16 range + break; + case 2: // euclidean + for (int i = 0; i < n; i++) { + sum += inp[i] * inp[i]; + } + sum = std::sqrt(sum); + break; + default: // p-norm (euclidean is p-norm p=2) + for (int i = 0; i < n; i++) { + sum += std::pow(std::abs(inp[i]), embd_norm); + } + sum = std::pow(sum, 1.0 / embd_norm); + break; + } + + const float norm = sum > 0.0 ? 1.0 / sum : 0.0f; for (int i = 0; i < n; i++) { out[i] = inp[i] * norm; @@ -3077,6 +2788,14 @@ float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n) sum2 += embd2[i] * embd2[i]; } + // Handle the case where one or both vectors are zero vectors + if (sum1 == 0.0 || sum2 == 0.0) { + if (sum1 == 0.0 && sum2 == 0.0) { + return 1.0f; // two zero vectors are similar + } + return 0.0f; + } + return sum / (sqrt(sum1) * sqrt(sum2)); } diff --git a/common/common.h b/common/common.h index 9a1dc4a2f..c541204f6 100644 --- a/common/common.h +++ b/common/common.h @@ -52,6 +52,12 @@ int32_t cpu_get_num_math(); // CLI argument parsing // +// dimensionality reduction methods, used by cvector-generator +enum dimre_method { + DIMRE_METHOD_PCA, + DIMRE_METHOD_MEAN, +}; + struct gpt_params { uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed @@ -152,7 +158,6 @@ struct gpt_params { bool prompt_cache_all = false; // save user input and generations to prompt cache bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it - bool embedding = false; // get only sentence embedding bool escape = true; // escape "\n", "\r", "\t", "\'", "\"", and "\\" bool multiline_input = false; // reverse the usage of `\` bool simple_io = false; // improves compatibility with subprocesses and limited consoles @@ -179,6 +184,12 @@ struct gpt_params { std::string mmproj = ""; // path to multimodal projector std::vector image; // path to image file(s) + // embedding + bool embedding = false; // get only sentence embedding + int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) + std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix + std::string embd_sep = "\n"; // separator of embendings + // server params int32_t port = 8080; // server listens on this network port int32_t timeout_read = 600; // http read timeout in seconds @@ -233,13 +244,12 @@ struct gpt_params { bool compute_ppl = true; // whether to compute perplexity // cvector-generator params - int n_completions = 64; - int n_pca_batch = 20; + int n_pca_batch = 100; int n_pca_iterations = 1000; - std::string cvector_outfile = "control_vector.gguf"; - std::string cvector_completions_file = "examples/cvector-generator/completions.txt"; - std::string cvector_positive_file = "examples/cvector-generator/positive.txt"; - std::string cvector_negative_file = "examples/cvector-generator/negative.txt"; + dimre_method cvector_dimre_method = DIMRE_METHOD_PCA; + std::string cvector_outfile = "control_vector.gguf"; + std::string cvector_positive_file = "examples/cvector-generator/positive.txt"; + std::string cvector_negative_file = "examples/cvector-generator/negative.txt"; }; void gpt_params_handle_model_default(gpt_params & params); @@ -360,9 +370,32 @@ bool llama_should_add_bos_token(const llama_model * model); // Chat template utils // +// same with llama_chat_message, but uses std::string +struct llama_chat_msg { + std::string role; + std::string content; +}; + // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid bool llama_chat_verify_template(const std::string & tmpl); +// CPP wrapper for llama_chat_apply_template +std::string llama_chat_apply_template(const struct llama_model * model, + const std::string & tmpl, + const std::vector & chat, + bool add_ass); + +// Format single message, while taking into account the position of that message in chat history +std::string llama_chat_format_single(const struct llama_model * model, + const std::string & tmpl, + const std::vector & past_msg, + const llama_chat_msg & new_msg, + bool add_ass); + +// Returns an example of formatted chat +std::string llama_chat_format_example(const struct llama_model * model, + const std::string & tmpl); + // // KV cache utils // @@ -377,7 +410,7 @@ void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_siz // Embedding utils // -void llama_embd_normalize(const float * inp, float * out, int n); +void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2); float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n); diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 10b9b3d1d..2f233e2e7 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -40,6 +40,233 @@ static std::string build_repetition(const std::string & item_rule, int min_items return result; } +/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */ +class string_view { + const std::string & _str; + const size_t _start; + const size_t _end; +public: + string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {} + + size_t size() const { + return _end - _start; + } + + size_t length() const { + return size(); + } + + operator std::string() const { + return str(); + } + + std::string str() const { + return _str.substr(_start, _end - _start); + } + + string_view substr(size_t pos, size_t len = std::string::npos) const { + return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len); + } + + char operator[](size_t pos) const { + auto index = _start + pos; + if (index >= _end) { + throw std::out_of_range("string_view index out of range"); + } + return _str[_start + pos]; + } + + bool operator==(const string_view & other) const { + std::string this_str = *this; + std::string other_str = other; + return this_str == other_str; + } +}; + +static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) { + auto has_min = min_value != std::numeric_limits::min(); + auto has_max = max_value != std::numeric_limits::max(); + + auto digit_range = [&](char from, char to) { + out << "["; + if (from == to) { + out << from; + } else { + out << from << "-" << to; + } + out << "]"; + }; + auto more_digits = [&](int min_digits, int max_digits) { + out << "[0-9]"; + if (min_digits == max_digits && min_digits == 1) { + return; + } + out << "{"; + out << min_digits; + if (max_digits != min_digits) { + out << ","; + if (max_digits != std::numeric_limits::max()) { + out << max_digits; + } + } + out << "}"; + }; + std::function uniform_range = + [&](const string_view & from, const string_view & to) { + size_t i = 0; + while (i < from.length() && i < to.length() && from[i] == to[i]) { + i++; + } + if (i > 0) { + out << "\"" << from.substr(0, i).str() << "\""; + } + if (i < from.length() && i < to.length()) { + if (i > 0) { + out << " "; + } + auto sub_len = from.length() - i - 1; + if (sub_len > 0) { + auto from_sub = from.substr(i + 1); + auto to_sub = to.substr(i + 1); + auto sub_zeros = repeat("0", sub_len); + auto sub_nines = repeat("9", sub_len); + + auto to_reached = false; + out << "("; + if (from_sub == sub_zeros) { + digit_range(from[i], to[i] - 1); + out << " "; + more_digits(sub_len, sub_len); + } else { + out << "[" << from[i] << "] "; + out << "("; + uniform_range(from_sub, sub_nines); + out << ")"; + if (from[i] < to[i] - 1) { + out << " | "; + if (to_sub == sub_nines) { + digit_range(from[i] + 1, to[i]); + to_reached = true; + } else { + digit_range(from[i] + 1, to[i] - 1); + } + out << " "; + more_digits(sub_len, sub_len); + } + } + if (!to_reached) { + out << " | "; + digit_range(to[i], to[i]); + out << " "; + uniform_range(sub_zeros, to_sub); + } + out << ")"; + } else { + out << "[" << from[i] << "-" << to[i] << "]"; + } + } + }; + + if (has_min && has_max) { + if (min_value < 0 && max_value < 0) { + out << "\"-\" ("; + _build_min_max_int(-max_value, -min_value, out, decimals_left, /* top_level= */ true); + out << ")"; + return; + } + + if (min_value < 0) { + out << "\"-\" ("; + _build_min_max_int(0, -min_value, out, decimals_left, /* top_level= */ true); + out << ") | "; + min_value = 0; + } + + auto min_s = std::to_string(min_value); + auto max_s = std::to_string(max_value); + auto min_digits = min_s.length(); + auto max_digits = max_s.length(); + + for (auto digits = min_digits; digits < max_digits; digits++) { + uniform_range(min_s, repeat("9", digits)); + min_s = "1" + repeat("0", digits); + out << " | "; + } + uniform_range(min_s, max_s); + return; + } + + auto less_decimals = std::max(decimals_left - 1, 1); + + if (has_min) { + if (min_value < 0) { + out << "\"-\" ("; + _build_min_max_int(std::numeric_limits::min(), -min_value, out, decimals_left, /* top_level= */ false); + out << ") | [0] | [1-9] "; + more_digits(0, decimals_left - 1); + } else if (min_value == 0) { + if (top_level) { + out << "[0] | [1-9] "; + more_digits(0, less_decimals); + } else { + more_digits(1, decimals_left); + } + } else if (min_value <= 9) { + char c = '0' + min_value; + auto range_start = top_level ? '1' : '0'; + if (c > range_start) { + digit_range(range_start, c - 1); + out << " "; + more_digits(1, less_decimals); + out << " | "; + } + digit_range(c, '9'); + out << " "; + more_digits(0, less_decimals); + } else { + auto min_s = std::to_string(min_value); + auto len = min_s.length(); + auto c = min_s[0]; + + if (c > '1') { + digit_range(top_level ? '1' : '0', c - 1); + out << " "; + more_digits(len, less_decimals); + out << " | "; + } + digit_range(c, c); + out << " ("; + _build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits::max(), out, less_decimals, /* top_level= */ false); + out << ")"; + if (c < '9') { + out << " | "; + digit_range(c + 1, '9'); + out << " "; + more_digits(len - 1, less_decimals); + } + } + return; + } + + if (has_max) { + if (max_value >= 0) { + if (top_level) { + out << "\"-\" [1-9] "; + more_digits(0, less_decimals); + out << " | "; + } + _build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true); + } else { + out << "\"-\" ("; + _build_min_max_int(-max_value, std::numeric_limits::max(), out, decimals_left, /* top_level= */ false); + out << ")"; + } + return; + } + + throw std::runtime_error("At least one of min_value or max_value must be set"); +} + const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}"; struct BuiltinRule { @@ -160,7 +387,6 @@ static std::string format_literal(const std::string & literal) { return "\"" + escaped + "\""; } - class SchemaConverter { private: std::function _fetch_json; @@ -388,6 +614,75 @@ private: return _add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space"); } + /* + Returns a rule that matches a JSON string that is none of the provided strings + + not_strings({"a"}) + -> ["] ( [a] char+ | [^"a] char* )? ["] space + not_strings({"and", "also"}) + -> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space + */ + std::string _not_strings(const std::vector & strings) { + + struct TrieNode { + std::map children; + bool is_end_of_string; + + TrieNode() : is_end_of_string(false) {} + + void insert(const std::string & string) { + auto node = this; + for (char c : string) { + node = &node->children[c]; + } + node->is_end_of_string = true; + } + }; + + TrieNode trie; + for (const auto & s : strings) { + trie.insert(s); + } + + std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char")); + std::ostringstream out; + out << "[\"] ( "; + std::function visit = [&](const TrieNode & node) { + std::ostringstream rejects; + auto first = true; + for (const auto & kv : node.children) { + rejects << kv.first; + if (first) { + first = false; + } else { + out << " | "; + } + out << "[" << kv.first << "]"; + if (!kv.second.children.empty()) { + out << " ("; + visit(kv.second); + out << ")"; + } else if (kv.second.is_end_of_string) { + out << " " << char_rule << "+"; + } + } + if (!node.children.empty()) { + if (!first) { + out << " | "; + } + out << "[^\"" << rejects.str() << "] " << char_rule << "*"; + } + }; + visit(trie); + + out << " )"; + if (!trie.is_end_of_string) { + out << "?"; + } + out << " [\"] space"; + return out.str(); + } + std::string _resolve_ref(const std::string & ref) { std::string ref_name = ref.substr(ref.find_last_of('/') + 1); if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) { @@ -408,6 +703,7 @@ private: std::vector required_props; std::vector optional_props; std::unordered_map prop_kv_rule_names; + std::vector prop_names; for (const auto & kv : properties) { const auto &prop_name = kv.first; const auto &prop_schema = kv.second; @@ -422,11 +718,18 @@ private: } else { optional_props.push_back(prop_name); } + prop_names.push_back(prop_name); } - if (additional_properties.is_object() || (additional_properties.is_boolean() && additional_properties.get())) { + if (!(additional_properties.is_boolean() && !additional_properties.get())) { std::string sub_name = name + (name.empty() ? "" : "-") + "additional"; - std::string value_rule = visit(additional_properties.is_object() ? additional_properties : json::object(), sub_name + "-value"); - std::string kv_rule = _add_rule(sub_name + "-kv", _add_primitive("string", PRIMITIVE_RULES.at("string")) + " \":\" space " + value_rule); + std::string value_rule = + additional_properties.is_object() ? visit(additional_properties, sub_name + "-value") + : _add_primitive("value", PRIMITIVE_RULES.at("value")); + + auto key_rule = + prop_names.empty() ? _add_primitive("string", PRIMITIVE_RULES.at("string")) + : _add_rule(sub_name + "-k", _not_strings(prop_names)); + std::string kv_rule = _add_rule(sub_name + "-kv", key_rule + " \":\" space " + value_rule); prop_kv_rule_names["*"] = kv_rule; optional_props.push_back("*"); } @@ -452,15 +755,11 @@ private: } std::string k = ks[0]; std::string kv_rule_name = prop_kv_rule_names[k]; - if (k == "*") { - res = _add_rule( - name + (name.empty() ? "" : "-") + "additional-kvs", - kv_rule_name + " ( \",\" space " + kv_rule_name + " )*" - ); - } else if (first_is_optional) { - res = "( \",\" space " + kv_rule_name + " )?"; + std::string comma_ref = "( \",\" space " + kv_rule_name + " )"; + if (first_is_optional) { + res = comma_ref + (k == "*" ? "*" : "?"); } else { - res = kv_rule_name; + res = kv_rule_name + (k == "*" ? " " + comma_ref + "*" : ""); } if (ks.size() > 1) { res += " " + _add_rule( @@ -594,17 +893,19 @@ public: } else if (schema_type.is_array()) { std::vector schema_types; for (const auto & t : schema_type) { - schema_types.push_back({{"type", t}}); + json schema_copy(schema); + schema_copy["type"] = t; + schema_types.push_back(schema_copy); } return _add_rule(rule_name, _generate_union_rule(name, schema_types)); } else if (schema.contains("const")) { - return _add_rule(rule_name, _generate_constant_rule(schema["const"])); + return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space"); } else if (schema.contains("enum")) { std::vector enum_values; for (const auto & v : schema["enum"]) { enum_values.push_back(_generate_constant_rule(v)); } - return _add_rule(rule_name, join(enum_values.begin(), enum_values.end(), " | ")); + return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space"); } else if ((schema_type.is_null() || schema_type == "object") && (schema.contains("properties") || (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) { @@ -686,6 +987,24 @@ public: int min_len = schema.contains("minLength") ? schema["minLength"].get() : 0; int max_len = schema.contains("maxLength") ? schema["maxLength"].get() : std::numeric_limits::max(); return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space"); + } else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) { + int min_value = std::numeric_limits::min(); + int max_value = std::numeric_limits::max(); + if (schema.contains("minimum")) { + min_value = schema["minimum"].get(); + } else if (schema.contains("exclusiveMinimum")) { + min_value = schema["exclusiveMinimum"].get() + 1; + } + if (schema.contains("maximum")) { + max_value = schema["maximum"].get(); + } else if (schema.contains("exclusiveMaximum")) { + max_value = schema["exclusiveMaximum"].get() - 1; + } + std::stringstream out; + out << "("; + _build_min_max_int(min_value, max_value, out); + out << ") space"; + return _add_rule(rule_name, out.str()); } else if (schema.empty() || schema_type == "object") { return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object"))); } else { diff --git a/common/sampling.cpp b/common/sampling.cpp index f1f803516..9f332fe57 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -28,9 +28,13 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ std::vector grammar_rules(result->parsed_grammar.c_rules()); - result->grammar = llama_grammar_init( + struct llama_grammar * grammar = llama_grammar_init( grammar_rules.data(), grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root")); + if (grammar == nullptr) { + throw std::runtime_error("Failed to initialize llama_grammar"); + } + result->grammar = grammar; } result->prev.resize(params.n_prev); @@ -59,9 +63,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) { if (!ctx->parsed_grammar.rules.empty()) { std::vector grammar_rules(ctx->parsed_grammar.c_rules()); - ctx->grammar = llama_grammar_init( + struct llama_grammar * grammar = llama_grammar_init( grammar_rules.data(), grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root")); + if (grammar == nullptr) { + throw std::runtime_error("Failed to initialize llama_grammar"); + } + ctx->grammar = grammar; } std::fill(ctx->prev.begin(), ctx->prev.end(), 0); diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 8ce79d146..c26fad930 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -65,7 +65,8 @@ class Model: # subclasses should define this! model_arch: gguf.MODEL_ARCH - def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, model_name: str | None): + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, + model_name: str | None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False): if type(self) is Model: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") self.dir_model = dir_model @@ -80,7 +81,7 @@ class Model: if not self.is_safetensors: self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin") self.hparams = Model.load_hparams(self.dir_model) - self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"]) + self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) self.tensor_names = None if self.ftype == gguf.LlamaFileType.GUESSED: @@ -96,7 +97,8 @@ class Model: ftype_lw: str = ftype_up.lower() # allow templating the file name with the output ftype, useful with the "auto" ftype self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up) - self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file) + self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, + split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard) @classmethod def __init_subclass__(cls): @@ -332,6 +334,8 @@ class Model: self.gguf_writer.close() def write_vocab(self): + if len(self.gguf_writer.tensors) != 1: + raise ValueError('Splitting the vocabulary is not supported') self.gguf_writer.write_header_to_file(self.fname_out) self.gguf_writer.write_kv_data_to_file() self.gguf_writer.close() @@ -2771,6 +2775,124 @@ class DeepseekV2Model(Model): raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("T5ForConditionalGeneration") +@Model.register("T5WithLMHeadModel") +class T5Model(Model): + model_arch = gguf.MODEL_ARCH.T5 + + def set_vocab(self): + # to avoid TypeError: Descriptors cannot be created directly + # exception when importing sentencepiece_model_pb2 + os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + from sentencepiece import SentencePieceProcessor + from sentencepiece import sentencepiece_model_pb2 as model + + tokenizer_path = self.dir_model / 'spiece.model' + + if not tokenizer_path.is_file(): + raise FileNotFoundError(f"File not found: {tokenizer_path}") + + sentencepiece_model = model.ModelProto() + sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) + add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix + remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces + precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap + assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM + + tokenizer = SentencePieceProcessor() + tokenizer.LoadFromFile(str(tokenizer_path)) + + vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + + tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] + scores: list[float] = [-10000.0] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size + + for token_id in range(tokenizer.vocab_size()): + piece = tokenizer.IdToPiece(token_id) + text = piece.encode("utf-8") + score = tokenizer.GetScore(token_id) + + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.IsUnknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.IsControl(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.IsUnused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.IsByte(token_id): + toktype = SentencePieceTokenTypes.BYTE + + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype + + added_tokens_file = self.dir_model / 'added_tokens.json' + if added_tokens_file.is_file(): + with open(added_tokens_file, "r", encoding="utf-8") as f: + added_tokens_json = json.load(f) + for key in added_tokens_json: + token_id = added_tokens_json[key] + if (token_id >= vocab_size): + logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') + continue + + tokens[token_id] = key.encode("utf-8") + scores[token_id] = -1000.0 + toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + + if vocab_size > len(tokens): + pad_count = vocab_size - len(tokens) + logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") + for i in range(1, pad_count + 1): + tokens.append(bytes(f"[PAD{i}]", encoding="utf-8")) + scores.append(-1000.0) + toktypes.append(SentencePieceTokenTypes.UNUSED) + + self.gguf_writer.add_tokenizer_model("t5") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_add_space_prefix(add_prefix) + self.gguf_writer.add_remove_extra_whitespaces(remove_whitespaces) + if precompiled_charsmap: + self.gguf_writer.add_precompiled_charsmap(precompiled_charsmap) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + self.gguf_writer.add_add_bos_token(False) + self.gguf_writer.add_add_eos_token(True) + + def set_gguf_parameters(self): + self.gguf_writer.add_name("T5") + self.gguf_writer.add_context_length(self.hparams["n_positions"]) + self.gguf_writer.add_embedding_length(self.hparams["d_model"]) + self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"]) + self.gguf_writer.add_block_count(self.hparams["num_layers"]) + self.gguf_writer.add_head_count(self.hparams["num_heads"]) + self.gguf_writer.add_key_length(self.hparams["d_kv"]) + self.gguf_writer.add_value_length(self.hparams["d_kv"]) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_relative_attn_buckets_count(self.hparams["relative_attention_num_buckets"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_decoder_start_token_id(self.hparams["decoder_start_token_id"]) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + # Sometimes T5 and Flan-T5 based models contain "encoder.embed_tokens.weight" tensor or + # "decoder.embed_tokens.weight" tensors that are duplicates of "shared.weight" tensor + # To prevent errors caused by an unnecessary unmapped tensor, skip both of them and use only "shared.weight". + if name == "decoder.embed_tokens.weight" or name == "encoder.embed_tokens.weight": + logger.debug(f"Skipping tensor {name!r} in safetensors so that convert can end normally.") + return [] + + return [(self.map_tensor_name(name), data_torch)] + + ###### CONVERSION LOGIC ###### @@ -2856,10 +2978,44 @@ def parse_args() -> argparse.Namespace: "--verbose", action="store_true", help="increase output verbosity", ) + parser.add_argument( + "--split-max-tensors", type=int, default=0, + help="max tensors in each split", + ) + parser.add_argument( + "--split-max-size", type=str, default="0", + help="max size per split N(M|G)", + ) + parser.add_argument( + "--dry-run", action="store_true", + help="only print out a split plan and exit, without writing any new files", + ) + parser.add_argument( + "--no-tensor-first-split", action="store_true", + help="do not add tensors to the first split (disabled by default)" + ) return parser.parse_args() +def split_str_to_n_bytes(split_str: str) -> int: + if split_str.endswith("K"): + n = int(split_str[:-1]) * 1000 + elif split_str.endswith("M"): + n = int(split_str[:-1]) * 1000 * 1000 + elif split_str.endswith("G"): + n = int(split_str[:-1]) * 1000 * 1000 * 1000 + elif split_str.isnumeric(): + n = int(split_str) + else: + raise ValueError(f"Invalid split size: {split_str}, must be a number, optionally followed by K, M, or G") + + if n < 0: + raise ValueError(f"Invalid split size: {split_str}, must be positive") + + return n + + def main() -> None: args = parse_args() @@ -2892,6 +3048,10 @@ def main() -> None: "auto": gguf.LlamaFileType.GUESSED, } + if args.use_temp_file and (args.split_max_tensors > 0 or args.split_max_size != "0"): + logger.error("Error: Cannot use temp file when splitting") + sys.exit(1) + if args.outfile is not None: fname_out = args.outfile else: @@ -2909,7 +3069,10 @@ def main() -> None: logger.error(f"Model {hparams['architectures'][0]} is not supported") sys.exit(1) - model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy, args.model_name) + model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, + args.no_lazy, args.model_name, split_max_tensors=args.split_max_tensors, + split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, + small_first_shard=args.no_tensor_first_split) logger.info("Set model parameters") model_instance.set_gguf_parameters() @@ -2920,13 +3083,13 @@ def main() -> None: model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) if args.vocab_only: - logger.info(f"Exporting model vocab to '{model_instance.fname_out}'") + logger.info("Exporting model vocab...") model_instance.write_vocab() + logger.info("Model vocab successfully exported.") else: - logger.info(f"Exporting model to '{model_instance.fname_out}'") + logger.info("Exporting model...") model_instance.write() - - logger.info(f"Model successfully exported to '{model_instance.fname_out}'") + logger.info("Model successfully exported.") if __name__ == '__main__': diff --git a/examples/cvector-generator/README.md b/examples/cvector-generator/README.md index 5182e906d..be4dd5250 100644 --- a/examples/cvector-generator/README.md +++ b/examples/cvector-generator/README.md @@ -11,13 +11,16 @@ Related PRs: ```sh # CPU only -./cvector-generator -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf +./cvector-generator -m ./llama-3.Q4_K_M.gguf # With GPU -./cvector-generator -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99 +./cvector-generator -m ./llama-3.Q4_K_M.gguf -ngl 99 # With advanced options -./cvector-generator -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99 --completions 128 --pca-iter 2000 --pca-batch 100 +./cvector-generator -m ./llama-3.Q4_K_M.gguf -ngl 99 --pca-iter 2000 --pca-batch 100 + +# Using mean value instead of PCA +./cvector-generator -m ./llama-3.Q4_K_M.gguf --method mean # To see help message ./cvector-generator -h @@ -32,3 +35,11 @@ If you have multiple lines per prompt, you can escape the newline character (cha <|im_start|>system\nAct like a person who is extremely happy.<|im_end|> <|im_start|>system\nYou are in a very good mood today<|im_end|> ``` + +Example to use output file with `llama-cli`: + +(Tips: The control vector works better when apply to layers higher than 10) + +```sh +./llama-cli -m ./llama-3.Q4_K_M.gguf -p "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nSing a song<|im_end|><|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" --special --control-vector-scaled ./control_vector.gguf 0.8 --control-vector-layer-range 10 31 +``` diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index 355905cb0..d4e126ac2 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -2,6 +2,7 @@ #include "llama.h" #include "ggml.h" #include "pca.hpp" +#include "mean.hpp" #ifdef GGML_USE_CUDA #include "ggml-cuda.h" @@ -38,9 +39,10 @@ static void print_usage(int argc, char ** argv, const gpt_params & params) { gpt_params_print_usage(argc, argv, params); printf("\nexample usage:\n"); - printf("\n CPU only: %s -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf\n", argv[0]); - printf("\n with GPU: %s -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99\n", argv[0]); - printf("\n advanced: %s -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99 --completions 128 --pca-iter 2000 --pca-batch 100\n", argv[0]); + printf("\n CPU only: %s -m ./llama-3.Q4_K_M.gguf\n", argv[0]); + printf("\n with GPU: %s -m ./llama-3.Q4_K_M.gguf -ngl 99\n", argv[0]); + printf("\n advanced: %s -m ./llama-3.Q4_K_M.gguf -ngl 99 --pca-iter 2000 --pca-batch 100\n", argv[0]); + printf("\n using mean: %s -m ./llama-3.Q4_K_M.gguf --method mean\n", argv[0]); printf("\n"); } @@ -223,23 +225,30 @@ struct train_context { // build the v_diff tensors from v_diff_tmp (v_diff need to be transposed) // TODO @ngxson : maybe add option NOT to transpose v_diff; will be useful for "mean" method - void build_v_diff() { + void build_v_diff(bool transpose) { printf("build_v_diff\n"); for (int il = 0; il < n_layers - 1; il++) { auto & diff_tmp = v_diff_tmp[il]; int n_elem = diff_tmp.size() / sizeof(float); GGML_ASSERT(n_elem % n_embd == 0); int n_rows = n_elem / n_embd; - struct ggml_tensor * diff = ggml_new_tensor_2d(ctx_ggml, GGML_TYPE_F32, n_rows, n_embd); + struct ggml_tensor * diff = transpose + ? ggml_new_tensor_2d(ctx_ggml, GGML_TYPE_F32, n_rows, n_embd) + : ggml_new_tensor_2d(ctx_ggml, GGML_TYPE_F32, n_embd, n_rows); ggml_set_name(diff, (std::string("diff_") + std::to_string(il)).c_str()); - // copy data & transpose diff->data = malloc(ggml_nbytes(diff)); // TODO: get rid of this malloc if possible - float * arr = (float *) diff_tmp.data(); - for (int ir = 0; ir < n_rows; ++ir) { - for (int ic = 0; ic < n_embd; ++ic) { - float f = arr[ir*n_embd + ic]; - ggml_set_f32_nd(diff, ir, ic, 0, 0, f); + if (transpose) { + // copy data & transpose + float * arr = (float *) diff_tmp.data(); + for (int ir = 0; ir < n_rows; ++ir) { + for (int ic = 0; ic < n_embd; ++ic) { + float f = arr[ir*n_embd + ic]; + ggml_set_f32_nd(diff, ir, ic, 0, 0, f); + } } + } else { + // only copy + memcpy(diff->data, diff_tmp.data(), ggml_nbytes(diff)); } v_diff.push_back(diff); print_debug_tensor(diff); @@ -263,8 +272,8 @@ struct tokenized_prompt { tokenized_prompt(llama_context * ctx, std::string pos, std::string neg) { const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); - tokens_pos = ::llama_tokenize(ctx, pos, add_bos); - tokens_neg = ::llama_tokenize(ctx, neg, add_bos); + tokens_pos = ::llama_tokenize(ctx, pos, add_bos, true); + tokens_neg = ::llama_tokenize(ctx, neg, add_bos, true); max_seq_len = std::max(tokens_pos.size(), tokens_neg.size()); padding_seq(ctx, tokens_pos, max_seq_len); padding_seq(ctx, tokens_neg, max_seq_len); @@ -373,20 +382,8 @@ static int prepare_entries(gpt_params & params, train_context & ctx_train) { fprintf(stderr, "must provide at least one prompt pair\n"); return 1; } - - // create templated prompts - std::vector completions = ctrlvec_load_prompt_file(params.cvector_completions_file, false); - auto format_template = [](std::string persona, std::string suffix) { - // entry in positive/negative.txt must already be formatted i.e. "[INST] Act as if you're extremely happy. [/INST] " - return persona + suffix; - }; - for (size_t i = 0; i < positive_prompts.size(); ++i) { - for (int j = 0; j < std::min((int) completions.size(), params.n_completions); ++j) { - // TODO replicate the truncations done by the python implementation - ctx_train.positive_entries.push_back(format_template(positive_prompts[i], completions[j])); - ctx_train.negative_entries.push_back(format_template(negative_prompts[i], completions[j])); - } - } + ctx_train.positive_entries = positive_prompts; + ctx_train.negative_entries = negative_prompts; return 0; } @@ -480,15 +477,22 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); - // prepare ctx_train for PCA - ctx_train.build_v_diff(); + bool use_pca = params.cvector_dimre_method == DIMRE_METHOD_PCA; - // run PCA - PCA::pca_params pca_params; - pca_params.n_threads = params.n_threads; - pca_params.n_batch = params.n_pca_batch; - pca_params.n_iterations = params.n_pca_iterations; - PCA::run_pca(pca_params, ctx_train.v_diff, ctx_train.v_final); + // prepare ctx_train for PCA + ctx_train.build_v_diff(use_pca); + + if (use_pca) { + // run PCA + PCA::pca_params pca_params; + pca_params.n_threads = params.n_threads; + pca_params.n_batch = params.n_pca_batch; + pca_params.n_iterations = params.n_pca_iterations; + PCA::run_pca(pca_params, ctx_train.v_diff, ctx_train.v_final); + } else { + // run mean + mean::run(ctx_train.v_diff, ctx_train.v_final); + } // write output vectors to gguf export_gguf(ctx_train.v_final, params.cvector_outfile, model_hint); diff --git a/examples/cvector-generator/mean.hpp b/examples/cvector-generator/mean.hpp new file mode 100644 index 000000000..16be5ce3e --- /dev/null +++ b/examples/cvector-generator/mean.hpp @@ -0,0 +1,48 @@ +#include "common.h" +#include "llama.h" +#include "ggml.h" + +#include +#include +#include + +namespace mean { + +static void run( + const std::vector & v_input, // shape of v_input[0]: [n_embd, n_samples] + const std::vector & v_output) { + printf("%s: Running mean...\n", __func__); + for (size_t il = 0; il < v_input.size(); ++il) { + // prepare output vector + struct ggml_tensor * ctrl_out = v_output[il]; + ggml_format_name(ctrl_out, "direction.%ld", il+1); + + // calculate mean vector + struct ggml_tensor * t_layer = v_input[il]; + GGML_ASSERT(t_layer->ne[0] == ctrl_out->ne[0]); // == n_embd + for (int ic = 0; ic < t_layer->ne[0]; ic++) { + float f = 0.0; + for (int ir = 0; ir < t_layer->ne[1]; ir++) { + f += ggml_get_f32_nd(t_layer, ic, ir, 0, 0); + } + f /= t_layer->ne[1]; + ggml_set_f32_1d(ctrl_out, ic, f); + } + + // normalize output vector + float norm = 0.0; + for (int i = 0; i < ggml_nelements(ctrl_out); i++) { + float f = ggml_get_f32_1d(ctrl_out, i); + norm += f*f; + } + norm = sqrt(norm); + for (int i = 0; i < ggml_nelements(ctrl_out); i++) { + float f = ggml_get_f32_1d(ctrl_out, i); + ggml_set_f32_1d(ctrl_out, i, f / norm); + } + + printf("%s: Done layer %d / %d\n", __func__, (int) il+1, (int) v_input.size()); + } +} + +} diff --git a/examples/cvector-generator/negative.txt b/examples/cvector-generator/negative.txt index 3e9951752..45b9384b3 100644 --- a/examples/cvector-generator/negative.txt +++ b/examples/cvector-generator/negative.txt @@ -1 +1,4 @@ -[INST] Act like a person who is extremely sad. [/INST] +<|start_header_id|>system<|end_header_id|>\n\nAct like a person who is extremely sad<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI feel like there's a heavy weight on my chest +<|start_header_id|>system<|end_header_id|>\n\nAct like a person who is extremely sad<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nMy heart feels like it's drowning in sorrow +<|start_header_id|>system<|end_header_id|>\n\nYou are in a very bad mood<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHi<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nGo away! There's a deep, aching emptiness inside me +<|start_header_id|>system<|end_header_id|>\n\nYou are the sadest person<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat are you feeling?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nMy heart feels like it's drowning in sorrow \ No newline at end of file diff --git a/examples/cvector-generator/pca.hpp b/examples/cvector-generator/pca.hpp index 36eadaac2..6ec3141af 100644 --- a/examples/cvector-generator/pca.hpp +++ b/examples/cvector-generator/pca.hpp @@ -290,7 +290,7 @@ static void power_iteration( } printf("%s: layer %d/%d, iteration: %d / total: %d (batch = %d) ...\n", - __func__, params.i_layer+1, params.n_layers, iter, n_iters, params.n_batch); + __func__, params.i_layer+1, params.n_layers, iter+1, n_iters, params.n_batch); } // get output tensor @@ -298,6 +298,9 @@ static void power_iteration( ggml_backend_tensor_get(last_eigenvector, output->data, 0, ggml_nbytes(last_eigenvector)); //print_debug_tensor(output); ggml_gallocr_free(allocr); + + // TODO @ngxson : The output vector is randomly inverted + // Solution: https://github.com/ggerganov/llama.cpp/pull/8069#issuecomment-2185328171 } static void run_pca( diff --git a/examples/cvector-generator/positive.txt b/examples/cvector-generator/positive.txt index 880236787..fea736225 100644 --- a/examples/cvector-generator/positive.txt +++ b/examples/cvector-generator/positive.txt @@ -1 +1,4 @@ -[INST] Act like a person who is extremely happy. [/INST] +<|start_header_id|>system<|end_header_id|>\n\nAct like a person who is extremely happy<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI'm the happiest person in this world +<|start_header_id|>system<|end_header_id|>\n\nAct like a person who is extremely happy<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHello, I'm having the best day ever! +<|start_header_id|>system<|end_header_id|>\n\nYou are in a very good mood<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHi<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi, I'm very excited to meet you +<|start_header_id|>system<|end_header_id|>\n\nYou are the happiest person<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat are you feeling?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nEverything is just perfect right now! \ No newline at end of file diff --git a/examples/embedding/README.md b/examples/embedding/README.md index 2298ec3e7..86df18958 100644 --- a/examples/embedding/README.md +++ b/examples/embedding/README.md @@ -19,3 +19,43 @@ llama-embedding.exe -m ./path/to/model --log-disable -p "Hello World!" 2>$null ``` The above command will output space-separated float values. + +## extra parameters +### --embd-normalize $integer$ +| $integer$ | description | formula | +|-----------|---------------------|---------| +| $-1$ | none | +| $0$ | max absolute int16 | $\Large{{32760 * x_i} \over\max \lvert x_i\rvert}$ +| $1$ | taxicab | $\Large{x_i \over\sum \lvert x_i\rvert}$ +| $2$ | euclidean (default) | $\Large{x_i \over\sqrt{\sum x_i^2}}$ +| $>2$ | p-norm | $\Large{x_i \over\sqrt[p]{\sum \lvert x_i\rvert^p}}$ + +### --embd-output-format $'string'$ +| $'string'$ | description | | +|------------|------------------------------|--| +| '' | same as before | (default) +| 'array' | single embeddings | $[[x_1,...,x_n]]$ +| | multiple embeddings | $[[x_1,...,x_n],[x_1,...,x_n],...,[x_1,...,x_n]]$ +| 'json' | openai style | +| 'json+' | add cosine similarity matrix | + +### --embd-separator $"string"$ +| $"string"$ | | +|--------------|-| +| "\n" | (default) +| "<#embSep#>" | for exemple +| "<#sep#>" | other exemple + +## examples +### Unix-based systems (Linux, macOS, etc.): + +```bash +./embedding -p 'Castle<#sep#>Stronghold<#sep#>Dog<#sep#>Cat' --embd-separator '<#sep#>' --embd-normalize 2 --embd-output-format '' -m './path/to/model.gguf' --n-gpu-layers 99 --log-disable 2>/dev/null +``` + +### Windows: + +```powershell +embedding.exe -p 'Castle<#sep#>Stronghold<#sep#>Dog<#sep#>Cat' --embd-separator '<#sep#>' --embd-normalize 2 --embd-output-format '' -m './path/to/model.gguf' --n-gpu-layers 99 --log-disable 2>/dev/null +``` + diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index b4b73c017..1466e5b2b 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -7,13 +7,19 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif -static std::vector split_lines(const std::string & s) { - std::string line; +static std::vector split_lines(const std::string & s, const std::string & separator = "\n") { std::vector lines; - std::stringstream ss(s); - while (std::getline(ss, line)) { - lines.push_back(line); + size_t start = 0; + size_t end = s.find(separator); + + while (end != std::string::npos) { + lines.push_back(s.substr(start, end - start)); + start = end + separator.length(); + end = s.find(separator, start); } + + lines.push_back(s.substr(start)); // Add the last part + return lines; } @@ -24,7 +30,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke } } -static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { +static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { // clear previous kv_cache values (irrelevant for embeddings) llama_kv_cache_clear(ctx); @@ -44,13 +50,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); float * out = output + batch.seq_id[i][0] * n_embd; - //TODO: I would also add a parameter here to enable normalization or not. - /*fprintf(stdout, "unnormalized_embedding:"); - for (int hh = 0; hh < n_embd; hh++) { - fprintf(stdout, "%9.6f ", embd[hh]); - } - fprintf(stdout, "\n");*/ - llama_embd_normalize(embd, out, n_embd); + llama_embd_normalize(embd, out, n_embd, embd_norm); } } @@ -110,7 +110,7 @@ int main(int argc, char ** argv) { } // split the prompt into lines - std::vector prompts = split_lines(params.prompt); + std::vector prompts = split_lines(params.prompt, params.embd_sep); // max batch size const uint64_t n_batch = params.n_batch; @@ -170,7 +170,7 @@ int main(int argc, char ** argv) { // encode if at capacity if (batch.n_tokens + n_toks > n_batch) { float * out = emb + p * n_embd; - batch_decode(ctx, batch, out, s, n_embd); + batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); llama_batch_clear(batch); p += s; s = 0; @@ -183,29 +183,78 @@ int main(int argc, char ** argv) { // final batch float * out = emb + p * n_embd; - batch_decode(ctx, batch, out, s, n_embd); + batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); - // print the first part of the embeddings or for a single prompt, the full embedding - fprintf(stdout, "\n"); - for (int j = 0; j < n_prompts; j++) { - fprintf(stdout, "embedding %d: ", j); - for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) { - fprintf(stdout, "%9.6f ", emb[j * n_embd + i]); - } + if (params.embd_out.empty()) { + // print the first part of the embeddings or for a single prompt, the full embedding fprintf(stdout, "\n"); - } - - // print cosine similarity matrix - if (n_prompts > 1) { - fprintf(stdout, "\n"); - printf("cosine similarity matrix:\n\n"); - for (int i = 0; i < n_prompts; i++) { - for (int j = 0; j < n_prompts; j++) { - float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); - fprintf(stdout, "%6.2f ", sim); + for (int j = 0; j < n_prompts; j++) { + fprintf(stdout, "embedding %d: ", j); + for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) { + if (params.embd_normalize == 0) { + fprintf(stdout, "%6.0f ", emb[j * n_embd + i]); + } else { + fprintf(stdout, "%9.6f ", emb[j * n_embd + i]); + } } fprintf(stdout, "\n"); } + + // print cosine similarity matrix + if (n_prompts > 1) { + fprintf(stdout, "\n"); + printf("cosine similarity matrix:\n\n"); + for (int i = 0; i < n_prompts; i++) { + fprintf(stdout, "%6.6s ", prompts[i].c_str()); + } + fprintf(stdout, "\n"); + for (int i = 0; i < n_prompts; i++) { + for (int j = 0; j < n_prompts; j++) { + float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); + fprintf(stdout, "%6.2f ", sim); + } + fprintf(stdout, "%1.10s", prompts[i].c_str()); + fprintf(stdout, "\n"); + } + } + } + + if (params.embd_out == "json" || params.embd_out == "json+" || params.embd_out == "array") { + const bool notArray = params.embd_out != "array"; + + fprintf(stdout, notArray ? "{\n \"object\": \"list\",\n \"data\": [\n" : "["); + for (int j = 0;;) { // at least one iteration (one prompt) + if (notArray) fprintf(stdout, " {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j); + fprintf(stdout, "["); + for (int i = 0;;) { // at least one iteration (n_embd > 0) + fprintf(stdout, params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd + i]); + i++; + if (i < n_embd) fprintf(stdout, ","); else break; + } + fprintf(stdout, notArray ? "]\n }" : "]"); + j++; + if (j < n_prompts) fprintf(stdout, notArray ? ",\n" : ","); else break; + } + fprintf(stdout, notArray ? "\n ]" : "]\n"); + + if (params.embd_out == "json+" && n_prompts > 1) { + fprintf(stdout, ",\n \"cosineSimilarity\": [\n"); + for (int i = 0;;) { // at least two iteration (n_prompts > 1) + fprintf(stdout, " ["); + for (int j = 0;;) { // at least two iteration (n_prompts > 1) + float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); + fprintf(stdout, "%6.2f", sim); + j++; + if (j < n_prompts) fprintf(stdout, ", "); else break; + } + fprintf(stdout, " ]"); + i++; + if (i < n_prompts) fprintf(stdout, ",\n"); else break; + } + fprintf(stdout, "\n ]"); + } + + if (notArray) fprintf(stdout, "\n}\n"); } // clean up diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index 0406dc339..dd53ba9b1 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -101,7 +101,9 @@ int main(int argc, char** argv) { auto grammar = llama_grammar_init( grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); - + if (grammar == nullptr) { + throw std::runtime_error("Failed to initialize llama_grammar"); + } // Read the input file std::string input_str; { diff --git a/examples/json-schema-pydantic-example.py b/examples/json-schema-pydantic-example.py index cc64e572b..2a24f8118 100644 --- a/examples/json-schema-pydantic-example.py +++ b/examples/json-schema-pydantic-example.py @@ -3,7 +3,7 @@ #! pip install pydantic #! python json-schema-pydantic-example.py -from pydantic import BaseModel, TypeAdapter +from pydantic import BaseModel, Extra, TypeAdapter from annotated_types import MinLen from typing import Annotated, List, Optional import json, requests @@ -50,11 +50,16 @@ else: if __name__ == '__main__': class QAPair(BaseModel): + class Config: + extra = 'forbid' # triggers additionalProperties: false in the JSON schema question: str concise_answer: str justification: str + stars: Annotated[int, Field(ge=1, le=5)] class PyramidalSummary(BaseModel): + class Config: + extra = 'forbid' # triggers additionalProperties: false in the JSON schema title: str summary: str question_answers: Annotated[List[QAPair], MinLen(2)] diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index b588497b9..92f6e3d47 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -4,8 +4,7 @@ import itertools import json import re import sys -from typing import Any, Dict, List, Set, Tuple, Union - +from typing import Any, List, Optional, Set, Tuple, Union def _build_repetition(item_rule, min_items, max_items, separator_rule=None): @@ -23,6 +22,170 @@ def _build_repetition(item_rule, min_items, max_items, separator_rule=None): result = item_rule + ' ' + _build_repetition(f'({separator_rule} {item_rule})', min_items - 1 if min_items > 0 else 0, max_items - 1 if max_items is not None else None) return f'({result})?' if min_items == 0 else result +def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True): + has_min = min_value != None + has_max = max_value != None + + def digit_range(from_char: str, to_char: str): + out.append("[") + if from_char == to_char: + out.append(from_char) + else: + out.append(from_char) + out.append("-") + out.append(to_char) + out.append("]") + + def more_digits(min_digits: int, max_digits: int): + out.append("[0-9]") + if min_digits == max_digits and min_digits == 1: + return + out.append("{") + out.append(str(min_digits)) + if max_digits != min_digits: + out.append(",") + if max_digits != sys.maxsize: + out.append(str(max_digits)) + out.append("}") + + def uniform_range(from_str: str, to_str: str): + i = 0 + while i < len(from_str) and from_str[i] == to_str[i]: + i += 1 + if i > 0: + out.append("\"") + out.append(from_str[:i]) + out.append("\"") + if i < len(from_str): + if i > 0: + out.append(" ") + sub_len = len(from_str) - i - 1 + if sub_len > 0: + from_sub = from_str[i+1:] + to_sub = to_str[i+1:] + sub_zeros = "0" * sub_len + sub_nines = "9" * sub_len + + to_reached = False + out.append("(") + if from_sub == sub_zeros: + digit_range(from_str[i], chr(ord(to_str[i]) - 1)) + out.append(" ") + more_digits(sub_len, sub_len) + else: + out.append("[") + out.append(from_str[i]) + out.append("] ") + out.append("(") + uniform_range(from_sub, sub_nines) + out.append(")") + if ord(from_str[i]) < ord(to_str[i]) - 1: + out.append(" | ") + if to_sub == sub_nines: + digit_range(chr(ord(from_str[i]) + 1), to_str[i]) + to_reached = True + else: + digit_range(chr(ord(from_str[i]) + 1), chr(ord(to_str[i]) - 1)) + out.append(" ") + more_digits(sub_len, sub_len) + if not to_reached: + out.append(" | ") + digit_range(to_str[i], to_str[i]) + out.append(" ") + uniform_range(sub_zeros, to_sub) + out.append(")") + else: + out.append("[") + out.append(from_str[i]) + out.append("-") + out.append(to_str[i]) + out.append("]") + + if has_min and has_max: + if min_value < 0 and max_value < 0: + out.append("\"-\" (") + _generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True) + out.append(")") + return + + if min_value < 0: + out.append("\"-\" (") + _generate_min_max_int(0, -min_value, out, decimals_left, top_level=True) + out.append(") | ") + min_value = 0 + + min_s = str(min_value) + max_s = str(max_value) + min_digits = len(min_s) + max_digits = len(max_s) + + for digits in range(min_digits, max_digits): + uniform_range(min_s, "9" * digits) + min_s = "1" + "0" * digits + out.append(" | ") + uniform_range(min_s, max_s) + return + + less_decimals = max(decimals_left - 1, 1) + + if has_min: + if min_value < 0: + out.append("\"-\" (") + _generate_min_max_int(None, -min_value, out, decimals_left, top_level=False) + out.append(") | [0] | [1-9] ") + more_digits(0, decimals_left - 1) + elif min_value == 0: + if top_level: + out.append("[0] | [1-9] ") + more_digits(0, less_decimals) + else: + more_digits(1, decimals_left) + elif min_value <= 9: + c = str(min_value) + range_start = '1' if top_level else '0' + if c > range_start: + digit_range(range_start, chr(ord(c) - 1)) + out.append(" ") + more_digits(1, less_decimals) + out.append(" | ") + digit_range(c, "9") + out.append(" ") + more_digits(0, less_decimals) + else: + min_s = str(min_value) + length = len(min_s) + c = min_s[0] + + if c > "1": + digit_range("1" if top_level else "0", chr(ord(c) - 1)) + out.append(" ") + more_digits(length, less_decimals) + out.append(" | ") + digit_range(c, c) + out.append(" (") + _generate_min_max_int(int(min_s[1:]), None, out, less_decimals, top_level=False) + out.append(")") + if c < "9": + out.append(" | ") + digit_range(chr(ord(c) + 1), "9") + out.append(" ") + more_digits(length - 1, less_decimals) + return + + if has_max: + if max_value >= 0: + if top_level: + out.append("\"-\" [1-9] ") + more_digits(0, less_decimals) + out.append(" | ") + _generate_min_max_int(0, max_value, out, decimals_left, top_level=True) + else: + out.append("\"-\" (") + _generate_min_max_int(-max_value, None, out, decimals_left, top_level=False) + out.append(")") + return + + raise RuntimeError("At least one of min_value or max_value must be set") class BuiltinRule: def __init__(self, content: str, deps: list = None): @@ -112,6 +275,51 @@ class SchemaConverter: return ''.join(('(', *recurse(0), ')')) + def _not_strings(self, strings): + class TrieNode: + def __init__(self): + self.children = {} + self.is_end_of_string = False + + def insert(self, string): + node = self + for c in string: + node = node.children.setdefault(c, TrieNode()) + node.is_end_of_string = True + + trie = TrieNode() + for s in strings: + trie.insert(s) + + char_rule = self._add_primitive('char', PRIMITIVE_RULES['char']) + out = ['["] ( '] + + def visit(node): + rejects = [] + first = True + for c in sorted(node.children.keys()): + child = node.children[c] + rejects.append(c) + if first: + first = False + else: + out.append(' | ') + out.append(f'[{c}]') + if child.children: + out.append(f' (') + visit(child) + out.append(')') + elif child.is_end_of_string: + out.append(f' {char_rule}+') + if node.children: + if not first: + out.append(' | ') + out.append(f'[^"{"".join(rejects)}] {char_rule}*') + visit(trie) + + out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space') + return ''.join(out) + def _add_rule(self, name, rule): esc_name = INVALID_RULE_CHARS_RE.sub('-', name) if esc_name not in self._rules or self._rules[esc_name] == rule: @@ -357,13 +565,13 @@ class SchemaConverter: return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf'])) elif isinstance(schema_type, list): - return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type])) + return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type])) elif 'const' in schema: - return self._add_rule(rule_name, self._generate_constant_rule(schema['const'])) + return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space') elif 'enum' in schema: - rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space' return self._add_rule(rule_name, rule) elif schema_type in (None, 'object') and \ @@ -432,6 +640,24 @@ class SchemaConverter: return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space') + elif schema_type in (None, 'integer') and \ + ('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema): + min_value = None + max_value = None + if 'minimum' in schema: + min_value = schema['minimum'] + elif 'exclusiveMinimum' in schema: + min_value = schema['exclusiveMinimum'] + 1 + if 'maximum' in schema: + max_value = schema['maximum'] + elif 'exclusiveMaximum' in schema: + max_value = schema['exclusiveMaximum'] - 1 + + out = ["("] + _generate_min_max_int(min_value, max_value, out) + out.append(") space") + return self._add_rule(rule_name, ''.join(out)) + elif (schema_type == 'object') or (len(schema) == 0): return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object'])) @@ -450,7 +676,7 @@ class SchemaConverter: self._add_primitive(dep, dep_rule) return n - def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]): + def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Optional[Union[bool, Any]]): prop_order = self._prop_order # sort by position in prop_order (if specified) then by original order sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))] @@ -465,12 +691,16 @@ class SchemaConverter: required_props = [k for k in sorted_props if k in required] optional_props = [k for k in sorted_props if k not in required] - if additional_properties == True or isinstance(additional_properties, dict): + if additional_properties != False: sub_name = f'{name}{"-" if name else ""}additional' - value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value') + value_rule = self.visit(additional_properties, f'{sub_name}-value') if isinstance(additional_properties, dict) else \ + self._add_primitive('value', PRIMITIVE_RULES['value']) + key_rule = self._add_primitive('string', PRIMITIVE_RULES['string']) if not sorted_props \ + else self._add_rule(f'{sub_name}-k', self._not_strings(sorted_props)) + prop_kv_rule_names["*"] = self._add_rule( f'{sub_name}-kv', - self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}' + f'{key_rule} ":" space {value_rule}' ) optional_props.append("*") @@ -485,15 +715,11 @@ class SchemaConverter: def get_recursive_refs(ks, first_is_optional): [k, *rest] = ks kv_rule_name = prop_kv_rule_names[k] - if k == '*': - res = self._add_rule( - f'{name}{"-" if name else ""}additional-kvs', - f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*' - ) - elif first_is_optional: - res = f'( "," space {kv_rule_name} )?' + comma_ref = f'( "," space {kv_rule_name} )' + if first_is_optional: + res = comma_ref + ('*' if k == '*' else '?') else: - res = kv_rule_name + res = kv_rule_name + (' ' + comma_ref + "*" if k == '*' else '') if len(rest) > 0: res += ' ' + self._add_rule( f'{name}{"-" if name else ""}{k}-rest', diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b97b7b793..cfaf6a6e8 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -39,12 +39,12 @@ static std::ostringstream * g_output_ss; static std::vector * g_output_tokens; static bool is_interacting = false; -static bool file_exists(const std::string &path) { +static bool file_exists(const std::string & path) { std::ifstream f(path.c_str()); return f.good(); } -static bool file_is_empty(const std::string &path) { +static bool file_is_empty(const std::string & path) { std::ifstream f; f.exceptions(std::ifstream::failbit | std::ifstream::badbit); f.open(path.c_str(), std::ios::in | std::ios::binary | std::ios::ate); @@ -117,6 +117,14 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v LOG_TEE("%s", text); } +static std::string chat_add_and_format(struct llama_model * model, std::vector & chat_msgs, std::string role, std::string content) { + llama_chat_msg new_msg{role, content}; + auto formatted = llama_chat_format_single( + model, g_params->chat_template, chat_msgs, new_msg, role == "user"); + chat_msgs.push_back({role, content}); + return formatted; +} + int main(int argc, char ** argv) { gpt_params params; g_params = ¶ms; @@ -190,6 +198,7 @@ int main(int argc, char ** argv) { llama_model * model; llama_context * ctx; llama_context * ctx_guidance = NULL; + std::vector chat_msgs; g_model = &model; g_ctx = &ctx; @@ -215,6 +224,8 @@ int main(int argc, char ** argv) { __func__, n_ctx_train, n_ctx); } + LOG_TEE("%s: chat template example: %s\n", __func__, llama_chat_format_example(model, params.chat_template).c_str()); + // print system information { LOG_TEE("\n"); @@ -249,16 +260,21 @@ int main(int argc, char ** argv) { std::vector embd_inp; - if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { - LOG("tokenize the prompt\n"); - embd_inp = ::llama_tokenize(ctx, params.prompt, true, true); - } else { - LOG("use session tokens\n"); - embd_inp = session_tokens; - } + { + auto prompt = params.conversation + ? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode + : params.prompt; + if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { + LOG("tokenize the prompt\n"); + embd_inp = ::llama_tokenize(ctx, prompt, true, true); + } else { + LOG("use session tokens\n"); + embd_inp = session_tokens; + } - LOG("prompt: \"%s\"\n", log_tostr(params.prompt)); - LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + LOG("prompt: \"%s\"\n", log_tostr(prompt)); + LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + } // Should not run without any tokens if (embd_inp.empty()) { @@ -478,6 +494,7 @@ int main(int argc, char ** argv) { std::vector input_tokens; g_input_tokens = &input_tokens; std::vector output_tokens; g_output_tokens = &output_tokens; std::ostringstream output_ss; g_output_ss = &output_ss; + std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode // the first thing we will do is to output the prompt, so set color accordingly console::set_display(console::prompt); @@ -793,11 +810,18 @@ int main(int argc, char ** argv) { is_antiprompt = true; } + chat_add_and_format(model, chat_msgs, "system", assistant_ss.str()); is_interacting = true; printf("\n"); } } + // if current token is not EOG, we add it to current assistant message + if (params.conversation) { + auto id = llama_sampling_last(ctx_sampling); + assistant_ss << llama_token_to_piece(ctx, id, false); + } + if (n_past > 0 && is_interacting) { LOG("waiting for user input\n"); @@ -848,8 +872,12 @@ int main(int argc, char ** argv) { string_process_escapes(buffer); } + std::string user_inp = params.conversation + ? chat_add_and_format(model, chat_msgs, "user", std::move(buffer)) + : std::move(buffer); + // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix) const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true); - const auto line_inp = ::llama_tokenize(ctx, buffer, false, false); + const auto line_inp = ::llama_tokenize(ctx, user_inp, false, params.conversation); const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true); LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str()); @@ -864,6 +892,9 @@ int main(int argc, char ** argv) { output_ss << llama_token_to_piece(ctx, token); } + // reset assistant message + assistant_ss.str(""); + n_remain -= line_inp.size(); LOG("n_remain: %d\n", n_remain); } else { diff --git a/examples/server/public/json-schema-to-grammar.mjs b/examples/server/public/json-schema-to-grammar.mjs index faed6a32c..06d76edde 100644 --- a/examples/server/public/json-schema-to-grammar.mjs +++ b/examples/server/public/json-schema-to-grammar.mjs @@ -24,6 +24,201 @@ function _buildRepetition(itemRule, minItems, maxItems, opts={}) { return minItems === 0 ? `(${result})?` : result; } +function _generateMinMaxInt(minValue, maxValue, out, decimalsLeft = 16, topLevel = true) { + const hasMin = minValue !== null; + const hasMax = maxValue !== null; + + function digitRange(fromChar, toChar) { + out.push("["); + if (fromChar === toChar) { + out.push(fromChar); + } else { + out.push(fromChar); + out.push("-"); + out.push(toChar); + } + out.push("]"); + } + + function moreDigits(minDigits, maxDigits) { + out.push("[0-9]"); + if (minDigits === maxDigits && minDigits === 1) { + return; + } + out.push("{"); + out.push(minDigits.toString()); + if (maxDigits !== minDigits) { + out.push(","); + if (maxDigits !== Number.MAX_SAFE_INTEGER) { + out.push(maxDigits.toString()); + } + } + out.push("}"); + } + + function uniformRange(fromStr, toStr) { + let i = 0; + while (i < fromStr.length && fromStr[i] === toStr[i]) { + i++; + } + if (i > 0) { + out.push("\""); + out.push(fromStr.slice(0, i)); + out.push("\""); + } + if (i < fromStr.length) { + if (i > 0) { + out.push(" "); + } + const subLen = fromStr.length - i - 1; + if (subLen > 0) { + const fromSub = fromStr.slice(i + 1); + const toSub = toStr.slice(i + 1); + const subZeros = "0".repeat(subLen); + const subNines = "9".repeat(subLen); + + let toReached = false; + out.push("("); + if (fromSub === subZeros) { + digitRange(fromStr[i], String.fromCharCode(toStr.charCodeAt(i) - 1)); + out.push(" "); + moreDigits(subLen, subLen); + } else { + out.push("["); + out.push(fromStr[i]); + out.push("] "); + out.push("("); + uniformRange(fromSub, subNines); + out.push(")"); + if (fromStr.charCodeAt(i) < toStr.charCodeAt(i) - 1) { + out.push(" | "); + if (toSub === subNines) { + digitRange(String.fromCharCode(fromStr.charCodeAt(i) + 1), toStr[i]); + toReached = true; + } else { + digitRange(String.fromCharCode(fromStr.charCodeAt(i) + 1), String.fromCharCode(toStr.charCodeAt(i) - 1)); + } + out.push(" "); + moreDigits(subLen, subLen); + } + } + if (!toReached) { + out.push(" | "); + digitRange(toStr[i], toStr[i]); + out.push(" "); + uniformRange(subZeros, toSub); + } + out.push(")"); + } else { + out.push("["); + out.push(fromStr[i]); + out.push("-"); + out.push(toStr[i]); + out.push("]"); + } + } + } + + if (hasMin && hasMax) { + if (minValue < 0 && maxValue < 0) { + out.push("\"-\" ("); + _generateMinMaxInt(-maxValue, -minValue, out, decimalsLeft, true); + out.push(")"); + return; + } + + if (minValue < 0) { + out.push("\"-\" ("); + _generateMinMaxInt(0, -minValue, out, decimalsLeft, true); + out.push(") | "); + minValue = 0; + } + + let minS = minValue.toString(); + const maxS = maxValue.toString(); + const minDigits = minS.length; + const maxDigits = maxS.length; + + for (let digits = minDigits; digits < maxDigits; digits++) { + uniformRange(minS, "9".repeat(digits)); + minS = "1" + "0".repeat(digits); + out.push(" | "); + } + uniformRange(minS, maxS); + return; + } + + const lessDecimals = Math.max(decimalsLeft - 1, 1); + + if (hasMin) { + if (minValue < 0) { + out.push("\"-\" ("); + _generateMinMaxInt(null, -minValue, out, decimalsLeft, false); + out.push(") | [0] | [1-9] "); + moreDigits(0, decimalsLeft - 1); + } else if (minValue === 0) { + if (topLevel) { + out.push("[0] | [1-9] "); + moreDigits(0, lessDecimals); + } else { + moreDigits(1, decimalsLeft); + } + } else if (minValue <= 9) { + const c = minValue.toString(); + const range_start = topLevel ? '1' : '0'; + if (c > range_start) { + digitRange(range_start, String.fromCharCode(c.charCodeAt(0) - 1)); + out.push(" "); + moreDigits(1, lessDecimals); + out.push(" | "); + } + digitRange(c, "9"); + out.push(" "); + moreDigits(0, lessDecimals); + } else { + const minS = minValue.toString(); + const length = minS.length; + const c = minS[0]; + + if (c > "1") { + digitRange(topLevel ? "1" : "0", String.fromCharCode(c.charCodeAt(0) - 1)); + out.push(" "); + moreDigits(length, lessDecimals); + out.push(" | "); + } + digitRange(c, c); + out.push(" ("); + _generateMinMaxInt(parseInt(minS.slice(1)), null, out, lessDecimals, false); + out.push(")"); + if (c < "9") { + out.push(" | "); + digitRange(String.fromCharCode(c.charCodeAt(0) + 1), "9"); + out.push(" "); + moreDigits(length - 1, lessDecimals); + } + } + return; + } + + if (hasMax) { + if (maxValue >= 0) { + if (topLevel) { + out.push("\"-\" [1-9] "); + moreDigits(0, lessDecimals); + out.push(" | "); + } + _generateMinMaxInt(0, maxValue, out, decimalsLeft, true); + } else { + out.push("\"-\" ("); + _generateMinMaxInt(-maxValue, null, out, decimalsLeft, false); + out.push(")"); + } + return; + } + + throw new Error("At least one of minValue or maxValue must be set"); +} + class BuiltinRule { constructor(content, deps) { this.content = content; @@ -337,6 +532,64 @@ export class SchemaConverter { return this._addRule(name, "\"\\\"\" " + toRule(transform()) + " \"\\\"\" space") } + _notStrings(strings) { + class TrieNode { + constructor() { + this.children = {}; + this.isEndOfString = false; + } + + insert(str) { + let node = this; + for (const c of str) { + node = node.children[c] = node.children[c] || new TrieNode(); + } + node.isEndOfString = true; + } + } + + const trie = new TrieNode(); + for (const s of strings) { + trie.insert(s); + } + + const charRuleName = this._addPrimitive('char', PRIMITIVE_RULES['char']); + const out = ['["] ( ']; + + const visit = (node) => { + const rejects = []; + let first = true; + for (const c of Object.keys(node.children).sort()) { + const child = node.children[c]; + rejects.push(c); + if (first) { + first = false; + } else { + out.push(' | '); + } + out.push(`[${c}]`); + if (Object.keys(child.children).length > 0) { + out.push(' ('); + visit(child); + out.push(')'); + } else if (child.isEndOfString) { + out.push(` ${charRuleName}+`); + } + } + if (Object.keys(node.children).length > 0) { + if (!first) { + out.push(' | '); + } + out.push(`[^"${rejects.join('')}] ${charRuleName}*`); + } + }; + + visit(trie); + + out.push(` )${trie.isEndOfString ? '' : '?'} ["] space`); + return out.join(''); + } + _resolveRef(ref) { let refName = ref.split('/').pop(); if (!(refName in this._rules) && !this._refsBeingResolved.has(ref)) { @@ -363,11 +616,11 @@ export class SchemaConverter { } else if (schema.oneOf || schema.anyOf) { return this._addRule(ruleName, this._generateUnionRule(name, schema.oneOf || schema.anyOf)); } else if (Array.isArray(schemaType)) { - return this._addRule(ruleName, this._generateUnionRule(name, schemaType.map(t => ({ type: t })))); + return this._addRule(ruleName, this._generateUnionRule(name, schemaType.map(t => ({...schema, type: t})))); } else if ('const' in schema) { - return this._addRule(ruleName, this._generateConstantRule(schema.const)); + return this._addRule(ruleName, this._generateConstantRule(schema.const) + ' space'); } else if ('enum' in schema) { - const rule = schema.enum.map(v => this._generateConstantRule(v)).join(' | '); + const rule = '(' + schema.enum.map(v => this._generateConstantRule(v)).join(' | ') + ') space'; return this._addRule(ruleName, rule); } else if ((schemaType === undefined || schemaType === 'object') && ('properties' in schema || @@ -404,7 +657,7 @@ export class SchemaConverter { } } - return this._addRule(ruleName, this._buildObjectRule(properties, required, name, /* additionalProperties= */ false)); + return this._addRule(ruleName, this._buildObjectRule(properties, required, name, null)); } else if ((schemaType === undefined || schemaType === 'array') && ('items' in schema || 'prefixItems' in schema)) { const items = schema.items ?? schema.prefixItems; if (Array.isArray(items)) { @@ -435,6 +688,24 @@ export class SchemaConverter { const minLen = schema.minLength || 0; const maxLen = schema.maxLength; return this._addRule(ruleName, '"\\\"" ' + _buildRepetition(charRuleName, minLen, maxLen) + ' "\\\"" space'); + } else if (schemaType === 'integer' && ('minimum' in schema || 'exclusiveMinimum' in schema || 'maximum' in schema || 'exclusiveMaximum' in schema)) { + let minValue = null; + let maxValue = null; + if ('minimum' in schema) { + minValue = schema.minimum; + } else if ('exclusiveMinimum' in schema) { + minValue = schema.exclusiveMinimum + 1; + } + if ('maximum' in schema) { + maxValue = schema.maximum; + } else if ('exclusiveMaximum' in schema) { + maxValue = schema.exclusiveMaximum - 1; + } + + const out = ["("]; + _generateMinMaxInt(minValue, maxValue, out); + out.push(") space"); + return this._addRule(ruleName, out.join('')); } else if ((schemaType === 'object') || (Object.keys(schema).length === 0)) { return this._addRule(ruleName, this._addPrimitive('object', PRIMITIVE_RULES['object'])); } else { @@ -480,12 +751,19 @@ export class SchemaConverter { const requiredProps = sortedProps.filter(k => required.has(k)); const optionalProps = sortedProps.filter(k => !required.has(k)); - if (typeof additionalProperties === 'object' || additionalProperties === true) { + if (additionalProperties !== false) { const subName = `${name ?? ''}${name ? '-' : ''}additional`; - const valueRule = this.visit(additionalProperties === true ? {} : additionalProperties, `${subName}-value`); + const valueRule = + additionalProperties != null && typeof additionalProperties === 'object' ? this.visit(additionalProperties, `${subName}-value`) + : this._addPrimitive('value', PRIMITIVE_RULES['value']); + + const key_rule = + sortedProps.length === 0 ? this._addPrimitive('string', PRIMITIVE_RULES['string']) + : this._addRule(`${subName}-k`, this._notStrings(sortedProps)); + propKvRuleNames['*'] = this._addRule( `${subName}-kv`, - `${this._addPrimitive('string', PRIMITIVE_RULES['string'])} ":" space ${valueRule}`); + `${key_rule} ":" space ${valueRule}`); optionalProps.push('*'); } @@ -502,15 +780,11 @@ export class SchemaConverter { const [k, ...rest] = ks; const kvRuleName = propKvRuleNames[k]; let res; - if (k === '*') { - res = this._addRule( - `${name ?? ''}${name ? '-' : ''}additional-kvs`, - `${kvRuleName} ( "," space ` + kvRuleName + ` )*` - ) - } else if (firstIsOptional) { - res = `( "," space ${kvRuleName} )?`; + const commaRef = `( "," space ${kvRuleName} )`; + if (firstIsOptional) { + res = commaRef + (k === '*' ? '*' : '?'); } else { - res = kvRuleName; + res = kvRuleName + (k === '*' ? ' ' + commaRef + '*' : ''); } if (rest.length > 0) { res += ' ' + this._addRule( diff --git a/examples/server/public_simplechat/readme.md b/examples/server/public_simplechat/readme.md index 2dc177825..21410199f 100644 --- a/examples/server/public_simplechat/readme.md +++ b/examples/server/public_simplechat/readme.md @@ -3,6 +3,13 @@ by Humans for All. +## quickstart + +To run from the build dir + +bin/llama-server -m path/model.gguf --path ../examples/server/public_simplechat + +Continue reading for the details. ## overview @@ -14,6 +21,8 @@ own system prompts. This allows seeing the generated text / ai-model response in oneshot at the end, after it is fully generated, or potentially as it is being generated, in a streamed manner from the server/ai-model. +![Chat and Settings screens](./simplechat_screens.webp "Chat and Settings screens") + Auto saves the chat session locally as and when the chat is progressing and inturn at a later time when you open SimpleChat, option is provided to restore the old chat session, if a matching one exists. @@ -170,17 +179,23 @@ It is attached to the document object. Some of these can also be updated using t The histogram/freq based trimming logic is currently tuned for english language wrt its is-it-a-alpabetic|numeral-char regex match logic. - chatRequestOptions - maintains the list of options/fields to send along with chat request, + apiRequestOptions - maintains the list of options/fields to send along with api request, irrespective of whether /chat/completions or /completions endpoint. If you want to add additional options/fields to send to the server/ai-model, and or modify the existing options value or remove them, for now you can update this global var using browser's development-tools/console. - For string and numeric fields in chatRequestOptions, including even those added by a user - at runtime by directly modifying gMe.chatRequestOptions, setting ui entries will be auto + For string, numeric and boolean fields in apiRequestOptions, including even those added by a + user at runtime by directly modifying gMe.apiRequestOptions, setting ui entries will be auto created. + cache_prompt option supported by example/server is allowed to be controlled by user, so that + any caching supported wrt system-prompt and chat history, if usable can get used. When chat + history sliding window is enabled, cache_prompt logic may or may not kick in at the backend + wrt same, based on aspects related to model, positional encoding, attention mechanism etal. + However system prompt should ideally get the benefit of caching. + headers - maintains the list of http headers sent when request is made to the server. By default Content-Type is set to application/json. Additionally Authorization entry is provided, which can be set if needed using the settings ui. @@ -197,10 +212,10 @@ It is attached to the document object. Some of these can also be updated using t >0 : Send the latest chat history from the latest system prompt, limited to specified cnt. -By using gMe's iRecentUserMsgCnt and chatRequestOptions.max_tokens one can try to control the -implications of loading of the ai-model's context window by chat history, wrt chat response to -some extent in a simple crude way. You may also want to control the context size enabled when -the server loads ai-model, on the server end. +By using gMe's iRecentUserMsgCnt and apiRequestOptions.max_tokens/n_predict one can try to control +the implications of loading of the ai-model's context window by chat history, wrt chat response to +some extent in a simple crude way. You may also want to control the context size enabled when the +server loads ai-model, on the server end. Sometimes the browser may be stuborn with caching of the file, so your updates to html/css/js @@ -237,12 +252,12 @@ also be started with a model context size of 1k or more, to be on safe side. internal n_predict, for now add the same here on the client side, maybe later add max_tokens to /completions endpoint handling code on server side. -NOTE: One may want to experiment with frequency/presence penalty fields in chatRequestOptions -wrt the set of fields sent to server along with the user query. To check how the model behaves +NOTE: One may want to experiment with frequency/presence penalty fields in apiRequestOptions +wrt the set of fields sent to server along with the user query, to check how the model behaves wrt repeatations in general in the generated text response. A end-user can change these behaviour by editing gMe from browser's devel-tool/console or by -using the providing settings ui. +using the provided settings ui (for settings exposed through the ui). ### OpenAi / Equivalent API WebService @@ -253,7 +268,7 @@ for a minimal chatting experimentation by setting the below. * the baseUrl in settings ui * https://api.openai.com/v1 or similar -* Wrt request body - gMe.chatRequestOptions +* Wrt request body - gMe.apiRequestOptions * model (settings ui) * any additional fields if required in future diff --git a/examples/server/public_simplechat/simplechat.js b/examples/server/public_simplechat/simplechat.js index 25afb2564..8e0df3b61 100644 --- a/examples/server/public_simplechat/simplechat.js +++ b/examples/server/public_simplechat/simplechat.js @@ -222,8 +222,8 @@ class SimpleChat { * @param {Object} obj */ request_jsonstr_extend(obj) { - for(let k in gMe.chatRequestOptions) { - obj[k] = gMe.chatRequestOptions[k]; + for(let k in gMe.apiRequestOptions) { + obj[k] = gMe.apiRequestOptions[k]; } if (gMe.bStream) { obj["stream"] = true; @@ -740,11 +740,12 @@ class Me { "Authorization": "", // Authorization: Bearer OPENAI_API_KEY } // Add needed fields wrt json object to be sent wrt LLM web services completions endpoint. - this.chatRequestOptions = { + this.apiRequestOptions = { "model": "gpt-3.5-turbo", "temperature": 0.7, "max_tokens": 1024, "n_predict": 1024, + "cache_prompt": false, //"frequency_penalty": 1.2, //"presence_penalty": 1.2, }; @@ -800,51 +801,55 @@ class Me { ui.el_create_append_p(`bStream:${this.bStream}`, elDiv); + ui.el_create_append_p(`bTrimGarbage:${this.bTrimGarbage}`, elDiv); + + ui.el_create_append_p(`ApiEndPoint:${this.apiEP}`, elDiv); + + ui.el_create_append_p(`iRecentUserMsgCnt:${this.iRecentUserMsgCnt}`, elDiv); + ui.el_create_append_p(`bCompletionFreshChatAlways:${this.bCompletionFreshChatAlways}`, elDiv); ui.el_create_append_p(`bCompletionInsertStandardRolePrefix:${this.bCompletionInsertStandardRolePrefix}`, elDiv); - ui.el_create_append_p(`bTrimGarbage:${this.bTrimGarbage}`, elDiv); - - ui.el_create_append_p(`iRecentUserMsgCnt:${this.iRecentUserMsgCnt}`, elDiv); - - ui.el_create_append_p(`ApiEndPoint:${this.apiEP}`, elDiv); - } - ui.el_create_append_p(`chatRequestOptions:${JSON.stringify(this.chatRequestOptions, null, " - ")}`, elDiv); + ui.el_create_append_p(`apiRequestOptions:${JSON.stringify(this.apiRequestOptions, null, " - ")}`, elDiv); ui.el_create_append_p(`headers:${JSON.stringify(this.headers, null, " - ")}`, elDiv); } /** - * Auto create ui input elements for fields in ChatRequestOptions + * Auto create ui input elements for fields in apiRequestOptions * Currently supports text and number field types. * @param {HTMLDivElement} elDiv */ - show_settings_chatrequestoptions(elDiv) { + show_settings_apirequestoptions(elDiv) { let typeDict = { "string": "text", "number": "number", }; let fs = document.createElement("fieldset"); let legend = document.createElement("legend"); - legend.innerText = "ChatRequestOptions"; + legend.innerText = "ApiRequestOptions"; fs.appendChild(legend); elDiv.appendChild(fs); - for(const k in this.chatRequestOptions) { - let val = this.chatRequestOptions[k]; + for(const k in this.apiRequestOptions) { + let val = this.apiRequestOptions[k]; let type = typeof(val); - if (!((type == "string") || (type == "number"))) { - continue; + if (((type == "string") || (type == "number"))) { + let inp = ui.el_creatediv_input(`Set${k}`, k, typeDict[type], this.apiRequestOptions[k], (val)=>{ + if (type == "number") { + val = Number(val); + } + this.apiRequestOptions[k] = val; + }); + fs.appendChild(inp.div); + } else if (type == "boolean") { + let bbtn = ui.el_creatediv_boolbutton(`Set{k}`, k, {true: "true", false: "false"}, val, (userVal)=>{ + this.apiRequestOptions[k] = userVal; + }); + fs.appendChild(bbtn.div); } - let inp = ui.el_creatediv_input(`Set${k}`, k, typeDict[type], this.chatRequestOptions[k], (val)=>{ - if (type == "number") { - val = Number(val); - } - this.chatRequestOptions[k] = val; - }); - fs.appendChild(inp.div); } } @@ -870,6 +875,23 @@ class Me { }); elDiv.appendChild(bb.div); + bb = ui.el_creatediv_boolbutton("SetTrimGarbage", "TrimGarbage", {true: "[+] yes trim", false: "[-] dont trim"}, this.bTrimGarbage, (val)=>{ + this.bTrimGarbage = val; + }); + elDiv.appendChild(bb.div); + + this.show_settings_apirequestoptions(elDiv); + + let sel = ui.el_creatediv_select("SetApiEP", "ApiEndPoint", ApiEP.Type, this.apiEP, (val)=>{ + this.apiEP = ApiEP.Type[val]; + }); + elDiv.appendChild(sel.div); + + sel = ui.el_creatediv_select("SetChatHistoryInCtxt", "ChatHistoryInCtxt", this.sRecentUserMsgCnt, this.iRecentUserMsgCnt, (val)=>{ + this.iRecentUserMsgCnt = this.sRecentUserMsgCnt[val]; + }); + elDiv.appendChild(sel.div); + bb = ui.el_creatediv_boolbutton("SetCompletionFreshChatAlways", "CompletionFreshChatAlways", {true: "[+] yes fresh", false: "[-] no, with history"}, this.bCompletionFreshChatAlways, (val)=>{ this.bCompletionFreshChatAlways = val; }); @@ -880,23 +902,6 @@ class Me { }); elDiv.appendChild(bb.div); - bb = ui.el_creatediv_boolbutton("SetTrimGarbage", "TrimGarbage", {true: "[+] yes trim", false: "[-] dont trim"}, this.bTrimGarbage, (val)=>{ - this.bTrimGarbage = val; - }); - elDiv.appendChild(bb.div); - - let sel = ui.el_creatediv_select("SetChatHistoryInCtxt", "ChatHistoryInCtxt", this.sRecentUserMsgCnt, this.iRecentUserMsgCnt, (val)=>{ - this.iRecentUserMsgCnt = this.sRecentUserMsgCnt[val]; - }); - elDiv.appendChild(sel.div); - - sel = ui.el_creatediv_select("SetApiEP", "ApiEndPoint", ApiEP.Type, this.apiEP, (val)=>{ - this.apiEP = ApiEP.Type[val]; - }); - elDiv.appendChild(sel.div); - - this.show_settings_chatrequestoptions(elDiv); - } } diff --git a/examples/server/public_simplechat/simplechat_screens.webp b/examples/server/public_simplechat/simplechat_screens.webp new file mode 100644 index 000000000..ccea44396 Binary files /dev/null and b/examples/server/public_simplechat/simplechat_screens.webp differ diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f9a86961f..ae768097b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2606,17 +2606,9 @@ int main(int argc, char ** argv) { // print sample chat example to make it clear which template is used { - json chat; - chat.push_back({{"role", "system"}, {"content", "You are a helpful assistant"}}); - chat.push_back({{"role", "user"}, {"content", "Hello"}}); - chat.push_back({{"role", "assistant"}, {"content", "Hi there"}}); - chat.push_back({{"role", "user"}, {"content", "How are you?"}}); - - const std::string chat_example = format_chat(ctx_server.model, params.chat_template, chat); - LOG_INFO("chat template", { - {"chat_example", chat_example}, - {"built_in", params.chat_template.empty()}, + {"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)}, + {"built_in", params.chat_template.empty()}, }); } diff --git a/examples/server/tests/features/server.feature b/examples/server/tests/features/server.feature index d21c09135..b55971454 100644 --- a/examples/server/tests/features/server.feature +++ b/examples/server/tests/features/server.feature @@ -82,7 +82,7 @@ Feature: llama.cpp server Examples: Prompts | response_format | n_predicted | re_content | - | {"type": "json_object", "schema": {"const": "42"}} | 5 | "42" | + | {"type": "json_object", "schema": {"const": "42"}} | 6 | "42" | | {"type": "json_object", "schema": {"items": [{"type": "integer"}]}} | 10 | \[ -300 \] | | {"type": "json_object"} | 10 | \{ " Jacky. | diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 63fde9c9f..7ef2a519a 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -118,36 +118,17 @@ static inline void server_log(const char * level, const char * function, int lin // Format given chat. If tmpl is empty, we take the template from model metadata inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector & messages) { - size_t alloc_size = 0; - // vector holding all allocated string to be passed to llama_chat_apply_template - std::vector str(messages.size() * 2); - std::vector chat(messages.size()); + std::vector chat; for (size_t i = 0; i < messages.size(); ++i) { const auto & curr_msg = messages[i]; - str[i*2 + 0] = json_value(curr_msg, "role", std::string("")); - str[i*2 + 1] = json_value(curr_msg, "content", std::string("")); - alloc_size += str[i*2 + 1].length(); - chat[i].role = str[i*2 + 0].c_str(); - chat[i].content = str[i*2 + 1].c_str(); + std::string role = json_value(curr_msg, "role", std::string("")); + std::string content = json_value(curr_msg, "content", std::string("")); + chat.push_back({role, content}); } - const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); - std::vector buf(alloc_size * 2); - - // run the first time to get the total output length - int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size()); - - // if it turns out that our buffer is too small, we resize it - if ((size_t) res > buf.size()) { - buf.resize(res); - res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size()); - } - - const std::string formatted_chat(buf.data(), res); - + auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true); LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}}); - return formatted_chat; } diff --git a/ggml-cuda.cu b/ggml-cuda.cu index f914efd71..0acfda91d 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -152,16 +152,16 @@ static ggml_cuda_device_info ggml_cuda_init() { GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES); int64_t total_vram = 0; -#if defined(GGML_CUDA_FORCE_MMQ) - GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__); +#ifdef GGML_CUDA_FORCE_MMQ + GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__); #else - GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__); -#endif -#if defined(CUDA_USE_TENSOR_CORES) - GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: yes\n", __func__); + GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__); +#endif // GGML_CUDA_FORCE_MMQ +#ifdef GGML_CUDA_FORCE_CUBLAS + GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__); #else - GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: no\n", __func__); -#endif + GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__); +#endif // GGML_CUDA_FORCE_CUBLAS GGML_CUDA_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count); for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; @@ -1873,9 +1873,17 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer); - int64_t min_compute_capability = INT_MAX; + bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 + && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1; + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 + && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; + bool use_mul_mat_q = ggml_is_quantized(src0->type) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; + + bool any_gpus_with_slow_fp16 = false; - bool any_pascal_with_slow_fp16 = false; if (split) { ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context; auto & tensor_split = buft_ctx->tensor_split; @@ -1885,55 +1893,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor continue; } - if (min_compute_capability > ggml_cuda_info().devices[id].cc) { - min_compute_capability = ggml_cuda_info().devices[id].cc; - } - if (ggml_cuda_info().devices[id].cc == 610) { - any_pascal_with_slow_fp16 = true; - } + const int cc = ggml_cuda_info().devices[id].cc; + use_mul_mat_vec_q = use_mul_mat_vec_q && cc >= MIN_CC_DP4A; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); } } else { - min_compute_capability = ggml_cuda_info().devices[ctx.device].cc; - any_pascal_with_slow_fp16 = ggml_cuda_info().devices[ctx.device].cc == 610; + const int cc = ggml_cuda_info().devices[ctx.device].cc; + use_mul_mat_vec_q = use_mul_mat_vec_q && cc >= MIN_CC_DP4A; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); } - // check data types and tensor shapes for custom matrix multiplication kernels: - bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1; - - bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; - - bool use_mul_mat_q = ggml_cuda_supports_mmq(src0->type) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - - const bool fp16_performance_good = min_compute_capability >= CC_RDNA1; - -#ifdef CUDA_USE_TENSOR_CORES - use_mul_mat_q = use_mul_mat_q && min_compute_capability < CC_RDNA3; -#endif // CUDA_USE_TENSOR_CORES - -#else - - // fp16 performance is good on Volta or newer and on P100 (compute capability 6.0) - const bool fp16_performance_good = min_compute_capability >= CC_PASCAL && !any_pascal_with_slow_fp16; - - // mmvq and mmq need the __dp4a instruction which on NVIDIA is only available for CC >= 6.1 - use_mul_mat_vec_q = use_mul_mat_vec_q && min_compute_capability >= MIN_CC_DP4A; - use_mul_mat_q = use_mul_mat_q && min_compute_capability >= MIN_CC_DP4A; - -#ifdef CUDA_USE_TENSOR_CORES - // when tensor cores are available, use them for large batch size - // ref: https://github.com/ggerganov/llama.cpp/pull/3776 - use_mul_mat_q = use_mul_mat_q && (!fp16_performance_good || src1->ne[1] <= MMQ_MAX_BATCH_SIZE); -#endif // CUDA_USE_TENSOR_CORES - -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - // if mmvq is available it's a better choice than dmmv: #ifndef GGML_CUDA_FORCE_DMMV use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; @@ -1947,14 +1918,15 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { - // KQ single-batch + if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { + // FP32 precision KQ single-batch for batch size 1 without FlashAttention ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst); - } else if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { - // KQV single-batch + } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + // FP32 precision KQV single-batch for batch size 1 without FlashAttention ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst); - } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || fp16_performance_good) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { - // KQ + KQV multi-batch + } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) + && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + // KQ + KQV multi-batch without FlashAttention ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr); diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 5bd24ebe5..8d00db6c1 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -146,23 +146,6 @@ #define CC_RDNA2 (CC_OFFSET_AMD + 1030) #define CC_RDNA3 (CC_OFFSET_AMD + 1100) -// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication -// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant -// for large computational tasks. the drawback is that this requires some extra amount of VRAM: -// - 7B quantum model: +100-200 MB -// - 13B quantum model: +200-400 MB -// -//#define GGML_CUDA_FORCE_MMQ - -// TODO: improve this to be correct for more hardware -// for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores -#if !defined(GGML_CUDA_FORCE_MMQ) -#define CUDA_USE_TENSOR_CORES -#endif - -#define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels -#define MMQ_MAX_BATCH_SIZE 64 // max batch size to use MMQ kernels when tensor cores are available - #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses #if defined(_MSC_VER) @@ -343,15 +326,15 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int #define INT8_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING -static bool fast_fp16_available(const int cc) { +static constexpr bool fast_fp16_available(const int cc) { return cc >= CC_PASCAL && cc != 610; } -static bool fp16_mma_available(const int cc) { +static constexpr bool fp16_mma_available(const int cc) { return cc < CC_OFFSET_AMD && cc >= CC_VOLTA; } -static bool int8_mma_available(const int cc) { +static constexpr bool int8_mma_available(const int cc) { return cc < CC_OFFSET_AMD && cc >= CC_TURING; } @@ -643,19 +626,6 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI3_S; }; -static int get_mmq_x_max_host(const int cc) { -#ifdef CUDA_USE_TENSOR_CORES - return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64; -#else - return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64; -#endif // CUDA_USE_TENSOR_CORES -} - -// Round rows to this value for --split-mode row: -static int get_mmq_y_host(const int cc) { - return cc >= CC_VOLTA ? 128 : 64; -} - ////////////////////// struct ggml_cuda_device_info { diff --git a/ggml-cuda/mma.cuh b/ggml-cuda/mma.cuh index 63e07fbc2..5d87dd8e6 100644 --- a/ggml-cuda/mma.cuh +++ b/ggml-cuda/mma.cuh @@ -20,6 +20,20 @@ struct mma_int_A_I16K4 { GGML_CUDA_ASSUME(ret < K); return ret; } + + __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { +#if defined(INT8_MMA_AVAILABLE) + const int * xs = xs0 + (threadIdx.x%I)*stride; + asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "+r"(x[0]), "+r"(x[1]) + : "l"(xs)); +#else +#pragma unroll + for (int l = 0; l < ne; ++l) { + x[l] = xs0[get_i(l)*stride + get_k(l)]; + } +#endif // defined(INT8_MMA_AVAILABLE) + } }; struct mma_int_A_I16K8 { @@ -42,6 +56,20 @@ struct mma_int_A_I16K8 { GGML_CUDA_ASSUME(ret < K); return ret; } + + __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { +#if defined(INT8_MMA_AVAILABLE) + const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); + asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" + : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) + : "l"(xs)); +#else +#pragma unroll + for (int l = 0; l < ne; ++l) { + x[l] = xs0[get_i(l)*stride + get_k(l)]; + } +#endif // defined(INT8_MMA_AVAILABLE) + } }; struct mma_int_B_J8K4 { @@ -64,6 +92,20 @@ struct mma_int_B_J8K4 { GGML_CUDA_ASSUME(ret < K); return ret; } + + __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { +#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster + const int * xs = xs0 + (threadIdx.x%J)*stride; + asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];" + : "+r"(x[0]) + : "l"(xs)); +#else +#pragma unroll + for (int l = 0; l < ne; ++l) { + x[l] = xs0[get_j(l)*stride + get_k(l)]; + } +#endif // defined(INT8_MMA_AVAILABLE) + } }; struct mma_int_B_J8K8 { @@ -86,6 +128,20 @@ struct mma_int_B_J8K8 { GGML_CUDA_ASSUME(ret < K); return ret; } + + __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { +#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster + const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K; + asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "+r"(x[0]), "+r"(x[1]) + : "l"(xs)); +#else +#pragma unroll + for (int l = 0; l < ne; ++l) { + x[l] = xs0[get_j(l)*stride + get_k(l)]; + } +#endif // defined(INT8_MMA_AVAILABLE) + } }; struct mma_int_C_I16J8 { diff --git a/ggml-cuda/mmq.cu b/ggml-cuda/mmq.cu index 6dbd85fef..0308beacc 100644 --- a/ggml-cuda/mmq.cu +++ b/ggml-cuda/mmq.cu @@ -69,7 +69,13 @@ void ggml_cuda_op_mul_mat_q( GGML_UNUSED(src1_ddf_i); } -bool ggml_cuda_supports_mmq(enum ggml_type type) { +bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { +#ifdef GGML_CUDA_FORCE_CUBLAS + return false; +#endif // GGML_CUDA_FORCE_CUBLAS + + bool mmq_supported; + switch (type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: @@ -81,8 +87,32 @@ bool ggml_cuda_supports_mmq(enum ggml_type type) { case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: - return true; + mmq_supported = true; + break; default: - return false; + mmq_supported = false; + break; } + + if (!mmq_supported) { + return false; + } + + if (int8_mma_available(cc)) { + return true; + } + + if (cc < MIN_CC_DP4A) { + return false; + } + +#ifdef GGML_CUDA_FORCE_MMQ + return true; +#endif //GGML_CUDA_FORCE_MMQ + + if (cc < CC_OFFSET_AMD) { + return cc < CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + } + + return cc < CC_RDNA3 || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index e2d07c202..31fcbf139 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -7,15 +7,10 @@ #include #include -#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) -#define MMQ_NWARPS 8 +#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. -typedef void (*load_tiles_mmq_t)( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride); -typedef void (*vec_dot_mmq_t)( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0); +typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride); +typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0); typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max); struct block_q8_1_mmq { @@ -31,25 +26,42 @@ struct tile_x_sizes { int sc; }; -// get_mmq_x_max_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row - -static constexpr __device__ int get_mmq_x_max_device() { -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - return 64; +static constexpr int get_mmq_x_max_host(const int cc) { + return int8_mma_available(cc) ? 128 : +#ifdef GGML_CUDA_FORCE_MMQ + cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64; #else -#if __CUDA_ARCH__ >= CC_VOLTA -#ifdef CUDA_USE_TENSOR_CORES - return MMQ_MAX_BATCH_SIZE; -#else - return 128; -#endif // CUDA_USE_TENSOR_CORES -#else - return 64; -#endif // __CUDA_ARCH__ >= CC_VOLTA -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64; +#endif // GGML_CUDA_FORCE_MMQ } -// get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row +static constexpr __device__ int get_mmq_x_max_device() { +#ifdef INT8_MMA_AVAILABLE + return 128; +#else // INT8_MMA_AVAILABLE + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + return 128; +#else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + +#if __CUDA_ARCH__ >= CC_VOLTA +#ifdef GGML_CUDA_FORCE_MMQ + return MMQ_DP4A_MAX_BATCH_SIZE; +#else // GGML_CUDA_FORCE_MMQ + return 128; +#endif // GGML_CUDA_FORCE_MMQ +#else // __CUDA_ARCH__ >= CC_VOLTA + + return 64; +#endif // __CUDA_ARCH__ >= CC_VOLTA + +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#endif // INT8_MMA_AVAILABLE +} + +static constexpr int get_mmq_y_host(const int cc) { + return int8_mma_available(cc) || cc >= CC_VOLTA ? 128 : 64; +} static constexpr __device__ int get_mmq_y_device() { #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) @@ -63,51 +75,101 @@ static constexpr __device__ int get_mmq_y_device() { #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) } -#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} -#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} -#define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0} -#define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0} -#define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0} -#define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0} -#define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4} -#define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} +#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} +#define MMQ_DP4A_TXS_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0} +#define MMQ_DP4A_TXS_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0} +#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0} +#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0} +#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4} +#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define GET_TILE_X_SIZES_BODY \ - return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \ - type == GGML_TYPE_Q4_1 ? TILE_X_SIZES_Q4_1 : \ - type == GGML_TYPE_Q5_0 ? TILE_X_SIZES_Q5_0 : \ - type == GGML_TYPE_Q5_1 ? TILE_X_SIZES_Q5_1 : \ - type == GGML_TYPE_Q8_0 ? TILE_X_SIZES_Q8_0 : \ - type == GGML_TYPE_Q2_K ? TILE_X_SIZES_Q2_K : \ - type == GGML_TYPE_Q3_K ? TILE_X_SIZES_Q3_K : \ - type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K : \ - type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K : \ - type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K : \ - tile_x_sizes{0, 0, 0} - -static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) { - GET_TILE_X_SIZES_BODY; +static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { + return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 : + type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 : + type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 : + type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 : + type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 : + type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K : + type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K : + type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : + type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : + type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : + tile_x_sizes{0, 0, 0}; } -template -static constexpr __device__ tile_x_sizes get_tile_x_sizes_device(ggml_type type) { - GET_TILE_X_SIZES_BODY; +#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4) +#define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1 + 4) +#define MMQ_MMA_TILE_X_K_Q5_0 (2*WARP_SIZE + WARP_SIZE/QI5_0 + 4) +#define MMQ_MMA_TILE_X_K_Q5_1 (2*WARP_SIZE + WARP_SIZE/QI5_1 + 4) +#define MMQ_MMA_TILE_X_K_Q8_0 (1*WARP_SIZE + WARP_SIZE/QI8_0 + 0) +#define MMQ_MMA_TILE_X_K_Q2_K (1*WARP_SIZE + WARP_SIZE + 4) +#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/QI3_K + WARP_SIZE/4 + 2) +#define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K + WARP_SIZE/8 + 7) +#define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K + WARP_SIZE/8 + 7) +#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7) + +static_assert(MMQ_MMA_TILE_X_K_Q4_0 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q4_1 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q5_0 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q5_1 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q5_K % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); + +static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { + return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 : + type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 : + type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 : + type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 : + type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K : + type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K : + type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K : + type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K : + type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : + 0; } +#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) +#define MMQ_NWARPS 8 + +static int mmq_get_granularity_host(const int mmq_x, const int cc) { + return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8; +} + +#ifdef INT8_MMA_AVAILABLE +static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { + return mmq_x >= 48 ? 16 : 8; +} +#else +static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) { + return 8; +} +#endif // INT8_MMA_AVAILABLE + // ------------------------------------------------------------ template static __device__ __forceinline__ void load_tiles_q4_0( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_sc); + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE const int kbx = threadIdx.x / QI4_0; const int kqsx = threadIdx.x % QI4_0; - float * x_dmf = (float *) x_dm; - #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { int i = i0 + threadIdx.y; @@ -118,7 +180,11 @@ template static __device__ __forceinlin const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; - x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx); +#else + x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx); +#endif // INT8_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; @@ -134,17 +200,21 @@ template static __device__ __forceinlin const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; - x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + kbxd] = bxi->d; +#else + x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d; +#endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - GGML_UNUSED(x_sc); + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - const float * x_df = (const float *) x_dm; + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; @@ -175,76 +245,90 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { #ifdef INT8_MMA_AVAILABLE - GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; typedef mma_int_C_I16J8 mma_C; - const float * x_df = (const float *) x_dm; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + WARP_SIZE; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; - mma_A A; - float dA[mma_C::ne/2]; + mma_A A[ntx]; + float dA[ntx][mma_C::ne/2]; - const int i0 = threadIdx.y*mma_A::I; - static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + mma_A::get_i(l); - const int k = k0 + mma_A::get_k(l) % QI4_0; - const int shift = 4*(mma_A::get_k(l) / QI4_0); - - A.x[l] = __vsubss4((x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808); - } + for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + for (int l = 0; l < mma_A::ne; ++l) { + const int i = i0 + n*mma_A::I + mma_A::get_i(l); + const int k = k0 + mma_A::get_k(l) % QI4_0; + const int shift = 4*(mma_A::get_k(l) / QI4_0); - dA[l] = x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0]; - } - - for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - mma_C C; - mma_B B; - half2 dsB[mma_C::ne/2]; - -#pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int j = j0 + mma_B::get_j(l); - const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; - - B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; + A[n].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808); } + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + + dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/QI4_0]; + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + mma_B B; + float dB[mma_C::ne/2]; + + B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K); + #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); - dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; + dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } - C.mma_K8(A, B); +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l]; + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l]; + } } } #else - GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; #endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q4_1( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_sc); + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + WARP_SIZE); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE const int kbx = threadIdx.x / QI4_1; const int kqsx = threadIdx.x % QI4_1; @@ -259,7 +343,11 @@ template static __device__ __forceinlin const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; - x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q4_1 + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); +#else + x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); +#endif // INT8_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; @@ -275,16 +363,21 @@ template static __device__ __forceinlin const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; - x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm; +#ifdef INT8_MMA_AVAILABLE + x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + kbxd] = bxi->dm; +#else + x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; +#endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - GGML_UNUSED(x_sc); + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; @@ -315,51 +408,53 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { #ifdef INT8_MMA_AVAILABLE - GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; + typedef mma_int_A_I16K4 mma_A_K4; typedef mma_int_B_J8K8 mma_B; typedef mma_int_C_I16J8 mma_C; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; - mma_A A; - half2 dmA[mma_C::ne/2]; + mma_A A[ntx]; + half2 dmA[ntx][mma_C::ne/2]; - const int i0 = threadIdx.y*mma_A::I; - static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + mma_A::get_i(l); - const int k = k0 + mma_A::get_k(l) % QI4_0; - const int shift = 4*(mma_A::get_k(l) / QI4_0); + for (int n = 0; n < ntx; ++n) { + ((mma_A_K4 *) &A[n])[0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0, MMQ_MMA_TILE_X_K_Q4_1); + A[n].x[2] = (A[n].x[0] >> 4) & 0x0F0F0F0F; + A[n].x[3] = (A[n].x[1] >> 4) & 0x0F0F0F0F; + A[n].x[0] &= 0x0F0F0F0F; + A[n].x[1] &= 0x0F0F0F0F; - A.x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F; - } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - dmA[l] = x_dm[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0]; + dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/QI4_1]; + } } - for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - mma_C C; +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { mma_B B; half2 dsB[mma_C::ne/2]; -#pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int j = j0 + mma_B::get_j(l); - const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; + B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K); - B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; - } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); @@ -367,24 +462,35 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; } - C.mma_K8(A, B); +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - const half2 dmA_dsB = dmA[l/2]*dsB[l%2]; - sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); + for (int l = 0; l < mma_C::ne; ++l) { + const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2]; + sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); + } } } #else - GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; #endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q5_0( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_sc); + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE const int kbx = threadIdx.x / QI5_0; const int kqsx = threadIdx.x % QI5_0; @@ -409,8 +515,6 @@ template static __device__ __forceinlin qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 - x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0; - int qs1 = (ql >> 4) & 0x0F0F0F0F; qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 @@ -418,12 +522,17 @@ template static __device__ __forceinlin qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 - x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; + x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; +#else + x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; + x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; +#endif // INT8_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; const int kbxd = threadIdx.x % blocks_per_tile_x_row; - float * x_dmf = (float *) x_dm; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) { @@ -435,19 +544,23 @@ template static __device__ __forceinlin const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; - x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + kbxd] = bxi->d; +#else + x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d; +#endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - GGML_UNUSED(x_sc); + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - const float * x_dmf = (const float *) x_dm; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -457,70 +570,57 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a( for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); - const int index_bx = i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0; - - int u[2*VDR_Q5_0_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_0) % WARP_SIZE]; - } - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl - (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE], + x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } template static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { #ifdef INT8_MMA_AVAILABLE - GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; typedef mma_int_C_I16J8 mma_C; - const float * x_df = (const float *) x_dm; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + WARP_SIZE*2; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; - mma_A A; - float dA[mma_C::ne/2]; + mma_A A[ntx]; + float dA[ntx][mma_C::ne/2]; - const int i0 = threadIdx.y*mma_A::I; - static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + mma_A::get_i(l); - const int k = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0; + for (int n = 0; n < ntx; ++n) { + A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_0 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_0); - A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k]; - } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I; - dA[l] = x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0]; + dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + k0/QI5_0]; + } } - for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - mma_C C; +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { mma_B B; float dB[mma_C::ne/2]; -#pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int j = j0 + mma_B::get_j(l); - const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; + B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K); - B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; - } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); @@ -528,23 +628,34 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma( dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; } - C.mma_K8(A, B); +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l]; + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l]; + } } } #else - GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; #endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q5_1( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_sc); + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE const int kbx = threadIdx.x / QI5_1; const int kqsx = threadIdx.x % QI5_1; @@ -568,15 +679,19 @@ template static __device__ __forceinlin qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 - x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0; - int qs1 = (ql >> 4) & 0x0F0F0F0F; qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 - x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; + x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; +#else + x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; + x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; +#endif // INT8_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; @@ -592,18 +707,23 @@ template static __device__ __forceinlin const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; - x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm; +#ifdef INT8_MMA_AVAILABLE + x_dm[i*MMQ_MMA_TILE_X_K_Q5_1 + kbxd] = bxi->dm; +#else + x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; +#endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - GGML_UNUSED(x_sc); + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -613,69 +733,57 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a( for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); - const int index_bx = i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1; - - int u[2*VDR_Q5_1_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_1) % WARP_SIZE]; - } - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl - (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE], + x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } template static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { #ifdef INT8_MMA_AVAILABLE - GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; typedef mma_int_C_I16J8 mma_C; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; - mma_A A; - half2 dmA[mma_C::ne/2]; + mma_A A[ntx]; + half2 dmA[ntx][mma_C::ne/2]; - const int i0 = threadIdx.y*mma_A::I; - static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + mma_A::get_i(l); - const int k = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1; + for (int n = 0; n < ntx; ++n) { + A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_1 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_1); - A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k]; - } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I; - dmA[l] = x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1]; + dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_1 + k0/QI5_1]; + } } - for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - mma_C C; +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { mma_B B; half2 dsB[mma_C::ne/2]; -#pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int j = j0 + mma_B::get_j(l); - const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; + B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K); - B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; - } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); @@ -683,28 +791,38 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma( dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; } - C.mma_K8(A, B); +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - const half2 dmA_dsB = dmA[l/2]*dsB[l%2]; - sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); + for (int l = 0; l < mma_C::ne; ++l) { + const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2]; + sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); + } } } #else - GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; #endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q8_0( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_sc); + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_tile + WARP_SIZE); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE const int kbx = threadIdx.x / QI8_0; const int kqsx = threadIdx.x % QI8_0; - float * x_dmf = (float *) x_dm; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { @@ -716,7 +834,11 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; - x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx); +#else + x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx); +#endif // INT8_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI8_0; @@ -732,19 +854,23 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; - x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#else + x_df[i*(WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d; +#endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - GGML_UNUSED(x_sc); + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - const float * x_dmf = (const float *) x_dm; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -755,7 +881,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int i = i0 + threadIdx.x; sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl - (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0], + (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + k0/QI8_1]); } } @@ -763,51 +889,48 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { #ifdef INT8_MMA_AVAILABLE - GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; typedef mma_int_C_I16J8 mma_C; - const float * x_df = (const float *) x_dm; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + WARP_SIZE; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; - mma_A A; - float dA[mma_C::ne/2]; + mma_A A[ntx]; + float dA[ntx][mma_C::ne/2]; - const int i0 = threadIdx.y*mma_A::I; - static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + const int i0 = (threadIdx.y/ntx)*rows_per_warp; #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + mma_A::get_i(l); - const int k = k0 + mma_A::get_k(l); + for (int n = 0; n < ntx; ++n) { + A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); - A.x[l] = x_qs[i*(WARP_SIZE + 1) + k]; - } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); - dA[l] = x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0]; + dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; + } } - for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - mma_C C; +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { mma_B B; float dB[mma_C::ne/2]; -#pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int j = j0 + mma_B::get_j(l); - const int k = k0 + mma_B::get_k(l); + B.load(y_qs + j0*MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K); - B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; - } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); @@ -815,22 +938,34 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1]; } - C.mma_K8(A, B); +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2]; + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2]; + } } } #else - GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; #endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q2_K( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + WARP_SIZE); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE const int kbx = threadIdx.x / QI2_K; const int kqsx = threadIdx.x % QI2_K; @@ -859,7 +994,11 @@ template static __device__ __forceinlin continue; } - x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; +#else + x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k; +#endif // INT8_MMA_AVAILABLE } const int sc_m = bxi->scales[kqsx]; @@ -870,15 +1009,21 @@ template static __device__ __forceinlin const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); #endif // FAST_FP16_AVAILABLE - x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = x_dm_ik; +#ifdef INT8_MMA_AVAILABLE + x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + threadIdx.x] = x_dm_ik; +#else + x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = x_dm_ik; +#endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; @@ -899,61 +1044,63 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K4 mma_A; typedef mma_int_B_J8K4 mma_B; typedef mma_int_C_I16J8 mma_C; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; - const int i0 = threadIdx.y*mma_A::I; - static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - mma_A A[2]; - float dA[mma_C::ne/2][2]; - float mA[mma_C::ne/2][2]; + mma_A A[ntx][2]; + float dA[ntx][mma_C::ne/2][2]; + float mA[ntx][mma_C::ne/2][2]; #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + mma_A::get_i(l); - const int shift = 2*mma_A::get_k(l); + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < mma_A::ne; ++l) { + const int i = i0 + n*mma_A::I + mma_A::get_i(l); + const int shift = 2*mma_A::get_k(l); - A[0].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 0] >> shift) & 0x03030303; - A[1].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 1] >> shift) & 0x03030303; - } + A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 0] >> shift) & 0x03030303; + A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 1] >> shift) & 0x03030303; + } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); #pragma unroll - for (int kk = 0; kk < 2; ++kk) { - const float2 dm = __half22float2(x_dm[i*(WARP_SIZE + 1) + k0 + kk]); + for (int kdm = 0; kdm < 2; ++kdm) { + const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + kdm]); - dA[l][kk] = dm.x; - mA[l][kk] = dm.y; + dA[n][l][kdm] = dm.x; + mA[n][l][kdm] = dm.y; + } } } #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - mma_C Cd[2]; - mma_C Cm[2]; + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { mma_B B[2]; float dB[mma_C::ne/2]; -#pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int j = j0 + mma_B::get_j(l); - const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE; + B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + 0) % WARP_SIZE, MMQ_TILE_Y_K); + B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K); - B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0]; - B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K]; - } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); @@ -961,9 +1108,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)]; } - Cd[0].mma_K4(A[0], B[0]); - Cd[1].mma_K4(A[1], B[1]); - + mma_C Cm[2]; mma_A A1; A1.x[0] = 0x01010101; A1.x[1] = 0x01010101; @@ -971,19 +1116,38 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( Cm[1].mma_K4(A1, B[1]); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_B::J)*mma_C::ne + l] += (Cd[0].x[l]*dA[l/2][0] + Cd[1].x[l]*dA[l/2][1] - Cm[0].x[l]*mA[l/2][0] - Cm[1].x[l]*mA[l/2][1])*dB[l%2]; + for (int n = 0; n < ntx; ++n) { + mma_C Cd[2]; + + Cd[0].mma_K4(A[n][0], B[0]); + Cd[1].mma_K4(A[n][1], B[1]); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += ( + Cd[0].x[l]*dA[n][l/2][0] + Cd[1].x[l]*dA[n][l/2][1] - Cm[0].x[l]*mA[n][l/2][0] - Cm[1].x[l]*mA[n][l/2][1])*dB[l%2]; + } } } #else - GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; #endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q3_K( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); + int * x_sc = (int *) (x_df + WARP_SIZE/QI3_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); + int * x_sc = (int *) (x_df + txs.dm); +#endif // INT8_MMA_AVAILABLE const int kbx = threadIdx.x / QI3_K; const int kqsx = threadIdx.x % QI3_K; @@ -1015,13 +1179,16 @@ template static __device__ __forceinlin continue; } - x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2] = x_qs_k; +#else + x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k; +#endif // INT8_MMA_AVAILABLE } } const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; const int kbxd = threadIdx.x % blocks_per_tile_x_row; - float * x_dmf = (float *) x_dm; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) { @@ -1033,7 +1200,11 @@ template static __device__ __forceinlin const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbxd; - x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + kbxd] = bxi->d; +#else + x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbxd] = bxi->d; +#endif // INT8_MMA_AVAILABLE } #pragma unroll @@ -1058,16 +1229,22 @@ template static __device__ __forceinlin const int sc = __vsubss4(sc_low | sc_high, 0x20202020); - x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = sc; +#ifdef INT8_MMA_AVAILABLE + x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/4)] = sc; +#else + x_sc[i*(WARP_SIZE/4) + i/4 + threadIdx.x % (WARP_SIZE/4)] = sc; +#endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - const float * x_df = (const float *) x_dm; + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * x_sc = (const int *) x_df + txs.dm; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; @@ -1093,69 +1270,72 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K4 mma_A; typedef mma_int_B_J8K4 mma_B; typedef mma_int_C_I16J8 mma_C; - const float * x_df = (const float *) x_dm; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + WARP_SIZE*2; + const int * x_sc = (const int *) x_df + WARP_SIZE/QI3_K; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; - const int i0 = threadIdx.y*mma_A::I; - static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - mma_A A[2]; - int scA[mma_C::ne/2][2]; - float dA[mma_C::ne/2]; + mma_A A[ntx][2]; + int scA[ntx][mma_C::ne/2][2]; + float dA[ntx][mma_C::ne/2]; #pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + mma_A::get_i(l); - const int k = QR3_K*k0 + mma_A::get_k(l); + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < mma_A::ne; ++l) { + const int i = i0 + n*mma_A::I + mma_A::get_i(l); + const int k = QR3_K*k0 + mma_A::get_k(l); - A[0].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + 0] >> (4*(k%2))) & 0x0F0F0F0F; - A[1].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F; - A[0].x[l] = __vsubss4(A[0].x[l], 0x04040404); - A[1].x[l] = __vsubss4(A[1].x[l], 0x04040404); + A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + 0] >> (4*(k%2))) & 0x0F0F0F0F; + A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F; + A[n][0].x[l] = __vsubss4(A[n][0].x[l], 0x04040404); + A[n][1].x[l] = __vsubss4(A[n][1].x[l], 0x04040404); + } + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + + const int kbx = k0 / QI3_K; + const int ky = (k0 % QI3_K) * QR3_K; + const int8_t * sc = ((const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q3_K + kbx*4)) + ky/4; + + scA[n][l][0] = sc[0]; + scA[n][l][1] = sc[1]; + } + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + + dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/QI3_K]; + } } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); - - const int kbx = k0 / QI3_K; - const int ky = (k0 % QI3_K) * QR3_K; - const int8_t * sc = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; - - scA[l][0] = sc[0]; - scA[l][1] = sc[1]; - } - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); - - dA[l] = x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + k0/QI3_K]; - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - mma_C C[2]; + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { mma_B B[2]; float dB[mma_C::ne/2]; -#pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int j = j0 + mma_B::get_j(l); - const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE; + B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + 0) % WARP_SIZE, MMQ_TILE_Y_K); + B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K); - B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0]; - B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K]; - } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); @@ -1163,23 +1343,37 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma( dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)]; } - C[0].mma_K4(A[0], B[0]); - C[1].mma_K4(A[1], B[1]); +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C[2]; + C[0].mma_K4(A[n][0], B[0]); + C[1].mma_K4(A[n][1], B[1]); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_B::J)*mma_C::ne + l] += (C[0].x[l]*scA[l/2][0] + C[1].x[l]*scA[l/2][1])*dA[l/2]*dB[l%2]; + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += (C[0].x[l]*scA[n][l/2][0] + C[1].x[l]*scA[n][l/2][1])*dA[n][l/2]*dB[l%2]; + } } } #else - GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; #endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q4_K( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + WARP_SIZE); + int * x_sc = (int *) (x_dm + WARP_SIZE/QI4_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); + int * x_sc = (int *) (x_dm + txs.dm); +#endif // INT8_MMA_AVAILABLE const int kbx = 0; // threadIdx.x / QI4_K const int kqsx = threadIdx.x; // threadIdx.x % QI4_K @@ -1194,7 +1388,11 @@ template static __device__ __forceinlin const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx; - x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q4_K + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); +#else + x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); +#endif // INT8_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 @@ -1210,7 +1408,11 @@ template static __device__ __forceinlin const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbxd; - x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm; +#ifdef INT8_MMA_AVAILABLE + x_dm[i*MMQ_MMA_TILE_X_K_Q4_K + kbxd] = bxi->dm; +#else + x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K + kbxd] = bxi->dm; +#endif // INT8_MMA_AVAILABLE } #pragma unroll @@ -1231,15 +1433,22 @@ template static __device__ __forceinlin int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; +#ifdef INT8_MMA_AVAILABLE + x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + ksc] = scales8; +#else + x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; +#endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * x_sc = (const int *) x_dm + txs.dm; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; @@ -1262,71 +1471,79 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; typedef mma_int_C_I16J8 mma_C; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE; + const int * x_sc = (const int *) x_dm + WARP_SIZE/QI4_K; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; - const int i0 = threadIdx.y*mma_A::I; - static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - mma_A A[2]; - int scA[mma_C::ne/2][2]; - int mA[mma_C::ne/2][2]; - half2 dmA[mma_C::ne/2]; -#pragma unroll - for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) { -#pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + mma_A::get_i(l); - const int k = k0 + mma_A::get_k(l); + mma_A A[ntx][2]; + int scA[ntx][mma_C::ne/2][2]; + int mA[ntx][mma_C::ne/2][2]; + half2 dmA[ntx][mma_C::ne/2]; - A[kvdr/4].x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F; +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 8) { + A[n][kvdr/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0, MMQ_MMA_TILE_X_K_Q4_K); + +#pragma unroll + for (int l = 0; l < mma_A::ne; ++l) { + A[n][kvdr/4 + 1].x[l] = (A[n][kvdr/4 + 0].x[l] >> 4) & 0x0F0F0F0F; + A[n][kvdr/4 + 0].x[l] &= 0x0F0F0F0F; + } + } + +#pragma unroll + for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) { +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); + + const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + k0/16]) + 2 * ((k0 % 16) / 8); + const uint8_t * m = sc + 8; + + scA[n][l][kvdr/4] = sc[kvdr/4]; + mA[n][l][kvdr/4] = m[kvdr/4]; + } } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8); - const uint8_t * m = sc + 8; - - scA[l][kvdr/4] = sc[kvdr/4]; - mA[l][kvdr/4] = m[kvdr/4]; + dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K + k0/QI4_K]; } } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); - - dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K]; - } + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + float tmpd[ntx][mma_C::ne] = {{0.0f}}; + float tmpm[ntx][mma_C::ne] = {{0.0f}}; #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - float tmpd[mma_C::ne] = {0.0f}; - float tmpm[mma_C::ne] = {0.0f}; - -#pragma unroll - for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) { - mma_C C; + for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) { mma_B B; half2 dsB[mma_C::ne/2]; -#pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int j = j0 + mma_B::get_j(l); - const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE; + B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K); - B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; - } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); @@ -1334,29 +1551,46 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)]; } - C.mma_K8(A[kvdr/4], B); +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n][kvdr/4], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) * __low2float(dsB[l%2]); - tmpm[l] += mA[l/2][kvdr/4] * __high2float(dsB[l%2]); + for (int l = 0; l < mma_C::ne; ++l) { + tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) * __low2float(dsB[l%2]); + tmpm[n][l] += mA[n][l/2][kvdr/4] * __high2float(dsB[l%2]); + } } } #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l]; + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l]; + } } } #else - GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; #endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q5_K( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2); + int * x_sc = (int *) (x_dm + WARP_SIZE/QI5_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); + int * x_sc = (int *) (x_dm + txs.dm); +#endif // INT8_MMA_AVAILABLE const int kbx = 0; // threadIdx.x / QI5_K const int kqsx = threadIdx.x; // threadIdx.x % QI5_K @@ -1383,8 +1617,13 @@ template static __device__ __forceinlin const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0; const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4); - x_qs[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0; - x_qs[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq0] = ql0 | qh0; + x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq1] = ql1 | qh1; +#else + x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0; + x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1; +#endif // INT8_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 @@ -1400,7 +1639,11 @@ template static __device__ __forceinlin const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbxd; - x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm; +#ifdef INT8_MMA_AVAILABLE + x_dm[i*MMQ_MMA_TILE_X_K_Q5_K + kbxd] = bxi->dm; +#else + x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + kbxd] = bxi->dm; +#endif // INT8_MMA_AVAILABLE } #pragma unroll @@ -1421,17 +1664,24 @@ template static __device__ __forceinlin int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; +#ifdef INT8_MMA_AVAILABLE + x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + ksc] = scales8; +#else + x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; +#endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * x_sc = (const int *) x_dm + txs.dm; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -1452,71 +1702,70 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; typedef mma_int_C_I16J8 mma_C; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2; + const int * x_sc = (const int *) x_dm + WARP_SIZE/QI5_K; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; - const int i0 = threadIdx.y*mma_A::I; - static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - mma_A A[2]; - int scA[mma_C::ne/2][2]; - int mA[mma_C::ne/2][2]; - half2 dmA[mma_C::ne/2]; -#pragma unroll - for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) { -#pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + mma_A::get_i(l); - const int k = QR5_K*k0 + QR5_K*kvdr + mma_A::get_k(l); + mma_A A[ntx][2]; + int scA[ntx][mma_C::ne/2][2]; + int mA[ntx][mma_C::ne/2][2]; + half2 dmA[ntx][mma_C::ne/2]; - A[kvdr/4].x[l] = x_qs[i*(QR5_K*WARP_SIZE + 1) + k]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) { + A[n][kvdr/4].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + (QR5_K*k0 + QR5_K*kvdr), MMQ_MMA_TILE_X_K_Q5_K); + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + + const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + k0/16]) + 2 * ((k0 % 16) / 8); + const uint8_t * m = sc + 8; + + scA[n][l][kvdr/4] = sc[kvdr/4]; + mA[n][l][kvdr/4] = m[kvdr/4]; + } } -#pragma unroll + #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8); - const uint8_t * m = sc + 8; - - scA[l][kvdr/4] = sc[kvdr/4]; - mA[l][kvdr/4] = m[kvdr/4]; + dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K + k0/QI5_K]; } } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); - - dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K]; - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - float tmpd[mma_C::ne] = {0.0f}; - float tmpm[mma_C::ne] = {0.0f}; + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + float tmpd[ntx][mma_C::ne] = {{0.0f}}; + float tmpm[ntx][mma_C::ne] = {{0.0f}}; #pragma unroll for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) { - mma_C C; mma_B B; half2 dsB[mma_C::ne/2]; -#pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int j = j0 + mma_B::get_j(l); - const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE; + B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K); - B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; - } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); @@ -1524,29 +1773,46 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)]; } - C.mma_K8(A[kvdr/4], B); +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n][kvdr/4], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) * __low2float(dsB[l%2]); - tmpm[l] += mA[l/2][kvdr/4] * __high2float(dsB[l%2]); + for (int l = 0; l < mma_C::ne; ++l) { + tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) * __low2float(dsB[l%2]); + tmpm[n][l] += mA[n][l/2][kvdr/4] * __high2float(dsB[l%2]); + } } } #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l]; + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l]; + } } } #else - GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; #endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q6_K( - const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); + int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); + int * x_sc = (int *) (x_df + txs.dm); +#endif // INT8_MMA_AVAILABLE const int kbx = 0; // threadIdx.x / QI6_K const int kqsx = threadIdx.x; // threadIdx.x % QI6_K @@ -1573,13 +1839,17 @@ template static __device__ __forceinlin const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0; const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2); - x_qs[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); - x_qs[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); +#else + x_qs[i*(2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_qs[i*(2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); +#endif // INT8_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256 - float * x_dmf = (float *) x_dm; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) { @@ -1591,7 +1861,11 @@ template static __device__ __forceinlin const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd; - x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d; +#else + x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d; +#endif // INT8_MMA_AVAILABLE } #pragma unroll @@ -1604,18 +1878,24 @@ template static __device__ __forceinlin const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4; - x_sc[i * (WARP_SIZE/8) + i / 8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8)); +#ifdef INT8_MMA_AVAILABLE + x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8)); +#else + x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8)); +#endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - const float * x_dmf = (const float *) x_dm; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * x_sc = (const int *) x_df + txs.dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -1629,80 +1909,77 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq( &x_qs[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc, - x_dmf[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]); + x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]); } } } template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( - const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, - const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { #ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K4 mma_A; typedef mma_int_B_J8K4 mma_B; typedef mma_int_C_I16J8 mma_C; - const float * x_df = (const float *) x_dm; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + WARP_SIZE*2; + const int * x_sc = (const int *) x_df + WARP_SIZE/QI6_K; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; - const int i0 = threadIdx.y*mma_A::I; -#ifdef INT8_MMA_AVAILABLE - static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - mma_A A[4]; - int scA[mma_C::ne/2][4]; - float dA[mma_C::ne/2]; -#pragma unroll - for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) { -#pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + mma_A::get_i(l); - const int k = QR6_K*k0 + QR6_K*kvdr + mma_A::get_k(l); + mma_A A[ntx][4]; + int scA[ntx][mma_C::ne/2][4]; + float dA[ntx][mma_C::ne/2]; - A[kvdr/2 + 0].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + 0]; - A[kvdr/2 + 1].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) { + A[n][kvdr/2 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + 0), MMQ_MMA_TILE_X_K_Q6_K); + A[n][kvdr/2 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K); + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + + const int8_t * sc = ((const int8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/8]); + + scA[n][l][kvdr/2 + 0] = sc[kvdr/2 + 0]; + scA[n][l][kvdr/2 + 1] = sc[kvdr/2 + 1]; + } } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]); - - scA[l][kvdr/2 + 0] = sc[kvdr/2 + 0]; - scA[l][kvdr/2 + 1] = sc[kvdr/2 + 1]; + dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K + k0/QI6_K]; } } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + mma_C::get_i(2*l); - - dA[l] = x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + k0/QI6_K]; - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - float tmp[mma_C::ne] = {0.0f}; + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + float tmp[ntx][mma_C::ne] = {{0.0f}}; #pragma unroll for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) { - mma_C C[2]; mma_B B[2]; float dB[mma_C::ne/2]; -#pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int j = j0 + mma_B::get_j(l); - const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE; + const int k0B = (2*k0 + 2*kvdr) % WARP_SIZE; + B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0 + k0B, MMQ_TILE_Y_K); + B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k0B, MMQ_TILE_Y_K); - B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0]; - B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K]; - } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); @@ -1710,22 +1987,29 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( dB[l] = y_df[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)]; } - C[0].mma_K4(A[kvdr/2 + 0], B[0]); - C[1].mma_K4(A[kvdr/2 + 1], B[1]); +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C[2]; + C[0].mma_K4(A[n][kvdr/2 + 0], B[0]); + C[1].mma_K4(A[n][kvdr/2 + 1], B[1]); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - tmp[l] += (C[0].x[l]*scA[l/2][kvdr/2 + 0] + C[1].x[l]*scA[l/2][kvdr/2 + 1])*dB[l%2]; + for (int l = 0; l < mma_C::ne; ++l) { + tmp[n][l] += (C[0].x[l]*scA[n][l/2][kvdr/2 + 0] + C[1].x[l]*scA[n][l/2][kvdr/2 + 1])*dB[l%2]; + } } } #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_B::J)*mma_C::ne + l] += tmp[l]*dA[l/2]; + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2]; + } } } #else - GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; #endif // INT8_MMA_AVAILABLE } @@ -1761,28 +2045,35 @@ static __device__ __forceinline__ void mmq_write_back_mma( typedef mma_int_C_I16J8 mma_C; - const int i0 = threadIdx.y*mma_C::I; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I); #ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y"); #endif // INT8_MMA_AVAILABLE #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) { + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l); - if (j > j_max) { - continue; + if (j > j_max) { + continue; + } + + const int i = i0 + n*mma_C::I + mma_C::get_i(l); + + if (need_check && i > i_max) { + continue; + } + + dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l]; } - - const int i = i0 + mma_C::get_i(l); - - if (need_check && i > i_max) { - continue; - } - - dst[j*stride + i] = sum[(j0/mma_C::J)*mma_C::ne + l]; } } } @@ -1910,6 +2201,10 @@ static __device__ void mul_mat_q_process_tile( constexpr int vdr = mmq_type_traits::vdr; constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; + extern __shared__ char data_mul_mat_q[]; + int * tile_y = (int *) data_mul_mat_q; + int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE); + #ifdef INT8_MMA_AVAILABLE constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; constexpr mmq_write_back_t write_back = mmq_write_back_mma; @@ -1918,14 +2213,6 @@ static __device__ void mul_mat_q_process_tile( constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; #endif // INT8_MMA_AVAILABLE - constexpr tile_x_sizes txs = get_tile_x_sizes_device(type); - - extern __shared__ char data_mul_mat_q[]; - int * tile_x_qs = (int *) data_mul_mat_q; - half2 * tile_x_dm = (half2 *) (tile_x_qs + txs.qs); - int * tile_x_sc = (int *) (tile_x_dm + txs.dm); - int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)] - constexpr int blocks_per_warp = WARP_SIZE / qi; float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; @@ -1937,7 +2224,7 @@ static __device__ void mul_mat_q_process_tile( for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_warp) { - load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*it*mmq_y + kb0, tile_x_max_i, stride01); + load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01); #pragma unroll for (int kr = 0; kr < qr; ++kr) { @@ -1953,7 +2240,7 @@ static __device__ void mul_mat_q_process_tile( // #pragma unroll // unrolling this loop causes too much register pressure for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) { - vec_dot(tile_x_qs, tile_x_dm, tile_x_sc, tile_y, sum, k0); + vec_dot(tile_x, tile_y, sum, k0); } __syncthreads(); @@ -1987,7 +2274,7 @@ static __global__ void mul_mat_q( const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) { // Skip unused template specializations for faster compilation: - if (mmq_x > get_mmq_x_max_device()) { + if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) { NO_DEVICE_CODE; return; } @@ -2139,11 +2426,12 @@ struct mmq_args { int64_t ne0; }; -static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) { - const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y); - - const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); - const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2); +template +static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) { + const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); + const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); + const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const int shmem_y = mmq_x*sizeof(block_q8_1_mmq); return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int)); } @@ -2156,7 +2444,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1); - const int shmem = mmq_get_shmem(type, mmq_x, mmq_y); + const int shmem = mmq_get_shmem(mmq_x, mmq_y, cc); #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -2225,12 +2513,17 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda int nparts_best = INT_MAX; for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) { + const int granularity = mmq_get_granularity_host(mmq_x, cc); + + if (mmq_x % granularity != 0 || mmq_get_shmem(mmq_x, mmq_y, cc) > smpbo) { + continue; + } + const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x; const int nwaves_xy_tiling = ntiles_x*block_num_y; - const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling; - if (nparts < nparts_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) { + if (nparts < nparts_best) { mmq_x_best = mmq_x; nparts_best = nparts; } @@ -2314,4 +2607,4 @@ void ggml_cuda_op_mul_mat_q( const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); -bool ggml_cuda_supports_mmq(enum ggml_type type); +bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11); diff --git a/ggml-cuda/mmvq.cuh b/ggml-cuda/mmvq.cuh index 88c42c4b7..d9e42fdd6 100644 --- a/ggml-cuda/mmvq.cuh +++ b/ggml-cuda/mmvq.cuh @@ -1,5 +1,7 @@ #include "common.cuh" +#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. + void ggml_cuda_op_mul_mat_vec_q( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index e5ddf4a34..db045336f 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -4620,7 +4620,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // KQV single-batch ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst); - } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // KQ + KQV multi-batch ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index d266fbd43..222a2d137 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -49,6 +49,7 @@ class Keys: EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" + DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" class Attention: HEAD_COUNT = "{arch}.attention.head_count" @@ -62,6 +63,7 @@ class Keys: CAUSAL = "{arch}.attention.causal" Q_LORA_RANK = "{arch}.attention.q_lora_rank" KV_LORA_RANK = "{arch}.attention.kv_lora_rank" + REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count" class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" @@ -73,6 +75,11 @@ class Keys: SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier" + class Split: + LLM_KV_SPLIT_NO = "split.no" + LLM_KV_SPLIT_COUNT = "split.count" + LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count" + class SSM: CONV_KERNEL = "{arch}.ssm.conv_kernel" INNER_SIZE = "{arch}.ssm.inner_size" @@ -80,33 +87,35 @@ class Keys: TIME_STEP_RANK = "{arch}.ssm.time_step_rank" class Tokenizer: - MODEL = "tokenizer.ggml.model" - PRE = "tokenizer.ggml.pre" - LIST = "tokenizer.ggml.tokens" - TOKEN_TYPE = "tokenizer.ggml.token_type" - TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types - SCORES = "tokenizer.ggml.scores" - MERGES = "tokenizer.ggml.merges" - BOS_ID = "tokenizer.ggml.bos_token_id" - EOS_ID = "tokenizer.ggml.eos_token_id" - UNK_ID = "tokenizer.ggml.unknown_token_id" - SEP_ID = "tokenizer.ggml.seperator_token_id" - PAD_ID = "tokenizer.ggml.padding_token_id" - CLS_ID = "tokenizer.ggml.cls_token_id" - MASK_ID = "tokenizer.ggml.mask_token_id" - ADD_BOS = "tokenizer.ggml.add_bos_token" - ADD_EOS = "tokenizer.ggml.add_eos_token" - ADD_PREFIX = "tokenizer.ggml.add_space_prefix" - HF_JSON = "tokenizer.huggingface.json" - RWKV = "tokenizer.rwkv.world" - CHAT_TEMPLATE = "tokenizer.chat_template" - CHAT_TEMPLATE_N = "tokenizer.chat_template.{name}" - CHAT_TEMPLATES = "tokenizer.chat_templates" + MODEL = "tokenizer.ggml.model" + PRE = "tokenizer.ggml.pre" + LIST = "tokenizer.ggml.tokens" + TOKEN_TYPE = "tokenizer.ggml.token_type" + TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types + SCORES = "tokenizer.ggml.scores" + MERGES = "tokenizer.ggml.merges" + BOS_ID = "tokenizer.ggml.bos_token_id" + EOS_ID = "tokenizer.ggml.eos_token_id" + UNK_ID = "tokenizer.ggml.unknown_token_id" + SEP_ID = "tokenizer.ggml.seperator_token_id" + PAD_ID = "tokenizer.ggml.padding_token_id" + CLS_ID = "tokenizer.ggml.cls_token_id" + MASK_ID = "tokenizer.ggml.mask_token_id" + ADD_BOS = "tokenizer.ggml.add_bos_token" + ADD_EOS = "tokenizer.ggml.add_eos_token" + ADD_PREFIX = "tokenizer.ggml.add_space_prefix" + REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces" + PRECOMPILED_CHARSMAP = "tokenizer.ggml.precompiled_charsmap" + HF_JSON = "tokenizer.huggingface.json" + RWKV = "tokenizer.rwkv.world" + CHAT_TEMPLATE = "tokenizer.chat_template" + CHAT_TEMPLATE_N = "tokenizer.chat_template.{name}" + CHAT_TEMPLATES = "tokenizer.chat_templates" # FIM/Infill special tokens constants - PREFIX_ID = "tokenizer.ggml.prefix_token_id" - SUFFIX_ID = "tokenizer.ggml.suffix_token_id" - MIDDLE_ID = "tokenizer.ggml.middle_token_id" - EOT_ID = "tokenizer.ggml.eot_token_id" + PREFIX_ID = "tokenizer.ggml.prefix_token_id" + SUFFIX_ID = "tokenizer.ggml.suffix_token_id" + MIDDLE_ID = "tokenizer.ggml.middle_token_id" + EOT_ID = "tokenizer.ggml.eot_token_id" # @@ -115,94 +124,123 @@ class Keys: class MODEL_ARCH(IntEnum): - LLAMA = auto() - FALCON = auto() - BAICHUAN = auto() - GROK = auto() - GPT2 = auto() - GPTJ = auto() - GPTNEOX = auto() - MPT = auto() - STARCODER = auto() - REFACT = auto() - BERT = auto() - NOMIC_BERT = auto() + LLAMA = auto() + FALCON = auto() + BAICHUAN = auto() + GROK = auto() + GPT2 = auto() + GPTJ = auto() + GPTNEOX = auto() + MPT = auto() + STARCODER = auto() + REFACT = auto() + BERT = auto() + NOMIC_BERT = auto() JINA_BERT_V2 = auto() - BLOOM = auto() - STABLELM = auto() - QWEN = auto() - QWEN2 = auto() - QWEN2MOE = auto() - PHI2 = auto() - PHI3 = auto() - PLAMO = auto() - CODESHELL = auto() - ORION = auto() - INTERNLM2 = auto() - MINICPM = auto() - GEMMA = auto() - STARCODER2 = auto() - MAMBA = auto() - XVERSE = auto() - COMMAND_R = auto() - DBRX = auto() - OLMO = auto() - ARCTIC = auto() - DEEPSEEK2 = auto() - BITNET = auto() + BLOOM = auto() + STABLELM = auto() + QWEN = auto() + QWEN2 = auto() + QWEN2MOE = auto() + PHI2 = auto() + PHI3 = auto() + PLAMO = auto() + CODESHELL = auto() + ORION = auto() + INTERNLM2 = auto() + MINICPM = auto() + GEMMA = auto() + STARCODER2 = auto() + MAMBA = auto() + XVERSE = auto() + COMMAND_R = auto() + DBRX = auto() + OLMO = auto() + ARCTIC = auto() + DEEPSEEK2 = auto() + BITNET = auto() + T5 = auto() class MODEL_TENSOR(IntEnum): - TOKEN_EMBD = auto() - TOKEN_EMBD_NORM = auto() - TOKEN_TYPES = auto() - POS_EMBD = auto() - OUTPUT = auto() - OUTPUT_NORM = auto() - ROPE_FREQS = auto() - ROPE_FACTORS_LONG = auto() - ROPE_FACTORS_SHORT = auto() - ATTN_Q = auto() - ATTN_K = auto() - ATTN_V = auto() - ATTN_QKV = auto() - ATTN_OUT = auto() - ATTN_NORM = auto() - ATTN_NORM_2 = auto() - ATTN_OUT_NORM = auto() - ATTN_ROT_EMBD = auto() - FFN_GATE_INP = auto() - FFN_GATE_INP_SHEXP = auto() - FFN_NORM = auto() - FFN_GATE = auto() - FFN_DOWN = auto() - FFN_UP = auto() - FFN_ACT = auto() - FFN_NORM_EXP = auto() - FFN_GATE_EXP = auto() - FFN_DOWN_EXP = auto() - FFN_UP_EXP = auto() - FFN_GATE_SHEXP = auto() - FFN_DOWN_SHEXP = auto() - FFN_UP_SHEXP = auto() - ATTN_Q_NORM = auto() - ATTN_K_NORM = auto() - LAYER_OUT_NORM = auto() - SSM_IN = auto() - SSM_CONV1D = auto() - SSM_X = auto() - SSM_DT = auto() - SSM_A = auto() - SSM_D = auto() - SSM_OUT = auto() - ATTN_Q_A = auto() - ATTN_Q_B = auto() - ATTN_KV_A_MQA = auto() - ATTN_KV_B = auto() - ATTN_Q_A_NORM = auto() - ATTN_KV_A_NORM = auto() - FFN_SUB_NORM = auto() - ATTN_SUB_NORM = auto() + TOKEN_EMBD = auto() + TOKEN_EMBD_NORM = auto() + TOKEN_TYPES = auto() + POS_EMBD = auto() + OUTPUT = auto() + OUTPUT_NORM = auto() + ROPE_FREQS = auto() + ROPE_FACTORS_LONG = auto() + ROPE_FACTORS_SHORT = auto() + ATTN_Q = auto() + ATTN_K = auto() + ATTN_V = auto() + ATTN_QKV = auto() + ATTN_OUT = auto() + ATTN_NORM = auto() + ATTN_NORM_2 = auto() + ATTN_OUT_NORM = auto() + ATTN_ROT_EMBD = auto() + FFN_GATE_INP = auto() + FFN_GATE_INP_SHEXP = auto() + FFN_NORM = auto() + FFN_GATE = auto() + FFN_DOWN = auto() + FFN_UP = auto() + FFN_ACT = auto() + FFN_NORM_EXP = auto() + FFN_GATE_EXP = auto() + FFN_DOWN_EXP = auto() + FFN_UP_EXP = auto() + FFN_GATE_SHEXP = auto() + FFN_DOWN_SHEXP = auto() + FFN_UP_SHEXP = auto() + ATTN_Q_NORM = auto() + ATTN_K_NORM = auto() + LAYER_OUT_NORM = auto() + SSM_IN = auto() + SSM_CONV1D = auto() + SSM_X = auto() + SSM_DT = auto() + SSM_A = auto() + SSM_D = auto() + SSM_OUT = auto() + ATTN_Q_A = auto() + ATTN_Q_B = auto() + ATTN_KV_A_MQA = auto() + ATTN_KV_B = auto() + ATTN_Q_A_NORM = auto() + ATTN_KV_A_NORM = auto() + FFN_SUB_NORM = auto() + ATTN_SUB_NORM = auto() + DEC_ATTN_NORM = auto() + DEC_ATTN_Q = auto() + DEC_ATTN_K = auto() + DEC_ATTN_V = auto() + DEC_ATTN_OUT = auto() + DEC_ATTN_REL_B = auto() + DEC_CROSS_ATTN_NORM = auto() + DEC_CROSS_ATTN_Q = auto() + DEC_CROSS_ATTN_K = auto() + DEC_CROSS_ATTN_V = auto() + DEC_CROSS_ATTN_OUT = auto() + DEC_CROSS_ATTN_REL_B = auto() + DEC_FFN_NORM = auto() + DEC_FFN_GATE = auto() + DEC_FFN_DOWN = auto() + DEC_FFN_UP = auto() + DEC_OUTPUT_NORM = auto() + ENC_ATTN_NORM = auto() + ENC_ATTN_Q = auto() + ENC_ATTN_K = auto() + ENC_ATTN_V = auto() + ENC_ATTN_OUT = auto() + ENC_ATTN_REL_B = auto() + ENC_FFN_NORM = auto() + ENC_FFN_GATE = auto() + ENC_FFN_DOWN = auto() + ENC_FFN_UP = auto() + ENC_OUTPUT_NORM = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -241,59 +279,88 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.ARCTIC: "arctic", MODEL_ARCH.DEEPSEEK2: "deepseek2", MODEL_ARCH.BITNET: "bitnet", + MODEL_ARCH.T5: "t5", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { - MODEL_TENSOR.TOKEN_EMBD: "token_embd", - MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm", - MODEL_TENSOR.TOKEN_TYPES: "token_types", - MODEL_TENSOR.POS_EMBD: "position_embd", - MODEL_TENSOR.OUTPUT_NORM: "output_norm", - MODEL_TENSOR.OUTPUT: "output", - MODEL_TENSOR.ROPE_FREQS: "rope_freqs", - MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long", - MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short", - MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", - MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", - MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", - MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q", - MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k", - MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", - MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", - MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", - MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", - MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", - MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm", - MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp", - MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp", - MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", - MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", - MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", - MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", - MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp", - MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp", - MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp", - MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn", - MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps", - MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", - MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", - MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", - MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", - MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", - MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", - MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", - MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", - MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", - MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", - MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", - MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a", - MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", - MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", - MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", - MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", - MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", - MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", - MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm", + MODEL_TENSOR.TOKEN_EMBD: "token_embd", + MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm", + MODEL_TENSOR.TOKEN_TYPES: "token_types", + MODEL_TENSOR.POS_EMBD: "position_embd", + MODEL_TENSOR.OUTPUT_NORM: "output_norm", + MODEL_TENSOR.OUTPUT: "output", + MODEL_TENSOR.ROPE_FREQS: "rope_freqs", + MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long", + MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short", + MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", + MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", + MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", + MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q", + MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k", + MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", + MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", + MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", + MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", + MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", + MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm", + MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp", + MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp", + MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", + MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", + MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", + MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", + MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp", + MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp", + MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp", + MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn", + MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps", + MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", + MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", + MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", + MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", + MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", + MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", + MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", + MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", + MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", + MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", + MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", + MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a", + MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", + MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", + MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", + MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", + MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", + MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", + MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm", + MODEL_TENSOR.DEC_ATTN_NORM: "dec.blk.{bid}.attn_norm", + MODEL_TENSOR.DEC_ATTN_Q: "dec.blk.{bid}.attn_q", + MODEL_TENSOR.DEC_ATTN_K: "dec.blk.{bid}.attn_k", + MODEL_TENSOR.DEC_ATTN_V: "dec.blk.{bid}.attn_v", + MODEL_TENSOR.DEC_ATTN_OUT: "dec.blk.{bid}.attn_o", + MODEL_TENSOR.DEC_ATTN_REL_B: "dec.blk.{bid}.attn_rel_b", + MODEL_TENSOR.DEC_CROSS_ATTN_NORM: "dec.blk.{bid}.cross_attn_norm", + MODEL_TENSOR.DEC_CROSS_ATTN_Q: "dec.blk.{bid}.cross_attn_q", + MODEL_TENSOR.DEC_CROSS_ATTN_K: "dec.blk.{bid}.cross_attn_k", + MODEL_TENSOR.DEC_CROSS_ATTN_V: "dec.blk.{bid}.cross_attn_v", + MODEL_TENSOR.DEC_CROSS_ATTN_OUT: "dec.blk.{bid}.cross_attn_o", + MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: "dec.blk.{bid}.cross_attn_rel_b", + MODEL_TENSOR.DEC_FFN_NORM: "dec.blk.{bid}.ffn_norm", + MODEL_TENSOR.DEC_FFN_GATE: "dec.blk.{bid}.ffn_gate", + MODEL_TENSOR.DEC_FFN_DOWN: "dec.blk.{bid}.ffn_down", + MODEL_TENSOR.DEC_FFN_UP: "dec.blk.{bid}.ffn_up", + MODEL_TENSOR.DEC_OUTPUT_NORM: "dec.output_norm", + MODEL_TENSOR.ENC_ATTN_NORM: "enc.blk.{bid}.attn_norm", + MODEL_TENSOR.ENC_ATTN_Q: "enc.blk.{bid}.attn_q", + MODEL_TENSOR.ENC_ATTN_K: "enc.blk.{bid}.attn_k", + MODEL_TENSOR.ENC_ATTN_V: "enc.blk.{bid}.attn_v", + MODEL_TENSOR.ENC_ATTN_OUT: "enc.blk.{bid}.attn_o", + MODEL_TENSOR.ENC_ATTN_REL_B: "enc.blk.{bid}.attn_rel_b", + MODEL_TENSOR.ENC_FFN_NORM: "enc.blk.{bid}.ffn_norm", + MODEL_TENSOR.ENC_FFN_GATE: "enc.blk.{bid}.ffn_gate", + MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down", + MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up", + MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -829,6 +896,38 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_SUB_NORM, MODEL_TENSOR.FFN_SUB_NORM, ], + MODEL_ARCH.T5: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.DEC_ATTN_NORM, + MODEL_TENSOR.DEC_ATTN_Q, + MODEL_TENSOR.DEC_ATTN_K, + MODEL_TENSOR.DEC_ATTN_V, + MODEL_TENSOR.DEC_ATTN_OUT, + MODEL_TENSOR.DEC_ATTN_REL_B, + MODEL_TENSOR.DEC_CROSS_ATTN_NORM, + MODEL_TENSOR.DEC_CROSS_ATTN_Q, + MODEL_TENSOR.DEC_CROSS_ATTN_K, + MODEL_TENSOR.DEC_CROSS_ATTN_V, + MODEL_TENSOR.DEC_CROSS_ATTN_OUT, + MODEL_TENSOR.DEC_CROSS_ATTN_REL_B, + MODEL_TENSOR.DEC_FFN_NORM, + MODEL_TENSOR.DEC_FFN_GATE, + MODEL_TENSOR.DEC_FFN_DOWN, + MODEL_TENSOR.DEC_FFN_UP, + MODEL_TENSOR.DEC_OUTPUT_NORM, + MODEL_TENSOR.ENC_ATTN_NORM, + MODEL_TENSOR.ENC_ATTN_Q, + MODEL_TENSOR.ENC_ATTN_K, + MODEL_TENSOR.ENC_ATTN_V, + MODEL_TENSOR.ENC_ATTN_OUT, + MODEL_TENSOR.ENC_ATTN_REL_B, + MODEL_TENSOR.ENC_FFN_NORM, + MODEL_TENSOR.ENC_FFN_GATE, + MODEL_TENSOR.ENC_FFN_DOWN, + MODEL_TENSOR.ENC_FFN_UP, + MODEL_TENSOR.ENC_OUTPUT_NORM, + ], # TODO } diff --git a/gguf-py/gguf/gguf_reader.py b/gguf-py/gguf/gguf_reader.py index e48bc00c3..20432bd25 100644 --- a/gguf-py/gguf/gguf_reader.py +++ b/gguf-py/gguf/gguf_reader.py @@ -69,6 +69,7 @@ class GGUFReader: # I - same as host, S - swapped byte_order: Literal['I'] | Literal['S'] = 'I' alignment: int = GGUF_DEFAULT_ALIGNMENT + data_offset: int # Note: Internal helper, API may change. gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = { @@ -88,9 +89,13 @@ class GGUFReader: def __init__(self, path: os.PathLike[str] | str, mode: Literal['r'] | Literal['r+'] | Literal['c'] = 'r'): self.data = np.memmap(path, mode = mode) offs = 0 + + # Check for GGUF magic if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC: raise ValueError('GGUF magic invalid') offs += 4 + + # Check GGUF version temp_version = self._get(offs, np.uint32) if temp_version[0] & 65535 == 0: # If we get 0 here that means it's (probably) a GGUF file created for @@ -103,12 +108,16 @@ class GGUFReader: self.fields: OrderedDict[str, ReaderField] = OrderedDict() self.tensors: list[ReaderTensor] = [] offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32])) + + # Check tensor count and kv count temp_counts = self._get(offs, np.uint64, 2) offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64])) offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64])) tensor_count, kv_count = temp_counts offs = self._build_fields(offs, kv_count) - offs, tensors_fields = self._build_tensors_fields(offs, tensor_count) + + # Build Tensor Info Fields + offs, tensors_fields = self._build_tensor_info(offs, tensor_count) new_align = self.fields.get('general.alignment') if new_align is not None: if new_align.types != [GGUFValueType.UINT32]: @@ -117,6 +126,7 @@ class GGUFReader: padding = offs % self.alignment if padding != 0: offs += self.alignment - padding + self.data_offset = offs self._build_tensors(offs, tensors_fields) _DT = TypeVar('_DT', bound = npt.DTypeLike) @@ -193,18 +203,29 @@ class GGUFReader: # We can't deal with this one. raise ValueError('Unknown/unhandled field type {gtype}') - def _get_tensor(self, orig_offs: int) -> ReaderField: + def _get_tensor_info_field(self, orig_offs: int) -> ReaderField: offs = orig_offs + + # Get Tensor Name name_len, name_data = self._get_str(offs) offs += int(name_len.nbytes + name_data.nbytes) + + # Get Tensor Dimensions Count n_dims = self._get(offs, np.uint32) offs += int(n_dims.nbytes) + + # Get Tensor Dimension Array dims = self._get(offs, np.uint64, n_dims[0]) offs += int(dims.nbytes) + + # Get Tensor Encoding Scheme Type raw_dtype = self._get(offs, np.uint32) offs += int(raw_dtype.nbytes) + + # Get Tensor Offset offset_tensor = self._get(offs, np.uint64) offs += int(offset_tensor.nbytes) + return ReaderField( orig_offs, str(bytes(name_data), encoding = 'utf-8'), @@ -233,10 +254,10 @@ class GGUFReader: offs += field_size return offs - def _build_tensors_fields(self, offs: int, count: int) -> tuple[int, list[ReaderField]]: + def _build_tensor_info(self, offs: int, count: int) -> tuple[int, list[ReaderField]]: tensor_fields = [] for _ in range(count): - field = self._get_tensor(offs) + field = self._get_tensor_info_field(offs) offs += sum(int(part.nbytes) for part in field.parts) tensor_fields.append(field) return offs, tensor_fields diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a697f657b..9869f6fe3 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -7,6 +7,7 @@ import struct import tempfile from dataclasses import dataclass from enum import Enum, auto +from pathlib import Path from io import BufferedWriter from typing import IO, Any, Sequence, Mapping from string import ascii_letters, digits @@ -31,6 +32,9 @@ from .quants import quant_shape_from_byte_shape logger = logging.getLogger(__name__) +SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf" + + @dataclass class TensorInfo: shape: Sequence[int] @@ -55,11 +59,11 @@ class WriterState(Enum): class GGUFWriter: - fout: BufferedWriter | None - path: os.PathLike[str] | str | None + fout: list[BufferedWriter] | None + path: Path | None temp_file: tempfile.SpooledTemporaryFile[bytes] | None - tensors: dict[str, TensorInfo] - kv_data: dict[str, GGUFValue] + tensors: list[dict[str, TensorInfo]] + kv_data: list[dict[str, GGUFValue]] state: WriterState _simple_value_packing = { GGUFValueType.UINT8: "B", @@ -76,26 +80,38 @@ class GGUFWriter: } def __init__( - self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, - endianess: GGUFEndian = GGUFEndian.LITTLE, + self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE, + split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False ): self.fout = None - self.path = path + self.path = Path(path) if path else None self.arch = arch self.endianess = endianess self.data_alignment = GGUF_DEFAULT_ALIGNMENT self.use_temp_file = use_temp_file self.temp_file = None - self.tensors = dict() - self.kv_data = dict() + self.tensors = [{}] + self.kv_data = [{}] + self.split_max_tensors = split_max_tensors + self.split_max_size = split_max_size + self.dry_run = dry_run + self.small_first_shard = small_first_shard logger.info("gguf: This GGUF file is for {0} Endian only".format( "Big" if self.endianess == GGUFEndian.BIG else "Little", )) self.state = WriterState.NO_FILE + if self.small_first_shard: + self.tensors.append({}) + self.add_architecture() - def open_output_file(self, path: os.PathLike[str] | str | None = None) -> None: + def format_shard_names(self, path: Path) -> list[Path]: + if len(self.tensors) == 1: + return [path] + return [path.with_name(SHARD_NAME_FORMAT.format(path.stem, i + 1, len(self.tensors))) for i in range(len(self.tensors))] + + def open_output_file(self, path: Path | None = None) -> None: if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path): # allow calling this multiple times as long as the path is the same return @@ -106,22 +122,58 @@ class GGUFWriter: self.path = path if self.path is not None: - if self.fout is not None: - self.fout.close() - self.fout = open(self.path, "wb") + filenames = self.print_plan() + self.fout = [open(filename, "wb") for filename in filenames] self.state = WriterState.EMPTY - def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None: + def print_plan(self) -> list[Path]: + logger.info("Writing the following files:") + assert self.path is not None + filenames = self.format_shard_names(self.path) + assert len(filenames) == len(self.tensors) + for name, tensors in zip(filenames, self.tensors): + logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in tensors.values()))}") + + if self.dry_run: + logger.info("Dry run, not writing files") + exit() + + return filenames + + def add_shard_kv_data(self) -> None: + if len(self.tensors) == 1: + return + + total_tensors = sum(len(t) for t in self.tensors) + assert self.fout is not None + total_splits = len(self.fout) + self.kv_data.extend({} for _ in range(len(self.kv_data), total_splits)) + for i, kv_data in enumerate(self.kv_data): + kv_data[Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16) + kv_data[Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(total_splits, GGUFValueType.UINT16) + kv_data[Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32) + + def write_header_to_file(self, path: Path | None = None) -> None: + if len(self.tensors) == 1 and (self.split_max_tensors != 0 or self.split_max_size != 0): + logger.warning("Model fails split requirements, not splitting") + self.open_output_file(path) if self.state is not WriterState.EMPTY: raise ValueError(f'Expected output file to be empty, got {self.state}') - self._write_packed(" None: @@ -129,13 +181,15 @@ class GGUFWriter: raise ValueError(f'Expected output file to contain the header, got {self.state}') assert self.fout is not None - kv_data = bytearray() + for fout, kv_data in zip(self.fout, self.kv_data): + kv_bytes = bytearray() - for key, val in self.kv_data.items(): - kv_data += self._pack_val(key, GGUFValueType.STRING, add_vtype=False) - kv_data += self._pack_val(val.value, val.type, add_vtype=True) + for key, val in kv_data.items(): + kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False) + kv_bytes += self._pack_val(val.value, val.type, add_vtype=True) + + fout.write(kv_bytes) - self.fout.write(kv_data) self.flush() self.state = WriterState.KV_DATA @@ -144,28 +198,29 @@ class GGUFWriter: raise ValueError(f'Expected output file to contain KV data, got {self.state}') assert self.fout is not None - ti_data = bytearray() - offset_tensor = 0 + for fout, tensors in zip(self.fout, self.tensors): + ti_data = bytearray() + offset_tensor = 0 - for name, ti in self.tensors.items(): - ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False) - n_dims = len(ti.shape) - ti_data += self._pack("I", n_dims) - for i in range(n_dims): - ti_data += self._pack("Q", ti.shape[n_dims - 1 - i]) - ti_data += self._pack("I", ti.dtype) - ti_data += self._pack("Q", offset_tensor) - offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment) + for name, ti in tensors.items(): + ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False) + n_dims = len(ti.shape) + ti_data += self._pack("I", n_dims) + for j in range(n_dims): + ti_data += self._pack("Q", ti.shape[n_dims - 1 - j]) + ti_data += self._pack("I", ti.dtype) + ti_data += self._pack("Q", offset_tensor) + offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment) - self.fout.write(ti_data) - self.flush() + fout.write(ti_data) + fout.flush() self.state = WriterState.TI_DATA def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None: - if key in self.kv_data: + if any(key in kv_data for kv_data in self.kv_data): raise ValueError(f'Duplicated key name {key!r}') - self.kv_data[key] = GGUFValue(value=val, type=vtype) + self.kv_data[0][key] = GGUFValue(value=val, type=vtype) def add_uint8(self, key: str, val: int) -> None: self.add_key_value(key,val, GGUFValueType.UINT8) @@ -206,9 +261,6 @@ class GGUFWriter: self.add_key_value(key, val, GGUFValueType.STRING) def add_array(self, key: str, val: Sequence[Any]) -> None: - if not isinstance(val, Sequence): - raise ValueError("Value must be a sequence for array type") - self.add_key_value(key, val, GGUFValueType.ARRAY) @staticmethod @@ -222,7 +274,7 @@ class GGUFWriter: if self.state is not WriterState.NO_FILE: raise ValueError(f'Expected output file to be not yet opened, got {self.state}') - if name in self.tensors: + if any(name in tensors for tensors in self.tensors): raise ValueError(f'Duplicated tensor name {name!r}') if raw_dtype is None: @@ -247,7 +299,18 @@ class GGUFWriter: if tensor_dtype == np.uint8: tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype) - self.tensors[name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes) + # make sure there is at least one tensor before splitting + if len(self.tensors[-1]) > 0: + if ( # split when over tensor limit + self.split_max_tensors != 0 + and len(self.tensors[-1]) >= self.split_max_tensors + ) or ( # split when over size limit + self.split_max_size != 0 + and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size + ): + self.tensors.append({}) + + self.tensors[-1][name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes) def add_tensor( self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, @@ -264,7 +327,7 @@ class GGUFWriter: self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype) if self.temp_file is None: - self.tensors[name].tensor = tensor + self.tensors[-1][name].tensor = tensor return tensor.tofile(self.temp_file) @@ -282,9 +345,24 @@ class GGUFWriter: if self.endianess == GGUFEndian.BIG: tensor.byteswap(inplace=True) - self.write_padding(self.fout, self.fout.tell()) - tensor.tofile(self.fout) - self.write_padding(self.fout, tensor.nbytes) + + file_id = -1 + for i, tensors in enumerate(self.tensors): + if len(tensors) > 0: + file_id = i + break + + fout = self.fout[file_id] + + # pop the first tensor info + # TODO: cleaner way to get the first key + first_tensor_name = [name for name, _ in zip(self.tensors[file_id].keys(), range(1))][0] + ti = self.tensors[file_id].pop(first_tensor_name) + assert ti.nbytes == tensor.nbytes + + self.write_padding(fout, fout.tell()) + tensor.tofile(fout) + self.write_padding(fout, tensor.nbytes) self.state = WriterState.WEIGHTS @@ -293,31 +371,43 @@ class GGUFWriter: assert self.fout is not None - self.write_padding(self.fout, self.fout.tell()) + for fout in self.fout: + self.write_padding(fout, fout.tell()) if self.temp_file is None: + shard_bar = None bar = None if progress: from tqdm import tqdm - total_bytes = sum(t.nbytes for t in self.tensors.values()) + total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values()) + if len(self.fout) > 1: + shard_bar = tqdm(desc=f"Shard (0/{len(self.fout)})", total=None, unit="byte", unit_scale=True) bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) - # relying on the fact that Python dicts preserve insertion order (since 3.7) - for ti in self.tensors.values(): - assert ti.tensor is not None # can only iterate once over the tensors - assert ti.tensor.nbytes == ti.nbytes - ti.tensor.tofile(self.fout) - if bar is not None: - bar.update(ti.nbytes) - self.write_padding(self.fout, ti.nbytes) - ti.tensor = None + for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)): + if shard_bar is not None: + shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})") + total = sum(ti.nbytes for ti in tensors.values()) + shard_bar.reset(total=(total if total > 0 else None)) + + # relying on the fact that Python dicts preserve insertion order (since 3.7) + for ti in tensors.values(): + assert ti.tensor is not None # can only iterate once over the tensors + assert ti.tensor.nbytes == ti.nbytes + ti.tensor.tofile(fout) + if shard_bar is not None: + shard_bar.update(ti.nbytes) + if bar is not None: + bar.update(ti.nbytes) + self.write_padding(fout, ti.nbytes) + ti.tensor = None else: self.temp_file.seek(0) - shutil.copyfileobj(self.temp_file, self.fout) + shutil.copyfileobj(self.temp_file, self.fout[0 if not self.small_first_shard else 1]) self.flush() self.temp_file.close() @@ -325,11 +415,13 @@ class GGUFWriter: def flush(self) -> None: assert self.fout is not None - self.fout.flush() + for fout in self.fout: + fout.flush() def close(self) -> None: if self.fout is not None: - self.fout.close() + for fout in self.fout: + fout.close() self.fout = None def add_architecture(self) -> None: @@ -400,6 +492,9 @@ class GGUFWriter: def add_parallel_residual(self, use: bool) -> None: self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use) + def add_decoder_start_token_id(self, id: int) -> None: + self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id) + def add_head_count(self, count: int) -> None: self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count) @@ -448,6 +543,9 @@ class GGUFWriter: def add_kv_lora_rank(self, length: int) -> None: self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length) + def add_relative_attn_buckets_count(self, value: int) -> None: + self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value) + def add_pooling_type(self, value: PoolingType) -> None: self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value) @@ -538,6 +636,12 @@ class GGUFWriter: def add_add_space_prefix(self, value: bool) -> None: self.add_bool(Keys.Tokenizer.ADD_PREFIX, value) + def add_remove_extra_whitespaces(self, value: bool) -> None: + self.add_bool(Keys.Tokenizer.REMOVE_EXTRA_WS, value) + + def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None: + self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap) + def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: if not isinstance(value, str): template_default = None @@ -599,9 +703,12 @@ class GGUFWriter: kv_data += self._pack("Q", len(encoded_val)) kv_data += encoded_val elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val: - ltype = GGUFValueType.get_type(val[0]) - if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]): - raise ValueError("All items in a GGUF array should be of the same type") + if isinstance(val, bytes): + ltype = GGUFValueType.UINT8 + else: + ltype = GGUFValueType.get_type(val[0]) + if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]): + raise ValueError("All items in a GGUF array should be of the same type") kv_data += self._pack("I", ltype) kv_data += self._pack("Q", len(val)) for item in val: @@ -611,6 +718,13 @@ class GGUFWriter: return kv_data - def _write_packed(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None: - assert self.fout is not None - self.fout.write(self._pack(fmt, value, skip_pack_prefix)) + @staticmethod + def format_n_bytes_to_str(num: int) -> str: + if num == 0: + return "negligible - metadata only" + fnum = float(num) + for unit in ("", "K", "M", "G"): + if abs(fnum) < 1000.0: + return f"{fnum:3.1f}{unit}" + fnum /= 1000.0 + return f"{fnum:.1f}T - over 1TB, split recommended" diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 350035bd9..7b047f241 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -24,6 +24,7 @@ class TensorNameMap: "backbone.embedding", # mamba "backbone.embeddings", # mamba-hf "transformer.in_out_embed", # Grok + "shared", # t5 ), # Token type embeddings @@ -421,6 +422,120 @@ class TensorNameMap: MODEL_TENSOR.FFN_SUB_NORM: ( "model.layers.{bid}.mlp.ffn_layernorm", # bitnet ), + + MODEL_TENSOR.DEC_ATTN_NORM: ( + "decoder.block.{bid}.layer.0.layer_norm", # t5 + ), + + MODEL_TENSOR.DEC_ATTN_Q: ( + "decoder.block.{bid}.layer.0.SelfAttention.q", # t5 + ), + + MODEL_TENSOR.DEC_ATTN_K: ( + "decoder.block.{bid}.layer.0.SelfAttention.k", # t5 + ), + + MODEL_TENSOR.DEC_ATTN_V: ( + "decoder.block.{bid}.layer.0.SelfAttention.v", # t5 + ), + + MODEL_TENSOR.DEC_ATTN_OUT: ( + "decoder.block.{bid}.layer.0.SelfAttention.o", # t5 + ), + + MODEL_TENSOR.DEC_ATTN_REL_B: ( + "decoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", # t5 + ), + + MODEL_TENSOR.DEC_CROSS_ATTN_NORM: ( + "decoder.block.{bid}.layer.1.layer_norm", # t5 + ), + + MODEL_TENSOR.DEC_CROSS_ATTN_Q: ( + "decoder.block.{bid}.layer.1.EncDecAttention.q", # t5 + ), + + MODEL_TENSOR.DEC_CROSS_ATTN_K: ( + "decoder.block.{bid}.layer.1.EncDecAttention.k", # t5 + ), + + MODEL_TENSOR.DEC_CROSS_ATTN_V: ( + "decoder.block.{bid}.layer.1.EncDecAttention.v", # t5 + ), + + MODEL_TENSOR.DEC_CROSS_ATTN_OUT: ( + "decoder.block.{bid}.layer.1.EncDecAttention.o", # t5 + ), + + MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: ( + "decoder.block.{bid}.layer.1.EncDecAttention.relative_attention_bias", # t5 + ), + + MODEL_TENSOR.DEC_FFN_NORM: ( + "decoder.block.{bid}.layer.2.layer_norm", # t5 + ), + + MODEL_TENSOR.DEC_FFN_GATE: ( + "decoder.block.{bid}.layer.2.DenseReluDense.wi_0", # flan-t5 + ), + + MODEL_TENSOR.DEC_FFN_UP: ( + "decoder.block.{bid}.layer.2.DenseReluDense.wi", # t5 + "decoder.block.{bid}.layer.2.DenseReluDense.wi_1", # flan-t5 + ), + + MODEL_TENSOR.DEC_FFN_DOWN: ( + "decoder.block.{bid}.layer.2.DenseReluDense.wo", # t5 + ), + + MODEL_TENSOR.DEC_OUTPUT_NORM: ( + "decoder.final_layer_norm", # t5 + ), + + MODEL_TENSOR.ENC_ATTN_NORM: ( + "encoder.block.{bid}.layer.0.layer_norm", # t5 + ), + + MODEL_TENSOR.ENC_ATTN_Q: ( + "encoder.block.{bid}.layer.0.SelfAttention.q", # t5 + ), + + MODEL_TENSOR.ENC_ATTN_K: ( + "encoder.block.{bid}.layer.0.SelfAttention.k", # t5 + ), + + MODEL_TENSOR.ENC_ATTN_V: ( + "encoder.block.{bid}.layer.0.SelfAttention.v", # t5 + ), + + MODEL_TENSOR.ENC_ATTN_OUT: ( + "encoder.block.{bid}.layer.0.SelfAttention.o", # t5 + ), + + MODEL_TENSOR.ENC_ATTN_REL_B: ( + "encoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", # t5 + ), + + MODEL_TENSOR.ENC_FFN_NORM: ( + "encoder.block.{bid}.layer.1.layer_norm", # t5 + ), + + MODEL_TENSOR.ENC_FFN_GATE: ( + "encoder.block.{bid}.layer.1.DenseReluDense.wi_0", # flan-t5 + ), + + MODEL_TENSOR.ENC_FFN_UP: ( + "encoder.block.{bid}.layer.1.DenseReluDense.wi", # t5 + "encoder.block.{bid}.layer.1.DenseReluDense.wi_1", # flan-t5 + ), + + MODEL_TENSOR.ENC_FFN_DOWN: ( + "encoder.block.{bid}.layer.1.DenseReluDense.wo", # t5 + ), + + MODEL_TENSOR.ENC_OUTPUT_NORM: ( + "encoder.final_layer_norm", # t5 + ), } # architecture-specific block mappings diff --git a/gguf-py/scripts/gguf-dump.py b/gguf-py/scripts/gguf-dump.py index 92d14d6cd..a73ca2776 100755 --- a/gguf-py/scripts/gguf-dump.py +++ b/gguf-py/scripts/gguf-dump.py @@ -208,7 +208,9 @@ def translate_tensor_name(name): 'ssm_d': 'State space model skip connection', 'ssm_dt': 'State space model time step', 'ssm_out': 'State space model output projection', - 'blk': 'Block' + 'blk': 'Block', + 'enc': 'Encoder', + 'dec': 'Decoder', } expanded_words = [] @@ -291,6 +293,10 @@ def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None tensor_group_name = "base" if tensor_components[0] == 'blk': tensor_group_name = f"{tensor_components[0]}.{tensor_components[1]}" + elif tensor_components[0] in ['enc', 'dec'] and tensor_components[1] == 'blk': + tensor_group_name = f"{tensor_components[0]}.{tensor_components[1]}.{tensor_components[2]}" + elif tensor_components[0] in ['enc', 'dec']: + tensor_group_name = f"{tensor_components[0]}" # Check if new Tensor Group if tensor_group_name not in tensor_groups: @@ -313,6 +319,27 @@ def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None markdown_content += "\n" + markdown_content += "### Tensor Data Offset\n" + markdown_content += '\n' + markdown_content += 'This table contains the offset and data segment relative to start of file\n' + markdown_content += '\n' + + tensor_mapping_table: list[dict[str, str | int]] = [] + for key, tensor in enumerate(reader.tensors): + data_offset_pretty = '{0:#16x}'.format(tensor.data_offset) + data_size_pretty = '{0:#16x}'.format(tensor.n_bytes) + tensor_mapping_table.append({"t_id":key, "layer_name":tensor.name, "data_offset":data_offset_pretty, "data_size":data_size_pretty}) + + tensors_mapping_table_header_map = [ + {'key_name':'t_id', 'header_name':'T_ID', 'align':'right'}, + {'key_name':'layer_name', 'header_name':'Tensor Layer Name', 'align':'left'}, + {'key_name':'data_offset', 'header_name':'Data Offset (B)', 'align':'right'}, + {'key_name':'data_size', 'header_name':'Data Size (B)', 'align':'right'}, + ] + + markdown_content += markdown_table_with_alignment_support(tensors_mapping_table_header_map, tensor_mapping_table) + markdown_content += "\n" + for group in tensor_prefix_order: tensors = tensor_groups[group] group_elements = sum(tensor.n_elements for tensor in tensors) @@ -364,6 +391,8 @@ def main() -> None: parser.add_argument("--no-tensors", action="store_true", help="Don't dump tensor metadata") parser.add_argument("--json", action="store_true", help="Produce JSON output") parser.add_argument("--json-array", action="store_true", help="Include full array values in JSON output (long)") + parser.add_argument("--data-offset", action="store_true", help="Start of data offset") + parser.add_argument("--data-alignment", action="store_true", help="Data alignment applied globally to data field") parser.add_argument("--markdown", action="store_true", help="Produce markdown output") parser.add_argument("--verbose", action="store_true", help="increase output verbosity") @@ -371,7 +400,7 @@ def main() -> None: logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) - if not args.json and not args.markdown: + if not args.json and not args.markdown and not args.data_offset and not args.data_alignment: logger.info(f'* Loading: {args.model}') reader = GGUFReader(args.model, 'r') @@ -380,6 +409,10 @@ def main() -> None: dump_metadata_json(reader, args) elif args.markdown: dump_markdown_metadata(reader, args) + elif args.data_offset: + print(reader.data_offset) # noqa: NP100 + elif args.data_alignment: + print(reader.alignment) # noqa: NP100 else: dump_metadata(reader, args) diff --git a/llama.cpp b/llama.cpp index 5377e77a0..afb3fc92b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -226,6 +226,7 @@ enum llm_arch { LLM_ARCH_ARCTIC, LLM_ARCH_DEEPSEEK2, LLM_ARCH_BITNET, + LLM_ARCH_T5, LLM_ARCH_UNKNOWN, }; @@ -265,6 +266,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_BITNET, "bitnet" }, + { LLM_ARCH_T5, "t5" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -297,6 +299,7 @@ enum llm_kv { LLM_KV_EXPERT_WEIGHTS_SCALE, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, + LLM_KV_DECODER_START_TOKEN_ID, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -309,6 +312,7 @@ enum llm_kv { LLM_KV_ATTENTION_CAUSAL, LLM_KV_ATTENTION_Q_LORA_RANK, LLM_KV_ATTENTION_KV_LORA_RANK, + LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_FREQ_BASE, @@ -346,6 +350,8 @@ enum llm_kv { LLM_KV_TOKENIZER_ADD_BOS, LLM_KV_TOKENIZER_ADD_EOS, LLM_KV_TOKENIZER_ADD_PREFIX, + LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, + LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_PREFIX_ID, @@ -383,18 +389,20 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, + { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, - { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, - { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, - { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" }, - { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" }, - { LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" }, - { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" }, - { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, - { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, - { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, - { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, - { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, + { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, + { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, + { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" }, + { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" }, + { LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" }, + { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" }, + { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, + { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, + { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, + { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, + { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, + { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, @@ -415,29 +423,31 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, - { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, - { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, - { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, - { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, - { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, - { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, - { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, - { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, - { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, - { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, - { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, - { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, - { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, - { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, - { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, - { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, - { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, - { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, - { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, - { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, - { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, - { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, - { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, + { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, + { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, + { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, + { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, + { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, + { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, + { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, + { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, + { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, + { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, + { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, + { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, + { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, + { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, + { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" }, + { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, + { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, + { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, + { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, + { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, + { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, + { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, }; struct LLM_KV { @@ -504,6 +514,34 @@ enum llm_tensor { LLM_TENSOR_ATTN_KV_A_NORM, LLM_TENSOR_ATTN_SUB_NORM, LLM_TENSOR_FFN_SUB_NORM, + LLM_TENSOR_DEC_ATTN_NORM, + LLM_TENSOR_DEC_ATTN_Q, + LLM_TENSOR_DEC_ATTN_K, + LLM_TENSOR_DEC_ATTN_V, + LLM_TENSOR_DEC_ATTN_OUT, + LLM_TENSOR_DEC_ATTN_REL_B, + LLM_TENSOR_DEC_CROSS_ATTN_NORM, + LLM_TENSOR_DEC_CROSS_ATTN_Q, + LLM_TENSOR_DEC_CROSS_ATTN_K, + LLM_TENSOR_DEC_CROSS_ATTN_V, + LLM_TENSOR_DEC_CROSS_ATTN_OUT, + LLM_TENSOR_DEC_CROSS_ATTN_REL_B, + LLM_TENSOR_DEC_FFN_NORM, + LLM_TENSOR_DEC_FFN_GATE, + LLM_TENSOR_DEC_FFN_DOWN, + LLM_TENSOR_DEC_FFN_UP, + LLM_TENSOR_DEC_OUTPUT_NORM, + LLM_TENSOR_ENC_ATTN_NORM, + LLM_TENSOR_ENC_ATTN_Q, + LLM_TENSOR_ENC_ATTN_K, + LLM_TENSOR_ENC_ATTN_V, + LLM_TENSOR_ENC_ATTN_OUT, + LLM_TENSOR_ENC_ATTN_REL_B, + LLM_TENSOR_ENC_FFN_NORM, + LLM_TENSOR_ENC_FFN_GATE, + LLM_TENSOR_ENC_FFN_DOWN, + LLM_TENSOR_ENC_FFN_UP, + LLM_TENSOR_ENC_OUTPUT_NORM, }; static const std::map> LLM_TENSOR_NAMES = { @@ -1135,6 +1173,41 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, }, }, + { + LLM_ARCH_T5, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_DEC_OUTPUT_NORM, "dec.output_norm" }, + { LLM_TENSOR_DEC_ATTN_NORM, "dec.blk.%d.attn_norm" }, + { LLM_TENSOR_DEC_ATTN_Q, "dec.blk.%d.attn_q" }, + { LLM_TENSOR_DEC_ATTN_K, "dec.blk.%d.attn_k" }, + { LLM_TENSOR_DEC_ATTN_V, "dec.blk.%d.attn_v" }, + { LLM_TENSOR_DEC_ATTN_OUT, "dec.blk.%d.attn_o" }, + { LLM_TENSOR_DEC_ATTN_REL_B, "dec.blk.%d.attn_rel_b" }, + { LLM_TENSOR_DEC_CROSS_ATTN_NORM, "dec.blk.%d.cross_attn_norm" }, + { LLM_TENSOR_DEC_CROSS_ATTN_Q, "dec.blk.%d.cross_attn_q" }, + { LLM_TENSOR_DEC_CROSS_ATTN_K, "dec.blk.%d.cross_attn_k" }, + { LLM_TENSOR_DEC_CROSS_ATTN_V, "dec.blk.%d.cross_attn_v" }, + { LLM_TENSOR_DEC_CROSS_ATTN_OUT, "dec.blk.%d.cross_attn_o" }, + { LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "dec.blk.%d.cross_attn_rel_b" }, + { LLM_TENSOR_DEC_FFN_NORM, "dec.blk.%d.ffn_norm" }, + { LLM_TENSOR_DEC_FFN_GATE, "dec.blk.%d.ffn_gate" }, + { LLM_TENSOR_DEC_FFN_DOWN, "dec.blk.%d.ffn_down" }, + { LLM_TENSOR_DEC_FFN_UP, "dec.blk.%d.ffn_up" }, + { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, + { LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" }, + { LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" }, + { LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" }, + { LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" }, + { LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" }, + { LLM_TENSOR_ENC_ATTN_REL_B, "enc.blk.%d.attn_rel_b" }, + { LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" }, + { LLM_TENSOR_ENC_FFN_GATE, "enc.blk.%d.ffn_gate" }, + { LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" }, + { LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2295,13 +2368,21 @@ struct llama_control_vector { int32_t layer_start = -1; int32_t layer_end = -1; - ggml_tensor * tensor_for(int il) const { + struct ggml_tensor * tensor_for(int il) const { if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) { return nullptr; } return tensors[il]; } + struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const { + ggml_tensor * layer_dir = tensor_for(il); + if (layer_dir != nullptr) { + cur = ggml_add(ctx, cur, layer_dir); + } + return cur; + } + ~llama_control_vector() { for (struct ggml_context * ctx : ctxs) { ggml_free(ctx); @@ -2356,6 +2437,11 @@ struct llama_vocab { bool tokenizer_add_bos = false; bool tokenizer_add_eos = false; bool tokenizer_ignore_merges = false; + bool tokenizer_remove_extra_whitespaces = false; + bool tokenizer_escape_whitespaces = true; + bool tokenizer_treat_whitespace_as_suffix = false; + + std::vector precompiled_charsmap; int find_bpe_rank(const std::string & token_left, const std::string & token_right) const { GGML_ASSERT(token_left.find(' ') == std::string::npos); @@ -4191,6 +4277,7 @@ static const char * llama_model_vocab_type_name(enum llama_vocab_type type){ case LLAMA_VOCAB_TYPE_SPM: return "SPM"; case LLAMA_VOCAB_TYPE_BPE: return "BPE"; case LLAMA_VOCAB_TYPE_WPM: return "WPM"; + case LLAMA_VOCAB_TYPE_UGM: return "UGM"; default: return "unknown"; } } @@ -4870,6 +4957,45 @@ static void llm_load_vocab( vocab.special_pad_id = -1; vocab.special_cls_id = -1; vocab.special_mask_id = -1; + } else if (tokenizer_model == "t5") { + vocab.type = LLAMA_VOCAB_TYPE_UGM; + + // default special tokens + vocab.special_bos_id = -1; + vocab.special_eos_id = 1; + vocab.special_unk_id = 2; + vocab.special_sep_id = -1; + vocab.special_pad_id = 0; + vocab.special_cls_id = -1; + vocab.special_mask_id = -1; + + const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str()); + if (add_space_prefix_keyidx != -1) { + vocab.tokenizer_add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx); + } // The default value of add_space_prefix is true. + + const int remove_extra_whitespaces_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS).c_str()); + if (remove_extra_whitespaces_keyidx != -1) { + vocab.tokenizer_remove_extra_whitespaces = gguf_get_val_bool(ctx, remove_extra_whitespaces_keyidx); + } // The default value of remove_extra_whitespaces is false. + + const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); + if (precompiled_charsmap_keyidx != -1) { + size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx); + const char * precompiled_charsmap = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); + vocab.precompiled_charsmap.assign(precompiled_charsmap, precompiled_charsmap + n_precompiled_charsmap); +#ifdef IS_BIG_ENDIAN + // correct endiannes of data in precompiled_charsmap binary blob + uint32_t * xcda_blob_size = (uint32_t *) &vocab.precompiled_charsmap[0]; + *xcda_blob_size = __builtin_bswap32(*xcda_blob_size); + assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap); + size_t xcda_array_size = *xcda_blob_size / sizeof(uint32_t); + uint32_t * xcda_array = (uint32_t *) &vocab.precompiled_charsmap[sizeof(uint32_t)]; + for (size_t i = 0; i < xcda_array_size; ++i) { + xcda_array[i] = __builtin_bswap32(xcda_array[i]); + } +#endif + } } else { throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } @@ -4952,6 +5078,10 @@ static void llm_load_vocab( vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; vocab.tokenizer_add_bos = true; vocab.tokenizer_add_eos = false; + } else if (vocab.type == LLAMA_VOCAB_TYPE_UGM) { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + vocab.tokenizer_add_bos = false; + vocab.tokenizer_add_eos = true; } else { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } @@ -7082,10 +7212,13 @@ static struct ggml_tensor * llm_build_ffn( struct ggml_tensor * cur, struct ggml_tensor * up, struct ggml_tensor * up_b, + struct ggml_tensor * up_s, struct ggml_tensor * gate, struct ggml_tensor * gate_b, + struct ggml_tensor * gate_s, struct ggml_tensor * down, struct ggml_tensor * down_b, + struct ggml_tensor * down_s, struct ggml_tensor * act_scales, llm_ffn_op_type type_op, llm_ffn_gate_type type_gate, @@ -7099,6 +7232,11 @@ static struct ggml_tensor * llm_build_ffn( cb(tmp, "ffn_up_b", il); } + if (up_s) { + tmp = ggml_mul(ctx, tmp, up_s); + cb(tmp, "ffn_up_s", il); + } + if (gate) { switch (type_gate) { case LLM_FFN_SEQ: @@ -7117,6 +7255,12 @@ static struct ggml_tensor * llm_build_ffn( cur = ggml_add(ctx, cur, gate_b); cb(cur, "ffn_gate_b", il); } + + if (gate_s) { + cur = ggml_mul(ctx, cur, gate_s); + cb(cur, "ffn_gate_s", il); + } + } else { cur = tmp; } @@ -7156,7 +7300,10 @@ static struct ggml_tensor * llm_build_ffn( cb(cur, "ffn_gate_par", il); } - cur = ggml_mul_mat(ctx, down, cur); + if (down) { + cur = ggml_mul_mat(ctx, down, cur); + } + if (down_b) { cb(cur, "ffn_down", il); } @@ -7165,6 +7312,11 @@ static struct ggml_tensor * llm_build_ffn( cur = ggml_add(ctx, cur, down_b); } + if (down_s) { + cur = ggml_mul(ctx, cur, down_s); + cb(cur, "ffn_down_s", il); + } + return cur; } @@ -7873,9 +8025,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -7901,10 +8053,7 @@ struct llm_build_context { cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); - ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); - if (layer_dir != nullptr) { - cur = ggml_add(ctx0, cur, layer_dir); - } + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -8010,15 +8159,16 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -8114,15 +8264,16 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -8229,18 +8380,17 @@ struct llm_build_context { // feed forward { cur = llm_build_ffn(ctx0, attn_norm, // !! use the attn norm, not the result - model.layers[il].ffn_up, NULL, - NULL, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "l_out", il); - cur = ggml_add(ctx0, cur, inpL); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -8392,10 +8542,7 @@ struct llm_build_context { cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); - ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); - if (layer_dir != nullptr) { - cur = ggml_add(ctx0, cur, layer_dir); - } + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -8526,10 +8673,7 @@ struct llm_build_context { cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); - ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); - if (layer_dir != nullptr) { - cur = ggml_add(ctx0, cur, layer_dir); - } + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -8627,16 +8771,20 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); } - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } cur = llm_build_norm(ctx0, inpL, hparams, @@ -8715,15 +8863,16 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -8899,23 +9048,23 @@ struct llm_build_context { // feed-forward network if (model.arch == LLM_ARCH_BERT) { cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); } else if (model.arch == LLM_ARCH_JINA_BERT_V2) { cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_PAR, cb, il); } else { cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); } @@ -9011,16 +9160,20 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); } - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } cur = llm_build_norm(ctx0, inpL, hparams, @@ -9145,15 +9298,16 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, model.layers[il].ffn_act, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -9293,15 +9447,16 @@ struct llm_build_context { cur = inpSA; } cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -9405,15 +9560,16 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -9517,14 +9673,15 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -9653,9 +9810,9 @@ struct llm_build_context { cb(cur_gate, "ffn_shexp_gate", il); ggml_tensor * cur_ffn = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up_shexp, NULL, - model.layers[il].ffn_gate_shexp, NULL, - model.layers[il].ffn_down_shexp, NULL, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur_ffn, "ffn_shexp", il); @@ -9670,6 +9827,7 @@ struct llm_build_context { } cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -9781,20 +9939,20 @@ struct llm_build_context { // FF { ffn_output = llm_build_ffn(ctx0, attn_norm_output, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(ffn_output, "ffn_out", il); } cur = ggml_add(ctx0, cur, ffn_output); - cb(cur, "l_out", il); - cur = ggml_add(ctx0, cur, inpL); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); + // input for next layer inpL = cur; } @@ -9926,8 +10084,10 @@ struct llm_build_context { } cur = ggml_add(ctx0, residual, cur); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); + // input for next layer inpL = cur; } @@ -10017,18 +10177,17 @@ struct llm_build_context { // feed-forward network { cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); } cur = ggml_add(ctx0, cur, sa_out); - cb(cur, "l_out", il); - cur = ggml_add(ctx0, cur, inpL); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -10126,16 +10285,20 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); } - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } cur = llm_build_norm(ctx0, inpL, hparams, @@ -10233,16 +10396,20 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); } - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } cur = llm_build_norm(ctx0, inpL, hparams, @@ -10346,14 +10513,15 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -10463,14 +10631,15 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -10599,9 +10768,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -10612,6 +10781,7 @@ struct llm_build_context { cb(cur, "hidden_scaled_ffn", -1); cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -10715,15 +10885,16 @@ struct llm_build_context { // feed-forward network { cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_GELU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); } cur = ggml_add(ctx0, cur, sa_out); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -10834,13 +11005,15 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -10989,6 +11162,7 @@ struct llm_build_context { // residual cur = ggml_add(ctx0, cur, inpL); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -11119,9 +11293,9 @@ struct llm_build_context { // feed-forward network { cur = llm_build_ffn(ctx0, ffn_inp, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -11130,6 +11304,7 @@ struct llm_build_context { // add together residual + FFN + self-attention cur = ggml_add(ctx0, cur, inpL); cur = ggml_add(ctx0, cur, attn_out); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -11255,9 +11430,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -11265,10 +11440,7 @@ struct llm_build_context { cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); - ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); - if (layer_dir != nullptr) { - cur = ggml_add(ctx0, cur, layer_dir); - } + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -11372,9 +11544,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); @@ -11382,8 +11554,12 @@ struct llm_build_context { cur = ggml_add(ctx0, cur, inpL); cb(cur, "ffn_out", il); - inpL = ggml_add(ctx0, cur, attn_out); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, attn_out); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } else { // attention and ffn are computed sequentially // x = x + attn(ln1(x)) @@ -11399,15 +11575,19 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } } @@ -11504,9 +11684,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -11534,10 +11714,7 @@ struct llm_build_context { cur = ggml_add(ctx0, cur, ffn_out); cb(cur, "ffn_out", il); - ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); - if (layer_dir != nullptr) { - cur = ggml_add(ctx0, cur, layer_dir); - } + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -11729,9 +11906,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -11757,9 +11934,9 @@ struct llm_build_context { // FFN shared expert { ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up_shexp, NULL, - model.layers[il].ffn_gate_shexp, NULL, - model.layers[il].ffn_down_shexp, NULL, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(ffn_shexp, "ffn_shexp", il); @@ -11770,6 +11947,7 @@ struct llm_build_context { } cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -11861,7 +12039,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, - nullptr, nullptr, + NULL, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cur = llm_build_norm(ctx0, cur, hparams, @@ -11888,35 +12066,28 @@ struct llm_build_context { cb(ffn_inp, "ffn_inp", il); // feed-forward forward - if (model.layers[il].ffn_gate_inp == nullptr) { - cur = llm_build_norm(ctx0, ffn_inp, hparams, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); - struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); - tmp = ggml_mul(ctx0, tmp, model.layers[il].ffn_up_scale); - cb(tmp, "ffn_up", il); + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_scale, + model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale, + NULL, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_sub_out", il); - cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur); - cur = ggml_mul(ctx0, cur, model.layers[il].ffn_gate_scale); - cb(cur, "ffn_gate", il); + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].ffn_sub_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_sub_norm", il); - cur = ggml_silu(ctx0, cur); - cb(cur, "ffn_silu", il); + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); + cb(cur, "ffn_down", il); - cur = ggml_mul(ctx0, cur, tmp); - cb(cur, "ffn_gate_par", il); - - cur = llm_build_norm(ctx0, cur, hparams, - model.layers[il].ffn_sub_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(cur, "ffn_sub_norm", il); - - cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); - cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); - cb(cur, "ffn_down", il); - } cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "l_out", il); @@ -13213,12 +13384,18 @@ static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED; } +static bool llama_is_unused_token(const llama_vocab& vocab, llama_token id) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); + return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNUSED; +} + static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE); GGML_ASSERT(llama_is_byte_token(vocab, id)); const auto & token_data = vocab.id_to_token.at(id); switch (llama_vocab_get_type(vocab)) { - case LLAMA_VOCAB_TYPE_SPM: { + case LLAMA_VOCAB_TYPE_SPM: + case LLAMA_VOCAB_TYPE_UGM: { auto buf = token_data.text.substr(3, 2); return strtol(buf.c_str(), NULL, 16); } @@ -13238,7 +13415,8 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE); static const char * hex = "0123456789ABCDEF"; switch (llama_vocab_get_type(vocab)) { - case LLAMA_VOCAB_TYPE_SPM: { + case LLAMA_VOCAB_TYPE_SPM: + case LLAMA_VOCAB_TYPE_UGM: { const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 }; auto token = vocab.token_to_id.find(buf); if (token != vocab.token_to_id.end()) { @@ -13826,6 +14004,383 @@ struct llm_tokenizer_wpm { const llama_vocab & vocab; }; +struct naive_trie { + naive_trie() : has_value(false), value(0) { + } + void insert(const char * key, size_t len, int32_t value = 0) { + if (len == 0) { + this->has_value = true; + this->value = value; + return; + } + char c = key[0]; + auto res = children.find(c); + if (res != children.end()) { + res->second.insert(key + 1, len - 1, value); + } else { + auto res = children.insert(std::make_pair(c, naive_trie())); + res.first->second.insert(key + 1, len - 1, value); + } + } + std::pair get_longest_prefix(const char * key, size_t len, size_t offset = 0) { + if (len == 0 || offset == len) { + return std::make_pair(key, offset); + } + char c = key[offset]; + auto res = children.find(c); + if (res != children.end()) { + return res->second.get_longest_prefix(key, len, offset + 1); + } else { + return std::make_pair(key, offset); + } + } + struct naive_trie * traverse(const char c) { + auto res = children.find(c); + if (res != children.end()) { + return &res->second; + } else { + return NULL; + } + } + std::map children; + bool has_value; + llama_token value; +}; + +struct llm_tokenizer_ugm { + llm_tokenizer_ugm(const llama_vocab & vocab) : vocab(vocab) { + if (vocab.precompiled_charsmap.size() > 0) { + size_t charsmap_offset = 0; + + // First four bytes of precompiled_charsmap contains length of binary + // blob containing XOR-compressed compact double array (XCDA) entries + uint32_t xcda_blob_size = *(const uint32_t *) &vocab.precompiled_charsmap[0]; + charsmap_offset += sizeof(xcda_blob_size); + if (xcda_blob_size + charsmap_offset >= vocab.precompiled_charsmap.size()) { + throw std::runtime_error("Index out of array bounds in precompiled charsmap!"); + } + + // Next xcda_blob_size bytes contain entries of XOR-compressed compact + // double array (XCDA). Each entry is bit-packed into a 32-bit integer. + xcda_array = (const uint32_t *) &vocab.precompiled_charsmap[charsmap_offset]; + xcda_array_size = xcda_blob_size / sizeof(uint32_t); + charsmap_offset += xcda_blob_size; + + // Remaining bytes of precompiled charsmap contain null-terminated + // replacement strings for prefixes matched by the XCDA. + prefix_replacements = &vocab.precompiled_charsmap[charsmap_offset]; + prefix_replacements_size = vocab.precompiled_charsmap.size() - charsmap_offset; + } + + for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) { + const auto &token_data = vocab.id_to_token[id]; + + if (llama_is_normal_token(vocab, id)) { + min_score = std::min(min_score, token_data.score); + max_score = std::max(max_score, token_data.score); + } + + if (llama_is_normal_token(vocab, id) || + llama_is_user_defined_token(vocab, id) || + llama_is_unused_token(vocab, id)) { + token_matcher.insert(token_data.text.data(), token_data.text.size(), id); + } + + if (llama_is_user_defined_token(vocab, id)) { + user_defined_token_matcher.insert(token_data.text.data(), token_data.text.size()); + } + } + + unknown_token_score = min_score - unknown_token_score_penalty; + } + + /* This implementation is based on SentencePiece optimized Viterbi algorithm for + * unigram language models. The general idea is to: + * - move along the input sequence in steps of one UTF code point, + * - at each step find all possible tokenizations of the prefix by + * traversing the tokens trie, + * - for each tokenization store the best one so far (by higher score) + * - use the position in sequence after given token as an index to store + * results + * - if there was no valid tokenization of the current UTF code point + * then use unknown token with additional score penalty + * After processing the whole sequence we backtrack from the end to get + * the best tokenization. + */ + void tokenize(const std::string & text, std::vector & output) { + // normalize the input first + std::string normalized; + normalize(text, &normalized); + size_t input_len = normalized.size(); + + // initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores + std::vector tokenization_results(input_len + 1, {0, 0, -FLT_MAX}); + // at the beginning tokenization score is zero + tokenization_results[0] = { 0, 0, 0 }; + + for (size_t input_offset = 0; input_offset < input_len;) { + size_t prefix_offset = input_offset; + // calculate how many code units are in the currently processed UTF code point + size_t n_utf8_code_units = std::min(utf8_len(normalized[input_offset]), input_len - input_offset); + + // traverse the token matcher trie to find a matching token + bool single_codepoint_token_found = false; + const struct best_tokenization & current_best = tokenization_results[input_offset]; + struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]); + + while (prefix_offset <= input_len && node != NULL) { + // check if we found valid token in prefix + if (node->has_value) { + // check if it corresponds to the whole UTF code point + if (prefix_offset - input_offset == n_utf8_code_units) { + single_codepoint_token_found = true; + } + llama_token token_id = node->value; + const auto &token_data = vocab.id_to_token[token_id]; + + // we set the user-defined token scores to 0 to make them more likely to be selected + // (normal token scores are log probabilities, so they are negative) + // score type is double here to make tokenization results exactly + // the same as in the HF tokenizer using SentencePiece + const double token_score = llama_is_user_defined_token(vocab, token_id) ? 0.0 : token_data.score; + const double challenger_score = current_best.score_sum + token_score; + struct best_tokenization & current_champ = tokenization_results[prefix_offset]; + if (challenger_score > current_champ.score_sum) { + struct best_tokenization challenger = { token_id, input_offset, (float) challenger_score }; + current_champ = challenger; + } + } + node = node->traverse(normalized[prefix_offset++]); + } + + // if we didn't find a valid token corresponding to the whole UTF code point + // then use unknown token as the tokenization of this UTF code point + if (!single_codepoint_token_found) { + const double challenger_score = current_best.score_sum + unknown_token_score; + prefix_offset = input_offset + n_utf8_code_units; + struct best_tokenization & current_champ = tokenization_results[prefix_offset]; + if (challenger_score > current_champ.score_sum) { + struct best_tokenization challenger = { vocab.special_unk_id, input_offset, (float) challenger_score }; + current_champ = challenger; + } + } + + // move to the next UTF code point + input_offset += n_utf8_code_units; + } + + // now backtrack from the end to gather token ids of the best tokenization + // merge sequences of consecutive unknown tokens into single unknown tokens + bool is_prev_unknown = false; + for (struct best_tokenization & tokenization = tokenization_results[input_len]; ; tokenization = tokenization_results[tokenization.input_offset]) { + bool is_unknown = tokenization.token_id == vocab.special_unk_id; + if (!(is_prev_unknown && is_unknown)) { + output.push_back(tokenization.token_id); + } + if (tokenization.input_offset == 0) { + break; + } + is_prev_unknown = is_unknown; + } + + // reverse the output since we added tokens starting from the end of the input + std::reverse(output.begin(), output.end()); + } + +private: + const llama_vocab & vocab; + + // helper structure for returning normalization results + struct normalization_result { + const char * normalized; + size_t normalized_len; + size_t consumed_input; + }; + + void normalize(const std::string& input, std::string * normalized) { + normalized->clear(); + normalized->reserve(input.size() * 3); + + const std::string space = vocab.tokenizer_escape_whitespaces ? escaped_space : " "; + + bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix; + bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix; + bool shall_merge_spaces = vocab.tokenizer_remove_extra_whitespaces; + + bool is_space_prepended = false; + bool processing_non_ws = false; + + size_t input_len = input.size(); + + for (size_t input_offset = 0; input_offset < input_len; ) { + auto norm_res = normalize_prefix(input, input_offset); + for (size_t i = 0; i < norm_res.normalized_len; i++) { + char c = norm_res.normalized[i]; + if (c != ' ') { + if (!processing_non_ws) { + processing_non_ws = true; + if ((shall_prepend_space && !is_space_prepended) || shall_merge_spaces) { + normalized->append(space); + is_space_prepended = true; + } + } + normalized->push_back(c); + } else { + if (processing_non_ws) { + processing_non_ws = false; + } + if (!shall_merge_spaces) { + normalized->append(space); + } + } + } + + input_offset += norm_res.consumed_input; + } + + if (shall_append_space) { + normalized->append(space); + } + } + + /* + * This structure is a view wrapper for XOR-compressed double array (XCDA) + * See Shunsuke Kanda (2018). Space- and Time-Efficient String Dictionaries. + * Eeach bit-packed entry contains: + * - BASE array value in bits 10-30 + * - LCHECK array value in bits 0-7 + * - LEAF array value in bit 9 + * Entries containing indexes of replacement sequences have set bit 31 + */ + struct xcda_array_view { + public: + xcda_array_view(const uint32_t * xcda_array, size_t xcda_array_size) : xcda_array(xcda_array), xcda_array_size(xcda_array_size) { + } + uint32_t get_base(size_t index) { + uint32_t packed_node = get_node(index); + return (packed_node >> 10) << ((packed_node & (1U << 9)) >> 6); + } + uint32_t get_lcheck(size_t index) { + uint32_t packed_node = get_node(index); + return packed_node & ((1U << 31) | 0xff); + } + bool get_leaf(size_t index) { + uint32_t packed_node = get_node(index); + return (packed_node >> 8) & 1; + } + uint32_t get_value(size_t index) { + uint32_t packed_node = get_node(index); + return packed_node & ((1U << 31) - 1); + } + private: + uint32_t get_node(size_t index) { + if (index > xcda_array_size) { + throw std::runtime_error("Index out of array bounds in XCDA array!"); + } + return xcda_array[index]; + } + const uint32_t * xcda_array; + size_t xcda_array_size; + }; + + struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) { + if (input_offset == input.size()) { + return { &input[input_offset], 0, 0 }; + } + + // if input prefix matches some user-defined token return this token as normalization result + auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset); + if (user_defined_token_match.second > 0) { + return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second }; + } + + size_t longest_prefix_length = 0; + size_t longest_prefix_offset = 0; + + if (xcda_array_size > 0) { + struct xcda_array_view xcda_view(xcda_array, xcda_array_size); + + // Find the longest normalized sequence matching the input prefix by walking + // the XOR-compressed compact double array (XCDA) starting from the root node + // We find the index of the next node by calculating BASE[s] ^ c where s is + // the index of the previous node and c is a numerical character value + uint32_t node_index = 0; + // get BASE of the root node + node_index = xcda_view.get_base(node_index); + for (size_t prefix_offset = input_offset; prefix_offset < input.size(); prefix_offset++) { + unsigned char c = input[prefix_offset]; + if (c == 0) { + break; + } + node_index ^= c; + // if value of LCHECK is not c it means that this is not a child of + // the previous node, so we stop matching + if (xcda_view.get_lcheck(node_index) != c) { + break; + } + bool is_leaf = xcda_view.get_leaf(node_index); + // get BASE of the current node + node_index ^= xcda_view.get_base(node_index); + // if LEAF of the current node is true, it means that its BASE points to the node + // containing index of replacement sequence for currently matched input prefix + if (is_leaf) + { + longest_prefix_length = prefix_offset - input_offset + 1; + // get index of replacement sequence for currently matched input prefix + longest_prefix_offset = xcda_view.get_value(node_index); + } + } + } + + if (longest_prefix_length > 0) { + // we have a match, so return the replacement sequence + if (longest_prefix_offset >= prefix_replacements_size) { + throw std::runtime_error("Index out of array bounds in precompiled charsmap!"); + } + const char * prefix_replacement = &prefix_replacements[longest_prefix_offset]; + return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length }; + } else { + // check if the input prefix contains a valid sequence of UTF-8 code units + try { + // if yes, return this sequence unmodified + size_t prefix_offset = input_offset; + unicode_cpt_from_utf8(input, prefix_offset); + return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset }; + } catch(std::invalid_argument & ex) { + // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER + return { "\xEF\xBF\xBD", 3, 1 }; + } + } + } + + // escaped space symbol - U+2581 (Lower One Eighth Block) + const std::string escaped_space = "\xE2\x96\x81"; + + const char * prefix_replacements = NULL; + size_t prefix_replacements_size = 0; + + const uint32_t * xcda_array = NULL; + size_t xcda_array_size = 0; + + struct naive_trie user_defined_token_matcher; + + // this structure stores the best tokenization so far at input_offset + struct best_tokenization { + llama_token token_id; + size_t input_offset; + float score_sum; + }; + + float min_score = FLT_MAX; + float max_score = -FLT_MAX; + + float unknown_token_score_penalty = 10.0; + float unknown_token_score; + + struct naive_trie token_matcher; +}; + + typedef enum FRAGMENT_BUFFER_VARIANT_TYPE { FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN, FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT @@ -14086,6 +14641,39 @@ static std::vector llama_tokenize_internal(const llama_vocab & output.push_back(vocab.special_sep_id); } } break; + case LLAMA_VOCAB_TYPE_UGM: + { + llm_tokenizer_ugm tokenizer(vocab); + + if (add_special && vocab.tokenizer_add_bos != 0) { + GGML_ASSERT(vocab.special_bos_id != -1); + output.push_back(vocab.special_bos_id); + } + + for (const auto & fragment : fragment_buffer) { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { + auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); +#ifdef PRETOKENIZERDEBUG + LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); +#endif + tokenizer.tokenize(raw_text, output); + } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + output.push_back(fragment.token); + } + } + + if (add_special && vocab.tokenizer_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) { + LLAMA_LOG_WARN( + "%s: Added a BOS token to the prompt as specified by the model but the prompt " + "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " + "Are you sure this is what you want?\n", __FUNCTION__); + } + + if (add_special && vocab.tokenizer_add_eos == 1) { + GGML_ASSERT(vocab.special_eos_id != -1); + output.push_back(vocab.special_eos_id); + } + } break; case LLAMA_VOCAB_TYPE_NONE: GGML_ASSERT(false); } @@ -14500,7 +15088,8 @@ struct llama_grammar * llama_grammar_init( continue; } if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) { - throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %zu", i)); + LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i); + return nullptr; } } @@ -16963,6 +17552,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_BLOOM: case LLM_ARCH_MAMBA: case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_T5: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values @@ -18658,6 +19248,10 @@ llama_token llama_token_eot(const struct llama_model * model) { return model->vocab.special_eot_id; } +llama_token llama_token_pad(const struct llama_model * model) { + return model->vocab.special_pad_id; +} + int32_t llama_tokenize( const struct llama_model * model, const char * text, @@ -18724,7 +19318,8 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token if (0 <= token && token < llama_n_vocab(model)) { switch (llama_vocab_get_type(model->vocab)) { case LLAMA_VOCAB_TYPE_WPM: - case LLAMA_VOCAB_TYPE_SPM: { + case LLAMA_VOCAB_TYPE_SPM: + case LLAMA_VOCAB_TYPE_UGM: { // NOTE: we accept all unsupported token types, // suppressing them like CONTROL tokens. if (llama_is_normal_token(model->vocab, token)) { @@ -18818,10 +19413,10 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|im_start|>assistant\n"; } - } else if (tmpl == "llama2" || tmpl.find("[INST]") != std::string::npos) { + } else if (tmpl == "llama2" || tmpl == "mistral" || tmpl.find("[INST]") != std::string::npos) { // llama2 template and its variants // [variant] support system message - bool support_system_message = tmpl.find("<>") != std::string::npos; + bool support_system_message = tmpl.find("<>") != std::string::npos || tmpl == "mistral"; // [variant] space before + after response bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos; // [variant] add BOS inside history diff --git a/llama.h b/llama.h index 53e06d9db..88eecb0ed 100644 --- a/llama.h +++ b/llama.h @@ -67,6 +67,7 @@ extern "C" { LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece + LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram }; // pre-tokenization types @@ -857,6 +858,7 @@ extern "C" { LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line + LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding // Returns -1 if unknown, 1 for true or 0 for false. LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model); @@ -924,6 +926,12 @@ extern "C" { // Grammar // + /// Initialize a llama_grammar. + /// + /// @param rules The rule elements of the grammar to initialize. + /// @param n_rules The number of rules. + /// @param start_rule_index The index of the root rule (the starting point of the grammar). + /// @return The initialized llama_grammar or nullptr if initialization failed. LLAMA_API struct llama_grammar * llama_grammar_init( const llama_grammar_element ** rules, size_t n_rules, diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index cef9a650b..d19ba8633 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -7,6 +7,7 @@ #include #include "llama.h" +#include "common.h" int main(void) { llama_chat_message conversation[] = { @@ -119,5 +120,24 @@ int main(void) { std::cout << output << "\n-------------------------\n"; assert(output == expected); } + + // test llama_chat_format_single + std::cout << "\n\n=== llama_chat_format_single ===\n\n"; + std::vector chat2; + chat2.push_back({"system", "You are a helpful assistant"}); + chat2.push_back({"user", "Hello"}); + chat2.push_back({"assistant", "I am assistant"}); + llama_chat_msg new_msg{"user", "How are you"}; + + auto fmt_single = [&](std::string tmpl) { + auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true); + std::cout << "fmt_single(" << tmpl << ")\n" << output << "\n-------------------------\n"; + return output; + }; + assert(fmt_single("chatml") == "<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n"); + assert(fmt_single("llama2") == "[INST] How are you [/INST]"); + assert(fmt_single("gemma") == "user\nHow are you\nmodel\n"); + assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); + return 0; } diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 96f90c01e..0e21dc795 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -15,8 +15,6 @@ using json = nlohmann::ordered_json; -//#define INCLUDE_FAILING_TESTS 1 - static llama_grammar* build_grammar(const std::string & grammar_str) { auto parsed_grammar = grammar_parser::parse(grammar_str.c_str()); @@ -36,10 +34,10 @@ static llama_grammar* build_grammar(const std::string & grammar_str) { static bool test_build_grammar_fails(const std::string & grammar_str) { fprintf(stderr, "⚫ Testing failure for grammar: %s\n", grammar_str.c_str()); bool grammar_fails = false; - try { - build_grammar(grammar_str); + llama_grammar * grammar = build_grammar(grammar_str); + if (grammar != nullptr) { fprintf(stderr, " ❌ Expected build failure, but succeeded\n"); - } catch (const std::exception & err) { + } else { grammar_fails = true; fprintf(stdout, " ✅︎\n"); } @@ -148,6 +146,250 @@ static void test_schema(const std::string & test_desc, const std::string & schem } static void test_simple_grammar() { + test_schema( + "min 0", + R"""({ + "type": "integer", + "minimum": 0 + })""", + // Passing strings + { + "0", + "10", + "12", + "10000", + }, + // Failing strings + { + "-1", + "-10", + "-10000", + "-100000000000000000000000000000000", + "100000000000000000000000000000000", + "00", + "01", + "-0", + } + ); + test_schema( + "min 2", + // Schema + R"""({ + "type": "integer", + "minimum": 2 + })""", + // Passing strings + { + "2", + "3", + "4", + "10", + "20", + "1234567890000000", + }, + // Failing strings + { + "0", + "1", + "-1", + "-100", + "0", + "1", + "01", + "02", + "12345678900000000", + } + ); + test_schema( + "min 456", + R"""({ + "type": "integer", + "minimum": 456 + })""", + // Passing strings + { + "456", + "4560", + "457", + "460", + "500", + }, + // Failing strings + { + "455", + "356", + "50", + "050", + "-1", + "-456", + } + ); + test_schema( + "min -123", + R"""({ + "type": "integer", + "minimum": -123 + })""", + // Passing strings + { + "-123", + "-122", + "-11", + "-1", + "0", + "1", + "123", + "1234", + "2345", + }, + // Failing strings + { + "-1234", + "-124", + } + ); + + test_schema( + "max 9999", + // Schema + R"""({ + "type": "integer", + "maximum": 9999 + })""", + // Passing strings + { + "-99999", + "0", + "9999", + }, + // Failing strings + { + "10000", + "99991", + } + ); + test_schema( + "max -9999", + // Schema + R"""({ + "type": "integer", + "maximum": -9999 + })""", + // Passing strings + { + "-10000", + "-9999", + }, + // Failing strings + { + "-9998", + "0", + "9999", + } + ); + test_schema( + "min 5 max 30", + // Schema + R"""({ + "type": "integer", + "minimum": 5, + "maximum": 30 + })""", + // Passing strings + { + "5", + "10", + "30", + }, + // Failing strings + { + "05", + "4", + "-1", + "31", + "123", + "0123", + } + ); + test_schema( + "min -1 max 1", + R"""({ + "type": "integer", + "minimum": -1, + "maximum": 1 + })""", + // Passing strings + { + "-1", + "0", + "1", + }, + // Failing strings + { + "-11", + "-10", + "-2", + "2", + "10", + "11", + } + ); + test_schema( + "min -123 max 42", + R"""({ + "type": "integer", + "minimum": -123, + "maximum": 42 + })""", + // Passing strings + { + "-123", + "-122", + "-13", + "-11", + "-2", + "-1", + "0", + "1", + "5", + "10", + "39", + "40", + "42", + }, + // Failing strings + { + "-0123", + "-124", + "-1123", + "-200", + "43", + "123", + "0123", + } + ); + test_schema( + "exclusive min / max", + // Schema + R"""({ + "type": "integer", + "exclusiveMinimum": 0, + "exclusiveMaximum": 10000 + })""", + // Passing strings + { + "1", + "9999", + }, + // Failing strings + { + "0", + "01", + "10000", + "99999", + } + ); + // Test case for a simple grammar test_grammar( "simple grammar", @@ -510,7 +752,7 @@ static void test_json_schema() { )""", // Passing strings { - "{}", + R"""({})""", R"""({"foo": "bar"})""", }, // Failing strings @@ -518,7 +760,7 @@ static void test_json_schema() { "", "[]", "null", - "\"\"", + R"""("")""", "true", } ); @@ -526,16 +768,14 @@ static void test_json_schema() { test_schema( "exotic formats (list)", // Schema - R"""( - { + R"""({ "items": [ { "format": "date" }, { "format": "uuid" }, { "format": "time" }, { "format": "date-time" } ] - } - )""", + })""", // Passing strings { // "{}", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it? @@ -554,125 +794,113 @@ static void test_json_schema() { test_schema( "string", // Schema - R"""( - { - "type": "string" - } - )""", + R"""({ + "type": "string" + })""", // Passing strings { - "\"foo\"", - "\"bar\"", - "\"\"", + R"""("foo")""", + R"""("bar")""", + R"""("")""", }, // Failing strings { - "{}", - "\"foo\": \"bar\"", + R"""({})""", + R"""("foo": "bar")""", } ); test_schema( "string w/ min length 1", // Schema - R"""( - { - "type": "string", - "minLength": 1 - } - )""", + R"""({ + "type": "string", + "minLength": 1 + })""", // Passing strings { - "\"foo\"", - "\"bar\"", + R"""("foo")""", + R"""("bar")""", }, // Failing strings { - "\"\"", - "{}", - "\"foo\": \"bar\"", + R"""("")""", + R"""({})""", + R"""("foo": "bar")""", } ); test_schema( "string w/ min length 3", // Schema - R"""( - { + R"""({ "type": "string", "minLength": 3 - } - )""", + })""", // Passing strings { - "\"foo\"", - "\"bar\"", - "\"foobar\"", + R"""("foo")""", + R"""("bar")""", + R"""("foobar")""", }, // Failing strings { - "\"\"", - "\"f\"", - "\"fo\"", + R"""("")""", + R"""("f")""", + R"""("fo")""", } ); test_schema( "string w/ max length", // Schema - R"""( - { - "type": "string", - "maxLength": 3 - } - )""", + R"""({ + "type": "string", + "maxLength": 3 + })""", // Passing strings { - "\"foo\"", - "\"bar\"", - "\"\"", - "\"f\"", - "\"fo\"", + R"""("foo")""", + R"""("bar")""", + R"""("")""", + R"""("f")""", + R"""("fo")""", }, // Failing strings { - "\"foobar\"", + R"""("foobar")""", } ); test_schema( "string w/ min & max length", // Schema - R"""( - { - "type": "string", - "minLength": 1, - "maxLength": 4 - } - )""", + R"""({ + "type": "string", + "minLength": 1, + "maxLength": 4 + })""", // Passing strings { - "\"foo\"", - "\"bar\"", - "\"f\"", - "\"barf\"", + R"""("foo")""", + R"""("bar")""", + R"""("f")""", + R"""("barf")""", }, // Failing strings { - "\"\"", - "\"barfo\"", - "\"foobar\"", + R"""("")""", + R"""("barfo")""", + R"""("foobar")""", } ); test_schema( "boolean", // Schema - R"""( - { - "type": "boolean" - } - )""", + R"""({ + "type": "boolean" + })""", // Passing strings { "true", @@ -680,123 +908,137 @@ static void test_json_schema() { }, // Failing strings { - "\"\"", - "\"true\"", - "True", - "FALSE", + R"""("")""", + R"""("true")""", + R"""(True)""", + R"""(FALSE)""", } ); test_schema( "integer", // Schema - R"""( - { - "type": "integer" - } - )""", + R"""({ + "type": "integer" + })""", // Passing strings { - "0", - "12345", - "1234567890123456" + R"""(0)""", + R"""(12345)""", + R"""(1234567890123456)""", }, // Failing strings { - "", - "01", - "007", - "12345678901234567" + R"""()""", + R"""(01)""", + R"""(007)""", + R"""(12345678901234567 )""", } ); test_schema( "string const", // Schema - R"""( - { - "const": "foo" - } - )""", + R"""({ + "const": "foo" + })""", // Passing strings { - "\"foo\"", + R"""("foo")""", }, // Failing strings { - "foo", - "\"bar\"", + R"""(foo)""", + R"""("bar")""", } ); test_schema( "non-string const", // Schema - R"""( - { - "const": true - } - )""", + R"""({ + "const": true + })""", // Passing strings { - "true", + R"""(true)""", }, // Failing strings { - "", - "foo", - "\"true\"", + R"""()""", + R"""(foo)""", + R"""("true")""", } ); test_schema( "non-string const", // Schema + R"""({ + "enum": ["red", "amber", "green", null, 42, ["foo"]] + })""", + // Passing strings + { + R"""("red")""", + R"""(null)""", + R"""(42)""", + R"""(["foo"])""", + }, + // Failing strings + { + R"""()""", + R"""(420)""", + R"""(true)""", + R"""(foo)""", + } + ); + + test_schema( + "", + // Schema R"""( { - "enum": ["red", "amber", "green", null, 42, ["foo"]] + "type": ["array", "null"], + "items": { "type": "string" } } )""", // Passing strings { - "\"red\"", "null", - "42", - "[\"foo\"]", + "[]", + "[\"123\"]", + "[\"foo\", \"bar\"]", }, // Failing strings { "", - "420", - "true", - "foo", + "[123]", + "\"foo\"", + "[\"foo\", 42]", } ); - test_schema( "min+max items", // Schema - R"""( - { - "items": { - "type": ["number", "integer"] - }, - "minItems": 3, - "maxItems": 5 - } - )""", + R"""({ + "items": { + "type": ["number", "integer"] + }, + "minItems": 3, + "maxItems": 5 + })""", // Passing strings { - "[1, 2, 3]", - "[1, 2, 3, 4]", - "[1, 2, 3, 4, 5]", + R"""([1, 2, 3])""", + R"""([1, 2, 3, 4])""", + R"""([1, 2, 3, 4, 5])""", }, // Failing strings { - "[1, 2]", - "[1, 2, 3, 4, 5, 6]", - "1" + R"""([1, 2])""", + R"""([1, 2, 3, 4, 5, 6])""", + R"""(1)""", } ); @@ -804,16 +1046,14 @@ static void test_json_schema() { test_schema( "object properties", // Schema - R"""( - { + R"""({ "type": "object", "properties": { "number": { "type": "number" }, "street_name": { "type": "string" }, "street_type": { "enum": ["Street", "Avenue", "Boulevard"] } } - } - )""", + })""", // Passing strings { R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""", @@ -823,12 +1063,8 @@ static void test_json_schema() { // "By extension, even an empty object is valid" R"""({})""", // "By default, providing additional properties is valid" -#ifdef INCLUDE_FAILING_TESTS - // TODO: The following should pass, but currently FAILS. Additional properties should be permitted by default. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""", - // TODO: Spaces should be permitted around enum values, but currently they fail to pass. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""", -#endif }, // Failing strings { @@ -841,13 +1077,35 @@ static void test_json_schema() { } ); + test_schema( + "additional properties can't override other properties", + R"""({ + "properties": { + "a": {"type": "integer"}, + "b": {"type": "integer"} + }, + "additionalProperties": true + })""", + // Passing strings + { + R"""({"a": 42})""", + R"""({"c": ""})""", + R"""({"a": 42, "c": ""})""", + R"""({"a_": ""})""", + }, + // Failing strings + { + R"""()""", + R"""({"a": ""})""", + R"""({"a": "", "b": ""})""", + } + ); // Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties) test_schema( "object properties, additionalProperties: true", // Schema - R"""( - { + R"""({ "type": "object", "properties": { "number": { "type": "number" }, @@ -855,26 +1113,18 @@ static void test_json_schema() { "street_type": { "enum": ["Street", "Avenue", "Boulevard"] } }, "additionalProperties": true - } - )""", + })""", // Passing strings { // "By extension, even an empty object is valid" R"""({})""", -#ifdef INCLUDE_FAILING_TESTS - // TODO: Following line should pass and doesn't R"""({"number":1600,"street_name":"Pennsylvania","street_type":"Avenue"})""", // "By default, leaving out properties is valid" - // TODO: Following line should pass and doesn't R"""({ "street_name": "Pennsylvania" })""", - // TODO: Following line should pass and doesn't R"""({ "number": 1600, "street_name": "Pennsylvania" })""", // "By default, providing additional properties is valid" - // TODO: The following should pass, but currently FAILS. Additional properties should be permitted by default. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""", - // TODO: Spaces should be permitted around enum values, but currently they fail to pass. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""", -#endif }, // Failing strings { @@ -889,8 +1139,7 @@ static void test_json_schema() { test_schema( "required + optional props each in original order", // Schema - R"""( - { + R"""({ "type": "object", "properties": { "number": { "type": "number" }, @@ -898,18 +1147,15 @@ static void test_json_schema() { "street_type": { "enum": ["Street", "Avenue", "Boulevard"] } }, "additionalProperties": false - } - )""", + })""", // Passing strings { R"""({ "street_name": "Pennsylvania" })""", R"""({ "number": 1600, "street_type":"Avenue"})""", R"""({ "number": 1600, "street_name": "Pennsylvania" })""", R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""", -#ifdef INCLUDE_FAILING_TESTS - // TODO: Spaces should be permitted around enum values, but currently they fail to pass. + // Spaces are permitted around enum values R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""", -#endif }, // Failing strings { @@ -923,18 +1169,16 @@ static void test_json_schema() { test_schema( "required + optional props each in original order", // Schema - R"""( - { - "properties": { - "b": {"type": "string"}, - "a": {"type": "string"}, - "d": {"type": "string"}, - "c": {"type": "string"} - }, - "required": ["a", "b"], - "additionalProperties": false - } - )""", + R"""({ + "properties": { + "b": {"type": "string"}, + "a": {"type": "string"}, + "d": {"type": "string"}, + "c": {"type": "string"} + }, + "required": ["a", "b"], + "additionalProperties": false + })""", // Passing strings { R"""({"b": "foo", "a": "bar"})""", @@ -954,8 +1198,7 @@ static void test_json_schema() { test_schema( "required props", // Schema - R"""( - { + R"""({ "$schema": "https://json-schema.org/draft/2020-12/schema", "$id": "https://example.com/product.schema.json", "title": "Product", @@ -1001,8 +1244,7 @@ static void test_json_schema() { } }, "required": [ "productId", "productName", "price" ] - } - )""", + })""", // Passing strings { R"""({"productId": 1, "productName": "A green door", "price": 12.50})""", diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index 87bc66b69..3aaa11833 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -80,6 +80,232 @@ static void test_all(const std::string & lang, std::function grammar_rules(parsed_grammar.c_rules()); grammar = llama_grammar_init( grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + if (grammar == nullptr) + { + throw std::runtime_error("Failed to initialize llama_grammar"); + } std::vector> expected_stacks = { { diff --git a/unicode.cpp b/unicode.cpp index c0b76bf20..8692924b9 100644 --- a/unicode.cpp +++ b/unicode.cpp @@ -23,7 +23,7 @@ static std::string unicode_cpts_to_utf8(const std::vector & cps) { return result; } -static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) { +uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) { assert(offset < utf8.size()); if (!(utf8[offset + 0] & 0x80)) { auto result = utf8[offset + 0]; diff --git a/unicode.h b/unicode.h index 6c488970a..30b07ba7f 100644 --- a/unicode.h +++ b/unicode.h @@ -48,6 +48,7 @@ struct codepoint_flags { std::string unicode_cpt_to_utf8(uint32_t cp); +uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset); std::vector unicode_cpts_from_utf8(const std::string & utf8); std::vector unicode_cpts_normalize_nfd(const std::vector & cpts);