diff --git a/.dockerignore b/.dockerignore
index 8916e2a66..064b7c7be 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -1,7 +1,7 @@
*.o
*.a
.cache/
-.git/
+# Do not ignore .git directory, otherwise the reported build number will always be 0
.github/
.gitignore
.vs/
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 1777489ec..e6a977b60 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -956,6 +956,7 @@ jobs:
cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/sycl7.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/svml_dispmd.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libmmd.dll" ./build/bin
+ cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libiomp5md.dll" ./build/bin
echo "cp oneAPI running time dll files to ./build/bin done"
7z a llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip ./build/bin/*
@@ -967,6 +968,7 @@ jobs:
name: llama-bin-win-sycl-x64.zip
windows-latest-cmake-hip:
+ if: ${{ github.event.inputs.create_release != 'true' }}
runs-on: windows-latest
steps:
@@ -994,8 +996,72 @@ jobs:
run: |
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
- cmake -G "Unix Makefiles" -B build -S . -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" -DGGML_HIPBLAS=ON
- cmake --build build --config Release
+ cmake -G "Unix Makefiles" -B build -S . -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" -DGGML_HIPBLAS=ON -DCMAKE_BUILD_TYPE=Release -DGGML_RPC=ON
+ cmake --build build -j ${env:NUMBER_OF_PROCESSORS}
+
+ windows-latest-cmake-hip-release:
+ if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
+ runs-on: windows-latest
+
+ strategy:
+ matrix:
+ gpu_target: [gfx1100, gfx1101, gfx1030]
+
+ steps:
+ - name: Clone
+ id: checkout
+ uses: actions/checkout@v4
+
+ - name: Install
+ id: depends
+ run: |
+ $ErrorActionPreference = "Stop"
+ write-host "Downloading AMD HIP SDK Installer"
+ Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
+ write-host "Installing AMD HIP SDK"
+ Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
+ write-host "Completed AMD HIP SDK installation"
+
+ - name: Verify ROCm
+ id: verify
+ run: |
+ & 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version
+
+ - name: Build
+ id: cmake_build
+ run: |
+ $env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
+ $env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
+ cmake -G "Unix Makefiles" -B build -S . -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" -DGGML_HIPBLAS=ON -DCMAKE_BUILD_TYPE=Release -DGPU_TARGETS=${{ matrix.gpu_target }} -DGGML_RPC=ON
+ cmake --build build -j ${env:NUMBER_OF_PROCESSORS}
+ md "build\bin\rocblas\library\"
+ cp "${env:HIP_PATH}\bin\hipblas.dll" "build\bin\"
+ cp "${env:HIP_PATH}\bin\rocblas.dll" "build\bin\"
+ cp "${env:HIP_PATH}\bin\rocblas\library\*" "build\bin\rocblas\library\"
+
+ - name: Determine tag name
+ id: tag
+ shell: bash
+ run: |
+ BUILD_NUMBER="$(git rev-list --count HEAD)"
+ SHORT_HASH="$(git rev-parse --short=7 HEAD)"
+ if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then
+ echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT
+ else
+ SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-')
+ echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT
+ fi
+
+ - name: Pack artifacts
+ id: pack_artifacts
+ run: |
+ 7z a llama-${{ steps.tag.outputs.name }}-bin-win-hip-x64-${{ matrix.gpu_target }}.zip .\build\bin\*
+
+ - name: Upload artifacts
+ uses: actions/upload-artifact@v4
+ with:
+ path: llama-${{ steps.tag.outputs.name }}-bin-win-hip-x64-${{ matrix.gpu_target }}.zip
+ name: llama-bin-win-hip-x64-${{ matrix.gpu_target }}.zip
ios-xcode-build:
runs-on: macos-latest
@@ -1060,6 +1126,7 @@ jobs:
- macOS-latest-cmake
- windows-latest-cmake
- windows-latest-cmake-cuda
+ - windows-latest-cmake-hip-release
- macOS-latest-cmake-arm64
- macOS-latest-cmake-x64
diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml
index 9044cd78b..a4ac9b217 100644
--- a/.github/workflows/docker.yml
+++ b/.github/workflows/docker.yml
@@ -15,11 +15,17 @@ on:
branches:
- master
paths: ['.github/workflows/docker.yml', '.devops/*.Dockerfile', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal']
+ workflow_dispatch: # allows manual triggering, useful for debugging
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
+# Fine-grant permission
+# https://docs.github.com/en/actions/security-for-github-actions/security-guides/automatic-token-authentication#modifying-the-permissions-for-the-github_token
+permissions:
+ packages: write
+
jobs:
push_to_registry:
name: Push Docker image to Docker Hub
@@ -46,6 +52,8 @@ jobs:
steps:
- name: Check out the repo
uses: actions/checkout@v4
+ with:
+ fetch-depth: 0 # preserve git history, so we can determine the build number
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
@@ -60,6 +68,34 @@ jobs:
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
+ - name: Determine tag name
+ id: tag
+ shell: bash
+ run: |
+ BUILD_NUMBER="$(git rev-list --count HEAD)"
+ SHORT_HASH="$(git rev-parse --short=7 HEAD)"
+ REPO_OWNER="${GITHUB_REPOSITORY_OWNER@L}" # to lower case
+ REPO_NAME="${{ github.event.repository.name }}"
+
+ # determine tag name postfix (build number, commit hash)
+ if [[ "${{ env.GITHUB_BRANCH_NAME }}" == "master" ]]; then
+ TAG_POSTFIX="b${BUILD_NUMBER}"
+ else
+ SAFE_NAME=$(echo "${{ env.GITHUB_BRANCH_NAME }}" | tr '/' '-')
+ TAG_POSTFIX="${SAFE_NAME}-${SHORT_HASH}"
+ fi
+
+ # list all tags possible
+ TAGS=""
+ TAGS="${TAGS}ghcr.io/${REPO_OWNER}/${REPO_NAME}:${{ matrix.config.tag }},"
+ TAGS="${TAGS}ghcr.io/${REPO_OWNER}/${REPO_NAME}:${{ matrix.config.tag }}-${TAG_POSTFIX}"
+
+ echo "output_tags=$TAGS" >> $GITHUB_OUTPUT
+ echo "output_tags=$TAGS" # print out for debugging
+ env:
+ GITHUB_BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
+ GITHUB_REPOSITORY_OWNER: '${{ github.repository_owner }}'
+
# https://github.com/jlumbroso/free-disk-space/tree/54081f138730dfa15788a46383842cd2f914a1be#example
- name: Free Disk Space (Ubuntu)
uses: jlumbroso/free-disk-space@main
@@ -77,25 +113,6 @@ jobs:
docker-images: true
swap-storage: true
- - name: Determine tag name
- id: tag
- shell: bash
- run: |
- BUILD_NUMBER="$(git rev-list --count HEAD)"
- SHORT_HASH="$(git rev-parse --short=7 HEAD)"
- if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then
- echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT
- else
- SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-')
- echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT
- fi
-
- - name: Downcase github.repository_owner
- run: |
- echo "repository_owner_lowercase=${GITHUB_REPOSITORY_OWNER@L}" >> $GITHUB_ENV
- env:
- GITHUB_REPOSITORY_OWNER: '${{ github.repository_owner }}'
-
- name: Build and push Docker image (tagged + versioned)
if: github.event_name == 'push'
uses: docker/build-push-action@v6
@@ -103,5 +120,6 @@ jobs:
context: .
push: true
platforms: ${{ matrix.config.platforms }}
- tags: "ghcr.io/${{ env.repository_owner_lowercase }}/llama.cpp:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }},ghcr.io/${{ env.repository_owner_lowercase }}/llama.cpp:${{ matrix.config.tag }},ghcr.io/${{ env.repository_owner_lowercase }}/llama.cpp:${{ matrix.config.tag }}-${{ steps.tag.outputs.name }}"
+ # tag list is generated from step above
+ tags: ${{ steps.tag.outputs.output_tags }}
file: ${{ matrix.config.dockerfile }}
diff --git a/.github/workflows/python-type-check.yml b/.github/workflows/python-type-check.yml
index e5ff5e6d7..373bb6010 100644
--- a/.github/workflows/python-type-check.yml
+++ b/.github/workflows/python-type-check.yml
@@ -4,11 +4,13 @@ on:
push:
paths:
- '.github/workflows/python-type-check.yml'
+ - 'pyrightconfig.json'
- '**.py'
- '**/requirements*.txt'
pull_request:
paths:
- '.github/workflows/python-type-check.yml'
+ - 'pyrightconfig.json'
- '**.py'
- '**/requirements*.txt'
@@ -33,6 +35,6 @@ jobs:
- name: Type-check with Pyright
uses: jakebailey/pyright-action@v2
with:
- version: 1.1.370
+ version: 1.1.382
level: warning
warnings: true
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 973907819..415743c2a 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -62,6 +62,9 @@ option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF)
option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF)
option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF)
+# utils
+option(LLAMA_BUILD_COMMON "llama: build common utils library" ON)
+
# extra artifacts
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
@@ -191,15 +194,17 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/llama.pc"
DESTINATION lib/pkgconfig)
#
-# programs, examples and tests
+# utils, programs, examples and tests
#
-add_subdirectory(common)
+if (LLAMA_BUILD_COMMON)
+ add_subdirectory(common)
+endif()
if (LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
include(CTest)
add_subdirectory(tests)
-endif ()
+endif()
if (LLAMA_BUILD_EXAMPLES)
add_subdirectory(examples)
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index a9e000e52..3d7c6f86c 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -27,3 +27,8 @@

+# Resources
+
+The Github issues, PRs and discussions contain a lot of information that can be useful to get familiar with the codebase. For convenience, some of the more important information is referenced from Github projects:
+
+https://github.com/ggerganov/llama.cpp/projects
diff --git a/Makefile b/Makefile
index f922f7083..8a903d7ed 100644
--- a/Makefile
+++ b/Makefile
@@ -611,7 +611,7 @@ ifdef GGML_CUDA
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include
MK_LDFLAGS += -lmusa -lmublas -lmusart -lpthread -ldl -lrt -L$(CUDA_PATH)/lib -L/usr/lib64
- MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_22
+ MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_21 --cuda-gpu-arch=mp_22
else
ifneq ('', '$(wildcard /opt/cuda)')
CUDA_PATH ?= /opt/cuda
diff --git a/README.md b/README.md
index 4d24dd591..ecc2df8ca 100644
--- a/README.md
+++ b/README.md
@@ -17,7 +17,8 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
## Hot topics
-- Huggingface GGUF editor: [discussion](https://github.com/ggerganov/llama.cpp/discussions/9268) | [tool](https://huggingface.co/spaces/CISCai/gguf-editor)
+- **Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggerganov/llama.cpp/discussions/9669**
+- Hugging Face GGUF editor: [discussion](https://github.com/ggerganov/llama.cpp/discussions/9268) | [tool](https://huggingface.co/spaces/CISCai/gguf-editor)
----
@@ -112,6 +113,7 @@ Typically finetunes of the base models below are supported as well.
- Go: [go-skynet/go-llama.cpp](https://github.com/go-skynet/go-llama.cpp)
- Node.js: [withcatai/node-llama-cpp](https://github.com/withcatai/node-llama-cpp)
- JS/TS (llama.cpp server client): [lgrammel/modelfusion](https://modelfusion.dev/integration/model-provider/llamacpp)
+- JS/TS (Programmable Prompt Engine CLI): [offline-ai/cli](https://github.com/offline-ai/cli)
- JavaScript/Wasm (works in browser): [tangledgroup/llama-cpp-wasm](https://github.com/tangledgroup/llama-cpp-wasm)
- Typescript/Wasm (nicer API, available on npm): [ngxson/wllama](https://github.com/ngxson/wllama)
- Ruby: [yoshoku/llama_cpp.rb](https://github.com/yoshoku/llama_cpp.rb)
@@ -172,6 +174,7 @@ Unless otherwise noted these projects are open-source with permissive licensing:
**Tools:**
- [akx/ggify](https://github.com/akx/ggify) – download PyTorch models from HuggingFace Hub and convert them to GGML
+- [akx/ollama-dl](https://github.com/akx/ollama-dl) – download models from the Ollama library to be used directly with llama.cpp
- [crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption
- [gpustack/gguf-parser](https://github.com/gpustack/gguf-parser-go/tree/main/cmd/gguf-parser) - review/check the GGUF file and estimate the memory usage
- [Styled Lines](https://marketplace.unity.com/packages/tools/generative-ai/styled-lines-llama-cpp-model-292902) (proprietary licensed, async wrapper of inference part for game development in Unity3d with prebuild Mobile and Web platform wrappers and a model example)
@@ -440,7 +443,7 @@ To learn more how to measure perplexity using llama.cpp, [read this documentatio
- Contributors can open PRs
- Collaborators can push to branches in the `llama.cpp` repo and merge PRs into the `master` branch
- Collaborators will be invited based on contributions
-- Any help with managing issues and PRs is very appreciated!
+- Any help with managing issues, PRs and projects is very appreciated!
- See [good first issues](https://github.com/ggerganov/llama.cpp/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) for tasks suitable for first contributions
- Read the [CONTRIBUTING.md](CONTRIBUTING.md) for more information
- Make sure to read this: [Inference at the edge](https://github.com/ggerganov/llama.cpp/discussions/205)
diff --git a/ci/run.sh b/ci/run.sh
index 1ac08ee4e..7d241ecc0 100755
--- a/ci/run.sh
+++ b/ci/run.sh
@@ -712,6 +712,81 @@ function gg_run_embd_bge_small {
set +e
}
+function gg_sum_embd_bge_small {
+ gg_printf '### %s\n\n' "${ci}"
+
+ gg_printf 'BGE Small (BERT):\n'
+ gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)"
+ gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)"
+ gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)"
+}
+
+# rerank_tiny
+
+function gg_run_rerank_tiny {
+ cd ${SRC}
+
+ gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json
+ gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer.json
+ gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer_config.json
+ gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/special_tokens_map.json
+ gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/resolve/main/pytorch_model.bin
+ gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/sentence_bert_config.json
+ gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/vocab.txt
+ gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/modules.json
+ gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json
+
+ gg_wget models-mnt/rerank-tiny/1_Pooling https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/1_Pooling/config.json
+
+ path_models="../models-mnt/rerank-tiny"
+
+ rm -rf build-ci-release && mkdir build-ci-release && cd build-ci-release
+
+ set -e
+
+ (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log
+ (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log
+
+ python3 ../convert_hf_to_gguf.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf
+
+ model_f16="${path_models}/ggml-model-f16.gguf"
+
+ (time ./bin/llama-embedding --model ${model_f16} -p "what is panda?hi\nwhat is panda?it's a bear\nwhat is panda?The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
+
+ # sample output
+ # rerank score 0: 0.029
+ # rerank score 1: 0.029
+ # rerank score 2: 0.135
+
+ # check that the score is in the range [$3, $4]
+ function check_score {
+ qnt="$1"
+ score=$(echo "$2" | grep -oE "[0-9]+\.[0-9]+" | tail -n 1)
+
+ if [ $(echo "$score < $3" | bc) -eq 1 ] || [ $(echo "$score > $4" | bc) -eq 1 ]; then
+ printf ' - %s @ %s (FAIL: score not in range [%s, %s])\n' "$qnt" "$score" "$3" "$4"
+ return 20
+ fi
+
+ printf ' - %s @ %s OK\n' "$qnt" "$score"
+ return 0
+ }
+
+ check_score "rerank score 0" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 0")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log
+ check_score "rerank score 1" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 1")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log
+ check_score "rerank score 2" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 2")" "0.10" "0.15" | tee -a $OUT/${ci}-rk-f16.log
+
+ set +e
+}
+
+function gg_sum_rerank_tiny {
+ gg_printf '### %s\n\n' "${ci}"
+
+ gg_printf 'Rerank Tiny (Jina):\n'
+ gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)"
+ gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-rk-f16.log)"
+}
+
function gg_check_build_requirements {
if ! command -v cmake &> /dev/null; then
gg_printf 'cmake not found, please install'
@@ -726,15 +801,6 @@ function gg_check_build_requirements {
fi
}
-function gg_sum_embd_bge_small {
- gg_printf '### %s\n\n' "${ci}"
-
- gg_printf 'BGE Small (BERT):\n'
- gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)"
- gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)"
- gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)"
-}
-
## main
export LLAMA_LOG_PREFIX=1
@@ -762,6 +828,7 @@ test $ret -eq 0 && gg_run ctest_release
if [ -z ${GG_BUILD_LOW_PERF} ]; then
test $ret -eq 0 && gg_run embd_bge_small
+ test $ret -eq 0 && gg_run rerank_tiny
if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then
test $ret -eq 0 && gg_run test_scripts_debug
diff --git a/common/arg.cpp b/common/arg.cpp
index 60e37a89a..8266a16c2 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -284,6 +284,10 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx
params.kv_overrides.back().key[0] = 0;
}
+ if (params.reranking && params.embedding) {
+ throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
+ }
+
return true;
}
@@ -391,7 +395,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params) {
params.verbose_prompt = true;
}
- ).set_examples({LLAMA_EXAMPLE_MAIN}));
+ ));
add_opt(llama_arg(
{"--no-display-prompt"},
format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"),
@@ -691,7 +695,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params) {
params.ctx_shift = false;
}
- ).set_examples({LLAMA_EXAMPLE_MAIN}));
+ ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
add_opt(llama_arg(
{"--chunks"}, "N",
format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
@@ -1093,16 +1097,17 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
}
).set_sparam());
add_opt(llama_arg(
- {"--pooling"}, "{none,mean,cls,last}",
+ {"--pooling"}, "{none,mean,cls,last,rank}",
"pooling type for embeddings, use model default if unspecified",
[](gpt_params & params, const std::string & value) {
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
- else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
+ else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
+ else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; }
else { throw std::invalid_argument("invalid value"); }
}
- ).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
+ ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING"));
add_opt(llama_arg(
{"--attention"}, "{causal,non,causal}",
"attention type for embeddings, use model default if unspecified",
@@ -1121,77 +1126,77 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
else { throw std::invalid_argument("invalid value"); }
}
- ));
+ ).set_env("LLAMA_ARG_ROPE_SCALING_TYPE"));
add_opt(llama_arg(
{"--rope-scale"}, "N",
"RoPE context scaling factor, expands context by a factor of N",
[](gpt_params & params, const std::string & value) {
params.rope_freq_scale = 1.0f / std::stof(value);
}
- ));
+ ).set_env("LLAMA_ARG_ROPE_SCALE"));
add_opt(llama_arg(
{"--rope-freq-base"}, "N",
"RoPE base frequency, used by NTK-aware scaling (default: loaded from model)",
[](gpt_params & params, const std::string & value) {
params.rope_freq_base = std::stof(value);
}
- ));
+ ).set_env("LLAMA_ARG_ROPE_FREQ_BASE"));
add_opt(llama_arg(
{"--rope-freq-scale"}, "N",
"RoPE frequency scaling factor, expands context by a factor of 1/N",
[](gpt_params & params, const std::string & value) {
params.rope_freq_scale = std::stof(value);
}
- ));
+ ).set_env("LLAMA_ARG_ROPE_FREQ_SCALE"));
add_opt(llama_arg(
{"--yarn-orig-ctx"}, "N",
format("YaRN: original context size of model (default: %d = model training context size)", params.yarn_orig_ctx),
[](gpt_params & params, int value) {
params.yarn_orig_ctx = value;
}
- ));
+ ).set_env("LLAMA_ARG_YARN_ORIG_CTX"));
add_opt(llama_arg(
{"--yarn-ext-factor"}, "N",
format("YaRN: extrapolation mix factor (default: %.1f, 0.0 = full interpolation)", (double)params.yarn_ext_factor),
[](gpt_params & params, const std::string & value) {
params.yarn_ext_factor = std::stof(value);
}
- ));
+ ).set_env("LLAMA_ARG_YARN_EXT_FACTOR"));
add_opt(llama_arg(
{"--yarn-attn-factor"}, "N",
format("YaRN: scale sqrt(t) or attention magnitude (default: %.1f)", (double)params.yarn_attn_factor),
[](gpt_params & params, const std::string & value) {
params.yarn_attn_factor = std::stof(value);
}
- ));
+ ).set_env("LLAMA_ARG_YARN_ATTN_FACTOR"));
add_opt(llama_arg(
{"--yarn-beta-slow"}, "N",
format("YaRN: high correction dim or alpha (default: %.1f)", (double)params.yarn_beta_slow),
[](gpt_params & params, const std::string & value) {
params.yarn_beta_slow = std::stof(value);
}
- ));
+ ).set_env("LLAMA_ARG_YARN_BETA_SLOW"));
add_opt(llama_arg(
{"--yarn-beta-fast"}, "N",
format("YaRN: low correction dim or beta (default: %.1f)", (double)params.yarn_beta_fast),
[](gpt_params & params, const std::string & value) {
params.yarn_beta_fast = std::stof(value);
}
- ));
+ ).set_env("LLAMA_ARG_YARN_BETA_FAST"));
add_opt(llama_arg(
{"-gan", "--grp-attn-n"}, "N",
format("group-attention factor (default: %d)", params.grp_attn_n),
[](gpt_params & params, int value) {
params.grp_attn_n = value;
}
- ));
+ ).set_env("LLAMA_ARG_GRP_ATTN_N"));
add_opt(llama_arg(
{"-gaw", "--grp-attn-w"}, "N",
format("group-attention width (default: %.1f)", (double)params.grp_attn_w),
[](gpt_params & params, int value) {
params.grp_attn_w = value;
}
- ));
+ ).set_env("LLAMA_ARG_GRP_ATTN_W"));
add_opt(llama_arg(
{"-dkvc", "--dump-kv-cache"},
"verbose print of the KV cache",
@@ -1205,7 +1210,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params) {
params.no_kv_offload = true;
}
- ));
+ ).set_env("LLAMA_ARG_NO_KV_OFFLOAD"));
add_opt(llama_arg(
{"-ctk", "--cache-type-k"}, "TYPE",
format("KV cache data type for K (default: %s)", params.cache_type_k.c_str()),
@@ -1213,7 +1218,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
// TODO: get the type right here
params.cache_type_k = value;
}
- ));
+ ).set_env("LLAMA_ARG_CACHE_TYPE_K"));
add_opt(llama_arg(
{"-ctv", "--cache-type-v"}, "TYPE",
format("KV cache data type for V (default: %s)", params.cache_type_v.c_str()),
@@ -1221,7 +1226,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
// TODO: get the type right here
params.cache_type_v = value;
}
- ));
+ ).set_env("LLAMA_ARG_CACHE_TYPE_V"));
add_opt(llama_arg(
{"--perplexity", "--all-logits"},
format("return logits for all tokens in the batch (default: %s)", params.logits_all ? "true" : "false"),
@@ -1312,7 +1317,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params, int value) {
params.n_parallel = value;
}
- ));
+ ).set_env("LLAMA_ARG_N_PARALLEL"));
add_opt(llama_arg(
{"-ns", "--sequences"}, "N",
format("number of sequences to decode (default: %d)", params.n_sequences),
@@ -1355,7 +1360,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params, const std::string & value) {
params.rpc_servers = value;
}
- ));
+ ).set_env("LLAMA_ARG_RPC"));
#endif
add_opt(llama_arg(
{"--mlock"},
@@ -1363,14 +1368,14 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params) {
params.use_mlock = true;
}
- ));
+ ).set_env("LLAMA_ARG_MLOCK"));
add_opt(llama_arg(
{"--no-mmap"},
"do not memory-map model (slower load but may reduce pageouts if not using mlock)",
[](gpt_params & params) {
params.use_mmap = false;
}
- ));
+ ).set_env("LLAMA_ARG_NO_MMAP"));
add_opt(llama_arg(
{"--numa"}, "TYPE",
"attempt optimizations that help on some NUMA systems\n"
@@ -1385,7 +1390,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; }
else { throw std::invalid_argument("invalid value"); }
}
- ));
+ ).set_env("LLAMA_ARG_NUMA"));
add_opt(llama_arg(
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
"number of layers to store in VRAM",
@@ -1433,7 +1438,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
fprintf(stderr, "warning: llama.cpp was compiled without support for GPU offload. Setting the split mode has no effect.\n");
}
}
- ));
+ ).set_env("LLAMA_ARG_SPLIT_MODE"));
add_opt(llama_arg(
{"-ts", "--tensor-split"}, "N0,N1,N2,...",
"fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1",
@@ -1460,7 +1465,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
fprintf(stderr, "warning: llama.cpp was compiled without support for GPU offload. Setting a tensor split has no effect.\n");
}
}
- ));
+ ).set_env("LLAMA_ARG_TENSOR_SPLIT"));
add_opt(llama_arg(
{"-mg", "--main-gpu"}, "INDEX",
format("the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: %d)", params.main_gpu),
@@ -1470,7 +1475,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
fprintf(stderr, "warning: llama.cpp was compiled without support for GPU offload. Setting the main GPU has no effect.\n");
}
}
- ));
+ ).set_env("LLAMA_ARG_MAIN_GPU"));
add_opt(llama_arg(
{"--check-tensors"},
format("check model tensor data for invalid values (default: %s)", params.check_tensors ? "true" : "false"),
@@ -1533,7 +1538,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params, const std::string & value) {
params.model_alias = value;
}
- ).set_examples({LLAMA_EXAMPLE_SERVER}));
+ ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ALIAS"));
add_opt(llama_arg(
{"-m", "--model"}, "FNAME",
ex == LLAMA_EXAMPLE_EXPORT_LORA
@@ -1741,7 +1746,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params, const std::string & value) {
params.public_path = value;
}
- ).set_examples({LLAMA_EXAMPLE_SERVER}));
+ ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_STATIC_PATH"));
add_opt(llama_arg(
{"--embedding", "--embeddings"},
format("restrict to only support embedding use case; use only with dedicated embedding models (default: %s)", params.embedding ? "enabled" : "disabled"),
@@ -1749,6 +1754,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
params.embedding = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
+ add_opt(llama_arg(
+ {"--reranking", "--rerank"},
+ format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"),
+ [](gpt_params & params) {
+ params.reranking = true;
+ }
+ ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING"));
add_opt(llama_arg(
{"--api-key"}, "KEY",
"API key to use for authentication (default: none)",
@@ -1779,14 +1791,14 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params, const std::string & value) {
params.ssl_file_key = value;
}
- ).set_examples({LLAMA_EXAMPLE_SERVER}));
+ ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_KEY_FILE"));
add_opt(llama_arg(
{"--ssl-cert-file"}, "FNAME",
"path to file a PEM-encoded SSL certificate",
[](gpt_params & params, const std::string & value) {
params.ssl_file_cert = value;
}
- ).set_examples({LLAMA_EXAMPLE_SERVER}));
+ ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_CERT_FILE"));
add_opt(llama_arg(
{"-to", "--timeout"}, "N",
format("server read/write timeout in seconds (default: %d)", params.timeout_read),
@@ -1794,7 +1806,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
params.timeout_read = value;
params.timeout_write = value;
}
- ).set_examples({LLAMA_EXAMPLE_SERVER}));
+ ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TIMEOUT"));
add_opt(llama_arg(
{"--threads-http"}, "N",
format("number of threads used to process HTTP requests (default: %d)", params.n_threads_http),
diff --git a/common/common.cpp b/common/common.cpp
index 8d0ed4f95..a0611f3d1 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -1023,6 +1023,11 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf;
+ if (params.reranking) {
+ cparams.embeddings = true;
+ cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
+ }
+
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
@@ -1432,6 +1437,8 @@ void llama_batch_add(
llama_pos pos,
const std::vector & seq_ids,
bool logits) {
+ GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
+
batch.token [batch.n_tokens] = id;
batch.pos [batch.n_tokens] = pos;
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
diff --git a/common/common.h b/common/common.h
index cb87c4479..8b84cf9ad 100644
--- a/common/common.h
+++ b/common/common.h
@@ -271,6 +271,7 @@ struct gpt_params {
int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
std::string embd_sep = "\n"; // separator of embendings
+ bool reranking = false; // enable reranking support on server
// server params
int32_t port = 8080; // server listens on this network port
diff --git a/common/console.cpp b/common/console.cpp
index f65cbc6ed..078a8d678 100644
--- a/common/console.cpp
+++ b/common/console.cpp
@@ -94,6 +94,9 @@ namespace console {
simple_io = true;
}
}
+ if (simple_io) {
+ _setmode(_fileno(stdin), _O_U8TEXT);
+ }
#else
// POSIX-specific console initialization
if (!simple_io) {
diff --git a/common/log.cpp b/common/log.cpp
index 2825a227e..5a844ed59 100644
--- a/common/log.cpp
+++ b/common/log.cpp
@@ -82,7 +82,7 @@ struct gpt_log_entry {
}
}
- if (level != GGML_LOG_LEVEL_NONE && prefix) {
+ if (level != GGML_LOG_LEVEL_NONE && level != GGML_LOG_LEVEL_CONT && prefix) {
if (timestamp) {
// [M.s.ms.us]
fprintf(fcur, "%s%d.%02d.%03d.%03d%s ",
diff --git a/common/log.h b/common/log.h
index d13f72d89..84f9b3ed7 100644
--- a/common/log.h
+++ b/common/log.h
@@ -83,8 +83,10 @@ void gpt_log_set_timestamps(struct gpt_log * log, bool timestamps); // w
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, 0, __VA_ARGS__)
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, 0, __VA_ARGS__)
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_DEFAULT_DEBUG, __VA_ARGS__)
+#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, 0, __VA_ARGS__)
#define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__)
#define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__)
#define LOG_ERRV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, verbosity, __VA_ARGS__)
#define LOG_DBGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, verbosity, __VA_ARGS__)
+#define LOG_CNTV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_CONT, verbosity, __VA_ARGS__)
diff --git a/common/sampling.cpp b/common/sampling.cpp
index e51d07611..3dc7f1120 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -209,7 +209,15 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
GGML_ASSERT(false && "unknown mirostat version");
}
} else {
- llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
+ if (params.n_probs > 0) {
+ // some use cases require to sample greedily, but still obtain the probabilities of the top tokens
+ // ref: https://github.com/ggerganov/llama.cpp/pull/9605
+ //
+ // the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but
+ // it is much faster, since we avoid sorting all tokens and should give a good approximation
+ llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs));
+ llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
+ }
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
}
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 7c2c87e0b..da5feb25b 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -294,8 +294,13 @@ class Model:
bid = int(part)
break
- for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
- data: np.ndarray # type hint
+ for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
+ data = data_torch.squeeze().numpy()
+
+ # if data ends up empty, it means data_torch was a scalar tensor -> restore
+ if len(data.shape) == 0:
+ data = data_torch.numpy()
+
n_dims = len(data.shape)
data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims)
@@ -595,6 +600,9 @@ class Model:
if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e":
# ref: https://huggingface.co/databricks/dbrx-base
res = "dbrx"
+ if chkhsh == "c7699093ba4255a91e702aa38a596aa81669f3525dae06c2953267dde580f448":
+ # ref: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
+ res = "jina-v1-en"
if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en
res = "jina-v2-en"
@@ -643,6 +651,9 @@ class Model:
if chkhsh == "fcace8b9cac38ce847670c970cd5892031a753a1ef381abd1d9af00f713da085":
# ref: https://huggingface.co/microsoft/phi-2
res = "phi-2"
+ if chkhsh == "60824e3c0d9401f89943cbb2fff727f0e2d4c545ba4df2d6e4f09a6db0f5b450":
+ # ref: https://huggingface.co/facebook/chameleon-7b
+ res = "chameleon"
if res is None:
logger.warning("\n")
@@ -2606,7 +2617,7 @@ class NomicBertModel(BertModel):
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
-@Model.register("XLMRobertaModel")
+@Model.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
class XLMRobertaModel(BertModel):
model_arch = gguf.MODEL_ARCH.BERT
@@ -2704,6 +2715,11 @@ class XLMRobertaModel(BertModel):
self.gguf_writer.add_add_eos_token(True)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ # if name starts with "roberta.", remove the prefix
+ # e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main
+ if name.startswith("roberta."):
+ name = name[8:]
+
# position embeddings start at pad_token_id + 1, so just chop down the weight tensor
if name == "embeddings.position_embeddings.weight":
if self._position_offset is not None:
@@ -3115,6 +3131,14 @@ class JinaBertV2Model(BertModel):
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(True)
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ # if name starts with "bert.", remove the prefix
+ # e.g. https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
+ if name.startswith("bert."):
+ name = name[5:]
+
+ return super().modify_tensors(data_torch, name, bid)
+
@Model.register("OpenELMForCausalLM")
class OpenELMModel(Model):
@@ -4085,8 +4109,109 @@ class ExaoneModel(Model):
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))
+@Model.register("GraniteForCausalLM")
+class GraniteModel(LlamaModel):
+ """Conversion for IBM's GraniteForCausalLM"""
+ model_arch = gguf.MODEL_ARCH.GRANITE
+
+ def set_gguf_parameters(self):
+ """Granite uses standard llama parameters with the following differences:
+
+ - No head_dim support
+ - New multiplier params:
+ - attention_scale
+ - embedding_scale
+ - residual_scale
+ - logits_scaling
+ """
+ if head_dim := self.hparams.pop("head_dim", None):
+ logger.warning("Ignoring head_dim (%s) from config for Granite", head_dim)
+ super().set_gguf_parameters()
+ # NOTE: Convert _multiplier params to _scale params for naming
+ # consistency
+ if attention_scale := self.hparams.get("attention_multiplier"):
+ self.gguf_writer.add_attention_scale(attention_scale)
+ logger.info("gguf: (granite) attention_scale = %s", attention_scale)
+ if embedding_scale := self.hparams.get("embedding_multiplier"):
+ self.gguf_writer.add_embedding_scale(embedding_scale)
+ logger.info("gguf: (granite) embedding_scale = %s", embedding_scale)
+ if residual_scale := self.hparams.get("residual_multiplier"):
+ self.gguf_writer.add_residual_scale(residual_scale)
+ logger.info("gguf: (granite) residual_scale = %s", residual_scale)
+ if logits_scale := self.hparams.get("logits_scaling"):
+ self.gguf_writer.add_logit_scale(logits_scale)
+ logger.info("gguf: (granite) logits_scale = %s", logits_scale)
+
+
+@Model.register("GraniteMoeForCausalLM")
+class GraniteMoeModel(GraniteModel):
+ """Conversion for IBM's GraniteMoeForCausalLM"""
+ model_arch = gguf.MODEL_ARCH.GRANITE_MOE
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ """In modeling_granitemoe, the JetMoe implementation of parallel experts
+ is used. This essentially merges w1 and w3 into a single tensor with 2x
+ the hidden size that is then split during forward. To keep compatibility
+ with existing mixtral support, we pull them apart here.
+ """
+
+ if name.endswith("block_sparse_moe.input_linear.weight"):
+ ffn_dim = self.hparams["intermediate_size"]
+ assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size"
+ gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :]
+ return [
+ (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate),
+ (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up),
+ ]
+
+ return super().modify_tensors(data_torch, name, bid)
+
+
+@Model.register("ChameleonForConditionalGeneration")
+@Model.register("ChameleonForCausalLM") # obsolete
+class ChameleonModel(Model):
+ model_arch = gguf.MODEL_ARCH.CHAMELEON
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ self.gguf_writer.add_swin_norm(self.hparams.get("swin_norm", False))
+
+ def set_vocab(self):
+ self._set_vocab_gpt2()
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ # ignore image tokenizer for now
+ # TODO: remove this once image support is implemented for Chameleon
+ if name.startswith("model.vqmodel"):
+ return []
+
+ n_head = self.hparams["num_attention_heads"]
+ n_kv_head = self.hparams.get("num_key_value_heads")
+ hidden_dim = self.hparams.get("hidden_size")
+
+ if name.endswith(("q_proj.weight", "q_proj.bias")):
+ data_torch = LlamaModel.permute(data_torch, n_head, n_head)
+ if name.endswith(("k_proj.weight", "k_proj.bias")):
+ data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
+ if name.endswith(("q_norm.weight", "q_norm.bias")):
+ data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_head, hidden_dim)
+ if name.endswith(("k_norm.weight", "k_norm.bias")):
+ data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_kv_head, hidden_dim)
+
+ return [(self.map_tensor_name(name), data_torch)]
+
+ # see: https://github.com/huggingface/transformers/blob/72fb02c47dbbe1999ae105319f24631cad6e2e00/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py#L176-L203
+ @staticmethod
+ def _reverse_hf_permute(data_torch, n_heads, hidden_dim):
+ head_dim = hidden_dim // n_heads
+ data_torch = data_torch[0].view(2, head_dim // 2).t().reshape(1, -1)
+ data_torch = data_torch.repeat_interleave(n_heads, 0)
+ return data_torch
+
+
###### CONVERSION LOGIC ######
+
# tree of lazy tensors
class LazyTorchTensor(gguf.LazyBase):
_tensor_type = torch.Tensor
diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py
index 021f65abd..022354a3b 100755
--- a/convert_hf_to_gguf_update.py
+++ b/convert_hf_to_gguf_update.py
@@ -81,6 +81,7 @@ models = [
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
+ {"name": "jina-v1-en", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-reranker-v1-tiny-en", },
{"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM!
{"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", },
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", },
@@ -99,6 +100,7 @@ models = [
{'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", },
{"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", },
{"name": "phi-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/microsoft/phi-2", },
+ {"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", },
]
diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md
index e3b9572cc..bc266f7d8 100644
--- a/docs/backend/SYCL.md
+++ b/docs/backend/SYCL.md
@@ -636,6 +636,14 @@ use 1 SYCL GPUs: [0] with Max compute units:512
It's same for other projects including llama.cpp SYCL backend.
+- Meet issue: `Native API failed. Native API returns: -6 (PI_ERROR_OUT_OF_HOST_MEMORY) -6 (PI_ERROR_OUT_OF_HOST_MEMORY) -999 (UNKNOWN PI error)` or `failed to allocate SYCL0 buffer`
+
+ Device Memory is not enough.
+
+ |Reason|Solution|
+ |-|-|
+ |Default Context is too big. It leads to more memory usage.|Set `-c 8192` or smaller value.|
+ |Model is big and require more memory than device's.|Choose smaller quantized model, like Q5 -> Q4;
Use more than one devices to load model.|
### **GitHub contribution**:
Please add the **[SYCL]** prefix/tag in issues/PRs titles to help the SYCL-team check/address them without delay.
diff --git a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp
index ecff95f9a..c140daed3 100644
--- a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp
+++ b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp
@@ -201,7 +201,7 @@ static void print_sample_weights(TransformerWeights *w){
//////////////////////////////////////// ggml structs and functions required to load models, configs and save the model.
-struct llama_vocab {
+struct my_llama_vocab {
using id = int32_t;
using token = std::string;
using ttype = llama_token_type;
@@ -525,7 +525,7 @@ static std::string llama_escape_whitespaces(const std::string & text) {
return out.str();
}
-static void load_vocab(const char * filename, const Config * config, struct llama_vocab * vocab) {
+static void load_vocab(const char * filename, const Config * config, struct my_llama_vocab * vocab) {
if (is_ggml_file(filename)) {
LOG_INF("%s: Loading vocabulary from gguf file %s\n", __func__, filename);
struct ggml_context * ctx_data = NULL;
@@ -583,13 +583,13 @@ static void load_vocab(const char * filename, const Config * config, struct llam
const int n_vocab = config->vocab_size;
/* uint32_t max_token_length = */ file.read_u32(); // unused
vocab->id_to_token.resize(n_vocab);
- for (llama_vocab::id id=0; idtoken_embedding_table -> model->tok_embeddings
@@ -671,7 +671,7 @@ static void save_as_llama_model(
std::vector tokens;
std::vector scores;
std::vector token_types;
- for (const llama_vocab::token_data & token_data : vocab->id_to_token) {
+ for (const my_llama_vocab::token_data & token_data : vocab->id_to_token) {
tokens.push_back(token_data.text.c_str());
scores.push_back(token_data.score);
token_types.push_back(token_data.type);
@@ -905,7 +905,7 @@ int main(int argc, char ** argv) {
fclose(file);
}
- struct llama_vocab vocab;
+ struct my_llama_vocab vocab;
load_vocab(params.fn_vocab_model, &config, &vocab);
struct my_llama_model model;
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp
index a438dcb5a..734926822 100644
--- a/examples/embedding/embedding.cpp
+++ b/examples/embedding/embedding.cpp
@@ -135,7 +135,7 @@ int main(int argc, char ** argv) {
// tokenize the prompts and trim
std::vector> inputs;
for (const auto & prompt : prompts) {
- auto inp = ::llama_tokenize(ctx, prompt, true, false);
+ auto inp = ::llama_tokenize(ctx, prompt, true, true);
if (inp.size() > n_batch) {
LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
__func__, (long long int) inp.size(), (long long int) n_batch);
@@ -234,6 +234,11 @@ int main(int argc, char ** argv) {
}
LOG("\n");
}
+ } else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
+ for (int j = 0; j < n_embd_count; j++) {
+ // NOTE: if you change this log - update the tests in ci/run.sh
+ LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
+ }
} else {
// print the first part of the embeddings or for a single prompt, the full embedding
for (int j = 0; j < n_prompts; j++) {
diff --git a/examples/gen-docs/gen-docs.cpp b/examples/gen-docs/gen-docs.cpp
index b6d4725fd..4b19a9dc2 100644
--- a/examples/gen-docs/gen-docs.cpp
+++ b/examples/gen-docs/gen-docs.cpp
@@ -6,42 +6,73 @@
// Export usage message (-h) to markdown format
+static void write_table_header(std::ofstream & file) {
+ file << "| Argument | Explanation |\n";
+ file << "| -------- | ----------- |\n";
+}
+
+static void write_table_entry(std::ofstream & file, const llama_arg & opt) {
+ file << "| `";
+ // args
+ for (const auto & arg : opt.args) {
+ if (arg == opt.args.front()) {
+ file << arg;
+ if (opt.args.size() > 1) file << ", ";
+ } else {
+ file << arg << (arg != opt.args.back() ? ", " : "");
+ }
+ }
+ // value hint
+ if (opt.value_hint) {
+ std::string md_value_hint(opt.value_hint);
+ string_replace_all(md_value_hint, "|", "\\|");
+ file << " " << md_value_hint;
+ }
+ if (opt.value_hint_2) {
+ std::string md_value_hint_2(opt.value_hint_2);
+ string_replace_all(md_value_hint_2, "|", "\\|");
+ file << " " << md_value_hint_2;
+ }
+ // help text
+ std::string md_help(opt.help);
+ string_replace_all(md_help, "\n", "
");
+ string_replace_all(md_help, "|", "\\|");
+ file << "` | " << md_help << " |\n";
+}
+
+static void write_table(std::ofstream & file, std::vector & opts) {
+ write_table_header(file);
+ for (const auto & opt : opts) {
+ write_table_entry(file, *opt);
+ }
+}
+
static void export_md(std::string fname, llama_example ex) {
std::ofstream file(fname, std::ofstream::out | std::ofstream::trunc);
gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, ex);
- file << "| Argument | Explanation |\n";
- file << "| -------- | ----------- |\n";
+ std::vector common_options;
+ std::vector sparam_options;
+ std::vector specific_options;
for (auto & opt : ctx_arg.options) {
- file << "| `";
- // args
- for (const auto & arg : opt.args) {
- if (arg == opt.args.front()) {
- file << arg;
- if (opt.args.size() > 1) file << ", ";
- } else {
- file << arg << (arg != opt.args.back() ? ", " : "");
- }
+ // in case multiple LLAMA_EXAMPLE_* are set, we prioritize the LLAMA_EXAMPLE_* matching current example
+ if (opt.is_sparam) {
+ sparam_options.push_back(&opt);
+ } else if (opt.in_example(ctx_arg.ex)) {
+ specific_options.push_back(&opt);
+ } else {
+ common_options.push_back(&opt);
}
- // value hint
- if (opt.value_hint) {
- std::string md_value_hint(opt.value_hint);
- string_replace_all(md_value_hint, "|", "\\|");
- file << " " << md_value_hint;
- }
- if (opt.value_hint_2) {
- std::string md_value_hint_2(opt.value_hint_2);
- string_replace_all(md_value_hint_2, "|", "\\|");
- file << " " << md_value_hint_2;
- }
- // help text
- std::string md_help(opt.help);
- string_replace_all(md_help, "\n", "
");
- string_replace_all(md_help, "|", "\\|");
- file << "` | " << md_help << " |\n";
}
+
+ file << "**Common params**\n\n";
+ write_table(file, common_options);
+ file << "\n\n**Sampling params**\n\n";
+ write_table(file, sparam_options);
+ file << "\n\n**Example-specific params**\n\n";
+ write_table(file, specific_options);
}
int main(int, char **) {
diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp
index 265281699..c8e273529 100644
--- a/examples/imatrix/imatrix.cpp
+++ b/examples/imatrix/imatrix.cpp
@@ -572,6 +572,7 @@ int main(int argc, char ** argv) {
params.n_ctx = 512;
params.logits_all = true;
+ params.escape = false;
if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_IMATRIX, print_usage)) {
return 1;
diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp
index b77b876cc..d52425ae6 100644
--- a/examples/infill/infill.cpp
+++ b/examples/infill/infill.cpp
@@ -97,6 +97,11 @@ static void sigint_handler(int signo) {
LOG("\n");
gpt_perf_print(*g_ctx, *g_smpl);
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
+
+ // make sure all logs are flushed
+ LOG("Interrupted by user\n");
+ gpt_log_pause(gpt_log_main());
+
_exit(130);
}
}
@@ -258,9 +263,9 @@ int main(int argc, char ** argv) {
if (params.n_keep > 0) {
LOG_INF("%s: static prompt based on n_keep: '", __func__);
for (int i = 0; i < params.n_keep; i++) {
- LOG("%s", llama_token_to_piece(ctx, embd_inp[i]).c_str());
+ LOG_CNT("%s", llama_token_to_piece(ctx, embd_inp[i]).c_str());
}
- LOG("'\n");
+ LOG_CNT("'\n");
}
LOG_INF("\n");
}
@@ -301,8 +306,8 @@ int main(int argc, char ** argv) {
LOG_INF("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
- LOG("\n");
- LOG("\n##### Infill mode #####\n\n");
+ LOG_INF("\n");
+ LOG_INF("\n##### Infill mode #####\n\n");
if (params.interactive) {
const char *control_message;
if (params.multiline_input) {
@@ -313,11 +318,11 @@ int main(int argc, char ** argv) {
" - To return control without starting a new line, end your input with '/'.\n"
" - If you want to submit another line, end your input with '\\'.\n";
}
- LOG("== Running in interactive mode. ==\n");
+ LOG_INF("== Running in interactive mode. ==\n");
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
- LOG( " - Press Ctrl+C to interject at any time.\n");
+ LOG_INF( " - Press Ctrl+C to interject at any time.\n");
#endif
- LOG( "%s\n", control_message);
+ LOG_INF( "%s\n", control_message);
is_interacting = params.interactive_first;
}
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index 2d90f65a0..fb1d387b2 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -439,6 +439,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
}
types.push_back(gt);
}
+ if (invalid_param) {
+ break;
+ }
params.type_k.insert(params.type_k.end(), types.begin(), types.end());
} else if (arg == "-ctv" || arg == "--cache-type-v") {
if (++i >= argc) {
@@ -455,6 +458,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
}
types.push_back(gt);
}
+ if (invalid_param) {
+ break;
+ }
params.type_v.insert(params.type_v.end(), types.begin(), types.end());
} else if (arg == "-t" || arg == "--threads") {
if (++i >= argc) {
@@ -520,6 +526,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
}
modes.push_back(mode);
}
+ if (invalid_param) {
+ break;
+ }
params.split_mode.insert(params.split_mode.end(), modes.begin(), modes.end());
} else if (arg == "-mg" || arg == "--main-gpu") {
if (++i >= argc) {
diff --git a/examples/llava/convert_image_encoder_to_gguf.py b/examples/llava/convert_image_encoder_to_gguf.py
index 36f6b92fb..4fa1d6cea 100644
--- a/examples/llava/convert_image_encoder_to_gguf.py
+++ b/examples/llava/convert_image_encoder_to_gguf.py
@@ -274,7 +274,7 @@ fout.add_bool("clip.use_gelu", use_gelu)
if has_llava_projector:
- model.vision_model.encoder.layers.pop(-1) # pyright: ignore[reportAttributeAccessIssue]
+ model.vision_model.encoder.layers.pop(-1)
projector = torch.load(args.llava_projector)
for name, data in projector.items():
name = get_tensor_name(name)
@@ -288,7 +288,7 @@ if has_llava_projector:
print("Projector tensors added\n")
-state_dict = model.state_dict() # pyright: ignore[reportAttributeAccessIssue]
+state_dict = model.state_dict()
for name, data in state_dict.items():
if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_llava_projector):
# we don't need this
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 91fea9326..6bbb1e13e 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -116,6 +116,11 @@ static void sigint_handler(int signo) {
LOG("\n");
gpt_perf_print(*g_ctx, *g_smpl);
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
+
+ // make sure all logs are flushed
+ LOG("Interrupted by user\n");
+ gpt_log_pause(gpt_log_main());
+
_exit(130);
}
}
@@ -380,9 +385,9 @@ int main(int argc, char ** argv) {
if (params.n_keep > add_bos) {
LOG_INF("%s: static prompt based on n_keep: '", __func__);
for (int i = 0; i < params.n_keep; i++) {
- LOG("%s", llama_token_to_piece(ctx, embd_inp[i]).c_str());
+ LOG_CNT("%s", llama_token_to_piece(ctx, embd_inp[i]).c_str());
}
- LOG("'\n");
+ LOG_CNT("'\n");
}
LOG_INF("\n");
}
@@ -404,40 +409,40 @@ int main(int argc, char ** argv) {
}
if (params.interactive) {
- LOG("%s: interactive mode on.\n", __func__);
+ LOG_INF("%s: interactive mode on.\n", __func__);
if (!params.antiprompt.empty()) {
for (const auto & antiprompt : params.antiprompt) {
- LOG("Reverse prompt: '%s'\n", antiprompt.c_str());
+ LOG_INF("Reverse prompt: '%s'\n", antiprompt.c_str());
if (params.verbose_prompt) {
auto tmp = ::llama_tokenize(ctx, antiprompt, false, true);
for (int i = 0; i < (int) tmp.size(); i++) {
- LOG("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
+ LOG_INF("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
}
}
}
}
if (params.input_prefix_bos) {
- LOG("Input prefix with BOS\n");
+ LOG_INF("Input prefix with BOS\n");
}
if (!params.input_prefix.empty()) {
- LOG("Input prefix: '%s'\n", params.input_prefix.c_str());
+ LOG_INF("Input prefix: '%s'\n", params.input_prefix.c_str());
if (params.verbose_prompt) {
auto tmp = ::llama_tokenize(ctx, params.input_prefix, true, true);
for (int i = 0; i < (int) tmp.size(); i++) {
- LOG("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
+ LOG_INF("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
}
}
}
if (!params.input_suffix.empty()) {
- LOG("Input suffix: '%s'\n", params.input_suffix.c_str());
+ LOG_INF("Input suffix: '%s'\n", params.input_suffix.c_str());
if (params.verbose_prompt) {
auto tmp = ::llama_tokenize(ctx, params.input_suffix, false, true);
for (int i = 0; i < (int) tmp.size(); i++) {
- LOG("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
+ LOG_INF("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
}
}
}
@@ -469,7 +474,7 @@ int main(int argc, char ** argv) {
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * grp_attn_n"); // NOLINT
LOG_INF("self-extend: n_ctx_train = %d, grp_attn_n = %d, grp_attn_w = %d\n", n_ctx_train, ga_n, ga_w);
}
- LOG("\n");
+ LOG_INF("\n");
if (params.interactive) {
const char * control_message;
@@ -481,11 +486,11 @@ int main(int argc, char ** argv) {
" - To return control without starting a new line, end your input with '/'.\n"
" - If you want to submit another line, end your input with '\\'.\n";
}
- LOG("== Running in interactive mode. ==\n");
+ LOG_INF("== Running in interactive mode. ==\n");
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
- LOG( " - Press Ctrl+C to interject at any time.\n");
+ LOG_INF( " - Press Ctrl+C to interject at any time.\n");
#endif
- LOG( "%s\n", control_message);
+ LOG_INF( "%s\n", control_message);
is_interacting = params.interactive_first;
}
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp
index 18e75a7a2..87347135e 100644
--- a/examples/perplexity/perplexity.cpp
+++ b/examples/perplexity/perplexity.cpp
@@ -444,7 +444,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
}
LOG("%.2f minutes\n", total_seconds / 60.0);
}
- LOG("\n");
//LOG_DBG("%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
for (int j = n_ctx - params.ppl_stride - 1; j < n_ctx - 1; ++j) {
@@ -638,7 +637,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
}
LOG("%.2f minutes\n", total_seconds / 60.0);
}
- LOG("\n");
for (int seq = 0; seq < n_seq_batch; seq++) {
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
@@ -1961,6 +1959,7 @@ int main(int argc, char ** argv) {
params.n_ctx = 512;
params.logits_all = true;
+ params.escape = false;
if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
return 1;
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index a23bfb86b..b98993210 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -63,6 +63,16 @@ static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count";
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS = "quantize.imatrix.chunks_count";
+static bool striequals(const char * a, const char * b) {
+ while (*a && *b) {
+ if (std::tolower(*a) != std::tolower(*b)) {
+ return false;
+ }
+ a++; b++;
+ }
+ return *a == *b;
+}
+
static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftype, std::string & ftype_str_out) {
std::string ftype_str;
@@ -70,7 +80,7 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
ftype_str.push_back(std::toupper(ch));
}
for (auto & it : QUANT_OPTIONS) {
- if (it.name == ftype_str) {
+ if (striequals(it.name.c_str(), ftype_str.c_str())) {
ftype = it.ftype;
ftype_str_out = it.name;
return true;
@@ -225,15 +235,15 @@ static int prepare_imatrix(const std::string & imatrix_file,
}
static ggml_type parse_ggml_type(const char * arg) {
- ggml_type result = GGML_TYPE_COUNT;
- for (int j = 0; j < GGML_TYPE_COUNT; ++j) {
- auto type = ggml_type(j);
+ for (int i = 0; i < GGML_TYPE_COUNT; ++i) {
+ auto type = (ggml_type)i;
const auto * name = ggml_type_name(type);
- if (name && strcmp(arg, name) == 0) {
- result = type; break;
+ if (name && striequals(name, arg)) {
+ return type;
}
}
- return result;
+ fprintf(stderr, "%s: invalid ggml_type '%s'\n", __func__, arg);
+ return GGML_TYPE_COUNT;
}
int main(int argc, char ** argv) {
@@ -254,12 +264,18 @@ int main(int argc, char ** argv) {
} else if (strcmp(argv[arg_idx], "--output-tensor-type") == 0) {
if (arg_idx < argc-1) {
params.output_tensor_type = parse_ggml_type(argv[++arg_idx]);
+ if (params.output_tensor_type == GGML_TYPE_COUNT) {
+ usage(argv[0]);
+ }
} else {
usage(argv[0]);
}
} else if (strcmp(argv[arg_idx], "--token-embedding-type") == 0) {
if (arg_idx < argc-1) {
params.token_embedding_type = parse_ggml_type(argv[++arg_idx]);
+ if (params.token_embedding_type == GGML_TYPE_COUNT) {
+ usage(argv[0]);
+ }
} else {
usage(argv[0]);
}
diff --git a/examples/server/README.md b/examples/server/README.md
index 168e14a99..951c4a44c 100644
--- a/examples/server/README.md
+++ b/examples/server/README.md
@@ -7,6 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp.
**Features:**
* LLM inference of F16 and quantized models on GPU and CPU
* [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes
+ * Reranking endoint (WIP: https://github.com/ggerganov/llama.cpp/pull/9510)
* Parallel decoding with multi-user support
* Continuous batching
* Multimodal (wip)
@@ -17,12 +18,13 @@ The project is under active development, and we are [looking for feedback and co
## Usage
+**Common params**
+
| Argument | Explanation |
| -------- | ----------- |
| `-h, --help, --usage` | print usage and exit |
| `--version` | show version and build info |
-| `-v, --verbose` | print verbose information |
-| `--verbosity N` | set specific verbosity level (default: 0) |
+| `--verbose-prompt` | print a verbose prompt before generation (default: false) |
| `-t, --threads N` | number of threads to use during generation (default: -1)
(env: LLAMA_ARG_THREADS) |
| `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) |
| `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") |
@@ -42,13 +44,63 @@ The project is under active development, and we are [looking for feedback and co
| `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) |
| `-fa, --flash-attn` | enable Flash Attention (default: disabled)
(env: LLAMA_ARG_FLASH_ATTN) |
| `-p, --prompt PROMPT` | prompt to start generation with |
+| `--no-perf` | disable internal libllama performance timings (default: false)
(env: LLAMA_ARG_NO_PERF) |
| `-f, --file FNAME` | a file containing the prompt (default: none) |
| `-bf, --binary-file FNAME` | binary file containing the prompt (default: none) |
| `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) |
| `--no-escape` | do not process escape sequences |
-| `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) |
+| `--rope-scaling {none,linear,yarn}` | RoPE frequency scaling method, defaults to linear unless specified by the model
(env: LLAMA_ARG_ROPE_SCALING_TYPE) |
+| `--rope-scale N` | RoPE context scaling factor, expands context by a factor of N
(env: LLAMA_ARG_ROPE_SCALE) |
+| `--rope-freq-base N` | RoPE base frequency, used by NTK-aware scaling (default: loaded from model)
(env: LLAMA_ARG_ROPE_FREQ_BASE) |
+| `--rope-freq-scale N` | RoPE frequency scaling factor, expands context by a factor of 1/N
(env: LLAMA_ARG_ROPE_FREQ_SCALE) |
+| `--yarn-orig-ctx N` | YaRN: original context size of model (default: 0 = model training context size)
(env: LLAMA_ARG_YARN_ORIG_CTX) |
+| `--yarn-ext-factor N` | YaRN: extrapolation mix factor (default: -1.0, 0.0 = full interpolation)
(env: LLAMA_ARG_YARN_EXT_FACTOR) |
+| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: 1.0)
(env: LLAMA_ARG_YARN_ATTN_FACTOR) |
+| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: 1.0)
(env: LLAMA_ARG_YARN_BETA_SLOW) |
+| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: 32.0)
(env: LLAMA_ARG_YARN_BETA_FAST) |
+| `-gan, --grp-attn-n N` | group-attention factor (default: 1)
(env: LLAMA_ARG_GRP_ATTN_N) |
+| `-gaw, --grp-attn-w N` | group-attention width (default: 512.0)
(env: LLAMA_ARG_GRP_ATTN_W) |
+| `-dkvc, --dump-kv-cache` | verbose print of the KV cache |
+| `-nkvo, --no-kv-offload` | disable KV offload
(env: LLAMA_ARG_NO_KV_OFFLOAD) |
+| `-ctk, --cache-type-k TYPE` | KV cache data type for K (default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K) |
+| `-ctv, --cache-type-v TYPE` | KV cache data type for V (default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V) |
+| `-dt, --defrag-thold N` | KV cache defragmentation threshold (default: -1.0, < 0 - disabled)
(env: LLAMA_ARG_DEFRAG_THOLD) |
+| `-np, --parallel N` | number of parallel sequences to decode (default: 1)
(env: LLAMA_ARG_N_PARALLEL) |
+| `--mlock` | force system to keep model in RAM rather than swapping or compressing
(env: LLAMA_ARG_MLOCK) |
+| `--no-mmap` | do not memory-map model (slower load but may reduce pageouts if not using mlock)
(env: LLAMA_ARG_NO_MMAP) |
+| `--numa TYPE` | attempt optimizations that help on some NUMA systems
- distribute: spread execution evenly over all nodes
- isolate: only spawn threads on CPUs on the node that execution started on
- numactl: use the CPU map provided by numactl
if run without this previously, it is recommended to drop the system page cache before using this
see https://github.com/ggerganov/llama.cpp/issues/1437
(env: LLAMA_ARG_NUMA) |
+| `-ngl, --gpu-layers, --n-gpu-layers N` | number of layers to store in VRAM
(env: LLAMA_ARG_N_GPU_LAYERS) |
+| `-sm, --split-mode {none,layer,row}` | how to split the model across multiple GPUs, one of:
- none: use one GPU only
- layer (default): split layers and KV across GPUs
- row: split rows across GPUs
(env: LLAMA_ARG_SPLIT_MODE) |
+| `-ts, --tensor-split N0,N1,N2,...` | fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1
(env: LLAMA_ARG_TENSOR_SPLIT) |
+| `-mg, --main-gpu INDEX` | the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: 0)
(env: LLAMA_ARG_MAIN_GPU) |
+| `--check-tensors` | check model tensor data for invalid values (default: false) |
+| `--override-kv KEY=TYPE:VALUE` | advanced option to override model metadata by key. may be specified multiple times.
types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false |
+| `--lora FNAME` | path to LoRA adapter (can be repeated to use multiple adapters) |
+| `--lora-scaled FNAME SCALE` | path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters) |
+| `--control-vector FNAME` | add a control vector
note: this argument can be repeated to add multiple control vectors |
+| `--control-vector-scaled FNAME SCALE` | add a control vector with user defined scaling SCALE
note: this argument can be repeated to add multiple scaled control vectors |
+| `--control-vector-layer-range START END` | layer range to apply the control vector(s) to, start and end inclusive |
+| `-m, --model FNAME` | model path (default: `models/$filename` with filename from `--hf-file` or `--model-url` if set, otherwise models/7B/ggml-model-f16.gguf)
(env: LLAMA_ARG_MODEL) |
+| `-mu, --model-url MODEL_URL` | model download url (default: unused)
(env: LLAMA_ARG_MODEL_URL) |
+| `-hfr, --hf-repo REPO` | Hugging Face model repository (default: unused)
(env: LLAMA_ARG_HF_REPO) |
+| `-hff, --hf-file FILE` | Hugging Face model file (default: unused)
(env: LLAMA_ARG_HF_FILE) |
+| `-hft, --hf-token TOKEN` | Hugging Face access token (default: value from HF_TOKEN environment variable)
(env: HF_TOKEN) |
+| `-ld, --logdir LOGDIR` | path under which to save YAML logs (no logging if unset) |
+| `--log-disable` | Log disable |
+| `--log-file FNAME` | Log to file |
+| `--log-colors` | Enable colored logging
(env: LLAMA_LOG_COLORS) |
+| `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) |
+| `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored.
(env: LLAMA_LOG_VERBOSITY) |
+| `--log-prefix` | Enable prefx in log messages
(env: LLAMA_LOG_PREFIX) |
+| `--log-timestamps` | Enable timestamps in log messages
(env: LLAMA_LOG_TIMESTAMPS) |
+
+
+**Sampling params**
+
+| Argument | Explanation |
+| -------- | ----------- |
| `--samplers SAMPLERS` | samplers that will be used for generation in the order, separated by ';'
(default: top_k;tfs_z;typ_p;top_p;min_p;temperature) |
-| `-s, --seed SEED` | RNG seed (default: -1, use random seed for < 0) |
+| `-s, --seed SEED` | RNG seed (default: 4294967295, use random seed for 4294967295) |
| `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: kfypmt) |
| `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) |
| `--penalize-nl` | penalize newline tokens (default: false) |
@@ -71,54 +123,29 @@ The project is under active development, and we are [looking for feedback and co
| `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') |
| `--grammar-file FNAME` | file to read grammar from |
| `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object
For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead |
-| `--rope-scaling {none,linear,yarn}` | RoPE frequency scaling method, defaults to linear unless specified by the model |
-| `--rope-scale N` | RoPE context scaling factor, expands context by a factor of N |
-| `--rope-freq-base N` | RoPE base frequency, used by NTK-aware scaling (default: loaded from model) |
-| `--rope-freq-scale N` | RoPE frequency scaling factor, expands context by a factor of 1/N |
-| `--yarn-orig-ctx N` | YaRN: original context size of model (default: 0 = model training context size) |
-| `--yarn-ext-factor N` | YaRN: extrapolation mix factor (default: -1.0, 0.0 = full interpolation) |
-| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: 1.0) |
-| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: 1.0) |
-| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: 32.0) |
-| `-gan, --grp-attn-n N` | group-attention factor (default: 1) |
-| `-gaw, --grp-attn-w N` | group-attention width (default: 512.0) |
-| `-dkvc, --dump-kv-cache` | verbose print of the KV cache |
-| `-nkvo, --no-kv-offload` | disable KV offload |
-| `-ctk, --cache-type-k TYPE` | KV cache data type for K (default: f16) |
-| `-ctv, --cache-type-v TYPE` | KV cache data type for V (default: f16) |
-| `-dt, --defrag-thold N` | KV cache defragmentation threshold (default: -1.0, < 0 - disabled)
(env: LLAMA_ARG_DEFRAG_THOLD) |
-| `-np, --parallel N` | number of parallel sequences to decode (default: 1) |
+
+
+**Example-specific params**
+
+| Argument | Explanation |
+| -------- | ----------- |
+| `--no-context-shift` | disables context shift on inifinite text generation (default: disabled)
(env: LLAMA_ARG_NO_CONTEXT_SHIFT) |
+| `-sp, --special` | special tokens output enabled (default: false) |
+| `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) |
+| `--pooling {none,mean,cls,last,rank}` | pooling type for embeddings, use model default if unspecified
(env: LLAMA_ARG_POOLING) |
| `-cb, --cont-batching` | enable continuous batching (a.k.a dynamic batching) (default: enabled)
(env: LLAMA_ARG_CONT_BATCHING) |
| `-nocb, --no-cont-batching` | disable continuous batching
(env: LLAMA_ARG_NO_CONT_BATCHING) |
-| `--mlock` | force system to keep model in RAM rather than swapping or compressing |
-| `--no-mmap` | do not memory-map model (slower load but may reduce pageouts if not using mlock) |
-| `--numa TYPE` | attempt optimizations that help on some NUMA systems
- distribute: spread execution evenly over all nodes
- isolate: only spawn threads on CPUs on the node that execution started on
- numactl: use the CPU map provided by numactl
if run without this previously, it is recommended to drop the system page cache before using this
see https://github.com/ggerganov/llama.cpp/issues/1437 |
-| `-ngl, --gpu-layers, --n-gpu-layers N` | number of layers to store in VRAM
(env: LLAMA_ARG_N_GPU_LAYERS) |
-| `-sm, --split-mode {none,layer,row}` | how to split the model across multiple GPUs, one of:
- none: use one GPU only
- layer (default): split layers and KV across GPUs
- row: split rows across GPUs |
-| `-ts, --tensor-split N0,N1,N2,...` | fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1 |
-| `-mg, --main-gpu INDEX` | the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: 0) |
-| `--check-tensors` | check model tensor data for invalid values (default: false) |
-| `--override-kv KEY=TYPE:VALUE` | advanced option to override model metadata by key. may be specified multiple times.
types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false |
-| `--lora FNAME` | path to LoRA adapter (can be repeated to use multiple adapters) |
-| `--lora-scaled FNAME SCALE` | path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters) |
-| `--control-vector FNAME` | add a control vector
note: this argument can be repeated to add multiple control vectors |
-| `--control-vector-scaled FNAME SCALE` | add a control vector with user defined scaling SCALE
note: this argument can be repeated to add multiple scaled control vectors |
-| `--control-vector-layer-range START END` | layer range to apply the control vector(s) to, start and end inclusive |
-| `-a, --alias STRING` | set alias for model name (to be used by REST API) |
-| `-m, --model FNAME` | model path (default: `models/$filename` with filename from `--hf-file` or `--model-url` if set, otherwise models/7B/ggml-model-f16.gguf)
(env: LLAMA_ARG_MODEL) |
-| `-mu, --model-url MODEL_URL` | model download url (default: unused)
(env: LLAMA_ARG_MODEL_URL) |
-| `-hfr, --hf-repo REPO` | Hugging Face model repository (default: unused)
(env: LLAMA_ARG_HF_REPO) |
-| `-hff, --hf-file FILE` | Hugging Face model file (default: unused)
(env: LLAMA_ARG_HF_FILE) |
-| `-hft, --hf-token TOKEN` | Hugging Face access token (default: value from HF_TOKEN environment variable)
(env: HF_TOKEN) |
+| `-a, --alias STRING` | set alias for model name (to be used by REST API)
(env: LLAMA_ARG_ALIAS) |
| `--host HOST` | ip address to listen (default: 127.0.0.1)
(env: LLAMA_ARG_HOST) |
| `--port PORT` | port to listen (default: 8080)
(env: LLAMA_ARG_PORT) |
-| `--path PATH` | path to serve static files from (default: ) |
+| `--path PATH` | path to serve static files from (default: )
(env: LLAMA_ARG_STATIC_PATH) |
| `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)
(env: LLAMA_ARG_EMBEDDINGS) |
+| `--reranking, --rerank` | enable reranking endpoint on server (default: disabled)
(env: LLAMA_ARG_RERANKING) |
| `--api-key KEY` | API key to use for authentication (default: none)
(env: LLAMA_API_KEY) |
| `--api-key-file FNAME` | path to file containing API keys (default: none) |
-| `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key |
-| `--ssl-cert-file FNAME` | path to file a PEM-encoded SSL certificate |
-| `-to, --timeout N` | server read/write timeout in seconds (default: 600) |
+| `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key
(env: LLAMA_ARG_SSL_KEY_FILE) |
+| `--ssl-cert-file FNAME` | path to file a PEM-encoded SSL certificate
(env: LLAMA_ARG_SSL_CERT_FILE) |
+| `-to, --timeout N` | server read/write timeout in seconds (default: 600)
(env: LLAMA_ARG_TIMEOUT) |
| `--threads-http N` | number of threads used to process HTTP requests (default: -1)
(env: LLAMA_ARG_THREADS_HTTP) |
| `-spf, --system-prompt-file FNAME` | set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications |
| `--metrics` | enable prometheus compatible metrics endpoint (default: disabled)
(env: LLAMA_ARG_ENDPOINT_METRICS) |
@@ -127,13 +154,7 @@ The project is under active development, and we are [looking for feedback and co
| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted:
https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
(env: LLAMA_ARG_CHAT_TEMPLATE) |
| `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)
|
| `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) |
-| `-ld, --logdir LOGDIR` | path under which to save YAML logs (no logging if unset) |
-| `--log-test` | Log test |
-| `--log-disable` | Log disable |
-| `--log-enable` | Log enable |
-| `--log-new` | Log new |
-| `--log-append` | Log append |
-| `--log-file FNAME` | Log file |
+
Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var.
@@ -461,6 +482,39 @@ The same as [the embedding example](../embedding) does.
`image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `content`. You can determine the place of the image in the content as in the following: `Image: [img-21].\nCaption: This is a picture of a house`. In this case, `[img-21]` will be replaced by the embeddings of the image with id `21` in the following `image_data` array: `{..., "image_data": [{"data": "", "id": 21}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
+### POST `/reranking`: Rerank documents according to a given query
+
+Similar to https://jina.ai/reranker/ but might change in the future.
+Requires a reranker model (such as [bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3)) and the `--embedding --pooling rank` options.
+
+ *Options:*
+
+ `query`: The query against which the documents will be ranked.
+
+ `documents`: An array strings representing the documents to be ranked.
+
+ *Aliases:*
+ - `/rerank`
+ - `/v1/rerank`
+ - `/v1/reranking`
+
+ *Examples:*
+
+ ```shell
+ curl http://127.0.0.1:8012/v1/rerank \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "some-model",
+ "query": "What is panda?",
+ "top_n": 3,
+ "documents": [
+ "hi",
+ "it is a bear",
+ "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China."
+ ]
+ }' | jq
+ ```
+
### POST `/infill`: For code infilling.
Takes a prefix and a suffix and returns the predicted completion as stream.
@@ -501,7 +555,7 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte
See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). While some OpenAI-specific features such as function calling aren't supported, llama.cpp `/completion`-specific features such as `mirostat` are supported.
- The `response_format` parameter supports both plain JSON output (e.g. `{"type": "json_object"}`) and schema-constrained JSON (e.g. `{"type": "json_object", "schema": {"type": "string", "minLength": 10, "maxLength": 100}}`), similar to other OpenAI-inspired API providers.
+ The `response_format` parameter supports both plain JSON output (e.g. `{"type": "json_object"}`) and schema-constrained JSON (e.g. `{"type": "json_object", "schema": {"type": "string", "minLength": 10, "maxLength": 100}}` or `{"type": "json_schema", "schema": {"properties": { "name": { "title": "Name", "type": "string" }, "date": { "title": "Date", "type": "string" }, "participants": { "items": {"type: "string" }, "title": "Participants", "type": "string" } } } }`), similar to other OpenAI-inspired API providers.
*Examples:*
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index b5f264ff1..f343cc252 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -92,6 +92,7 @@ enum server_task_type {
enum server_task_cmpl_type {
SERVER_TASK_CMPL_TYPE_NORMAL,
SERVER_TASK_CMPL_TYPE_EMBEDDING,
+ SERVER_TASK_CMPL_TYPE_RERANK,
SERVER_TASK_CMPL_TYPE_INFILL,
};
@@ -172,6 +173,7 @@ struct server_slot {
std::vector generated_token_probs;
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
+
bool has_next_token = true;
bool truncated = false;
bool stopped_eos = false;
@@ -531,26 +533,38 @@ struct server_response {
// add the id_task to the list of tasks waiting for response
void add_waiting_task_id(int id_task) {
- SRV_DBG("waiting for task id = %d\n", id_task);
+ SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
std::unique_lock lock(mutex_results);
waiting_task_ids.insert(id_task);
}
void add_waiting_tasks(const std::vector & tasks) {
- for (const auto & t : tasks) {
- add_waiting_task_id(t.id);
+ std::unique_lock lock(mutex_results);
+
+ for (const auto & task : tasks) {
+ SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size());
+ waiting_task_ids.insert(task.id);
}
}
// when the request is finished, we can remove task associated with it
void remove_waiting_task_id(int id_task) {
- SRV_DBG("task id = %d is done\n", id_task);
+ SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
std::unique_lock lock(mutex_results);
waiting_task_ids.erase(id_task);
}
+ void remove_waiting_task_ids(const std::unordered_set & id_tasks) {
+ std::unique_lock lock(mutex_results);
+
+ for (const auto & id_task : id_tasks) {
+ SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
+ waiting_task_ids.erase(id_task);
+ }
+ }
+
// This function blocks the thread until there is a response for one of the id_tasks
server_task_result recv(const std::unordered_set & id_tasks) {
while (true) {
@@ -942,8 +956,17 @@ struct server_context {
slot.prompt = *prompt;
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
slot.prompt = prompt->at(0);
+ } else if (prompt->is_array() && prompt->size() > 1) {
+ // array of strings
+ for (const auto & el : *prompt) {
+ if (!el.is_string()) {
+ send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
+ return false;
+ }
+ }
+ slot.prompt = *prompt;
} else {
- send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST);
+ send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
return false;
}
}
@@ -1168,6 +1191,15 @@ struct server_context {
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
}
+ // if context shift is disabled, we stop when it reaches the context limit
+ if (slot.n_decoded >= slot.n_ctx) {
+ slot.truncated = true;
+ slot.stopped_limit = true;
+ slot.has_next_token = false;
+
+ SLT_DBG(slot, "stopped due to running out of context capacity, n_decoded = %d, n_ctx = %d\n", slot.n_decoded, slot.n_ctx);
+ }
+
if (llama_token_is_eog(model, result.tok)) {
slot.stopped_eos = true;
slot.has_next_token = false;
@@ -1368,6 +1400,7 @@ struct server_context {
res.data = json {
{"embedding", std::vector(n_embd, 0.0f)},
+ {"index", slot.index},
};
continue;
@@ -1386,6 +1419,44 @@ struct server_context {
queue_results.send(res);
}
+ void send_rerank(const server_slot & slot, const llama_batch & batch) {
+ server_task_result res;
+ res.id = slot.id_task;
+ res.error = false;
+ res.stop = true;
+
+ for (int i = 0; i < batch.n_tokens; ++i) {
+ if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
+ continue;
+ }
+
+ const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
+ if (embd == NULL) {
+ embd = llama_get_embeddings_ith(ctx, i);
+ }
+
+ if (embd == NULL) {
+ SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
+
+ res.data = json {
+ {"index", slot.index},
+ {"score", -1e6},
+ };
+
+ continue;
+ }
+
+ res.data = json {
+ {"index", slot.index},
+ {"score", embd[0]},
+ };
+ }
+
+ SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str());
+
+ queue_results.send(res);
+ }
+
//
// Functions to create new task(s) and receive result(s)
//
@@ -1421,13 +1492,27 @@ struct server_context {
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
else if (prompt.is_array()) {
std::vector prompts = prompt;
- for (size_t i = 0; i < prompts.size(); i++) {
- const auto & e = prompts[i];
- if (e.is_string() || json_is_array_of_numbers(e)) {
- data["index"] = i;
- create_task(data, true, e);
- } else {
- throw std::runtime_error(error_msg);
+ if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
+ // prompts[0] is the question
+ // the rest are the answers/documents
+ SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
+ for (size_t i = 1; i < prompts.size(); i++) {
+ json qd;
+ qd.push_back(prompts[0]);
+ qd.push_back(prompts[i]);
+ data["index"] = i - 1;
+ create_task(data, true, qd);
+ }
+ } else {
+ SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
+ for (size_t i = 0; i < prompts.size(); i++) {
+ const auto & e = prompts[i];
+ if (e.is_string() || json_is_array_of_numbers(e)) {
+ data["index"] = i;
+ create_task(data, true, e);
+ } else {
+ throw std::runtime_error(error_msg);
+ }
}
}
}
@@ -1468,10 +1553,12 @@ struct server_context {
if (result.error) {
error_handler(result.data);
cancel_tasks(id_tasks);
- break;
+ return;
}
- size_t idx = result.data["index"];
+ const size_t idx = result.data["index"];
+ GGML_ASSERT(idx < results.size() && "index out of range");
+
results[idx] = result;
}
result_handler(results);
@@ -1815,6 +1902,14 @@ struct server_context {
for (server_slot & slot : slots) {
if (slot.ga_n == 1) {
if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
+ if (!params.ctx_shift) {
+ // this check is redundant (for good)
+ // we should never get here, because generation should already stopped in process_token()
+ slot.release();
+ send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
+ continue;
+ }
+
// Shift context
const int n_keep = slot.params.n_keep + add_bos_token;
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
@@ -1874,6 +1969,7 @@ struct server_context {
// track if this is an embedding or non-embedding batch
// if we've added sampled tokens above, we are in non-embedding mode
// -1: none, 0: non-embedding, 1: embedding
+ // TODO: make enum
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
// next, batch any pending prompts without exceeding n_batch
@@ -1922,6 +2018,29 @@ struct server_context {
}
prompt_tokens = embd_inp;
+ } else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
+ // require slot.prompt to be array of 2 strings
+ if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
+ SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
+ slot.release();
+ send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
+ continue;
+ }
+
+ // prompt: querydoc
+ prompt_tokens.clear();
+ prompt_tokens.push_back(llama_token_bos(model));
+ {
+ const auto part = tokenize(slot.prompt[0], false);
+ prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
+ }
+ prompt_tokens.push_back(llama_token_eos(model));
+ prompt_tokens.push_back(llama_token_bos(model));
+ {
+ const auto part = tokenize(slot.prompt[1], false);
+ prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
+ }
+ prompt_tokens.push_back(llama_token_eos(model));
} else {
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
}
@@ -1941,7 +2060,7 @@ struct server_context {
continue;
}
- if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
+ if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
// this prompt is too large to process - discard it
if (slot.n_prompt_tokens > n_ubatch) {
slot.release();
@@ -1949,6 +2068,14 @@ struct server_context {
continue;
}
} else {
+ if (!params.ctx_shift) {
+ // if context shift is disabled, we make sure prompt size is smaller than KV size
+ if ((int) system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) {
+ slot.release();
+ send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
+ continue;
+ }
+ }
if (slot.params.n_keep < 0) {
slot.params.n_keep = slot.n_prompt_tokens;
}
@@ -2011,7 +2138,8 @@ struct server_context {
slot.n_prompt_tokens_processed = 0;
}
- if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
+ // non-causal tasks require to fit the entire prompt in the physical batch
+ if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
continue;
@@ -2019,7 +2147,10 @@ struct server_context {
}
// check that we are in the right batch_type, if not defer the slot
- bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0;
+ const bool slot_type =
+ slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
+ slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
+
if (batch_type == -1) {
batch_type = slot_type;
} else if (batch_type != slot_type) {
@@ -2192,6 +2323,13 @@ struct server_context {
continue; // continue loop of slots
}
+ if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
+ send_rerank(slot, batch_view);
+ slot.release();
+ slot.i_batch = -1;
+ continue; // continue loop of slots
+ }
+
// prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING;
} else if (slot.state != SLOT_STATE_GENERATING) {
@@ -2254,14 +2392,6 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
return;
}
- //LOG_INFO("request", {
- // {"remote_addr", req.remote_addr},
- // {"remote_port", req.remote_port},
- // {"status", res.status},
- // {"method", req.method},
- // {"path", req.path},
- // {"params", req.params},
- //});
LOG_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
LOG_DBG("request: %s\n", req.body.c_str());
@@ -2318,15 +2448,19 @@ int main(int argc, char ** argv) {
std::unique_ptr svr;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
- LOG_INFO("Running with SSL", {{"key", params.ssl_file_key}, {"cert", params.ssl_file_cert}});
+ LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
svr.reset(
new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str())
);
} else {
- LOG_INFO("Running without SSL", {});
+ LOG_INF("Running without SSL\n");
svr.reset(new httplib::Server());
}
#else
+ if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
+ LOG_ERR("Server is built without SSL support\n");
+ return 1;
+ }
svr.reset(new httplib::Server());
#endif
@@ -2754,8 +2888,8 @@ int main(int argc, char ** argv) {
};
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
- if (ctx_server.params.embedding) {
- res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
+ if (ctx_server.params.embedding || ctx_server.params.reranking) {
+ res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
@@ -2782,6 +2916,8 @@ int main(int argc, char ** argv) {
}, [&](const json & error_data) {
res_error(res, error_data);
});
+
+ ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else {
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
@@ -2792,7 +2928,12 @@ int main(int argc, char ** argv) {
sink.done();
return false;
};
- res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
+
+ auto on_complete = [task_ids, &ctx_server] (bool) {
+ ctx_server.queue_results.remove_waiting_task_ids(task_ids);
+ };
+
+ res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
};
@@ -2808,8 +2949,8 @@ int main(int argc, char ** argv) {
// TODO: maybe merge this function with "handle_completions_generic"
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
- if (ctx_server.params.embedding) {
- res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
+ if (ctx_server.params.embedding || ctx_server.params.reranking) {
+ res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
@@ -2831,6 +2972,8 @@ int main(int argc, char ** argv) {
}, [&](const json & error_data) {
res_error(res, error_data);
});
+
+ ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else {
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
@@ -2852,7 +2995,12 @@ int main(int argc, char ** argv) {
sink.done();
return true;
};
- res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
+
+ auto on_complete = [task_ids, &ctx_server] (bool) {
+ ctx_server.queue_results.remove_waiting_task_ids(task_ids);
+ };
+
+ res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
};
@@ -2926,6 +3074,11 @@ int main(int argc, char ** argv) {
};
const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
+ // TODO: somehow clean up this checks in the future
+ if (!ctx_server.params.embedding || ctx_server.params.reranking) {
+ res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
+ return;
+ }
const json body = json::parse(req.body);
bool is_openai = false;
@@ -2961,6 +3114,8 @@ int main(int argc, char ** argv) {
res_error(res, error_data);
error = true;
});
+
+ ctx_server.queue_results.remove_waiting_task_ids(task_ids);
}
if (error) {
@@ -2974,6 +3129,79 @@ int main(int argc, char ** argv) {
res_ok(res, root);
};
+ const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
+ if (!ctx_server.params.reranking) {
+ res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
+ return;
+ }
+ const json body = json::parse(req.body);
+
+ // TODO: implement
+ //int top_n = 1;
+ //if (body.count("top_n") != 1) {
+ // top_n = body.at("top_n");
+ //} else {
+ // res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
+ // return;
+ //}
+
+ json query;
+ if (body.count("query") == 1) {
+ query = body.at("query");
+ if (!query.is_string()) {
+ res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
+ return;
+ }
+ } else {
+ res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
+ return;
+ }
+
+ std::vector documents = json_value(body, "documents", std::vector());
+ if (documents.empty()) {
+ res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
+ return;
+ }
+
+ // construct prompt object: array of ["query", "doc0", "doc1", ...]
+ json prompt;
+ prompt.push_back(query);
+ for (const auto & doc : documents) {
+ prompt.push_back(doc);
+ }
+
+ LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str());
+
+ // create and queue the task
+ json responses = json::array();
+ bool error = false;
+ {
+ std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
+ ctx_server.queue_results.add_waiting_tasks(tasks);
+ ctx_server.queue_tasks.post(tasks);
+
+ // get the result
+ std::unordered_set task_ids = server_task::get_list_id(tasks);
+
+ ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) {
+ for (const auto & res : results) {
+ responses.push_back(res.data);
+ }
+ }, [&](const json & error_data) {
+ res_error(res, error_data);
+ error = true;
+ });
+ }
+
+ if (error) {
+ return;
+ }
+
+ // write JSON response
+ json root = format_response_rerank(body, responses);
+ res_ok(res, root);
+ };
+
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
json result = json::array();
for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
@@ -3070,6 +3298,10 @@ int main(int argc, char ** argv) {
svr->Post("/embedding", handle_embeddings); // legacy
svr->Post("/embeddings", handle_embeddings);
svr->Post("/v1/embeddings", handle_embeddings);
+ svr->Post("/rerank", handle_rerank);
+ svr->Post("/reranking", handle_rerank);
+ svr->Post("/v1/rerank", handle_rerank);
+ svr->Post("/v1/reranking", handle_rerank);
svr->Post("/tokenize", handle_tokenize);
svr->Post("/detokenize", handle_detokenize);
// LoRA adapters hotswap
@@ -3108,7 +3340,6 @@ int main(int argc, char ** argv) {
std::thread t([&]() { svr->listen_after_bind(); });
svr->wait_until_ready();
- //LOG_INFO("HTTP server is listening", log_data);
LOG_INF("%s: HTTP server is listening, hostname: %s, port: %d, http threads: %d\n", __func__, params.hostname.c_str(), params.port, params.n_threads_http);
// load the model
@@ -3135,7 +3366,7 @@ int main(int argc, char ** argv) {
}
// print sample chat example to make it clear which template is used
- LOG_INF("%s: chat template, built_in: %d, chat_example: '%s\n'", __func__, params.chat_template.empty(), llama_chat_format_example(ctx_server.model, params.chat_template).c_str());
+ LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), llama_chat_format_example(ctx_server.model, params.chat_template).c_str());
ctx_server.queue_tasks.on_new_task(std::bind(
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
diff --git a/examples/server/tests/features/ctx_shift.feature b/examples/server/tests/features/ctx_shift.feature
new file mode 100644
index 000000000..ba3afcf06
--- /dev/null
+++ b/examples/server/tests/features/ctx_shift.feature
@@ -0,0 +1,62 @@
+@llama.cpp
+@ctx_shift
+Feature: llama.cpp server
+
+ Background: Server startup
+ Given a server listening on localhost:8080
+ And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
+ And a model file test-model.gguf
+ And a model alias tinyllama-2
+ And BOS token is 1
+ And 42 as server seed
+ And 256 KV cache size
+ And 32 as batch size
+ And 2 slots
+
+ Scenario: Inference with context shift
+ And 64 server max tokens to predict
+ Then the server is starting
+ Then the server is healthy
+ Given a prompt:
+ """
+ Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+ Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+ Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+ Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
+ """
+ And a completion request with no api error
+ Then 64 tokens are predicted matching fun|Annaks|popcorns|pictry|bowl
+ And the completion is truncated
+ And 109 prompt tokens are processed
+
+ Scenario Outline: Inference without context shift
+ And server max tokens to predict
+ And disable context shifting
+ Then the server is starting
+ Then the server is healthy
+ Given a prompt:
+ """
+ Hi how are you
+ """
+ And a completion request with no api error
+ Then tokens are predicted matching twind|Anna
+ And the completion is truncated
+ And 8 prompt tokens are processed
+ Examples:
+ | n_predict | n_token_output | truncated |
+ | 64 | 64 | not |
+ | -1 | 120 | |
+
+ Scenario: Inference without context shift (expected error: prompt too long)
+ And disable context shifting
+ Then the server is starting
+ Then the server is healthy
+ Given a prompt:
+ """
+ Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+ Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+ Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+ Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
+ """
+ And a completion request with 400 api error
+
diff --git a/examples/server/tests/features/embeddings.feature b/examples/server/tests/features/embeddings.feature
index e1eade6cd..f4fe2ee43 100644
--- a/examples/server/tests/features/embeddings.feature
+++ b/examples/server/tests/features/embeddings.feature
@@ -10,12 +10,12 @@ Feature: llama.cpp server
And 42 as server seed
And 2 slots
# the bert-bge-small model has context size of 512
- # since the generated prompts are as big as the batch size, we need to set the batch size to 512
+ # since the generated prompts are as big as the batch size, we need to set the batch size to <= 512
# ref: https://huggingface.co/BAAI/bge-small-en-v1.5/blob/5c38ec7c405ec4b44b94cc5a9bb96e735b38267a/config.json#L20
- And 512 as batch size
- And 512 as ubatch size
- And 2048 KV cache size
- And embeddings extraction
+ And 128 as batch size
+ And 128 as ubatch size
+ And 512 KV cache size
+ And enable embeddings endpoint
Then the server is starting
Then the server is healthy
@@ -26,6 +26,20 @@ Feature: llama.cpp server
"""
Then embeddings are generated
+ Scenario: Embedding (error: prompt too long)
+ When embeddings are computed for:
+ """
+ Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+ Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+ Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+ Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
+ Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+ Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+ Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+ Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
+ """
+ And embeddings request with 500 api error
+
Scenario: OAI Embeddings compatibility
Given a model bert-bge-small
When an OAI compatible embeddings computation request for:
diff --git a/examples/server/tests/features/rerank.feature b/examples/server/tests/features/rerank.feature
new file mode 100644
index 000000000..c36cc8e21
--- /dev/null
+++ b/examples/server/tests/features/rerank.feature
@@ -0,0 +1,42 @@
+@llama.cpp
+@rerank
+Feature: llama.cpp server
+
+ Background: Server startup
+ Given a server listening on localhost:8080
+ And a model url https://huggingface.co/ggml-org/models/resolve/main/jina-reranker-v1-tiny-en/ggml-model-f16.gguf
+ And a model file jina-reranker-v1-tiny-en.gguf
+ And a model alias jina-reranker-v1-tiny-en
+ And 42 as server seed
+ And 2 slots
+ And 512 as batch size
+ And 512 as ubatch size
+ And 512 KV cache size
+ And enable reranking endpoint
+ Then the server is starting
+ Then the server is healthy
+
+ Scenario: Rerank
+ Given a rerank query:
+ """
+ Machine learning is
+ """
+ And a rerank document:
+ """
+ A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.
+ """
+ And a rerank document:
+ """
+ Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.
+ """
+ And a rerank document:
+ """
+ Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.
+ """
+ And a rerank document:
+ """
+ Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine.
+ """
+ When reranking request
+ Then reranking results are returned
+ Then reranking highest score is index 2 and lowest score is index 3
diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py
index 062f084be..2611614ba 100644
--- a/examples/server/tests/features/steps/steps.py
+++ b/examples/server/tests/features/steps/steps.py
@@ -68,6 +68,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
context.server_api_key = None
context.server_continuous_batching = False
context.server_embeddings = False
+ context.server_reranking = False
context.server_metrics = False
context.server_process = None
context.seed = None
@@ -77,11 +78,16 @@ def step_server_config(context, server_fqdn: str, server_port: str):
context.response_format = None
context.temperature = None
context.lora_file = None
+ context.disable_ctx_shift = False
context.tasks_result = []
context.concurrent_tasks = []
context.prompts = []
+ context.reranking_query = None
+ context.reranking_documents = []
+ context.reranking_results = None
+
@step('a model file {hf_file} from HF repo {hf_repo}')
def step_download_hf_model(context, hf_file: str, hf_repo: str):
@@ -148,7 +154,7 @@ def step_n_slots(context, n_slots: int):
@step('{n_predict:d} server max tokens to predict')
def step_server_n_predict(context, n_predict: int):
- context.n_server_predict = n_predict
+ context.n_server_predict = n_predict if n_predict > 0 else None
@step('{slot_save_path} as slot save path')
@@ -171,15 +177,21 @@ def step_server_continuous_batching(context):
context.server_continuous_batching = True
-@step('embeddings extraction')
+@step('enable embeddings endpoint')
def step_server_embeddings(context):
context.server_embeddings = True
+@step('enable reranking endpoint')
+def step_server_reranking(context):
+ context.server_reranking = True
@step('prometheus compatible metrics exposed')
def step_server_metrics(context):
context.server_metrics = True
+@step('disable context shifting')
+def step_server_disable_ctx_shift(context):
+ context.disable_ctx_shift = True
@step("the server is starting")
def step_start_server(context):
@@ -257,7 +269,7 @@ async def step_all_slots_status(context, expected_slot_status_string: Literal['i
@step('a completion request with {api_error} api error')
@async_run_until_complete
async def step_request_completion(context, api_error: Literal['raised'] | str):
- expect_api_error = api_error == 'raised'
+ expect_api_error = api_error == 'raised' or api_error != 'no'
seeds = await completions_seed(context, num_seeds=1)
completion = await request_completion(context.prompts.pop(),
seeds[0] if seeds is not None else seeds,
@@ -272,8 +284,11 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
context.tasks_result.append(completion)
if context.debug:
print(f"Completion response: {completion}")
- if expect_api_error:
+ if api_error == 'raised':
assert completion == 401, f"completion must be an 401 status code: {completion}"
+ elif api_error.isdigit():
+ api_error_code = int(api_error)
+ assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
@step('{predicted_n:d} tokens are predicted matching {re_content}')
@@ -445,6 +460,14 @@ def step_impl(context, n_ga_w):
def step_prompt_passkey(context):
context.prompt_passkey = context_text(context)
+@step('a rerank query')
+def step_set_rerank_query(context):
+ context.reranking_query = context_text(context)
+ context.reranking_documents = []
+
+@step('a rerank document')
+def step_set_rerank_document(context):
+ context.reranking_documents.append(context_text(context))
@step('{n_prompts:d} fixed prompts')
def step_fixed_prompts(context, n_prompts):
@@ -612,6 +635,22 @@ async def step_compute_embedding(context):
context.embeddings = await request_embedding(context_text(context), None, base_url=context.base_url)
+@step('reranking request')
+@async_run_until_complete
+async def step_compute_reranking(context):
+ async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
+ async with session.post(f'{context.base_url}/reranking',
+ json={
+ "query": context.reranking_query,
+ "documents": context.reranking_documents,
+ }) as response:
+ if response.status == 200:
+ response_json = await response.json()
+ context.reranking_results = response_json['results']
+ else:
+ context.reranking_results = response.status
+
+
@step('all embeddings are the same')
@async_run_until_complete
async def step_all_embeddings_are_the_same(context):
@@ -645,6 +684,9 @@ def step_assert_embeddings(context):
for embedding in context.embeddings:
assert_embeddings(embedding)
+@step('embeddings request with {api_error_code:d} api error')
+def step_assert_embeddings(context, api_error_code: int):
+ assert context.embeddings == api_error_code, f"embeddings request must return code {api_error_code}, but got {context.embeddings}"
@step('an OAI compatible embeddings computation request for')
@async_run_until_complete
@@ -694,6 +736,24 @@ async def all_embeddings_are_generated(context):
for i in range(n_embedding_requests):
assert_embeddings(context.tasks_result.pop().pop())
+@step('reranking results are returned')
+def reranking_results_are_returned(context):
+ assert len(context.reranking_results) == len(context.reranking_documents)
+
+@step('reranking highest score is index {idx_high:d} and lowest score is index {idx_low:d}')
+def reranking_results_are_returned(context, idx_high: int, idx_low: int):
+ max_score, max_idx = 0, 0
+ min_score, min_idx = 0, 0
+ for res in context.reranking_results:
+ if max_score < res['relevance_score']:
+ max_score = res['relevance_score']
+ max_idx = res['index']
+ if min_score > res['relevance_score']:
+ min_score = res['relevance_score']
+ min_idx = res['index']
+ print(context.reranking_results)
+ assert max_idx == idx_high
+ assert min_idx == idx_low
@step('adding special tokens')
def step_tokenize_set_add_special(context):
@@ -1089,15 +1149,17 @@ async def oai_chat_completions(user_prompt,
return completion_response
-async def request_embedding(content, seed, base_url=None) -> list[list[float]]:
+async def request_embedding(content, seed, base_url=None) -> list[list[float]] | int:
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
async with session.post(f'{base_url}/embedding',
json={
"content": content,
}) as response:
- assert response.status == 200
- response_json = await response.json()
- return [response_json['embedding']]
+ if response.status == 200:
+ response_json = await response.json()
+ return [response_json['embedding']]
+ else:
+ return response.status
async def request_oai_embeddings(input, seed,
@@ -1350,6 +1412,8 @@ def start_server_background(context):
server_args.append('--cont-batching')
if context.server_embeddings:
server_args.append('--embedding')
+ if context.server_reranking:
+ server_args.append('--reranking')
if context.server_metrics:
server_args.append('--metrics')
if context.model_alias:
@@ -1372,6 +1436,8 @@ def start_server_background(context):
server_args.append('--verbose')
if context.lora_file:
server_args.extend(['--lora', context.lora_file])
+ if context.disable_ctx_shift:
+ server_args.extend(['--no-context-shift'])
args = [str(arg) for arg in [context.server_path, *server_args]]
print(f"bench: starting server with: {' '.join(args)}")
diff --git a/examples/server/tests/requirements.txt b/examples/server/tests/requirements.txt
index f2d7e5c57..553954872 100644
--- a/examples/server/tests/requirements.txt
+++ b/examples/server/tests/requirements.txt
@@ -1,6 +1,6 @@
aiohttp~=3.9.3
behave~=1.2.6
-huggingface_hub~=0.20.3
+huggingface_hub~=0.23.2
numpy~=1.26.4
openai~=1.30.3
prometheus-client~=0.20.0
diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp
index 537c8a223..47dfdfde5 100644
--- a/examples/server/utils.hpp
+++ b/examples/server/utils.hpp
@@ -331,6 +331,9 @@ static json oaicompat_completion_params_parse(
std::string response_type = json_value(response_format, "type", std::string());
if (response_type == "json_object") {
llama_params["json_schema"] = json_value(response_format, "schema", json::object());
+ } else if (response_type == "json_schema") {
+ json json_schema = json_value(response_format, "json_schema", json::object());
+ llama_params["json_schema"] = json_value(json_schema, "schema", json::object());
} else if (!response_type.empty() && response_type != "text") {
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
}
@@ -534,7 +537,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
json res = json {
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"},
- {"usage", json {
+ {"usage", json { // TODO: fill
{"prompt_tokens", 0},
{"total_tokens", 0}
}},
@@ -544,6 +547,29 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
return res;
}
+static json format_response_rerank(const json & request, const json & ranks) {
+ json data = json::array();
+ int i = 0;
+ for (const auto & rank : ranks) {
+ data.push_back(json{
+ {"index", i++},
+ {"relevance_score", json_value(rank, "score", 0.0)},
+ });
+ }
+
+ json res = json {
+ {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
+ {"object", "list"},
+ {"usage", json { // TODO: fill
+ {"prompt_tokens", 0},
+ {"total_tokens", 0}
+ }},
+ {"results", data}
+ };
+
+ return res;
+}
+
static bool is_valid_utf8(const std::string & str) {
const unsigned char* bytes = reinterpret_cast(str.data());
const unsigned char* end = bytes + str.length();
diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp
index fbac21811..adf6255e1 100644
--- a/examples/speculative/speculative.cpp
+++ b/examples/speculative/speculative.cpp
@@ -32,6 +32,9 @@ struct seq_draft {
int main(int argc, char ** argv) {
gpt_params params;
+ // needed to get candidate probs even for temp <= 0.0
+ params.sparams.n_probs = 128;
+
if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
return 1;
}
@@ -49,7 +52,7 @@ int main(int argc, char ** argv) {
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
const float p_split = params.p_split;
- std::default_random_engine rng(params.sparams.seed);
+ std::default_random_engine rng(params.sparams.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sparams.seed);
std::uniform_real_distribution<> u_dist;
// init llama.cpp
diff --git a/examples/sycl/run-llama2.sh b/examples/sycl/run-llama2.sh
index a8cf0aa64..3b9ba3b2d 100755
--- a/examples/sycl/run-llama2.sh
+++ b/examples/sycl/run-llama2.sh
@@ -11,16 +11,17 @@ source /opt/intel/oneapi/setvars.sh
#ZES_ENABLE_SYSMAN=1, Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory. Recommended to use when --split-mode = layer.
INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:"
-MODEL_FILE=llama-2-7b.Q4_0.gguf
+MODEL_FILE=models/llama-2-7b.Q4_0.gguf
NGL=33
+CONEXT=8192
if [ $# -gt 0 ]; then
GGML_SYCL_DEVICE=$1
echo "use $GGML_SYCL_DEVICE as main GPU"
#use signle GPU only
- ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m models/${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -mg $GGML_SYCL_DEVICE -sm none
+ ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONEXT} -mg $GGML_SYCL_DEVICE -sm none
else
#use multiple GPUs with same max compute units
- ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m models/${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0
+ ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONEXT}
fi
diff --git a/flake.lock b/flake.lock
index 0db5ff92a..dde1ab527 100644
--- a/flake.lock
+++ b/flake.lock
@@ -20,11 +20,11 @@
},
"nixpkgs": {
"locked": {
- "lastModified": 1726062873,
- "narHash": "sha256-IiA3jfbR7K/B5+9byVi9BZGWTD4VSbWe8VLpp9B/iYk=",
+ "lastModified": 1727348695,
+ "narHash": "sha256-J+PeFKSDV+pHL7ukkfpVzCOO7mBSrrpJ3svwBFABbhI=",
"owner": "NixOS",
"repo": "nixpkgs",
- "rev": "4f807e8940284ad7925ebd0a0993d2a1791acb2f",
+ "rev": "1925c603f17fc89f4c8f6bf6f631a802ad85d784",
"type": "github"
},
"original": {
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index e497b6d02..71c0bef8e 100644
--- a/ggml/include/ggml-backend.h
+++ b/ggml/include/ggml-backend.h
@@ -66,6 +66,7 @@ extern "C" {
// "offset" refers to the offset of the tensor data for setting/getting data
GGML_API GGML_CALL void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
+ GGML_API GGML_CALL void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
@@ -122,7 +123,7 @@ extern "C" {
// The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
GGML_API size_t ggml_backend_reg_get_count(void);
- GGML_API size_t ggml_backend_reg_find_by_name(const char * name);
+ GGML_API size_t ggml_backend_reg_find_by_name(const char * name); // returns index of backend with name, or SIZE_MAX if not found
GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is backend_name:params (params is optional)
GGML_API const char * ggml_backend_reg_get_name(size_t i);
GGML_API ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index a413df357..f46d4a8a6 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -229,14 +229,16 @@
#define GGML_MAX_PARAMS 2048
#define GGML_MAX_CONTEXTS 64
#define GGML_MAX_SRC 10
-#ifndef GGML_MAX_NAME
-#define GGML_MAX_NAME 64
#define GGML_MAX_N_THREADS 512
-
-#endif
#define GGML_MAX_OP_PARAMS 64
+
+#ifndef GGML_MAX_NAME
+# define GGML_MAX_NAME 64
+#endif
+
#define GGML_DEFAULT_N_THREADS 4
#define GGML_DEFAULT_GRAPH_SIZE 2048
+
#if UINTPTR_MAX == 0xFFFFFFFF
#define GGML_MEM_ALIGN 4
#else
@@ -259,21 +261,21 @@
#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
#ifndef NDEBUG
-#define GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0)
+# define GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0)
#elif defined(__GNUC__)
-#define GGML_UNREACHABLE() __builtin_unreachable()
+# define GGML_UNREACHABLE() __builtin_unreachable()
#elif defined(_MSC_VER)
-#define GGML_UNREACHABLE() __assume(0)
+# define GGML_UNREACHABLE() __assume(0)
#else
-#define GGML_UNREACHABLE() ((void) 0)
+# define GGML_UNREACHABLE() ((void) 0)
#endif
#ifdef __cplusplus
-#define GGML_NORETURN [[noreturn]]
+# define GGML_NORETURN [[noreturn]]
#elif defined(_MSC_VER)
-#define GGML_NORETURN __declspec(noreturn)
+# define GGML_NORETURN __declspec(noreturn)
#else
-#define GGML_NORETURN _Noreturn
+# define GGML_NORETURN _Noreturn
#endif
#define GGML_ABORT(...) ggml_abort(__FILE__, __LINE__, __VA_ARGS__)
@@ -534,6 +536,7 @@ extern "C" {
GGML_OP_CROSS_ENTROPY_LOSS,
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
+ GGML_OP_OPT_STEP_ADAMW,
GGML_OP_COUNT,
};
@@ -569,12 +572,15 @@ extern "C" {
GGML_LOG_LEVEL_WARN = 2,
GGML_LOG_LEVEL_ERROR = 3,
GGML_LOG_LEVEL_DEBUG = 4,
+ GGML_LOG_LEVEL_CONT = 5, // continue previous log
};
+ // this tensor...
enum ggml_tensor_flag {
- GGML_TENSOR_FLAG_INPUT = 1,
- GGML_TENSOR_FLAG_OUTPUT = 2,
- GGML_TENSOR_FLAG_PARAM = 4,
+ GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
+ GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
+ GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
+ GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
};
// n-dimensional tensor
@@ -1976,6 +1982,9 @@ extern "C" {
typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);
+#define GGML_N_TASKS_MAX (-1)
+ // n_tasks == GGML_N_TASKS_MAX means to use max number of tasks
+
GGML_API struct ggml_tensor * ggml_map_custom1(
struct ggml_context * ctx,
struct ggml_tensor * a,
@@ -2037,23 +2046,44 @@ extern "C" {
struct ggml_tensor * b,
struct ggml_tensor * c);
+ // AdamW optimizer step
+ // Paper: https://arxiv.org/pdf/1711.05101v3.pdf
+ // PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
+ GGML_API struct ggml_tensor * ggml_opt_step_adamw(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float alpha,
+ float beta1,
+ float beta2,
+ float eps,
+ float wd); // weight decay
+
//
// automatic differentiation
//
- GGML_API void ggml_set_param(
- struct ggml_context * ctx,
- struct ggml_tensor * tensor);
+ GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor);
+ GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
- GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
+ GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep);
+
+ GGML_API void ggml_build_opt_adamw(
+ struct ggml_context * ctx,
+ struct ggml_cgraph * gf,
+ struct ggml_cgraph * gb,
+ float alpha,
+ float beta1,
+ float beta2,
+ float eps,
+ float wd); // weight decay
// graph allocation in a context
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads);
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
- GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads
+ GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
GGML_API int ggml_graph_size (struct ggml_cgraph * cgraph);
@@ -2479,6 +2509,9 @@ extern "C" {
GGML_API int ggml_cpu_has_cann (void);
GGML_API int ggml_cpu_has_llamafile (void);
+ // get the sve vector length in bytes
+ GGML_API int ggml_cpu_get_sve_cnt(void);
+
//
// Internal types and functions exposed for tests and benchmarks
//
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 527c22c68..cbc349500 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -364,7 +364,7 @@ if (GGML_CUDA)
if (GGML_MUSA)
set_source_files_properties(${GGML_SOURCES_CUDA} PROPERTIES LANGUAGE CXX)
foreach(SOURCE ${GGML_SOURCES_CUDA})
- set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_22")
+ set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_21 --cuda-gpu-arch=mp_22")
endforeach()
endif()
@@ -1186,6 +1186,7 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
endif()
if (GGML_AVX512)
list(APPEND ARCH_FLAGS -mavx512f)
+ list(APPEND ARCH_FLAGS -mavx512dq)
list(APPEND ARCH_FLAGS -mavx512bw)
endif()
if (GGML_AVX512_VBMI)
diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c
index 27375d0d7..b27f41147 100644
--- a/ggml/src/ggml-aarch64.c
+++ b/ggml/src/ggml-aarch64.c
@@ -1,4 +1,7 @@
-// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
+// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
+// SPDX-License-Identifier: MIT
+//
+
#define GGML_COMMON_IMPL_C
#include "ggml-common.h"
@@ -39,11 +42,44 @@
//
#if defined(__AVX__)
#if defined(__F16C__)
+#if defined(__AVX512F__)
+#define GGML_F32Cx8x2_LOAD(x, y) _mm512_cvtph_ps(_mm256_set_m128i(_mm_loadu_si128((const __m128i *)(y)), _mm_loadu_si128((const __m128i *)(x))))
+#define GGML_F32Cx16_REPEAT_LOAD(x) _mm512_cvtph_ps(_mm256_set_m128i(x, x))
+#endif
// the _mm256_cvt intrinsics require F16C
#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask) _mm256_cvtph_ps(_mm_shuffle_epi32(_mm_maskload_epi32((int const*)(x), loadMask), 68))
#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask))
#else
+#if defined(__AVX512F__)
+static inline __m512 __avx512_f32cx8x2_load(ggml_fp16_t *x, ggml_fp16_t *y) {
+ float tmp[16];
+
+ for (int i = 0; i < 8; i++) {
+ tmp[i] = GGML_FP16_TO_FP32(x[i]);
+ }
+
+ for (int i = 0; i < 8; i++) {
+ tmp[i + 8] = GGML_FP16_TO_FP32(y[i]);
+ }
+
+ return _mm512_loadu_ps(tmp);
+}
+static inline __m512 __avx512_repeat_f32cx16_load(__m128i x) {
+ float tmp[16];
+ uint16_t tmphalf[8];
+ _mm_storeu_si128((__m128i*)tmphalf, x);
+
+ for (int i = 0; i < 4; i++) {
+ tmp[i] = GGML_FP16_TO_FP32(tmphalf[i]);
+ tmp[i + 4] = GGML_FP16_TO_FP32(tmphalf[i]);
+ tmp[i + 8] = GGML_FP16_TO_FP32(tmphalf[i]);
+ tmp[i + 12] = GGML_FP16_TO_FP32(tmphalf[i]);
+ }
+
+ return _mm512_loadu_ps(tmp);
+}
+#endif
static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
float tmp[8];
@@ -78,30 +114,65 @@ static inline __m256 __avx_rearranged_f32cx8_load(ggml_fp16_t *x, __m128i arrang
#define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x)
#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask) __avx_repeat_f32cx8_load(x)
#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) __avx_rearranged_f32cx8_load(x, arrangeMask)
+#if defined(__AVX512F__)
+#define GGML_F32Cx8x2_LOAD(x, y) __avx512_f32cx8x2_load(x, y)
+#define GGML_F32Cx16_REPEAT_LOAD(x) __avx512_repeat_f32cx16_load(x)
+#endif
#endif
#endif
#if defined(__AVX2__) || defined(__AVX512F__)
-static inline __m256i sum_i16_pairs_int(const __m256i x) {
+#if defined(__AVX512F__)
+// add int16_t pairwise and return as 512 bit int vector
+static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) {
+ const __m512i ones = _mm512_set1_epi16(1);
+ return _mm512_madd_epi16(ones, x);
+}
+
+static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) {
+#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
+ const __m512i zero = _mm512_setzero_si512();
+ return _mm512_dpbusd_epi32(zero, ax, sy);
+#else
+ // Perform multiplication and create 16-bit values
+ const __m512i dot = _mm512_maddubs_epi16(ax, sy);
+ return sum_i16_pairs_int_32x16(dot);
+#endif
+}
+
+// multiply int8_t, add results pairwise twice and return as 512 bit int vector
+static inline __m512i mul_sum_i8_pairs_int32x16(const __m512i x, const __m512i y) {
+ const __m512i zero = _mm512_setzero_si512();
+ // Get absolute values of x vectors
+ const __m512i ax = _mm512_abs_epi8(x);
+ // Sign the values of the y vectors
+ __mmask64 blt0 = _mm512_movepi8_mask(x);
+ const __m512i sy = _mm512_mask_sub_epi8(y, blt0, zero, y);
+ return mul_sum_us8_pairs_int32x16(ax, sy);
+}
+#endif
+
+// add int16_t pairwise and return as 256 bit int vector
+static inline __m256i sum_i16_pairs_int32x8(const __m256i x) {
const __m256i ones = _mm256_set1_epi16(1);
return _mm256_madd_epi16(ones, x);
}
-static inline __m256i mul_sum_us8_pairs_int(const __m256i ax, const __m256i sy) {
+static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) {
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
const __m256i zero = _mm256_setzero_si256();
return _mm256_dpbusd_epi32(zero, ax, sy);
#else
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
- return sum_i16_pairs_int(dot);
+ return sum_i16_pairs_int32x8(dot);
#endif
}
// Integer variant of the function defined in ggml-quants.c
-// multiply int8_t, add results pairwise twice and return as float vector
-static inline __m256i mul_sum_i8_pairs_int(const __m256i x, const __m256i y) {
+// multiply int8_t, add results pairwise twice and return as 256 bit int vector
+static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y) {
#if __AVXVNNIINT8__
const __m256i zero = _mm256_setzero_si256();
return _mm256_dpbssd_epi32(zero, x, y);
@@ -110,7 +181,7 @@ static inline __m256i mul_sum_i8_pairs_int(const __m256i x, const __m256i y) {
const __m256i ax = _mm256_sign_epi8(x, x);
// Sign the values of the y vectors
const __m256i sy = _mm256_sign_epi8(y, x);
- return mul_sum_us8_pairs_int(ax, sy);
+ return mul_sum_us8_pairs_int32x8(ax, sy);
#endif
}
#endif
@@ -546,73 +617,67 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
UNUSED(ncols_interleaved);
UNUSED(blocklen);
-#if defined(__ARM_FEATURE_SVE)
- if (ggml_sve_cnt_b == QK8_0) {
- GGML_ASSERT(!(ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
- "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
- }
-#endif
-#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
- GGML_ASSERT(!(ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) &&
- "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance");
-#elif defined(__ARM_NEON) && defined(__aarch64__) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
- const void * b_ptr = vx;
- const void * a_ptr = vy;
- float * res_ptr = s;
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+ if (ggml_cpu_has_neon()) {
+ const void * b_ptr = vx;
+ const void * a_ptr = vy;
+ float * res_ptr = s;
- __asm__ __volatile__(
- "movi v31.16b, #0x4\n"
- "movi v30.16b, #0xf0\n"
- "add %x[b_ptr], %x[b_ptr], #0x8\n"
- "1:" // Column loop
- "add x22, %x[a_ptr], #0x2\n"
- "movi v29.16b, #0x0\n"
- "mov x21, %x[nb]\n"
- "2:" // Block loop
- "ldr q28, [%x[b_ptr], #0x0]\n"
- "ldr q27, [x22, #0x0]\n"
- "movi v26.4s, #0x0\n"
- "sub x20, x22, #0x2\n"
- "ldr q25, [x22, #0x10]\n"
- "ldr q24, [%x[b_ptr], #0x10]\n"
- "sub x21, x21, #0x1\n"
- "add x22, x22, #0x22\n"
- "ldr q23, [%x[b_ptr], #0x20]\n"
- "ldr q22, [%x[b_ptr], #0x30]\n"
- "ld1r { v21.8h }, [x20]\n"
- "ldr q20, [%x[b_ptr], #-0x8]\n"
- "sshl v16.16b, v28.16b, v31.16b\n"
- "and v28.16b, v28.16b, v30.16b\n"
- "sshl v19.16b, v24.16b, v31.16b\n"
- "and v24.16b, v24.16b, v30.16b\n"
- "add %x[b_ptr], %x[b_ptr], #0x48\n"
- "sshl v18.16b, v23.16b, v31.16b\n"
- "and v23.16b, v23.16b, v30.16b\n"
- ".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n"
- "sshl v17.16b, v22.16b, v31.16b\n"
- "and v22.16b, v22.16b, v30.16b\n"
- "fcvtl v21.4s, v21.4h\n"
- "fcvtl v16.4s, v20.4h\n"
- ".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n"
- "fmul v16.4s, v16.4s, v21.4s\n"
- ".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n"
- ".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n"
- ".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n"
- ".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n"
- ".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n"
- ".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n"
- "scvtf v26.4s, v26.4s, #0x4\n"
- "fmla v29.4s, v26.4s, v16.4s\n"
- "cbnz x21, 2b\n"
- "sub %x[nc], %x[nc], #0x4\n"
- "str q29, [%x[res_ptr], #0x0]\n"
- "add %x[res_ptr], %x[res_ptr], #0x10\n"
- "cbnz %x[nc], 1b\n"
- : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
- : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
- : "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22"
- );
-#else
+ __asm__ __volatile__(
+ "movi v31.16b, #0x4\n"
+ "movi v30.16b, #0xf0\n"
+ "add %x[b_ptr], %x[b_ptr], #0x8\n"
+ "1:" // Column loop
+ "add x22, %x[a_ptr], #0x2\n"
+ "movi v29.16b, #0x0\n"
+ "mov x21, %x[nb]\n"
+ "2:" // Block loop
+ "ldr q28, [%x[b_ptr], #0x0]\n"
+ "ldr q27, [x22, #0x0]\n"
+ "movi v26.4s, #0x0\n"
+ "sub x20, x22, #0x2\n"
+ "ldr q25, [x22, #0x10]\n"
+ "ldr q24, [%x[b_ptr], #0x10]\n"
+ "sub x21, x21, #0x1\n"
+ "add x22, x22, #0x22\n"
+ "ldr q23, [%x[b_ptr], #0x20]\n"
+ "ldr q22, [%x[b_ptr], #0x30]\n"
+ "ld1r { v21.8h }, [x20]\n"
+ "ldr q20, [%x[b_ptr], #-0x8]\n"
+ "sshl v16.16b, v28.16b, v31.16b\n"
+ "and v28.16b, v28.16b, v30.16b\n"
+ "sshl v19.16b, v24.16b, v31.16b\n"
+ "and v24.16b, v24.16b, v30.16b\n"
+ "add %x[b_ptr], %x[b_ptr], #0x48\n"
+ "sshl v18.16b, v23.16b, v31.16b\n"
+ "and v23.16b, v23.16b, v30.16b\n"
+ ".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n"
+ "sshl v17.16b, v22.16b, v31.16b\n"
+ "and v22.16b, v22.16b, v30.16b\n"
+ "fcvtl v21.4s, v21.4h\n"
+ "fcvtl v16.4s, v20.4h\n"
+ ".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n"
+ "fmul v16.4s, v16.4s, v21.4s\n"
+ ".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n"
+ ".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n"
+ ".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n"
+ ".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n"
+ ".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n"
+ ".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n"
+ "scvtf v26.4s, v26.4s, #0x4\n"
+ "fmla v29.4s, v26.4s, v16.4s\n"
+ "cbnz x21, 2b\n"
+ "sub %x[nc], %x[nc], #0x4\n"
+ "str q29, [%x[res_ptr], #0x0]\n"
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
+ "cbnz %x[nc], 1b\n"
+ : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
+ : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
+ : "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22"
+ );
+ return;
+ }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
float sumf[4];
int sumi;
@@ -636,7 +701,6 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
}
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
}
-#endif
}
void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
@@ -658,79 +722,72 @@ void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
UNUSED(ncols_interleaved);
UNUSED(blocklen);
-#if defined(__ARM_FEATURE_SVE)
- if (ggml_sve_cnt_b == QK8_0) {
- GGML_ASSERT(!(ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
- "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
- }
-#endif
-#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
- const void * b_ptr = vx;
- const void * a_ptr = vy;
- float * res_ptr = s;
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+ if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+ const void * b_ptr = vx;
+ const void * a_ptr = vy;
+ float * res_ptr = s;
- __asm__ __volatile__(
- "movi v2.16b, #0x4\n"
- "movi v1.16b, #0xf0\n"
- "add %x[b_ptr], %x[b_ptr], #0x8\n"
- "1:" // Column loop
- "add x23, %x[a_ptr], #0x2\n"
- "movi v0.16b, #0x0\n"
- "mov x22, %x[nb]\n"
- "2:" // Block loop
- "ldr q31, [%x[b_ptr], #0x0]\n"
- "ldr q30, [%x[b_ptr], #0x10]\n"
- "mov x21, x23\n"
- "movi v29.4s, #0x0\n"
- "ldr q28, [%x[b_ptr], #0x20]\n"
- "ldr q27, [%x[b_ptr], #0x30]\n"
- "movi v26.4s, #0x0\n"
- "sub x20, x23, #0x2\n"
- "ld1r { v25.8h }, [x20]\n"
- "ldr q24, [%x[b_ptr], #-0x8]\n"
- "sub x22, x22, #0x1\n"
- "add x23, x23, #0x22\n"
- "ld1r { v23.2d }, [x21], #0x8\n"
- "sshl v22.16b, v31.16b, v2.16b\n"
- "sshl v16.16b, v30.16b, v2.16b\n"
- "add %x[b_ptr], %x[b_ptr], #0x48\n"
- "ld1r { v21.2d }, [x21], #0x8\n"
- "sshl v20.16b, v28.16b, v2.16b\n"
- "sshl v19.16b, v27.16b, v2.16b\n"
- "ld1r { v18.2d }, [x21], #0x8\n"
- "ld1r { v17.2d }, [x21], #0x8\n"
- "and v31.16b, v31.16b, v1.16b\n"
- "and v30.16b, v30.16b, v1.16b\n"
- ".inst 0x4e9796dd // sdot v29.4s, v22.16b, v23.16b\n"
- ".inst 0x4e97961a // sdot v26.4s, v16.16b, v23.16b\n"
- "and v28.16b, v28.16b, v1.16b\n"
- "and v27.16b, v27.16b, v1.16b\n"
- "fcvtl v25.4s, v25.4h\n"
- "fcvtl v16.4s, v24.4h\n"
- ".inst 0x4e95969d // sdot v29.4s, v20.16b, v21.16b\n"
- ".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n"
- "fmul v16.4s, v16.4s, v25.4s\n"
- ".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n"
- ".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n"
- ".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n"
- ".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n"
- "addp v29.4s, v29.4s, v26.4s\n"
- "scvtf v29.4s, v29.4s, #0x4\n"
- "fmla v0.4s, v29.4s, v16.4s\n"
- "cbnz x22, 2b\n"
- "sub %x[nc], %x[nc], #0x4\n"
- "str q0, [%x[res_ptr], #0x0]\n"
- "add %x[res_ptr], %x[res_ptr], #0x10\n"
- "cbnz %x[nc], 1b\n"
- : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
- : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
- : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23"
- );
-#elif defined(__ARM_NEON) && defined(__aarch64__)
- GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
- "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
- "performance");
-#else
+ __asm__ __volatile__(
+ "movi v2.16b, #0x4\n"
+ "movi v1.16b, #0xf0\n"
+ "add %x[b_ptr], %x[b_ptr], #0x8\n"
+ "1:" // Column loop
+ "add x23, %x[a_ptr], #0x2\n"
+ "movi v0.16b, #0x0\n"
+ "mov x22, %x[nb]\n"
+ "2:" // Block loop
+ "ldr q31, [%x[b_ptr], #0x0]\n"
+ "ldr q30, [%x[b_ptr], #0x10]\n"
+ "mov x21, x23\n"
+ "movi v29.4s, #0x0\n"
+ "ldr q28, [%x[b_ptr], #0x20]\n"
+ "ldr q27, [%x[b_ptr], #0x30]\n"
+ "movi v26.4s, #0x0\n"
+ "sub x20, x23, #0x2\n"
+ "ld1r { v25.8h }, [x20]\n"
+ "ldr q24, [%x[b_ptr], #-0x8]\n"
+ "sub x22, x22, #0x1\n"
+ "add x23, x23, #0x22\n"
+ "ld1r { v23.2d }, [x21], #0x8\n"
+ "sshl v22.16b, v31.16b, v2.16b\n"
+ "sshl v16.16b, v30.16b, v2.16b\n"
+ "add %x[b_ptr], %x[b_ptr], #0x48\n"
+ "ld1r { v21.2d }, [x21], #0x8\n"
+ "sshl v20.16b, v28.16b, v2.16b\n"
+ "sshl v19.16b, v27.16b, v2.16b\n"
+ "ld1r { v18.2d }, [x21], #0x8\n"
+ "ld1r { v17.2d }, [x21], #0x8\n"
+ "and v31.16b, v31.16b, v1.16b\n"
+ "and v30.16b, v30.16b, v1.16b\n"
+ ".inst 0x4e9796dd // sdot v29.4s, v22.16b, v23.16b\n"
+ ".inst 0x4e97961a // sdot v26.4s, v16.16b, v23.16b\n"
+ "and v28.16b, v28.16b, v1.16b\n"
+ "and v27.16b, v27.16b, v1.16b\n"
+ "fcvtl v25.4s, v25.4h\n"
+ "fcvtl v16.4s, v24.4h\n"
+ ".inst 0x4e95969d // sdot v29.4s, v20.16b, v21.16b\n"
+ ".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n"
+ "fmul v16.4s, v16.4s, v25.4s\n"
+ ".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n"
+ ".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n"
+ ".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n"
+ ".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n"
+ "addp v29.4s, v29.4s, v26.4s\n"
+ "scvtf v29.4s, v29.4s, #0x4\n"
+ "fmla v0.4s, v29.4s, v16.4s\n"
+ "cbnz x22, 2b\n"
+ "sub %x[nc], %x[nc], #0x4\n"
+ "str q0, [%x[res_ptr], #0x0]\n"
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
+ "cbnz %x[nc], 1b\n"
+ : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
+ : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
+ : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23"
+ );
+ return;
+ }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
float sumf[4];
int sumi;
@@ -754,7 +811,6 @@ void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
}
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
}
-#endif
}
void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
@@ -776,8 +832,9 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
UNUSED(ncols_interleaved);
UNUSED(blocklen);
-#if defined(__ARM_FEATURE_SVE) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
- if (ggml_sve_cnt_b == QK8_0) {
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
+#if defined(__ARM_FEATURE_SVE)
+ if (ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) {
const void * b_ptr = vx;
const void * a_ptr = vy;
float * res_ptr = s;
@@ -842,24 +899,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
);
return;
}
- else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
- GGML_ASSERT((ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
- "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal "
- "performance");
- }
- else if (ggml_cpu_has_neon()) {
- GGML_ASSERT(((ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) || ggml_cpu_has_matmul_int8()) &&
- "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 "
- "quantization format for optimal performance");
- }
-#endif
-#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
- GGML_ASSERT(ggml_cpu_has_sve() &&
- "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance");
-#elif defined(__ARM_NEON) && defined(__aarch64__)
- GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
- "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
- "performance");
+#endif // #if defined(__ARM_FEATURE_SVE)
#elif defined(__AVX2__)
// Lookup table to convert signed nibbles to signed bytes
__m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
@@ -929,17 +969,17 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
// ...........................................................................
// B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)
- iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0)));
- iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85)));
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0)));
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85)));
- iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170)));
- iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255)));
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170)));
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255)));
- iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0)));
- iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85)));
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0)));
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85)));
- iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170)));
- iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255)));
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170)));
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255)));
// Accumulated values multipled with appropriate scales
acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
@@ -950,31 +990,33 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
_mm256_storeu_ps(s + (y * nr + x * 8), acc_row);
}
}
-#else
- float sumf[8];
- int sumi;
+ return;
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
+ {
+ float sumf[8];
+ int sumi;
- const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
- for (int x = 0; x < nc / ncols_interleaved; x++) {
- const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
+ const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
- for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
- for (int l = 0; l < nb; l++) {
- for (int k = 0; k < (qk / (2 * blocklen)); k++) {
- for (int j = 0; j < ncols_interleaved; j++) {
- sumi = 0;
- for (int i = 0; i < blocklen; ++i) {
- const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
- const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
- sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
+ for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+ for (int l = 0; l < nb; l++) {
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumi = 0;
+ for (int i = 0; i < blocklen; ++i) {
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+ sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
+ }
+ sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
}
- sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
}
}
+ for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
}
- for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
}
-#endif
}
void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
@@ -997,505 +1039,500 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
UNUSED(ncols_interleaved);
UNUSED(blocklen);
-#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
- if (ggml_sve_cnt_b == QK8_0) {
- GGML_ASSERT(!(ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
- "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+ if (ggml_cpu_has_neon()) {
+ const void * b_ptr = vx;
+ const void * a_ptr = vy;
+ float * res_ptr = s;
+ size_t res_stride = bs * sizeof(float);
+
+ __asm__ __volatile__(
+ "mov x10, %x[nr]\n"
+ "mov x9, #0x88\n"
+ "cmp x10, #0x10\n"
+ "mul x9, %x[nb], x9\n"
+ "blt 4f\n"
+ "1:" // Row loop
+ "add x28, %x[b_ptr], #0x8\n"
+ "mov x27, %x[nc]\n"
+ "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
+ "2:" // Column loop
+ "add x25, %x[a_ptr], #0x8\n"
+ "movi v15.16b, #0x0\n"
+ "movi v19.16b, #0x0\n"
+ "mov x24, %x[nb]\n"
+ "add x23, x25, x9\n"
+ "movi v18.16b, #0x0\n"
+ "movi v14.16b, #0x0\n"
+ "add x22, x23, x9\n"
+ "movi v11.16b, #0x0\n"
+ "movi v13.16b, #0x0\n"
+ "add x21, x22, x9\n"
+ "movi v23.16b, #0x0\n"
+ "movi v16.16b, #0x0\n"
+ "movi v25.16b, #0x0\n"
+ "movi v7.16b, #0x0\n"
+ "movi v0.16b, #0x0\n"
+ "movi v4.16b, #0x0\n"
+ "movi v5.16b, #0x0\n"
+ "movi v21.16b, #0x0\n"
+ "movi v8.16b, #0x0\n"
+ "movi v1.16b, #0x0\n"
+ "3:" // Block loop
+ "ldr q3, [x28, #0x0]\n"
+ "ldr q31, [x25, #0x0]\n"
+ "movi v28.16b, #0x4\n"
+ "movi v10.4s, #0x0\n"
+ "ldr q22, [x28, #0x10]\n"
+ "ldr q6, [x25, #0x10]\n"
+ "movi v29.4s, #0x0\n"
+ "movi v9.4s, #0x0\n"
+ "ldr q27, [x28, #0x20]\n"
+ "ldr q30, [x28, #0x30]\n"
+ "movi v20.4s, #0x0\n"
+ "movi v24.16b, #0xf0\n"
+ "ldr d2, [x25, #-0x8]\n"
+ "ldr d26, [x23, #-0x8]\n"
+ "sshl v12.16b, v3.16b, v28.16b\n"
+ "sub x20, x28, #0x8\n"
+ "ldr d17, [x20, #0x0]\n"
+ "and v3.16b, v3.16b, v24.16b\n"
+ "subs x24, x24, #0x1\n"
+ "add x28, x28, #0x48\n"
+ ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n"
+ ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n"
+ ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n"
+ ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n"
+ "sshl v31.16b, v22.16b, v28.16b\n"
+ "and v22.16b, v22.16b, v24.16b\n"
+ "fcvtl v17.4s, v17.4h\n"
+ "fcvtl v2.4s, v2.4h\n"
+ "fcvtl v26.4s, v26.4h\n"
+ ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n"
+ ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n"
+ ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n"
+ ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n"
+ "sshl v6.16b, v27.16b, v28.16b\n"
+ "sshl v28.16b, v30.16b, v28.16b\n"
+ "and v27.16b, v27.16b, v24.16b\n"
+ "and v30.16b, v30.16b, v24.16b\n"
+ "ldr q24, [x25, #0x20]\n"
+ ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
+ ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n"
+ "ldr q24, [x25, #0x30]\n"
+ ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n"
+ ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n"
+ ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n"
+ "ldr q24, [x25, #0x40]\n"
+ ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
+ ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n"
+ "ldr q24, [x25, #0x50]\n"
+ ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n"
+ ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n"
+ ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n"
+ "ldr q24, [x25, #0x60]\n"
+ ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
+ ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n"
+ ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n"
+ "ldr q24, [x25, #0x70]\n"
+ "add x25, x25, #0x88\n"
+ ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n"
+ ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n"
+ ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n"
+ "fmul v24.4s, v17.4s, v2.s[0]\n"
+ "scvtf v10.4s, v10.4s, #0x4\n"
+ "scvtf v29.4s, v29.4s, #0x4\n"
+ "scvtf v9.4s, v9.4s, #0x4\n"
+ "scvtf v20.4s, v20.4s, #0x4\n"
+ "fmla v15.4s, v10.4s, v24.4s\n"
+ "ldr q24, [x23, #0x0]\n"
+ "fmul v10.4s, v17.4s, v2.s[1]\n"
+ "fmla v19.4s, v29.4s, v10.4s\n"
+ "ldr q10, [x23, #0x10]\n"
+ "fmul v29.4s, v17.4s, v2.s[2]\n"
+ "fmul v2.4s, v17.4s, v2.s[3]\n"
+ "fmla v18.4s, v9.4s, v29.4s\n"
+ "movi v9.4s, #0x0\n"
+ "movi v29.4s, #0x0\n"
+ ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n"
+ "fmla v14.4s, v20.4s, v2.4s\n"
+ "movi v20.4s, #0x0\n"
+ "movi v2.4s, #0x0\n"
+ ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
+ "ldr q24, [x23, #0x20]\n"
+ ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n"
+ ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n"
+ ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n"
+ ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n"
+ "ldr q10, [x23, #0x30]\n"
+ ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
+ ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n"
+ "ldr q24, [x23, #0x40]\n"
+ ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n"
+ ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n"
+ ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n"
+ ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n"
+ "ldr q10, [x23, #0x50]\n"
+ ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
+ ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n"
+ "ldr q24, [x23, #0x60]\n"
+ ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n"
+ ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n"
+ ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n"
+ ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n"
+ "ldr q10, [x23, #0x70]\n"
+ "add x23, x23, #0x88\n"
+ ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
+ ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n"
+ ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n"
+ "ldr q24, [x22, #0x0]\n"
+ ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n"
+ ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n"
+ ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n"
+ ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n"
+ "fmul v10.4s, v17.4s, v26.s[0]\n"
+ "scvtf v9.4s, v9.4s, #0x4\n"
+ "scvtf v29.4s, v29.4s, #0x4\n"
+ "scvtf v20.4s, v20.4s, #0x4\n"
+ "scvtf v2.4s, v2.4s, #0x4\n"
+ "fmla v11.4s, v9.4s, v10.4s\n"
+ "ldr q9, [x22, #0x10]\n"
+ "fmul v10.4s, v17.4s, v26.s[1]\n"
+ "fmla v13.4s, v29.4s, v10.4s\n"
+ "ldr d29, [x22, #-0x8]\n"
+ "fmul v10.4s, v17.4s, v26.s[2]\n"
+ "fmul v26.4s, v17.4s, v26.s[3]\n"
+ "fcvtl v29.4s, v29.4h\n"
+ "fmla v23.4s, v20.4s, v10.4s\n"
+ "movi v20.4s, #0x0\n"
+ "movi v10.4s, #0x0\n"
+ "fmla v16.4s, v2.4s, v26.4s\n"
+ "movi v26.4s, #0x0\n"
+ "movi v2.4s, #0x0\n"
+ ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n"
+ ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
+ "ldr q24, [x22, #0x20]\n"
+ ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n"
+ ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n"
+ ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n"
+ ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n"
+ "ldr q9, [x22, #0x30]\n"
+ ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n"
+ ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n"
+ "ldr q24, [x22, #0x40]\n"
+ ".inst 0x4f89e394 // sdot v20.4s, v28.16b, v9.4b[0]\n"
+ ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n"
+ ".inst 0x4f89eb9a // sdot v26.4s, v28.16b, v9.4b[2]\n"
+ ".inst 0x4fa9eb82 // sdot v2.4s, v28.16b, v9.4b[3]\n"
+ "ldr q9, [x22, #0x50]\n"
+ ".inst 0x4f98e074 // sdot v20.4s, v3.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e06a // sdot v10.4s, v3.16b, v24.4b[1]\n"
+ ".inst 0x4f98e87a // sdot v26.4s, v3.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n"
+ "ldr q24, [x22, #0x60]\n"
+ ".inst 0x4f89e2d4 // sdot v20.4s, v22.16b, v9.4b[0]\n"
+ ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n"
+ ".inst 0x4f89eada // sdot v26.4s, v22.16b, v9.4b[2]\n"
+ ".inst 0x4fa9eac2 // sdot v2.4s, v22.16b, v9.4b[3]\n"
+ "ldr q9, [x22, #0x70]\n"
+ "add x22, x22, #0x88\n"
+ ".inst 0x4f98e374 // sdot v20.4s, v27.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e36a // sdot v10.4s, v27.16b, v24.4b[1]\n"
+ ".inst 0x4f98eb7a // sdot v26.4s, v27.16b, v24.4b[2]\n"
+ ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n"
+ "ldr q24, [x21, #0x0]\n"
+ ".inst 0x4f89e3d4 // sdot v20.4s, v30.16b, v9.4b[0]\n"
+ ".inst 0x4fa9e3ca // sdot v10.4s, v30.16b, v9.4b[1]\n"
+ ".inst 0x4f89ebda // sdot v26.4s, v30.16b, v9.4b[2]\n"
+ ".inst 0x4fa9ebc2 // sdot v2.4s, v30.16b, v9.4b[3]\n"
+ "fmul v9.4s, v17.4s, v29.s[0]\n"
+ "scvtf v20.4s, v20.4s, #0x4\n"
+ "scvtf v10.4s, v10.4s, #0x4\n"
+ "scvtf v26.4s, v26.4s, #0x4\n"
+ "scvtf v2.4s, v2.4s, #0x4\n"
+ "fmla v25.4s, v20.4s, v9.4s\n"
+ "ldr q9, [x21, #0x10]\n"
+ "fmul v20.4s, v17.4s, v29.s[1]\n"
+ "fmla v7.4s, v10.4s, v20.4s\n"
+ "ldr d20, [x21, #-0x8]\n"
+ "fmul v10.4s, v17.4s, v29.s[2]\n"
+ "fmul v29.4s, v17.4s, v29.s[3]\n"
+ "fcvtl v20.4s, v20.4h\n"
+ "fmla v0.4s, v26.4s, v10.4s\n"
+ "movi v26.4s, #0x0\n"
+ "movi v10.4s, #0x0\n"
+ "fmla v4.4s, v2.4s, v29.4s\n"
+ "movi v2.4s, #0x0\n"
+ "movi v29.4s, #0x0\n"
+ ".inst 0x4f98e19a // sdot v26.4s, v12.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n"
+ ".inst 0x4f98e982 // sdot v2.4s, v12.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e99d // sdot v29.4s, v12.16b, v24.4b[3]\n"
+ "ldr q12, [x21, #0x20]\n"
+ "fmul v24.4s, v17.4s, v20.s[0]\n"
+ ".inst 0x4f89e3fa // sdot v26.4s, v31.16b, v9.4b[0]\n"
+ ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n"
+ ".inst 0x4f89ebe2 // sdot v2.4s, v31.16b, v9.4b[2]\n"
+ ".inst 0x4fa9ebfd // sdot v29.4s, v31.16b, v9.4b[3]\n"
+ "ldr q9, [x21, #0x30]\n"
+ "fmul v31.4s, v17.4s, v20.s[1]\n"
+ ".inst 0x4f8ce0da // sdot v26.4s, v6.16b, v12.4b[0]\n"
+ ".inst 0x4face0ca // sdot v10.4s, v6.16b, v12.4b[1]\n"
+ ".inst 0x4f8ce8c2 // sdot v2.4s, v6.16b, v12.4b[2]\n"
+ ".inst 0x4face8dd // sdot v29.4s, v6.16b, v12.4b[3]\n"
+ "ldr q12, [x21, #0x40]\n"
+ "fmul v6.4s, v17.4s, v20.s[2]\n"
+ "fmul v20.4s, v17.4s, v20.s[3]\n"
+ ".inst 0x4f89e39a // sdot v26.4s, v28.16b, v9.4b[0]\n"
+ ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n"
+ ".inst 0x4f89eb82 // sdot v2.4s, v28.16b, v9.4b[2]\n"
+ ".inst 0x4fa9eb9d // sdot v29.4s, v28.16b, v9.4b[3]\n"
+ "ldr q9, [x21, #0x50]\n"
+ ".inst 0x4f8ce07a // sdot v26.4s, v3.16b, v12.4b[0]\n"
+ ".inst 0x4face06a // sdot v10.4s, v3.16b, v12.4b[1]\n"
+ ".inst 0x4f8ce862 // sdot v2.4s, v3.16b, v12.4b[2]\n"
+ ".inst 0x4face87d // sdot v29.4s, v3.16b, v12.4b[3]\n"
+ "ldr q12, [x21, #0x60]\n"
+ ".inst 0x4f89e2da // sdot v26.4s, v22.16b, v9.4b[0]\n"
+ ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n"
+ ".inst 0x4f89eac2 // sdot v2.4s, v22.16b, v9.4b[2]\n"
+ ".inst 0x4fa9eadd // sdot v29.4s, v22.16b, v9.4b[3]\n"
+ "ldr q17, [x21, #0x70]\n"
+ "add x21, x21, #0x88\n"
+ ".inst 0x4f8ce37a // sdot v26.4s, v27.16b, v12.4b[0]\n"
+ ".inst 0x4face36a // sdot v10.4s, v27.16b, v12.4b[1]\n"
+ ".inst 0x4f8ceb62 // sdot v2.4s, v27.16b, v12.4b[2]\n"
+ ".inst 0x4faceb7d // sdot v29.4s, v27.16b, v12.4b[3]\n"
+ ".inst 0x4f91e3da // sdot v26.4s, v30.16b, v17.4b[0]\n"
+ ".inst 0x4fb1e3ca // sdot v10.4s, v30.16b, v17.4b[1]\n"
+ ".inst 0x4f91ebc2 // sdot v2.4s, v30.16b, v17.4b[2]\n"
+ ".inst 0x4fb1ebdd // sdot v29.4s, v30.16b, v17.4b[3]\n"
+ "scvtf v26.4s, v26.4s, #0x4\n"
+ "scvtf v10.4s, v10.4s, #0x4\n"
+ "fmla v5.4s, v26.4s, v24.4s\n"
+ "scvtf v2.4s, v2.4s, #0x4\n"
+ "scvtf v29.4s, v29.4s, #0x4\n"
+ "fmla v21.4s, v10.4s, v31.4s\n"
+ "fmla v8.4s, v2.4s, v6.4s\n"
+ "fmla v1.4s, v29.4s, v20.4s\n"
+ "bgt 3b\n"
+ "mov x20, %x[res_ptr]\n"
+ "subs x27, x27, #0x4\n"
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
+ "str q15, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q19, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q18, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q14, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q11, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q13, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q23, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q16, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q25, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q7, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q0, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q4, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q5, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q21, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q8, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q1, [x20, #0x0]\n"
+ "bne 2b\n"
+ "mov x20, #0x4\n"
+ "sub x10, x10, #0x10\n"
+ "cmp x10, #0x10\n"
+ "mov %x[res_ptr], x26\n"
+ "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
+ "bge 1b\n"
+ "4:" // Row loop skip
+ "cbz x10, 9f\n"
+ "5:" // Row tail: Row loop
+ "add x24, %x[b_ptr], #0x8\n"
+ "mov x23, %x[nc]\n"
+ "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
+ "6:" // Row tail: Column loop
+ "movi v15.16b, #0x0\n"
+ "movi v19.16b, #0x0\n"
+ "add x25, %x[a_ptr], #0x8\n"
+ "mov x21, %x[nb]\n"
+ "movi v18.16b, #0x0\n"
+ "movi v14.16b, #0x0\n"
+ "7:" // Row tail: Block loop
+ "ldr q7, [x24, #0x0]\n"
+ "ldr q5, [x25, #0x0]\n"
+ "movi v9.16b, #0x4\n"
+ "movi v4.4s, #0x0\n"
+ "ldr q3, [x24, #0x10]\n"
+ "ldr q2, [x25, #0x10]\n"
+ "movi v1.4s, #0x0\n"
+ "movi v0.4s, #0x0\n"
+ "ldr q13, [x24, #0x20]\n"
+ "ldr q31, [x25, #0x20]\n"
+ "movi v30.4s, #0x0\n"
+ "movi v29.16b, #0xf0\n"
+ "ldr q28, [x24, #0x30]\n"
+ "ldr q27, [x25, #0x30]\n"
+ "sshl v20.16b, v7.16b, v9.16b\n"
+ "sub x20, x24, #0x8\n"
+ "ldr q26, [x25, #0x40]\n"
+ "ldr q25, [x25, #0x50]\n"
+ "sshl v17.16b, v3.16b, v9.16b\n"
+ "and v7.16b, v7.16b, v29.16b\n"
+ "ldr q24, [x25, #0x60]\n"
+ "ldr q16, [x25, #0x70]\n"
+ "sshl v22.16b, v13.16b, v9.16b\n"
+ "and v3.16b, v3.16b, v29.16b\n"
+ "ldr d21, [x20, #0x0]\n"
+ "ldr d12, [x25, #-0x8]\n"
+ ".inst 0x4f85e284 // sdot v4.4s, v20.16b, v5.4b[0]\n"
+ ".inst 0x4fa5e281 // sdot v1.4s, v20.16b, v5.4b[1]\n"
+ ".inst 0x4f85ea80 // sdot v0.4s, v20.16b, v5.4b[2]\n"
+ ".inst 0x4fa5ea9e // sdot v30.4s, v20.16b, v5.4b[3]\n"
+ "sshl v9.16b, v28.16b, v9.16b\n"
+ "subs x21, x21, #0x1\n"
+ "and v13.16b, v13.16b, v29.16b\n"
+ "and v28.16b, v28.16b, v29.16b\n"
+ "add x25, x25, #0x88\n"
+ "add x24, x24, #0x48\n"
+ "fcvtl v21.4s, v21.4h\n"
+ "fcvtl v12.4s, v12.4h\n"
+ ".inst 0x4f82e224 // sdot v4.4s, v17.16b, v2.4b[0]\n"
+ ".inst 0x4fa2e221 // sdot v1.4s, v17.16b, v2.4b[1]\n"
+ ".inst 0x4f82ea20 // sdot v0.4s, v17.16b, v2.4b[2]\n"
+ ".inst 0x4fa2ea3e // sdot v30.4s, v17.16b, v2.4b[3]\n"
+ "fmul v11.4s, v21.4s, v12.s[0]\n"
+ "fmul v23.4s, v21.4s, v12.s[1]\n"
+ "fmul v17.4s, v21.4s, v12.s[2]\n"
+ ".inst 0x4f9fe2c4 // sdot v4.4s, v22.16b, v31.4b[0]\n"
+ "fmul v6.4s, v21.4s, v12.s[3]\n"
+ ".inst 0x4fbfe2c1 // sdot v1.4s, v22.16b, v31.4b[1]\n"
+ ".inst 0x4f9feac0 // sdot v0.4s, v22.16b, v31.4b[2]\n"
+ ".inst 0x4fbfeade // sdot v30.4s, v22.16b, v31.4b[3]\n"
+ ".inst 0x4f9be124 // sdot v4.4s, v9.16b, v27.4b[0]\n"
+ ".inst 0x4fbbe121 // sdot v1.4s, v9.16b, v27.4b[1]\n"
+ ".inst 0x4f9be920 // sdot v0.4s, v9.16b, v27.4b[2]\n"
+ ".inst 0x4fbbe93e // sdot v30.4s, v9.16b, v27.4b[3]\n"
+ ".inst 0x4f9ae0e4 // sdot v4.4s, v7.16b, v26.4b[0]\n"
+ ".inst 0x4fbae0e1 // sdot v1.4s, v7.16b, v26.4b[1]\n"
+ ".inst 0x4f9ae8e0 // sdot v0.4s, v7.16b, v26.4b[2]\n"
+ ".inst 0x4fbae8fe // sdot v30.4s, v7.16b, v26.4b[3]\n"
+ ".inst 0x4f99e064 // sdot v4.4s, v3.16b, v25.4b[0]\n"
+ ".inst 0x4fb9e061 // sdot v1.4s, v3.16b, v25.4b[1]\n"
+ ".inst 0x4f99e860 // sdot v0.4s, v3.16b, v25.4b[2]\n"
+ ".inst 0x4fb9e87e // sdot v30.4s, v3.16b, v25.4b[3]\n"
+ ".inst 0x4f98e1a4 // sdot v4.4s, v13.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e1a1 // sdot v1.4s, v13.16b, v24.4b[1]\n"
+ ".inst 0x4f98e9a0 // sdot v0.4s, v13.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e9be // sdot v30.4s, v13.16b, v24.4b[3]\n"
+ ".inst 0x4f90e384 // sdot v4.4s, v28.16b, v16.4b[0]\n"
+ ".inst 0x4fb0e381 // sdot v1.4s, v28.16b, v16.4b[1]\n"
+ ".inst 0x4f90eb80 // sdot v0.4s, v28.16b, v16.4b[2]\n"
+ ".inst 0x4fb0eb9e // sdot v30.4s, v28.16b, v16.4b[3]\n"
+ "scvtf v4.4s, v4.4s, #0x4\n"
+ "scvtf v1.4s, v1.4s, #0x4\n"
+ "scvtf v0.4s, v0.4s, #0x4\n"
+ "fmla v15.4s, v4.4s, v11.4s\n"
+ "scvtf v30.4s, v30.4s, #0x4\n"
+ "fmla v19.4s, v1.4s, v23.4s\n"
+ "fmla v18.4s, v0.4s, v17.4s\n"
+ "fmla v14.4s, v30.4s, v6.4s\n"
+ "bgt 7b\n"
+ "mov x20, %x[res_ptr]\n"
+ "cmp x10, #0x1\n"
+ "str q15, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "ble 8f\n"
+ "cmp x10, #0x2\n"
+ "str q19, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "ble 8f\n"
+ "cmp x10, #0x3\n"
+ "str q18, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "ble 8f\n"
+ "str q14, [x20, #0x0]\n"
+ "8:" // Row tail: Accumulator store skip
+ "subs x23, x23, #0x4\n"
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
+ "bne 6b\n"
+ "subs x10, x10, #0x4\n"
+ "add %x[a_ptr], %x[a_ptr], x9\n"
+ "mov %x[res_ptr], x22\n"
+ "bgt 5b\n"
+ "9:" // Row tail: Row loop skip
+ : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
+ : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
+ : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+ );
+ return;
}
-#endif
-#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
- GGML_ASSERT(!(ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) &&
- "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance");
-#elif defined(__ARM_NEON) && defined(__aarch64__) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
- const void * b_ptr = vx;
- const void * a_ptr = vy;
- float * res_ptr = s;
- size_t res_stride = bs * sizeof(float);
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+ {
+ float sumf[4][4];
+ int sumi;
- __asm__ __volatile__(
- "mov x10, %x[nr]\n"
- "mov x9, #0x88\n"
- "cmp x10, #0x10\n"
- "mul x9, %x[nb], x9\n"
- "blt 4f\n"
- "1:" // Row loop
- "add x28, %x[b_ptr], #0x8\n"
- "mov x27, %x[nc]\n"
- "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
- "2:" // Column loop
- "add x25, %x[a_ptr], #0x8\n"
- "movi v15.16b, #0x0\n"
- "movi v19.16b, #0x0\n"
- "mov x24, %x[nb]\n"
- "add x23, x25, x9\n"
- "movi v18.16b, #0x0\n"
- "movi v14.16b, #0x0\n"
- "add x22, x23, x9\n"
- "movi v11.16b, #0x0\n"
- "movi v13.16b, #0x0\n"
- "add x21, x22, x9\n"
- "movi v23.16b, #0x0\n"
- "movi v16.16b, #0x0\n"
- "movi v25.16b, #0x0\n"
- "movi v7.16b, #0x0\n"
- "movi v0.16b, #0x0\n"
- "movi v4.16b, #0x0\n"
- "movi v5.16b, #0x0\n"
- "movi v21.16b, #0x0\n"
- "movi v8.16b, #0x0\n"
- "movi v1.16b, #0x0\n"
- "3:" // Block loop
- "ldr q3, [x28, #0x0]\n"
- "ldr q31, [x25, #0x0]\n"
- "movi v28.16b, #0x4\n"
- "movi v10.4s, #0x0\n"
- "ldr q22, [x28, #0x10]\n"
- "ldr q6, [x25, #0x10]\n"
- "movi v29.4s, #0x0\n"
- "movi v9.4s, #0x0\n"
- "ldr q27, [x28, #0x20]\n"
- "ldr q30, [x28, #0x30]\n"
- "movi v20.4s, #0x0\n"
- "movi v24.16b, #0xf0\n"
- "ldr d2, [x25, #-0x8]\n"
- "ldr d26, [x23, #-0x8]\n"
- "sshl v12.16b, v3.16b, v28.16b\n"
- "sub x20, x28, #0x8\n"
- "ldr d17, [x20, #0x0]\n"
- "and v3.16b, v3.16b, v24.16b\n"
- "subs x24, x24, #0x1\n"
- "add x28, x28, #0x48\n"
- ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n"
- ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n"
- ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n"
- ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n"
- "sshl v31.16b, v22.16b, v28.16b\n"
- "and v22.16b, v22.16b, v24.16b\n"
- "fcvtl v17.4s, v17.4h\n"
- "fcvtl v2.4s, v2.4h\n"
- "fcvtl v26.4s, v26.4h\n"
- ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n"
- ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n"
- ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n"
- ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n"
- "sshl v6.16b, v27.16b, v28.16b\n"
- "sshl v28.16b, v30.16b, v28.16b\n"
- "and v27.16b, v27.16b, v24.16b\n"
- "and v30.16b, v30.16b, v24.16b\n"
- "ldr q24, [x25, #0x20]\n"
- ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n"
- ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
- ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n"
- ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n"
- "ldr q24, [x25, #0x30]\n"
- ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n"
- ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n"
- ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n"
- ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n"
- "ldr q24, [x25, #0x40]\n"
- ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n"
- ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
- ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n"
- ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n"
- "ldr q24, [x25, #0x50]\n"
- ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n"
- ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n"
- ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n"
- ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n"
- "ldr q24, [x25, #0x60]\n"
- ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n"
- ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
- ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n"
- ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n"
- "ldr q24, [x25, #0x70]\n"
- "add x25, x25, #0x88\n"
- ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n"
- ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n"
- ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n"
- ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n"
- "fmul v24.4s, v17.4s, v2.s[0]\n"
- "scvtf v10.4s, v10.4s, #0x4\n"
- "scvtf v29.4s, v29.4s, #0x4\n"
- "scvtf v9.4s, v9.4s, #0x4\n"
- "scvtf v20.4s, v20.4s, #0x4\n"
- "fmla v15.4s, v10.4s, v24.4s\n"
- "ldr q24, [x23, #0x0]\n"
- "fmul v10.4s, v17.4s, v2.s[1]\n"
- "fmla v19.4s, v29.4s, v10.4s\n"
- "ldr q10, [x23, #0x10]\n"
- "fmul v29.4s, v17.4s, v2.s[2]\n"
- "fmul v2.4s, v17.4s, v2.s[3]\n"
- "fmla v18.4s, v9.4s, v29.4s\n"
- "movi v9.4s, #0x0\n"
- "movi v29.4s, #0x0\n"
- ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n"
- ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n"
- "fmla v14.4s, v20.4s, v2.4s\n"
- "movi v20.4s, #0x0\n"
- "movi v2.4s, #0x0\n"
- ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n"
- ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
- "ldr q24, [x23, #0x20]\n"
- ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n"
- ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n"
- ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n"
- ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n"
- "ldr q10, [x23, #0x30]\n"
- ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n"
- ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
- ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n"
- ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n"
- "ldr q24, [x23, #0x40]\n"
- ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n"
- ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n"
- ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n"
- ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n"
- "ldr q10, [x23, #0x50]\n"
- ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n"
- ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
- ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n"
- ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n"
- "ldr q24, [x23, #0x60]\n"
- ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n"
- ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n"
- ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n"
- ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n"
- "ldr q10, [x23, #0x70]\n"
- "add x23, x23, #0x88\n"
- ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n"
- ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
- ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n"
- ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n"
- "ldr q24, [x22, #0x0]\n"
- ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n"
- ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n"
- ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n"
- ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n"
- "fmul v10.4s, v17.4s, v26.s[0]\n"
- "scvtf v9.4s, v9.4s, #0x4\n"
- "scvtf v29.4s, v29.4s, #0x4\n"
- "scvtf v20.4s, v20.4s, #0x4\n"
- "scvtf v2.4s, v2.4s, #0x4\n"
- "fmla v11.4s, v9.4s, v10.4s\n"
- "ldr q9, [x22, #0x10]\n"
- "fmul v10.4s, v17.4s, v26.s[1]\n"
- "fmla v13.4s, v29.4s, v10.4s\n"
- "ldr d29, [x22, #-0x8]\n"
- "fmul v10.4s, v17.4s, v26.s[2]\n"
- "fmul v26.4s, v17.4s, v26.s[3]\n"
- "fcvtl v29.4s, v29.4h\n"
- "fmla v23.4s, v20.4s, v10.4s\n"
- "movi v20.4s, #0x0\n"
- "movi v10.4s, #0x0\n"
- "fmla v16.4s, v2.4s, v26.4s\n"
- "movi v26.4s, #0x0\n"
- "movi v2.4s, #0x0\n"
- ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n"
- ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n"
- ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n"
- ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
- "ldr q24, [x22, #0x20]\n"
- ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n"
- ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n"
- ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n"
- ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n"
- "ldr q9, [x22, #0x30]\n"
- ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n"
- ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n"
- ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n"
- ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n"
- "ldr q24, [x22, #0x40]\n"
- ".inst 0x4f89e394 // sdot v20.4s, v28.16b, v9.4b[0]\n"
- ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n"
- ".inst 0x4f89eb9a // sdot v26.4s, v28.16b, v9.4b[2]\n"
- ".inst 0x4fa9eb82 // sdot v2.4s, v28.16b, v9.4b[3]\n"
- "ldr q9, [x22, #0x50]\n"
- ".inst 0x4f98e074 // sdot v20.4s, v3.16b, v24.4b[0]\n"
- ".inst 0x4fb8e06a // sdot v10.4s, v3.16b, v24.4b[1]\n"
- ".inst 0x4f98e87a // sdot v26.4s, v3.16b, v24.4b[2]\n"
- ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n"
- "ldr q24, [x22, #0x60]\n"
- ".inst 0x4f89e2d4 // sdot v20.4s, v22.16b, v9.4b[0]\n"
- ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n"
- ".inst 0x4f89eada // sdot v26.4s, v22.16b, v9.4b[2]\n"
- ".inst 0x4fa9eac2 // sdot v2.4s, v22.16b, v9.4b[3]\n"
- "ldr q9, [x22, #0x70]\n"
- "add x22, x22, #0x88\n"
- ".inst 0x4f98e374 // sdot v20.4s, v27.16b, v24.4b[0]\n"
- ".inst 0x4fb8e36a // sdot v10.4s, v27.16b, v24.4b[1]\n"
- ".inst 0x4f98eb7a // sdot v26.4s, v27.16b, v24.4b[2]\n"
- ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n"
- "ldr q24, [x21, #0x0]\n"
- ".inst 0x4f89e3d4 // sdot v20.4s, v30.16b, v9.4b[0]\n"
- ".inst 0x4fa9e3ca // sdot v10.4s, v30.16b, v9.4b[1]\n"
- ".inst 0x4f89ebda // sdot v26.4s, v30.16b, v9.4b[2]\n"
- ".inst 0x4fa9ebc2 // sdot v2.4s, v30.16b, v9.4b[3]\n"
- "fmul v9.4s, v17.4s, v29.s[0]\n"
- "scvtf v20.4s, v20.4s, #0x4\n"
- "scvtf v10.4s, v10.4s, #0x4\n"
- "scvtf v26.4s, v26.4s, #0x4\n"
- "scvtf v2.4s, v2.4s, #0x4\n"
- "fmla v25.4s, v20.4s, v9.4s\n"
- "ldr q9, [x21, #0x10]\n"
- "fmul v20.4s, v17.4s, v29.s[1]\n"
- "fmla v7.4s, v10.4s, v20.4s\n"
- "ldr d20, [x21, #-0x8]\n"
- "fmul v10.4s, v17.4s, v29.s[2]\n"
- "fmul v29.4s, v17.4s, v29.s[3]\n"
- "fcvtl v20.4s, v20.4h\n"
- "fmla v0.4s, v26.4s, v10.4s\n"
- "movi v26.4s, #0x0\n"
- "movi v10.4s, #0x0\n"
- "fmla v4.4s, v2.4s, v29.4s\n"
- "movi v2.4s, #0x0\n"
- "movi v29.4s, #0x0\n"
- ".inst 0x4f98e19a // sdot v26.4s, v12.16b, v24.4b[0]\n"
- ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n"
- ".inst 0x4f98e982 // sdot v2.4s, v12.16b, v24.4b[2]\n"
- ".inst 0x4fb8e99d // sdot v29.4s, v12.16b, v24.4b[3]\n"
- "ldr q12, [x21, #0x20]\n"
- "fmul v24.4s, v17.4s, v20.s[0]\n"
- ".inst 0x4f89e3fa // sdot v26.4s, v31.16b, v9.4b[0]\n"
- ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n"
- ".inst 0x4f89ebe2 // sdot v2.4s, v31.16b, v9.4b[2]\n"
- ".inst 0x4fa9ebfd // sdot v29.4s, v31.16b, v9.4b[3]\n"
- "ldr q9, [x21, #0x30]\n"
- "fmul v31.4s, v17.4s, v20.s[1]\n"
- ".inst 0x4f8ce0da // sdot v26.4s, v6.16b, v12.4b[0]\n"
- ".inst 0x4face0ca // sdot v10.4s, v6.16b, v12.4b[1]\n"
- ".inst 0x4f8ce8c2 // sdot v2.4s, v6.16b, v12.4b[2]\n"
- ".inst 0x4face8dd // sdot v29.4s, v6.16b, v12.4b[3]\n"
- "ldr q12, [x21, #0x40]\n"
- "fmul v6.4s, v17.4s, v20.s[2]\n"
- "fmul v20.4s, v17.4s, v20.s[3]\n"
- ".inst 0x4f89e39a // sdot v26.4s, v28.16b, v9.4b[0]\n"
- ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n"
- ".inst 0x4f89eb82 // sdot v2.4s, v28.16b, v9.4b[2]\n"
- ".inst 0x4fa9eb9d // sdot v29.4s, v28.16b, v9.4b[3]\n"
- "ldr q9, [x21, #0x50]\n"
- ".inst 0x4f8ce07a // sdot v26.4s, v3.16b, v12.4b[0]\n"
- ".inst 0x4face06a // sdot v10.4s, v3.16b, v12.4b[1]\n"
- ".inst 0x4f8ce862 // sdot v2.4s, v3.16b, v12.4b[2]\n"
- ".inst 0x4face87d // sdot v29.4s, v3.16b, v12.4b[3]\n"
- "ldr q12, [x21, #0x60]\n"
- ".inst 0x4f89e2da // sdot v26.4s, v22.16b, v9.4b[0]\n"
- ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n"
- ".inst 0x4f89eac2 // sdot v2.4s, v22.16b, v9.4b[2]\n"
- ".inst 0x4fa9eadd // sdot v29.4s, v22.16b, v9.4b[3]\n"
- "ldr q17, [x21, #0x70]\n"
- "add x21, x21, #0x88\n"
- ".inst 0x4f8ce37a // sdot v26.4s, v27.16b, v12.4b[0]\n"
- ".inst 0x4face36a // sdot v10.4s, v27.16b, v12.4b[1]\n"
- ".inst 0x4f8ceb62 // sdot v2.4s, v27.16b, v12.4b[2]\n"
- ".inst 0x4faceb7d // sdot v29.4s, v27.16b, v12.4b[3]\n"
- ".inst 0x4f91e3da // sdot v26.4s, v30.16b, v17.4b[0]\n"
- ".inst 0x4fb1e3ca // sdot v10.4s, v30.16b, v17.4b[1]\n"
- ".inst 0x4f91ebc2 // sdot v2.4s, v30.16b, v17.4b[2]\n"
- ".inst 0x4fb1ebdd // sdot v29.4s, v30.16b, v17.4b[3]\n"
- "scvtf v26.4s, v26.4s, #0x4\n"
- "scvtf v10.4s, v10.4s, #0x4\n"
- "fmla v5.4s, v26.4s, v24.4s\n"
- "scvtf v2.4s, v2.4s, #0x4\n"
- "scvtf v29.4s, v29.4s, #0x4\n"
- "fmla v21.4s, v10.4s, v31.4s\n"
- "fmla v8.4s, v2.4s, v6.4s\n"
- "fmla v1.4s, v29.4s, v20.4s\n"
- "bgt 3b\n"
- "mov x20, %x[res_ptr]\n"
- "subs x27, x27, #0x4\n"
- "add %x[res_ptr], %x[res_ptr], #0x10\n"
- "str q15, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q19, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q18, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q14, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q11, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q13, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q23, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q16, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q25, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q7, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q0, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q4, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q5, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q21, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q8, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q1, [x20, #0x0]\n"
- "bne 2b\n"
- "mov x20, #0x4\n"
- "sub x10, x10, #0x10\n"
- "cmp x10, #0x10\n"
- "mov %x[res_ptr], x26\n"
- "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
- "bge 1b\n"
- "4:" // Row loop skip
- "cbz x10, 9f\n"
- "5:" // Row tail: Row loop
- "add x24, %x[b_ptr], #0x8\n"
- "mov x23, %x[nc]\n"
- "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
- "6:" // Row tail: Column loop
- "movi v15.16b, #0x0\n"
- "movi v19.16b, #0x0\n"
- "add x25, %x[a_ptr], #0x8\n"
- "mov x21, %x[nb]\n"
- "movi v18.16b, #0x0\n"
- "movi v14.16b, #0x0\n"
- "7:" // Row tail: Block loop
- "ldr q7, [x24, #0x0]\n"
- "ldr q5, [x25, #0x0]\n"
- "movi v9.16b, #0x4\n"
- "movi v4.4s, #0x0\n"
- "ldr q3, [x24, #0x10]\n"
- "ldr q2, [x25, #0x10]\n"
- "movi v1.4s, #0x0\n"
- "movi v0.4s, #0x0\n"
- "ldr q13, [x24, #0x20]\n"
- "ldr q31, [x25, #0x20]\n"
- "movi v30.4s, #0x0\n"
- "movi v29.16b, #0xf0\n"
- "ldr q28, [x24, #0x30]\n"
- "ldr q27, [x25, #0x30]\n"
- "sshl v20.16b, v7.16b, v9.16b\n"
- "sub x20, x24, #0x8\n"
- "ldr q26, [x25, #0x40]\n"
- "ldr q25, [x25, #0x50]\n"
- "sshl v17.16b, v3.16b, v9.16b\n"
- "and v7.16b, v7.16b, v29.16b\n"
- "ldr q24, [x25, #0x60]\n"
- "ldr q16, [x25, #0x70]\n"
- "sshl v22.16b, v13.16b, v9.16b\n"
- "and v3.16b, v3.16b, v29.16b\n"
- "ldr d21, [x20, #0x0]\n"
- "ldr d12, [x25, #-0x8]\n"
- ".inst 0x4f85e284 // sdot v4.4s, v20.16b, v5.4b[0]\n"
- ".inst 0x4fa5e281 // sdot v1.4s, v20.16b, v5.4b[1]\n"
- ".inst 0x4f85ea80 // sdot v0.4s, v20.16b, v5.4b[2]\n"
- ".inst 0x4fa5ea9e // sdot v30.4s, v20.16b, v5.4b[3]\n"
- "sshl v9.16b, v28.16b, v9.16b\n"
- "subs x21, x21, #0x1\n"
- "and v13.16b, v13.16b, v29.16b\n"
- "and v28.16b, v28.16b, v29.16b\n"
- "add x25, x25, #0x88\n"
- "add x24, x24, #0x48\n"
- "fcvtl v21.4s, v21.4h\n"
- "fcvtl v12.4s, v12.4h\n"
- ".inst 0x4f82e224 // sdot v4.4s, v17.16b, v2.4b[0]\n"
- ".inst 0x4fa2e221 // sdot v1.4s, v17.16b, v2.4b[1]\n"
- ".inst 0x4f82ea20 // sdot v0.4s, v17.16b, v2.4b[2]\n"
- ".inst 0x4fa2ea3e // sdot v30.4s, v17.16b, v2.4b[3]\n"
- "fmul v11.4s, v21.4s, v12.s[0]\n"
- "fmul v23.4s, v21.4s, v12.s[1]\n"
- "fmul v17.4s, v21.4s, v12.s[2]\n"
- ".inst 0x4f9fe2c4 // sdot v4.4s, v22.16b, v31.4b[0]\n"
- "fmul v6.4s, v21.4s, v12.s[3]\n"
- ".inst 0x4fbfe2c1 // sdot v1.4s, v22.16b, v31.4b[1]\n"
- ".inst 0x4f9feac0 // sdot v0.4s, v22.16b, v31.4b[2]\n"
- ".inst 0x4fbfeade // sdot v30.4s, v22.16b, v31.4b[3]\n"
- ".inst 0x4f9be124 // sdot v4.4s, v9.16b, v27.4b[0]\n"
- ".inst 0x4fbbe121 // sdot v1.4s, v9.16b, v27.4b[1]\n"
- ".inst 0x4f9be920 // sdot v0.4s, v9.16b, v27.4b[2]\n"
- ".inst 0x4fbbe93e // sdot v30.4s, v9.16b, v27.4b[3]\n"
- ".inst 0x4f9ae0e4 // sdot v4.4s, v7.16b, v26.4b[0]\n"
- ".inst 0x4fbae0e1 // sdot v1.4s, v7.16b, v26.4b[1]\n"
- ".inst 0x4f9ae8e0 // sdot v0.4s, v7.16b, v26.4b[2]\n"
- ".inst 0x4fbae8fe // sdot v30.4s, v7.16b, v26.4b[3]\n"
- ".inst 0x4f99e064 // sdot v4.4s, v3.16b, v25.4b[0]\n"
- ".inst 0x4fb9e061 // sdot v1.4s, v3.16b, v25.4b[1]\n"
- ".inst 0x4f99e860 // sdot v0.4s, v3.16b, v25.4b[2]\n"
- ".inst 0x4fb9e87e // sdot v30.4s, v3.16b, v25.4b[3]\n"
- ".inst 0x4f98e1a4 // sdot v4.4s, v13.16b, v24.4b[0]\n"
- ".inst 0x4fb8e1a1 // sdot v1.4s, v13.16b, v24.4b[1]\n"
- ".inst 0x4f98e9a0 // sdot v0.4s, v13.16b, v24.4b[2]\n"
- ".inst 0x4fb8e9be // sdot v30.4s, v13.16b, v24.4b[3]\n"
- ".inst 0x4f90e384 // sdot v4.4s, v28.16b, v16.4b[0]\n"
- ".inst 0x4fb0e381 // sdot v1.4s, v28.16b, v16.4b[1]\n"
- ".inst 0x4f90eb80 // sdot v0.4s, v28.16b, v16.4b[2]\n"
- ".inst 0x4fb0eb9e // sdot v30.4s, v28.16b, v16.4b[3]\n"
- "scvtf v4.4s, v4.4s, #0x4\n"
- "scvtf v1.4s, v1.4s, #0x4\n"
- "scvtf v0.4s, v0.4s, #0x4\n"
- "fmla v15.4s, v4.4s, v11.4s\n"
- "scvtf v30.4s, v30.4s, #0x4\n"
- "fmla v19.4s, v1.4s, v23.4s\n"
- "fmla v18.4s, v0.4s, v17.4s\n"
- "fmla v14.4s, v30.4s, v6.4s\n"
- "bgt 7b\n"
- "mov x20, %x[res_ptr]\n"
- "cmp x10, #0x1\n"
- "str q15, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "ble 8f\n"
- "cmp x10, #0x2\n"
- "str q19, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "ble 8f\n"
- "cmp x10, #0x3\n"
- "str q18, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "ble 8f\n"
- "str q14, [x20, #0x0]\n"
- "8:" // Row tail: Accumulator store skip
- "subs x23, x23, #0x4\n"
- "add %x[res_ptr], %x[res_ptr], #0x10\n"
- "bne 6b\n"
- "subs x10, x10, #0x4\n"
- "add %x[a_ptr], %x[a_ptr], x9\n"
- "mov %x[res_ptr], x22\n"
- "bgt 5b\n"
- "9:" // Row tail: Row loop skip
- : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
- : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
- : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
- );
-#else
- float sumf[4][4];
- int sumi;
-
- for (int y = 0; y < nr / 4; y++) {
- const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
- for (int x = 0; x < nc / ncols_interleaved; x++) {
- const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
- for (int m = 0; m < 4; m++) {
- for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
- }
- for (int l = 0; l < nb; l++) {
- for (int k = 0; k < (qk / (2 * blocklen)); k++) {
- for (int m = 0; m < 4; m++) {
- for (int j = 0; j < ncols_interleaved; j++) {
- sumi = 0;
- for (int i = 0; i < blocklen; ++i) {
- const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
- const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
- sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
- (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+ for (int y = 0; y < nr / 4; y++) {
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+ }
+ for (int l = 0; l < nb; l++) {
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumi = 0;
+ for (int i = 0; i < blocklen; ++i) {
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+ sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+ (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+ }
+ sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
}
- sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
}
}
}
- }
- for (int m = 0; m < 4; m++) {
- for (int j = 0; j < ncols_interleaved; j++)
- s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++)
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+ }
}
}
}
-#endif
}
void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
@@ -1518,413 +1555,406 @@ void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
UNUSED(ncols_interleaved);
UNUSED(blocklen);
-#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
- if (ggml_sve_cnt_b == QK8_0) {
- GGML_ASSERT(!(ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
- "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
- }
-#endif
-#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
- const void * b_ptr = vx;
- const void * a_ptr = vy;
- float * res_ptr = s;
- size_t res_stride = bs * sizeof(float);
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+ if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+ const void * b_ptr = vx;
+ const void * a_ptr = vy;
+ float * res_ptr = s;
+ size_t res_stride = bs * sizeof(float);
- __asm__ __volatile__(
- "mov x10, %x[nr]\n"
- "mov x9, #0x88\n"
- "cmp x10, #0x10\n"
- "mul x9, %x[nb], x9\n"
- "blt 4f\n"
- "1:" // Row loop
- "add x28, %x[b_ptr], #0x8\n"
- "mov x27, %x[nc]\n"
- "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
- "2:" // Column loop
- "add x25, %x[a_ptr], #0x8\n"
- "movi v2.16b, #0x0\n"
- "movi v10.16b, #0x0\n"
- "mov x24, %x[nb]\n"
- "add x23, x25, x9\n"
- "movi v12.16b, #0x0\n"
- "movi v28.16b, #0x0\n"
- "add x22, x23, x9\n"
- "movi v11.16b, #0x0\n"
- "movi v13.16b, #0x0\n"
- "add x21, x22, x9\n"
- "movi v22.16b, #0x0\n"
- "movi v23.16b, #0x0\n"
- "movi v25.16b, #0x0\n"
- "movi v5.16b, #0x0\n"
- "movi v7.16b, #0x0\n"
- "movi v4.16b, #0x0\n"
- "movi v6.16b, #0x0\n"
- "movi v30.16b, #0x0\n"
- "movi v24.16b, #0x0\n"
- "movi v14.16b, #0x0\n"
- "3:" // Block loop
- "ldr q21, [x28, #0x0]\n"
- "ldr q16, [x28, #0x10]\n"
- "movi v1.16b, #0x4\n"
- "movi v19.4s, #0x0\n"
- "ldr q27, [x25, #0x0]\n"
- "ldr q15, [x25, #0x10]\n"
- "movi v26.4s, #0x0\n"
- "movi v18.4s, #0x0\n"
- "ldr q29, [x28, #0x20]\n"
- "ldr q3, [x28, #0x30]\n"
- "movi v17.4s, #0x0\n"
- "movi v0.16b, #0xf0\n"
- "ldr d20, [x25, #-0x8]\n"
- "ldr d9, [x23, #-0x8]\n"
- "sshl v8.16b, v21.16b, v1.16b\n"
- "sshl v31.16b, v16.16b, v1.16b\n"
- "and v21.16b, v21.16b, v0.16b\n"
- "and v16.16b, v16.16b, v0.16b\n"
- "sub x20, x28, #0x8\n"
- "subs x24, x24, #0x1\n"
- "add x28, x28, #0x48\n"
- ".inst 0x4e88a773 // smmla v19.4s, v27.16b, v8.16b\n"
- ".inst 0x4e9fa77a // smmla v26.4s, v27.16b, v31.16b\n"
- "ldr q27, [x25, #0x20]\n"
- ".inst 0x4e88a5f2 // smmla v18.4s, v15.16b, v8.16b\n"
- ".inst 0x4e9fa5f1 // smmla v17.4s, v15.16b, v31.16b\n"
- "sshl v15.16b, v29.16b, v1.16b\n"
- "sshl v1.16b, v3.16b, v1.16b\n"
- "and v29.16b, v29.16b, v0.16b\n"
- "and v3.16b, v3.16b, v0.16b\n"
- "ldr q0, [x25, #0x30]\n"
- "fcvtl v20.4s, v20.4h\n"
- ".inst 0x4e8fa773 // smmla v19.4s, v27.16b, v15.16b\n"
- "fcvtl v9.4s, v9.4h\n"
- ".inst 0x4e81a77a // smmla v26.4s, v27.16b, v1.16b\n"
- "ldr q27, [x25, #0x40]\n"
- ".inst 0x4e8fa412 // smmla v18.4s, v0.16b, v15.16b\n"
- ".inst 0x4e81a411 // smmla v17.4s, v0.16b, v1.16b\n"
- "ldr q0, [x25, #0x50]\n"
- ".inst 0x4e95a773 // smmla v19.4s, v27.16b, v21.16b\n"
- ".inst 0x4e90a77a // smmla v26.4s, v27.16b, v16.16b\n"
- "ldr q27, [x25, #0x60]\n"
- ".inst 0x4e95a412 // smmla v18.4s, v0.16b, v21.16b\n"
- ".inst 0x4e90a411 // smmla v17.4s, v0.16b, v16.16b\n"
- "ldr q0, [x25, #0x70]\n"
- "add x25, x25, #0x88\n"
- ".inst 0x4e9da773 // smmla v19.4s, v27.16b, v29.16b\n"
- ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n"
- "ldr d27, [x20, #0x0]\n"
- ".inst 0x4e9da412 // smmla v18.4s, v0.16b, v29.16b\n"
- ".inst 0x4e83a411 // smmla v17.4s, v0.16b, v3.16b\n"
- "fcvtl v27.4s, v27.4h\n"
- "uzp1 v0.2d, v19.2d, v26.2d\n"
- "uzp2 v26.2d, v19.2d, v26.2d\n"
- "fmul v19.4s, v27.4s, v20.s[0]\n"
- "scvtf v0.4s, v0.4s, #0x4\n"
- "scvtf v26.4s, v26.4s, #0x4\n"
- "fmla v2.4s, v0.4s, v19.4s\n"
- "ldr q19, [x23, #0x0]\n"
- "uzp1 v0.2d, v18.2d, v17.2d\n"
- "uzp2 v18.2d, v18.2d, v17.2d\n"
- "fmul v17.4s, v27.4s, v20.s[1]\n"
- "scvtf v0.4s, v0.4s, #0x4\n"
- "scvtf v18.4s, v18.4s, #0x4\n"
- "fmla v10.4s, v26.4s, v17.4s\n"
- "ldr q17, [x23, #0x10]\n"
- "fmul v26.4s, v27.4s, v20.s[2]\n"
- "fmul v20.4s, v27.4s, v20.s[3]\n"
- "fmla v12.4s, v0.4s, v26.4s\n"
- "ldr d0, [x22, #-0x8]\n"
- "ldr d26, [x21, #-0x8]\n"
- "fcvtl v0.4s, v0.4h\n"
- "fmla v28.4s, v18.4s, v20.4s\n"
- "movi v20.4s, #0x0\n"
- "movi v18.4s, #0x0\n"
- ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n"
- ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n"
- "ldr q19, [x23, #0x20]\n"
- "fcvtl v26.4s, v26.4h\n"
- ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n"
- ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n"
- "ldr q19, [x23, #0x40]\n"
- ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n"
- ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n"
- "ldr q19, [x23, #0x60]\n"
- ".inst 0x4e9da674 // smmla v20.4s, v19.16b, v29.16b\n"
- ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n"
- "uzp1 v19.2d, v20.2d, v18.2d\n"
- "scvtf v19.4s, v19.4s, #0x4\n"
- "uzp2 v20.2d, v20.2d, v18.2d\n"
- "fmul v18.4s, v27.4s, v9.s[0]\n"
- "scvtf v20.4s, v20.4s, #0x4\n"
- "fmla v11.4s, v19.4s, v18.4s\n"
- "ldr q18, [x22, #0x0]\n"
- "fmul v19.4s, v27.4s, v9.s[1]\n"
- "fmla v13.4s, v20.4s, v19.4s\n"
- "movi v19.4s, #0x0\n"
- "movi v20.4s, #0x0\n"
- ".inst 0x4e88a633 // smmla v19.4s, v17.16b, v8.16b\n"
- ".inst 0x4e9fa634 // smmla v20.4s, v17.16b, v31.16b\n"
- "ldr q17, [x23, #0x30]\n"
- ".inst 0x4e8fa633 // smmla v19.4s, v17.16b, v15.16b\n"
- ".inst 0x4e81a634 // smmla v20.4s, v17.16b, v1.16b\n"
- "ldr q17, [x23, #0x50]\n"
- ".inst 0x4e95a633 // smmla v19.4s, v17.16b, v21.16b\n"
- ".inst 0x4e90a634 // smmla v20.4s, v17.16b, v16.16b\n"
- "ldr q17, [x23, #0x70]\n"
- "add x23, x23, #0x88\n"
- ".inst 0x4e9da633 // smmla v19.4s, v17.16b, v29.16b\n"
- ".inst 0x4e83a634 // smmla v20.4s, v17.16b, v3.16b\n"
- "uzp1 v17.2d, v19.2d, v20.2d\n"
- "scvtf v17.4s, v17.4s, #0x4\n"
- "uzp2 v20.2d, v19.2d, v20.2d\n"
- "fmul v19.4s, v27.4s, v9.s[2]\n"
- "fmul v9.4s, v27.4s, v9.s[3]\n"
- "scvtf v20.4s, v20.4s, #0x4\n"
- "fmla v22.4s, v17.4s, v19.4s\n"
- "ldr q17, [x22, #0x10]\n"
- "movi v19.4s, #0x0\n"
- ".inst 0x4e88a653 // smmla v19.4s, v18.16b, v8.16b\n"
- "fmla v23.4s, v20.4s, v9.4s\n"
- "movi v20.4s, #0x0\n"
- "movi v9.4s, #0x0\n"
- ".inst 0x4e9fa654 // smmla v20.4s, v18.16b, v31.16b\n"
- "ldr q18, [x22, #0x20]\n"
- ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n"
- ".inst 0x4e8fa653 // smmla v19.4s, v18.16b, v15.16b\n"
- ".inst 0x4e81a654 // smmla v20.4s, v18.16b, v1.16b\n"
- "ldr q18, [x22, #0x40]\n"
- ".inst 0x4e95a653 // smmla v19.4s, v18.16b, v21.16b\n"
- ".inst 0x4e90a654 // smmla v20.4s, v18.16b, v16.16b\n"
- "ldr q18, [x22, #0x60]\n"
- ".inst 0x4e9da653 // smmla v19.4s, v18.16b, v29.16b\n"
- ".inst 0x4e83a654 // smmla v20.4s, v18.16b, v3.16b\n"
- "movi v18.4s, #0x0\n"
- ".inst 0x4e9fa632 // smmla v18.4s, v17.16b, v31.16b\n"
- "ldr q17, [x22, #0x30]\n"
- ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n"
- ".inst 0x4e81a632 // smmla v18.4s, v17.16b, v1.16b\n"
- "ldr q17, [x22, #0x50]\n"
- ".inst 0x4e95a629 // smmla v9.4s, v17.16b, v21.16b\n"
- ".inst 0x4e90a632 // smmla v18.4s, v17.16b, v16.16b\n"
- "ldr q17, [x22, #0x70]\n"
- "add x22, x22, #0x88\n"
- ".inst 0x4e9da629 // smmla v9.4s, v17.16b, v29.16b\n"
- ".inst 0x4e83a632 // smmla v18.4s, v17.16b, v3.16b\n"
- "uzp1 v17.2d, v19.2d, v20.2d\n"
- "uzp2 v20.2d, v19.2d, v20.2d\n"
- "fmul v19.4s, v27.4s, v0.s[0]\n"
- "scvtf v17.4s, v17.4s, #0x4\n"
- "scvtf v20.4s, v20.4s, #0x4\n"
- "fmla v25.4s, v17.4s, v19.4s\n"
- "ldr q19, [x21, #0x0]\n"
- "fmul v17.4s, v27.4s, v0.s[1]\n"
- "fmla v5.4s, v20.4s, v17.4s\n"
- "ldr q17, [x21, #0x10]\n"
- "uzp1 v20.2d, v9.2d, v18.2d\n"
- "uzp2 v9.2d, v9.2d, v18.2d\n"
- "fmul v18.4s, v27.4s, v0.s[2]\n"
- "fmul v0.4s, v27.4s, v0.s[3]\n"
- "scvtf v20.4s, v20.4s, #0x4\n"
- "scvtf v9.4s, v9.4s, #0x4\n"
- "fmla v7.4s, v20.4s, v18.4s\n"
- "movi v20.4s, #0x0\n"
- "movi v18.4s, #0x0\n"
- ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n"
- ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n"
- "ldr q19, [x21, #0x20]\n"
- "fmla v4.4s, v9.4s, v0.4s\n"
- "movi v9.4s, #0x0\n"
- "movi v0.4s, #0x0\n"
- ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n"
- "fmul v8.4s, v27.4s, v26.s[0]\n"
- ".inst 0x4e9fa620 // smmla v0.4s, v17.16b, v31.16b\n"
- "ldr q17, [x21, #0x30]\n"
- ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n"
- "fmul v31.4s, v27.4s, v26.s[1]\n"
- ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n"
- "ldr q19, [x21, #0x40]\n"
- ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n"
- "fmul v15.4s, v27.4s, v26.s[2]\n"
- "fmul v27.4s, v27.4s, v26.s[3]\n"
- ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n"
- "ldr q1, [x21, #0x50]\n"
- ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n"
- ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n"
- "ldr q26, [x21, #0x60]\n"
- ".inst 0x4e95a429 // smmla v9.4s, v1.16b, v21.16b\n"
- ".inst 0x4e90a420 // smmla v0.4s, v1.16b, v16.16b\n"
- "ldr q21, [x21, #0x70]\n"
- "add x21, x21, #0x88\n"
- ".inst 0x4e9da754 // smmla v20.4s, v26.16b, v29.16b\n"
- ".inst 0x4e83a752 // smmla v18.4s, v26.16b, v3.16b\n"
- ".inst 0x4e9da6a9 // smmla v9.4s, v21.16b, v29.16b\n"
- ".inst 0x4e83a6a0 // smmla v0.4s, v21.16b, v3.16b\n"
- "uzp1 v29.2d, v20.2d, v18.2d\n"
- "uzp2 v21.2d, v20.2d, v18.2d\n"
- "scvtf v29.4s, v29.4s, #0x4\n"
- "uzp1 v18.2d, v9.2d, v0.2d\n"
- "uzp2 v16.2d, v9.2d, v0.2d\n"
- "scvtf v21.4s, v21.4s, #0x4\n"
- "fmla v6.4s, v29.4s, v8.4s\n"
- "scvtf v18.4s, v18.4s, #0x4\n"
- "scvtf v16.4s, v16.4s, #0x4\n"
- "fmla v30.4s, v21.4s, v31.4s\n"
- "fmla v24.4s, v18.4s, v15.4s\n"
- "fmla v14.4s, v16.4s, v27.4s\n"
- "bgt 3b\n"
- "mov x20, %x[res_ptr]\n"
- "subs x27, x27, #0x4\n"
- "add %x[res_ptr], %x[res_ptr], #0x10\n"
- "str q2, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q10, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q12, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q28, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q11, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q13, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q22, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q23, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q25, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q5, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q7, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q4, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q6, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q30, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q24, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q14, [x20, #0x0]\n"
- "bne 2b\n"
- "mov x20, #0x4\n"
- "sub x10, x10, #0x10\n"
- "cmp x10, #0x10\n"
- "mov %x[res_ptr], x26\n"
- "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
- "bge 1b\n"
- "4:" // Row loop skip
- "cbz x10, 9f\n"
- "5:" // Row tail: Row loop
- "add x24, %x[b_ptr], #0x8\n"
- "mov x23, %x[nc]\n"
- "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
- "6:" // Row tail: Column loop
- "movi v2.16b, #0x0\n"
- "movi v10.16b, #0x0\n"
- "add x25, %x[a_ptr], #0x8\n"
- "mov x21, %x[nb]\n"
- "movi v12.16b, #0x0\n"
- "movi v28.16b, #0x0\n"
- "7:" // Row tail: Block loop
- "ldr q6, [x24, #0x0]\n"
- "ldr q5, [x24, #0x10]\n"
- "movi v17.16b, #0x4\n"
- "movi v8.4s, #0x0\n"
- "ldr q4, [x25, #0x0]\n"
- "ldr q13, [x25, #0x10]\n"
- "movi v27.4s, #0x0\n"
- "movi v0.4s, #0x0\n"
- "ldr q31, [x24, #0x20]\n"
- "ldr q14, [x24, #0x30]\n"
- "movi v29.4s, #0x0\n"
- "movi v22.16b, #0xf0\n"
- "ldr q11, [x25, #0x20]\n"
- "ldr q23, [x25, #0x30]\n"
- "sshl v21.16b, v6.16b, v17.16b\n"
- "sshl v16.16b, v5.16b, v17.16b\n"
- "ldr q20, [x25, #0x40]\n"
- "ldr q26, [x25, #0x50]\n"
- "and v6.16b, v6.16b, v22.16b\n"
- "and v5.16b, v5.16b, v22.16b\n"
- "ldr q25, [x25, #0x60]\n"
- "ldr q3, [x25, #0x70]\n"
- "sshl v19.16b, v31.16b, v17.16b\n"
- "sshl v18.16b, v14.16b, v17.16b\n"
- "ldr d17, [x25, #-0x8]\n"
- ".inst 0x4e95a488 // smmla v8.4s, v4.16b, v21.16b\n"
- ".inst 0x4e90a49b // smmla v27.4s, v4.16b, v16.16b\n"
- "and v31.16b, v31.16b, v22.16b\n"
- ".inst 0x4e95a5a0 // smmla v0.4s, v13.16b, v21.16b\n"
- ".inst 0x4e90a5bd // smmla v29.4s, v13.16b, v16.16b\n"
- "and v14.16b, v14.16b, v22.16b\n"
- "sub x20, x24, #0x8\n"
- "ldr d16, [x20, #0x0]\n"
- "subs x21, x21, #0x1\n"
- "add x25, x25, #0x88\n"
- "fcvtl v17.4s, v17.4h\n"
- "add x24, x24, #0x48\n"
- ".inst 0x4e93a568 // smmla v8.4s, v11.16b, v19.16b\n"
- ".inst 0x4e92a57b // smmla v27.4s, v11.16b, v18.16b\n"
- ".inst 0x4e93a6e0 // smmla v0.4s, v23.16b, v19.16b\n"
- ".inst 0x4e92a6fd // smmla v29.4s, v23.16b, v18.16b\n"
- "fcvtl v16.4s, v16.4h\n"
- ".inst 0x4e86a688 // smmla v8.4s, v20.16b, v6.16b\n"
- ".inst 0x4e85a69b // smmla v27.4s, v20.16b, v5.16b\n"
- "fmul v23.4s, v16.4s, v17.s[0]\n"
- "fmul v21.4s, v16.4s, v17.s[1]\n"
- "fmul v1.4s, v16.4s, v17.s[2]\n"
- "fmul v20.4s, v16.4s, v17.s[3]\n"
- ".inst 0x4e86a740 // smmla v0.4s, v26.16b, v6.16b\n"
- ".inst 0x4e85a75d // smmla v29.4s, v26.16b, v5.16b\n"
- ".inst 0x4e9fa728 // smmla v8.4s, v25.16b, v31.16b\n"
- ".inst 0x4e8ea73b // smmla v27.4s, v25.16b, v14.16b\n"
- ".inst 0x4e9fa460 // smmla v0.4s, v3.16b, v31.16b\n"
- ".inst 0x4e8ea47d // smmla v29.4s, v3.16b, v14.16b\n"
- "uzp1 v19.2d, v8.2d, v27.2d\n"
- "uzp2 v18.2d, v8.2d, v27.2d\n"
- "scvtf v19.4s, v19.4s, #0x4\n"
- "uzp1 v17.2d, v0.2d, v29.2d\n"
- "uzp2 v16.2d, v0.2d, v29.2d\n"
- "scvtf v18.4s, v18.4s, #0x4\n"
- "fmla v2.4s, v19.4s, v23.4s\n"
- "scvtf v17.4s, v17.4s, #0x4\n"
- "scvtf v16.4s, v16.4s, #0x4\n"
- "fmla v10.4s, v18.4s, v21.4s\n"
- "fmla v12.4s, v17.4s, v1.4s\n"
- "fmla v28.4s, v16.4s, v20.4s\n"
- "bgt 7b\n"
- "mov x20, %x[res_ptr]\n"
- "cmp x10, #0x1\n"
- "str q2, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "ble 8f\n"
- "cmp x10, #0x2\n"
- "str q10, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "ble 8f\n"
- "cmp x10, #0x3\n"
- "str q12, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "ble 8f\n"
- "str q28, [x20, #0x0]\n"
- "8:" // Row tail: Accumulator store skip
- "subs x23, x23, #0x4\n"
- "add %x[res_ptr], %x[res_ptr], #0x10\n"
- "bne 6b\n"
- "subs x10, x10, #0x4\n"
- "add %x[a_ptr], %x[a_ptr], x9\n"
- "mov %x[res_ptr], x22\n"
- "bgt 5b\n"
- "9:" // Row tail: Row loop skip
- : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
- : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
- : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
- );
-#elif defined(__ARM_NEON) && defined(__aarch64__)
- GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
- "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
- "performance");
-#else
+ __asm__ __volatile__(
+ "mov x10, %x[nr]\n"
+ "mov x9, #0x88\n"
+ "cmp x10, #0x10\n"
+ "mul x9, %x[nb], x9\n"
+ "blt 4f\n"
+ "1:" // Row loop
+ "add x28, %x[b_ptr], #0x8\n"
+ "mov x27, %x[nc]\n"
+ "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
+ "2:" // Column loop
+ "add x25, %x[a_ptr], #0x8\n"
+ "movi v2.16b, #0x0\n"
+ "movi v10.16b, #0x0\n"
+ "mov x24, %x[nb]\n"
+ "add x23, x25, x9\n"
+ "movi v12.16b, #0x0\n"
+ "movi v28.16b, #0x0\n"
+ "add x22, x23, x9\n"
+ "movi v11.16b, #0x0\n"
+ "movi v13.16b, #0x0\n"
+ "add x21, x22, x9\n"
+ "movi v22.16b, #0x0\n"
+ "movi v23.16b, #0x0\n"
+ "movi v25.16b, #0x0\n"
+ "movi v5.16b, #0x0\n"
+ "movi v7.16b, #0x0\n"
+ "movi v4.16b, #0x0\n"
+ "movi v6.16b, #0x0\n"
+ "movi v30.16b, #0x0\n"
+ "movi v24.16b, #0x0\n"
+ "movi v14.16b, #0x0\n"
+ "3:" // Block loop
+ "ldr q21, [x28, #0x0]\n"
+ "ldr q16, [x28, #0x10]\n"
+ "movi v1.16b, #0x4\n"
+ "movi v19.4s, #0x0\n"
+ "ldr q27, [x25, #0x0]\n"
+ "ldr q15, [x25, #0x10]\n"
+ "movi v26.4s, #0x0\n"
+ "movi v18.4s, #0x0\n"
+ "ldr q29, [x28, #0x20]\n"
+ "ldr q3, [x28, #0x30]\n"
+ "movi v17.4s, #0x0\n"
+ "movi v0.16b, #0xf0\n"
+ "ldr d20, [x25, #-0x8]\n"
+ "ldr d9, [x23, #-0x8]\n"
+ "sshl v8.16b, v21.16b, v1.16b\n"
+ "sshl v31.16b, v16.16b, v1.16b\n"
+ "and v21.16b, v21.16b, v0.16b\n"
+ "and v16.16b, v16.16b, v0.16b\n"
+ "sub x20, x28, #0x8\n"
+ "subs x24, x24, #0x1\n"
+ "add x28, x28, #0x48\n"
+ ".inst 0x4e88a773 // smmla v19.4s, v27.16b, v8.16b\n"
+ ".inst 0x4e9fa77a // smmla v26.4s, v27.16b, v31.16b\n"
+ "ldr q27, [x25, #0x20]\n"
+ ".inst 0x4e88a5f2 // smmla v18.4s, v15.16b, v8.16b\n"
+ ".inst 0x4e9fa5f1 // smmla v17.4s, v15.16b, v31.16b\n"
+ "sshl v15.16b, v29.16b, v1.16b\n"
+ "sshl v1.16b, v3.16b, v1.16b\n"
+ "and v29.16b, v29.16b, v0.16b\n"
+ "and v3.16b, v3.16b, v0.16b\n"
+ "ldr q0, [x25, #0x30]\n"
+ "fcvtl v20.4s, v20.4h\n"
+ ".inst 0x4e8fa773 // smmla v19.4s, v27.16b, v15.16b\n"
+ "fcvtl v9.4s, v9.4h\n"
+ ".inst 0x4e81a77a // smmla v26.4s, v27.16b, v1.16b\n"
+ "ldr q27, [x25, #0x40]\n"
+ ".inst 0x4e8fa412 // smmla v18.4s, v0.16b, v15.16b\n"
+ ".inst 0x4e81a411 // smmla v17.4s, v0.16b, v1.16b\n"
+ "ldr q0, [x25, #0x50]\n"
+ ".inst 0x4e95a773 // smmla v19.4s, v27.16b, v21.16b\n"
+ ".inst 0x4e90a77a // smmla v26.4s, v27.16b, v16.16b\n"
+ "ldr q27, [x25, #0x60]\n"
+ ".inst 0x4e95a412 // smmla v18.4s, v0.16b, v21.16b\n"
+ ".inst 0x4e90a411 // smmla v17.4s, v0.16b, v16.16b\n"
+ "ldr q0, [x25, #0x70]\n"
+ "add x25, x25, #0x88\n"
+ ".inst 0x4e9da773 // smmla v19.4s, v27.16b, v29.16b\n"
+ ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n"
+ "ldr d27, [x20, #0x0]\n"
+ ".inst 0x4e9da412 // smmla v18.4s, v0.16b, v29.16b\n"
+ ".inst 0x4e83a411 // smmla v17.4s, v0.16b, v3.16b\n"
+ "fcvtl v27.4s, v27.4h\n"
+ "uzp1 v0.2d, v19.2d, v26.2d\n"
+ "uzp2 v26.2d, v19.2d, v26.2d\n"
+ "fmul v19.4s, v27.4s, v20.s[0]\n"
+ "scvtf v0.4s, v0.4s, #0x4\n"
+ "scvtf v26.4s, v26.4s, #0x4\n"
+ "fmla v2.4s, v0.4s, v19.4s\n"
+ "ldr q19, [x23, #0x0]\n"
+ "uzp1 v0.2d, v18.2d, v17.2d\n"
+ "uzp2 v18.2d, v18.2d, v17.2d\n"
+ "fmul v17.4s, v27.4s, v20.s[1]\n"
+ "scvtf v0.4s, v0.4s, #0x4\n"
+ "scvtf v18.4s, v18.4s, #0x4\n"
+ "fmla v10.4s, v26.4s, v17.4s\n"
+ "ldr q17, [x23, #0x10]\n"
+ "fmul v26.4s, v27.4s, v20.s[2]\n"
+ "fmul v20.4s, v27.4s, v20.s[3]\n"
+ "fmla v12.4s, v0.4s, v26.4s\n"
+ "ldr d0, [x22, #-0x8]\n"
+ "ldr d26, [x21, #-0x8]\n"
+ "fcvtl v0.4s, v0.4h\n"
+ "fmla v28.4s, v18.4s, v20.4s\n"
+ "movi v20.4s, #0x0\n"
+ "movi v18.4s, #0x0\n"
+ ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n"
+ ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n"
+ "ldr q19, [x23, #0x20]\n"
+ "fcvtl v26.4s, v26.4h\n"
+ ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n"
+ ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n"
+ "ldr q19, [x23, #0x40]\n"
+ ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n"
+ ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n"
+ "ldr q19, [x23, #0x60]\n"
+ ".inst 0x4e9da674 // smmla v20.4s, v19.16b, v29.16b\n"
+ ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n"
+ "uzp1 v19.2d, v20.2d, v18.2d\n"
+ "scvtf v19.4s, v19.4s, #0x4\n"
+ "uzp2 v20.2d, v20.2d, v18.2d\n"
+ "fmul v18.4s, v27.4s, v9.s[0]\n"
+ "scvtf v20.4s, v20.4s, #0x4\n"
+ "fmla v11.4s, v19.4s, v18.4s\n"
+ "ldr q18, [x22, #0x0]\n"
+ "fmul v19.4s, v27.4s, v9.s[1]\n"
+ "fmla v13.4s, v20.4s, v19.4s\n"
+ "movi v19.4s, #0x0\n"
+ "movi v20.4s, #0x0\n"
+ ".inst 0x4e88a633 // smmla v19.4s, v17.16b, v8.16b\n"
+ ".inst 0x4e9fa634 // smmla v20.4s, v17.16b, v31.16b\n"
+ "ldr q17, [x23, #0x30]\n"
+ ".inst 0x4e8fa633 // smmla v19.4s, v17.16b, v15.16b\n"
+ ".inst 0x4e81a634 // smmla v20.4s, v17.16b, v1.16b\n"
+ "ldr q17, [x23, #0x50]\n"
+ ".inst 0x4e95a633 // smmla v19.4s, v17.16b, v21.16b\n"
+ ".inst 0x4e90a634 // smmla v20.4s, v17.16b, v16.16b\n"
+ "ldr q17, [x23, #0x70]\n"
+ "add x23, x23, #0x88\n"
+ ".inst 0x4e9da633 // smmla v19.4s, v17.16b, v29.16b\n"
+ ".inst 0x4e83a634 // smmla v20.4s, v17.16b, v3.16b\n"
+ "uzp1 v17.2d, v19.2d, v20.2d\n"
+ "scvtf v17.4s, v17.4s, #0x4\n"
+ "uzp2 v20.2d, v19.2d, v20.2d\n"
+ "fmul v19.4s, v27.4s, v9.s[2]\n"
+ "fmul v9.4s, v27.4s, v9.s[3]\n"
+ "scvtf v20.4s, v20.4s, #0x4\n"
+ "fmla v22.4s, v17.4s, v19.4s\n"
+ "ldr q17, [x22, #0x10]\n"
+ "movi v19.4s, #0x0\n"
+ ".inst 0x4e88a653 // smmla v19.4s, v18.16b, v8.16b\n"
+ "fmla v23.4s, v20.4s, v9.4s\n"
+ "movi v20.4s, #0x0\n"
+ "movi v9.4s, #0x0\n"
+ ".inst 0x4e9fa654 // smmla v20.4s, v18.16b, v31.16b\n"
+ "ldr q18, [x22, #0x20]\n"
+ ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n"
+ ".inst 0x4e8fa653 // smmla v19.4s, v18.16b, v15.16b\n"
+ ".inst 0x4e81a654 // smmla v20.4s, v18.16b, v1.16b\n"
+ "ldr q18, [x22, #0x40]\n"
+ ".inst 0x4e95a653 // smmla v19.4s, v18.16b, v21.16b\n"
+ ".inst 0x4e90a654 // smmla v20.4s, v18.16b, v16.16b\n"
+ "ldr q18, [x22, #0x60]\n"
+ ".inst 0x4e9da653 // smmla v19.4s, v18.16b, v29.16b\n"
+ ".inst 0x4e83a654 // smmla v20.4s, v18.16b, v3.16b\n"
+ "movi v18.4s, #0x0\n"
+ ".inst 0x4e9fa632 // smmla v18.4s, v17.16b, v31.16b\n"
+ "ldr q17, [x22, #0x30]\n"
+ ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n"
+ ".inst 0x4e81a632 // smmla v18.4s, v17.16b, v1.16b\n"
+ "ldr q17, [x22, #0x50]\n"
+ ".inst 0x4e95a629 // smmla v9.4s, v17.16b, v21.16b\n"
+ ".inst 0x4e90a632 // smmla v18.4s, v17.16b, v16.16b\n"
+ "ldr q17, [x22, #0x70]\n"
+ "add x22, x22, #0x88\n"
+ ".inst 0x4e9da629 // smmla v9.4s, v17.16b, v29.16b\n"
+ ".inst 0x4e83a632 // smmla v18.4s, v17.16b, v3.16b\n"
+ "uzp1 v17.2d, v19.2d, v20.2d\n"
+ "uzp2 v20.2d, v19.2d, v20.2d\n"
+ "fmul v19.4s, v27.4s, v0.s[0]\n"
+ "scvtf v17.4s, v17.4s, #0x4\n"
+ "scvtf v20.4s, v20.4s, #0x4\n"
+ "fmla v25.4s, v17.4s, v19.4s\n"
+ "ldr q19, [x21, #0x0]\n"
+ "fmul v17.4s, v27.4s, v0.s[1]\n"
+ "fmla v5.4s, v20.4s, v17.4s\n"
+ "ldr q17, [x21, #0x10]\n"
+ "uzp1 v20.2d, v9.2d, v18.2d\n"
+ "uzp2 v9.2d, v9.2d, v18.2d\n"
+ "fmul v18.4s, v27.4s, v0.s[2]\n"
+ "fmul v0.4s, v27.4s, v0.s[3]\n"
+ "scvtf v20.4s, v20.4s, #0x4\n"
+ "scvtf v9.4s, v9.4s, #0x4\n"
+ "fmla v7.4s, v20.4s, v18.4s\n"
+ "movi v20.4s, #0x0\n"
+ "movi v18.4s, #0x0\n"
+ ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n"
+ ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n"
+ "ldr q19, [x21, #0x20]\n"
+ "fmla v4.4s, v9.4s, v0.4s\n"
+ "movi v9.4s, #0x0\n"
+ "movi v0.4s, #0x0\n"
+ ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n"
+ "fmul v8.4s, v27.4s, v26.s[0]\n"
+ ".inst 0x4e9fa620 // smmla v0.4s, v17.16b, v31.16b\n"
+ "ldr q17, [x21, #0x30]\n"
+ ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n"
+ "fmul v31.4s, v27.4s, v26.s[1]\n"
+ ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n"
+ "ldr q19, [x21, #0x40]\n"
+ ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n"
+ "fmul v15.4s, v27.4s, v26.s[2]\n"
+ "fmul v27.4s, v27.4s, v26.s[3]\n"
+ ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n"
+ "ldr q1, [x21, #0x50]\n"
+ ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n"
+ ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n"
+ "ldr q26, [x21, #0x60]\n"
+ ".inst 0x4e95a429 // smmla v9.4s, v1.16b, v21.16b\n"
+ ".inst 0x4e90a420 // smmla v0.4s, v1.16b, v16.16b\n"
+ "ldr q21, [x21, #0x70]\n"
+ "add x21, x21, #0x88\n"
+ ".inst 0x4e9da754 // smmla v20.4s, v26.16b, v29.16b\n"
+ ".inst 0x4e83a752 // smmla v18.4s, v26.16b, v3.16b\n"
+ ".inst 0x4e9da6a9 // smmla v9.4s, v21.16b, v29.16b\n"
+ ".inst 0x4e83a6a0 // smmla v0.4s, v21.16b, v3.16b\n"
+ "uzp1 v29.2d, v20.2d, v18.2d\n"
+ "uzp2 v21.2d, v20.2d, v18.2d\n"
+ "scvtf v29.4s, v29.4s, #0x4\n"
+ "uzp1 v18.2d, v9.2d, v0.2d\n"
+ "uzp2 v16.2d, v9.2d, v0.2d\n"
+ "scvtf v21.4s, v21.4s, #0x4\n"
+ "fmla v6.4s, v29.4s, v8.4s\n"
+ "scvtf v18.4s, v18.4s, #0x4\n"
+ "scvtf v16.4s, v16.4s, #0x4\n"
+ "fmla v30.4s, v21.4s, v31.4s\n"
+ "fmla v24.4s, v18.4s, v15.4s\n"
+ "fmla v14.4s, v16.4s, v27.4s\n"
+ "bgt 3b\n"
+ "mov x20, %x[res_ptr]\n"
+ "subs x27, x27, #0x4\n"
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
+ "str q2, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q10, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q12, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q28, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q11, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q13, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q22, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q23, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q25, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q5, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q7, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q4, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q6, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q30, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q24, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q14, [x20, #0x0]\n"
+ "bne 2b\n"
+ "mov x20, #0x4\n"
+ "sub x10, x10, #0x10\n"
+ "cmp x10, #0x10\n"
+ "mov %x[res_ptr], x26\n"
+ "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
+ "bge 1b\n"
+ "4:" // Row loop skip
+ "cbz x10, 9f\n"
+ "5:" // Row tail: Row loop
+ "add x24, %x[b_ptr], #0x8\n"
+ "mov x23, %x[nc]\n"
+ "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
+ "6:" // Row tail: Column loop
+ "movi v2.16b, #0x0\n"
+ "movi v10.16b, #0x0\n"
+ "add x25, %x[a_ptr], #0x8\n"
+ "mov x21, %x[nb]\n"
+ "movi v12.16b, #0x0\n"
+ "movi v28.16b, #0x0\n"
+ "7:" // Row tail: Block loop
+ "ldr q6, [x24, #0x0]\n"
+ "ldr q5, [x24, #0x10]\n"
+ "movi v17.16b, #0x4\n"
+ "movi v8.4s, #0x0\n"
+ "ldr q4, [x25, #0x0]\n"
+ "ldr q13, [x25, #0x10]\n"
+ "movi v27.4s, #0x0\n"
+ "movi v0.4s, #0x0\n"
+ "ldr q31, [x24, #0x20]\n"
+ "ldr q14, [x24, #0x30]\n"
+ "movi v29.4s, #0x0\n"
+ "movi v22.16b, #0xf0\n"
+ "ldr q11, [x25, #0x20]\n"
+ "ldr q23, [x25, #0x30]\n"
+ "sshl v21.16b, v6.16b, v17.16b\n"
+ "sshl v16.16b, v5.16b, v17.16b\n"
+ "ldr q20, [x25, #0x40]\n"
+ "ldr q26, [x25, #0x50]\n"
+ "and v6.16b, v6.16b, v22.16b\n"
+ "and v5.16b, v5.16b, v22.16b\n"
+ "ldr q25, [x25, #0x60]\n"
+ "ldr q3, [x25, #0x70]\n"
+ "sshl v19.16b, v31.16b, v17.16b\n"
+ "sshl v18.16b, v14.16b, v17.16b\n"
+ "ldr d17, [x25, #-0x8]\n"
+ ".inst 0x4e95a488 // smmla v8.4s, v4.16b, v21.16b\n"
+ ".inst 0x4e90a49b // smmla v27.4s, v4.16b, v16.16b\n"
+ "and v31.16b, v31.16b, v22.16b\n"
+ ".inst 0x4e95a5a0 // smmla v0.4s, v13.16b, v21.16b\n"
+ ".inst 0x4e90a5bd // smmla v29.4s, v13.16b, v16.16b\n"
+ "and v14.16b, v14.16b, v22.16b\n"
+ "sub x20, x24, #0x8\n"
+ "ldr d16, [x20, #0x0]\n"
+ "subs x21, x21, #0x1\n"
+ "add x25, x25, #0x88\n"
+ "fcvtl v17.4s, v17.4h\n"
+ "add x24, x24, #0x48\n"
+ ".inst 0x4e93a568 // smmla v8.4s, v11.16b, v19.16b\n"
+ ".inst 0x4e92a57b // smmla v27.4s, v11.16b, v18.16b\n"
+ ".inst 0x4e93a6e0 // smmla v0.4s, v23.16b, v19.16b\n"
+ ".inst 0x4e92a6fd // smmla v29.4s, v23.16b, v18.16b\n"
+ "fcvtl v16.4s, v16.4h\n"
+ ".inst 0x4e86a688 // smmla v8.4s, v20.16b, v6.16b\n"
+ ".inst 0x4e85a69b // smmla v27.4s, v20.16b, v5.16b\n"
+ "fmul v23.4s, v16.4s, v17.s[0]\n"
+ "fmul v21.4s, v16.4s, v17.s[1]\n"
+ "fmul v1.4s, v16.4s, v17.s[2]\n"
+ "fmul v20.4s, v16.4s, v17.s[3]\n"
+ ".inst 0x4e86a740 // smmla v0.4s, v26.16b, v6.16b\n"
+ ".inst 0x4e85a75d // smmla v29.4s, v26.16b, v5.16b\n"
+ ".inst 0x4e9fa728 // smmla v8.4s, v25.16b, v31.16b\n"
+ ".inst 0x4e8ea73b // smmla v27.4s, v25.16b, v14.16b\n"
+ ".inst 0x4e9fa460 // smmla v0.4s, v3.16b, v31.16b\n"
+ ".inst 0x4e8ea47d // smmla v29.4s, v3.16b, v14.16b\n"
+ "uzp1 v19.2d, v8.2d, v27.2d\n"
+ "uzp2 v18.2d, v8.2d, v27.2d\n"
+ "scvtf v19.4s, v19.4s, #0x4\n"
+ "uzp1 v17.2d, v0.2d, v29.2d\n"
+ "uzp2 v16.2d, v0.2d, v29.2d\n"
+ "scvtf v18.4s, v18.4s, #0x4\n"
+ "fmla v2.4s, v19.4s, v23.4s\n"
+ "scvtf v17.4s, v17.4s, #0x4\n"
+ "scvtf v16.4s, v16.4s, #0x4\n"
+ "fmla v10.4s, v18.4s, v21.4s\n"
+ "fmla v12.4s, v17.4s, v1.4s\n"
+ "fmla v28.4s, v16.4s, v20.4s\n"
+ "bgt 7b\n"
+ "mov x20, %x[res_ptr]\n"
+ "cmp x10, #0x1\n"
+ "str q2, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "ble 8f\n"
+ "cmp x10, #0x2\n"
+ "str q10, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "ble 8f\n"
+ "cmp x10, #0x3\n"
+ "str q12, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "ble 8f\n"
+ "str q28, [x20, #0x0]\n"
+ "8:" // Row tail: Accumulator store skip
+ "subs x23, x23, #0x4\n"
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
+ "bne 6b\n"
+ "subs x10, x10, #0x4\n"
+ "add %x[a_ptr], %x[a_ptr], x9\n"
+ "mov %x[res_ptr], x22\n"
+ "bgt 5b\n"
+ "9:" // Row tail: Row loop skip
+ : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
+ : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
+ : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+ );
+ return;
+ }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
float sumf[4][4];
int sumi;
@@ -1944,7 +1974,7 @@ void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
- (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+ (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
}
sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
}
@@ -1957,7 +1987,6 @@ void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
}
}
}
-#endif
}
void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
@@ -1980,8 +2009,9 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
UNUSED(ncols_interleaved);
UNUSED(blocklen);
-#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
- if (ggml_sve_cnt_b == QK8_0) {
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
+#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
+ if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) {
const void * b_ptr = vx;
const void * a_ptr = vy;
float * res_ptr = s;
@@ -2391,134 +2421,682 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
);
return;
}
- else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
- GGML_ASSERT((ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
- "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal "
- "performance");
- }
- else if (ggml_cpu_has_neon()) {
- GGML_ASSERT(((ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) || ggml_cpu_has_matmul_int8()) &&
- "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 "
- "quantization format for optimal performance");
- }
-#endif
-#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
- GGML_ASSERT(ggml_cpu_has_sve() &&
- "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance");
-#elif defined(__ARM_NEON) && defined(__aarch64__)
- GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
- "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
- "performance");
+#endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
#elif defined(__AVX2__) || defined(__AVX512F__)
- const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx;
- const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy;
- int64_t b_nb = n / QK4_0;
- int64_t y = 0;
- // Mask to mask out nibbles from packed bytes
- const __m256i m4b = _mm256_set1_epi8(0x0F);
- const __m128i loadMask = _mm_blend_epi32(_mm_setzero_si128(), _mm_set1_epi32(0xFFFFFFFF), 3);
- // Lookup table to convert signed nibbles to signed bytes
- __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
- signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
- // Permute mask used for easier vector processing at later stages
- __m256i requiredOrder = _mm256_set_epi32(3 ,2 ,1 ,0, 7 ,6, 5, 4);
+ {
+ const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx;
+ const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy;
+ int64_t b_nb = n / QK4_0;
+ int64_t y = 0;
+ // Mask to mask out nibbles from packed bytes
+ const __m256i m4b = _mm256_set1_epi8(0x0F);
+ const __m128i loadMask = _mm_blend_epi32(_mm_setzero_si128(), _mm_set1_epi32(0xFFFFFFFF), 3);
+ // Lookup table to convert signed nibbles to signed bytes
+ __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
+ signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
+ // Permute mask used for easier vector processing at later stages
+ __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
+ int64_t xstart = 0;
+ int anr = nr - nr%16; // Used to align nr with boundary of 16
+ #ifdef __AVX512F__
+ int anc = nc - nc%16; // Used to align nc with boundary of 16
+ // Mask to mask out nibbles from packed bytes expanded to 512 bit length
+ const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
+ // Lookup table to convert signed nibbles to signed bytes expanded to 512 bit length
+ __m512i signextendlutexpanded = _mm512_inserti32x8(_mm512_castsi256_si512(signextendlut), signextendlut, 1);
- // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
- int anr = nr - nr %16; // Used to align nr with boundary of 16
+ // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
+ for (; y < anr / 4; y += 4) {
- for (; y < anr / 4; y += 4) {
- const block_q8_0x4 * a_ptrs[4];
+ const block_q8_0x4 * a_ptrs[4];
- a_ptrs[0] = a_ptr_start + (y * nb);
- for (int i = 0; i < 3; ++i) {
- a_ptrs[i + 1] = a_ptrs[i] + nb;
- }
-
- // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation
- for (int64_t x = 0; x < nc / 8; x++) {
-
- const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
-
- // Master FP accumulators
- __m256 acc_rows[16];
- for (int i = 0; i < 16; i++) {
- acc_rows[i] = _mm256_setzero_ps();
+ a_ptrs[0] = a_ptr_start + (y * nb);
+ for (int i = 0; i < 3; ++i) {
+ a_ptrs[i + 1] = a_ptrs[i] + nb;
}
- for (int64_t b = 0; b < nb; b++) {
- // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
- const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
- const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
- const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
- const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
+ // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation
+ for (int64_t x = 0; x < anc / 8; x += 2) {
- // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
- const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
- const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
- const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
- const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+ const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);
+ const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
- // 4-bit -> 8-bit - Sign is maintained
- const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
- const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
+ // Master FP accumulators
+ __m512 acc_rows[16];
+ for (int i = 0; i < 16; i++) {
+ acc_rows[i] = _mm512_setzero_ps();
+ }
- const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
- const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
+ for (int64_t b = 0; b < nb; b++) {
+ // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF
+ const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));
+ const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));
+ const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));
+ const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));
- const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
- const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
+ const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));
+ const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));
+ const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));
+ const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));
- const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
- const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
+ // Save the values in the following vectors in the formats B0B1B4B5B8B9BCBD, B2B3B6B7BABBBEBF for further processing and storing of values
+ const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+ const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+ const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+ const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
- // Shuffle pattern one - right side input
- const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
- const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
+ const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
+ const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
+ const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
+ const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
- const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
- const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
+ const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
+ const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
+ const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
+ const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
- const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
- const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
+ // 4-bit -> 8-bit - Sign is maintained
+ const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)
+ const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)
- const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
- const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
+ const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15)
+ const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15)
- // Shuffle pattern two - right side input
+ const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23)
+ const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23)
- const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
- const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
+ const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31)
+ const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
- const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
- const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
+ // Shuffle pattern one - right side input
+ const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
+ const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
- const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
- const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
+ const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
+ const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
- const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
- const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
+ const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
+ const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
- // Scale values - Load the wight scale values of block_q4_0x8
- const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
+ const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
+ const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
+
+ // Shuffle pattern two - right side input
+
+ const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
+ const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
+
+ const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
+ const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
+
+ const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
+ const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
+
+ const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
+ const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
+
+ // Scale values - Load the weight scale values of two block_q4_0x8
+ const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+
+ // Process LHS in pairs of rows
+ for (int rp = 0; rp < 4; rp++) {
+
+ // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
+ // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
+ __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
+ __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);
+ __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);
+ __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
+ __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);
+ __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);
+ __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
+ __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);
+ __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);
+ __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
+ __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);
+ __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);
+
+ __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1);
+ __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1);
+ __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);
+ __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);
+ __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);
+ __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);
+ __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);
+ __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);
+
+ // Shuffle pattern one - left side input
+
+ const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
+ const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
+
+ const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
+ const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
+
+ const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
+ const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
+
+ const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
+ const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
+
+ // Shuffle pattern two - left side input
+
+ const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
+ const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
+
+ const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
+ const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
+
+ const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
+ const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
+
+ const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
+ const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
+
+ // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+ // Resembles MMLAs into 2x2 matrices in ARM Version
+ __m512i iacc_mat_00_sp1 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1));
+ __m512i iacc_mat_01_sp1 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1));
+ __m512i iacc_mat_10_sp1 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1));
+ __m512i iacc_mat_11_sp1 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1));
+ __m512i iacc_mat_00_sp2 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2));
+ __m512i iacc_mat_01_sp2 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2));
+ __m512i iacc_mat_10_sp2 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2));
+ __m512i iacc_mat_11_sp2 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2));
+
+ // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+ __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
+ __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
+ __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
+ __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+
+
+ // Straighten out to make 4 row vectors
+ __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
+ __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
+ __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
+ __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
+
+ // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
+ const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);
+ const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
+
+ // Multiply with appropiate scales and accumulate
+ acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
+ acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
+ acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
+ acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
+ }
+ }
+
+ // Store the accumulated values
+ for (int i = 0; i < 16; i++) {
+ _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
+ }
+ }
+ }
+ // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation
+ for (; y < nr / 4; y ++) {
+
+ const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
+
+ // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation
+ for (int64_t x = 0; x < anc / 8; x += 2) {
+
+ const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);
+ const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
+
+ // Master FP accumulators
+ __m512 acc_rows[4];
+ for (int i = 0; i < 4; i++) {
+ acc_rows[i] = _mm512_setzero_ps();
+ }
+
+ for (int64_t b = 0; b < nb; b++) {
+ // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF
+ const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));
+ const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));
+ const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));
+ const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));
+
+ const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));
+ const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));
+ const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));
+ const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));
+
+ // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess
+ const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+ const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+ const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+ const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+
+ const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
+ const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
+ const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
+ const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
+
+ const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
+ const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
+ const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
+ const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
+
+ // 4-bit -> 8-bit - Sign is maintained
+ const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)
+ const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)
+
+ const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15)
+ const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15)
+
+ const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23)
+ const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23)
+
+ const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31)
+ const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
+
+ // Shuffle pattern one - right side input
+ const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
+ const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
+
+ const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
+ const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
+
+ const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
+ const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
+
+ const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
+ const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
+
+ // Shuffle pattern two - right side input
+
+ const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
+ const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
+
+ const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
+ const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
+
+ const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
+ const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
+
+ const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
+ const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
+
+
+ // Scale values - Load the weight scale values of two block_q4_0x8
+ const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+
+ // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
+ // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
+ __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
+ __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);
+ __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);
+ __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
+ __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);
+ __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);
+ __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
+ __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);
+ __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);
+ __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
+ __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);
+ __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);
+
+ __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1);
+ __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1);
+ __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);
+ __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);
+ __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);
+ __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);
+ __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);
+ __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);
+
+ // Shuffle pattern one - left side input
+
+ const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
+ const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
+
+ const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
+ const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
+
+ const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
+ const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
+
+ const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
+ const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
+
+ // Shuffle pattern two - left side input
+
+ const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
+ const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
+
+ const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
+ const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
+
+ const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
+ const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
+
+ const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
+ const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
+
+ // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+ // Resembles MMLAs into 2x2 matrices in ARM Version
+ __m512i iacc_mat_00_sp1 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1));
+ __m512i iacc_mat_01_sp1 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1));
+ __m512i iacc_mat_10_sp1 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1));
+ __m512i iacc_mat_11_sp1 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1));
+ __m512i iacc_mat_00_sp2 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2));
+ __m512i iacc_mat_01_sp2 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2));
+ __m512i iacc_mat_10_sp2 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2));
+ __m512i iacc_mat_11_sp2 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2));
+
+ // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+ __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
+ __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
+ __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
+ __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+
+
+ // Straighten out to make 4 row vectors
+ __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
+ __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
+ __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
+ __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
+
+ // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
+ const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68);
+ const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
+
+ // Multiply with appropiate scales and accumulate
+ acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
+ acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
+ acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
+ acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
+ }
+
+ // Store the accumulated values
+ for (int i = 0; i < 4; i++) {
+ _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
+ }
+ }
+ }
+ if (anc != nc) {
+ xstart = anc/8;
+ y = 0;
+ }
+ #endif // __AVX512F__
+
+ // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
+
+ for (; y < anr / 4; y += 4) {
+ const block_q8_0x4 * a_ptrs[4];
+
+ a_ptrs[0] = a_ptr_start + (y * nb);
+ for (int i = 0; i < 3; ++i) {
+ a_ptrs[i + 1] = a_ptrs[i] + nb;
+ }
+
+ // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation
+ for (int64_t x = xstart; x < nc / 8; x++) {
+
+ const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
+
+ // Master FP accumulators
+ __m256 acc_rows[16];
+ for (int i = 0; i < 16; i++) {
+ acc_rows[i] = _mm256_setzero_ps();
+ }
+
+ for (int64_t b = 0; b < nb; b++) {
+ // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
+ const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
+ const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
+ const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
+ const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
+
+ // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
+ const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+ const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+ const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+ const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+
+ // 4-bit -> 8-bit - Sign is maintained
+ const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
+ const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
+
+ const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
+ const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
+
+ const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
+ const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
+
+ const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
+ const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
+
+ // Shuffle pattern one - right side input
+ const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
+ const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
+
+ const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
+ const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
+
+ const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
+ const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
+
+ const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
+ const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
+
+ // Shuffle pattern two - right side input
+
+ const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
+ const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
+
+ const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
+ const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
+
+ const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
+ const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
+
+ const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
+ const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
+
+ // Scale values - Load the wight scale values of block_q4_0x8
+ const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
+
+ // Process LHS in groups of four
+ for (int rp = 0; rp < 4; rp++) {
+ // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
+ // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
+ __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
+ __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
+ __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
+ __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
+ __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
+ __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
+ __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
+ __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
+ __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
+ __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
+ __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
+ __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
+
+ // Shuffle pattern one - left side input
+ const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
+ const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
+
+ const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
+ const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
+
+ const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
+ const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
+
+ const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
+ const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
+
+ // Shuffle pattern two - left side input
+ const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
+ const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
+
+ const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
+ const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
+
+ const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
+ const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
+
+ const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
+ const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
+
+ // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+ // Resembles MMLAs into 2x2 matrices in ARM Version
+ __m256i iacc_mat_00_sp1 =
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
+ __m256i iacc_mat_01_sp1 =
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
+ __m256i iacc_mat_10_sp1 =
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
+ __m256i iacc_mat_11_sp1 =
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
+ __m256i iacc_mat_00_sp2 =
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
+ __m256i iacc_mat_01_sp2 =
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
+ __m256i iacc_mat_10_sp2 =
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
+ __m256i iacc_mat_11_sp2 =
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
+
+ // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+ __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
+ __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
+ __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
+ __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+
+ // Straighten out to make 4 row vectors
+ __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
+ __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
+ __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
+ __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
+
+ // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
+ const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
+
+ // Multiply with appropiate scales and accumulate
+ acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
+ acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
+ acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
+ acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
+ }
+ }
+
+ // Store the accumulated values
+ for (int i = 0; i < 16; i++) {
+ _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
+ }
+ }
+ }
+
+ // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation
+ for (; y < nr / 4; y ++) {
+
+ const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
+
+ // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
+ for (int64_t x = xstart; x < nc / 8; x++) {
+
+ const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
+
+ // Master FP accumulators
+ __m256 acc_rows[4];
+ for (int i = 0; i < 4; i++) {
+ acc_rows[i] = _mm256_setzero_ps();
+ }
+
+ for (int64_t b = 0; b < nb; b++) {
+ // Load the eight block_q8_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
+ const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
+ const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
+ const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
+ const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
+
+ // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess
+ const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+ const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+ const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+ const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+
+ // 4-bit -> 8-bit - Sign is maintained
+ const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
+ const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
+
+ const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
+ const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
+
+ const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
+ const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
+
+ const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
+ const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
+
+ // Shuffle pattern one - right side input
+ const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
+ const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
+
+ const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
+ const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
+
+ const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
+ const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
+
+ const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
+ const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
+
+ // Shuffle pattern two - right side input
+
+ const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
+ const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
+
+ const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
+ const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
+
+ const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
+ const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
+
+ const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
+ const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
+
+ // Scale values - Load the wight scale values of block_q4_0x8
+ const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
- // Process LHS in groups of four
- for (int rp = 0; rp < 4; rp++) {
// Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
// Loaded as set of 128 bit vectors and repeated into a 256 bit vector
- __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
+ __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
__m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
__m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
- __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
+ __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
__m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
__m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
- __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
+ __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
__m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
__m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
- __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
+ __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
__m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
__m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
// Shuffle pattern one - left side input
+
const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
@@ -2532,6 +3110,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
// Shuffle pattern two - left side input
+
const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
@@ -2547,21 +3126,21 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
// Resembles MMLAs into 2x2 matrices in ARM Version
__m256i iacc_mat_00_sp1 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
__m256i iacc_mat_01_sp1 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
__m256i iacc_mat_10_sp1 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
__m256i iacc_mat_11_sp1 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
__m256i iacc_mat_00_sp2 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
__m256i iacc_mat_01_sp2 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
__m256i iacc_mat_10_sp2 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
__m256i iacc_mat_11_sp2 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
__m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
@@ -2569,6 +3148,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
__m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
__m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+
// Straighten out to make 4 row vectors
__m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
__m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
@@ -2576,187 +3156,24 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
__m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
- const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
+ const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask);
// Multiply with appropiate scales and accumulate
- acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
- acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
- acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
- acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
+ acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
+ acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
+ acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
+ acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
+ }
+
+ // Store the accumulated values
+ for (int i = 0; i < 4; i++) {
+ _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
}
}
-
- // Store the accumulated values
- for (int i = 0; i < 16; i++) {
- _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
- }
}
+ return;
}
-
- // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation
- for (; y < nr / 4; y ++) {
-
- const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
-
- // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
- for (int64_t x = 0; x < nc / 8; x++) {
-
- const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
-
- // Master FP accumulators
- __m256 acc_rows[4];
- for (int i = 0; i < 4; i++) {
- acc_rows[i] = _mm256_setzero_ps();
- }
-
- for (int64_t b = 0; b < nb; b++) {
- // Load the eight block_q8_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
- const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
- const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
- const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
- const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
-
- // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess
- const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
- const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
- const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
- const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
-
- // 4-bit -> 8-bit - Sign is maintained
- const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
- const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
-
- const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
- const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
-
- const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
- const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
-
- const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
- const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
-
- // Shuffle pattern one - right side input
- const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
- const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
-
- const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
- const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
-
- const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
- const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
-
- const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
- const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
-
- // Shuffle pattern two - right side input
-
- const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
- const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
-
- const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
- const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
-
- const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
- const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
-
- const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
- const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
-
- // Scale values - Load the wight scale values of block_q4_0x8
- const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
-
- // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
- // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
- __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
- __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
- __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
- __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
- __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
- __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
- __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
- __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
- __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
- __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
- __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
- __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
-
- // Shuffle pattern one - left side input
-
- const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
- const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
-
- const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
- const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
-
- const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
- const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
-
- const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
- const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
-
- // Shuffle pattern two - left side input
-
- const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
- const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
-
- const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
- const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
-
- const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
- const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
-
- const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
- const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
-
- // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
- // Resembles MMLAs into 2x2 matrices in ARM Version
- __m256i iacc_mat_00_sp1 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
- __m256i iacc_mat_01_sp1 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
- __m256i iacc_mat_10_sp1 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
- __m256i iacc_mat_11_sp1 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
- __m256i iacc_mat_00_sp2 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
- __m256i iacc_mat_01_sp2 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
- __m256i iacc_mat_10_sp2 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
- __m256i iacc_mat_11_sp2 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
-
- // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
- __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
- __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
- __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
- __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
-
-
- // Straighten out to make 4 row vectors
- __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
- __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
- __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
- __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
-
- // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
- const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask);
-
- // Multiply with appropiate scales and accumulate
- acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
- acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
- acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
- acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
- }
-
- // Store the accumulated values
- for (int i = 0; i < 4; i++) {
- _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
- }
- }
- }
-#else
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
float sumf[4][8];
int sumi;
@@ -2789,5 +3206,4 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
}
}
}
-#endif
}
diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c
index e485326ab..70187b9b6 100644
--- a/ggml/src/ggml-alloc.c
+++ b/ggml/src/ggml-alloc.c
@@ -294,6 +294,12 @@ static void ggml_dyn_tallocr_reset(struct ggml_dyn_tallocr * alloc) {
alloc->free_blocks[0].offset = 0;
alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
alloc->max_size = 0;
+
+#ifdef GGML_ALLOCATOR_DEBUG
+ for (int i = 0; i < 1024; i++) {
+ alloc->allocated_tensors[i].tensor = NULL;
+ }
+#endif
}
static struct ggml_dyn_tallocr * ggml_dyn_tallocr_new(size_t alignment) {
diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h
index 36ca37086..b0d4141cc 100644
--- a/ggml/src/ggml-backend-impl.h
+++ b/ggml/src/ggml-backend-impl.h
@@ -38,15 +38,16 @@ extern "C" {
typedef void * ggml_backend_buffer_context_t;
struct ggml_backend_buffer_i {
- const char * (*GGML_CALL get_name) (ggml_backend_buffer_t buffer);
- void (*GGML_CALL free_buffer)(ggml_backend_buffer_t buffer);
- void * (*GGML_CALL get_base) (ggml_backend_buffer_t buffer);
- void (*GGML_CALL init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
- void (*GGML_CALL set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
- void (*GGML_CALL get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
- bool (*GGML_CALL cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer
- void (*GGML_CALL clear) (ggml_backend_buffer_t buffer, uint8_t value);
- void (*GGML_CALL reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
+ const char * (*GGML_CALL get_name) (ggml_backend_buffer_t buffer);
+ void (*GGML_CALL free_buffer) (ggml_backend_buffer_t buffer);
+ void * (*GGML_CALL get_base) (ggml_backend_buffer_t buffer);
+ void (*GGML_CALL init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+ void (*GGML_CALL memset_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
+ void (*GGML_CALL set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+ void (*GGML_CALL get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
+ bool (*GGML_CALL cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer
+ void (*GGML_CALL clear) (ggml_backend_buffer_t buffer, uint8_t value);
+ void (*GGML_CALL reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
};
struct ggml_backend_buffer {
diff --git a/ggml/src/ggml-backend.c b/ggml/src/ggml-backend.c
index b5d9301a7..ba280e064 100644
--- a/ggml/src/ggml-backend.c
+++ b/ggml/src/ggml-backend.c
@@ -246,6 +246,22 @@ GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void *
buf->iface.get_tensor(buf, tensor, data, offset, size);
}
+GGML_API GGML_CALL void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+ GGML_ASSERT(buf != NULL && "tensor buffer not set");
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+
+ if (!size) {
+ return;
+ }
+
+ GGML_ASSERT(buf->iface.memset_tensor != NULL && "memset not supported by backend buffer");
+
+ buf->iface.memset_tensor(buf, tensor, value, offset, size);
+}
+
void ggml_backend_synchronize(ggml_backend_t backend) {
if (backend->iface.synchronize == NULL) {
return;
@@ -569,6 +585,12 @@ GGML_CALL static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t
free(buffer->context);
}
+GGML_CALL static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+ memset((char *)tensor->data + offset, value, size);
+
+ GGML_UNUSED(buffer);
+}
+
GGML_CALL static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
memcpy((char *)tensor->data + offset, data, size);
@@ -600,6 +622,7 @@ static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
/* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer,
/* .get_base = */ ggml_backend_cpu_buffer_get_base,
/* .init_tensor = */ NULL, // no initialization required
+ /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
/* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
@@ -613,6 +636,7 @@ static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
/* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
/* .get_base = */ ggml_backend_cpu_buffer_get_base,
/* .init_tensor = */ NULL, // no initialization required
+ /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
/* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
@@ -980,6 +1004,7 @@ static struct ggml_backend_buffer_i ggml_backend_multi_buffer_context_interface(
/* .free_buffer = */ ggml_backend_multi_buffer_free_buffer,
/* .get_base = */ NULL,
/* .init_tensor = */ NULL,
+ /* .memset_tensor = */ NULL,
/* .set_tensor = */ NULL,
/* .get_tensor = */ NULL,
/* .cpy_tensor = */ NULL,
diff --git a/ggml/src/ggml-cann.cpp b/ggml/src/ggml-cann.cpp
index aa315b83f..d3ab78006 100644
--- a/ggml/src/ggml-cann.cpp
+++ b/ggml/src/ggml-cann.cpp
@@ -1037,6 +1037,7 @@ static ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
/* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
/* .get_base = */ ggml_backend_cann_buffer_get_base,
/* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
+ /* .memset_tensor = */ NULL,
/* .set_tensor = */ ggml_backend_cann_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_cann_buffer_get_tensor,
/* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor,
diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h
index e6a570107..edfa49614 100644
--- a/ggml/src/ggml-cann/common.h
+++ b/ggml/src/ggml-cann/common.h
@@ -227,6 +227,7 @@ struct ggml_backend_cann_context {
* @brief Destructor for cleaning up resources.
*/
~ggml_backend_cann_context() {
+ ggml_cann_set_device(device);
if (copy_event != nullptr) {
ACL_CHECK(aclrtDestroyEvent(copy_event));
}
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index 54f1a7c2d..6efdab14c 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -21,6 +21,8 @@
#include "ggml-cuda/mmq.cuh"
#include "ggml-cuda/mmvq.cuh"
#include "ggml-cuda/norm.cuh"
+#include "ggml-cuda/opt-step-adamw.cuh"
+#include "ggml-cuda/out-prod.cuh"
#include "ggml-cuda/pad.cuh"
#include "ggml-cuda/pool2d.cuh"
#include "ggml-cuda/quantize.cuh"
@@ -32,6 +34,7 @@
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
+#include "ggml-cuda/rwkv-wkv.cuh"
#include
#include
@@ -133,7 +136,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
return res;
#else
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_HIPBLAS)
cudaError_t err;
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
{
@@ -146,7 +149,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
return err;
#else
return cudaMalloc(ptr, size);
-#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
+#endif // !defined(GGML_USE_HIPBLAS)
#endif
}
@@ -184,7 +187,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
for (int id = 0; id < info.device_count; ++id) {
int device_vmm = 0;
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
CUdevice device;
CU_CHECK(cuDeviceGet(&device, id));
CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
@@ -196,7 +199,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
alloc_prop.location.id = id;
CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
}
-#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
info.devices[id].vmm = !!device_vmm;
cudaDeviceProp prop;
@@ -332,7 +335,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
};
// pool with virtual memory
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
@@ -426,14 +429,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
}
};
-#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) {
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
if (ggml_cuda_info().devices[device].vmm) {
return std::unique_ptr(new ggml_cuda_pool_vmm(device));
}
-#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
return std::unique_ptr(new ggml_cuda_pool_leg(device));
}
@@ -493,6 +496,14 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t
}
}
+GGML_CALL static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+
+ ggml_cuda_set_device(ctx->device);
+ CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread));
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+}
+
GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
@@ -544,6 +555,7 @@ static ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
/* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer,
/* .get_base = */ ggml_backend_cuda_buffer_get_base,
/* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor,
+ /* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor,
/* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor,
@@ -860,6 +872,7 @@ static struct ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {
/* .free_buffer = */ ggml_backend_cuda_split_buffer_free_buffer,
/* .get_base = */ ggml_backend_cuda_split_buffer_get_base,
/* .init_tensor = */ ggml_backend_cuda_split_buffer_init_tensor,
+ /* .memset_tensor = */ NULL,
/* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor,
/* .cpy_tensor = */ NULL,
@@ -2168,6 +2181,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_REPEAT:
ggml_cuda_op_repeat(ctx, dst);
break;
+ case GGML_OP_REPEAT_BACK:
+ ggml_cuda_op_repeat_back(ctx, dst);
+ break;
case GGML_OP_GET_ROWS:
ggml_cuda_op_get_rows(ctx, dst);
break;
@@ -2201,6 +2217,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_NEG:
ggml_cuda_op_neg(ctx, dst);
break;
+ case GGML_UNARY_OP_STEP:
+ ggml_cuda_op_step(ctx, dst);
+ break;
case GGML_UNARY_OP_GELU:
ggml_cuda_op_gelu(ctx, dst);
break;
@@ -2225,6 +2244,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_HARDSWISH:
ggml_cuda_op_hardswish(ctx, dst);
break;
+ case GGML_UNARY_OP_EXP:
+ ggml_cuda_op_exp(ctx, dst);
+ break;
default:
return false;
}
@@ -2267,6 +2289,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_MUL_MAT_ID:
ggml_cuda_mul_mat_id(ctx, dst);
break;
+ case GGML_OP_OUT_PROD:
+ ggml_cuda_out_prod(ctx, dst);
+ break;
case GGML_OP_SCALE:
ggml_cuda_op_scale(ctx, dst);
break;
@@ -2324,6 +2349,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cuda_cross_entropy_loss(ctx, dst);
break;
+ case GGML_OP_RWKV_WKV:
+ ggml_cuda_op_rwkv_wkv(ctx, dst);
+ break;
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+ ggml_cuda_cross_entropy_loss_back(ctx, dst);
+ break;
+ case GGML_OP_OPT_STEP_ADAMW:
+ ggml_cuda_opt_step_adamw(ctx, dst);
+ break;
default:
return false;
}
@@ -2451,6 +2485,7 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
for (int i = 0; i < GGML_MAX_SRC; i++) {
graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
}
+ memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
}
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
@@ -2482,6 +2517,12 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
return false;
}
}
+
+ if (node->op == GGML_OP_SCALE &&
+ memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
+ return false;
+ }
+
return true;
}
@@ -2693,7 +2734,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
// First call with null argument gets number of nodes in graph
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
// Subsequent call with non-null argument gets nodes
+ cuda_ctx->cuda_graph->nodes.clear();
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
+ cuda_ctx->cuda_graph->params.clear();
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
if (cuda_ctx->cuda_graph->num_nodes > 0) {
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
@@ -2761,6 +2804,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_NEG:
+ case GGML_UNARY_OP_STEP:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
@@ -2769,6 +2813,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
+ case GGML_UNARY_OP_EXP:
return ggml_is_contiguous(op->src[0]);
default:
return false;
@@ -2785,6 +2830,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
return false;
}
+#ifdef GGML_USE_MUSA
+ if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
+ !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
+ return false;
+ }
+#endif // GGML_USE_MUSA
switch (a->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
@@ -2808,11 +2859,18 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
+#ifdef GGML_USE_MUSA
+ if (a->type == GGML_TYPE_Q3_K) {
+ return false;
+ }
+#endif // GGML_USE_MUSA
return true;
default:
return false;
}
} break;
+ case GGML_OP_OUT_PROD:
+ return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
case GGML_OP_GET_ROWS:
{
switch (op->src[0]->type) {
@@ -2841,6 +2899,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
return true;
}
+ if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
+ return true;
+ }
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
return true;
}
@@ -2869,6 +2930,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
} break;
case GGML_OP_DUP:
case GGML_OP_REPEAT:
+ {
+ ggml_type src0_type = op->src[0]->type;
+ return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
+ } break;
+ case GGML_OP_REPEAT_BACK:
+ return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1;
case GGML_OP_CONCAT:
{
ggml_type src0_type = op->src[0]->type;
@@ -2922,22 +2989,28 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
+ case GGML_OP_RWKV_WKV:
return true;
- case GGML_OP_FLASH_ATTN_EXT:
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
-#else
- if (op->src[0]->ne[0] == 128) {
- return true;
- }
+ case GGML_OP_FLASH_ATTN_EXT: {
+#ifndef FLASH_ATTN_AVAILABLE
+ return false;
+#endif
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
return true;
}
- return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
- op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
+ if (op->src[0]->ne[0] == 128) {
+ return true;
+ }
+ if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
+ return true;
+ }
+ const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
+ return cc >= CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
+ }
case GGML_OP_CROSS_ENTROPY_LOSS:
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+ case GGML_OP_OPT_STEP_ADAMW:
return true;
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
default:
return false;
}
diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu
index e1390a041..c7b6be4e2 100644
--- a/ggml/src/ggml-cuda/binbcast.cu
+++ b/ggml/src/ggml-cuda/binbcast.cu
@@ -1,4 +1,5 @@
#include "binbcast.cuh"
+#include
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
return b;
@@ -90,6 +91,30 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
}
+template
+static __global__ void k_repeat_back(
+ const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
+ const int64_t ne0, const int64_t ne1, const int64_t ne2) {
+
+ const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
+ const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y;
+ const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z;
+
+ if (tid0 >= ne0) {
+ return;
+ }
+
+ T sum = 0;
+ for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
+ for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
+ for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
+ sum += src[i2*ne01*ne00 + i1*ne00 + i0];
+ }
+ }
+ }
+ dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
+}
+
template
struct bin_bcast_cuda {
template
@@ -247,6 +272,16 @@ struct bin_bcast_cuda {
}
};
+template
+static void repeat_back_cuda(
+ const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
+ const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) {
+
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2);
+ k_repeat_back<<>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
+}
+
template
static void ggml_cuda_op_bin_bcast(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
@@ -286,3 +321,35 @@ void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}
+
+void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(src0->type == dst->type);
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_can_repeat(dst, src0));
+
+ cudaStream_t stream = ctx.stream();
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int64_t ne02 = src0->ne[2];
+ GGML_ASSERT(src0->ne[3] == 1);
+
+ const int64_t ne0 = dst->ne[0];
+ const int64_t ne1 = dst->ne[1];
+ const int64_t ne2 = dst->ne[2];
+ GGML_ASSERT(dst->ne[3] == 1);
+
+ switch (dst->type) {
+ case GGML_TYPE_F32: {
+ const float * src0_d = (const float *) src0->data;
+ float * dst_d = (float *) dst->data;
+ repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream);
+ } break;
+ default: {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
diff --git a/ggml/src/ggml-cuda/binbcast.cuh b/ggml/src/ggml-cuda/binbcast.cuh
index 198c9ef6f..3ac1c9b03 100644
--- a/ggml/src/ggml-cuda/binbcast.cuh
+++ b/ggml/src/ggml-cuda/binbcast.cuh
@@ -5,3 +5,5 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index eb39b6d23..6a4bcdba0 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -50,6 +50,8 @@
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
+#define CC_QY1 210
+#define CC_QY2 220
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
@@ -134,6 +136,10 @@ typedef float2 dfloat2;
#define INT8_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
+#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
+#define FLASH_ATTN_AVAILABLE
+#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
+
static constexpr bool fast_fp16_available(const int cc) {
return cc >= CC_PASCAL && cc != 610;
}
@@ -569,6 +575,7 @@ struct ggml_graph_node_properties {
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];
void * src_address[GGML_MAX_SRC];
+ int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
};
struct ggml_cuda_graph {
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
index 51deb75fd..54c0f66d2 100644
--- a/ggml/src/ggml-cuda/cpy.cu
+++ b/ggml/src/ggml-cuda/cpy.cu
@@ -81,6 +81,17 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
}
}
+static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
+ const block_q8_0 * xi = (const block_q8_0 *) cxi;
+ float * dsti = (float *) cdsti;
+
+ const float d = (float)xi->d;
+
+ for (int j = 0; j < QK8_0; j++) {
+ dsti[j] = xi->qs[j] * d;
+ }
+}
+
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q4_0 * dsti = (block_q4_0 *) cdsti;
@@ -288,6 +299,32 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
cpy_blck(cx + x_offset, cdst + dst_offset);
}
+template
+static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+ const int nb12, const int nb13) {
+ const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
+
+ if (i >= ne) {
+ return;
+ }
+
+ const int i03 = i/(ne00 * ne01 * ne02);
+ const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+ const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
+ const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+ const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+ const int i13 = i/(ne10 * ne11 * ne12);
+ const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+ const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+ const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+ const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
+
+ cpy_blck(cx + x_offset, cdst + dst_offset);
+}
+
static void ggml_cpy_f16_f32_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -329,6 +366,16 @@ static void ggml_cpy_f32_q8_0_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
+static void ggml_cpy_q8_0_f32_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+ const int num_blocks = ne;
+ cpy_q_f32<<>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
static void ggml_cpy_f32_q4_0_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -437,6 +484,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
+ ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
@@ -471,6 +520,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
return (void*) cpy_f32_f16;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
return (void*) cpy_f32_q;
+ } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
+ return (void*) cpy_q_f32;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
return (void*) cpy_f32_q;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
diff --git a/ggml/src/ggml-cuda/cross-entropy-loss.cu b/ggml/src/ggml-cuda/cross-entropy-loss.cu
index 5575a90f6..ed09406a8 100644
--- a/ggml/src/ggml-cuda/cross-entropy-loss.cu
+++ b/ggml/src/ggml-cuda/cross-entropy-loss.cu
@@ -71,6 +71,32 @@ static __global__ void cross_entropy_loss_f32(const float * logits, const float
dst[blockIdx.x] = loss;
}
+static __global__ void cross_entropy_loss_back_f32(const float * logits, const float * labels, const float * loss, float * dst, const int nclasses) {
+ extern __shared__ float tmp[];
+
+ float maxval = -INFINITY;
+ for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+ const float val = logits[blockIdx.x*nclasses + i];
+ maxval = fmaxf(maxval, val);
+ tmp[i] = val;
+ }
+ maxval = warp_reduce_max(maxval);
+
+ float sum = 0.0f;
+ for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+ const float val = expf(tmp[i] - maxval);
+ sum += val;
+ tmp[i] = val;
+ }
+ sum = warp_reduce_sum(sum);
+ const float sm_scale = 1.0f/sum;
+
+ const float d_by_nrows = *loss/gridDim.x;
+ for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+ dst[blockIdx.x*nclasses + i] = (tmp[i]*sm_scale - labels[blockIdx.x*nclasses + i])*d_by_nrows;
+ }
+}
+
void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
@@ -104,3 +130,37 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
// Combine results from individual blocks:
sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
}
+
+void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * opt0 = dst->src[2];
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(opt0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
+ GGML_ASSERT(ggml_is_contiguous(opt0));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ const float * src0_d = (const float *) src0->data;
+ const float * src1_d = (const float *) src1->data;
+ const float * opt0_d = (const float *) opt0->data;
+ float * dst_d = (float *) dst->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ const dim3 blocks_dim(WARP_SIZE, 1, 1);
+ const dim3 blocks_num(nrows, 1, 1);
+ const int shmem = ne00*sizeof(float);
+
+ cross_entropy_loss_back_f32<<>>(src0_d, src1_d, opt0_d, dst_d, ne00);
+}
diff --git a/ggml/src/ggml-cuda/cross-entropy-loss.cuh b/ggml/src/ggml-cuda/cross-entropy-loss.cuh
index 9d7b8b0f0..9ec7152ff 100644
--- a/ggml/src/ggml-cuda/cross-entropy-loss.cuh
+++ b/ggml/src/ggml-cuda/cross-entropy-loss.cuh
@@ -3,3 +3,5 @@
#define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256
void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu
index 827437ca0..f402195ce 100644
--- a/ggml/src/ggml-cuda/fattn-tile-f32.cu
+++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu
@@ -44,13 +44,17 @@ static __global__ void flash_attn_tile_ext_f32(
const int ne1,
const int ne2,
const int ne3) {
+#ifndef FLASH_ATTN_AVAILABLE
+ NO_DEVICE_CODE;
+ return;
+#endif // FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
return;
}
- //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+ // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index f28a19d40..83e5589a1 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -314,7 +314,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
}
if (!fast_fp16_available(cc)) {
- if (Q->ne[1] <= 8) {
+ if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu
index 3d0d8d4e6..16463ab0f 100644
--- a/ggml/src/ggml-cuda/im2col.cu
+++ b/ggml/src/ggml-cuda/im2col.cu
@@ -69,7 +69,6 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
diff --git a/ggml/src/ggml-cuda/opt-step-adamw.cu b/ggml/src/ggml-cuda/opt-step-adamw.cu
new file mode 100644
index 000000000..d6f13a9c6
--- /dev/null
+++ b/ggml/src/ggml-cuda/opt-step-adamw.cu
@@ -0,0 +1,80 @@
+#include "opt-step-adamw.cuh"
+
+#include
+
+static __global__ void opt_step_adamw_f32(
+ float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v, const int64_t k,
+ const float alpha, const float beta1, const float beta2, const float eps, const float wd,
+ const float beta1h, const float beta2h) {
+
+ const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ const float gi = g[i];
+ const float gmi = g_m[i]*beta1 + gi*(1.0f - beta1);
+ const float gvi = g_v[i]*beta2 + gi*gi*(1.0f - beta2);
+
+ g_m[i] = gmi;
+ g_v[i] = gvi;
+
+ const float mh = gmi*beta1h;
+ const float vh = sqrtf(gvi*beta2h) + eps;
+
+ x[i] = x[i]*(1.0f - alpha*wd) - mh/vh;
+}
+
+static void opt_step_adamw_f32_cuda(
+ float * x, const float * g, float * g_m, float * g_v, const int64_t k,
+ const float alpha, const float beta1, const float beta2, const float eps, const float wd,
+ const float beta1h, const float beta2h, cudaStream_t stream) {
+
+ const dim3 block_dims(CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
+ const dim3 block_nums((k + CUDA_OPT_STEP_ADAMW_BLOCK_SIZE - 1) / CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
+ opt_step_adamw_f32<<>>(x, g, g_m, g_v, k, alpha, beta1, beta2, eps, wd, beta1h, beta2h);
+}
+
+void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src0_grad = dst->src[1];
+ const ggml_tensor * src0_grad_m = dst->src[2];
+ const ggml_tensor * src0_grad_v = dst->src[3];
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0_grad_m->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0_grad_v->type == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src0_grad));
+ GGML_ASSERT(ggml_is_contiguous(src0_grad_m));
+ GGML_ASSERT(ggml_is_contiguous(src0_grad_v));
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
+
+ float * src0_d = (float *) src0->data;
+ const float * src0_grad_d = (const float *) src0_grad->data;
+ float * src0_grad_m_d = (float *) src0_grad_m->data;
+ float * src0_grad_v_d = (float *) src0_grad_v->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ const int64_t ne = ggml_nelements(src0);
+
+ int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t));
+ float alpha; memcpy(&alpha, &dst->op_params[2], sizeof(float));
+ float beta1; memcpy(&beta1, &dst->op_params[3], sizeof(float));
+ float beta2; memcpy(&beta2, &dst->op_params[4], sizeof(float));
+ float eps; memcpy(&eps, &dst->op_params[5], sizeof(float));
+ float wd; memcpy(&wd, &dst->op_params[6], sizeof(float));
+
+ const float beta1h = alpha/(1.0f - powf(beta1, iter));
+ const float beta2h = 1.0f/(1.0f - powf(beta2, iter));
+
+ opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, ne, alpha, beta1, beta2, eps, wd, beta1h, beta2h, stream);
+
+ iter++;
+ memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
+}
diff --git a/ggml/src/ggml-cuda/opt-step-adamw.cuh b/ggml/src/ggml-cuda/opt-step-adamw.cuh
new file mode 100644
index 000000000..58d6f6e5d
--- /dev/null
+++ b/ggml/src/ggml-cuda/opt-step-adamw.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_OPT_STEP_ADAMW_BLOCK_SIZE 256
+
+void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/out-prod.cu b/ggml/src/ggml-cuda/out-prod.cu
new file mode 100644
index 000000000..619cfdcb5
--- /dev/null
+++ b/ggml/src/ggml-cuda/out-prod.cu
@@ -0,0 +1,51 @@
+#include "out-prod.cuh"
+
+#include
+
+void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+
+ GGML_ASSERT(ne01 == ne11);
+ GGML_ASSERT(ne0 == ne00);
+ GGML_ASSERT(ne1 == ne10);
+
+ GGML_ASSERT(ne2 == src0->ne[2]);
+ GGML_ASSERT(ne2 == src1->ne[2]);
+ GGML_ASSERT(ne3 == src0->ne[3]);
+ GGML_ASSERT(ne3 == src1->ne[3]);
+
+ const float * src0_d = (const float *) src0->data;
+ const float * src1_d = (const float *) src1->data;
+ float * dst_d = (float *) dst->data;
+
+ cudaStream_t stream = ctx.stream();
+ cublasHandle_t handle = ctx.cublas_handle();
+
+ const float alpha = 1.0f;
+ const float beta = 0.0f;
+
+ GGML_ASSERT(ne2 == 1);
+ GGML_ASSERT(ne3 == 1);
+ CUBLAS_CHECK(cublasSetStream(handle, stream));
+
+ const bool src1_T = ggml_is_transposed(src1);
+ const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
+ const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
+ GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float));
+
+ CUBLAS_CHECK(
+ cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
+ ne0, ne1, ne01,
+ &alpha, src0_d, ne00,
+ src1_d, ldb,
+ &beta, dst_d, ne0));
+}
diff --git a/ggml/src/ggml-cuda/out-prod.cuh b/ggml/src/ggml-cuda/out-prod.cuh
new file mode 100644
index 000000000..a0046f5f8
--- /dev/null
+++ b/ggml/src/ggml-cuda/out-prod.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/rwkv-wkv.cu b/ggml/src/ggml-cuda/rwkv-wkv.cu
new file mode 100644
index 000000000..098e92d35
--- /dev/null
+++ b/ggml/src/ggml-cuda/rwkv-wkv.cu
@@ -0,0 +1,89 @@
+#include "common.cuh"
+#include "rwkv-wkv.cuh"
+
+static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
+ const int tid = threadIdx.x;
+ const int bid = blockIdx.x;
+
+ const int head_size = CUDA_WKV_BLOCK_SIZE;
+ const int batch_i = bid / H;
+ const int head_i = bid % H;
+ const int state_size = C * head_size;
+ const int n_seq_tokens = T / B;
+
+ float state[head_size];
+ __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
+
+ #pragma unroll
+ for (int i = 0; i < head_size; i++) {
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
+ }
+
+ __syncthreads();
+ _tf[tid] = tf[head_i * head_size + tid];
+ __syncthreads();
+
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
+ __syncthreads();
+ _k[tid] = k[t];
+ _r[tid] = r[t];
+ _td[tid] = td[t];
+ __syncthreads();
+
+ const float _v = v[t];
+ float y = 0;
+ for (int j = 0; j < head_size; j += 4) {
+ const float4& k = (float4&)(_k[j]);
+ const float4& r = (float4&)(_r[j]);
+ const float4& tf = (float4&)(_tf[j]);
+ const float4& td = (float4&)(_td[j]);
+ float4& s = (float4&)(state[j]);
+ float4 kv;
+
+ kv.x = k.x * _v;
+ kv.y = k.y * _v;
+ kv.z = k.z * _v;
+ kv.w = k.w * _v;
+
+ y += r.x * (tf.x * kv.x + s.x);
+ y += r.y * (tf.y * kv.y + s.y);
+ y += r.z * (tf.z * kv.z + s.z);
+ y += r.w * (tf.w * kv.w + s.w);
+
+ s.x = s.x * td.x + kv.x;
+ s.y = s.y * td.y + kv.y;
+ s.z = s.z * td.z + kv.z;
+ s.w = s.w * td.w + kv.w;
+ }
+ dst[t] = y;
+ }
+
+ #pragma unroll
+ for (int i = 0; i < head_size; i++) {
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
+ }
+}
+
+void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const float * k_d = (const float *)dst->src[0]->data;
+ const float * v_d = (const float *)dst->src[1]->data;
+ const float * r_d = (const float *)dst->src[2]->data;
+ const float * tf_d = (const float *)dst->src[3]->data;
+ const float * td_d = (const float *)dst->src[4]->data;
+ const float * s_d = (const float *)dst->src[5]->data;
+
+ const int64_t B = dst->src[5]->ne[1];
+ const int64_t T = dst->src[0]->ne[3];
+ const int64_t C = dst->ne[0];
+ const int64_t H = dst->src[0]->ne[2];
+
+ float * dst_d = (float *)dst->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
+ GGML_ASSERT(C % H == 0);
+ GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE);
+
+ rwkv_wkv_f32<<>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
+}
diff --git a/ggml/src/ggml-cuda/rwkv-wkv.cuh b/ggml/src/ggml-cuda/rwkv-wkv.cuh
new file mode 100644
index 000000000..13795247f
--- /dev/null
+++ b/ggml/src/ggml-cuda/rwkv-wkv.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_WKV_BLOCK_SIZE 64
+
+void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/sum.cu b/ggml/src/ggml-cuda/sum.cu
index 21da63509..0583e4fe0 100644
--- a/ggml/src/ggml-cuda/sum.cu
+++ b/ggml/src/ggml-cuda/sum.cu
@@ -1,9 +1,13 @@
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
+#define USE_CUB
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
+
+#ifdef USE_CUB
// On Windows CUB uses libraries with variables called CC_PASCAL which conflict with the define in common.cuh.
// For this reason CUB must be included BEFORE anything else.
#include
using namespace cub;
-#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
+#endif // USE_CUB
#include "sumrows.cuh"
#include "sum.cuh"
@@ -11,7 +15,7 @@ using namespace cub;
#include
void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) {
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
+#ifdef USE_CUB
size_t tmp_size = 0;
DeviceReduce::Sum(nullptr, tmp_size, x, dst, ne, stream);
ggml_cuda_pool_alloc tmp_alloc(pool, tmp_size);
@@ -21,7 +25,7 @@ void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int
// For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14.
sum_rows_f32_cuda(x, dst, ne, 1, stream);
GGML_UNUSED(pool);
-#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
+#endif // USE_CUB
}
void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu
index 8ac669f94..81fc92202 100644
--- a/ggml/src/ggml-cuda/unary.cu
+++ b/ggml/src/ggml-cuda/unary.cu
@@ -10,6 +10,16 @@ static __global__ void neg_f32(const float * x, float * dst, const int k) {
dst[i] = -x[i];
}
+static __global__ void step_f32(const float * x, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ dst[i] = x[i] > 0.0f;
+}
+
static __global__ void gelu_f32(const float * x, float * dst, const int k) {
const float GELU_COEF_A = 0.044715f;
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
@@ -85,6 +95,15 @@ static __global__ void hardswish_f32(const float * x, float * dst, const int k)
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
}
+static __global__ void exp_f32(const float * x, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = expf(x[i]);
+}
+
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
@@ -134,6 +153,11 @@ static void neg_f32_cuda(const float * x, float * dst, const int k, cudaStream_t
neg_f32<<>>(x, dst, k);
}
+static void step_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_STEP_BLOCK_SIZE - 1) / CUDA_STEP_BLOCK_SIZE;
+ step_f32<<>>(x, dst, k);
+}
+
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
gelu_f32<<>>(x, dst, k);
@@ -174,6 +198,11 @@ static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaSt
hardswish_f32<<>>(x, dst, k);
}
+static void exp_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE;
+ exp_f32<<>>(x, dst, k);
+}
+
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
leaky_relu_f32<<>>(x, dst, k, negative_slope);
@@ -213,6 +242,20 @@ void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
neg_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
+void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ step_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
@@ -325,6 +368,20 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
+void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ exp_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh
index ed2ffc461..c91936728 100644
--- a/ggml/src/ggml-cuda/unary.cuh
+++ b/ggml/src/ggml-cuda/unary.cuh
@@ -1,12 +1,14 @@
#include "common.cuh"
#define CUDA_NEG_BLOCK_SIZE 256
+#define CUDA_STEP_BLOCK_SIZE 256
#define CUDA_GELU_BLOCK_SIZE 256
#define CUDA_SILU_BLOCK_SIZE 256
#define CUDA_TANH_BLOCK_SIZE 256
#define CUDA_RELU_BLOCK_SIZE 256
#define CUDA_SIGMOID_BLOCK_SIZE 256
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
+#define CUDA_EXP_BLOCK_SIZE 256
#define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
#define CUDA_SQRT_BLOCK_SIZE 256
@@ -15,6 +17,8 @@
void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@@ -29,6 +33,8 @@ void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index d0c377255..1f3c70c2e 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -30,6 +30,7 @@
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
+#define cublasOperation_t hipblasOperation_t
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h
index 8df571149..1604b8229 100644
--- a/ggml/src/ggml-cuda/vendors/musa.h
+++ b/ggml/src/ggml-cuda/vendors/musa.h
@@ -26,6 +26,7 @@
#define cublasSetStream mublasSetStream
#define cublasSgemm mublasSgemm
#define cublasStatus_t mublasStatus_t
+#define cublasOperation_t mublasOperation_t
#define cublasGetStatusString mublasStatus_to_string
#define cudaDataType_t musaDataType_t
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
@@ -56,6 +57,7 @@
#define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaMalloc musaMalloc
#define cudaMallocHost musaMallocHost
+#define cudaMallocManaged musaMallocManaged
#define cudaMemcpy musaMemcpy
#define cudaMemcpyAsync musaMemcpyAsync
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
diff --git a/ggml/src/ggml-kompute.cpp b/ggml/src/ggml-kompute.cpp
index 7f0bd82d5..9cbc57a64 100644
--- a/ggml/src/ggml-kompute.cpp
+++ b/ggml/src/ggml-kompute.cpp
@@ -1872,6 +1872,7 @@ static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = {
/* .free_buffer = */ ggml_backend_kompute_buffer_free_buffer,
/* .get_base = */ ggml_backend_kompute_buffer_get_base,
/* .init_tensor = */ NULL,
+ /* .memset_tensor = */ NULL,
/* .set_tensor = */ ggml_backend_kompute_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_kompute_buffer_get_tensor,
/* .cpy_tensor = */ NULL,
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index f87181d19..ef3b7f0e8 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -3167,6 +3167,7 @@ static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
/* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
/* .get_base = */ ggml_backend_metal_buffer_get_base,
/* .init_tensor = */ NULL,
+ /* .memset_tensor = */ NULL,
/* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
/* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor,
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index f323ab5f4..2b2000323 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -2631,11 +2631,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
const short iv3 = iq3 / rv3;
// load the queries from shared memory into local memory
- half4 mq[D4];
+ float4 mq[D4];
for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
- mq[i] = sq4[i];
+ mq[i] = (float4) sq4[i];
}
// pointer to the mask
@@ -2661,11 +2661,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
for (short ii = 0; ii < D4; ii += NW) {
const short i = ii + tiisg;
- half4x4 mk;
- mk[0] = pk4[i + 0*(nb11/8)];
- mk[1] = pk4[i + 1*(nb11/8)];
- mk[2] = pk4[i + 2*(nb11/8)];
- mk[3] = pk4[i + 3*(nb11/8)];
+ float4x4 mk;
+ mk[0] = (float4) pk4[i + 0*(nb11/8)];
+ mk[1] = (float4) pk4[i + 1*(nb11/8)];
+ mk[2] = (float4) pk4[i + 2*(nb11/8)];
+ mk[3] = (float4) pk4[i + 3*(nb11/8)];
mqk += (float4) (mq[i] * mk);
}
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index 8bffce860..7aa6dce89 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -4013,7 +4013,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
svfloat32_t sumv0 = svdup_n_f32(0.0f);
svfloat32_t sumv1 = svdup_n_f32(0.0f);
- const int vector_length = ggml_sve_cnt_b*8;
+ const int vector_length = ggml_cpu_get_sve_cnt()*8;
// VLA Implementation using switch case
switch (vector_length) {
@@ -5597,7 +5597,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
svfloat32_t sumv0 = svdup_n_f32(0.0f);
svfloat32_t sumv1 = svdup_n_f32(0.0f);
- const int vector_length = ggml_sve_cnt_b*8;
+ const int vector_length = ggml_cpu_get_sve_cnt()*8;
//VLA Implemenation for SVE
switch (vector_length) {
diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h
index e96ce2b5e..df9c4b24a 100644
--- a/ggml/src/ggml-quants.h
+++ b/ggml/src/ggml-quants.h
@@ -142,10 +142,6 @@ void iq2xs_free_impl(enum ggml_type type);
void iq3xs_init_impl(int grid_size);
void iq3xs_free_impl(int grid_size);
-#if defined(__ARM_FEATURE_SVE)
-extern int ggml_sve_cnt_b;
-#endif
-
#ifdef __cplusplus
}
#endif
diff --git a/ggml/src/ggml-rpc.cpp b/ggml/src/ggml-rpc.cpp
index a8a2eb85a..49b3fa911 100644
--- a/ggml/src/ggml-rpc.cpp
+++ b/ggml/src/ggml-rpc.cpp
@@ -469,6 +469,7 @@ static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
/* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
/* .get_base = */ ggml_backend_rpc_buffer_get_base,
/* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
+ /* .memset_tensor = */ NULL,
/* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor,
/* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor,
diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp
index acef7c6d4..6978a3192 100644
--- a/ggml/src/ggml-sycl.cpp
+++ b/ggml/src/ggml-sycl.cpp
@@ -3496,8 +3496,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
- && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE
- && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda || src1->ne[1] > MMVQ_MIN_BATCH_SIZE);
+ && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
@@ -4323,6 +4322,7 @@ static struct ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
/* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
/* .get_base = */ ggml_backend_sycl_buffer_get_base,
/* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor,
+ /* .memset_tensor = */ NULL,
/* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
/* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
@@ -4734,6 +4734,7 @@ static struct ggml_backend_buffer_i ggml_backend_sycl_split_buffer_interface = {
/* .free_buffer = */ ggml_backend_sycl_split_buffer_free_buffer,
/* .get_base = */ ggml_backend_sycl_split_buffer_get_base,
/* .init_tensor = */ ggml_backend_sycl_split_buffer_init_tensor,
+ /* .memset_tensor = */ NULL,
/* .set_tensor = */ ggml_backend_sycl_split_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_sycl_split_buffer_get_tensor,
/* .cpy_tensor = */ NULL,
diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp
index 05947ccb7..bc0faa867 100644
--- a/ggml/src/ggml-sycl/common.hpp
+++ b/ggml/src/ggml-sycl/common.hpp
@@ -134,7 +134,6 @@ typedef sycl::float2 dfloat2;
#endif // GGML_SYCL_F16
#define MMVQ_MAX_BATCH_SIZE 8
-#define MMVQ_MIN_BATCH_SIZE 4
static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp
index bad960510..c677a2728 100644
--- a/ggml/src/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan.cpp
@@ -20,6 +20,8 @@
#include
#include
#include
+#include
+#include
#include "ggml-impl.h"
#include "ggml-backend-impl.h"
@@ -607,13 +609,16 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx
GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend);
-static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, const std::string& name, size_t spv_size, const void* spv_data, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector&& specialization_constants, uint32_t align) {
+// variables to track number of compiles in progress
+static uint32_t compile_count = 0;
+static std::mutex compile_count_mutex;
+static std::condition_variable compile_count_cond;
+
+static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector specialization_constants, uint32_t align) {
VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");
GGML_ASSERT(parameter_count > 0);
GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
- std::lock_guard guard(device->mutex);
-
pipeline = std::make_shared();
pipeline->name = name;
pipeline->parameter_count = parameter_count;
@@ -681,7 +686,17 @@ static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, co
pipeline->layout);
pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
- device->pipelines.insert({ pipeline->name, pipeline });
+ {
+ std::lock_guard guard(device->mutex);
+ device->pipelines.insert({ pipeline->name, pipeline });
+ }
+
+ {
+ std::lock_guard guard(compile_count_mutex);
+ assert(compile_count > 0);
+ compile_count--;
+ }
+ compile_count_cond.notify_all();
}
static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
@@ -1079,7 +1094,8 @@ static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
// Fall back to host memory type
buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
} else {
- buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ // use rebar if available, otherwise fallback to device only visible memory
+ buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal);
}
} catch (const vk::SystemError& e) {
std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
@@ -1193,6 +1209,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared();
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared();
+ std::vector> compiles;
+ auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector&& specialization_constants, uint32_t align) {
+ {
+ // wait until fewer than N compiles are in progress
+ uint32_t N = std::max(1u, std::thread::hardware_concurrency());
+ std::unique_lock guard(compile_count_mutex);
+ while (compile_count >= N) {
+ compile_count_cond.wait(guard);
+ }
+ compile_count++;
+ }
+ compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align));
+ };
+
if (device->fp16) {
ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
@@ -1742,6 +1772,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
+
+ for (auto &c : compiles) {
+ c.wait();
+ }
}
static vk_device ggml_vk_get_device(size_t idx) {
@@ -2806,7 +2840,11 @@ static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t
static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) {
VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")");
- if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
+
+ // If the device is not an UMA device the memory is host-accessible through rebar. While writing
+ // through PCIe is sufficient fast reading back data from PCIe is slower than going through
+ // the HW device to host copy path.
+ if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) {
GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
memcpy(dst, (uint8_t *) src->ptr + offset, size);
@@ -5008,6 +5046,8 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
}
}
+ ggml_pipeline_allocate_descriptor_sets(ctx->device);
+
vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
@@ -5124,7 +5164,9 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
avg_err /= m * n;
- std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms avg_err=" << avg_err << std::endl;
+ double tflops = 2.0*m*n*k*batch*num_it / (time / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
+
+ std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
if (avg_err > 0.1) {
std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
@@ -5246,12 +5288,14 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
+ ggml_pipeline_allocate_descriptor_sets(ctx->device);
+
ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
ggml_vk_ctx_begin(ctx->device, subctx);
const std::vector pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne };
- ggml_vk_dispatch_pipeline(ctx, subctx, p, { { qx_buf, 0, qx_sz }, { x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1});
+ ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1});
ggml_vk_ctx_end(subctx);
auto begin = std::chrono::high_resolution_clock::now();
@@ -5378,6 +5422,8 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
}
}
+ ggml_pipeline_allocate_descriptor_sets(ctx->device);
+
ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
ggml_vk_buffer_write(y_buf, 0, y, y_sz);
@@ -5445,7 +5491,9 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
avg_err /= m * n;
- std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms avg_err=" << avg_err << std::endl;
+ double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
+
+ std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
if (avg_err > 0.01 || std::isnan(avg_err)) {
std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
@@ -5497,9 +5545,6 @@ static ggml_tensor_extra_gpu * ggml_vk_tensor_create_extra(ggml_tensor * tensor)
static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
#if defined(GGML_VULKAN_RUN_TESTS)
- ctx->staging = ggml_vk_create_buffer_check(ctx->device, 100ul * 1024ul * 1024ul,
- vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
- vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_F32);
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_0);
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_1);
@@ -6246,6 +6291,7 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
/* .free_buffer = */ ggml_backend_vk_buffer_free_buffer,
/* .get_base = */ ggml_backend_vk_buffer_get_base,
/* .init_tensor = */ ggml_backend_vk_buffer_init_tensor,
+ /* .memset_tensor = */ NULL,
/* .set_tensor = */ ggml_backend_vk_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_vk_buffer_get_tensor,
/* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor,
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 3a8aadae8..81b651c6a 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1,6 +1,7 @@
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
#define _USE_MATH_DEFINES // For M_PI on MSVC
+#include "ggml-backend.h"
#include "ggml-impl.h"
#include "ggml-cpu-impl.h"
#include "ggml-quants.h"
@@ -38,9 +39,6 @@
#include
#endif
-#if defined(__ARM_FEATURE_SVE)
-int ggml_sve_cnt_b = 0;
-#endif
#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8)
#undef GGML_USE_LLAMAFILE
#endif
@@ -62,6 +60,25 @@ int ggml_sve_cnt_b = 0;
#pragma warning(disable: 4702)
#endif
+// Note: once we move threading into a separate C++ file
+// will use std::hardware_destructive_interference_size instead of hardcoding it here
+// and we'll use C++ attribute syntax.
+#define GGML_CACHE_LINE 64
+
+#if defined(__clang__) || defined(__GNUC__)
+#define GGML_CACHE_ALIGN __attribute__((aligned(GGML_CACHE_LINE)))
+#endif
+
+#if defined(__has_feature)
+#if __has_feature(thread_sanitizer)
+#define GGML_TSAN_ENABLED 1
+#endif
+#else // __has_feature
+#if defined(__SANITIZE_THREAD__)
+#define GGML_TSAN_ENABLED 1
+#endif
+#endif // __has_feature
+
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
@@ -71,6 +88,8 @@ int ggml_sve_cnt_b = 0;
#include
#if !defined(__clang__)
+#define GGML_CACHE_ALIGN __declspec(align(GGML_CACHE_LINE))
+
typedef volatile LONG atomic_int;
typedef atomic_int atomic_bool;
typedef atomic_int atomic_flag;
@@ -113,6 +132,9 @@ static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
static void atomic_flag_clear(atomic_flag * ptr) {
InterlockedExchange(ptr, 0);
}
+static void atomic_thread_fence(memory_order mo) {
+ MemoryBarrier();
+}
#else // clang
#include
#endif
@@ -288,7 +310,6 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
#define GGML_DEBUG 0
#define GGML_GELU_FP16
#define GGML_GELU_QUICK_FP16
-#define GGML_N_TASKS_MAX (-1)
#define GGML_SOFT_MAX_UNROLL 4
#define GGML_VEC_DOT_UNROLL 2
@@ -431,6 +452,15 @@ static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
// precomputed f32 table for f16 (256 KB) (ggml-impl.h)
float ggml_table_f32_f16[1 << 16];
+#if defined(__ARM_ARCH)
+struct ggml_arm_arch_features_type {
+ int has_neon;
+ int has_i8mm;
+ int has_sve;
+ int sve_cnt;
+} ggml_arm_arch_features = {-1, -1, -1, 0};
+#endif
+
GGML_CALL const char * ggml_status_to_string(enum ggml_status status) {
switch (status) {
case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)";
@@ -2006,17 +2036,18 @@ struct ggml_threadpool {
// synchronization primitives
atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
- atomic_int n_barrier;
- atomic_int n_barrier_passed;
+ atomic_int GGML_CACHE_ALIGN n_barrier;
+ atomic_int GGML_CACHE_ALIGN n_barrier_passed;
atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
// these are atomic as an annotation for thread-sanitizer
atomic_bool stop; // Used for stopping the threadpool altogether
atomic_bool pause; // Used for pausing the threadpool or individual threads
+ atomic_bool abort; // Used for aborting processing of a graph
struct ggml_compute_state * workers; // per thread state
int n_threads_max; // number of threads in the pool
- int n_threads_cur; // number of threads used in the current graph
+ atomic_int n_threads_cur; // number of threads used in the current graph
int32_t prio; // Scheduling priority
uint32_t poll; // Polling level (0 - no polling)
@@ -2996,9 +3027,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS",
"CROSS_ENTROPY_LOSS_BACK",
+ "OPT_STEP_ADAMW",
};
-static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
+static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -3089,9 +3121,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss(x,y)",
"cross_entropy_loss_back(x,y)",
+ "adamw(x)",
};
-static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
+static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -3178,41 +3211,43 @@ inline static void ggml_critical_section_start(void) {
}
}
+static void ggml_barrier(struct ggml_threadpool * tp) {
+ int n_threads = atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed);
+ if (n_threads == 1) {
+ return;
+ }
+
#ifdef GGML_USE_OPENMP
-static void ggml_barrier(struct ggml_threadpool * threadpool) {
- if (threadpool->n_threads_cur == 1) {
- return;
- }
-
#pragma omp barrier
-}
#else
-static void ggml_barrier(struct ggml_threadpool * threadpool) {
- if (threadpool->n_threads_cur == 1) {
+ int n_passed = atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed);
+
+ // enter barrier (full seq-cst fence)
+ int n_barrier = atomic_fetch_add_explicit(&tp->n_barrier, 1, memory_order_seq_cst);
+
+ if (n_barrier == (n_threads - 1)) {
+ // last thread
+ atomic_store_explicit(&tp->n_barrier, 0, memory_order_relaxed);
+
+ // exit barrier (fill seq-cst fence)
+ atomic_fetch_add_explicit(&tp->n_barrier_passed, 1, memory_order_seq_cst);
return;
}
- atomic_int * n_barrier = &threadpool->n_barrier;
- atomic_int * n_barrier_passed = &threadpool->n_barrier_passed;
-
- int n_threads = threadpool->n_threads_cur;
- int passed_old = atomic_load_explicit(n_barrier_passed, memory_order_relaxed);
-
- if (atomic_fetch_add(n_barrier, 1) == n_threads - 1) {
- // last thread
- atomic_store(n_barrier, 0);
- atomic_fetch_add_explicit(n_barrier_passed, 1, memory_order_relaxed);
- } else {
- // wait for other threads
- while (true) {
- if (atomic_load_explicit(n_barrier_passed, memory_order_relaxed) != passed_old) {
- return;
- }
- ggml_thread_cpu_relax();
- }
+ // wait for other threads
+ while (atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed) == n_passed) {
+ ggml_thread_cpu_relax();
}
-}
+
+ // exit barrier (full seq-cst fence)
+ // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
+ #ifdef GGML_TSAN_ENABLED
+ atomic_fetch_add_explicit(&tp->n_barrier_passed, 0, memory_order_seq_cst);
+ #else
+ atomic_thread_fence(memory_order_seq_cst);
+ #endif
#endif
+}
// TODO: make this somehow automatically executed
// some sort of "sentry" mechanism
@@ -3644,6 +3679,70 @@ static inline int ggml_up(int n, int m) {
////////////////////////////////////////////////////////////////////////////////
+#if defined(__ARM_ARCH)
+
+#if defined(__linux__) && defined(__aarch64__)
+#include
+#elif defined(__APPLE__)
+#include
+#endif
+
+#if !defined(HWCAP2_I8MM)
+#define HWCAP2_I8MM 0
+#endif
+
+static void ggml_init_arm_arch_features(void) {
+#if defined(__linux__) && defined(__aarch64__)
+ uint32_t hwcap = getauxval(AT_HWCAP);
+ uint32_t hwcap2 = getauxval(AT_HWCAP2);
+
+ ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
+ ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
+ ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE);
+
+#if defined(__ARM_FEATURE_SVE)
+ ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
+#endif
+#elif defined(__APPLE__)
+ int oldp = 0;
+ size_t size = sizeof(oldp);
+ if (sysctlbyname("hw.optional.AdvSIMD", &oldp, &size, NULL, 0) != 0) {
+ oldp = 0;
+ }
+ ggml_arm_arch_features.has_neon = oldp;
+
+ if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) != 0) {
+ oldp = 0;
+ }
+ ggml_arm_arch_features.has_i8mm = oldp;
+
+ ggml_arm_arch_features.has_sve = 0;
+ ggml_arm_arch_features.sve_cnt = 0;
+#else
+// Run-time CPU feature detection not implemented for this platform, fallback to compile time
+#if defined(__ARM_NEON)
+ ggml_arm_arch_features.has_neon = 1;
+#else
+ ggml_arm_arch_features.has_neon = 0;
+#endif
+
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+ ggml_arm_arch_features.has_i8mm = 1;
+#else
+ ggml_arm_arch_features.has_i8mm = 0;
+#endif
+
+#if defined(__ARM_FEATURE_SVE)
+ ggml_arm_arch_features.has_sve = 1;
+ ggml_arm_arch_features.sve_cnt = 16;
+#else
+ ggml_arm_arch_features.has_sve = 0;
+ ggml_arm_arch_features.sve_cnt = 0;
+#endif
+#endif
+}
+#endif
+
struct ggml_context * ggml_init(struct ggml_init_params params) {
// make this function thread safe
ggml_critical_section_start();
@@ -3694,6 +3793,10 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
}
+#if defined(__ARM_ARCH)
+ ggml_init_arm_arch_features();
+#endif
+
is_first_call = false;
}
@@ -3742,12 +3845,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
GGML_ASSERT_ALIGNED(ctx->mem_buffer);
-#if defined(__ARM_FEATURE_SVE)
- if (!ggml_sve_cnt_b) {
- ggml_sve_cnt_b = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
- }
-#endif
-
GGML_PRINT_DEBUG("%s: context initialized\n", __func__);
ggml_critical_section_end();
@@ -4098,7 +4195,11 @@ static void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, floa
}
struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
- memset(tensor->data, 0, ggml_nbytes(tensor));
+ if (tensor->buffer) {
+ ggml_backend_tensor_memset(tensor, 0, 0, ggml_nbytes(tensor));
+ } else {
+ memset(tensor->data, 0, ggml_nbytes(tensor));
+ }
return tensor;
}
@@ -8324,11 +8425,46 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
return result;
}
+// opt_step_adamw
+
+struct ggml_tensor * ggml_opt_step_adamw(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float alpha,
+ float beta1,
+ float beta2,
+ float eps,
+ float wd) {
+ GGML_ASSERT(a->grad);
+ GGML_ASSERT(alpha > 0.0f);
+ GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
+ GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
+ GGML_ASSERT(eps >= 0.0f);
+ GGML_ASSERT(wd >= 0.0f && wd <= 1.0f);
+
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+ result->op = GGML_OP_OPT_STEP_ADAMW;
+ result->grad = NULL;
+ result->src[0] = a;
+ result->src[1] = a->grad;
+ result->src[2] = ggml_dup_tensor(ctx, a->grad);
+ result->src[3] = ggml_dup_tensor(ctx, a->grad);
+
+ const int64_t iter = 1;
+ memcpy(&result->op_params[0], &iter, sizeof(int64_t));
+ ggml_set_op_params_f32(result, 2, alpha);
+ ggml_set_op_params_f32(result, 3, beta1);
+ ggml_set_op_params_f32(result, 4, beta2);
+ ggml_set_op_params_f32(result, 5, eps);
+ ggml_set_op_params_f32(result, 6, wd);
+
+ return result;
+}
+
////////////////////////////////////////////////////////////////////////////////
-void ggml_set_param(
- struct ggml_context * ctx,
- struct ggml_tensor * tensor) {
+void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) {
tensor->flags |= GGML_TENSOR_FLAG_PARAM;
GGML_ASSERT(tensor->grad == NULL);
@@ -8336,6 +8472,13 @@ void ggml_set_param(
ggml_format_name(tensor->grad, "%s (grad)", tensor->name);
}
+void ggml_set_loss(struct ggml_tensor * tensor) {
+ GGML_ASSERT(ggml_is_scalar(tensor));
+ GGML_ASSERT(tensor->type == GGML_TYPE_F32);
+ GGML_ASSERT(tensor->grad);
+ tensor->flags |= GGML_TENSOR_FLAG_LOSS;
+}
+
// ggml_compute_forward_dup
static void ggml_compute_forward_dup_same_cont(
@@ -17410,7 +17553,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
const int64_t ir0 = dr*ith;
const int64_t ir1 = MIN(ir0 + dr, nr);
- float * d = (float *) opt0->data;
+ const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
for (int64_t i1 = ir0; i1 < ir1; i1++) {
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
@@ -17434,7 +17577,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
// grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
ggml_vec_sub_f32(nc, ds0, ds0, s1);
- ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr);
+ ggml_vec_scale_f32(nc, ds0, d_by_nr);
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
@@ -17463,6 +17606,94 @@ static void ggml_compute_forward_cross_entropy_loss_back(
}
}
+static void ggml_compute_forward_opt_step_adamw_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src0_grad = dst->src[1];
+ const struct ggml_tensor * src0_grad_m = dst->src[2];
+ const struct ggml_tensor * src0_grad_v = dst->src[3];
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nr = ggml_nrows(src0);
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+ GGML_ASSERT(nb00 == sizeof(float));
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ /* const float gnorm = 1.0f; */
+ int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t));
+ const float alpha = ggml_get_op_params_f32(dst, 2);
+ const float beta1 = ggml_get_op_params_f32(dst, 3);
+ const float beta2 = ggml_get_op_params_f32(dst, 4);
+ const float eps = ggml_get_op_params_f32(dst, 5);
+ const float wd = ggml_get_op_params_f32(dst, 6);
+
+ const float beta1h = alpha/(1.0f - powf(beta1, iter));
+ const float beta2h = 1.0f/(1.0f - powf(beta2, iter));
+
+ for (int ir = ir0; ir < ir1; ++ir) {
+ const int64_t i03 = ir/(ne02*ne01);
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+ const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
+
+ float * w = (float *) ((char *) src0->data + offset); // weight
+ const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
+ float * m = (float *) ((char *) src0_grad_m->data + offset);
+ float * v = (float *) ((char *) src0_grad_v->data + offset);
+
+ for (int i00 = 0; i00 < ne00; ++i00) {
+ m[i00] = m[i00]*beta1 + g[i00]*(1.0f - beta1);
+ v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
+
+ const float mh = m[i00]*beta1h;
+ const float vh = sqrtf(v[i00]*beta2h) + eps;
+
+ // The weight decay is applied independently of the Adam momenta m and v.
+ // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
+ // See: https://arxiv.org/pdf/1711.05101v3.pdf
+ w[i00] = w[i00]*(1.0f - alpha*wd) - mh/vh;
+ }
+ }
+
+ ggml_barrier(params->threadpool);
+ if (ith != 0) {
+ return;
+ }
+
+ iter++;
+ memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
+}
+
+static void ggml_compute_forward_opt_step_adamw(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_opt_step_adamw_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
/////////////////////////////////
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@@ -17808,6 +18039,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
ggml_compute_forward_cross_entropy_loss_back(params, tensor);
}
break;
+ case GGML_OP_OPT_STEP_ADAMW:
+ {
+ ggml_compute_forward_opt_step_adamw(params, tensor);
+ }
+ break;
case GGML_OP_NONE:
{
// nop
@@ -17962,7 +18198,7 @@ void ggml_build_backward_gradient_checkpointing(
struct ggml_tensor * * checkpoints,
int n_checkpoints) {
ggml_graph_cpy(gf, gb_tmp);
- ggml_build_backward_expand(ctx, gf, gb_tmp, true);
+ ggml_build_backward_expand(ctx, gf, gb_tmp, false, true);
if (n_checkpoints <= 0) {
ggml_graph_cpy(gb_tmp, gb);
@@ -18000,42 +18236,93 @@ void ggml_build_backward_gradient_checkpointing(
ggml_hash_map_free(replacements);
}
-// functions to change gradients considering the case that input a might be initial gradient with zero value
+// utility functions to change gradients
+// if a is in acc_table, modify gradients in-place and mark result as gradient accumulator
+// else if a is in zero_table, replace a
+// else, just add/subtract/etc. the gradients
-static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
+static struct ggml_tensor * ggml_add_or_set(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_hash_set * zero_table,
+ struct ggml_hash_set * acc_table) {
+ if (ggml_hash_contains(acc_table, a)) {
+ struct ggml_tensor * ret = ggml_add_impl(ctx, a, b, true);
+ const size_t insert_result = ggml_hash_insert(acc_table, ret);
+ GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
+ GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
+ return ret;
+ }
if (ggml_hash_contains(zero_table, a)) {
return b;
- } else {
- return ggml_add_impl(ctx, a, b, false);
}
+ return ggml_add_impl(ctx, a, b, false);
}
-static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set * zero_table) {
+static struct ggml_tensor * ggml_acc_or_set(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ const size_t nb1,
+ const size_t nb2,
+ const size_t nb3,
+ const size_t offset,
+ struct ggml_hash_set * zero_table,
+ struct ggml_hash_set * acc_table) {
+ if (ggml_hash_contains(acc_table, a)) {
+ struct ggml_tensor * ret = ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
+ const size_t insert_result = ggml_hash_insert(acc_table, ret);
+ GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
+ GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
+ return ret;
+ }
if (ggml_hash_contains(zero_table, a)) {
- struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f);
+ struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
- } else {
- return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
}
+ return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
}
-static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
+static struct ggml_tensor * ggml_add1_or_set(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_hash_set * zero_table,
+ struct ggml_hash_set * acc_table) {
+ if (ggml_hash_contains(acc_table, a)) {
+ struct ggml_tensor * ret = ggml_add1_impl(ctx, a, b, true);
+ const size_t insert_result = ggml_hash_insert(acc_table, ret);
+ GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
+ GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
+ return ret;
+ }
if (ggml_hash_contains(zero_table, a)) {
return ggml_repeat(ctx, b, a);
- } else {
- return ggml_add1_impl(ctx, a, b, false);
}
+ return ggml_add1_impl(ctx, a, b, false);
}
-static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
+static struct ggml_tensor * ggml_sub_or_set(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_hash_set * zero_table,
+ struct ggml_hash_set * acc_table) {
+ if (ggml_hash_contains(acc_table, a)) {
+ struct ggml_tensor * ret = ggml_sub_impl(ctx, a, b, true);
+ const size_t insert_result = ggml_hash_insert(acc_table, ret);
+ GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
+ GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
+ return ret;
+ }
if (ggml_hash_contains(zero_table, a)) {
return ggml_neg(ctx, b);
- } else {
- return ggml_sub_impl(ctx, a, b, false);
}
+ return ggml_sub_impl(ctx, a, b, false);
}
-static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set * zero_table) {
+static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set * zero_table, struct ggml_hash_set * acc_table) {
struct ggml_tensor * src0 = tensor->src[0];
struct ggml_tensor * src1 = tensor->src[1];
struct ggml_tensor * src2 = tensor->src[2];
@@ -18044,38 +18331,38 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
case GGML_OP_DUP:
{
if (src0->grad) {
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
}
} break;
case GGML_OP_ADD:
{
if (src0->grad) {
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
}
if (src1->grad) {
if (ggml_are_same_shape(src0, src1)) {
- src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
+ src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table);
} else {
- src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table);
+ src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table, acc_table);
}
}
} break;
case GGML_OP_ADD1:
{
if (src0->grad) {
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
}
if (src1->grad) {
src1->grad = ggml_add_or_set(ctx,
src1->grad,
ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_ACC:
{
if (src0->grad) {
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
}
if (src1->grad) {
const size_t nb1 = ((int32_t *) tensor->op_params)[0];
@@ -18097,16 +18384,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_reshape(ctx,
ggml_cont(ctx, tensor_grad_view),
src1->grad),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_SUB:
{
if (src0->grad) {
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
}
if (src1->grad) {
- src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table);
+ src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table);
}
} break;
case GGML_OP_MUL:
@@ -18116,14 +18403,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_add_or_set(ctx,
src0->grad,
ggml_mul(ctx, src1, tensor->grad),
- zero_table);
+ zero_table, acc_table);
}
if (src1->grad) {
src1->grad =
ggml_add_or_set(ctx,
src1->grad,
ggml_mul(ctx, src0, tensor->grad),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_DIV:
@@ -18133,7 +18420,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_add_or_set(ctx,
src0->grad,
ggml_div(ctx, tensor->grad, src1),
- zero_table);
+ zero_table, acc_table);
}
if (src1->grad) {
src1->grad =
@@ -18142,7 +18429,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_mul(ctx,
tensor->grad,
ggml_div(ctx, tensor, src1)),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_SQR:
@@ -18154,7 +18441,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_scale(ctx,
ggml_mul(ctx, src0, tensor->grad),
2.0f),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_SQRT:
@@ -18168,7 +18455,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
tensor->grad,
tensor),
0.5f),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_LOG:
@@ -18180,7 +18467,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_div(ctx,
tensor->grad,
src0),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_SIN:
@@ -18192,7 +18479,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_mul(ctx,
tensor->grad,
ggml_cos(ctx, src0)),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_COS:
@@ -18204,7 +18491,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_mul(ctx,
tensor->grad,
ggml_sin(ctx, src0)),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_SUM:
@@ -18214,7 +18501,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_add1_or_set(ctx,
src0->grad,
tensor->grad,
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_SUM_ROWS:
@@ -18226,7 +18513,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_repeat(ctx,
tensor->grad,
src0->grad),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_MEAN:
@@ -18241,7 +18528,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_repeat_back(ctx, tensor->grad, src0->grad),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_REPEAT_BACK:
@@ -18251,7 +18538,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_repeat(ctx, tensor->grad, src0->grad),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_CONCAT:
@@ -18276,7 +18563,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_RMS_NORM_BACK:
@@ -18324,7 +18611,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_add_or_set(ctx,
src0->grad, // [n,m,q1,r1]
s1_tg, // [n,m,q1,r1]
- zero_table);
+ zero_table, acc_table);
}
if (src1->grad) {
src1->grad =
@@ -18342,7 +18629,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0, // [n,m,q1,r1]
ggml_transpose(ctx, // [p,m,qq,rr]
tensor->grad)), // [m,p,qq,rr]
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_MUL_MAT_ID:
@@ -18364,7 +18651,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_add_or_set(ctx,
src0->grad,
ggml_scale_impl(ctx, tensor->grad, s, false),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_SET:
@@ -18393,7 +18680,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
tensor->grad,
ggml_neg(ctx, tensor_grad_view),
nb1, nb2, nb3, offset, false),
- zero_table);
+ zero_table, acc_table);
}
if (src1->grad) {
@@ -18403,7 +18690,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_reshape(ctx,
ggml_cont(ctx, tensor_grad_view),
src1->grad),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_CPY:
@@ -18414,7 +18701,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// tensor = src0 * 1 + src1 * 0
if (src0->grad) {
// dsrc0 = dtensor * 1
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
}
if (src1->grad) {
// dsrc1 = dtensor * 0 -> noop
@@ -18426,7 +18713,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
if (src0->grad) {
GGML_ASSERT(ggml_is_contiguous(src0->grad));
GGML_ASSERT(ggml_is_contiguous(tensor->grad));
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
}
} break;
case GGML_OP_RESHAPE:
@@ -18440,7 +18727,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
? tensor->grad
: ggml_cont(ctx, tensor->grad),
src0->grad),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_VIEW:
@@ -18469,7 +18756,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
nb3 = (nb3 / n0) * ng;
}
- src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table);
+ src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table, acc_table);
}
} break;
case GGML_OP_PERMUTE:
@@ -18494,7 +18781,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
axes_backward[1],
axes_backward[2],
axes_backward[3]),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_TRANSPOSE:
@@ -18504,7 +18791,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad =
ggml_add_or_set(ctx, src0->grad,
ggml_transpose(ctx, tensor->grad),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_GET_ROWS:
@@ -18516,7 +18803,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// last ggml_get_rows_back argument src0->grad is only
// necessary to setup correct output shape
ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
- zero_table);
+ zero_table, acc_table);
}
if (src1->grad) {
// noop
@@ -18540,7 +18827,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
/* ggml_diag_mask_inf_impl() shouldn't be here */
/* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_DIAG_MASK_ZERO:
@@ -18551,7 +18838,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad =
ggml_add_or_set(ctx, src0->grad,
ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_SOFT_MAX:
@@ -18561,7 +18848,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad =
ggml_add_or_set(ctx, src0->grad,
ggml_soft_max_back(ctx, tensor->grad, tensor),
- zero_table);
+ zero_table, acc_table);
}
} break;
@@ -18602,7 +18889,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
attn_factor,
beta_fast,
beta_slow),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_ROPE_BACK:
@@ -18638,7 +18925,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
beta_fast,
beta_slow,
false),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_CLAMP:
@@ -18663,7 +18950,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src1->grad = ggml_add_or_set(ctx,
src1->grad,
ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_IM2COL_BACK:
@@ -18692,7 +18979,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_POOL_2D_BACK:
@@ -18757,7 +19044,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad = ggml_add_or_set(ctx,
src0->grad,
grad_q,
- zero_table);
+ zero_table, acc_table);
}
if (src1->grad) {
struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
@@ -18765,7 +19052,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src1->grad = ggml_add_or_set(ctx,
src1->grad,
grad_k,
- zero_table);
+ zero_table, acc_table);
}
if (src2->grad) {
struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
@@ -18773,7 +19060,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src2->grad = ggml_add_or_set(ctx,
src2->grad,
grad_v,
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_FLASH_ATTN_BACK:
@@ -18799,7 +19086,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_mul(ctx,
ggml_sgn(ctx, src0),
tensor->grad),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_UNARY_OP_SGN:
@@ -18811,7 +19098,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
case GGML_UNARY_OP_NEG:
{
if (src0->grad) {
- src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table);
+ src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
}
} break;
case GGML_UNARY_OP_STEP:
@@ -18836,7 +19123,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_mul(ctx,
ggml_step(ctx, src0),
tensor->grad),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_UNARY_OP_SIGMOID:
@@ -18858,7 +19145,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_silu_back(ctx, src0, tensor->grad),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_UNARY_OP_EXP:
@@ -18867,7 +19154,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_mul(ctx, tensor, tensor->grad),
- zero_table);
+ zero_table, acc_table);
}
} break;
default:
@@ -18897,13 +19184,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0,
src1,
tensor->grad),
- zero_table);
+ zero_table, acc_table);
}
} break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
{
GGML_ABORT("fatal error"); // not supported
}
+ case GGML_OP_OPT_STEP_ADAMW:
+ {
+ GGML_ABORT("fatal error"); // not supported
+ }
case GGML_OP_NONE:
{
// nop
@@ -18993,7 +19284,7 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
ggml_build_forward_impl(cgraph, tensor, true);
}
-void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) {
+void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep) {
GGML_ASSERT(gf->n_nodes > 0);
GGML_ASSERT(gf->grads);
@@ -19009,21 +19300,35 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
}
}
- // remember original gradients which start with zero values
+ // keep tables of original gradients for replacement/accumulation logic
struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size);
+ struct ggml_hash_set acc_table = ggml_hash_set_new(gf->size);
for (int i = 0; i < gf->n_nodes; i++) {
- if (gf->grads[i]) {
- ggml_hash_insert(&zero_table, gf->grads[i]);
+ struct ggml_tensor * node = gf->nodes[i];
+
+ if (node->grad) {
+ {
+ const size_t insert_result = ggml_hash_insert(&zero_table, node->grad);
+ GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
+ GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
+ }
+
+ // only gradients of trainable parameters should be accumulated
+ if (accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
+ const size_t insert_result = ggml_hash_insert(&acc_table, node->grad);
+ GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
+ GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
+ }
}
}
for (int i = gf->n_nodes - 1; i >= 0; i--) {
struct ggml_tensor * node = gf->nodes[i];
- // inplace operations to add gradients are not created by ggml_compute_backward
+ // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
// use allocator to automatically make inplace operations
if (node->grad) {
- ggml_compute_backward(ctx, node, &zero_table);
+ ggml_compute_backward(ctx, node, &zero_table, &acc_table);
}
}
@@ -19037,8 +19342,30 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
}
ggml_hash_set_free(&zero_table);
+ ggml_hash_set_free(&acc_table);
}
+void ggml_build_opt_adamw(
+ struct ggml_context * ctx,
+ struct ggml_cgraph * gf,
+ struct ggml_cgraph * gb,
+ float alpha,
+ float beta1,
+ float beta2,
+ float eps,
+ float wd) {
+ for (int i = 0; i < gf->n_nodes; i++) {
+ struct ggml_tensor * node = gf->nodes[i];
+
+ if (node->flags & GGML_TENSOR_FLAG_PARAM) {
+ GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
+ struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd);
+ ggml_build_forward_expand(gb, opt_step);
+ }
+ }
+}
+
+
static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
void * ptr = *p;
ptr = (void *) GGML_PAD((uintptr_t) ptr, align);
@@ -19166,10 +19493,28 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
GGML_ASSERT(cgraph->grads != NULL);
for (int i = 0; i < cgraph->n_nodes; i++) {
- struct ggml_tensor * grad = cgraph->grads[i];
+ struct ggml_tensor * node = cgraph->nodes[i];
- if (grad) {
- ggml_set_zero(grad);
+ // initial gradients of loss should be 1, 0 otherwise
+ if (node->grad) {
+ if (node->flags & GGML_TENSOR_FLAG_LOSS) {
+ GGML_ASSERT(node->grad->buffer);
+ GGML_ASSERT(node->type == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_is_scalar(node));
+
+ const float onef = 1.0f;
+ ggml_backend_tensor_set(node->grad, &onef, 0, ggml_nbytes(node->grad));
+ } else {
+ ggml_set_zero(node->grad);
+ }
+ }
+
+ GGML_ASSERT(node);
+ if (node->op == GGML_OP_OPT_STEP_ADAMW) {
+ // set iteration to 1 and clear momenta
+ ggml_set_op_params_i32(node, 0, 1);
+ ggml_set_zero(node->src[2]);
+ ggml_set_zero(node->src[3]);
}
}
}
@@ -19462,6 +19807,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
} break;
case GGML_OP_CROSS_ENTROPY_LOSS:
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+ case GGML_OP_OPT_STEP_ADAMW:
{
n_tasks = n_threads;
} break;
@@ -19757,8 +20103,8 @@ void ggml_threadpool_resume(struct ggml_threadpool * threadpool) {
struct ggml_cplan ggml_graph_plan(
const struct ggml_cgraph * cgraph,
- int n_threads,
- struct ggml_threadpool * threadpool) {
+ int n_threads,
+ struct ggml_threadpool * threadpool) {
if (threadpool == NULL) {
GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads);
@@ -19933,34 +20279,33 @@ struct ggml_cplan ggml_graph_plan(
static thread_ret_t ggml_graph_compute_thread(void * data) {
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
+ struct ggml_threadpool * tp = state->threadpool;
- const struct ggml_cgraph * cgraph = state->threadpool->cgraph;
- const struct ggml_cplan * cplan = state->threadpool->cplan;
+ const struct ggml_cgraph * cgraph = tp->cgraph;
+ const struct ggml_cplan * cplan = tp->cplan;
set_numa_thread_affinity(state->ith);
struct ggml_compute_params params = {
/*.ith =*/ state->ith,
- /*.nth =*/ state->threadpool->n_threads_cur,
+ /*.nth =*/ atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed),
/*.wsize =*/ cplan->work_size,
/*.wdata =*/ cplan->work_data,
- /*.threadpool=*/ state->threadpool,
+ /*.threadpool=*/ tp,
};
- for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
+ for (int node_n = 0; node_n < cgraph->n_nodes && !tp->abort; node_n++) {
struct ggml_tensor * node = cgraph->nodes[node_n];
ggml_compute_forward(¶ms, node);
- if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
- state->threadpool->ec = GGML_STATUS_ABORTED;
+ if (state->ith == 0 && cplan->abort_callback &&
+ cplan->abort_callback(cplan->abort_callback_data)) {
+ tp->abort = true;
+ tp->ec = GGML_STATUS_ABORTED;
}
ggml_barrier(state->threadpool);
-
- if (state->threadpool->ec != GGML_STATUS_SUCCESS) {
- break;
- }
}
return 0;
@@ -19968,7 +20313,15 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
#ifndef GGML_USE_OPENMP
-static inline bool ggml_graph_compute_ready(struct ggml_compute_state * state) {
+// check if thread is active
+static inline bool ggml_graph_compute_thread_active(struct ggml_compute_state * state) {
+ struct ggml_threadpool * threadpool = state->threadpool;
+ int n_threads = atomic_load_explicit(&threadpool->n_threads_cur, memory_order_relaxed);
+ return (state->ith < n_threads);
+}
+
+// check if thread is ready to proceed (exit from polling or sleeping)
+static inline bool ggml_graph_compute_thread_ready(struct ggml_compute_state * state) {
struct ggml_threadpool * threadpool = state->threadpool;
if (state->pending || threadpool->stop || threadpool->pause) { return true; }
@@ -19976,21 +20329,37 @@ static inline bool ggml_graph_compute_ready(struct ggml_compute_state * state) {
// check for new graph/work
int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed);
if (new_graph != state->last_graph) {
- state->pending = (state->ith < threadpool->n_threads_cur);
+ state->pending = ggml_graph_compute_thread_active(state);
state->last_graph = new_graph;
}
return state->pending;
}
+// sync thread state after polling
+static inline void ggml_graph_compute_thread_sync(struct ggml_compute_state * state) {
+ // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
+ #ifdef GGML_TSAN_ENABLED
+ atomic_fetch_add_explicit(&state->threadpool->n_graph, 0, memory_order_seq_cst);
+ #else
+ atomic_thread_fence(memory_order_seq_cst);
+ #endif
+ UNUSED(state);
+}
+
static inline bool ggml_graph_compute_poll_for_work(struct ggml_compute_state * state) {
struct ggml_threadpool * threadpool = state->threadpool;
+ // Skip polling for unused threads
+ if (!ggml_graph_compute_thread_active(state)) {
+ return state->pending;
+ }
+
// This seems to make 0 ... 100 a decent range for polling level across modern processors.
// Perhaps, we can adjust it dynamically based on load and things.
const uint64_t n_rounds = 1024UL * 128 * threadpool->poll;
- for (uint64_t i=0; !ggml_graph_compute_ready(state) && ithreadpool;
if (ggml_graph_compute_poll_for_work(state)) {
+ ggml_graph_compute_thread_sync(state);
return state->pending;
}
ggml_mutex_lock_shared(&threadpool->mutex);
- while (!ggml_graph_compute_ready(state)) {
+ while (!ggml_graph_compute_thread_ready(state)) {
// No new work. Wait for the signal.
- GGML_PRINT_DEBUG("thread #%d waiting for work\n", state->ith);
+ GGML_PRINT_DEBUG("thread #%d waiting for work (sleeping)\n", state->ith);
ggml_cond_wait(&threadpool->cond, &threadpool->mutex);
}
ggml_mutex_unlock_shared(&threadpool->mutex);
@@ -20055,13 +20425,20 @@ static thread_ret_t ggml_graph_compute_secondary_thread(void* data) {
}
// Start processing new graph
-static void ggml_graph_compute_kickoff(struct ggml_threadpool * threadpool)
+static void ggml_graph_compute_kickoff(struct ggml_threadpool * threadpool, int n_threads)
{
- // always take the mutex here because the worker threads are doing hybrid poll/wait
+ // Always take the mutex here because the worker threads are doing hybrid poll/wait
ggml_mutex_lock(&threadpool->mutex);
- atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_relaxed);
+ GGML_PRINT_DEBUG("threadpool: n_threads_cur %d n_threads %d\n", threadpool->n_threads_cur, n_threads);
+
+ // Update the number of active threads
+ atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed);
+
+ // Indicate the graph is ready to be processed
+ // We need the full seq-cst fence here because of the polling threads (used in thread_sync)
+ atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_seq_cst);
if (threadpool->pause) {
// Update main thread prio and affinity to match the threadpool settings
@@ -20120,6 +20497,7 @@ static struct ggml_threadpool * ggml_threadpool_new_impl(
threadpool->current_chunk = 0;
threadpool->stop = false;
threadpool->pause = tpp->paused;
+ threadpool->abort = false;
threadpool->workers = NULL;
threadpool->n_threads_max = tpp->n_threads;
threadpool->n_threads_cur = tpp->n_threads;
@@ -20195,15 +20573,11 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
// No worker threads should be accessing the parameters below at this stage
threadpool->cgraph = cgraph;
threadpool->cplan = cplan;
- threadpool->n_threads_cur = n_threads;
threadpool->current_chunk = 0;
+ threadpool->abort = false;
threadpool->ec = GGML_STATUS_SUCCESS;
}
- if (n_threads > threadpool->n_threads_max) {
- GGML_PRINT("WARNING: cplan is requesting more threads than the threadpool contains. Expect a bad time!\n");
- }
-
#ifdef GGML_USE_OPENMP
if (n_threads > 1) {
#pragma omp parallel num_threads(n_threads)
@@ -20212,17 +20586,23 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
{
// update the number of threads from the actual number of threads that we got from OpenMP
n_threads = omp_get_num_threads();
- threadpool->n_threads_cur = n_threads;
+ atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed);
}
ggml_graph_compute_thread(&threadpool->workers[omp_get_thread_num()]);
}
} else {
+ atomic_store_explicit(&threadpool->n_threads_cur, 1, memory_order_relaxed);
ggml_graph_compute_thread(&threadpool->workers[0]);
}
#else
+ if (n_threads > threadpool->n_threads_max) {
+ GGML_PRINT("WARNING: cplan requested more threads (%d) than available (%d)\n", n_threads, threadpool->n_threads_max);
+ n_threads = threadpool->n_threads_max;
+ }
+
// Kick all threads to start the new graph
- ggml_graph_compute_kickoff(threadpool);
+ ggml_graph_compute_kickoff(threadpool, n_threads);
// This is a work thread too
ggml_graph_compute_thread(&threadpool->workers[0]);
@@ -21824,7 +22204,7 @@ enum ggml_opt_result ggml_opt_resume(
ggml_build_forward_expand(gf, f);
struct ggml_cgraph * gb = ggml_graph_dup(ctx, gf);
- ggml_build_backward_expand(ctx, gf, gb, true);
+ ggml_build_backward_expand(ctx, gf, gb, false, true);
return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL);
}
@@ -23266,16 +23646,16 @@ int ggml_cpu_has_fma(void) {
}
int ggml_cpu_has_neon(void) {
-#if defined(__ARM_NEON)
- return 1;
+#if defined(__ARM_ARCH)
+ return ggml_arm_arch_features.has_neon;
#else
return 0;
#endif
}
int ggml_cpu_has_sve(void) {
-#if defined(__ARM_FEATURE_SVE)
- return 1;
+#if defined(__ARM_ARCH)
+ return ggml_arm_arch_features.has_sve;
#else
return 0;
#endif
@@ -23422,11 +23802,18 @@ int ggml_cpu_has_vsx(void) {
}
int ggml_cpu_has_matmul_int8(void) {
-#if defined(__ARM_FEATURE_MATMUL_INT8)
- return 1;
+#if defined(__ARM_ARCH)
+ return ggml_arm_arch_features.has_i8mm;
#else
return 0;
#endif
}
+int ggml_cpu_get_sve_cnt(void) {
+#if defined(__ARM_ARCH)
+ return ggml_arm_arch_features.sve_cnt;
+#else
+ return 0;
+#endif
+}
////////////////////////////////////////////////////////////////////////////////
diff --git a/ggml/src/vulkan-shaders/argsort.comp b/ggml/src/vulkan-shaders/argsort.comp
index e55414b03..d4fa45b1e 100644
--- a/ggml/src/vulkan-shaders/argsort.comp
+++ b/ggml/src/vulkan-shaders/argsort.comp
@@ -29,20 +29,18 @@ void main() {
const int col = int(gl_LocalInvocationID.x);
const uint row = gl_WorkGroupID.y;
- if (col >= p.ncols_pad) {
- return;
- }
-
const uint row_offset = row * p.ncols;
// initialize indices
- dst_row[col] = col;
+ if (col < p.ncols_pad) {
+ dst_row[col] = col;
+ }
barrier();
for (uint k = 2; k <= p.ncols_pad; k *= 2) {
for (uint j = k / 2; j > 0; j /= 2) {
const uint ixj = col ^ j;
- if (ixj > col) {
+ if (col < p.ncols_pad && ixj > col) {
if ((col & k) == 0) {
if (dst_row[col] >= p.ncols ||
(dst_row[ixj] < p.ncols && (p.order == ASC ?
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index 7aaf6745a..e08617ba2 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -94,9 +94,12 @@ class Keys:
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
+ SWIN_NORM = "{arch}.swin_norm"
RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers"
TIME_MIX_EXTRA_DIM = "{arch}.time_mix_extra_dim"
TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim"
+ RESIDUAL_SCALE = "{arch}.residual_scale"
+ EMBEDDING_SCALE = "{arch}.embedding_scale"
class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
@@ -112,6 +115,7 @@ class Keys:
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
SLIDING_WINDOW = "{arch}.attention.sliding_window"
+ SCALE = "{arch}.attention.scale"
class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
@@ -231,6 +235,9 @@ class MODEL_ARCH(IntEnum):
JAIS = auto()
NEMOTRON = auto()
EXAONE = auto()
+ GRANITE = auto()
+ GRANITE_MOE = auto()
+ CHAMELEON = auto()
class MODEL_TENSOR(IntEnum):
@@ -338,6 +345,8 @@ class MODEL_TENSOR(IntEnum):
ENC_FFN_DOWN = auto()
ENC_FFN_UP = auto()
ENC_OUTPUT_NORM = auto()
+ CLS = auto() # classifier
+ CLS_OUT = auto() # classifier output projection
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@@ -387,6 +396,9 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.JAIS: "jais",
MODEL_ARCH.NEMOTRON: "nemotron",
MODEL_ARCH.EXAONE: "exaone",
+ MODEL_ARCH.GRANITE: "granite",
+ MODEL_ARCH.GRANITE_MOE: "granitemoe",
+ MODEL_ARCH.CHAMELEON: "chameleon",
}
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -494,6 +506,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down",
MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up",
MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm",
+ MODEL_TENSOR.CLS: "cls",
+ MODEL_TENSOR.CLS_OUT: "cls.output",
}
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@@ -603,6 +617,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.LAYER_OUT_NORM,
+ MODEL_TENSOR.CLS,
+ MODEL_TENSOR.CLS_OUT,
],
MODEL_ARCH.NOMIC_BERT: [
MODEL_TENSOR.TOKEN_EMBD,
@@ -634,6 +650,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.LAYER_OUT_NORM,
+ MODEL_TENSOR.CLS,
],
MODEL_ARCH.MPT: [
MODEL_TENSOR.TOKEN_EMBD,
@@ -1228,6 +1245,51 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
+ MODEL_ARCH.GRANITE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.GRANITE_MOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ ],
+ MODEL_ARCH.CHAMELEON: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
# TODO
}
diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py
index 3c95c2673..5c460ef1b 100644
--- a/gguf-py/gguf/gguf_writer.py
+++ b/gguf-py/gguf/gguf_writer.py
@@ -670,6 +670,9 @@ class GGUFWriter:
def add_expert_weights_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
+ def add_swin_norm(self, value: bool) -> None:
+ self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
+
def add_rescale_every_n_layers(self, count: int) -> None:
self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count)
@@ -679,6 +682,12 @@ class GGUFWriter:
def add_time_decay_extra_dim(self, dim: int) -> None:
self.add_uint32(Keys.LLM.TIME_DECAY_EXTRA_DIM.format(arch=self.arch), dim)
+ def add_residual_scale(self, value: float) -> None:
+ self.add_float32(Keys.LLM.RESIDUAL_SCALE.format(arch=self.arch), value)
+
+ def add_embedding_scale(self, value: float) -> None:
+ self.add_float32(Keys.LLM.EMBEDDING_SCALE.format(arch=self.arch), value)
+
def add_wkv_head_size(self, size: int) -> None:
self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size)
@@ -703,6 +712,9 @@ class GGUFWriter:
def add_sliding_window(self, value: int) -> None:
self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)
+ def add_attention_scale(self, value: float) -> None:
+ self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
+
def add_pooling_type(self, value: PoolingType) -> None:
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index 4639f2f9b..f4a787c56 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -254,11 +254,12 @@ class TensorNameMap:
),
MODEL_TENSOR.FFN_GATE_INP: (
- "layers.{bid}.feed_forward.gate", # mixtral
- "model.layers.{bid}.block_sparse_moe.gate", # mixtral
- "model.layers.{bid}.mlp.gate", # qwen2moe olmoe
- "transformer.decoder_layer.{bid}.router", # Grok
- "transformer.blocks.{bid}.ffn.router.layer", # dbrx
+ "layers.{bid}.feed_forward.gate", # mixtral
+ "model.layers.{bid}.block_sparse_moe.gate", # mixtral
+ "model.layers.{bid}.mlp.gate", # qwen2moe olmoe
+ "transformer.decoder_layer.{bid}.router", # Grok
+ "transformer.blocks.{bid}.ffn.router.layer", # dbrx
+ "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
),
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@@ -367,10 +368,11 @@ class TensorNameMap:
),
MODEL_TENSOR.FFN_DOWN_EXP: (
- "layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
- "transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
- "transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
- "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
+ "layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
+ "transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
+ "transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
+ "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
+ "model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
),
MODEL_TENSOR.FFN_DOWN_SHEXP: (
@@ -381,7 +383,7 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_Q_NORM: (
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
- "model.layers.{bid}.self_attn.q_norm", # cohere olmoe
+ "model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
"transformer.layers.{bid}.attn.q_norm", # openelm
@@ -390,7 +392,7 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_K_NORM: (
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
- "model.layers.{bid}.self_attn.k_norm", # cohere olmoe
+ "model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
"transformer.layers.{bid}.attn.k_norm", # openelm
@@ -680,6 +682,15 @@ class TensorNameMap:
MODEL_TENSOR.ENC_OUTPUT_NORM: (
"encoder.final_layer_norm", # t5
),
+
+ MODEL_TENSOR.CLS: (
+ "classifier", # jina
+ "classifier.dense", # roberta
+ ),
+
+ MODEL_TENSOR.CLS_OUT: (
+ "classifier.out_proj", # roberta
+ ),
}
# architecture-specific block mappings
diff --git a/grammars/README.md b/grammars/README.md
index 7ec815471..4e8b4e2fc 100644
--- a/grammars/README.md
+++ b/grammars/README.md
@@ -120,7 +120,7 @@ You can use GBNF grammars:
- In [llama-server](../examples/server):
- For any completion endpoints, passed as the `json_schema` body field
- - For the `/chat/completions` endpoint, passed inside the `response_format` body field (e.g. `{"type", "json_object", "schema": {"items": {}}}`)
+ - For the `/chat/completions` endpoint, passed inside the `response_format` body field (e.g. `{"type", "json_object", "schema": {"items": {}}}` or `{ type: "json_schema", json_schema: {"schema": ...} }`)
- In [llama-cli](../examples/main), passed as the `--json` / `-j` flag
- To convert to a grammar ahead of time:
- in CLI, with [examples/json_schema_to_grammar.py](../examples/json_schema_to_grammar.py)
diff --git a/include/llama.h b/include/llama.h
index cfc8d85dc..7cae1bbe2 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -102,6 +102,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
+ LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
};
enum llama_rope_type {
@@ -192,6 +193,7 @@ extern "C" {
LLAMA_POOLING_TYPE_MEAN = 1,
LLAMA_POOLING_TYPE_CLS = 2,
LLAMA_POOLING_TYPE_LAST = 3,
+ LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph
};
enum llama_attention_type {
@@ -201,9 +203,9 @@ extern "C" {
};
enum llama_split_mode {
- LLAMA_SPLIT_MODE_NONE = 0, // single GPU
- LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
- LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
+ LLAMA_SPLIT_MODE_NONE = 0, // single GPU
+ LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
+ LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
};
// TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
@@ -441,6 +443,7 @@ extern "C" {
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
+ LLAMA_API int32_t llama_n_head (const struct llama_model * model);
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
@@ -870,7 +873,8 @@ extern "C" {
// Get the embeddings for a sequence id
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
- // shape: [n_embd] (1-dimensional)
+ // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
+ // otherwise: float[n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
//
@@ -909,6 +913,8 @@ extern "C" {
//
// Tokenization
//
+ // The API is thread-safe.
+ //
/// @details Convert the provided text into tokens.
/// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
@@ -1065,6 +1071,7 @@ extern "C" {
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
+ /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void);
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
diff --git a/models/ggml-vocab-chameleon.gguf.inp b/models/ggml-vocab-chameleon.gguf.inp
new file mode 100644
index 000000000..9baf7d77a
--- /dev/null
+++ b/models/ggml-vocab-chameleon.gguf.inp
@@ -0,0 +1,112 @@
+ied 4 ½ months
+__ggml_vocab_test__
+Führer
+__ggml_vocab_test__
+
+__ggml_vocab_test__
+
+__ggml_vocab_test__
+
+__ggml_vocab_test__
+
+__ggml_vocab_test__
+
+__ggml_vocab_test__
+
+
+__ggml_vocab_test__
+
+
+
+__ggml_vocab_test__
+
+
+
+
+__ggml_vocab_test__
+
+
+__ggml_vocab_test__
+Hello world
+__ggml_vocab_test__
+ Hello world
+__ggml_vocab_test__
+Hello World
+__ggml_vocab_test__
+ Hello World
+__ggml_vocab_test__
+ Hello World!
+__ggml_vocab_test__
+Hello, world!
+__ggml_vocab_test__
+ Hello, world!
+__ggml_vocab_test__
+ this is 🦙.cpp
+__ggml_vocab_test__
+w048 7tuijk dsdfhu
+__ggml_vocab_test__
+нещо на Български
+__ggml_vocab_test__
+កាន់តែពិសេសអាចខលចេញ
+__ggml_vocab_test__
+🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
+__ggml_vocab_test__
+Hello
+__ggml_vocab_test__
+ Hello
+__ggml_vocab_test__
+ Hello
+__ggml_vocab_test__
+ Hello
+__ggml_vocab_test__
+ Hello
+__ggml_vocab_test__
+ Hello
+ Hello
+__ggml_vocab_test__
+ (
+__ggml_vocab_test__
+
+ =
+__ggml_vocab_test__
+' era
+__ggml_vocab_test__
+Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
+__ggml_vocab_test__
+!!!!!!
+__ggml_vocab_test__
+3
+__ggml_vocab_test__
+33
+__ggml_vocab_test__
+333
+__ggml_vocab_test__
+3333
+__ggml_vocab_test__
+33333
+__ggml_vocab_test__
+333333
+__ggml_vocab_test__
+3333333
+__ggml_vocab_test__
+33333333
+__ggml_vocab_test__
+333333333
+__ggml_vocab_test__
+Cửa Việt
+__ggml_vocab_test__
+ discards
+__ggml_vocab_test__
+
+
+
+
+
+
+
+
+
+
+
+🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
+__ggml_vocab_test__
diff --git a/models/ggml-vocab-chameleon.gguf.out b/models/ggml-vocab-chameleon.gguf.out
new file mode 100644
index 000000000..7c5413fee
--- /dev/null
+++ b/models/ggml-vocab-chameleon.gguf.out
@@ -0,0 +1,46 @@
+ 17245 16604 16403 16604 33583 18355
+ 16421 51153
+
+ 16604
+ 16650
+ 16650 16604
+ 16581
+ 16582
+ 16582 16582
+ 16582 16582 16582
+ 16581 16582
+ 31596 17394
+ 34926 17394
+ 31596 18671
+ 34926 18671
+ 34926 18671 16384
+ 31596 16395 17394 16384
+ 34926 16395 17394 16384
+ 16811 16704 20410 16483 16631 16397 52854
+ 16470 16399 16403 16407 16604 16406 35764 38185 51595 22592 26639
+ 29479 23955 17012 20103 25527 27670 17408 19005 21473 24774
+ 54254 42231 48084 29409 16617 61889 29409 16608 21954 16628 21954 16499 58445 29409 16607 58445 21954 16479 42231 21954 16611 21954 16607 21954 16633 21954 16611 29409 16607 21954 16615
+ 52351 16604 16391 25825 16392 23686 16498 39161 18885 16618 16488 30853 16604 16391 54124 17153 25134 16656 18476 26169 16895 16392 62193 16611 16604 16391 24664 17153 57169 16721 16872 17073 17304 28729 16392
+ 31596
+ 34926
+ 16650 31596
+ 16650 34926
+ 16696 31596
+ 16696 31596 16582 16696 31596
+ 16604 16391
+ 16582 16604 16412
+ 16390 22623
+ 31596 16395 16712 16390 16828 16384 17674 16769 16732 23686 16607 16604 16414 24427 16623 41809 16495 28999 36469 45292 30197 16400 16402 16400 16403 16400 16404 16400 43969 65211 16636
+ 16384 16384 16384 16384 16384 16384
+ 16402
+ 16402 16402
+ 16402 16402 16402
+ 16402 16402 16402 16402
+ 16402 16402 16402 16402 16402
+ 16402 16402 16402 16402 16402 16402
+ 16402 16402 16402 16402 16402 16402 16402
+ 16402 16402 16402 16402 16402 16402 16402 16402
+ 16402 16402 16402 16402 16402 16402 16402 16402 16402
+ 16418 19038 16639 16448 24315 33727 16467
+ 18765 17981
+ 16582 16604 16582 16582 16604 16582 16582 16582 16604 16581 16604 16581 16581 16604 16581 16582 16650 16582 16650 16604 16582 16696 16582 16696 16604 16582 52351 16604 16391 25825 16392 23686 16498 39161 18885 16618 16488 30853 16604 16391 54124 17153 25134 16656 18476 26169 16895 16392 62193 16611 20410 16483 16631 18885 16483 16631 16604 16402 16604 16402 16402 16604 16402 16402 16402 16604 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16402 16402 16604 16402 16397 16402 16604 16402 16397 16397 16402 16604 16402 16397 16397 16397 16402 16604 54254 42231 48084 29409 16617 61889 29409 16608 21954 16628 21954 16499 58445 29409 16607 58445 21954 16479 42231 21954 16611 27683 16607 16604 16414 24427 16623 41809 16495 28999 36469 45292 30197 16400 16402 16400 16403 16400 16404 16400 43969 65211 16636 16604 16396 16396 16396 16396 16396 16396 16412 16412 16412 16412 16412 16412 16412 27268 23955 17012 20103 25527 27670 17408 19005 21473 24774 16604 16390 16390 16390 16390 16390 16390 16447 16447 16447 16447 16447 16447 16447 16385 16385 16385 16385 16397 16397 16397 16397 16397 16397 16384 16384 16384 16384 16384 16384 16414 16414 16414 16414 16414 16414 16687 16390 16690 16992 16604 16390 61797 16733 16390 16466 16986 16395 16604 16390 17879 16732 17811 16414 16604 16390 16428 16804 17811 16687 16390 16683 17190 16728 16395 16604 16390 16419 16732 16945 16991 25251 16414 17119 16390 38127 16641 16390 16459 16427
diff --git a/pyrightconfig.json b/pyrightconfig.json
index 6016f4b6d..9acbbeb78 100644
--- a/pyrightconfig.json
+++ b/pyrightconfig.json
@@ -5,7 +5,8 @@
"reportUnusedImport": "warning",
"reportDuplicateImport": "error",
"reportDeprecated": "warning",
- "reportUnnecessaryTypeIgnoreComment": "warning",
+ "reportUnnecessaryTypeIgnoreComment": "information",
+ "disableBytesTypePromotions": false, // TODO: change once Python 3.12 is the minimum
"executionEnvironments": [
{
// TODO: make this version override work correctly
diff --git a/requirements/requirements-convert_legacy_llama.txt b/requirements/requirements-convert_legacy_llama.txt
index 1d07b0952..859204b27 100644
--- a/requirements/requirements-convert_legacy_llama.txt
+++ b/requirements/requirements-convert_legacy_llama.txt
@@ -1,5 +1,5 @@
numpy~=1.26.4
sentencepiece~=0.2.0
-transformers>=4.40.1,<5.0.0
+transformers>=4.45.1,<5.0.0
gguf>=0.1.0
protobuf>=4.21.0,<5.0.0
diff --git a/scripts/compare-commits.sh b/scripts/compare-commits.sh
index 70679f4e5..8b9b1ad39 100755
--- a/scripts/compare-commits.sh
+++ b/scripts/compare-commits.sh
@@ -8,6 +8,9 @@ fi
set -e
set -x
+# verify at the start that the compare script has all the necessary dependencies installed
+./scripts/compare-llama-bench.py --check
+
bench_args="${@:3}"
rm -f llama-bench.sqlite > /dev/null
diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py
index 92b9e682a..e45e83ce8 100755
--- a/scripts/compare-llama-bench.py
+++ b/scripts/compare-llama-bench.py
@@ -92,6 +92,7 @@ help_s = (
"If the columns are manually specified, then the results for each unique combination of the "
"specified values are averaged WITHOUT weighing by the --repetitions parameter of llama-bench."
)
+parser.add_argument("--check", action="store_true", help="check if all required Python libraries are installed")
parser.add_argument("-s", "--show", help=help_s)
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
@@ -99,6 +100,10 @@ known_args, unknown_args = parser.parse_known_args()
logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO)
+if known_args.check:
+ # Check if all required Python libraries are installed. Would have failed earlier if not.
+ sys.exit(0)
+
if unknown_args:
logger.error(f"Received unknown args: {unknown_args}.\n")
parser.print_help()
diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last
index 3d2dfb413..aa301462a 100644
--- a/scripts/sync-ggml.last
+++ b/scripts/sync-ggml.last
@@ -1 +1 @@
-10e83a412717c20d57ba19f025248e18e43addf3
+9a24b8c8c40eab7262d067e91d08df160678df8d
diff --git a/src/llama-impl.h b/src/llama-impl.h
index 2bde75ec1..70f16b61c 100644
--- a/src/llama-impl.h
+++ b/src/llama-impl.h
@@ -28,6 +28,8 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void *
#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
+#define LLAMA_LOG_DEBUG(...) llama_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
+#define LLAMA_LOG_CONT(...) llama_log_internal(GGML_LOG_LEVEL_CONT , __VA_ARGS__)
//
// helpers
diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp
index 5275b1d60..e255a8fc4 100644
--- a/src/llama-sampling.cpp
+++ b/src/llama-sampling.cpp
@@ -3,13 +3,14 @@
#include "llama-vocab.h"
#include "llama-grammar.h"
-#include
#include
-#include
-#include
+#include
#include
#include
#include
+#include
+#include
+#include
#include
#include
#include
@@ -236,9 +237,10 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
// TODO: do not allocate each time
- std::vector cur(n_vocab);
+ std::vector cur;
+ cur.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
- cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
+ cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array cur_p = {
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
index 2c007477e..d2f34ddd6 100644
--- a/src/llama-vocab.cpp
+++ b/src/llama-vocab.cpp
@@ -50,7 +50,7 @@ struct naive_trie {
res.first->second.insert(key + 1, len - 1, value);
}
}
- std::pair get_longest_prefix(const char * key, size_t len, size_t offset = 0) {
+ std::pair get_longest_prefix(const char * key, size_t len, size_t offset = 0) const {
if (len == 0 || offset == len) {
return std::make_pair(key, offset);
}
@@ -79,6 +79,15 @@ struct naive_trie {
// impl
//
+struct llm_tokenizer {
+ llm_tokenizer() {}
+ virtual ~llm_tokenizer() = default;
+};
+
+llama_vocab::~llama_vocab() {
+ delete tokenizer;
+}
+
int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
GGML_ASSERT(token_left.find(' ') == std::string::npos);
GGML_ASSERT(token_left.find('\n') == std::string::npos);
@@ -187,10 +196,15 @@ struct llm_bigram_spm {
size_t size;
};
-struct llm_tokenizer_spm {
- llm_tokenizer_spm(const llama_vocab & vocab) : vocab(vocab) {}
+struct llm_tokenizer_spm : llm_tokenizer {
+ llm_tokenizer_spm(const llama_vocab & /*vocab*/) : llm_tokenizer() {}
+};
+
+struct llm_tokenizer_spm_session {
+ llm_tokenizer_spm_session(const llama_vocab & vocab) : vocab(vocab) {}
void tokenize(const std::string & text, std::vector & output) {
+
// split string into utf8 chars
int index = 0;
size_t offs = 0;
@@ -271,7 +285,7 @@ private:
return;
}
- resegment(symbols[p->second.first], output);
+ resegment(symbols[p->second.first], output);
resegment(symbols[p->second.second], output);
}
@@ -279,7 +293,6 @@ private:
if (left == -1 || right == -1) {
return;
}
-
const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
auto token = vocab.token_to_id.find(text);
@@ -306,10 +319,11 @@ private:
}
const llama_vocab & vocab;
+ // currently unused
+ // const llm_tokenizer_spm * spm_tokenizer;
std::vector symbols;
llm_bigram_spm::queue work_queue;
-
std::map> rev_merge;
};
@@ -352,8 +366,8 @@ struct llm_bigram_bpe {
size_t size;
};
-struct llm_tokenizer_bpe {
- llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {
+struct llm_tokenizer_bpe : llm_tokenizer {
+ llm_tokenizer_bpe(const llama_vocab & vocab) : llm_tokenizer() {
GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
switch (vocab.type_pre) {
case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
@@ -450,6 +464,20 @@ struct llm_tokenizer_bpe {
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
+ case LLAMA_VOCAB_PRE_TYPE_CHAMELEON:
+ // Note: in theory, the special token (sentinel and image token) regex_exprs below
+ // are unnecessary, as they are split in `tokenizer_st_partition` anyway.
+ // However, since the upstream pre-tokenizer uses them, they are also
+ // included here (see https://huggingface.co/facebook/chameleon-7b).
+ regex_exprs = {
+ "", // Sentinel tokens
+ "(IMGIMG)((A|B|C|D|E|F|G|H|I){1,4})Z", // Image tokens
+ "([\\t\\n]| | )", // directly from tokenizer.json
+ "\\p{N}", // Individual digits
+ "[\\p{P}!-/:-@\\[-`{-~]", // Punctuation, Isolated
+ "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
+ };
+ break;
default:
// default regex for BPE tokenization pre-processing
regex_exprs = {
@@ -462,7 +490,14 @@ struct llm_tokenizer_bpe {
}
}
- void append(const llama_vocab::id token_id, std::vector & output) const {
+ std::vector regex_exprs;
+};
+
+struct llm_tokenizer_bpe_session {
+ llm_tokenizer_bpe_session(const llama_vocab & vocab) : vocab(vocab),
+ bpe_tokenizer(static_cast(vocab.tokenizer)) {}
+
+ static void append(const llama_vocab::id token_id, std::vector & output) {
output.push_back(token_id);
}
@@ -501,12 +536,11 @@ struct llm_tokenizer_bpe {
void tokenize(const std::string & text, std::vector & output) {
int final_prev_index = -1;
-
- const auto word_collection = unicode_regex_split(text, regex_exprs);
+ const auto word_collection = unicode_regex_split(text, bpe_tokenizer->regex_exprs);
symbols_final.clear();
- for (auto & word : word_collection) {
+ for (const auto & word : word_collection) {
work_queue = llm_bigram_bpe::queue();
symbols.clear();
@@ -609,7 +643,6 @@ private:
if (left == -1 || right == -1) {
return;
}
-
std::string left_token = std::string(symbols[left].text, symbols[left].n);
std::string right_token = std::string(symbols[right].text, symbols[right].n);
@@ -633,12 +666,10 @@ private:
}
const llama_vocab & vocab;
-
- std::vector regex_exprs;
+ const llm_tokenizer_bpe * bpe_tokenizer;
std::vector symbols;
std::vector symbols_final;
-
llm_bigram_bpe::queue work_queue;
};
@@ -646,15 +677,17 @@ private:
// WPM tokenizer
//
-struct llm_tokenizer_wpm {
- llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
+struct llm_tokenizer_wpm : llm_tokenizer {
+ llm_tokenizer_wpm(const llama_vocab & /*vocab*/) : llm_tokenizer() {}
+};
- void tokenize(const std::string & text, std::vector & output) const {
+struct llm_tokenizer_wpm_session {
+ llm_tokenizer_wpm_session(const llama_vocab & vocab) : vocab(vocab) {}
+
+ void tokenize(const std::string & text, std::vector & output) {
const auto & token_map = vocab.token_to_id;
-
// normalize and split by whitespace
std::vector words = preprocess(text);
-
// bos token prepended already
// find the longest tokens that form the words
@@ -699,7 +732,7 @@ struct llm_tokenizer_wpm {
}
// TODO: reduce string copies by using cpts_offs array
- std::vector preprocess(const std::string & text) const {
+ static std::vector preprocess(const std::string & text) {
const std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
std::vector words(1, "");
@@ -751,15 +784,18 @@ struct llm_tokenizer_wpm {
//(cpt >= 0xFF00 && cpt <= 0xFFEF);
}
+private:
const llama_vocab & vocab;
+ // currently unused
+ // const llm_tokenizer_wpm * wpm_tokenizer;
};
//
// UGM tokenizer
//
-struct llm_tokenizer_ugm {
- llm_tokenizer_ugm(const llama_vocab & vocab) : vocab(vocab) {
+struct llm_tokenizer_ugm : llm_tokenizer {
+ llm_tokenizer_ugm(const llama_vocab & vocab) : llm_tokenizer() {
if (vocab.precompiled_charsmap.size() > 0) {
size_t charsmap_offset = 0;
@@ -805,6 +841,30 @@ struct llm_tokenizer_ugm {
unknown_token_score = min_score - unknown_token_score_penalty;
}
+ // escaped space symbol - U+2581 (Lower One Eighth Block)
+ const std::string escaped_space = "\xE2\x96\x81";
+
+ const char * prefix_replacements = NULL;
+ size_t prefix_replacements_size = 0;
+
+ const uint32_t * xcda_array = NULL;
+ size_t xcda_array_size = 0;
+
+ struct naive_trie user_defined_token_matcher;
+
+ float min_score = FLT_MAX;
+ float max_score = -FLT_MAX;
+
+ float unknown_token_score_penalty = 10.0;
+ float unknown_token_score;
+
+ struct naive_trie token_matcher;
+};
+
+struct llm_tokenizer_ugm_session {
+ llm_tokenizer_ugm_session(const llama_vocab & vocab) : vocab(vocab),
+ ugm_tokenizer(static_cast(vocab.tokenizer)) {}
+
/* This implementation is based on SentencePiece optimized Viterbi algorithm for
* unigram language models. The general idea is to:
* - move along the input sequence in steps of one UTF code point,
@@ -843,7 +903,7 @@ struct llm_tokenizer_ugm {
// traverse the token matcher trie to find a matching token
bool single_codepoint_token_found = false;
const struct best_tokenization & current_best = tokenization_results[input_offset];
- const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
+ const struct naive_trie * node = ugm_tokenizer->token_matcher.traverse(normalized[prefix_offset++]);
while (prefix_offset <= input_len && node != NULL) {
// check if we found valid token in prefix
@@ -873,7 +933,7 @@ struct llm_tokenizer_ugm {
// if we didn't find a valid token corresponding to the whole UTF code point
// then use unknown token as the tokenization of this UTF code point
if (!single_codepoint_token_found) {
- const double challenger_score = current_best.score_sum + unknown_token_score;
+ const double challenger_score = current_best.score_sum + ugm_tokenizer->unknown_token_score;
prefix_offset = input_offset + n_utf8_code_units;
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
if (challenger_score > current_champ.score_sum) {
@@ -905,7 +965,6 @@ struct llm_tokenizer_ugm {
}
private:
- const llama_vocab & vocab;
// helper structure for returning normalization results
struct normalization_result {
@@ -918,7 +977,7 @@ private:
normalized->clear();
normalized->reserve(input.size() * 3);
- const std::string space = vocab.tokenizer_escape_whitespaces ? escaped_space : " ";
+ const std::string space = vocab.tokenizer_escape_whitespaces ? ugm_tokenizer->escaped_space : " ";
bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
@@ -1000,13 +1059,21 @@ private:
size_t xcda_array_size;
};
+ // this structure stores the best tokenization so far at input_offset
+ struct best_tokenization {
+ llama_token token_id;
+ size_t input_offset;
+ float score_sum;
+ };
+
struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
if (input_offset == input.size()) {
return { &input[input_offset], 0, 0 };
}
// if input prefix matches some user-defined token return this token as normalization result
- auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
+ auto user_defined_token_match =
+ ugm_tokenizer->user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
if (user_defined_token_match.second > 0) {
return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second };
}
@@ -1014,8 +1081,8 @@ private:
size_t longest_prefix_length = 0;
size_t longest_prefix_offset = 0;
- if (xcda_array_size > 0) {
- struct xcda_array_view xcda_view(xcda_array, xcda_array_size);
+ if (ugm_tokenizer->xcda_array_size > 0) {
+ struct xcda_array_view xcda_view(ugm_tokenizer->xcda_array, ugm_tokenizer->xcda_array_size);
// Find the longest normalized sequence matching the input prefix by walking
// the XOR-compressed compact double array (XCDA) starting from the root node
@@ -1051,50 +1118,27 @@ private:
if (longest_prefix_length > 0) {
// we have a match, so return the replacement sequence
- if (longest_prefix_offset >= prefix_replacements_size) {
+ if (longest_prefix_offset >= ugm_tokenizer->prefix_replacements_size) {
throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
}
- const char * prefix_replacement = &prefix_replacements[longest_prefix_offset];
+ const char * prefix_replacement = &(ugm_tokenizer->prefix_replacements)[longest_prefix_offset];
return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length };
- } else {
- // check if the input prefix contains a valid sequence of UTF-8 code units
- try {
- // if yes, return this sequence unmodified
- size_t prefix_offset = input_offset;
- unicode_cpt_from_utf8(input, prefix_offset);
- return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset };
- } catch (std::invalid_argument & /*ex*/) {
- // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER
- return { "\xEF\xBF\xBD", 3, 1 };
- }
+ }
+
+ // check if the input prefix contains a valid sequence of UTF-8 code units
+ try {
+ // if yes, return this sequence unmodified
+ size_t prefix_offset = input_offset;
+ unicode_cpt_from_utf8(input, prefix_offset);
+ return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset };
+ } catch (std::invalid_argument & /*ex*/) {
+ // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER
+ return { "\xEF\xBF\xBD", 3, 1 };
}
}
- // escaped space symbol - U+2581 (Lower One Eighth Block)
- const std::string escaped_space = "\xE2\x96\x81";
-
- const char * prefix_replacements = NULL;
- size_t prefix_replacements_size = 0;
-
- const uint32_t * xcda_array = NULL;
- size_t xcda_array_size = 0;
-
- struct naive_trie user_defined_token_matcher;
-
- // this structure stores the best tokenization so far at input_offset
- struct best_tokenization {
- llama_token token_id;
- size_t input_offset;
- float score_sum;
- };
-
- float min_score = FLT_MAX;
- float max_score = -FLT_MAX;
-
- float unknown_token_score_penalty = 10.0;
- float unknown_token_score;
-
- struct naive_trie token_matcher;
+ const llama_vocab & vocab;
+ const llm_tokenizer_ugm * ugm_tokenizer;
};
//
@@ -1155,8 +1199,8 @@ static std::vector llama_unescape_rwkv_token(const std::string & escape
return output;
}
-struct llm_tokenizer_rwkv {
- llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) {
+struct llm_tokenizer_rwkv : llm_tokenizer {
+ llm_tokenizer_rwkv(const llama_vocab & vocab) : llm_tokenizer() {
// RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
// For now, we decode the vocab here into the lookup we'll use for tokenization.
@@ -1168,11 +1212,17 @@ struct llm_tokenizer_rwkv {
}
}
+ struct naive_trie token_matcher;
+};
+
+struct llm_tokenizer_rwkv_session {
+ llm_tokenizer_rwkv_session(const llama_vocab & vocab) : vocab(vocab),
+ rwkv_tokenizer(static_cast(*vocab.tokenizer)) {}
+
void tokenize(const std::string & text, std::vector & output) {
uint32_t position = 0;
-
while (position < text.size()) {
- const struct naive_trie * node = token_matcher.traverse(text[position]);
+ const struct naive_trie * node = rwkv_tokenizer.token_matcher.traverse(text[position]);
if (node == NULL) {
// no matching token found, add unknown token
output.push_back(vocab.special_unk_id);
@@ -1197,11 +1247,33 @@ struct llm_tokenizer_rwkv {
}
}
+private:
const llama_vocab & vocab;
-
- struct naive_trie token_matcher;
+ const llm_tokenizer_rwkv & rwkv_tokenizer;
};
+void llama_vocab::init_tokenizer() {
+ switch (type) {
+ case LLAMA_VOCAB_TYPE_SPM:
+ tokenizer = new llm_tokenizer_spm(*this);
+ break;
+ case LLAMA_VOCAB_TYPE_BPE:
+ tokenizer = new llm_tokenizer_bpe(*this);
+ break;
+ case LLAMA_VOCAB_TYPE_WPM:
+ tokenizer = new llm_tokenizer_wpm(*this);
+ break;
+ case LLAMA_VOCAB_TYPE_UGM:
+ tokenizer = new llm_tokenizer_ugm(*this);
+ break;
+ case LLAMA_VOCAB_TYPE_RWKV:
+ tokenizer = new llm_tokenizer_rwkv(*this);
+ break;
+ default:
+ GGML_ABORT("unsupported vocab type");
+ }
+}
+
//
// (de-) tokenize
//
@@ -1263,7 +1335,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
// if a fragment is text ( not yet processed )
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
- auto & raw_text = fragment.raw_text;
+ const auto & raw_text = fragment.raw_text;
auto raw_text_base_offset = fragment.offset;
auto raw_text_base_length = fragment.length;
@@ -1362,7 +1434,13 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
}
}
-std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) {
+std::vector llama_tokenize_internal(
+ const llama_vocab & vocab,
+ std::string raw_text,
+ bool add_special,
+ bool parse_special) {
+ GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
+
std::vector output;
std::forward_list fragment_buffer;
@@ -1399,9 +1477,9 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab,
#ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif
- llm_tokenizer_spm tokenizer(vocab);
llama_escape_whitespace(raw_text);
- tokenizer.tokenize(raw_text, output);
+ llm_tokenizer_spm_session session(vocab);
+ session.tokenize(raw_text, output);
is_prev_special = false;
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token);
@@ -1423,10 +1501,11 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab,
} break;
case LLAMA_VOCAB_TYPE_BPE:
{
- llm_tokenizer_bpe tokenizer(vocab);
-
+ llm_tokenizer_bpe_session session(vocab);
+ // it calls some other methods that are not exist in llm_tokenizer,
+ // here just cast it to bpe tokenizer object
if (add_special) {
- tokenizer.append_bos(output);
+ session.append_bos(output);
}
for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
@@ -1435,15 +1514,15 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab,
#ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif
- tokenizer.tokenize(raw_text, output);
+ session.tokenize(raw_text, output);
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
- tokenizer.append(fragment.token, output);
+ session.append(fragment.token, output);
}
}
if (add_special) {
- tokenizer.append_eos(output);
- tokenizer.check_double_bos_eos(output);
+ session.append_eos(output);
+ session.check_double_bos_eos(output);
}
} break;
case LLAMA_VOCAB_TYPE_WPM:
@@ -1453,7 +1532,7 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab,
output.push_back(vocab.special_cls_id);
}
- llm_tokenizer_wpm tokenizer(vocab);
+ llm_tokenizer_wpm_session session(vocab);
for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
@@ -1462,7 +1541,7 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab,
#ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif
- tokenizer.tokenize(raw_text, output);
+ session.tokenize(raw_text, output);
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token);
}
@@ -1475,12 +1554,11 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab,
} break;
case LLAMA_VOCAB_TYPE_UGM:
{
- llm_tokenizer_ugm tokenizer(vocab);
-
- if (add_special && vocab.tokenizer_add_bos != 0) {
+ if (add_special && vocab.tokenizer_add_bos) {
GGML_ASSERT(vocab.special_bos_id != -1);
output.push_back(vocab.special_bos_id);
}
+ llm_tokenizer_ugm_session session(vocab);
for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
@@ -1488,26 +1566,27 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab,
#ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif
- tokenizer.tokenize(raw_text, output);
+ session.tokenize(raw_text, output);
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token);
}
}
- if (add_special && vocab.tokenizer_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
+ if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
LLAMA_LOG_WARN(
"%s: Added a BOS token to the prompt as specified by the model but the prompt "
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
"Are you sure this is what you want?\n", __FUNCTION__);
}
- if (add_special && vocab.tokenizer_add_eos == 1) {
+ if (add_special && vocab.tokenizer_add_eos) {
GGML_ASSERT(vocab.special_eos_id != -1);
output.push_back(vocab.special_eos_id);
}
} break;
case LLAMA_VOCAB_TYPE_RWKV:
{
+ llm_tokenizer_rwkv_session session(vocab);
for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
@@ -1516,8 +1595,7 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab,
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif
- llm_tokenizer_rwkv tokenizer(vocab);
- tokenizer.tokenize(raw_text, output);
+ session.tokenize(raw_text, output);
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token);
}
@@ -1570,11 +1648,7 @@ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, lla
}
bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
- return token != -1 && (
- token == llama_token_eos_impl(vocab) ||
- token == llama_token_eot_impl(vocab) ||
- token == llama_token_eom_impl(vocab)
- );
+ return token != -1 && vocab.special_eog_ids.count(token) > 0;
}
bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
@@ -1634,13 +1708,13 @@ llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
}
int32_t llama_tokenize_impl(
- const struct llama_vocab & vocab,
- const char * text,
- int32_t text_len,
- llama_token * tokens,
- int32_t n_tokens_max,
- bool add_special,
- bool parse_special) {
+ const struct llama_vocab & vocab,
+ const char * text,
+ int32_t text_len,
+ llama_token * tokens,
+ int32_t n_tokens_max,
+ bool add_special,
+ bool parse_special) {
auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special);
if (n_tokens_max < (int) res.size()) {
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
@@ -1717,11 +1791,13 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
// suppressing them like CONTROL tokens.
if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
return _try_copy(token_text.data(), token_text.size());
- } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
+ }
+ if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
std::string result = token_text;
llama_unescape_whitespace(result);
return _try_copy(result.data(), result.size());
- } else if (attr & LLAMA_TOKEN_ATTR_BYTE) {
+ }
+ if (attr & LLAMA_TOKEN_ATTR_BYTE) {
char byte = (char) llama_token_to_byte(vocab, token);
return _try_copy((char*) &byte, 1);
}
@@ -1732,7 +1808,8 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
// suppressing them like CONTROL tokens.
if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
return _try_copy(token_text.data(), token_text.size());
- } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
+ }
+ if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
std::string result = llama_decode_text(token_text);
return _try_copy(result.data(), result.size());
}
@@ -1765,6 +1842,8 @@ int32_t llama_detokenize_impl(
int32_t text_len_max,
bool remove_special,
bool unparse_special) {
+ GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
+
int32_t avail = text_len_max;
int32_t total = 0;
diff --git a/src/llama-vocab.h b/src/llama-vocab.h
index dc4b5f12f..069bdc423 100644
--- a/src/llama-vocab.h
+++ b/src/llama-vocab.h
@@ -6,6 +6,9 @@
#include
#include
#include