Merge remote-tracking branch 'origin/master' into tool-call

This commit is contained in:
ochafik 2025-01-29 17:54:57 +00:00
commit babdefc4dd
69 changed files with 3254 additions and 845 deletions

View file

@ -2,6 +2,10 @@ ARG UBUNTU_VERSION=22.04
FROM ubuntu:$UBUNTU_VERSION AS build
ARG TARGETARCH
ARG GGML_CPU_ARM_ARCH=armv8-a
RUN apt-get update && \
apt-get install -y build-essential git cmake libcurl4-openssl-dev
@ -9,7 +13,14 @@ WORKDIR /app
COPY . .
RUN cmake -S . -B build -DGGML_BACKEND_DL=ON -DGGML_NATIVE=OFF -DGGML_CPU_ALL_VARIANTS=ON -DLLAMA_CURL=ON -DCMAKE_BUILD_TYPE=Release && \
RUN if [ "$TARGETARCH" = "amd64" ]; then \
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DLLAMA_CURL=ON -DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON; \
elif [ "$TARGETARCH" = "arm64" ]; then \
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DLLAMA_CURL=ON -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=${GGML_CPU_ARM_ARCH}; \
else \
echo "Unsupported architecture"; \
exit 1; \
fi && \
cmake --build build -j $(nproc)
RUN mkdir -p /app/lib && \

View file

@ -13,9 +13,13 @@ elif [[ "$arg1" == '--quantize' || "$arg1" == '-q' ]]; then
exec ./llama-quantize "$@"
elif [[ "$arg1" == '--run' || "$arg1" == '-r' ]]; then
exec ./llama-cli "$@"
elif [[ "$arg1" == '--bench' || "$arg1" == '-b' ]]; then
exec ./llama-bench "$@"
elif [[ "$arg1" == '--perplexity' || "$arg1" == '-p' ]]; then
exec ./llama-perplexity "$@"
elif [[ "$arg1" == '--all-in-one' || "$arg1" == '-a' ]]; then
echo "Converting PTH to GGML..."
for i in `ls $1/$2/ggml-model-f16.bin*`; do
for i in $(ls $1/$2/ggml-model-f16.bin*); do
if [ -f "${i/f16/q4_0}" ]; then
echo "Skip model quantization, it already exists: ${i/f16/q4_0}"
else
@ -30,6 +34,10 @@ else
echo "Available commands: "
echo " --run (-r): Run a model previously converted into ggml"
echo " ex: -m /models/7B/ggml-model-q4_0.bin -p \"Building a website can be done in 10 simple steps:\" -n 512"
echo " --bench (-b): Benchmark the performance of the inference for various parameters."
echo " ex: -m model.gguf"
echo " --perplexity (-p): Measure the perplexity of a model over a given text."
echo " ex: -m model.gguf -f file.txt"
echo " --convert (-c): Convert a llama model into ggml"
echo " ex: --outtype f16 \"/models/7B/\" "
echo " --quantize (-q): Optimize with quantization process ggml"

View file

@ -1,4 +1,4 @@
ARG UBUNTU_VERSION=jammy
ARG UBUNTU_VERSION=24.04
FROM ubuntu:$UBUNTU_VERSION AS build
@ -7,7 +7,7 @@ RUN apt update && apt install -y git build-essential cmake wget
# Install Vulkan SDK and cURL
RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - && \
wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list && \
wget -qO /etc/apt/sources.list.d/lunarg-vulkan-noble.list https://packages.lunarg.com/vulkan/lunarg-vulkan-noble.list && \
apt update -y && \
apt-get install -y vulkan-sdk libcurl4-openssl-dev curl
@ -34,7 +34,7 @@ RUN mkdir -p /app/full \
FROM ubuntu:$UBUNTU_VERSION AS base
RUN apt-get update \
&& apt-get install -y libgomp1 curl\
&& apt-get install -y libgomp1 curl libvulkan-dev \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
@ -55,8 +55,9 @@ RUN apt-get update \
git \
python3 \
python3-pip \
&& pip install --upgrade pip setuptools wheel \
&& pip install -r requirements.txt \
python3-wheel \
&& pip install --break-system-packages --upgrade setuptools \
&& pip install --break-system-packages -r requirements.txt \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \

View file

@ -56,6 +56,7 @@ jobs:
mkdir build
cd build
cmake .. \
-DCMAKE_BUILD_RPATH="@loader_path" \
-DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_CURL=ON \
-DGGML_METAL_USE_BF16=ON \
@ -120,6 +121,7 @@ jobs:
# Metal is disabled due to intermittent failures with Github runners not having a GPU:
# https://github.com/ggerganov/llama.cpp/actions/runs/8635935781/job/23674807267#step:5:2313
cmake -B build \
-DCMAKE_BUILD_RPATH="@loader_path" \
-DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_CURL=ON \
-DGGML_METAL=OFF \
@ -160,8 +162,8 @@ jobs:
path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip
name: llama-bin-macos-x64.zip
ubuntu-latest-cmake:
runs-on: ubuntu-latest
ubuntu-cpu-cmake:
runs-on: ubuntu-22.04
steps:
- name: Clone
@ -181,7 +183,10 @@ jobs:
run: |
mkdir build
cd build
cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DGGML_RPC=ON
cmake .. \
-DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_CURL=ON \
-DGGML_RPC=ON
cmake --build . --config Release -j $(nproc)
- name: Test
@ -256,7 +261,10 @@ jobs:
run: |
mkdir build
cd build
cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON -DCMAKE_BUILD_TYPE=${{ matrix.build_type }}
cmake .. \
-DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }}
cmake --build . --config ${{ matrix.build_type }} -j $(nproc)
- name: Build (no OpenMP)
@ -265,7 +273,11 @@ jobs:
run: |
mkdir build
cd build
cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} -DGGML_OPENMP=OFF
cmake .. \
-DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
-DGGML_OPENMP=OFF
cmake --build . --config ${{ matrix.build_type }} -j $(nproc)
- name: Test
@ -295,7 +307,8 @@ jobs:
run: |
mkdir build
cd build
cmake -DGGML_RPC=ON ..
cmake .. \
-DGGML_RPC=ON
cmake --build . --config Release -j $(nproc)
- name: Test
@ -325,14 +338,16 @@ jobs:
run: |
mkdir build
cd build
cmake -DGGML_VULKAN=ON ..
cmake .. \
-DGGML_VULKAN=ON
cmake --build . --config Release -j $(nproc)
- name: Test
id: cmake_test
run: |
cd build
ctest -L main --verbose --timeout 900
# This is using llvmpipe and runs slower than other backends
ctest -L main --verbose --timeout 1800
ubuntu-22-cmake-hip:
runs-on: ubuntu-22.04
@ -352,13 +367,18 @@ jobs:
- name: Build with native CMake HIP support
id: cmake_build
run: |
cmake -B build -S . -DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" -DGGML_HIP=ON
cmake -B build -S . \
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
-DGGML_HIP=ON
cmake --build build --config Release -j $(nproc)
- name: Build with legacy HIP support
id: cmake_build_legacy_hip
run: |
cmake -B build2 -S . -DCMAKE_C_COMPILER=hipcc -DCMAKE_CXX_COMPILER=hipcc -DGGML_HIP=ON
cmake -B build2 -S . \
-DCMAKE_C_COMPILER=hipcc \
-DCMAKE_CXX_COMPILER=hipcc \
-DGGML_HIP=ON
cmake --build build2 --config Release -j $(nproc)
ubuntu-22-cmake-musa:
@ -379,7 +399,8 @@ jobs:
- name: Build with native CMake MUSA support
id: cmake_build
run: |
cmake -B build -S . -DGGML_MUSA=ON
cmake -B build -S . \
-DGGML_MUSA=ON
cmake --build build --config Release -j $(nproc)
ubuntu-22-cmake-sycl:
@ -420,7 +441,10 @@ jobs:
source /opt/intel/oneapi/setvars.sh
mkdir build
cd build
cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx ..
cmake .. \
-DGGML_SYCL=ON \
-DCMAKE_C_COMPILER=icx \
-DCMAKE_CXX_COMPILER=icpx
cmake --build . --config Release -j $(nproc)
ubuntu-22-cmake-sycl-fp16:
@ -461,42 +485,13 @@ jobs:
source /opt/intel/oneapi/setvars.sh
mkdir build
cd build
cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON ..
cmake .. \
-DGGML_SYCL=ON \
-DCMAKE_C_COMPILER=icx \
-DCMAKE_CXX_COMPILER=icpx \
-DGGML_SYCL_F16=ON
cmake --build . --config Release -j $(nproc)
# TODO: build with GGML_METAL=OFF because test-backend-ops fail on "Apple Paravirtual device" and I don't know
# how to debug it.
# ref: https://github.com/ggerganov/llama.cpp/actions/runs/7132125951/job/19422043567?pr=4359#step:5:6584
# would be great if we fix these
macOS-latest-cmake:
runs-on: macos-latest
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Dependencies
id: depends
continue-on-error: true
run: |
brew update
- name: Build
id: cmake_build
run: |
sysctl -a
mkdir build
cd build
cmake -DLLAMA_FATAL_WARNINGS=ON -DGGML_METAL=OFF ..
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu)
- name: Test
id: cmake_test
run: |
cd build
ctest -L main --verbose --timeout 900
macOS-latest-cmake-ios:
runs-on: macos-latest
@ -619,6 +614,7 @@ jobs:
msystem: ${{matrix.sys}}
install: >-
base-devel
git
mingw-w64-${{matrix.env}}-toolchain
mingw-w64-${{matrix.env}}-cmake
mingw-w64-${{matrix.env}}-openblas
@ -827,7 +823,13 @@ jobs:
- name: Build with CMake
run: |
cmake -S . -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=89-real -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined -DLLAMA_FATAL_WARNINGS=ON
cmake -S . -B build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_CUDA_ARCHITECTURES=89-real \
-DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined \
-DLLAMA_FATAL_WARNINGS=ON \
-DGGML_NATIVE=OFF \
-DGGML_CUDA=ON
cmake --build build
windows-2019-cmake-cuda:
@ -916,7 +918,11 @@ jobs:
shell: cmd
run: |
call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
cmake -S . -B build -G "Ninja Multi-Config" -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_CUDA=ON -DGGML_RPC=ON
cmake -S . -B build -G "Ninja Multi-Config" ^
-DLLAMA_BUILD_SERVER=ON ^
-DGGML_NATIVE=OFF ^
-DGGML_CUDA=ON ^
-DGGML_RPC=ON
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
cmake --build build --config Release -j %NINJA_JOBS% -t ggml
cmake --build build --config Release
@ -1069,7 +1075,12 @@ jobs:
run: |
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
cmake -G "Unix Makefiles" -B build -S . -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" -DGGML_HIP=ON -DCMAKE_BUILD_TYPE=Release -DGGML_RPC=ON
cmake -G "Unix Makefiles" -B build -S . `
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
-DCMAKE_BUILD_TYPE=Release `
-DGGML_HIP=ON `
-DGGML_RPC=ON
cmake --build build -j ${env:NUMBER_OF_PROCESSORS}
windows-latest-cmake-hip-release:
@ -1107,7 +1118,13 @@ jobs:
run: |
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
cmake -G "Unix Makefiles" -B build -S . -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" -DGGML_HIP=ON -DCMAKE_BUILD_TYPE=Release -DAMDGPU_TARGETS=${{ matrix.gpu_target }} -DGGML_RPC=ON
cmake -G "Unix Makefiles" -B build -S . `
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
-DCMAKE_BUILD_TYPE=Release `
-DAMDGPU_TARGETS=${{ matrix.gpu_target }} `
-DGGML_HIP=ON `
-DGGML_RPC=ON
cmake --build build -j ${env:NUMBER_OF_PROCESSORS}
md "build\bin\rocblas\library\"
cp "${env:HIP_PATH}\bin\hipblas.dll" "build\bin\"
@ -1201,8 +1218,7 @@ jobs:
runs-on: ubuntu-latest
needs:
- ubuntu-latest-cmake
- macOS-latest-cmake
- ubuntu-cpu-cmake
- windows-latest-cmake
- windows-2019-cmake-cuda
- windows-latest-cmake-hip-release
@ -1461,3 +1477,37 @@ jobs:
# popd
# emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
# make
openEuler-latest-cmake-cann:
if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'Ascend NPU') }}
defaults:
run:
shell: bash -el {0}
runs-on: ubuntu-24.04-arm
strategy:
matrix:
cann:
- '8.0.rc3.beta1-910b-openeuler22.03-py3.10'
device:
- 'ascend910b3'
build:
- 'Release'
container: ascendai/cann:${{ matrix.cann }}
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Dependencies
run: |
yum update -y
yum install -y git gcc gcc-c++ make cmake
- name: Build
run: |
export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH}
cmake -S . -B build \
-DCMAKE_BUILD_TYPE=${{ matrix.build }} \
-DGGML_CANN=on \
-DSOC_TYPE=${{ matrix.device }}
cmake --build build -j $(nproc)

View file

@ -28,10 +28,11 @@ jobs:
push_to_registry:
name: Push Docker image to Docker Hub
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
env:
COMMIT_SHA: ${{ github.sha }}
strategy:
fail-fast: false
matrix:
config:
# Multi-stage build

View file

@ -16,6 +16,7 @@ endif()
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
set(LLAMA_STANDALONE ON)
@ -49,6 +50,8 @@ endif()
if (MSVC)
add_compile_options("$<$<COMPILE_LANGUAGE:C>:/utf-8>")
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/utf-8>")
add_compile_options("$<$<COMPILE_LANGUAGE:C>:/bigobj>")
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/bigobj>")
endif()
#
@ -185,27 +188,14 @@ set(LLAMA_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location o
set(LLAMA_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files")
set(LLAMA_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files")
# At the moment some compile definitions are placed within the ggml/src
# directory but not exported on the `ggml` target. This could be improved by
# determining _precisely_ which defines are necessary for the llama-config
# package.
#
set(GGML_TRANSIENT_DEFINES)
get_target_property(GGML_DIRECTORY ggml SOURCE_DIR)
get_directory_property(GGML_DIR_DEFINES DIRECTORY ${GGML_DIRECTORY} COMPILE_DEFINITIONS)
if (GGML_DIR_DEFINES)
list(APPEND GGML_TRANSIENT_DEFINES ${GGML_DIR_DEFINES})
endif()
get_target_property(GGML_TARGET_DEFINES ggml COMPILE_DEFINITIONS)
if (GGML_TARGET_DEFINES)
list(APPEND GGML_TRANSIENT_DEFINES ${GGML_TARGET_DEFINES})
endif()
get_target_property(GGML_LINK_LIBRARIES ggml LINK_LIBRARIES)
# all public headers
set(LLAMA_PUBLIC_HEADERS
${CMAKE_CURRENT_SOURCE_DIR}/include/llama.h
${CMAKE_CURRENT_SOURCE_DIR}/include/llama-cpp.h)
set_target_properties(llama PROPERTIES PUBLIC_HEADER "${LLAMA_PUBLIC_HEADERS}")
set_target_properties(llama
PROPERTIES
PUBLIC_HEADER "${LLAMA_PUBLIC_HEADERS}")
install(TARGETS llama LIBRARY PUBLIC_HEADER)
configure_package_config_file(

View file

@ -16,7 +16,10 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
## Hot topics
- **Introducing GGUF-my-LoRA** https://github.com/ggerganov/llama.cpp/discussions/10123
- **How to use [MTLResidencySet](https://developer.apple.com/documentation/metal/mtlresidencyset?language=objc) to keep the GPU memory active?** https://github.com/ggerganov/llama.cpp/pull/11427
- **VS Code extension for FIM completions:** https://github.com/ggml-org/llama.vscode
- Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim
- Introducing GGUF-my-LoRA https://github.com/ggerganov/llama.cpp/discussions/10123
- Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggerganov/llama.cpp/discussions/9669
- Hugging Face GGUF editor: [discussion](https://github.com/ggerganov/llama.cpp/discussions/9268) | [tool](https://huggingface.co/spaces/CISCai/gguf-editor)

View file

@ -3,159 +3,13 @@ set(LLAMA_BUILD_COMMIT @LLAMA_BUILD_COMMIT@)
set(LLAMA_BUILD_NUMBER @LLAMA_BUILD_NUMBER@)
set(LLAMA_SHARED_LIB @BUILD_SHARED_LIBS@)
set(GGML_STATIC @GGML_STATIC@)
set(GGML_NATIVE @GGML_NATIVE@)
set(GGML_LTO @GGML_LTO@)
set(GGML_CCACHE @GGML_CCACHE@)
set(GGML_AVX @GGML_AVX@)
set(GGML_AVX2 @GGML_AVX2@)
set(GGML_AVX512 @GGML_AVX512@)
set(GGML_AVX512_VBMI @GGML_AVX512_VBMI@)
set(GGML_AVX512_VNNI @GGML_AVX512_VNNI@)
set(GGML_AVX512_BF16 @GGML_AVX512_BF16@)
set(GGML_AMX_TILE @GGML_AMX_TILE@)
set(GGML_AMX_INT8 @GGML_AMX_INT8@)
set(GGML_AMX_BF16 @GGML_AMX_BF16@)
set(GGML_FMA @GGML_FMA@)
set(GGML_LASX @GGML_LASX@)
set(GGML_LSX @GGML_LSX@)
set(GGML_RVV @GGML_RVV@)
set(GGML_SVE @GGML_SVE@)
set(GGML_ACCELERATE @GGML_ACCELERATE@)
set(GGML_OPENMP @GGML_OPENMP@)
set(GGML_CPU_HBM @GGML_CPU_HBM@)
set(GGML_BLAS_VENDOR @GGML_BLAS_VENDOR@)
set(GGML_CUDA_FORCE_MMQ @GGML_CUDA_FORCE_MMQ@)
set(GGML_CUDA_FORCE_CUBLAS @GGML_CUDA_FORCE_CUBLAS@)
set(GGML_CUDA_F16 @GGML_CUDA_F16@)
set(GGML_CUDA_PEER_MAX_BATCH_SIZE @GGML_CUDA_PEER_MAX_BATCH_SIZE@)
set(GGML_CUDA_NO_PEER_COPY @GGML_CUDA_NO_PEER_COPY@)
set(GGML_CUDA_NO_VMM @GGML_CUDA_NO_VMM@)
set(GGML_CUDA_FA_ALL_QUANTS @GGML_CUDA_FA_ALL_QUANTS@)
set(GGML_CUDA_GRAPHS @GGML_CUDA_GRAPHS@)
set(GGML_HIP_UMA @GGML_HIP_UMA@)
set(GGML_VULKAN_CHECK_RESULTS @GGML_VULKAN_CHECK_RESULTS@)
set(GGML_VULKAN_DEBUG @GGML_VULKAN_DEBUG@)
set(GGML_VULKAN_MEMORY_DEBUG @GGML_VULKAN_MEMORY_DEBUG@)
set(GGML_VULKAN_SHADER_DEBUG_INFO @GGML_VULKAN_SHADER_DEBUG_INFO@)
set(GGML_VULKAN_PERF @GGML_VULKAN_PERF@)
set(GGML_VULKAN_VALIDATE @GGML_VULKAN_VALIDATE@)
set(GGML_VULKAN_RUN_TESTS @GGML_VULKAN_RUN_TESTS@)
set(GGML_METAL_USE_BF16 @GGML_METAL_USE_BF16@)
set(GGML_METAL_NDEBUG @GGML_METAL_NDEBUG@)
set(GGML_METAL_SHADER_DEBUG @GGML_METAL_SHADER_DEBUG@)
set(GGML_METAL_EMBED_LIBRARY @GGML_METAL_EMBED_LIBRARY@)
set(GGML_METAL_MACOSX_VERSION_MIN @GGML_METAL_MACOSX_VERSION_MIN@)
set(GGML_METAL_STD @GGML_METAL_STD@)
set(GGML_SYCL_F16 @GGML_SYCL_F16@)
set(GGML_SYCL_TARGET @GGML_SYCL_TARGET@)
set(GGML_SYCL_DEVICE_ARCH @GGML_SYCL_DEVICE_ARCH@)
@PACKAGE_INIT@
set_and_check(LLAMA_INCLUDE_DIR "@PACKAGE_LLAMA_INCLUDE_INSTALL_DIR@")
set_and_check(LLAMA_LIB_DIR "@PACKAGE_LLAMA_LIB_INSTALL_DIR@")
set_and_check(LLAMA_BIN_DIR "@PACKAGE_LLAMA_BIN_INSTALL_DIR@")
find_package(Threads REQUIRED)
set(_llama_transient_defines "@GGML_TRANSIENT_DEFINES@")
set(_llama_link_deps "")
set(_llama_link_opts "")
foreach(_ggml_lib ggml ggml-base)
string(REPLACE "-" "_" _ggml_lib_var "${_ggml_lib}_LIBRARY")
find_library(${_ggml_lib_var} ${_ggml_lib}
REQUIRED
HINTS ${LLAMA_LIB_DIR}
NO_CMAKE_FIND_ROOT_PATH
)
list(APPEND _llama_link_deps "${${_ggml_lib_var}}")
message(STATUS "Found ${${_ggml_lib_var}}")
endforeach()
foreach(backend amx blas cann cpu cuda hip kompute metal musa rpc sycl vulkan)
string(TOUPPER "GGML_${backend}" backend_id)
set(_ggml_lib "ggml-${backend}")
string(REPLACE "-" "_" _ggml_lib_var "${_ggml_lib}_LIBRARY")
find_library(${_ggml_lib_var} ${_ggml_lib}
HINTS ${LLAMA_LIB_DIR}
NO_CMAKE_FIND_ROOT_PATH
)
if(${_ggml_lib_var})
list(APPEND _llama_link_deps "${${_ggml_lib_var}}")
set(${backend_id} ON)
message(STATUS "Found backend ${${_ggml_lib_var}}")
else()
set(${backend_id} OFF)
endif()
endforeach()
if (NOT LLAMA_SHARED_LIB)
if (APPLE AND GGML_ACCELERATE)
find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED)
list(APPEND _llama_link_deps ${ACCELERATE_FRAMEWORK})
endif()
if (GGML_OPENMP)
find_package(OpenMP REQUIRED)
list(APPEND _llama_link_deps OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
endif()
if (GGML_CPU_HBM)
find_library(memkind memkind REQUIRED)
list(APPEND _llama_link_deps memkind)
endif()
if (GGML_BLAS)
find_package(BLAS REQUIRED)
list(APPEND _llama_link_deps ${BLAS_LIBRARIES})
list(APPEND _llama_link_opts ${BLAS_LINKER_FLAGS})
endif()
if (GGML_CUDA)
find_package(CUDAToolkit REQUIRED)
endif()
if (GGML_METAL)
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED)
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
list(APPEND _llama_link_deps ${FOUNDATION_LIBRARY}
${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK})
endif()
if (GGML_VULKAN)
find_package(Vulkan REQUIRED)
list(APPEND _llama_link_deps Vulkan::Vulkan)
endif()
if (GGML_HIP)
find_package(hip REQUIRED)
find_package(hipblas REQUIRED)
find_package(rocblas REQUIRED)
list(APPEND _llama_link_deps hip::host roc::rocblas roc::hipblas)
endif()
if (GGML_SYCL)
find_package(DNNL)
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
list(APPEND _llama_link_deps DNNL::dnnl)
endif()
if (WIN32)
find_package(IntelSYCL REQUIRED)
find_package(MKL REQUIRED)
list(APPEND _llama_link_deps IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
endif()
endif()
endif()
find_package(ggml REQUIRED HINTS ${LLAMA_LIB_DIR}/cmake)
find_library(llama_LIBRARY llama
REQUIRED
@ -167,12 +21,10 @@ add_library(llama UNKNOWN IMPORTED)
set_target_properties(llama
PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${LLAMA_INCLUDE_DIR}"
INTERFACE_LINK_LIBRARIES "${_llama_link_deps}"
INTERFACE_LINK_OPTIONS "${_llama_link_opts}"
INTERFACE_COMPILE_DEFINITIONS "${_llama_transient_defines}"
INTERFACE_LINK_LIBRARIES "ggml::ggml;ggml::ggml-base;"
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
IMPORTED_LOCATION "${llama_LIBRARY}"
INTERFACE_COMPILE_FEATURES cxx_std_11
POSITION_INDEPENDENT_CODE ON )
INTERFACE_COMPILE_FEATURES c_std_90
POSITION_INDEPENDENT_CODE ON)
check_required_components(Llama)

View file

@ -877,7 +877,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.warmup = false;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}));
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING}));
add_opt(common_arg(
{"--spm-infill"},
string_format(

View file

@ -133,7 +133,7 @@ The docker build option is currently limited to *intel GPU* targets.
### Build image
```sh
# Using FP16
docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" -f .devops/llama-cli-intel.Dockerfile .
docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" --target light -f .devops/intel.Dockerfile .
```
*Notes*:

View file

@ -286,7 +286,7 @@ You don't need to install Vulkan SDK. It will be installed inside the container.
```sh
# Build the image
docker build -t llama-cpp-vulkan -f .devops/llama-cli-vulkan.Dockerfile .
docker build -t llama-cpp-vulkan --target light -f .devops/vulkan.Dockerfile .
# Then, use it:
docker run -it --rm -v "$(pwd):/app:Z" --device /dev/dri/renderD128:/dev/dri/renderD128 --device /dev/dri/card1:/dev/dri/card1 llama-cpp-vulkan -m "/app/models/YOUR_MODEL_FILE" -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33

View file

@ -60,9 +60,9 @@ Assuming one has the [nvidia-container-toolkit](https://github.com/NVIDIA/nvidia
## Building Docker locally
```bash
docker build -t local/llama.cpp:full-cuda -f .devops/full-cuda.Dockerfile .
docker build -t local/llama.cpp:light-cuda -f .devops/llama-cli-cuda.Dockerfile .
docker build -t local/llama.cpp:server-cuda -f .devops/llama-server-cuda.Dockerfile .
docker build -t local/llama.cpp:full-cuda --target full -f .devops/cuda.Dockerfile .
docker build -t local/llama.cpp:light-cuda --target light -f .devops/cuda.Dockerfile .
docker build -t local/llama.cpp:server-cuda --target server -f .devops/cuda.Dockerfile .
```
You may want to pass in some different `ARGS`, depending on the CUDA environment supported by your container host, as well as the GPU architecture.
@ -95,9 +95,9 @@ Assuming one has the [mt-container-toolkit](https://developer.mthreads.com/musa/
## Building Docker locally
```bash
docker build -t local/llama.cpp:full-musa -f .devops/full-musa.Dockerfile .
docker build -t local/llama.cpp:light-musa -f .devops/llama-cli-musa.Dockerfile .
docker build -t local/llama.cpp:server-musa -f .devops/llama-server-musa.Dockerfile .
docker build -t local/llama.cpp:full-musa --target full -f .devops/musa.Dockerfile .
docker build -t local/llama.cpp:light-musa --target light -f .devops/musa.Dockerfile .
docker build -t local/llama.cpp:server-musa --target server -f .devops/musa.Dockerfile .
```
You may want to pass in some different `ARGS`, depending on the MUSA environment supported by your container host, as well as the GPU architecture.

View file

@ -1,32 +0,0 @@
cmake_minimum_required(VERSION 3.12)
project("llama-cli-cmake-pkg" C CXX)
set(TARGET llama-cli-cmake-pkg)
find_package(Llama 0.0.1 REQUIRED)
# Bake common functionality in with target. Because applications
# using the relocatable Llama package should be outside of the
# source tree, llama-cli-cmake-pkg pretends the dependencies are built-in.
set(_common_path "${CMAKE_CURRENT_LIST_DIR}/../../common")
add_library(common OBJECT)
file(GLOB _common_files
"${_common_path}/*.h"
"${_common_path}/*.cpp"
)
target_sources(common PRIVATE ${_common_files})
# If the common project was part of "llama-cli-cmake-pkg" the transient
# defines would automatically be attached. Because the common func-
# tionality is separate, but dependent upon the defines, it must be
# explicitly extracted from the "llama" target.
#
get_target_property(_llama_transient_defines llama
INTERFACE_COMPILE_DEFINITIONS)
target_compile_definitions(common PRIVATE "${_llama_transient_defines}")
add_executable(${TARGET} ${CMAKE_CURRENT_LIST_DIR}/../main/main.cpp)
target_include_directories(${TARGET} PRIVATE ${_common_path})
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

View file

@ -1,31 +0,0 @@
# llama.cpp/example/main-cmake-pkg
This program builds [llama-cli](../main) using a relocatable CMake package. It serves as an example of using the `find_package()` CMake command to conveniently include [llama.cpp](https://github.com/ggerganov/llama.cpp) in projects which live outside of the source tree.
## Building
Because this example is "outside of the source tree", it is important to first build/install llama.cpp using CMake. An example is provided here, but please see the [llama.cpp build instructions](../..) for more detailed build instructions.
### Considerations
When hardware acceleration libraries are used (e.g. CUDA, Metal, etc.), CMake must be able to locate the associated CMake package.
### Build llama.cpp and install to C:\LlamaCPP directory
```cmd
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp
cmake -B build -DBUILD_SHARED_LIBS=OFF -G "Visual Studio 17 2022" -A x64
cmake --build build --config Release
cmake --install build --prefix C:/LlamaCPP
```
### Build llama-cli-cmake-pkg
```cmd
cd ..\examples\main-cmake-pkg
cmake -B build -DBUILD_SHARED_LIBS=OFF -DCMAKE_PREFIX_PATH="C:/LlamaCPP/lib/cmake/Llama" -G "Visual Studio 17 2022" -A x64
cmake --build build --config Release
cmake --install build --prefix C:/MyLlamaApp
```

View file

@ -310,9 +310,9 @@ These options help improve the performance and memory usage of the LLaMA models.
### Batch Size
- `-b N, --batch-size N`: Set the batch size for prompt processing (default: `2048`). This large batch size benefits users who have BLAS installed and enabled it during the build. If you don't have BLAS enabled ("BLAS=0"), you can use a smaller number, such as 8, to see the prompt progress as it's evaluated in some situations.
- `-ub N`, `--ubatch-size N`: Physical batch size. This is the maximum number of tokens that may be processed at a time. Increasing this value may improve performance during prompt processing, at the expense of higher memory usage. Default: `512`.
- `-ub N`, `--ubatch-size N`: physical maximum batch size. This is for pipeline parallelization. Default: `512`.
- `-b N`, `--batch-size N`: Logical batch size. Increasing this value above the value of the physical batch size may improve prompt processing performance when using multiple GPUs with pipeline parallelism. Default: `2048`.
### Prompt Caching

View file

@ -3,11 +3,10 @@
The purpose of this example is to demonstrate a minimal usage of llama.cpp for running models.
```bash
llama-run granite-code
llama-run granite3-moe
```
```bash
llama-run -h
Description:
Runs a llm
@ -17,7 +16,7 @@ Usage:
Options:
-c, --context-size <value>
Context size (default: 2048)
-n, --ngl <value>
-n, -ngl, --ngl <value>
Number of GPU layers (default: 0)
--temp <value>
Temperature (default: 0.8)

View file

@ -147,7 +147,8 @@ class Opt {
if (handle_option_with_value(argc, argv, i, context_size) == 1) {
return 1;
}
} else if (options_parsing && (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--ngl") == 0)) {
} else if (options_parsing &&
(strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "-ngl") == 0 || strcmp(argv[i], "--ngl") == 0)) {
if (handle_option_with_value(argc, argv, i, ngl) == 1) {
return 1;
}
@ -180,6 +181,10 @@ class Opt {
}
}
if (model_.empty()){
return 1;
}
return 0;
}
@ -194,7 +199,7 @@ class Opt {
"Options:\n"
" -c, --context-size <value>\n"
" Context size (default: %d)\n"
" -n, --ngl <value>\n"
" -n, -ngl, --ngl <value>\n"
" Number of GPU layers (default: %d)\n"
" --temp <value>\n"
" Temperature (default: %.1f)\n"
@ -318,6 +323,10 @@ class HttpClient {
public:
int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
const bool progress, std::string * response_str = nullptr) {
if (std::filesystem::exists(output_file)) {
return 0;
}
std::string output_file_partial;
curl = curl_easy_init();
if (!curl) {
@ -345,7 +354,11 @@ class HttpClient {
data.file_size = set_resume_point(output_file_partial);
set_progress_options(progress, data);
set_headers(headers);
perform(url);
CURLcode res = perform(url);
if (res != CURLE_OK){
printe("Fetching resource '%s' failed: %s\n", url.c_str(), curl_easy_strerror(res));
return 1;
}
if (!output_file.empty()) {
std::filesystem::rename(output_file_partial, output_file);
}
@ -410,16 +423,12 @@ class HttpClient {
}
}
void perform(const std::string & url) {
CURLcode res;
CURLcode perform(const std::string & url) {
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_DEFAULT_PROTOCOL, "https");
curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L);
res = curl_easy_perform(curl);
if (res != CURLE_OK) {
printe("curl_easy_perform() failed: %s\n", curl_easy_strerror(res));
}
return curl_easy_perform(curl);
}
static std::string human_readable_time(double seconds) {
@ -557,13 +566,14 @@ class LlamaData {
}
sampler = initialize_sampler(opt);
return 0;
}
private:
#ifdef LLAMA_USE_CURL
int download(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
const bool progress, std::string * response_str = nullptr) {
int download(const std::string & url, const std::string & output_file, const bool progress,
const std::vector<std::string> & headers = {}, std::string * response_str = nullptr) {
HttpClient http;
if (http.init(url, headers, output_file, progress, response_str)) {
return 1;
@ -572,48 +582,85 @@ class LlamaData {
return 0;
}
#else
int download(const std::string &, const std::vector<std::string> &, const std::string &, const bool,
int download(const std::string &, const std::string &, const bool, const std::vector<std::string> & = {},
std::string * = nullptr) {
printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
return 1;
}
#endif
int huggingface_dl(const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
// Find the second occurrence of '/' after protocol string
size_t pos = model.find('/');
pos = model.find('/', pos + 1);
if (pos == std::string::npos) {
return 1;
}
const std::string hfr = model.substr(0, pos);
const std::string hff = model.substr(pos + 1);
const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
return download(url, headers, bn, true);
}
int ollama_dl(std::string & model, const std::vector<std::string> headers, const std::string & bn) {
if (model.find('/') == std::string::npos) {
model = "library/" + model;
}
std::string model_tag = "latest";
size_t colon_pos = model.find(':');
// Helper function to handle model tag extraction and URL construction
std::pair<std::string, std::string> extract_model_and_tag(std::string & model, const std::string & base_url) {
std::string model_tag = "latest";
const size_t colon_pos = model.find(':');
if (colon_pos != std::string::npos) {
model_tag = model.substr(colon_pos + 1);
model = model.substr(0, colon_pos);
}
std::string manifest_url = "https://registry.ollama.ai/v2/" + model + "/manifests/" + model_tag;
std::string url = base_url + model + "/manifests/" + model_tag;
return { model, url };
}
// Helper function to download and parse the manifest
int download_and_parse_manifest(const std::string & url, const std::vector<std::string> & headers,
nlohmann::json & manifest) {
std::string manifest_str;
const int ret = download(manifest_url, headers, "", false, &manifest_str);
int ret = download(url, "", false, headers, &manifest_str);
if (ret) {
return ret;
}
nlohmann::json manifest = nlohmann::json::parse(manifest_str);
std::string layer;
manifest = nlohmann::json::parse(manifest_str);
return 0;
}
int huggingface_dl(std::string & model, const std::string & bn) {
// Find the second occurrence of '/' after protocol string
size_t pos = model.find('/');
pos = model.find('/', pos + 1);
std::string hfr, hff;
std::vector<std::string> headers = { "User-Agent: llama-cpp", "Accept: application/json" };
std::string url;
if (pos == std::string::npos) {
auto [model_name, manifest_url] = extract_model_and_tag(model, "https://huggingface.co/v2/");
hfr = model_name;
nlohmann::json manifest;
int ret = download_and_parse_manifest(manifest_url, headers, manifest);
if (ret) {
return ret;
}
hff = manifest["ggufFile"]["rfilename"];
} else {
hfr = model.substr(0, pos);
hff = model.substr(pos + 1);
}
url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
return download(url, bn, true, headers);
}
int ollama_dl(std::string & model, const std::string & bn) {
const std::vector<std::string> headers = { "Accept: application/vnd.docker.distribution.manifest.v2+json" };
if (model.find('/') == std::string::npos) {
model = "library/" + model;
}
auto [model_name, manifest_url] = extract_model_and_tag(model, "https://registry.ollama.ai/v2/");
nlohmann::json manifest;
int ret = download_and_parse_manifest(manifest_url, {}, manifest);
if (ret) {
return ret;
}
std::string layer;
for (const auto & l : manifest["layers"]) {
if (l["mediaType"] == "application/vnd.ollama.image.model") {
layer = l["digest"];
@ -621,8 +668,34 @@ class LlamaData {
}
}
std::string blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + layer;
return download(blob_url, headers, bn, true);
std::string blob_url = "https://registry.ollama.ai/v2/" + model_name + "/blobs/" + layer;
return download(blob_url, bn, true, headers);
}
int github_dl(const std::string & model, const std::string & bn) {
std::string repository = model;
std::string branch = "main";
const size_t at_pos = model.find('@');
if (at_pos != std::string::npos) {
repository = model.substr(0, at_pos);
branch = model.substr(at_pos + 1);
}
const std::vector<std::string> repo_parts = string_split(repository, "/");
if (repo_parts.size() < 3) {
printe("Invalid GitHub repository format\n");
return 1;
}
const std::string & org = repo_parts[0];
const std::string & project = repo_parts[1];
std::string url = "https://raw.githubusercontent.com/" + org + "/" + project + "/" + branch;
for (size_t i = 2; i < repo_parts.size(); ++i) {
url += "/" + repo_parts[i];
}
return download(url, bn, true);
}
std::string basename(const std::string & path) {
@ -634,37 +707,41 @@ class LlamaData {
return path.substr(pos + 1);
}
int remove_proto(std::string & model_) {
const std::string::size_type pos = model_.find("://");
int rm_until_substring(std::string & model_, const std::string & substring) {
const std::string::size_type pos = model_.find(substring);
if (pos == std::string::npos) {
return 1;
}
model_ = model_.substr(pos + 3); // Skip past "://"
model_ = model_.substr(pos + substring.size()); // Skip past the substring
return 0;
}
int resolve_model(std::string & model_) {
int ret = 0;
if (string_starts_with(model_, "file://") || std::filesystem::exists(model_)) {
remove_proto(model_);
rm_until_substring(model_, "://");
return ret;
}
const std::string bn = basename(model_);
const std::vector<std::string> headers = { "--header",
"Accept: application/vnd.docker.distribution.manifest.v2+json" };
if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) {
remove_proto(model_);
ret = huggingface_dl(model_, headers, bn);
} else if (string_starts_with(model_, "ollama://")) {
remove_proto(model_);
ret = ollama_dl(model_, headers, bn);
} else if (string_starts_with(model_, "https://")) {
download(model_, headers, bn, true);
} else {
ret = ollama_dl(model_, headers, bn);
const std::string bn = basename(model_);
if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://") ||
string_starts_with(model_, "hf.co/")) {
rm_until_substring(model_, "hf.co/");
rm_until_substring(model_, "://");
ret = huggingface_dl(model_, bn);
} else if ((string_starts_with(model_, "https://") || string_starts_with(model_, "http://")) &&
!string_starts_with(model_, "https://ollama.com/library/")) {
ret = download(model_, bn, true);
} else if (string_starts_with(model_, "github:") || string_starts_with(model_, "github://")) {
rm_until_substring(model_, "github:");
rm_until_substring(model_, "://");
ret = github_dl(model_, bn);
} else { // ollama:// or nothing
rm_until_substring(model_, "ollama.com/library/");
rm_until_substring(model_, "://");
ret = ollama_dl(model_, bn);
}
model_ = bn;

View file

@ -236,9 +236,13 @@ npm i
# to run the dev server
npm run dev
# to build the public/index.html
# to build the public/index.html.gz
npm run build
```
After `public/index.html.gz` has been generated we need to generate the c++
headers (like build/examples/server/index.html.gz.hpp) that will be included
by server.cpp. This is done by building `llama-server` as described in the
[build](#build) section above.
NOTE: if you are using the vite dev server, you can change the API base URL to llama.cpp. To do that, run this code snippet in browser's console:

Binary file not shown.

View file

@ -14,7 +14,7 @@
// mime type for sending response
#define MIMETYPE_JSON "application/json; charset=utf-8"
// auto generated files (update with ./deps.sh)
// auto generated files (see README.md for details)
#include "index.html.gz.hpp"
#include "loading.html.hpp"
@ -1436,6 +1436,10 @@ struct server_queue {
int post(server_task task, bool front = false) {
std::unique_lock<std::mutex> lock(mutex_tasks);
GGML_ASSERT(task.id != -1);
// if this is cancel task make sure to clean up pending tasks
if (task.type == SERVER_TASK_TYPE_CANCEL) {
cleanup_pending_task(task.id_target);
}
QUE_DBG("new task, id = %d, front = %d\n", task.id, front);
if (front) {
queue_tasks.push_front(std::move(task));
@ -1453,6 +1457,10 @@ struct server_queue {
if (task.id == -1) {
task.id = id++;
}
// if this is cancel task make sure to clean up pending tasks
if (task.type == SERVER_TASK_TYPE_CANCEL) {
cleanup_pending_task(task.id_target);
}
QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
if (front) {
queue_tasks.push_front(std::move(task));
@ -1553,6 +1561,20 @@ struct server_queue {
}
}
}
private:
void cleanup_pending_task(int id_target) {
// no need lock because this is called exclusively by post()
auto rm_func = [id_target](const server_task & task) {
return task.id_target == id_target;
};
queue_tasks.erase(
std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func),
queue_tasks.end());
queue_tasks_deferred.erase(
std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
queue_tasks_deferred.end());
}
};
struct server_response {
@ -1588,6 +1610,12 @@ struct server_response {
std::unique_lock<std::mutex> lock(mutex_results);
waiting_task_ids.erase(id_task);
// make sure to clean up all pending results
queue_results.erase(
std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) {
return res->id == id_task;
}),
queue_results.end());
}
void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
@ -1607,7 +1635,7 @@ struct server_response {
return !queue_results.empty();
});
for (int i = 0; i < (int) queue_results.size(); i++) {
for (size_t i = 0; i < queue_results.size(); i++) {
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
server_task_result_ptr res = std::move(queue_results[i]);
queue_results.erase(queue_results.begin() + i);
@ -1624,12 +1652,6 @@ struct server_response {
server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) {
while (true) {
std::unique_lock<std::mutex> lock(mutex_results);
bool cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout), [&]{
return !queue_results.empty();
});
if (!cr_res) {
return nullptr;
}
for (int i = 0; i < (int) queue_results.size(); i++) {
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
@ -1638,6 +1660,11 @@ struct server_response {
return res;
}
}
std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout));
if (cr_res == std::cv_status::timeout) {
return nullptr;
}
}
// should never reach here
@ -1781,6 +1808,9 @@ struct server_context {
// force F16 KV cache for the draft model for extra performance
cparams_dft.type_k = GGML_TYPE_F16;
cparams_dft.type_v = GGML_TYPE_F16;
// the context is not needed - we will create one for each slot
llama_init_dft.context.reset();
}
chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
@ -2389,8 +2419,8 @@ struct server_context {
server_task task(SERVER_TASK_TYPE_CANCEL);
task.id_target = id_task;
cancel_tasks.push_back(task);
queue_results.remove_waiting_task_id(id_task);
cancel_tasks.push_back(task);
}
// push to beginning of the queue, so it has highest priority
queue_tasks.post(cancel_tasks, true);

View file

@ -87,7 +87,7 @@ def test_completion_stream_vs_non_stream():
assert content_stream == res_non_stream.body["content"]
def test_completion_stream_with_openai_library():
def test_completion_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
@ -102,7 +102,7 @@ def test_completion_stream_with_openai_library():
assert match_regex("(going|bed)+", res.choices[0].text)
def test_completion_with_openai_library():
def test_completion_stream_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")

View file

@ -141,6 +141,7 @@
:msg="pendingMsg"
:key="pendingMsg.id"
:is-generating="isGenerating"
:show-thought-in-progress="config.showThoughtInProgress"
:edit-user-msg-and-regenerate="() => {}"
:regenerate-msg="() => {}"></message-bubble>
</div>
@ -202,6 +203,20 @@
</template>
</div>
</details>
<!-- Section: Reasoning models -->
<details class="collapse collapse-arrow bg-base-200 mb-2 overflow-visible">
<summary class="collapse-title font-bold">Reasoning models</summary>
<div class="collapse-content">
<div class="flex flex-row items-center mb-2">
<input type="checkbox" class="checkbox" v-model="config.showThoughtInProgress" />
<span class="ml-4">Expand though process by default for generating message</span>
</div>
<div class="flex flex-row items-center mb-2">
<input type="checkbox" class="checkbox" v-model="config.excludeThoughtOnReq" />
<span class="ml-4">Exclude thought process when sending request to API (Recommended for DeepSeek-R1)</span>
</div>
</div>
</details>
<!-- Section: Advanced config -->
<details class="collapse collapse-arrow bg-base-200 mb-2 overflow-visible">
<summary class="collapse-title font-bold">Advanced config</summary>
@ -261,7 +276,17 @@
<span v-if="msg.content === null" class="loading loading-dots loading-md"></span>
<!-- render message as markdown -->
<div v-else dir="auto">
<vue-markdown :source="msg.content"></vue-markdown>
<details v-if="msg.role === 'assistant' && splitMsgContent.cot" class="collapse bg-base-200 collapse-arrow mb-4" :open="splitMsgContent.isThinking && showThoughtInProgress">
<summary class="collapse-title">
<span v-if="splitMsgContent.isThinking">
<span v-if="isGenerating" class="loading loading-spinner loading-md mr-2" style="vertical-align: middle;"></span>
<b>Thinking</b>
</span>
<b v-else>Thought Process</b>
</summary>
<vue-markdown :source="splitMsgContent.cot" dir="auto" class="collapse-content"></vue-markdown>
</details>
<vue-markdown :source="splitMsgContent.content"></vue-markdown>
</div>
<!-- render timings if enabled -->
<div class="dropdown dropdown-hover dropdown-top mt-2" v-if="timings && config.showTokensPerSecond">

View file

@ -17,6 +17,11 @@ import { asyncIterator } from '@sec-ant/readable-stream/ponyfill/asyncIterator';
const isDev = import.meta.env.MODE === 'development';
// types
/** @typedef {{ id: number, role: 'user' | 'assistant', content: string, timings: any }} Message */
/** @typedef {{ role: 'user' | 'assistant', content: string }} APIMessage */
/** @typedef {{ id: string, lastModified: number, messages: Array<Message> }} Conversation */
// utility functions
const isString = (x) => !!x.toLowerCase;
const isBoolean = (x) => x === true || x === false;
@ -50,6 +55,8 @@ const CONFIG_DEFAULT = {
apiKey: '',
systemMessage: 'You are a helpful assistant.',
showTokensPerSecond: false,
showThoughtInProgress: false,
excludeThoughtOnReq: true,
// make sure these default values are in sync with `common.h`
samplers: 'edkypmxt',
temperature: 0.8,
@ -172,6 +179,7 @@ const MessageBubble = defineComponent({
config: Object,
msg: Object,
isGenerating: Boolean,
showThoughtInProgress: Boolean,
editUserMsgAndRegenerate: Function,
regenerateMsg: Function,
},
@ -188,7 +196,31 @@ const MessageBubble = defineComponent({
prompt_per_second: this.msg.timings.prompt_n / (this.msg.timings.prompt_ms / 1000),
predicted_per_second: this.msg.timings.predicted_n / (this.msg.timings.predicted_ms / 1000),
};
}
},
splitMsgContent() {
const content = this.msg.content;
if (this.msg.role !== 'assistant') {
return { content };
}
let actualContent = '';
let cot = '';
let isThinking = false;
let thinkSplit = content.split('<think>', 2);
actualContent += thinkSplit[0];
while (thinkSplit[1] !== undefined) {
// <think> tag found
thinkSplit = thinkSplit[1].split('</think>', 2);
cot += thinkSplit[0];
isThinking = true;
if (thinkSplit[1] !== undefined) {
// </think> closing tag found
isThinking = false;
thinkSplit = thinkSplit[1].split('<think>', 2);
actualContent += thinkSplit[0];
}
}
return { content: actualContent, cot, isThinking };
},
},
methods: {
copyMsg() {
@ -208,7 +240,10 @@ const MessageBubble = defineComponent({
// format: { [convId]: { id: string, lastModified: number, messages: [...] } }
// convId is a string prefixed with 'conv-'
const StorageUtils = {
// manage conversations
/**
* manage conversations
* @returns {Array<Conversation>}
*/
getAllConversations() {
const res = [];
for (const key in localStorage) {
@ -219,11 +254,19 @@ const StorageUtils = {
res.sort((a, b) => b.lastModified - a.lastModified);
return res;
},
// can return null if convId does not exist
/**
* can return null if convId does not exist
* @param {string} convId
* @returns {Conversation | null}
*/
getOneConversation(convId) {
return JSON.parse(localStorage.getItem(convId) || 'null');
},
// if convId does not exist, create one
/**
* if convId does not exist, create one
* @param {string} convId
* @param {Message} msg
*/
appendMsg(convId, msg) {
if (msg.content === null) return;
const conv = StorageUtils.getOneConversation(convId) || {
@ -235,12 +278,24 @@ const StorageUtils = {
conv.lastModified = Date.now();
localStorage.setItem(convId, JSON.stringify(conv));
},
/**
* Get new conversation id
* @returns {string}
*/
getNewConvId() {
return `conv-${Date.now()}`;
},
/**
* remove conversation by id
* @param {string} convId
*/
remove(convId) {
localStorage.removeItem(convId);
},
/**
* remove all conversations
* @param {string} convId
*/
filterAndKeepMsgs(convId, predicate) {
const conv = StorageUtils.getOneConversation(convId);
if (!conv) return;
@ -248,6 +303,11 @@ const StorageUtils = {
conv.lastModified = Date.now();
localStorage.setItem(convId, JSON.stringify(conv));
},
/**
* remove last message from conversation
* @param {string} convId
* @returns {Message | undefined}
*/
popMsg(convId) {
const conv = StorageUtils.getOneConversation(convId);
if (!conv) return;
@ -322,10 +382,12 @@ const mainApp = createApp({
data() {
return {
conversations: StorageUtils.getAllConversations(),
messages: [], // { id: number, role: 'user' | 'assistant', content: string }
/** @type {Array<Message>} */
messages: [],
viewingConvId: StorageUtils.getNewConvId(),
inputMsg: '',
isGenerating: false,
/** @type {Array<Message> | null} */
pendingMsg: null, // the on-going message from assistant
stopGeneration: () => {},
selectedTheme: StorageUtils.getTheme(),
@ -333,6 +395,7 @@ const mainApp = createApp({
showConfigDialog: false,
// const
themes: THEMES,
/** @type {CONFIG_DEFAULT} */
configDefault: {...CONFIG_DEFAULT},
configInfo: {...CONFIG_INFO},
isDev,
@ -425,42 +488,50 @@ const mainApp = createApp({
this.isGenerating = true;
try {
/** @type {CONFIG_DEFAULT} */
const config = this.config;
const abortController = new AbortController();
this.stopGeneration = () => abortController.abort();
/** @type {Array<APIMessage>} */
let messages = [
{ role: 'system', content: config.systemMessage },
...normalizeMsgsForAPI(this.messages),
];
if (config.excludeThoughtOnReq) {
messages = filterThoughtFromMsgs(messages);
}
if (isDev) console.log({messages});
const params = {
messages: [
{ role: 'system', content: this.config.systemMessage },
...this.messages,
],
messages,
stream: true,
cache_prompt: true,
samplers: this.config.samplers,
temperature: this.config.temperature,
dynatemp_range: this.config.dynatemp_range,
dynatemp_exponent: this.config.dynatemp_exponent,
top_k: this.config.top_k,
top_p: this.config.top_p,
min_p: this.config.min_p,
typical_p: this.config.typical_p,
xtc_probability: this.config.xtc_probability,
xtc_threshold: this.config.xtc_threshold,
repeat_last_n: this.config.repeat_last_n,
repeat_penalty: this.config.repeat_penalty,
presence_penalty: this.config.presence_penalty,
frequency_penalty: this.config.frequency_penalty,
dry_multiplier: this.config.dry_multiplier,
dry_base: this.config.dry_base,
dry_allowed_length: this.config.dry_allowed_length,
dry_penalty_last_n: this.config.dry_penalty_last_n,
max_tokens: this.config.max_tokens,
timings_per_token: !!this.config.showTokensPerSecond,
...(this.config.custom.length ? JSON.parse(this.config.custom) : {}),
samplers: config.samplers,
temperature: config.temperature,
dynatemp_range: config.dynatemp_range,
dynatemp_exponent: config.dynatemp_exponent,
top_k: config.top_k,
top_p: config.top_p,
min_p: config.min_p,
typical_p: config.typical_p,
xtc_probability: config.xtc_probability,
xtc_threshold: config.xtc_threshold,
repeat_last_n: config.repeat_last_n,
repeat_penalty: config.repeat_penalty,
presence_penalty: config.presence_penalty,
frequency_penalty: config.frequency_penalty,
dry_multiplier: config.dry_multiplier,
dry_base: config.dry_base,
dry_allowed_length: config.dry_allowed_length,
dry_penalty_last_n: config.dry_penalty_last_n,
max_tokens: config.max_tokens,
timings_per_token: !!config.showTokensPerSecond,
...(config.custom.length ? JSON.parse(config.custom) : {}),
};
const chunks = sendSSEPostRequest(`${BASE_URL}/v1/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
...(this.config.apiKey ? {'Authorization': `Bearer ${this.config.apiKey}`} : {})
...(config.apiKey ? {'Authorization': `Bearer ${config.apiKey}`} : {})
},
body: JSON.stringify(params),
signal: abortController.signal,
@ -477,7 +548,7 @@ const mainApp = createApp({
};
}
const timings = chunk.timings;
if (timings && this.config.showTokensPerSecond) {
if (timings && config.showTokensPerSecond) {
// only extract what's really needed, to save some space
this.pendingMsg.timings = {
prompt_n: timings.prompt_n,
@ -598,3 +669,33 @@ try {
<button class="btn" onClick="localStorage.clear(); window.location.reload();">Clear localStorage</button>
</div>`;
}
/**
* filter out redundant fields upon sending to API
* @param {Array<APIMessage>} messages
* @returns {Array<APIMessage>}
*/
function normalizeMsgsForAPI(messages) {
return messages.map((msg) => {
return {
role: msg.role,
content: msg.content,
};
});
}
/**
* recommended for DeepsSeek-R1, filter out content between <think> and </think> tags
* @param {Array<APIMessage>} messages
* @returns {Array<APIMessage>}
*/
function filterThoughtFromMsgs(messages) {
return messages.map((msg) => {
return {
role: msg.role,
content: msg.role === 'assistant'
? msg.content.split('</think>').at(-1).trim()
: msg.content,
};
});
}

View file

@ -0,0 +1,11 @@
cmake_minimum_required(VERSION 3.12)
project(llama-simple-cmake-pkg)
set(TARGET llama-simple-cmake-pkg)
find_package(Llama REQUIRED)
add_executable(${TARGET} ${CMAKE_CURRENT_LIST_DIR}/../simple/simple.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE llama ggml::all ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

View file

@ -0,0 +1,34 @@
# llama.cpp/example/simple-cmake-pkg
This program builds [simple](../simple) using a relocatable CMake package. It serves as an example of using the `find_package()` CMake command to conveniently include [llama.cpp](https://github.com/ggerganov/llama.cpp) in projects which live outside of the source tree.
## Building
Because this example is "outside of the source tree", it is important to first build/install llama.cpp using CMake. An example is provided here, but please see the [llama.cpp build instructions](../..) for more detailed build instructions.
### Considerations
When hardware acceleration libraries are used (e.g. CUDA, Metal, Vulkan, etc.), the appropriate dependencies will be searched for automatically. So, for example, when finding a package
### Build llama.cpp and install to llama.cpp/inst
```sh
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp
cmake -S . -B build
cmake --build build
cmake --install build --prefix inst
### Build simple-cmake-pkg
```sh
cd examples/simple-cmake-pkg
cmake -S . -B build -DCMAKE_PREFIX_PATH=../../inst/lib/cmake
cmake --build build
```
### Run simple-cmake-pkg
```sh
./build/llama-simple-cmake-pkg -m ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is"
```

View file

@ -58,7 +58,8 @@ else()
set(GGML_BLAS_VENDOR_DEFAULT "Generic")
endif()
if (CMAKE_CROSSCOMPILING)
if (CMAKE_CROSSCOMPILING OR DEFINED ENV{SOURCE_DATE_EPOCH})
message(STATUS "Setting GGML_NATIVE_DEFAULT to OFF")
set(GGML_NATIVE_DEFAULT OFF)
else()
set(GGML_NATIVE_DEFAULT ON)
@ -153,6 +154,8 @@ option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashA
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
option(GGML_HIP "ggml: use HIP" OFF)
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
option(GGML_VULKAN "ggml: use Vulkan" OFF)
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
@ -264,3 +267,74 @@ if (GGML_STANDALONE)
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc
DESTINATION share/pkgconfig)
endif()
#
# Create CMake package
#
# Generate version info based on git commit.
find_program(GIT_EXE NAMES git git.exe REQUIRED NO_CMAKE_FIND_ROOT_PATH)
execute_process(COMMAND ${GIT_EXE} rev-list --count HEAD
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE GGML_BUILD_NUMBER
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(GGML_BUILD_NUMBER EQUAL 1)
message(WARNING "GGML build version fixed at 1 likely due to a shallow clone.")
endif()
execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE GGML_BUILD_COMMIT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
# Capture variables prefixed with GGML_.
set(variable_set_statements
"
####### Expanded from @GGML_VARIABLES_EXPANED@ by configure_package_config_file() #######
####### Any changes to this file will be overwritten by the next CMake run #######
")
set(GGML_SHARED_LIB ${BUILD_SHARED_LIBS})
get_cmake_property(all_variables VARIABLES)
foreach(variable_name IN LISTS all_variables)
if(variable_name MATCHES "^GGML_")
string(REPLACE ";" "\\;"
variable_value "${${variable_name}}")
set(variable_set_statements
"${variable_set_statements}set(${variable_name} \"${variable_value}\")\n")
endif()
endforeach()
set(GGML_VARIABLES_EXPANDED ${variable_set_statements})
# Create the CMake package and set install location.
set(GGML_INSTALL_VERSION 0.0.${GGML_BUILD_NUMBER})
set(GGML_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files")
set(GGML_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files")
set(GGML_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files")
configure_package_config_file(
${CMAKE_CURRENT_SOURCE_DIR}/cmake/ggml-config.cmake.in
${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake
INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml
PATH_VARS GGML_INCLUDE_INSTALL_DIR
GGML_LIB_INSTALL_DIR
GGML_BIN_INSTALL_DIR)
write_basic_package_version_file(
${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake
VERSION ${GGML_INSTALL_VERSION}
COMPATIBILITY SameMajorVersion)
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake
${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml)

View file

@ -0,0 +1,147 @@
@GGML_VARIABLES_EXPANDED@
@PACKAGE_INIT@
set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@")
set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@")
set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@")
find_package(Threads REQUIRED)
find_library(GGML_LIBRARY ggml
REQUIRED
HINTS ${GGML_LIB_DIR}
NO_CMAKE_FIND_ROOT_PATH)
add_library(ggml::ggml UNKNOWN IMPORTED)
set_target_properties(ggml::ggml
PROPERTIES
IMPORTED_LOCATION "${GGML_LIBRARY}")
find_library(GGML_BASE_LIBRARY ggml-base
REQUIRED
HINTS ${GGML_LIB_DIR}
NO_CMAKE_FIND_ROOT_PATH)
add_library(ggml::ggml-base UNKNOWN IMPORTED)
set_target_properties(ggml::ggml-base
PROPERTIES
IMPORTED_LOCATION "${GGML_BASE_LIBRARY}")
if (NOT GGML_SHARED_LIB)
if (APPLE AND GGML_ACCELERATE)
find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED)
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${ACCELERATE_FRAMEWORK})
endif()
if (GGML_OPENMP)
find_package(OpenMP REQUIRED)
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
endif()
if (GGML_CPU_HBM)
find_library(memkind memkind REQUIRED)
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES memkind)
endif()
if (GGML_BLAS)
find_package(BLAS REQUIRED)
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${BLAS_LIBRARIES})
list(APPEND GGML_CPU_INTERFACE_LINK_OPTIONS ${BLAS_LINKER_FLAGS})
endif()
if (GGML_CUDA)
find_package(CUDAToolkit REQUIRED)
endif()
if (GGML_METAL)
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED)
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
list(APPEND GGML_METAL_INTERFACE_LINK_LIBRARIES
${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK})
endif()
if (GGML_VULKAN)
find_package(Vulkan REQUIRED)
list(APPEND GGML_VULKAN_INTERFACE_LINK_LIBRARIES Vulkan::Vulkan)
endif()
if (GGML_HIP)
find_package(hip REQUIRED)
find_package(hipblas REQUIRED)
find_package(rocblas REQUIRED)
list(APPEND GGML_HIP_INTERFACE_LINK_LIBRARIES hip::host roc::rocblas roc::hipblas)
endif()
if (GGML_SYCL)
find_package(DNNL)
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES DNNL::dnnl)
endif()
if (WIN32)
find_package(IntelSYCL REQUIRED)
find_package(MKL REQUIRED)
list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
endif()
endif()
endif()
set(_ggml_all_targets "")
foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})
string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}")
string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx)
find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend}
REQUIRED
HINTS ${GGML_LIB_DIR}
NO_CMAKE_FIND_ROOT_PATH)
message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}")
add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED)
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}"
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}"
INTERFACE_COMPILE_FEATURES c_std_90
POSITION_INDEPENDENT_CODE ON)
string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}")
if(is_cpu_variant)
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml" "ggml::ggml-base")
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}")
if(GGML_CPU_INTERFACE_LINK_OPTIONS)
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}")
endif()
else()
list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml" "ggml::ggml-base")
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}")
if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS)
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}")
endif()
endif()
list(APPEND _ggml_all_targets ggml::${_ggml_backend})
endforeach()
add_library(ggml::all INTERFACE IMPORTED)
set_target_properties(ggml::all
PROPERTIES
INTERFACE_LINK_LIBRARIES "${_ggml_all_targets}")
check_required_components(ggml)

View file

@ -250,6 +250,17 @@ function(ggml_add_backend_library backend)
target_compile_definitions(${backend} PRIVATE GGML_BACKEND_BUILD)
target_compile_definitions(${backend} PUBLIC GGML_BACKEND_SHARED)
endif()
if(NOT GGML_AVAILABLE_BACKENDS)
set(GGML_AVAILABLE_BACKENDS "${backend}"
CACHE INTERNAL "List of backends for cmake package")
else()
list(FIND GGML_AVAILABLE_BACKENDS "${backend}" has_backend)
if(has_backend EQUAL -1)
set(GGML_AVAILABLE_BACKENDS "${GGML_AVAILABLE_BACKENDS};${backend}"
CACHE INTERNAL "List of backends for cmake package")
endif()
endif()
endfunction()
function(ggml_add_backend backend)
@ -297,7 +308,7 @@ if (GGML_CPU_ALL_VARIANTS)
# MSVC doesn't support AMX
ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
endif()
else ()
elseif (GGML_CPU)
ggml_add_cpu_backend_variant_impl("")
endif()

View file

@ -1302,7 +1302,7 @@ struct ggml_threadpool {
// these are atomic as an annotation for thread-sanitizer
atomic_bool stop; // Used for stopping the threadpool altogether
atomic_bool pause; // Used for pausing the threadpool or individual threads
atomic_bool abort; // Used for aborting processing of a graph
atomic_int abort; // Used for aborting processing of a graph
struct ggml_compute_state * workers; // per thread state
int n_threads_max; // number of threads in the pool
@ -7883,7 +7883,7 @@ static void ggml_compute_forward_out_prod_f32(
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
}
@ -7892,7 +7892,7 @@ static void ggml_compute_forward_out_prod_f32(
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
ggml_vec_mad_f32(ne0, d, s0, *s1);
}
@ -13851,14 +13851,14 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
/*.threadpool=*/ tp,
};
for (int node_n = 0; node_n < cgraph->n_nodes && !tp->abort; node_n++) {
for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
struct ggml_tensor * node = cgraph->nodes[node_n];
ggml_compute_forward(&params, node);
if (state->ith == 0 && cplan->abort_callback &&
cplan->abort_callback(cplan->abort_callback_data)) {
tp->abort = true;
atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);
tp->ec = GGML_STATUS_ABORTED;
}
@ -14031,7 +14031,7 @@ static struct ggml_threadpool * ggml_threadpool_new_impl(
threadpool->current_chunk = 0;
threadpool->stop = false;
threadpool->pause = tpp->paused;
threadpool->abort = false;
threadpool->abort = -1;
threadpool->workers = NULL;
threadpool->n_threads_max = tpp->n_threads;
threadpool->n_threads_cur = tpp->n_threads;
@ -14110,7 +14110,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
threadpool->cgraph = cgraph;
threadpool->cplan = cplan;
threadpool->current_chunk = 0;
threadpool->abort = false;
threadpool->abort = -1;
threadpool->ec = GGML_STATUS_SUCCESS;
}

View file

@ -416,7 +416,8 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
case GGML_OP_IM2COL_BACK:
return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
case GGML_OP_OUT_PROD:
return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32;
return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) &&
src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
default:
return true;
}

View file

@ -93,26 +93,31 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
template <typename T>
static __global__ void k_repeat_back(
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t ne0, const int64_t ne1, const int64_t ne2) {
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const size_t s00, const size_t s01, const size_t s02, const size_t s03,
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y;
const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z;
const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;
const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;
const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z;
const int64_t tid2 = tid23 % ne2;
const int64_t tid3 = tid23 / ne2;
if (tid0 >= ne0) {
return;
}
T sum = 0;
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
sum += src[i2*ne01*ne00 + i1*ne00 + i0];
for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];
}
}
}
}
dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
}
template<float (*bin_op)(const float, const float)>
@ -274,12 +279,14 @@ struct bin_bcast_cuda {
template <typename T>
static void repeat_back_cuda(
const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) {
const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const size_t s00, const size_t s01, const size_t s02, const size_t s03,
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2);
k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3);
k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>
(src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3);
}
template<class op>
@ -326,27 +333,26 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0->type == dst->type);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_can_repeat(dst, src0));
cudaStream_t stream = ctx.stream();
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
GGML_ASSERT(src0->ne[3] == 1);
GGML_TENSOR_UNARY_OP_LOCALS;
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
const int64_t ne2 = dst->ne[2];
GGML_ASSERT(dst->ne[3] == 1);
GGML_ASSERT(ne2*ne3 <= (1 << 15));
const size_t ts = ggml_type_size(src0->type);
const size_t s00 = nb00 / ts;
const size_t s01 = nb01 / ts;
const size_t s02 = nb02 / ts;
const size_t s03 = nb03 / ts;
switch (dst->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
repeat_back_cuda<float>(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream);
repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream);
} break;
default: {
GGML_ASSERT(false);

View file

@ -46,20 +46,20 @@
#define GGML_CUDA_CC_VOLTA 700
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_OFFSET_AMD 1000000
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
// GCN/CNDA, wave size is 64
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 803) // Tonga, Fiji, Polaris, minimum for fast fp16
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 900) // Vega56/64, minimum for fp16 dual issue
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 906) // MI50/Radeon VII, minimum for dp4a
#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 908) // MI100, minimum for MFMA, acc registers
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 910) // MI210, minimum acc register renameing
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 942) // MI300
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300
// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 1010) // RX 5000
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
#define GGML_CUDA_CC_QY1 210
#define GGML_CUDA_CC_QY2 220
@ -131,6 +131,10 @@ typedef float dfloat; // dequantize float
typedef float2 dfloat2;
#endif // GGML_CUDA_F16
#if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
#define GGML_USE_VMM
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#define FP16_AVAILABLE
#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
@ -588,7 +592,7 @@ struct ggml_tensor_extra_gpu {
};
#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)
#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS)
#define USE_CUDA_GRAPH
#endif

View file

@ -42,6 +42,7 @@
#include <algorithm>
#include <array>
#include <atomic>
#include <charconv>
#include <cinttypes>
#include <cstddef>
#include <cstdint>
@ -62,7 +63,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
[[noreturn]]
void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
int id = -1; // in case cudaGetDevice fails
cudaGetDevice(&id);
(void)cudaGetDevice(&id);
GGML_LOG_ERROR(GGML_CUDA_NAME " error: %s\n", msg);
GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line);
@ -119,12 +120,78 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
#endif
}
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
static int ggml_cuda_parse_id(char devName[]) {
// A list of possible Target IDs can be found under the rocclr/clr repo in device.cpp
// these values are not stable so this is susceptible to breakage
// https://github.com/ROCm/clr/blob/amd-staging/rocclr/device/device.cpp
int archMajor = 0x0;
int archMinor = 0x0;
int archNum = GGML_CUDA_CC_OFFSET_AMD;
int archLen = strlen(devName);
char archName[archLen + 1];
// strip leading 'gfx' while copying into our buffer
if (archLen > 3) {
strcpy(archName, &devName[3]);
archLen -= 3;
}
// trim trailing :xnack- or :sramecc- statuses
archLen = strcspn(archName, ":");
archName[archLen] = '\0';
// tease out the version information
if (archLen > 8) {
// versions labeled generic use '-' as delimiter
// strip the trailing "-generic" then iterate through what remains
if ((strstr(archName, "-generic"))) {
archName[archLen - 8] = '\0';
char * pch;
if ((pch = strtok(archName, "-"))) {
archMajor = (int)strtoul(pch, 0, 16);
if ((pch = strtok(NULL, "-"))) {
archMinor = 0x10 * (int)strtoul(pch, 0, 16);
}
}
}
} else if (archLen >= 3) {
// last two digits should be the minor * 0x10 + stepping
archMinor = (int)strtoul(&archName[archLen - 2], 0, 16);
archName[archLen - 2] = '\0';
// only the major version remains
archMajor = (int)strtoul(archName, 0, 16);
}
archNum += archMajor * 0x100;
archNum += archMinor;
return archNum;
}
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
static ggml_cuda_device_info ggml_cuda_init() {
#ifdef __HIP_PLATFORM_AMD__
// Workaround for a rocBLAS bug when using multiple graphics cards:
// https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
rocblas_initialize();
CUDA_CHECK(cudaDeviceSynchronize());
{
int major_version = 0;
size_t version_length = 0;
if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) {
std::string version(version_length, '\0');
if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) {
version.resize(::strlen(version.c_str()));
int parsed_value = 0;
if (std::from_chars(version.c_str(), version.c_str() + version.length(), parsed_value).ec == std::errc()) {
major_version = parsed_value;
}
}
}
if (major_version < 4) {
GGML_LOG_DEBUG(GGML_CUDA_NAME " calling rocblas_initialize as a workaround for a rocBLAS bug\n");
rocblas_initialize();
CUDA_CHECK(cudaDeviceSynchronize());
}
}
#endif
ggml_cuda_device_info info = {};
@ -152,7 +219,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
for (int id = 0; id < info.device_count; ++id) {
int device_vmm = 0;
#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
#if defined(GGML_USE_VMM)
CUdevice device;
CU_CHECK(cuDeviceGet(&device, id));
CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
@ -164,12 +231,11 @@ static ggml_cuda_device_info ggml_cuda_init() {
alloc_prop.location.id = id;
CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
}
#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
#endif // defined(GGML_USE_VMM)
info.devices[id].vmm = !!device_vmm;
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
info.default_tensor_split[id] = total_vram;
total_vram += prop.totalGlobalMem;
@ -178,10 +244,25 @@ static ggml_cuda_device_info ggml_cuda_init() {
info.devices[id].smpb = prop.sharedMemPerBlock;
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
info.devices[id].smpbo = prop.sharedMemPerBlock;
info.devices[id].cc = 100*prop.major + 10*prop.minor + GGML_CUDA_CC_OFFSET_AMD;
info.devices[id].cc = ggml_cuda_parse_id(prop.gcnArchName);
if ((info.devices[id].cc & 0xff00) == 0x0) {
GGML_LOG_WARN("invalid architecture ID received for device %d %s: %s cc %d.%d\n",
id, prop.name, prop.gcnArchName, prop.major, prop.minor);
// Fallback to prop.major and prop.minor
if (prop.major > 0) {
info.devices[id].cc = GGML_CUDA_CC_OFFSET_AMD + prop.major * 0x100;
info.devices[id].cc += prop.minor * 0x10;
}
}
GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s\n",
id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff, device_vmm ? "yes" : "no");
#else
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
info.devices[id].cc = 100*prop.major + 10*prop.minor;
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
}
@ -300,7 +381,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
};
// pool with virtual memory
#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
#if defined(GGML_USE_VMM)
struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
@ -309,6 +390,9 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
size_t pool_used = 0;
size_t pool_size = 0;
size_t granularity;
#if defined(GGML_USE_HIP)
std::vector<std::pair<CUdeviceptr, size_t>> mappings;
#endif
explicit ggml_cuda_pool_vmm(int device) :
device(device),
@ -317,7 +401,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
~ggml_cuda_pool_vmm() {
if (pool_addr != 0) {
#if defined(GGML_USE_HIP)
// Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285
for (std::pair<CUdeviceptr, size_t> & mapping : mappings) {
CU_CHECK(cuMemUnmap(mapping.first, mapping.second));
}
#else
CU_CHECK(cuMemUnmap(pool_addr, pool_size));
#endif
CU_CHECK(cuMemAddressFree(pool_addr, CUDA_POOL_VMM_MAX_SIZE));
}
}
@ -350,7 +441,11 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
}
// map at the end of the pool
CU_CHECK(cuMemMap(pool_addr + pool_size, reserve_size, 0, handle, 0));
CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size);
CU_CHECK(cuMemMap(start_ptr, reserve_size, 0, handle, 0));
#if defined(GGML_USE_HIP)
mappings.push_back({start_ptr, reserve_size});
#endif
// the memory allocation handle is no longer needed after mapping
CU_CHECK(cuMemRelease(handle));
@ -360,7 +455,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
access.location.id = device;
access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
CU_CHECK(cuMemSetAccess(pool_addr + pool_size, reserve_size, &access, 1));
CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1));
// add to the pool
pool_size += reserve_size;
@ -372,7 +467,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
GGML_ASSERT(pool_addr != 0);
void * ptr = (void *) (pool_addr + pool_used);
void * ptr = (void *) ((CUdeviceptr)((char *)(pool_addr) + pool_used));
*actual_size = size;
pool_used += size;
@ -391,17 +486,17 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
pool_used -= size;
// all deallocations must be in reverse order of the allocations
GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used));
}
};
#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
#endif // defined(GGML_USE_VMM)
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
#if defined(GGML_USE_VMM)
if (ggml_cuda_info().devices[device].vmm) {
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
}
#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
#endif // defined(GGML_USE_VMM)
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
}
@ -547,7 +642,7 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac
cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
if (err != cudaSuccess) {
// clear the error
cudaGetLastError();
(void)cudaGetLastError();
GGML_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err));
return nullptr;
}
@ -962,7 +1057,7 @@ static void * ggml_cuda_host_malloc(size_t size) {
cudaError_t err = cudaMallocHost((void **) &ptr, size);
if (err != cudaSuccess) {
// clear the error
cudaGetLastError();
(void)cudaGetLastError();
GGML_LOG_DEBUG("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
size / 1024.0 / 1024.0, cudaGetErrorString(err));
return nullptr;
@ -1082,7 +1177,9 @@ static void ggml_cuda_op_mul_mat_cublas(
const int compute_capability = ggml_cuda_info().devices[id].cc;
if (compute_capability >= GGML_CUDA_CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
if (src0->type != GGML_TYPE_F16) {
@ -1103,28 +1200,38 @@ static void ggml_cuda_op_mul_mat_cublas(
to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
}
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
cu_compute_type = CUBLAS_COMPUTE_32F;
}
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
CUBLAS_CHECK(
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10,
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
src1_ptr, CUDA_R_16F, ne10,
&beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
if (compute_capability == GGML_CUDA_CC_CDNA) {
const float alpha = 1.0f;
const float beta = 0.0f;
CUBLAS_CHECK(
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10,
&alpha, src0_ptr, CUDA_R_16F, ne00,
src1_ptr, CUDA_R_16F, ne10,
&beta, dst_dd_i, CUDA_R_32F, ldc,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;
CUBLAS_CHECK(
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10,
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
src1_ptr, CUDA_R_16F, ne10,
&beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
}
} else {
ggml_cuda_pool_alloc<float> src0_ddq_as_f32(ctx.pool(id));
ggml_cuda_pool_alloc<float> src1_ddq_as_f32(ctx.pool(id));
@ -1197,7 +1304,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
CUDA_CHECK(err);
} else {
// reset the error
cudaGetLastError();
(void)cudaGetLastError();
}
} else {
cudaError_t err = cudaDeviceDisablePeerAccess(id_other);
@ -1205,7 +1312,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
CUDA_CHECK(err);
} else {
// reset the error
cudaGetLastError();
(void)cudaGetLastError();
}
}
}
@ -1613,10 +1720,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
cudaDataType_t cu_data_type = CUDA_R_16F;
if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
cu_compute_type = CUBLAS_COMPUTE_32F;
}
// dst strides
size_t nbd2 = dst->nb[2];
size_t nbd3 = dst->nb[3];
@ -1645,6 +1748,12 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
beta = &beta_f32;
}
if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
cu_compute_type = CUBLAS_COMPUTE_32F;
alpha = &alpha_f32;
beta = &beta_f32;
}
GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
@ -2438,7 +2547,7 @@ static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vecto
if (stat == cudaErrorInvalidDeviceFunction) {
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
// We don't need to update blas nodes, so clear error and move on.
cudaGetLastError();
(void)cudaGetLastError();
} else {
GGML_ASSERT(stat == cudaSuccess);
}
@ -2493,14 +2602,20 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
cudaGraphExecUpdateResultInfo result_info;
#ifdef __HIP_PLATFORM_AMD__
hipGraphNode_t errorNode;
hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
#else
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
#endif
if (stat == cudaErrorGraphExecUpdateFailure) {
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
#endif
// The pre-existing graph exec cannot be updated due to violated constraints
// so instead clear error and re-instantiate
cudaGetLastError();
(void)cudaGetLastError();
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
cuda_ctx->cuda_graph->instance = nullptr;
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
@ -2728,7 +2843,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
if (err != cudaSuccess) {
// clear the error
cudaGetLastError();
(void)cudaGetLastError();
GGML_LOG_DEBUG("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__,
size / 1024.0 / 1024.0, cudaGetErrorString(err));
@ -2748,7 +2863,7 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
cudaError_t err = cudaHostUnregister(buffer);
if (err != cudaSuccess) {
// clear the error
cudaGetLastError();
(void)cudaGetLastError();
}
}
@ -3002,7 +3117,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
} break;
case GGML_OP_REPEAT_BACK:
return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1;
return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15);
case GGML_OP_CONCAT:
{
ggml_type src0_type = op->src[0]->type;
@ -3216,7 +3331,7 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
features.push_back({ "FORCE_CUBLAS", "1" });
#endif
#ifdef GGML_CUDA_NO_VMM
#ifndef GGML_USE_VMM
features.push_back({ "NO_VMM", "1" });
#endif

View file

@ -142,7 +142,7 @@ static void mul_mat_vec_q_cuda(
int64_t nwarps = 1;
int64_t rows_per_cuda_block = 1;
if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_CDNA || ggml_cuda_info().devices[id].cc == GGML_CUDA_CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA
if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
switch(ncols_y) {
case 1:
nwarps = 4;
@ -166,6 +166,7 @@ static void mul_mat_vec_q_cuda(
break;
}
}
const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
const dim3 block_nums(nblocks, 1, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);

View file

@ -34,6 +34,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
CUBLAS_CHECK(cublasSetStream(handle, stream));
const int64_t lda = nb01 / sizeof(float);
const int64_t ldc = nb1 / sizeof(float);
const bool src1_T = ggml_is_transposed(src1);
const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
@ -57,9 +60,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
CUBLAS_CHECK(
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
ne0, ne1, ne01,
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00,
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
src1_d + i3 *s13 + i2 *s12, ldb,
&beta, dst_d + i3 *s3 + i2 *s2, ne0));
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
}
}
}

View file

@ -13,6 +13,12 @@ __device__ float __forceinline__ t2f32<half>(half val) {
return __half2float(val);
}
// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#endif
template <bool use_shared, int ncols_template, int block_size_template, typename T>
static __global__ void soft_max_f32(
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
@ -118,6 +124,9 @@ static __global__ void soft_max_f32(
dst[col] = vals[col] * inv_sum;
}
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif
static __global__ void soft_max_back_f32(
const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {

View file

@ -19,6 +19,12 @@
#define CUBLAS_TF32_TENSOR_OP_MATH 0
#define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_32F HIPBLAS_R_32F
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasCreate hipblasCreate
@ -74,6 +80,21 @@
#define cudaMemGetInfo hipMemGetInfo
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
#define cudaSetDevice hipSetDevice
#define cuDeviceGet hipDeviceGet
#define CUdevice hipDevice_t
#define CUdeviceptr hipDeviceptr_t
#define cuMemUnmap hipMemUnmap
#define CUmemAccessDesc hipMemAccessDesc
#define cuMemAddressFree hipMemAddressFree
#define cuMemRelease hipMemRelease
#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
#define cuMemCreate hipMemCreate
#define cuMemAddressReserve hipMemAddressReserve
#define cuMemMap hipMemMap
#define cuMemSetAccess hipMemSetAccess
#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
#define CUmemAllocationProp hipMemAllocationProp
#define cuDeviceGetAttribute hipDeviceGetAttribute
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
#define cudaStreamDestroy hipStreamDestroy
#define cudaStreamFireAndForget hipStreamFireAndForget
@ -81,6 +102,28 @@
#define cudaStreamPerThread hipStreamPerThread
#define cudaStreamSynchronize hipStreamSynchronize
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
#define cudaGraphExec_t hipGraphExec_t
#define cudaGraphNode_t hipGraphNode_t
#define cudaKernelNodeParams hipKernelNodeParams
#define cudaKernelNodeParams hipKernelNodeParams
#define cudaGraphExecDestroy hipGraphExecDestroy
#define cudaGraphLaunch hipGraphLaunch
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
#define cudaGraphNodeType hipGraphNodeType
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
#define cudaGraphInstantiate hipGraphInstantiate
#define cudaStreamEndCapture hipStreamEndCapture
#define cudaGraphDestroy hipGraphDestroy
#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
#define cudaGraphNodeGetType hipGraphNodeGetType
#define cudaGraphGetNodes hipGraphGetNodes
#define cudaGraphExecUpdate hipGraphExecUpdate
#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
#define cudaStreamBeginCapture hipStreamBeginCapture
#define cudaGraph_t hipGraph_t
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#define __trap() do { abort(); __builtin_unreachable(); } while(0)

View file

@ -92,6 +92,14 @@ if (GGML_CUDA_NO_PEER_COPY)
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
endif()
if (GGML_HIP_GRAPHS)
add_compile_definitions(GGML_HIP_GRAPHS)
endif()
if (GGML_HIP_NO_VMM)
add_compile_definitions(GGML_HIP_NO_VMM)
endif()
if (CXX_IS_HIPCC)
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
target_link_libraries(ggml-hip PRIVATE hip::device)

View file

@ -19,7 +19,10 @@
// max number of MTLCommandBuffer used to submit a graph for processing
#define GGML_METAL_MAX_COMMAND_BUFFERS 8
#define UNUSED(x) (void)(x)
// create residency sets only on macOS >= 15.0
#if TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000
#define GGML_METAL_HAS_RESIDENCY_SETS 1
#endif
// globals
@ -39,6 +42,7 @@ static struct ggml_backend_metal_device_context {
bool has_simdgroup_reduction;
bool has_simdgroup_mm;
bool has_residency_sets;
bool has_bfloat;
bool use_bfloat;
@ -48,6 +52,7 @@ static struct ggml_backend_metal_device_context {
/*.mtl_device_ref_count =*/ 0,
/*.has_simdgroup_reduction =*/ false,
/*.has_simdgroup_mm =*/ false,
/*.has_residency_sets =*/ false,
/*.has_bfloat =*/ false,
/*.use_bfloat =*/ false,
/*.name =*/ "",
@ -59,12 +64,18 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
if (ctx->mtl_device == nil) {
ctx->mtl_device = MTLCreateSystemDefaultDevice();
}
if (ctx->mtl_device) {
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL;
#endif
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
@ -90,8 +101,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
ctx->mtl_device_ref_count--;
if (ctx->mtl_device_ref_count == 0) {
[ctx->mtl_device release];
ctx->mtl_device = nil;
if (ctx->mtl_device) {
[ctx->mtl_device release];
ctx->mtl_device = nil;
}
}
}
@ -483,6 +496,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
ctx->queue = [device newCommandQueue];
if (ctx->queue == nil) {
GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
return NULL;
}
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
id<MTLLibrary> metal_library;
@ -649,6 +667,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false");
GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false");
GGML_LOG_INFO("%s: has residency sets = %s\n", __func__, ctx_dev->has_residency_sets ? "true" : "false");
GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false");
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
@ -1035,8 +1054,70 @@ struct ggml_backend_metal_buffer_context {
// multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
int n_buffers;
struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
// optional MTLResidencySet
id rset;
};
// rset init
static bool ggml_backend_metal_buffer_rset_init(
struct ggml_backend_metal_buffer_context * ctx,
struct ggml_backend_metal_device_context * ctx_dev,
id<MTLDevice> device) {
ctx->rset = nil;
if (!ctx_dev->has_residency_sets) {
return true;
}
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
if (@available(macOS 15.0, *)) {
MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init];
desc.label = @"ggml_backend_metal";
desc.initialCapacity = ctx->n_buffers;
NSError * error;
ctx->rset = [device newResidencySetWithDescriptor:desc error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
[desc release];
return false;
}
[desc release];
for (int i = 0; i < ctx->n_buffers; i++) {
[ctx->rset addAllocation:ctx->buffers[i].metal];
}
[ctx->rset commit];
[ctx->rset requestResidency];
return true;
}
#else
GGML_UNUSED(ctx_dev);
GGML_UNUSED(device);
#endif
return true;
}
// rset free
static void ggml_backend_metal_buffer_rset_free(struct ggml_backend_metal_buffer_context * ctx) {
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
if (@available(macOS 15.0, *)) {
if (ctx->rset) {
[ctx->rset endResidency];
[ctx->rset removeAllAllocations];
[ctx->rset release];
}
}
#else
GGML_UNUSED(ctx);
#endif
}
// finds the Metal buffer that contains the tensor data on the GPU device
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
// Metal buffer based on the host memory pointer
@ -4176,6 +4257,8 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
for (int i = 0; i < ctx->n_buffers; i++) {
[ctx->buffers[i].metal release];
}
ggml_backend_metal_buffer_rset_free(ctx);
ggml_backend_metal_device_rel(buffer->buft->device->context);
if (ctx->owned) {
@ -4198,19 +4281,19 @@ static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
memset((char *)tensor->data + offset, value, size);
UNUSED(buffer);
GGML_UNUSED(buffer);
}
static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
memcpy((char *)tensor->data + offset, data, size);
UNUSED(buffer);
GGML_UNUSED(buffer);
}
static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
memcpy(data, (const char *)tensor->data + offset, size);
UNUSED(buffer);
GGML_UNUSED(buffer);
}
static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
@ -4220,7 +4303,7 @@ static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, c
}
return false;
UNUSED(buffer);
GGML_UNUSED(buffer);
}
static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
@ -4246,7 +4329,7 @@ static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
return "Metal";
UNUSED(buft);
GGML_UNUSED(buft);
}
static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
@ -4270,8 +4353,8 @@ static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t s
}
#endif
#endif
UNUSED(device);
UNUSED(size_aligned);
GGML_UNUSED(device);
GGML_UNUSED(size_aligned);
}
static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
@ -4284,7 +4367,8 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
size_aligned += (size_page - (size_aligned % size_page));
}
id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context;
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
ctx->all_data = ggml_metal_host_malloc(size_aligned);
ctx->all_size = size_aligned;
@ -4307,7 +4391,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
free(ctx);
ggml_backend_metal_device_rel(buft->device->context);
ggml_backend_metal_device_rel(ctx_dev);
return NULL;
}
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
free(ctx);
ggml_backend_metal_device_rel(ctx_dev);
return NULL;
}
@ -4318,7 +4409,7 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
return 32;
UNUSED(buft);
GGML_UNUSED(buft);
}
static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
@ -4328,13 +4419,13 @@ static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_ty
return max_size;
UNUSED(buft);
GGML_UNUSED(buft);
}
static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
return true;
UNUSED(buft);
GGML_UNUSED(buft);
}
ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
@ -4357,7 +4448,7 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
static const char * ggml_backend_metal_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) {
return "Metal_Mapped";
UNUSED(buft);
GGML_UNUSED(buft);
}
static ggml_backend_buffer_type_t ggml_backend_metal_buffer_from_ptr_type(void) {
@ -4400,7 +4491,8 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
size_aligned += (size_page - (size_aligned % size_page));
}
id<MTLDevice> device = ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
// the buffer fits into the max buffer size allowed by the device
if (size_aligned <= device.maxBufferLength) {
@ -4453,6 +4545,13 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
}
}
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
free(ctx);
ggml_backend_metal_device_rel(ctx_dev);
return NULL;
}
return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
}
@ -4461,7 +4560,7 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
static const char * ggml_backend_metal_name(ggml_backend_t backend) {
return "Metal";
UNUSED(backend);
GGML_UNUSED(backend);
}
static void ggml_backend_metal_free(ggml_backend_t backend) {
@ -4766,6 +4865,13 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
}
}
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
free(ctx);
ggml_backend_metal_device_rel(ctx_dev);
return NULL;
}
return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
}
@ -4779,7 +4885,7 @@ static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml
return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
UNUSED(dev);
GGML_UNUSED(dev);
}
static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {

View file

@ -3878,10 +3878,6 @@ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
}
static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_soft_max);
}
static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope);
@ -4090,7 +4086,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
ggml_sycl_diag_mask_inf(ctx, dst);
break;
case GGML_OP_SOFT_MAX:
ggml_sycl_soft_max(ctx, dst);
ggml_sycl_op_soft_max(ctx, dst);
break;
case GGML_OP_ROPE:
ggml_sycl_rope(ctx, dst);

View file

@ -1,7 +1,7 @@
#include "norm.hpp"
#include "softmax.hpp"
template <bool vals_smem, int ncols_template, int block_size_template>
static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par,
const int nrows_y, const float scale, const float max_bias, const float m0,
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
@ -29,7 +29,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
slope = sycl::pow(base, float(exp));
}
float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
float *vals = vals_smem ? buf + sycl::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
float max_val = -INFINITY;
for (int col0 = 0; col0 < ncols; col0 += block_size) {
@ -42,7 +42,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
const float val = x[ix]*scale + (mask ? slope*static_cast<float>(mask[iy]) : 0.0f);
vals[col] = val;
max_val = sycl::max(max_val, val);
@ -65,7 +65,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
item_ct1.barrier(sycl::access::fence_space::local_space);
max_val = buf[lane_id];
for (size_t i = 1; i < nreduce; i += 1) {
max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
max_val = sycl::max(max_val, buf[lane_id + i * WARP_SIZE]);
}
max_val = warp_reduce_max(max_val, item_ct1);
}
@ -122,8 +122,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
}
}
template <bool vals_smem, int ncols_template, int block_size_template>
static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par,
const int nrows_y, const float scale, const float max_bias, const float m0,
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
const size_t n_local_scratch, queue_ptr stream) {
@ -141,7 +141,8 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float *
});
}
static void soft_max_f32_sycl(const float * x, const float * mask,
template<typename T>
static void soft_max_f32_sycl(const float * x, const T * mask,
float * dst, const int ncols_x, const int nrows_x,
const int nrows_y, const float scale, const float max_bias,
queue_ptr stream, int device) {
@ -223,22 +224,16 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
}
}
void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F16 || dst->src[1]->type == GGML_TYPE_F32); // src1 contains mask and it is optional
const int64_t ne00 = src0->ne[0];
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src0->ne[1];
const int64_t ne00 = dst->src[0]->ne[0];
const int64_t nrows_x = ggml_nrows(dst->src[0]);
const int64_t nrows_y = dst->src[0]->ne[1];
float scale = 1.0f;
float max_bias = 0.0f;
@ -246,6 +241,21 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *s
memcpy(&scale, dst->op_params + 0, sizeof(float));
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
ggml_sycl_set_device(ctx.device);
dpct::queue_ptr main_stream = ctx.stream();
if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) {
const sycl::half * src1_dd = static_cast<sycl::half *>(dst->src[1]->data);
soft_max_f32_sycl<sycl::half>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias,
main_stream, ctx.device);
} else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) {
const float * src1_dd = static_cast<const float *>(dst->src[1]->data);
soft_max_f32_sycl<float>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
} else {
/* mask unavailable */
soft_max_f32_sycl<float>(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
}
}

View file

@ -15,10 +15,6 @@
#include "common.hpp"
void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream);
void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst);
#endif // GGML_SYCL_SOFTMAX_HPP

View file

@ -85,6 +85,10 @@ struct vk_pipeline_struct {
uint32_t parameter_count;
std::array<uint32_t, 3> wg_denoms;
uint32_t align;
// set to true to request the pipeline is compiled after the dryrun
bool needed {};
// set to true when the shader has been compiled
bool compiled {};
};
typedef std::shared_ptr<vk_pipeline_struct> vk_pipeline;
@ -186,8 +190,11 @@ struct vk_device_struct {
bool mul_mat_id_m;
bool mul_mat_id_s;
vk_matmul_pipeline pipeline_matmul_f32;
vk_matmul_pipeline pipeline_matmul_f32_f16;
// set to true to indicate that some shaders need to be compiled after the dryrun
bool need_compiles {};
vk_matmul_pipeline pipeline_matmul_f32 {};
vk_matmul_pipeline pipeline_matmul_f32_f16 {};
vk_matmul_pipeline2 pipeline_matmul_f16;
vk_matmul_pipeline2 pipeline_matmul_f16_f32;
vk_pipeline pipeline_matmul_split_k_reduce;
@ -195,7 +202,7 @@ struct vk_device_struct {
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
vk_matmul_pipeline pipeline_matmul_id_f32;
vk_matmul_pipeline pipeline_matmul_id_f32 {};
vk_matmul_pipeline2 pipeline_matmul_id_f16;
vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
@ -767,22 +774,15 @@ static uint32_t compile_count = 0;
static std::mutex compile_count_mutex;
static std::condition_variable compile_count_cond;
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint,
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
uint32_t align, bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size <<
", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align <<
", " << disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")");
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint,
uint32_t parameter_count, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << pipeline->name << ", " << entrypoint << ", " << parameter_count <<
", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " <<
disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")");
GGML_ASSERT(parameter_count > 0);
GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
pipeline = std::make_shared<vk_pipeline_struct>();
pipeline->name = name;
pipeline->parameter_count = parameter_count;
pipeline->push_constant_size = push_constant_size;
pipeline->wg_denoms = wg_denoms;
pipeline->align = align;
vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
@ -864,7 +864,14 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
compute_pipeline_create_info.setPNext(&rci);
}
pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
try {
pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
} catch (const vk::SystemError& e) {
std::cerr << "ggml_vulkan: Compute pipeline creation failed for " << pipeline->name << std::endl;
std::cerr << "ggml_vulkan: " << e.what() << std::endl;
throw e;
}
pipeline->compiled = true;
{
std::lock_guard<std::mutex> guard(device->mutex);
@ -875,12 +882,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
std::lock_guard<std::mutex> guard(compile_count_mutex);
assert(compile_count > 0);
compile_count--;
// "Progress bar" for shader compiles
static uint32_t total_compile_count = 0;
if ((total_compile_count++ % 10) == 0) {
std::cerr << ".";
}
}
compile_count_cond.notify_all();
}
@ -906,6 +907,10 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline)
static void ggml_pipeline_request_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) {
VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")");
device->pipeline_descriptor_set_requirements[pipeline->name] += n;
if (!pipeline->compiled) {
pipeline->needed = true;
device->need_compiles = true;
}
}
static void ggml_pipeline_allocate_descriptor_sets(vk_device& device) {
@ -1388,8 +1393,6 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
static void ggml_vk_load_shaders(vk_device& device) {
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
std::cerr << "ggml_vulkan: Compiling shaders";
// some shaders have a minimum subgroup size
const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u);
@ -1527,15 +1530,33 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
}
device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
if (!device->pipeline_matmul_f32) {
device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
}
if (!device->pipeline_matmul_f32_f16) {
device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
}
if (!device->pipeline_matmul_id_f32) {
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
}
std::vector<std::future<void>> compiles;
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
if (!pipeline) {
pipeline = std::make_shared<vk_pipeline_struct>();
pipeline->name = name;
pipeline->parameter_count = parameter_count;
pipeline->push_constant_size = push_constant_size;
pipeline->wg_denoms = wg_denoms;
pipeline->align = align;
}
if (!pipeline->needed || pipeline->compiled) {
return;
}
{
// wait until fewer than N compiles are in progress
uint32_t N = std::max(1u, std::thread::hardware_concurrency());
@ -1545,8 +1566,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
compile_count++;
}
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint,
parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness, require_full_subgroups, required_subgroup_size));
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
};
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
@ -1595,6 +1616,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
//CREATE_FA(GGML_TYPE_Q4_K, q4_k)
//CREATE_FA(GGML_TYPE_Q5_K, q5_k)
//CREATE_FA(GGML_TYPE_Q6_K, q6_k)
//CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs)
//CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs)
//CREATE_FA(GGML_TYPE_IQ2_S, iq2_s)
//CREATE_FA(GGML_TYPE_IQ3_XXS, iq3_xxs)
//CREATE_FA(GGML_TYPE_IQ3_S, iq3_s)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl)
#undef CREATE_FA
@ -1623,7 +1649,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
@ -1636,7 +1667,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
#undef CREATE_MM
#undef CREATE_MM2
} else
@ -1673,31 +1709,41 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
if (device->coopmat_acc_f16_support) {
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
} else {
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
}
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
@ -1707,31 +1753,41 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
if (device->coopmat_acc_f16_support) {
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
} else {
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
}
#undef CREATE_MM2
@ -1775,7 +1831,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
@ -1794,7 +1855,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
#undef CREATE_MM2
#undef CREATE_MM
@ -1830,7 +1896,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
@ -1849,7 +1920,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
#undef CREATE_MM
}
@ -1880,7 +1956,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
@ -1894,7 +1975,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
}
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
@ -1909,7 +1995,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
// dequant shaders
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@ -1923,7 +2014,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XXS], "dequant_iq2_xxs", dequant_iq2_xxs_len, dequant_iq2_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XS], "dequant_iq2_xs", dequant_iq2_xs_len, dequant_iq2_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S], "dequant_iq2_s", dequant_iq2_s_len, dequant_iq2_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_XXS], "dequant_iq3_xxs", dequant_iq3_xxs_len, dequant_iq3_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
// get_rows
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
@ -1933,7 +2029,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs", get_rows_iq2_xs_len, get_rows_iq2_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S], "get_rows_iq2_s", get_rows_iq2_s_len, get_rows_iq2_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs", get_rows_iq3_xxs_len, get_rows_iq3_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
@ -1942,7 +2043,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs_f32", get_rows_iq2_xs_f32_len, get_rows_iq2_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S], "get_rows_iq2_s_f32", get_rows_iq2_s_f32_len, get_rows_iq2_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs_f32", get_rows_iq3_xxs_f32_len, get_rows_iq3_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
@ -2012,7 +2118,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
@ -2050,7 +2156,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
for (auto &c : compiles) {
c.wait();
}
std::cerr << "Done!" << std::endl;
device->need_compiles = false;
}
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props);
@ -2869,6 +2975,11 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_NL:
break;
default:
@ -2917,6 +3028,11 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_NL:
break;
default:
@ -2948,6 +3064,11 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_NL:
break;
default:
@ -2991,6 +3112,11 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_NL:
break;
default:
@ -3017,6 +3143,11 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_NL:
break;
default:
@ -7656,6 +7787,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false);
}
if (ctx->device->need_compiles) {
ggml_vk_load_shaders(ctx->device);
}
ggml_vk_preallocate_buffers(ctx);
ggml_pipeline_allocate_descriptor_sets(ctx->device);
@ -7883,6 +8017,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_NL:
break;
default:
@ -7951,6 +8090,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
//case GGML_TYPE_Q4_K:
//case GGML_TYPE_Q5_K:
//case GGML_TYPE_Q6_K:
//case GGML_TYPE_IQ2_XXS:
//case GGML_TYPE_IQ2_XS:
//case GGML_TYPE_IQ2_S:
//case GGML_TYPE_IQ3_XXS:
//case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_NL:
break;
default:
@ -7968,6 +8112,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_NL:
return true;
default:

View file

@ -12,8 +12,8 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
#endif
void main() {
#if defined(DATA_A_IQ4_NL)
init_iq4nl_shmem();
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
init_iq_shmem(gl_WorkGroupSize);
if (gl_LocalInvocationIndex.x != 0) {
return;
}

View file

@ -217,8 +217,8 @@ void quantize(uint dst_idx, uint src_idx)
#endif
void main() {
#if defined(DATA_A_IQ4_NL)
init_iq4nl_shmem();
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
init_iq_shmem(gl_WorkGroupSize);
if (gl_LocalInvocationIndex.x != 0) {
return;
}

View file

@ -88,6 +88,222 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
}
#endif
#if defined(DATA_A_IQ2_XXS)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = (iqs / 8) % 4;
const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8];
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2],
data_a_packed16[a_offset + ib].qs[4 * ib32 + 3]));
const float db = 0.25 * (0.5 + (signs >> 28));
const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
return db * vec2(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = (iqs / 8) % 4;
const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8];
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2],
data_a_packed16[a_offset + ib].qs[4 * ib32 + 3]));
const float db = 0.25 * (0.5 + (signs >> 28));
const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
return db * vec4(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0),
grid.z * (sign2 ? -1.0 : 1.0),
grid.w * (sign3 ? -1.0 : 1.0)
);
}
#endif
#if defined(DATA_A_IQ2_XS)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf;
const uint qs = data_a[a_offset + ib].qs[iqs / 8];
const float db = 0.25 * (0.5 + scale);
const uint sign7 = qs >> 9;
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
return db * vec2(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf;
const uint qs = data_a[a_offset + ib].qs[iqs / 8];
const float db = 0.25 * (0.5 + scale);
const uint sign7 = qs >> 9;
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
return db * vec4(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0),
grid.z * (sign2 ? -1.0 : 1.0),
grid.w * (sign3 ? -1.0 : 1.0)
);
}
#endif
#if defined(DATA_A_IQ2_S)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = iqs / 8;
const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf;
const uint qs = data_a[a_offset + ib].qs[ib8];
const uint qh = data_a[a_offset + ib].qh[ib32];
const uint qhshift = 2 * (ib8 % 4);
const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8);
const float db = 0.25 * (0.5 + scale);
const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]);
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
return db * vec2(
grid[iqs % 4] * (sign0 ? -1.0 : 1.0),
grid[(iqs % 4) + 1] * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = iqs / 8;
const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf;
const uint qs = data_a[a_offset + ib].qs[ib8];
const uint qh = data_a[a_offset + ib].qh[ib32];
const uint qhshift = 2 * (ib8 % 4);
const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8);
const float db = 0.25 * (0.5 + scale);
const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]);
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
return db * vec4(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0),
grid.z * (sign2 ? -1.0 : 1.0),
grid.w * (sign3 ? -1.0 : 1.0)
);
}
#endif
#if defined(DATA_A_IQ3_XXS)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib4 = iqs / 4;
const uint ib32 = iqs / 32;
const uint is = QUANT_K / 4 + 4 * ib32;
const uint qs = data_a[a_offset + ib].qs[ib4];
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2],
data_a_packed16[a_offset + ib].qs[is / 2 + 1]));
const float db = 0.5 * (0.5 + (signs >> 28));
const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7);
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq3xxs_grid[qs] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
return db * vec2(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib4 = iqs / 4;
const uint ib32 = iqs / 32;
const uint is = QUANT_K / 4 + 4 * ib32;
const uint qs = data_a[a_offset + ib].qs[ib4];
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2],
data_a_packed16[a_offset + ib].qs[is / 2 + 1]));
const float db = 0.5 * (0.5 + (signs >> 28));
const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7);
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq3xxs_grid[qs]);
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
return db * vec4(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0),
grid.z * (sign2 ? -1.0 : 1.0),
grid.w * (sign3 ? -1.0 : 1.0)
);
}
#endif
#if defined(DATA_A_IQ3_S)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint qs = data_a[a_offset + ib].qs[iqs / 4];
const uint qh = data_a[a_offset + ib].qh[iqs / 32];
const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8);
const uint scale = data_a[a_offset + ib].scales[iqs / 64];
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
const float db = 1 + 2 * ((scale >> (4 * ((iqs / 32) & 1))) & 0xf);
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ((iqs / 4) % 8))) & 256)] >> (8 * (iqs % 4));
return db * vec2(
int(grid & 0xFF) * (sign0 ? -1.0 : 1.0),
int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib4 = iqs / 4;
const uint ib32 = iqs / 32;
const uint qs = data_a[a_offset + ib].qs[ib4];
const uint qh = data_a[a_offset + ib].qh[ib32];
const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8);
const uint scale = data_a[a_offset + ib].scales[ib32 / 2];
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
const float db = 1 + 2 * ((scale >> (4 * (ib32 & 1))) & 0xf);
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ib4 % 8)) & 256)] >> (8 * (iqs % 4));
return db * vec4(
int(grid & 0xFF) * (sign0 ? -1.0 : 1.0),
int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0),
int((grid >> 16) & 0xFF) * (sign2 ? -1.0 : 1.0),
int((grid >> 24) & 0xFF) * (sign3 ? -1.0 : 1.0)
);
}
#endif
#if defined(DATA_A_IQ4_NL)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
@ -105,7 +321,7 @@ vec2 get_dm(uint ib, uint a_offset) {
}
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(float(data_a[a_offset + ib].d), 0);
}

View file

@ -301,6 +301,160 @@ float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2
return ret;
}
#if defined(DATA_A_IQ2_XXS)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS {
block_iq2_xxs block;
};
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS_packed16 {
block_iq2_xxs_packed16 block;
};
float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl);
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint ib32 = (idx & 0xE0) >> 5; // 0..7
const uint ib8 = (idx & 0x18) >> 3; // 0..3
const uint iqs = 8 * ib32 + ib8;
const uint8_t qs = bl.block.qs[iqs];
const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3]));
const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t(signscale >> 28));
uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7);
sign |= bitCount(sign) << 7;
const uint8_t g = unpack8(iq2xxs_grid[qs][(idx & 4) >> 2])[idx & 3];
float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
return ret;
}
#endif
#if defined(DATA_A_IQ2_XS)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XS {
block_iq2_xs block;
};
float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint is = (idx & 0xE0) >> 5; // 0..8
const uint sshift = (idx & 0x10) >> 2; // 0,4
const uint iqs = (idx & 0xF8) >> 3; // 0..63
const uint16_t qs = bl.block.qs[iqs];
const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t((bl.block.scales[is] >> sshift) & 0xF));
uint sign = uint(qs >> 9);
sign |= bitCount(sign) << 7;
const uint8_t g = unpack8(iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2])[idx & 3];
float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
return ret;
}
#endif
#if defined(DATA_A_IQ2_S)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_S {
block_iq2_s block;
};
float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
uint idx = coordInBlock[1];
uint lsb = idx & 1;
idx /= 2;
const uint ib8 = (idx % 128) / 4; // 0..31
const uint ib32 = ib8 / 4; // 0..7
const uint scale = (bl.block.scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
const uint qs = bl.block.qs[ib8];
const uint qh = bl.block.qh[ib32];
const uint qhshift = 2 * (ib8 % 4);
const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4));
const float d = float(bl.block.d);
const float db = d * 0.25 * (0.5 + scale);
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid));
return float16_t(v[lsb]);
}
#endif
#if defined(DATA_A_IQ3_XXS)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS {
block_iq3_xxs block;
};
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS_packed16 {
block_iq3_xxs_packed16 block;
};
float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
uint idx = coordInBlock[1];
uint lsb = idx & 1;
idx /= 2;
const uint iqs = (idx % 128) / 2; // 0..63
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
const float d = float(bl.block.d);
const uint qs = bl.block.qs[iqs];
const uint signs = pack32(u8vec4(
bl.block.qs[is+0],
bl.block.qs[is+1],
bl.block.qs[is+2],
bl.block.qs[is+3]
));
const float db = d * 0.5 * (0.5 + (signs >> 28));
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
return float16_t(v[lsb]);
}
#endif
#if defined(DATA_A_IQ3_S)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_S {
block_iq3_s block;
};
float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
uint idx = coordInBlock[1];
uint lsb = idx & 1;
idx /= 2;
const uint iqs = (idx % 128) / 2; // 0..63
const uint iqh = iqs / 8;
const float d = float(bl.block.d);
const uint qs = bl.block.qs[iqs];
const uint qh = bl.block.qh[iqh];
const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (2 * (idx % 4)));
const uint scale = bl.block.scales[iqs / 16];
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
return float16_t(v[lsb]);
}
#endif
#if defined(DATA_A_IQ4_NL)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL {
block_iq4_nl block;
@ -340,6 +494,16 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
#define dequantFuncA dequantFuncQ5_K
#elif defined(DATA_A_Q6_K)
#define dequantFuncA dequantFuncQ6_K
#elif defined(DATA_A_IQ2_XXS)
#define dequantFuncA dequantFuncIQ2_XXS
#elif defined(DATA_A_IQ2_XS)
#define dequantFuncA dequantFuncIQ2_XS
#elif defined(DATA_A_IQ2_S)
#define dequantFuncA dequantFuncIQ2_S
#elif defined(DATA_A_IQ3_XXS)
#define dequantFuncA dequantFuncIQ3_XXS
#elif defined(DATA_A_IQ3_S)
#define dequantFuncA dequantFuncIQ3_S
#elif defined(DATA_A_IQ4_NL)
#define dequantFuncA dequantFuncIQ4_NL
#endif

View file

@ -0,0 +1,44 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq2_s data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 subblock (32 values with 2 scales)
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint ib32 = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * ib32;
const float d = float(data_a[ib].d);
const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4);
const vec2 db = d * (0.5 + scale) * 0.25;
uint qh = data_a[ib].qh[ib32];
[[unroll]] for (uint l = 0; l < 4; ++l) {
uint qs = data_a[ib].qs[4 * ib32 + l];
const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l];
qs |= (qh << (8 - 2 * l)) & 0x300;
const uvec2 grid = iq2s_grid[qs & 511];
const u8vec4 grid0 = unpack8(grid.x);
const u8vec4 grid1 = unpack8(grid.y);
data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign & 4) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign & 8) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign & 16) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign & 32) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign & 64) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign & 128) != 0 ? -1.0 : 1.0));
}
}

View file

@ -0,0 +1,43 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq2_xs data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 subblock (32 values with 2 scales)
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint ib32 = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * ib32;
const float d = float(data_a[ib].d);
const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4);
const vec2 db = d * (0.5 + scale) * 0.25;
[[unroll]] for (uint l = 0; l < 4; ++l) {
uint16_t qs = data_a[ib].qs[4 * ib32 + l];
const uint sign7 = qs >> 9;
const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit
const uvec2 grid = iq2xs_grid[qs & 511];
const u8vec4 grid0 = unpack8(grid.x);
const u8vec4 grid1 = unpack8(grid.y);
data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));
}
}

View file

@ -0,0 +1,48 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq2_xxs data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 scale block (32 values)
// Each block is described by 4 lattice indices, 4x7 sign bits and 4 scale bits
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint is = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * is;
const float d = float(data_a[ib].d);
uint signscale = pack32(u8vec4(
data_a[ib].qs[8*is + 4],
data_a[ib].qs[8*is + 5],
data_a[ib].qs[8*is + 6],
data_a[ib].qs[8*is + 7]
));
const float db = d * (0.5 + (signscale >> 28)) * 0.25;
[[unroll]] for (uint l = 0; l < 4; ++l) {
const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);
const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit
const uvec2 grid = iq2xxs_grid[data_a[ib].qs[8 * is + l]];
const u8vec4 grid0 = unpack8(grid.x);
const u8vec4 grid1 = unpack8(grid.y);
data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));
}
}

View file

@ -0,0 +1,39 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq3_s data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 scale nibble.
// Each block contains 4 scale bytes (8 scales) for 256 output values.
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint is = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * is;
const float d = float(data_a[ib].d);
const float db = d * (1 + 2 * ((data_a[ib].scales[is] >> (4 * (is % 2))) & 0xf));
// We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes.
uint qh = data_a[ib].qh[is];
[[unroll]] for (uint l = 0; l < 8; ++l) {
uint qs = data_a[ib].qs[8 * is + l];
uint gidx = qs | ((qh << (8 - l)) & 256);
uint8_t signs = data_a[ib].signs[8 * is + l / 2] >> (4 * (l & 1));
u8vec4 grid = unpack8(iq3s_grid[gidx]);
data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0));
data_b[b_idx + 4 * l + 3] = D_TYPE(db * grid.w * ((signs & 8) != 0 ? -1.0 : 1.0));
}
}

View file

@ -0,0 +1,49 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq3_xxs data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 scale block (32 values)
// 8 threads handle 1 superblock
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint is = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * is;
const uint s_idx = QUANT_K / 4 + 4 * is;
const float d = float(data_a[ib].d);
uint signscale = pack32(u8vec4(
data_a[ib].qs[s_idx + 0],
data_a[ib].qs[s_idx + 1],
data_a[ib].qs[s_idx + 2],
data_a[ib].qs[s_idx + 3]
));
const float db = d * (0.5 + (signscale >> 28)) * 0.5;
[[unroll]] for (uint l = 0; l < 4; ++l) {
const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);
// Restore parity bit.
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const u8vec4 grid0 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l]]);
const u8vec4 grid1 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l + 1]]);
data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));
}
}

View file

@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
init_iq4nl_shmem();
init_iq_shmem(gl_WorkGroupSize);
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;

View file

@ -12,7 +12,7 @@ layout (push_constant) uniform parameter
#include "types.comp"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};

View file

@ -104,8 +104,8 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
#endif
void main() {
#if defined(DATA_A_IQ4_NL)
init_iq4nl_shmem();
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
init_iq_shmem(gl_WorkGroupSize);
#endif
const uint32_t N = p.N;

View file

@ -12,8 +12,8 @@ void main() {
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
#if defined(DATA_A_IQ4_NL)
init_iq4nl_shmem();
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
init_iq_shmem(gl_WorkGroupSize);
#endif
if (i00 >= p.ne00) {

View file

@ -133,8 +133,8 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
#if defined(DATA_A_IQ4_NL)
init_iq4nl_shmem();
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
init_iq_shmem(gl_WorkGroupSize);
#endif
// do NUM_ROWS at a time, unless there aren't enough remaining rows

View file

@ -95,8 +95,8 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif
void main() {
#if defined(DATA_A_IQ4_NL)
init_iq4nl_shmem();
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
init_iq_shmem(gl_WorkGroupSize);
#endif
#ifdef MUL_MAT_ID
@ -343,10 +343,8 @@ void main() {
const uint qsshift = halfsplit * 2; // 0,2,4,6
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
const int8_t us = int8_t(is < 4 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+8] >> 0) & 3) << 4) :
is < 8 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+4] >> 2) & 3) << 4) :
is < 12 ? (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is+0] >> 4) & 3) << 4) :
(data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is-4] >> 6) & 3) << 4));
const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF)
| (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
const float dl = float(data_a[ib].d) * float(us - 32);
buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)));
@ -439,6 +437,118 @@ void main() {
buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32));
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
#elif defined(DATA_A_IQ2_XXS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint ib32 = (idx % 128) / 16; // 0..7
const uint ib8 = (idx / 4) % 4;
const float d = float(data_a[ib].d);
const uint qs = data_a[ib].qs[8 * ib32 + ib8];
const uint signs = pack32(u8vec4(
data_a[ib].qs[8*ib32 + 4],
data_a[ib].qs[8*ib32 + 5],
data_a[ib].qs[8*ib32 + 6],
data_a[ib].qs[8*ib32 + 7]
));
const float db = d * 0.25 * (0.5 + (signs >> 28));
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_IQ2_XS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint ib32 = (idx % 128) / 16; // 0..7
const uint ib8 = (idx / 4) % 4; // 0..3
const float d = float(data_a[ib].d);
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
const float db = d * 0.25 * (0.5 + scale);
const uint qs = data_a[ib].qs[4 * ib32 + ib8];
const uint sign7 = qs >> 9;
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_IQ2_S)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint ib8 = (idx % 128) / 4; // 0..31
const uint ib32 = ib8 / 4; // 0..7
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
const uint qs = data_a[ib].qs[ib8];
const uint qh = data_a[ib].qh[ib32];
const uint qhshift = 2 * (ib8 % 4);
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4));
const float d = float(data_a[ib].d);
const float db = d * 0.25 * (0.5 + scale);
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid));
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_IQ3_XXS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = (idx % 128) / 2; // 0..63
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
const float d = float(data_a[ib].d);
const uint qs = data_a[ib].qs[iqs];
const uint signs = pack32(u8vec4(
data_a[ib].qs[is+0],
data_a[ib].qs[is+1],
data_a[ib].qs[is+2],
data_a[ib].qs[is+3]
));
const float db = d * 0.5 * (0.5 + (signs >> 28));
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_IQ3_S)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = (idx % 128) / 2; // 0..63
const uint iqh = iqs / 8;
const float d = float(data_a[ib].d);
const uint qs = data_a[ib].qs[iqs];
const uint qh = data_a[ib].qh[iqh];
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4)));
const uint scale = data_a[ib].scales[iqs / 16];
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_IQ4_NL)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;

View file

@ -106,8 +106,8 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
#endif
void main() {
#if defined(DATA_A_IQ4_NL)
init_iq4nl_shmem();
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
init_iq_shmem(gl_WorkGroupSize);
#endif
#ifdef MUL_MAT_ID

View file

@ -294,6 +294,738 @@ struct block_q6_K_packed16
// IQuants
#define QUANT_K_IQ2_XXS 256
#define QUANT_R_IQ2_XXS 1
struct block_iq2_xxs
{
float16_t d;
uint8_t qs[QUANT_K_IQ2_XXS/4];
};
struct block_iq2_xxs_packed16
{
float16_t d;
uint16_t qs[QUANT_K_IQ2_XXS/8];
};
#if defined(DATA_A_IQ2_XXS)
const uvec2[256] iq2xxs_grid_const = {
uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808),
uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x082b0808, 0x08080808),
uvec2(0x082b082b, 0x08080808), uvec2(0x082b2b08, 0x08080808), uvec2(0x082b2b2b, 0x08080808), uvec2(0x19080819, 0x08080808),
uvec2(0x19081908, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808),
uvec2(0x192b1908, 0x08080808), uvec2(0x2b080808, 0x08080808), uvec2(0x2b08082b, 0x08080808), uvec2(0x2b082b2b, 0x08080808),
uvec2(0x2b2b082b, 0x08080808), uvec2(0x08080819, 0x08080819), uvec2(0x08081908, 0x08080819), uvec2(0x08190808, 0x08080819),
uvec2(0x08191919, 0x08080819), uvec2(0x19080808, 0x08080819), uvec2(0x2b081908, 0x08080819), uvec2(0x2b192b08, 0x08080819),
uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), uvec2(0x082b082b, 0x0808082b), uvec2(0x2b08082b, 0x0808082b),
uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x082b0819, 0x08081908),
uvec2(0x082b1908, 0x08081908), uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19082b08, 0x08081908),
uvec2(0x192b0808, 0x08081908), uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b190808, 0x08081908),
uvec2(0x2b2b1908, 0x08081908), uvec2(0x08080808, 0x08081919), uvec2(0x0808082b, 0x08081919), uvec2(0x08082b08, 0x08081919),
uvec2(0x082b0808, 0x08081919), uvec2(0x1908192b, 0x08081919), uvec2(0x192b2b19, 0x08081919), uvec2(0x2b080808, 0x08081919),
uvec2(0x2b190819, 0x08081919), uvec2(0x08082b19, 0x0808192b), uvec2(0x08190808, 0x0808192b), uvec2(0x19080808, 0x0808192b),
uvec2(0x2b081908, 0x0808192b), uvec2(0x2b2b1908, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x08081919, 0x08082b08),
uvec2(0x08082b08, 0x08082b08), uvec2(0x08191908, 0x08082b08), uvec2(0x082b2b08, 0x08082b08), uvec2(0x19080819, 0x08082b08),
uvec2(0x19081908, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x1919082b, 0x08082b08), uvec2(0x2b082b08, 0x08082b08),
uvec2(0x08081908, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x0808082b, 0x08082b2b), uvec2(0x08191908, 0x08082b2b),
uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x082b0819, 0x08190808),
uvec2(0x19080808, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x2b081908, 0x08190808), uvec2(0x2b190808, 0x08190808),
uvec2(0x2b191919, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x08082b08, 0x08190819), uvec2(0x082b0808, 0x08190819),
uvec2(0x19190808, 0x08190819), uvec2(0x19192b2b, 0x08190819), uvec2(0x2b080808, 0x08190819), uvec2(0x082b1908, 0x0819082b),
uvec2(0x19081919, 0x0819082b), uvec2(0x08080808, 0x08191908), uvec2(0x08082b08, 0x08191908), uvec2(0x082b0808, 0x08191908),
uvec2(0x082b1919, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x08192b08, 0x08191919),
uvec2(0x192b082b, 0x08191919), uvec2(0x08080808, 0x0819192b), uvec2(0x0819192b, 0x0819192b), uvec2(0x08080819, 0x08192b08),
uvec2(0x08081908, 0x08192b08), uvec2(0x08190808, 0x08192b08), uvec2(0x19080808, 0x08192b08), uvec2(0x2b080819, 0x08192b08),
uvec2(0x08080808, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x2b2b0808, 0x08192b19), uvec2(0x19190819, 0x08192b2b),
uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08082b2b, 0x082b0808), uvec2(0x19081908, 0x082b0808),
uvec2(0x192b0819, 0x082b0808), uvec2(0x2b080808, 0x082b0808), uvec2(0x2b08082b, 0x082b0808), uvec2(0x082b2b19, 0x082b0819),
uvec2(0x19082b08, 0x082b0819), uvec2(0x08080808, 0x082b082b), uvec2(0x0808082b, 0x082b082b), uvec2(0x08080819, 0x082b1908),
uvec2(0x08081908, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x19080808, 0x082b1908), uvec2(0x1919192b, 0x082b1908),
uvec2(0x08080808, 0x082b1919), uvec2(0x19080819, 0x082b1919), uvec2(0x192b1908, 0x082b1919), uvec2(0x2b190808, 0x082b192b),
uvec2(0x08082b08, 0x082b2b08), uvec2(0x082b0808, 0x082b2b08), uvec2(0x2b191908, 0x082b2b08), uvec2(0x19081908, 0x082b2b2b),
uvec2(0x08080819, 0x19080808), uvec2(0x08081908, 0x19080808), uvec2(0x08190808, 0x19080808), uvec2(0x08192b08, 0x19080808),
uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), uvec2(0x19080808, 0x19080808), uvec2(0x19082b08, 0x19080808),
uvec2(0x1919192b, 0x19080808), uvec2(0x192b0808, 0x19080808), uvec2(0x2b080819, 0x19080808), uvec2(0x2b081908, 0x19080808),
uvec2(0x2b190808, 0x19080808), uvec2(0x08080808, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x192b0819, 0x19080819),
uvec2(0x2b080808, 0x19080819), uvec2(0x2b081919, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08190808, 0x1908082b),
uvec2(0x19082b08, 0x1908082b), uvec2(0x1919192b, 0x1908082b), uvec2(0x192b2b08, 0x1908082b), uvec2(0x08080808, 0x19081908),
uvec2(0x08082b08, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x2b080808, 0x19081908), uvec2(0x2b192b19, 0x19081908),
uvec2(0x0819082b, 0x19081919), uvec2(0x082b1908, 0x19081919), uvec2(0x08080808, 0x1908192b), uvec2(0x08080819, 0x19082b08),
uvec2(0x08081908, 0x19082b08), uvec2(0x08190808, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x19081919, 0x19082b08),
uvec2(0x08080808, 0x19082b19), uvec2(0x19192b08, 0x19082b19), uvec2(0x192b0819, 0x19082b19), uvec2(0x2b08082b, 0x19082b19),
uvec2(0x19081919, 0x19082b2b), uvec2(0x2b190808, 0x19082b2b), uvec2(0x08080808, 0x19190808), uvec2(0x08082b08, 0x19190808),
uvec2(0x08190819, 0x19190808), uvec2(0x08192b19, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x2b080808, 0x19190808),
uvec2(0x2b082b08, 0x19190808), uvec2(0x08081908, 0x19190819), uvec2(0x1908082b, 0x19190819), uvec2(0x2b2b1908, 0x19190819),
uvec2(0x2b190819, 0x1919082b), uvec2(0x2b190808, 0x19191908), uvec2(0x2b19082b, 0x19191908), uvec2(0x08082b2b, 0x19191919),
uvec2(0x08080819, 0x1919192b), uvec2(0x19191908, 0x1919192b), uvec2(0x08080808, 0x19192b08), uvec2(0x08190819, 0x19192b08),
uvec2(0x08192b19, 0x19192b08), uvec2(0x192b1908, 0x19192b08), uvec2(0x19080808, 0x19192b19), uvec2(0x08082b08, 0x19192b2b),
uvec2(0x08081908, 0x192b0808), uvec2(0x08190808, 0x192b0808), uvec2(0x19080808, 0x192b0808), uvec2(0x192b2b08, 0x192b0808),
uvec2(0x08080808, 0x192b0819), uvec2(0x19191919, 0x192b0819), uvec2(0x08192b08, 0x192b082b), uvec2(0x192b0808, 0x192b082b),
uvec2(0x08080808, 0x192b1908), uvec2(0x08081919, 0x192b1908), uvec2(0x08190808, 0x192b1919), uvec2(0x0819082b, 0x192b1919),
uvec2(0x2b081908, 0x192b1919), uvec2(0x1908082b, 0x192b2b08), uvec2(0x08080808, 0x2b080808), uvec2(0x0808082b, 0x2b080808),
uvec2(0x08082b2b, 0x2b080808), uvec2(0x19080819, 0x2b080808), uvec2(0x2b08082b, 0x2b080808), uvec2(0x08081908, 0x2b080819),
uvec2(0x08192b08, 0x2b080819), uvec2(0x19080808, 0x2b080819), uvec2(0x08190819, 0x2b08082b), uvec2(0x08080819, 0x2b081908),
uvec2(0x08081908, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x08191919, 0x2b081908), uvec2(0x19080808, 0x2b081908),
uvec2(0x192b0808, 0x2b081908), uvec2(0x08080808, 0x2b081919), uvec2(0x1908192b, 0x2b081919), uvec2(0x2b191908, 0x2b081919),
uvec2(0x08082b19, 0x2b08192b), uvec2(0x19080808, 0x2b08192b), uvec2(0x192b0808, 0x2b08192b), uvec2(0x0808082b, 0x2b082b08),
uvec2(0x08081908, 0x2b082b19), uvec2(0x08190819, 0x2b082b2b), uvec2(0x08081908, 0x2b190808), uvec2(0x08190808, 0x2b190808),
uvec2(0x082b1908, 0x2b190808), uvec2(0x19080808, 0x2b190808), uvec2(0x2b2b0819, 0x2b190808), uvec2(0x0819192b, 0x2b190819),
uvec2(0x2b080808, 0x2b190819), uvec2(0x19081919, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x082b082b, 0x2b191908),
uvec2(0x19081908, 0x2b191908), uvec2(0x19190819, 0x2b191919), uvec2(0x2b080819, 0x2b192b08), uvec2(0x082b0808, 0x2b192b19),
uvec2(0x0808082b, 0x2b2b0808), uvec2(0x19190808, 0x2b2b0808), uvec2(0x2b081919, 0x2b2b0808), uvec2(0x08082b19, 0x2b2b0819),
uvec2(0x08080808, 0x2b2b082b), uvec2(0x08192b08, 0x2b2b1908), uvec2(0x19190808, 0x2b2b2b08), uvec2(0x08081908, 0x2b2b2b19)
};
shared uvec2 iq2xxs_grid[256];
void init_iq_shmem(uvec3 wgsize)
{
// copy the table into shared memory and sync
for (uint i = gl_LocalInvocationIndex.x; i < iq2xxs_grid.length(); i += wgsize.x) {
iq2xxs_grid[i] = iq2xxs_grid_const[i];
}
barrier();
}
#define QUANT_K QUANT_K_IQ2_XXS
#define QUANT_R QUANT_R_IQ2_XXS
#define A_TYPE block_iq2_xxs
#define A_TYPE_PACKED16 block_iq2_xxs_packed16
#endif
#define QUANT_K_IQ2_XS 256
#define QUANT_R_IQ2_XS 1
struct block_iq2_xs
{
float16_t d;
uint16_t qs[QUANT_K_IQ2_XS/8];
uint8_t scales[QUANT_K_IQ2_XS/32];
};
struct block_iq2_xs_packed16
{
float16_t d;
uint16_t qs[QUANT_K_IQ2_XS/8];
uint16_t scales[QUANT_K_IQ2_XS/64];
};
#if defined(DATA_A_IQ2_XS)
const uvec2 iq2xs_grid_const[512] = {
uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808),
uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x0819192b, 0x08080808),
uvec2(0x08192b19, 0x08080808), uvec2(0x082b0808, 0x08080808), uvec2(0x082b082b, 0x08080808), uvec2(0x082b1919, 0x08080808),
uvec2(0x082b2b08, 0x08080808), uvec2(0x19080819, 0x08080808), uvec2(0x19081908, 0x08080808), uvec2(0x1908192b, 0x08080808),
uvec2(0x19082b19, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x1919082b, 0x08080808), uvec2(0x19191919, 0x08080808),
uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), uvec2(0x192b1908, 0x08080808), uvec2(0x2b080808, 0x08080808),
uvec2(0x2b08082b, 0x08080808), uvec2(0x2b081919, 0x08080808), uvec2(0x2b082b08, 0x08080808), uvec2(0x2b190819, 0x08080808),
uvec2(0x2b191908, 0x08080808), uvec2(0x2b192b19, 0x08080808), uvec2(0x2b2b0808, 0x08080808), uvec2(0x08080819, 0x08080819),
uvec2(0x08081908, 0x08080819), uvec2(0x0808192b, 0x08080819), uvec2(0x08082b19, 0x08080819), uvec2(0x08190808, 0x08080819),
uvec2(0x0819082b, 0x08080819), uvec2(0x08191919, 0x08080819), uvec2(0x08192b08, 0x08080819), uvec2(0x08192b2b, 0x08080819),
uvec2(0x082b0819, 0x08080819), uvec2(0x082b1908, 0x08080819), uvec2(0x19080808, 0x08080819), uvec2(0x1908082b, 0x08080819),
uvec2(0x19081919, 0x08080819), uvec2(0x19082b08, 0x08080819), uvec2(0x19190819, 0x08080819), uvec2(0x19191908, 0x08080819),
uvec2(0x192b0808, 0x08080819), uvec2(0x192b2b08, 0x08080819), uvec2(0x2b080819, 0x08080819), uvec2(0x2b081908, 0x08080819),
uvec2(0x2b190808, 0x08080819), uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), uvec2(0x08081919, 0x0808082b),
uvec2(0x08082b08, 0x0808082b), uvec2(0x08190819, 0x0808082b), uvec2(0x08191908, 0x0808082b), uvec2(0x082b0808, 0x0808082b),
uvec2(0x19080819, 0x0808082b), uvec2(0x19081908, 0x0808082b), uvec2(0x19190808, 0x0808082b), uvec2(0x19191919, 0x0808082b),
uvec2(0x2b080808, 0x0808082b), uvec2(0x2b082b2b, 0x0808082b), uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908),
uvec2(0x0808192b, 0x08081908), uvec2(0x08082b19, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x0819082b, 0x08081908),
uvec2(0x08191919, 0x08081908), uvec2(0x08192b08, 0x08081908), uvec2(0x082b0819, 0x08081908), uvec2(0x082b1908, 0x08081908),
uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19081919, 0x08081908), uvec2(0x19082b08, 0x08081908),
uvec2(0x19190819, 0x08081908), uvec2(0x19191908, 0x08081908), uvec2(0x1919192b, 0x08081908), uvec2(0x192b0808, 0x08081908),
uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b190808, 0x08081908), uvec2(0x08080808, 0x08081919),
uvec2(0x0808082b, 0x08081919), uvec2(0x08081919, 0x08081919), uvec2(0x08082b08, 0x08081919), uvec2(0x08190819, 0x08081919),
uvec2(0x08191908, 0x08081919), uvec2(0x082b0808, 0x08081919), uvec2(0x19080819, 0x08081919), uvec2(0x19081908, 0x08081919),
uvec2(0x19190808, 0x08081919), uvec2(0x192b0819, 0x08081919), uvec2(0x2b080808, 0x08081919), uvec2(0x08080819, 0x0808192b),
uvec2(0x08081908, 0x0808192b), uvec2(0x08190808, 0x0808192b), uvec2(0x082b192b, 0x0808192b), uvec2(0x19080808, 0x0808192b),
uvec2(0x1908082b, 0x0808192b), uvec2(0x2b081908, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x0808082b, 0x08082b08),
uvec2(0x08081919, 0x08082b08), uvec2(0x08082b08, 0x08082b08), uvec2(0x08082b2b, 0x08082b08), uvec2(0x08190819, 0x08082b08),
uvec2(0x08191908, 0x08082b08), uvec2(0x082b0808, 0x08082b08), uvec2(0x082b1919, 0x08082b08), uvec2(0x19080819, 0x08082b08),
uvec2(0x19081908, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x19192b08, 0x08082b08), uvec2(0x2b080808, 0x08082b08),
uvec2(0x2b2b0808, 0x08082b08), uvec2(0x2b2b2b2b, 0x08082b08), uvec2(0x08080819, 0x08082b19), uvec2(0x08081908, 0x08082b19),
uvec2(0x08190808, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x2b080819, 0x08082b19), uvec2(0x2b082b19, 0x08082b19),
uvec2(0x08080808, 0x08082b2b), uvec2(0x082b0808, 0x08082b2b), uvec2(0x082b2b08, 0x08082b2b), uvec2(0x2b19192b, 0x08082b2b),
uvec2(0x2b2b0808, 0x08082b2b), uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), uvec2(0x0808192b, 0x08190808),
uvec2(0x08082b19, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x0819082b, 0x08190808), uvec2(0x08191919, 0x08190808),
uvec2(0x08192b08, 0x08190808), uvec2(0x082b0819, 0x08190808), uvec2(0x082b1908, 0x08190808), uvec2(0x19080808, 0x08190808),
uvec2(0x1908082b, 0x08190808), uvec2(0x19081919, 0x08190808), uvec2(0x19082b08, 0x08190808), uvec2(0x19190819, 0x08190808),
uvec2(0x19191908, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x192b2b2b, 0x08190808), uvec2(0x2b080819, 0x08190808),
uvec2(0x2b081908, 0x08190808), uvec2(0x2b190808, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x0808082b, 0x08190819),
uvec2(0x08081919, 0x08190819), uvec2(0x08082b08, 0x08190819), uvec2(0x08190819, 0x08190819), uvec2(0x08191908, 0x08190819),
uvec2(0x082b0808, 0x08190819), uvec2(0x19080819, 0x08190819), uvec2(0x19081908, 0x08190819), uvec2(0x19190808, 0x08190819),
uvec2(0x2b080808, 0x08190819), uvec2(0x2b191908, 0x08190819), uvec2(0x2b19192b, 0x08190819), uvec2(0x08080819, 0x0819082b),
uvec2(0x08081908, 0x0819082b), uvec2(0x0808192b, 0x0819082b), uvec2(0x08190808, 0x0819082b), uvec2(0x19080808, 0x0819082b),
uvec2(0x192b0808, 0x0819082b), uvec2(0x08080808, 0x08191908), uvec2(0x0808082b, 0x08191908), uvec2(0x08081919, 0x08191908),
uvec2(0x08082b08, 0x08191908), uvec2(0x08190819, 0x08191908), uvec2(0x08191908, 0x08191908), uvec2(0x082b0808, 0x08191908),
uvec2(0x19080819, 0x08191908), uvec2(0x19081908, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x19190808, 0x08191908),
uvec2(0x192b1908, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x08080819, 0x08191919), uvec2(0x08081908, 0x08191919),
uvec2(0x08190808, 0x08191919), uvec2(0x19080808, 0x08191919), uvec2(0x08080808, 0x0819192b), uvec2(0x08191908, 0x0819192b),
uvec2(0x19082b19, 0x0819192b), uvec2(0x08080819, 0x08192b08), uvec2(0x08081908, 0x08192b08), uvec2(0x08190808, 0x08192b08),
uvec2(0x0819082b, 0x08192b08), uvec2(0x19080808, 0x08192b08), uvec2(0x19191908, 0x08192b08), uvec2(0x2b08192b, 0x08192b08),
uvec2(0x08080808, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x192b192b, 0x08192b19), uvec2(0x19190819, 0x08192b2b),
uvec2(0x2b2b2b19, 0x08192b2b), uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08081919, 0x082b0808),
uvec2(0x08082b08, 0x082b0808), uvec2(0x08082b2b, 0x082b0808), uvec2(0x08190819, 0x082b0808), uvec2(0x08191908, 0x082b0808),
uvec2(0x082b0808, 0x082b0808), uvec2(0x19080819, 0x082b0808), uvec2(0x19081908, 0x082b0808), uvec2(0x19190808, 0x082b0808),
uvec2(0x2b080808, 0x082b0808), uvec2(0x2b2b0808, 0x082b0808), uvec2(0x08080819, 0x082b0819), uvec2(0x08081908, 0x082b0819),
uvec2(0x08190808, 0x082b0819), uvec2(0x19080808, 0x082b0819), uvec2(0x19082b08, 0x082b0819), uvec2(0x192b1919, 0x082b0819),
uvec2(0x08080808, 0x082b082b), uvec2(0x082b082b, 0x082b082b), uvec2(0x2b080808, 0x082b082b), uvec2(0x2b2b2b08, 0x082b082b),
uvec2(0x08080819, 0x082b1908), uvec2(0x08081908, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x082b2b19, 0x082b1908),
uvec2(0x19080808, 0x082b1908), uvec2(0x08080808, 0x082b1919), uvec2(0x19080819, 0x082b1919), uvec2(0x1919082b, 0x082b1919),
uvec2(0x2b192b19, 0x082b1919), uvec2(0x08080819, 0x082b192b), uvec2(0x08192b2b, 0x082b192b), uvec2(0x2b2b192b, 0x082b192b),
uvec2(0x08080808, 0x082b2b08), uvec2(0x08082b08, 0x082b2b08), uvec2(0x08082b2b, 0x082b2b08), uvec2(0x082b0808, 0x082b2b08),
uvec2(0x19191919, 0x082b2b08), uvec2(0x2b082b08, 0x082b2b08), uvec2(0x2b2b082b, 0x082b2b08), uvec2(0x192b2b08, 0x082b2b19),
uvec2(0x2b190808, 0x082b2b19), uvec2(0x08082b08, 0x082b2b2b), uvec2(0x082b0808, 0x082b2b2b), uvec2(0x2b08082b, 0x082b2b2b),
uvec2(0x2b082b08, 0x082b2b2b), uvec2(0x2b082b2b, 0x082b2b2b), uvec2(0x08080819, 0x19080808), uvec2(0x08081908, 0x19080808),
uvec2(0x0808192b, 0x19080808), uvec2(0x08082b19, 0x19080808), uvec2(0x08190808, 0x19080808), uvec2(0x0819082b, 0x19080808),
uvec2(0x08191919, 0x19080808), uvec2(0x08192b08, 0x19080808), uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808),
uvec2(0x19080808, 0x19080808), uvec2(0x1908082b, 0x19080808), uvec2(0x19081919, 0x19080808), uvec2(0x19082b08, 0x19080808),
uvec2(0x19082b2b, 0x19080808), uvec2(0x19190819, 0x19080808), uvec2(0x19191908, 0x19080808), uvec2(0x192b0808, 0x19080808),
uvec2(0x192b1919, 0x19080808), uvec2(0x2b080819, 0x19080808), uvec2(0x2b081908, 0x19080808), uvec2(0x2b190808, 0x19080808),
uvec2(0x08080808, 0x19080819), uvec2(0x0808082b, 0x19080819), uvec2(0x08081919, 0x19080819), uvec2(0x08082b08, 0x19080819),
uvec2(0x08190819, 0x19080819), uvec2(0x08191908, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x19080819, 0x19080819),
uvec2(0x19081908, 0x19080819), uvec2(0x19190808, 0x19080819), uvec2(0x2b080808, 0x19080819), uvec2(0x2b081919, 0x19080819),
uvec2(0x2b2b082b, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08081908, 0x1908082b), uvec2(0x08190808, 0x1908082b),
uvec2(0x0819082b, 0x1908082b), uvec2(0x082b2b19, 0x1908082b), uvec2(0x19080808, 0x1908082b), uvec2(0x08080808, 0x19081908),
uvec2(0x0808082b, 0x19081908), uvec2(0x08081919, 0x19081908), uvec2(0x08082b08, 0x19081908), uvec2(0x08190819, 0x19081908),
uvec2(0x08191908, 0x19081908), uvec2(0x08192b19, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x19080819, 0x19081908),
uvec2(0x19081908, 0x19081908), uvec2(0x19190808, 0x19081908), uvec2(0x2b080808, 0x19081908), uvec2(0x2b191908, 0x19081908),
uvec2(0x08080819, 0x19081919), uvec2(0x08081908, 0x19081919), uvec2(0x08190808, 0x19081919), uvec2(0x082b1908, 0x19081919),
uvec2(0x19080808, 0x19081919), uvec2(0x2b192b2b, 0x19081919), uvec2(0x08080808, 0x1908192b), uvec2(0x08082b2b, 0x1908192b),
uvec2(0x19081908, 0x1908192b), uvec2(0x19190808, 0x1908192b), uvec2(0x08080819, 0x19082b08), uvec2(0x08081908, 0x19082b08),
uvec2(0x08190808, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x19081919, 0x19082b08), uvec2(0x19191908, 0x19082b08),
uvec2(0x192b082b, 0x19082b08), uvec2(0x08080808, 0x19082b19), uvec2(0x08190819, 0x19082b19), uvec2(0x19081908, 0x19082b19),
uvec2(0x19190808, 0x19082b19), uvec2(0x192b2b19, 0x19082b19), uvec2(0x08081908, 0x19082b2b), uvec2(0x08080808, 0x19190808),
uvec2(0x0808082b, 0x19190808), uvec2(0x08081919, 0x19190808), uvec2(0x08082b08, 0x19190808), uvec2(0x08190819, 0x19190808),
uvec2(0x08191908, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x082b2b08, 0x19190808), uvec2(0x19080819, 0x19190808),
uvec2(0x19081908, 0x19190808), uvec2(0x19190808, 0x19190808), uvec2(0x2b080808, 0x19190808), uvec2(0x08080819, 0x19190819),
uvec2(0x08081908, 0x19190819), uvec2(0x08190808, 0x19190819), uvec2(0x08191919, 0x19190819), uvec2(0x19080808, 0x19190819),
uvec2(0x1908082b, 0x19190819), uvec2(0x08080808, 0x1919082b), uvec2(0x19081908, 0x1919082b), uvec2(0x2b2b2b2b, 0x1919082b),
uvec2(0x08080819, 0x19191908), uvec2(0x08081908, 0x19191908), uvec2(0x08190808, 0x19191908), uvec2(0x082b0819, 0x19191908),
uvec2(0x19080808, 0x19191908), uvec2(0x192b0808, 0x19191908), uvec2(0x2b080819, 0x19191908), uvec2(0x2b2b0819, 0x19191908),
uvec2(0x08080808, 0x19191919), uvec2(0x08082b08, 0x19191919), uvec2(0x2b080808, 0x19191919), uvec2(0x2b082b08, 0x19191919),
uvec2(0x082b0819, 0x1919192b), uvec2(0x192b2b08, 0x1919192b), uvec2(0x2b2b0819, 0x1919192b), uvec2(0x08080808, 0x19192b08),
uvec2(0x08191908, 0x19192b08), uvec2(0x19080819, 0x19192b08), uvec2(0x19190808, 0x19192b08), uvec2(0x2b192b19, 0x19192b08),
uvec2(0x08192b2b, 0x19192b19), uvec2(0x19080808, 0x19192b19), uvec2(0x1908082b, 0x19192b19), uvec2(0x2b081919, 0x19192b2b),
uvec2(0x08080819, 0x192b0808), uvec2(0x08081908, 0x192b0808), uvec2(0x08190808, 0x192b0808), uvec2(0x19080808, 0x192b0808),
uvec2(0x19191908, 0x192b0808), uvec2(0x192b082b, 0x192b0808), uvec2(0x2b08192b, 0x192b0808), uvec2(0x2b2b2b19, 0x192b0808),
uvec2(0x08080808, 0x192b0819), uvec2(0x082b1908, 0x192b082b), uvec2(0x19082b2b, 0x192b082b), uvec2(0x2b19082b, 0x192b082b),
uvec2(0x08080808, 0x192b1908), uvec2(0x0819192b, 0x192b1908), uvec2(0x08190808, 0x192b1919), uvec2(0x19080808, 0x192b1919),
uvec2(0x19081919, 0x192b1919), uvec2(0x2b2b1908, 0x192b1919), uvec2(0x08080819, 0x192b2b08), uvec2(0x192b2b2b, 0x192b2b08),
uvec2(0x082b1919, 0x192b2b19), uvec2(0x0808192b, 0x192b2b2b), uvec2(0x19191908, 0x192b2b2b), uvec2(0x192b082b, 0x192b2b2b),
uvec2(0x08080808, 0x2b080808), uvec2(0x0808082b, 0x2b080808), uvec2(0x08081919, 0x2b080808), uvec2(0x08082b08, 0x2b080808),
uvec2(0x08190819, 0x2b080808), uvec2(0x08191908, 0x2b080808), uvec2(0x082b0808, 0x2b080808), uvec2(0x082b2b2b, 0x2b080808),
uvec2(0x19080819, 0x2b080808), uvec2(0x19081908, 0x2b080808), uvec2(0x19190808, 0x2b080808), uvec2(0x2b080808, 0x2b080808),
uvec2(0x2b08082b, 0x2b080808), uvec2(0x2b2b2b08, 0x2b080808), uvec2(0x2b2b2b2b, 0x2b080808), uvec2(0x08080819, 0x2b080819),
uvec2(0x08081908, 0x2b080819), uvec2(0x0808192b, 0x2b080819), uvec2(0x08190808, 0x2b080819), uvec2(0x19080808, 0x2b080819),
uvec2(0x19190819, 0x2b080819), uvec2(0x19192b19, 0x2b080819), uvec2(0x08080808, 0x2b08082b), uvec2(0x082b0808, 0x2b08082b),
uvec2(0x2b080808, 0x2b08082b), uvec2(0x2b08082b, 0x2b08082b), uvec2(0x2b2b0808, 0x2b08082b), uvec2(0x2b2b2b08, 0x2b08082b),
uvec2(0x08080819, 0x2b081908), uvec2(0x08081908, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x0819082b, 0x2b081908),
uvec2(0x08191919, 0x2b081908), uvec2(0x19080808, 0x2b081908), uvec2(0x192b0808, 0x2b081908), uvec2(0x2b082b19, 0x2b081908),
uvec2(0x08080808, 0x2b081919), uvec2(0x19081908, 0x2b081919), uvec2(0x2b2b1919, 0x2b081919), uvec2(0x08192b08, 0x2b08192b),
uvec2(0x192b2b2b, 0x2b08192b), uvec2(0x08080808, 0x2b082b08), uvec2(0x08082b08, 0x2b082b08), uvec2(0x082b1919, 0x2b082b08),
uvec2(0x19192b2b, 0x2b082b08), uvec2(0x2b080808, 0x2b082b08), uvec2(0x2b08082b, 0x2b082b08), uvec2(0x2b2b2b08, 0x2b082b08),
uvec2(0x0808192b, 0x2b082b19), uvec2(0x082b082b, 0x2b082b2b), uvec2(0x2b080808, 0x2b082b2b), uvec2(0x2b082b08, 0x2b082b2b),
uvec2(0x2b19192b, 0x2b082b2b), uvec2(0x2b2b2b08, 0x2b082b2b), uvec2(0x08080819, 0x2b190808), uvec2(0x08081908, 0x2b190808),
uvec2(0x08190808, 0x2b190808), uvec2(0x19080808, 0x2b190808), uvec2(0x1919192b, 0x2b190808), uvec2(0x2b081908, 0x2b190808),
uvec2(0x08080808, 0x2b190819), uvec2(0x082b082b, 0x2b190819), uvec2(0x192b1908, 0x2b190819), uvec2(0x1919192b, 0x2b19082b),
uvec2(0x2b082b19, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x08081919, 0x2b191908), uvec2(0x19081908, 0x2b191908),
uvec2(0x19190808, 0x2b191908), uvec2(0x19192b08, 0x2b191908), uvec2(0x082b2b19, 0x2b191919), uvec2(0x2b190808, 0x2b191919),
uvec2(0x2b19082b, 0x2b191919), uvec2(0x19080819, 0x2b19192b), uvec2(0x19190819, 0x2b192b08), uvec2(0x2b2b192b, 0x2b192b08),
uvec2(0x19082b19, 0x2b192b19), uvec2(0x08191919, 0x2b192b2b), uvec2(0x192b0808, 0x2b192b2b), uvec2(0x08080808, 0x2b2b0808),
uvec2(0x0808082b, 0x2b2b0808), uvec2(0x08082b08, 0x2b2b0808), uvec2(0x08082b2b, 0x2b2b0808), uvec2(0x082b0808, 0x2b2b0808),
uvec2(0x082b2b2b, 0x2b2b0808), uvec2(0x2b2b0808, 0x2b2b0808), uvec2(0x19190819, 0x2b2b0819), uvec2(0x19192b19, 0x2b2b0819),
uvec2(0x2b2b192b, 0x2b2b0819), uvec2(0x08080808, 0x2b2b082b), uvec2(0x0808082b, 0x2b2b082b), uvec2(0x08082b08, 0x2b2b082b),
uvec2(0x082b2b2b, 0x2b2b082b), uvec2(0x2b080808, 0x2b2b082b), uvec2(0x2b2b0808, 0x2b2b082b), uvec2(0x19080808, 0x2b2b1908),
uvec2(0x2b191919, 0x2b2b1908), uvec2(0x192b1919, 0x2b2b192b), uvec2(0x2b192b08, 0x2b2b192b), uvec2(0x08082b2b, 0x2b2b2b08),
uvec2(0x082b0808, 0x2b2b2b08), uvec2(0x082b082b, 0x2b2b2b08), uvec2(0x082b2b08, 0x2b2b2b08), uvec2(0x2b2b0808, 0x2b2b2b08),
uvec2(0x2b2b2b08, 0x2b2b2b08), uvec2(0x08081908, 0x2b2b2b19), uvec2(0x2b081908, 0x2b2b2b19), uvec2(0x2b08192b, 0x2b2b2b19),
uvec2(0x082b2b08, 0x2b2b2b2b), uvec2(0x082b2b2b, 0x2b2b2b2b), uvec2(0x2b190819, 0x2b2b2b2b), uvec2(0x2b2b2b2b, 0x2b2b2b2b),
};
shared uvec2 iq2xs_grid[512];
void init_iq_shmem(uvec3 wgsize)
{
// copy the table into shared memory and sync
for (uint i = gl_LocalInvocationIndex.x; i < iq2xs_grid.length(); i += wgsize.x) {
iq2xs_grid[i] = iq2xs_grid_const[i];
}
barrier();
}
#define QUANT_K QUANT_K_IQ2_XS
#define QUANT_R QUANT_R_IQ2_XS
#define A_TYPE block_iq2_xs
#define A_TYPE_PACKED16 block_iq2_xs_packed16
#endif
#define QUANT_K_IQ2_S 256
#define QUANT_R_IQ2_S 1
struct block_iq2_s
{
float16_t d;
uint8_t qs[QUANT_K_IQ2_S/4];
uint8_t qh[QUANT_K_IQ2_S/32];
uint8_t scales[QUANT_K_IQ2_S/32];
};
#if defined(DATA_A_IQ2_S)
const uvec2 iq2s_grid_const[1024] = {
uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808),
uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x0819192b, 0x08080808),
uvec2(0x08192b19, 0x08080808), uvec2(0x082b0808, 0x08080808), uvec2(0x082b082b, 0x08080808), uvec2(0x082b1919, 0x08080808),
uvec2(0x082b2b08, 0x08080808), uvec2(0x19080819, 0x08080808), uvec2(0x19081908, 0x08080808), uvec2(0x1908192b, 0x08080808),
uvec2(0x19082b19, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x1919082b, 0x08080808), uvec2(0x19191919, 0x08080808),
uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), uvec2(0x192b1908, 0x08080808), uvec2(0x192b192b, 0x08080808),
uvec2(0x192b2b19, 0x08080808), uvec2(0x2b080808, 0x08080808), uvec2(0x2b08082b, 0x08080808), uvec2(0x2b081919, 0x08080808),
uvec2(0x2b082b08, 0x08080808), uvec2(0x2b190819, 0x08080808), uvec2(0x2b191908, 0x08080808), uvec2(0x2b2b0808, 0x08080808),
uvec2(0x2b2b1919, 0x08080808), uvec2(0x2b2b2b2b, 0x08080808), uvec2(0x08080819, 0x08080819), uvec2(0x08081908, 0x08080819),
uvec2(0x0808192b, 0x08080819), uvec2(0x08082b19, 0x08080819), uvec2(0x08190808, 0x08080819), uvec2(0x0819082b, 0x08080819),
uvec2(0x08191919, 0x08080819), uvec2(0x08192b08, 0x08080819), uvec2(0x082b0819, 0x08080819), uvec2(0x082b1908, 0x08080819),
uvec2(0x19080808, 0x08080819), uvec2(0x1908082b, 0x08080819), uvec2(0x19081919, 0x08080819), uvec2(0x19082b08, 0x08080819),
uvec2(0x19190819, 0x08080819), uvec2(0x19191908, 0x08080819), uvec2(0x1919192b, 0x08080819), uvec2(0x19192b19, 0x08080819),
uvec2(0x192b0808, 0x08080819), uvec2(0x192b1919, 0x08080819), uvec2(0x192b2b08, 0x08080819), uvec2(0x2b080819, 0x08080819),
uvec2(0x2b081908, 0x08080819), uvec2(0x2b190808, 0x08080819), uvec2(0x2b19082b, 0x08080819), uvec2(0x2b191919, 0x08080819),
uvec2(0x2b2b0819, 0x08080819), uvec2(0x2b2b1908, 0x08080819), uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b),
uvec2(0x08081919, 0x0808082b), uvec2(0x08082b08, 0x0808082b), uvec2(0x08190819, 0x0808082b), uvec2(0x08191908, 0x0808082b),
uvec2(0x082b0808, 0x0808082b), uvec2(0x082b2b2b, 0x0808082b), uvec2(0x19080819, 0x0808082b), uvec2(0x19081908, 0x0808082b),
uvec2(0x1908192b, 0x0808082b), uvec2(0x19082b19, 0x0808082b), uvec2(0x19190808, 0x0808082b), uvec2(0x19191919, 0x0808082b),
uvec2(0x2b080808, 0x0808082b), uvec2(0x2b081919, 0x0808082b), uvec2(0x2b082b2b, 0x0808082b), uvec2(0x2b191908, 0x0808082b),
uvec2(0x2b2b082b, 0x0808082b), uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), uvec2(0x0808192b, 0x08081908),
uvec2(0x08082b19, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x0819082b, 0x08081908), uvec2(0x08191919, 0x08081908),
uvec2(0x08192b08, 0x08081908), uvec2(0x082b0819, 0x08081908), uvec2(0x082b1908, 0x08081908), uvec2(0x082b192b, 0x08081908),
uvec2(0x082b2b19, 0x08081908), uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19081919, 0x08081908),
uvec2(0x19082b08, 0x08081908), uvec2(0x19082b2b, 0x08081908), uvec2(0x19190819, 0x08081908), uvec2(0x19191908, 0x08081908),
uvec2(0x1919192b, 0x08081908), uvec2(0x19192b19, 0x08081908), uvec2(0x192b0808, 0x08081908), uvec2(0x192b082b, 0x08081908),
uvec2(0x192b1919, 0x08081908), uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b08192b, 0x08081908),
uvec2(0x2b082b19, 0x08081908), uvec2(0x2b190808, 0x08081908), uvec2(0x2b191919, 0x08081908), uvec2(0x2b192b08, 0x08081908),
uvec2(0x2b2b0819, 0x08081908), uvec2(0x2b2b1908, 0x08081908), uvec2(0x08080808, 0x08081919), uvec2(0x0808082b, 0x08081919),
uvec2(0x08081919, 0x08081919), uvec2(0x08082b08, 0x08081919), uvec2(0x08082b2b, 0x08081919), uvec2(0x08190819, 0x08081919),
uvec2(0x08191908, 0x08081919), uvec2(0x0819192b, 0x08081919), uvec2(0x08192b19, 0x08081919), uvec2(0x082b0808, 0x08081919),
uvec2(0x082b1919, 0x08081919), uvec2(0x082b2b08, 0x08081919), uvec2(0x19080819, 0x08081919), uvec2(0x19081908, 0x08081919),
uvec2(0x1908192b, 0x08081919), uvec2(0x19082b19, 0x08081919), uvec2(0x19190808, 0x08081919), uvec2(0x1919082b, 0x08081919),
uvec2(0x19191919, 0x08081919), uvec2(0x19192b08, 0x08081919), uvec2(0x192b0819, 0x08081919), uvec2(0x192b1908, 0x08081919),
uvec2(0x2b080808, 0x08081919), uvec2(0x2b08082b, 0x08081919), uvec2(0x2b081919, 0x08081919), uvec2(0x2b082b08, 0x08081919),
uvec2(0x2b190819, 0x08081919), uvec2(0x2b191908, 0x08081919), uvec2(0x2b2b0808, 0x08081919), uvec2(0x08080819, 0x0808192b),
uvec2(0x08081908, 0x0808192b), uvec2(0x0808192b, 0x0808192b), uvec2(0x08082b19, 0x0808192b), uvec2(0x08190808, 0x0808192b),
uvec2(0x08191919, 0x0808192b), uvec2(0x19080808, 0x0808192b), uvec2(0x19081919, 0x0808192b), uvec2(0x19082b08, 0x0808192b),
uvec2(0x19190819, 0x0808192b), uvec2(0x19191908, 0x0808192b), uvec2(0x192b0808, 0x0808192b), uvec2(0x2b080819, 0x0808192b),
uvec2(0x2b081908, 0x0808192b), uvec2(0x2b190808, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x0808082b, 0x08082b08),
uvec2(0x08081919, 0x08082b08), uvec2(0x08082b08, 0x08082b08), uvec2(0x08190819, 0x08082b08), uvec2(0x08191908, 0x08082b08),
uvec2(0x0819192b, 0x08082b08), uvec2(0x08192b19, 0x08082b08), uvec2(0x082b0808, 0x08082b08), uvec2(0x082b1919, 0x08082b08),
uvec2(0x082b2b2b, 0x08082b08), uvec2(0x19080819, 0x08082b08), uvec2(0x19081908, 0x08082b08), uvec2(0x1908192b, 0x08082b08),
uvec2(0x19082b19, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x1919082b, 0x08082b08), uvec2(0x19191919, 0x08082b08),
uvec2(0x19192b08, 0x08082b08), uvec2(0x192b0819, 0x08082b08), uvec2(0x192b1908, 0x08082b08), uvec2(0x2b080808, 0x08082b08),
uvec2(0x2b081919, 0x08082b08), uvec2(0x2b191908, 0x08082b08), uvec2(0x2b2b2b2b, 0x08082b08), uvec2(0x08080819, 0x08082b19),
uvec2(0x08081908, 0x08082b19), uvec2(0x08190808, 0x08082b19), uvec2(0x0819082b, 0x08082b19), uvec2(0x08191919, 0x08082b19),
uvec2(0x08192b08, 0x08082b19), uvec2(0x082b0819, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x19081919, 0x08082b19),
uvec2(0x19082b08, 0x08082b19), uvec2(0x19190819, 0x08082b19), uvec2(0x19191908, 0x08082b19), uvec2(0x192b0808, 0x08082b19),
uvec2(0x2b080819, 0x08082b19), uvec2(0x2b190808, 0x08082b19), uvec2(0x08080808, 0x08082b2b), uvec2(0x08190819, 0x08082b2b),
uvec2(0x08191908, 0x08082b2b), uvec2(0x082b082b, 0x08082b2b), uvec2(0x082b2b08, 0x08082b2b), uvec2(0x082b2b2b, 0x08082b2b),
uvec2(0x19190808, 0x08082b2b), uvec2(0x2b192b19, 0x08082b2b), uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808),
uvec2(0x0808192b, 0x08190808), uvec2(0x08082b19, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x0819082b, 0x08190808),
uvec2(0x08191919, 0x08190808), uvec2(0x08192b08, 0x08190808), uvec2(0x082b0819, 0x08190808), uvec2(0x082b1908, 0x08190808),
uvec2(0x082b192b, 0x08190808), uvec2(0x19080808, 0x08190808), uvec2(0x1908082b, 0x08190808), uvec2(0x19081919, 0x08190808),
uvec2(0x19082b08, 0x08190808), uvec2(0x19190819, 0x08190808), uvec2(0x19191908, 0x08190808), uvec2(0x1919192b, 0x08190808),
uvec2(0x19192b19, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x192b082b, 0x08190808), uvec2(0x192b1919, 0x08190808),
uvec2(0x192b2b08, 0x08190808), uvec2(0x2b080819, 0x08190808), uvec2(0x2b081908, 0x08190808), uvec2(0x2b08192b, 0x08190808),
uvec2(0x2b190808, 0x08190808), uvec2(0x2b191919, 0x08190808), uvec2(0x2b192b08, 0x08190808), uvec2(0x2b2b0819, 0x08190808),
uvec2(0x2b2b1908, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x0808082b, 0x08190819), uvec2(0x08081919, 0x08190819),
uvec2(0x08082b08, 0x08190819), uvec2(0x08082b2b, 0x08190819), uvec2(0x08190819, 0x08190819), uvec2(0x08191908, 0x08190819),
uvec2(0x0819192b, 0x08190819), uvec2(0x08192b19, 0x08190819), uvec2(0x082b0808, 0x08190819), uvec2(0x082b082b, 0x08190819),
uvec2(0x082b1919, 0x08190819), uvec2(0x082b2b08, 0x08190819), uvec2(0x19080819, 0x08190819), uvec2(0x19081908, 0x08190819),
uvec2(0x1908192b, 0x08190819), uvec2(0x19082b19, 0x08190819), uvec2(0x19190808, 0x08190819), uvec2(0x1919082b, 0x08190819),
uvec2(0x19191919, 0x08190819), uvec2(0x19192b08, 0x08190819), uvec2(0x192b0819, 0x08190819), uvec2(0x192b1908, 0x08190819),
uvec2(0x2b080808, 0x08190819), uvec2(0x2b08082b, 0x08190819), uvec2(0x2b081919, 0x08190819), uvec2(0x2b082b08, 0x08190819),
uvec2(0x2b190819, 0x08190819), uvec2(0x2b191908, 0x08190819), uvec2(0x08080819, 0x0819082b), uvec2(0x08081908, 0x0819082b),
uvec2(0x08082b19, 0x0819082b), uvec2(0x08190808, 0x0819082b), uvec2(0x08191919, 0x0819082b), uvec2(0x082b0819, 0x0819082b),
uvec2(0x082b1908, 0x0819082b), uvec2(0x19080808, 0x0819082b), uvec2(0x19081919, 0x0819082b), uvec2(0x19190819, 0x0819082b),
uvec2(0x19191908, 0x0819082b), uvec2(0x2b080819, 0x0819082b), uvec2(0x2b081908, 0x0819082b), uvec2(0x2b190808, 0x0819082b),
uvec2(0x08080808, 0x08191908), uvec2(0x0808082b, 0x08191908), uvec2(0x08081919, 0x08191908), uvec2(0x08082b08, 0x08191908),
uvec2(0x08190819, 0x08191908), uvec2(0x08191908, 0x08191908), uvec2(0x0819192b, 0x08191908), uvec2(0x08192b19, 0x08191908),
uvec2(0x082b0808, 0x08191908), uvec2(0x082b1919, 0x08191908), uvec2(0x082b2b08, 0x08191908), uvec2(0x19080819, 0x08191908),
uvec2(0x19081908, 0x08191908), uvec2(0x1908192b, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x19190808, 0x08191908),
uvec2(0x1919082b, 0x08191908), uvec2(0x19191919, 0x08191908), uvec2(0x19192b08, 0x08191908), uvec2(0x192b0819, 0x08191908),
uvec2(0x192b1908, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x2b08082b, 0x08191908), uvec2(0x2b081919, 0x08191908),
uvec2(0x2b082b08, 0x08191908), uvec2(0x2b190819, 0x08191908), uvec2(0x2b191908, 0x08191908), uvec2(0x2b2b0808, 0x08191908),
uvec2(0x08080819, 0x08191919), uvec2(0x08081908, 0x08191919), uvec2(0x0808192b, 0x08191919), uvec2(0x08082b19, 0x08191919),
uvec2(0x08190808, 0x08191919), uvec2(0x0819082b, 0x08191919), uvec2(0x08191919, 0x08191919), uvec2(0x08192b08, 0x08191919),
uvec2(0x082b0819, 0x08191919), uvec2(0x082b1908, 0x08191919), uvec2(0x19080808, 0x08191919), uvec2(0x1908082b, 0x08191919),
uvec2(0x19081919, 0x08191919), uvec2(0x19082b08, 0x08191919), uvec2(0x19190819, 0x08191919), uvec2(0x19191908, 0x08191919),
uvec2(0x192b0808, 0x08191919), uvec2(0x2b080819, 0x08191919), uvec2(0x2b081908, 0x08191919), uvec2(0x2b190808, 0x08191919),
uvec2(0x08080808, 0x0819192b), uvec2(0x08081919, 0x0819192b), uvec2(0x08082b08, 0x0819192b), uvec2(0x08190819, 0x0819192b),
uvec2(0x08191908, 0x0819192b), uvec2(0x082b0808, 0x0819192b), uvec2(0x19080819, 0x0819192b), uvec2(0x19081908, 0x0819192b),
uvec2(0x19190808, 0x0819192b), uvec2(0x2b080808, 0x0819192b), uvec2(0x2b2b2b2b, 0x0819192b), uvec2(0x08080819, 0x08192b08),
uvec2(0x08081908, 0x08192b08), uvec2(0x0808192b, 0x08192b08), uvec2(0x08082b19, 0x08192b08), uvec2(0x08190808, 0x08192b08),
uvec2(0x08191919, 0x08192b08), uvec2(0x08192b08, 0x08192b08), uvec2(0x082b0819, 0x08192b08), uvec2(0x19080808, 0x08192b08),
uvec2(0x1908082b, 0x08192b08), uvec2(0x19081919, 0x08192b08), uvec2(0x19082b08, 0x08192b08), uvec2(0x19190819, 0x08192b08),
uvec2(0x19191908, 0x08192b08), uvec2(0x192b0808, 0x08192b08), uvec2(0x2b080819, 0x08192b08), uvec2(0x2b081908, 0x08192b08),
uvec2(0x08080808, 0x08192b19), uvec2(0x0808082b, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x08082b08, 0x08192b19),
uvec2(0x08190819, 0x08192b19), uvec2(0x08191908, 0x08192b19), uvec2(0x082b0808, 0x08192b19), uvec2(0x19080819, 0x08192b19),
uvec2(0x19081908, 0x08192b19), uvec2(0x19190808, 0x08192b19), uvec2(0x192b2b19, 0x08192b19), uvec2(0x2b2b082b, 0x08192b19),
uvec2(0x08081908, 0x08192b2b), uvec2(0x08190808, 0x08192b2b), uvec2(0x19080808, 0x08192b2b), uvec2(0x1919192b, 0x08192b2b),
uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08081919, 0x082b0808), uvec2(0x08082b08, 0x082b0808),
uvec2(0x08190819, 0x082b0808), uvec2(0x08191908, 0x082b0808), uvec2(0x0819192b, 0x082b0808), uvec2(0x08192b19, 0x082b0808),
uvec2(0x082b0808, 0x082b0808), uvec2(0x082b1919, 0x082b0808), uvec2(0x082b2b2b, 0x082b0808), uvec2(0x19080819, 0x082b0808),
uvec2(0x19081908, 0x082b0808), uvec2(0x19190808, 0x082b0808), uvec2(0x1919082b, 0x082b0808), uvec2(0x19191919, 0x082b0808),
uvec2(0x192b1908, 0x082b0808), uvec2(0x2b080808, 0x082b0808), uvec2(0x2b082b2b, 0x082b0808), uvec2(0x2b191908, 0x082b0808),
uvec2(0x2b2b2b2b, 0x082b0808), uvec2(0x08080819, 0x082b0819), uvec2(0x08081908, 0x082b0819), uvec2(0x08190808, 0x082b0819),
uvec2(0x0819082b, 0x082b0819), uvec2(0x08191919, 0x082b0819), uvec2(0x082b0819, 0x082b0819), uvec2(0x19080808, 0x082b0819),
uvec2(0x1908082b, 0x082b0819), uvec2(0x19081919, 0x082b0819), uvec2(0x19190819, 0x082b0819), uvec2(0x19191908, 0x082b0819),
uvec2(0x192b0808, 0x082b0819), uvec2(0x2b080819, 0x082b0819), uvec2(0x2b081908, 0x082b0819), uvec2(0x2b190808, 0x082b0819),
uvec2(0x08080808, 0x082b082b), uvec2(0x08082b2b, 0x082b082b), uvec2(0x082b082b, 0x082b082b), uvec2(0x082b2b08, 0x082b082b),
uvec2(0x082b2b2b, 0x082b082b), uvec2(0x19081908, 0x082b082b), uvec2(0x19190808, 0x082b082b), uvec2(0x2b082b08, 0x082b082b),
uvec2(0x2b082b2b, 0x082b082b), uvec2(0x2b2b2b08, 0x082b082b), uvec2(0x08080819, 0x082b1908), uvec2(0x08081908, 0x082b1908),
uvec2(0x0808192b, 0x082b1908), uvec2(0x08082b19, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x08191919, 0x082b1908),
uvec2(0x08192b08, 0x082b1908), uvec2(0x082b0819, 0x082b1908), uvec2(0x082b1908, 0x082b1908), uvec2(0x19080808, 0x082b1908),
uvec2(0x1908082b, 0x082b1908), uvec2(0x19081919, 0x082b1908), uvec2(0x19082b08, 0x082b1908), uvec2(0x19190819, 0x082b1908),
uvec2(0x19191908, 0x082b1908), uvec2(0x192b0808, 0x082b1908), uvec2(0x2b080819, 0x082b1908), uvec2(0x2b081908, 0x082b1908),
uvec2(0x2b190808, 0x082b1908), uvec2(0x08080808, 0x082b1919), uvec2(0x08081919, 0x082b1919), uvec2(0x08082b08, 0x082b1919),
uvec2(0x08190819, 0x082b1919), uvec2(0x08191908, 0x082b1919), uvec2(0x082b0808, 0x082b1919), uvec2(0x19080819, 0x082b1919),
uvec2(0x19081908, 0x082b1919), uvec2(0x19190808, 0x082b1919), uvec2(0x192b192b, 0x082b1919), uvec2(0x2b080808, 0x082b1919),
uvec2(0x08080819, 0x082b192b), uvec2(0x08081908, 0x082b192b), uvec2(0x08190808, 0x082b192b), uvec2(0x19080808, 0x082b192b),
uvec2(0x19192b19, 0x082b192b), uvec2(0x08080808, 0x082b2b08), uvec2(0x08081919, 0x082b2b08), uvec2(0x08190819, 0x082b2b08),
uvec2(0x08191908, 0x082b2b08), uvec2(0x19080819, 0x082b2b08), uvec2(0x19081908, 0x082b2b08), uvec2(0x19190808, 0x082b2b08),
uvec2(0x2b082b2b, 0x082b2b08), uvec2(0x2b2b2b2b, 0x082b2b08), uvec2(0x08080819, 0x082b2b19), uvec2(0x08081908, 0x082b2b19),
uvec2(0x08190808, 0x082b2b19), uvec2(0x2b191919, 0x082b2b19), uvec2(0x08082b2b, 0x082b2b2b), uvec2(0x082b082b, 0x082b2b2b),
uvec2(0x192b1908, 0x082b2b2b), uvec2(0x2b082b08, 0x082b2b2b), uvec2(0x2b082b2b, 0x082b2b2b), uvec2(0x08080819, 0x19080808),
uvec2(0x08081908, 0x19080808), uvec2(0x0808192b, 0x19080808), uvec2(0x08082b19, 0x19080808), uvec2(0x08190808, 0x19080808),
uvec2(0x0819082b, 0x19080808), uvec2(0x08191919, 0x19080808), uvec2(0x08192b08, 0x19080808), uvec2(0x08192b2b, 0x19080808),
uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), uvec2(0x082b192b, 0x19080808), uvec2(0x19080808, 0x19080808),
uvec2(0x1908082b, 0x19080808), uvec2(0x19081919, 0x19080808), uvec2(0x19082b08, 0x19080808), uvec2(0x19082b2b, 0x19080808),
uvec2(0x19190819, 0x19080808), uvec2(0x19191908, 0x19080808), uvec2(0x1919192b, 0x19080808), uvec2(0x19192b19, 0x19080808),
uvec2(0x192b0808, 0x19080808), uvec2(0x192b082b, 0x19080808), uvec2(0x192b1919, 0x19080808), uvec2(0x2b080819, 0x19080808),
uvec2(0x2b081908, 0x19080808), uvec2(0x2b190808, 0x19080808), uvec2(0x2b191919, 0x19080808), uvec2(0x2b192b08, 0x19080808),
uvec2(0x2b2b0819, 0x19080808), uvec2(0x2b2b1908, 0x19080808), uvec2(0x08080808, 0x19080819), uvec2(0x0808082b, 0x19080819),
uvec2(0x08081919, 0x19080819), uvec2(0x08082b08, 0x19080819), uvec2(0x08190819, 0x19080819), uvec2(0x08191908, 0x19080819),
uvec2(0x0819192b, 0x19080819), uvec2(0x08192b19, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x082b082b, 0x19080819),
uvec2(0x082b1919, 0x19080819), uvec2(0x19080819, 0x19080819), uvec2(0x19081908, 0x19080819), uvec2(0x1908192b, 0x19080819),
uvec2(0x19082b19, 0x19080819), uvec2(0x19190808, 0x19080819), uvec2(0x1919082b, 0x19080819), uvec2(0x19191919, 0x19080819),
uvec2(0x19192b08, 0x19080819), uvec2(0x192b0819, 0x19080819), uvec2(0x192b1908, 0x19080819), uvec2(0x2b080808, 0x19080819),
uvec2(0x2b08082b, 0x19080819), uvec2(0x2b081919, 0x19080819), uvec2(0x2b082b08, 0x19080819), uvec2(0x2b190819, 0x19080819),
uvec2(0x2b191908, 0x19080819), uvec2(0x2b2b0808, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08081908, 0x1908082b),
uvec2(0x08190808, 0x1908082b), uvec2(0x0819082b, 0x1908082b), uvec2(0x08191919, 0x1908082b), uvec2(0x08192b08, 0x1908082b),
uvec2(0x082b1908, 0x1908082b), uvec2(0x19080808, 0x1908082b), uvec2(0x19081919, 0x1908082b), uvec2(0x19082b08, 0x1908082b),
uvec2(0x19190819, 0x1908082b), uvec2(0x19191908, 0x1908082b), uvec2(0x192b0808, 0x1908082b), uvec2(0x2b080819, 0x1908082b),
uvec2(0x2b081908, 0x1908082b), uvec2(0x08080808, 0x19081908), uvec2(0x0808082b, 0x19081908), uvec2(0x08081919, 0x19081908),
uvec2(0x08082b08, 0x19081908), uvec2(0x08082b2b, 0x19081908), uvec2(0x08190819, 0x19081908), uvec2(0x08191908, 0x19081908),
uvec2(0x0819192b, 0x19081908), uvec2(0x08192b19, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x082b082b, 0x19081908),
uvec2(0x082b1919, 0x19081908), uvec2(0x082b2b08, 0x19081908), uvec2(0x19080819, 0x19081908), uvec2(0x19081908, 0x19081908),
uvec2(0x1908192b, 0x19081908), uvec2(0x19082b19, 0x19081908), uvec2(0x19190808, 0x19081908), uvec2(0x1919082b, 0x19081908),
uvec2(0x19191919, 0x19081908), uvec2(0x19192b08, 0x19081908), uvec2(0x192b0819, 0x19081908), uvec2(0x192b1908, 0x19081908),
uvec2(0x2b080808, 0x19081908), uvec2(0x2b08082b, 0x19081908), uvec2(0x2b081919, 0x19081908), uvec2(0x2b082b08, 0x19081908),
uvec2(0x2b190819, 0x19081908), uvec2(0x2b191908, 0x19081908), uvec2(0x2b2b0808, 0x19081908), uvec2(0x08080819, 0x19081919),
uvec2(0x08081908, 0x19081919), uvec2(0x0808192b, 0x19081919), uvec2(0x08082b19, 0x19081919), uvec2(0x08190808, 0x19081919),
uvec2(0x0819082b, 0x19081919), uvec2(0x08191919, 0x19081919), uvec2(0x08192b08, 0x19081919), uvec2(0x082b0819, 0x19081919),
uvec2(0x082b1908, 0x19081919), uvec2(0x19080808, 0x19081919), uvec2(0x1908082b, 0x19081919), uvec2(0x19081919, 0x19081919),
uvec2(0x19082b08, 0x19081919), uvec2(0x19190819, 0x19081919), uvec2(0x19191908, 0x19081919), uvec2(0x192b0808, 0x19081919),
uvec2(0x192b2b2b, 0x19081919), uvec2(0x2b080819, 0x19081919), uvec2(0x2b081908, 0x19081919), uvec2(0x2b190808, 0x19081919),
uvec2(0x08080808, 0x1908192b), uvec2(0x0808082b, 0x1908192b), uvec2(0x08081919, 0x1908192b), uvec2(0x08082b08, 0x1908192b),
uvec2(0x08190819, 0x1908192b), uvec2(0x08191908, 0x1908192b), uvec2(0x082b0808, 0x1908192b), uvec2(0x19080819, 0x1908192b),
uvec2(0x19081908, 0x1908192b), uvec2(0x19190808, 0x1908192b), uvec2(0x2b080808, 0x1908192b), uvec2(0x2b2b1919, 0x1908192b),
uvec2(0x08080819, 0x19082b08), uvec2(0x08081908, 0x19082b08), uvec2(0x08082b19, 0x19082b08), uvec2(0x08190808, 0x19082b08),
uvec2(0x0819082b, 0x19082b08), uvec2(0x08191919, 0x19082b08), uvec2(0x08192b08, 0x19082b08), uvec2(0x082b0819, 0x19082b08),
uvec2(0x082b1908, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x1908082b, 0x19082b08), uvec2(0x19081919, 0x19082b08),
uvec2(0x19082b08, 0x19082b08), uvec2(0x19190819, 0x19082b08), uvec2(0x19191908, 0x19082b08), uvec2(0x192b0808, 0x19082b08),
uvec2(0x2b081908, 0x19082b08), uvec2(0x2b190808, 0x19082b08), uvec2(0x08080808, 0x19082b19), uvec2(0x0808082b, 0x19082b19),
uvec2(0x08081919, 0x19082b19), uvec2(0x08082b08, 0x19082b19), uvec2(0x08190819, 0x19082b19), uvec2(0x08191908, 0x19082b19),
uvec2(0x082b0808, 0x19082b19), uvec2(0x19080819, 0x19082b19), uvec2(0x19081908, 0x19082b19), uvec2(0x19190808, 0x19082b19),
uvec2(0x2b080808, 0x19082b19), uvec2(0x2b19192b, 0x19082b19), uvec2(0x08080819, 0x19082b2b), uvec2(0x08081908, 0x19082b2b),
uvec2(0x08190808, 0x19082b2b), uvec2(0x19080808, 0x19082b2b), uvec2(0x08080808, 0x19190808), uvec2(0x0808082b, 0x19190808),
uvec2(0x08081919, 0x19190808), uvec2(0x08082b08, 0x19190808), uvec2(0x08190819, 0x19190808), uvec2(0x08191908, 0x19190808),
uvec2(0x0819192b, 0x19190808), uvec2(0x08192b19, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x082b082b, 0x19190808),
uvec2(0x082b1919, 0x19190808), uvec2(0x082b2b08, 0x19190808), uvec2(0x19080819, 0x19190808), uvec2(0x19081908, 0x19190808),
uvec2(0x1908192b, 0x19190808), uvec2(0x19082b19, 0x19190808), uvec2(0x19190808, 0x19190808), uvec2(0x1919082b, 0x19190808),
uvec2(0x19191919, 0x19190808), uvec2(0x19192b08, 0x19190808), uvec2(0x192b0819, 0x19190808), uvec2(0x192b1908, 0x19190808),
uvec2(0x2b080808, 0x19190808), uvec2(0x2b08082b, 0x19190808), uvec2(0x2b081919, 0x19190808), uvec2(0x2b082b08, 0x19190808),
uvec2(0x2b190819, 0x19190808), uvec2(0x2b191908, 0x19190808), uvec2(0x08080819, 0x19190819), uvec2(0x08081908, 0x19190819),
uvec2(0x0808192b, 0x19190819), uvec2(0x08082b19, 0x19190819), uvec2(0x08190808, 0x19190819), uvec2(0x0819082b, 0x19190819),
uvec2(0x08191919, 0x19190819), uvec2(0x08192b08, 0x19190819), uvec2(0x082b0819, 0x19190819), uvec2(0x082b1908, 0x19190819),
uvec2(0x19080808, 0x19190819), uvec2(0x1908082b, 0x19190819), uvec2(0x19081919, 0x19190819), uvec2(0x19082b08, 0x19190819),
uvec2(0x19190819, 0x19190819), uvec2(0x19191908, 0x19190819), uvec2(0x192b0808, 0x19190819), uvec2(0x2b080819, 0x19190819),
uvec2(0x2b081908, 0x19190819), uvec2(0x2b190808, 0x19190819), uvec2(0x08080808, 0x1919082b), uvec2(0x08081919, 0x1919082b),
uvec2(0x08082b08, 0x1919082b), uvec2(0x08190819, 0x1919082b), uvec2(0x08191908, 0x1919082b), uvec2(0x082b0808, 0x1919082b),
uvec2(0x19080819, 0x1919082b), uvec2(0x19081908, 0x1919082b), uvec2(0x19190808, 0x1919082b), uvec2(0x192b2b19, 0x1919082b),
uvec2(0x2b080808, 0x1919082b), uvec2(0x08080819, 0x19191908), uvec2(0x08081908, 0x19191908), uvec2(0x0808192b, 0x19191908),
uvec2(0x08082b19, 0x19191908), uvec2(0x08190808, 0x19191908), uvec2(0x0819082b, 0x19191908), uvec2(0x08191919, 0x19191908),
uvec2(0x08192b08, 0x19191908), uvec2(0x082b0819, 0x19191908), uvec2(0x082b1908, 0x19191908), uvec2(0x19080808, 0x19191908),
uvec2(0x1908082b, 0x19191908), uvec2(0x19081919, 0x19191908), uvec2(0x19082b08, 0x19191908), uvec2(0x19190819, 0x19191908),
uvec2(0x19191908, 0x19191908), uvec2(0x192b0808, 0x19191908), uvec2(0x2b080819, 0x19191908), uvec2(0x2b081908, 0x19191908),
uvec2(0x2b190808, 0x19191908), uvec2(0x08080808, 0x19191919), uvec2(0x0808082b, 0x19191919), uvec2(0x08081919, 0x19191919),
uvec2(0x08082b08, 0x19191919), uvec2(0x08190819, 0x19191919), uvec2(0x08191908, 0x19191919), uvec2(0x082b0808, 0x19191919),
uvec2(0x19080819, 0x19191919), uvec2(0x19081908, 0x19191919), uvec2(0x19190808, 0x19191919), uvec2(0x2b080808, 0x19191919),
uvec2(0x08080819, 0x1919192b), uvec2(0x08081908, 0x1919192b), uvec2(0x08190808, 0x1919192b), uvec2(0x082b192b, 0x1919192b),
uvec2(0x19080808, 0x1919192b), uvec2(0x08080808, 0x19192b08), uvec2(0x0808082b, 0x19192b08), uvec2(0x08081919, 0x19192b08),
uvec2(0x08082b08, 0x19192b08), uvec2(0x08190819, 0x19192b08), uvec2(0x08191908, 0x19192b08), uvec2(0x082b0808, 0x19192b08),
uvec2(0x19080819, 0x19192b08), uvec2(0x19081908, 0x19192b08), uvec2(0x19190808, 0x19192b08), uvec2(0x19192b2b, 0x19192b08),
uvec2(0x2b080808, 0x19192b08), uvec2(0x08080819, 0x19192b19), uvec2(0x08081908, 0x19192b19), uvec2(0x08190808, 0x19192b19),
uvec2(0x19080808, 0x19192b19), uvec2(0x08080808, 0x19192b2b), uvec2(0x08192b19, 0x19192b2b), uvec2(0x2b081919, 0x19192b2b),
uvec2(0x2b2b2b08, 0x19192b2b), uvec2(0x08080819, 0x192b0808), uvec2(0x08081908, 0x192b0808), uvec2(0x0808192b, 0x192b0808),
uvec2(0x08190808, 0x192b0808), uvec2(0x0819082b, 0x192b0808), uvec2(0x08191919, 0x192b0808), uvec2(0x08192b08, 0x192b0808),
uvec2(0x082b0819, 0x192b0808), uvec2(0x082b1908, 0x192b0808), uvec2(0x19080808, 0x192b0808), uvec2(0x19081919, 0x192b0808),
uvec2(0x19082b08, 0x192b0808), uvec2(0x19190819, 0x192b0808), uvec2(0x19191908, 0x192b0808), uvec2(0x192b0808, 0x192b0808),
uvec2(0x2b081908, 0x192b0808), uvec2(0x2b190808, 0x192b0808), uvec2(0x08080808, 0x192b0819), uvec2(0x0808082b, 0x192b0819),
uvec2(0x08081919, 0x192b0819), uvec2(0x08082b08, 0x192b0819), uvec2(0x08190819, 0x192b0819), uvec2(0x08191908, 0x192b0819),
uvec2(0x082b0808, 0x192b0819), uvec2(0x19080819, 0x192b0819), uvec2(0x19081908, 0x192b0819), uvec2(0x19190808, 0x192b0819),
uvec2(0x2b080808, 0x192b0819), uvec2(0x2b192b19, 0x192b0819), uvec2(0x08081908, 0x192b082b), uvec2(0x08190808, 0x192b082b),
uvec2(0x19080808, 0x192b082b), uvec2(0x1919192b, 0x192b082b), uvec2(0x2b2b0819, 0x192b082b), uvec2(0x08080808, 0x192b1908),
uvec2(0x08081919, 0x192b1908), uvec2(0x08082b08, 0x192b1908), uvec2(0x08190819, 0x192b1908), uvec2(0x08191908, 0x192b1908),
uvec2(0x082b0808, 0x192b1908), uvec2(0x19080819, 0x192b1908), uvec2(0x19081908, 0x192b1908), uvec2(0x19190808, 0x192b1908),
uvec2(0x2b080808, 0x192b1908), uvec2(0x08080819, 0x192b1919), uvec2(0x08081908, 0x192b1919), uvec2(0x08190808, 0x192b1919),
uvec2(0x19080808, 0x192b1919), uvec2(0x19082b2b, 0x192b1919), uvec2(0x192b2b08, 0x192b1919), uvec2(0x2b19082b, 0x192b1919),
uvec2(0x08080808, 0x192b192b), uvec2(0x2b191908, 0x192b192b), uvec2(0x08080819, 0x192b2b08), uvec2(0x08081908, 0x192b2b08),
uvec2(0x08190808, 0x192b2b08), uvec2(0x192b1919, 0x192b2b08), uvec2(0x2b192b08, 0x192b2b08), uvec2(0x08080808, 0x192b2b19),
uvec2(0x082b2b2b, 0x192b2b19), uvec2(0x1908082b, 0x192b2b2b), uvec2(0x2b2b0819, 0x192b2b2b), uvec2(0x08080808, 0x2b080808),
uvec2(0x0808082b, 0x2b080808), uvec2(0x08081919, 0x2b080808), uvec2(0x08082b08, 0x2b080808), uvec2(0x08190819, 0x2b080808),
uvec2(0x08191908, 0x2b080808), uvec2(0x08192b19, 0x2b080808), uvec2(0x082b0808, 0x2b080808), uvec2(0x082b1919, 0x2b080808),
uvec2(0x19080819, 0x2b080808), uvec2(0x19081908, 0x2b080808), uvec2(0x19190808, 0x2b080808), uvec2(0x1919082b, 0x2b080808),
uvec2(0x19191919, 0x2b080808), uvec2(0x19192b08, 0x2b080808), uvec2(0x192b0819, 0x2b080808), uvec2(0x2b080808, 0x2b080808),
uvec2(0x2b081919, 0x2b080808), uvec2(0x2b190819, 0x2b080808), uvec2(0x2b191908, 0x2b080808), uvec2(0x08080819, 0x2b080819),
uvec2(0x08081908, 0x2b080819), uvec2(0x08082b19, 0x2b080819), uvec2(0x08190808, 0x2b080819), uvec2(0x0819082b, 0x2b080819),
uvec2(0x08191919, 0x2b080819), uvec2(0x08192b08, 0x2b080819), uvec2(0x082b0819, 0x2b080819), uvec2(0x082b1908, 0x2b080819),
uvec2(0x19080808, 0x2b080819), uvec2(0x1908082b, 0x2b080819), uvec2(0x19081919, 0x2b080819), uvec2(0x19082b08, 0x2b080819),
uvec2(0x19190819, 0x2b080819), uvec2(0x19191908, 0x2b080819), uvec2(0x2b080819, 0x2b080819), uvec2(0x2b081908, 0x2b080819),
uvec2(0x2b190808, 0x2b080819), uvec2(0x2b2b2b19, 0x2b080819), uvec2(0x08080808, 0x2b08082b), uvec2(0x08081919, 0x2b08082b),
uvec2(0x08082b2b, 0x2b08082b), uvec2(0x08190819, 0x2b08082b), uvec2(0x08191908, 0x2b08082b), uvec2(0x19080819, 0x2b08082b),
uvec2(0x19081908, 0x2b08082b), uvec2(0x19190808, 0x2b08082b), uvec2(0x08080819, 0x2b081908), uvec2(0x08081908, 0x2b081908),
uvec2(0x0808192b, 0x2b081908), uvec2(0x08082b19, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x0819082b, 0x2b081908),
uvec2(0x08191919, 0x2b081908), uvec2(0x08192b08, 0x2b081908), uvec2(0x082b0819, 0x2b081908), uvec2(0x19080808, 0x2b081908),
uvec2(0x1908082b, 0x2b081908), uvec2(0x19081919, 0x2b081908), uvec2(0x19082b08, 0x2b081908), uvec2(0x19190819, 0x2b081908),
uvec2(0x19191908, 0x2b081908), uvec2(0x192b0808, 0x2b081908), uvec2(0x2b080819, 0x2b081908), uvec2(0x2b081908, 0x2b081908),
uvec2(0x2b190808, 0x2b081908), uvec2(0x08080808, 0x2b081919), uvec2(0x0808082b, 0x2b081919), uvec2(0x08081919, 0x2b081919),
uvec2(0x08082b08, 0x2b081919), uvec2(0x08190819, 0x2b081919), uvec2(0x08191908, 0x2b081919), uvec2(0x082b0808, 0x2b081919),
uvec2(0x19080819, 0x2b081919), uvec2(0x19081908, 0x2b081919), uvec2(0x19190808, 0x2b081919), uvec2(0x2b080808, 0x2b081919),
uvec2(0x2b082b2b, 0x2b081919), uvec2(0x08080819, 0x2b08192b), uvec2(0x08081908, 0x2b08192b), uvec2(0x08190808, 0x2b08192b),
uvec2(0x082b2b19, 0x2b08192b), uvec2(0x19080808, 0x2b08192b), uvec2(0x08080808, 0x2b082b08), uvec2(0x08081919, 0x2b082b08),
uvec2(0x08190819, 0x2b082b08), uvec2(0x08191908, 0x2b082b08), uvec2(0x19080819, 0x2b082b08), uvec2(0x19081908, 0x2b082b08),
uvec2(0x19190808, 0x2b082b08), uvec2(0x2b2b082b, 0x2b082b08), uvec2(0x08080819, 0x2b082b19), uvec2(0x08081908, 0x2b082b19),
uvec2(0x19080808, 0x2b082b19), uvec2(0x192b1919, 0x2b082b19), uvec2(0x082b082b, 0x2b082b2b), uvec2(0x19192b08, 0x2b082b2b),
uvec2(0x19192b2b, 0x2b082b2b), uvec2(0x2b08082b, 0x2b082b2b), uvec2(0x2b2b082b, 0x2b082b2b), uvec2(0x08080819, 0x2b190808),
uvec2(0x08081908, 0x2b190808), uvec2(0x08082b19, 0x2b190808), uvec2(0x08190808, 0x2b190808), uvec2(0x0819082b, 0x2b190808),
uvec2(0x08191919, 0x2b190808), uvec2(0x08192b08, 0x2b190808), uvec2(0x082b1908, 0x2b190808), uvec2(0x19080808, 0x2b190808),
uvec2(0x1908082b, 0x2b190808), uvec2(0x19081919, 0x2b190808), uvec2(0x19082b08, 0x2b190808), uvec2(0x19190819, 0x2b190808),
uvec2(0x19191908, 0x2b190808), uvec2(0x192b0808, 0x2b190808), uvec2(0x2b080819, 0x2b190808), uvec2(0x2b081908, 0x2b190808),
uvec2(0x2b190808, 0x2b190808), uvec2(0x08080808, 0x2b190819), uvec2(0x08081919, 0x2b190819), uvec2(0x08190819, 0x2b190819),
uvec2(0x08191908, 0x2b190819), uvec2(0x19080819, 0x2b190819), uvec2(0x19081908, 0x2b190819), uvec2(0x19190808, 0x2b190819),
uvec2(0x19192b2b, 0x2b190819), uvec2(0x08080819, 0x2b19082b), uvec2(0x08081908, 0x2b19082b), uvec2(0x08190808, 0x2b19082b),
uvec2(0x19080808, 0x2b19082b), uvec2(0x2b2b192b, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x0808082b, 0x2b191908),
uvec2(0x08081919, 0x2b191908), uvec2(0x08082b08, 0x2b191908), uvec2(0x08190819, 0x2b191908), uvec2(0x08191908, 0x2b191908),
uvec2(0x082b0808, 0x2b191908), uvec2(0x19080819, 0x2b191908), uvec2(0x19081908, 0x2b191908), uvec2(0x19190808, 0x2b191908),
uvec2(0x2b080808, 0x2b191908), uvec2(0x2b19192b, 0x2b191908), uvec2(0x08080819, 0x2b191919), uvec2(0x08081908, 0x2b191919),
uvec2(0x08190808, 0x2b191919), uvec2(0x19080808, 0x2b191919), uvec2(0x2b192b08, 0x2b191919), uvec2(0x2b2b0819, 0x2b191919),
uvec2(0x08080808, 0x2b19192b), uvec2(0x1908192b, 0x2b19192b), uvec2(0x192b1908, 0x2b19192b), uvec2(0x08080819, 0x2b192b08),
uvec2(0x08081908, 0x2b192b08), uvec2(0x08190808, 0x2b192b08), uvec2(0x082b192b, 0x2b192b08), uvec2(0x19080808, 0x2b192b08),
uvec2(0x2b2b2b19, 0x2b192b08), uvec2(0x08080808, 0x2b192b19), uvec2(0x19082b19, 0x2b192b19), uvec2(0x1919082b, 0x2b192b19),
uvec2(0x2b190808, 0x2b192b2b), uvec2(0x08080808, 0x2b2b0808), uvec2(0x08081919, 0x2b2b0808), uvec2(0x08082b2b, 0x2b2b0808),
uvec2(0x08191908, 0x2b2b0808), uvec2(0x082b082b, 0x2b2b0808), uvec2(0x082b2b2b, 0x2b2b0808), uvec2(0x19080819, 0x2b2b0808),
uvec2(0x19081908, 0x2b2b0808), uvec2(0x19190808, 0x2b2b0808), uvec2(0x2b2b082b, 0x2b2b0808), uvec2(0x2b2b2b2b, 0x2b2b0808),
uvec2(0x19080808, 0x2b2b0819), uvec2(0x192b1919, 0x2b2b0819), uvec2(0x0808082b, 0x2b2b082b), uvec2(0x08082b2b, 0x2b2b082b),
uvec2(0x082b082b, 0x2b2b082b), uvec2(0x082b2b08, 0x2b2b082b), uvec2(0x082b2b2b, 0x2b2b082b), uvec2(0x2b08082b, 0x2b2b082b),
uvec2(0x2b082b08, 0x2b2b082b), uvec2(0x2b082b2b, 0x2b2b082b), uvec2(0x2b2b2b08, 0x2b2b082b), uvec2(0x08080819, 0x2b2b1908),
uvec2(0x08081908, 0x2b2b1908), uvec2(0x08190808, 0x2b2b1908), uvec2(0x19080808, 0x2b2b1908), uvec2(0x2b082b19, 0x2b2b1908),
uvec2(0x2b2b1908, 0x2b2b1908), uvec2(0x08080808, 0x2b2b1919), uvec2(0x08192b19, 0x2b2b1919), uvec2(0x19190819, 0x2b2b192b),
uvec2(0x08082b2b, 0x2b2b2b08), uvec2(0x082b2b08, 0x2b2b2b08), uvec2(0x2b2b082b, 0x2b2b2b08), uvec2(0x19191908, 0x2b2b2b19),
uvec2(0x2b08192b, 0x2b2b2b19), uvec2(0x08082b08, 0x2b2b2b2b), uvec2(0x08082b2b, 0x2b2b2b2b), uvec2(0x082b0808, 0x2b2b2b2b),
uvec2(0x082b082b, 0x2b2b2b2b), uvec2(0x082b2b08, 0x2b2b2b2b), uvec2(0x2b082b08, 0x2b2b2b2b), uvec2(0x2b2b2b2b, 0x2b2b2b2b)
};
shared uvec2 iq2s_grid[1024];
void init_iq_shmem(uvec3 wgsize)
{
// copy the table into shared memory and sync
for (uint i = gl_LocalInvocationIndex.x; i < iq2s_grid.length(); i += wgsize.x) {
iq2s_grid[i] = iq2s_grid_const[i];
}
barrier();
}
#define QUANT_K QUANT_K_IQ2_S
#define QUANT_R QUANT_R_IQ2_S
#define A_TYPE block_iq2_s
#endif
#define QUANT_K_IQ3_XXS 256
#define QUANT_R_IQ3_XXS 1
struct block_iq3_xxs
{
float16_t d;
uint8_t qs[QUANT_K_IQ3_XXS/4 + QUANT_K_IQ3_XXS/8];
};
struct block_iq3_xxs_packed16
{
float16_t d;
uint16_t qs[QUANT_K_IQ3_XXS/8 + QUANT_K_IQ3_XXS/16];
};
#if defined(DATA_A_IQ3_XXS)
const uint32_t iq3xxs_grid_const[256] = {
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
};
shared uint32_t iq3xxs_grid[256];
void init_iq_shmem(uvec3 wgsize)
{
// copy the table into shared memory and sync
for (uint i = gl_LocalInvocationIndex.x; i < iq3xxs_grid.length(); i += wgsize.x) {
iq3xxs_grid[i] = iq3xxs_grid_const[i];
}
barrier();
}
#define QUANT_K QUANT_K_IQ3_XXS
#define QUANT_R QUANT_R_IQ3_XXS
#define A_TYPE block_iq3_xxs
#define A_TYPE_PACKED16 block_iq3_xxs_packed16
#endif
#define QUANT_K_IQ3_S 256
#define QUANT_R_IQ3_S 1
struct block_iq3_s
{
float16_t d;
uint8_t qs[QUANT_K_IQ3_S/4];
uint8_t qh[QUANT_K_IQ3_S/32];
uint8_t signs[QUANT_K_IQ3_S/8];
uint8_t scales[QUANT_K_IQ3_S/64];
};
struct block_iq3_s_packed16
{
float16_t d;
uint16_t qs[QUANT_K_IQ3_S/4/2];
uint16_t qh[QUANT_K_IQ3_S/32/2];
uint16_t signs[QUANT_K_IQ3_S/8/2];
uint16_t scales[QUANT_K_IQ3_S/64/2];
};
#if defined(DATA_A_IQ3_S)
const uint32_t iq3s_grid_const[512] = {
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
};
shared uint32_t iq3s_grid[512];
void init_iq_shmem(uvec3 wgsize)
{
// copy the table into shared memory and sync
for (uint i = gl_LocalInvocationIndex.x; i < iq3s_grid.length(); i += wgsize.x) {
iq3s_grid[i] = iq3s_grid_const[i];
}
barrier();
}
#define QUANT_K QUANT_K_IQ3_S
#define QUANT_R QUANT_R_IQ3_S
#define A_TYPE block_iq3_s
#define A_TYPE_PACKED16 block_iq3_s_packed16
#endif
#define QUANT_K_IQ4_NL 32
#define QUANT_R_IQ4_NL 2
@ -318,11 +1050,11 @@ const int8_t kvalues_iq4nl_const[16] = {
shared FLOAT_TYPE kvalues_iq4nl[16];
void init_iq4nl_shmem()
void init_iq_shmem(uvec3 wgsize)
{
// copy the table into shared memory and sync
if (gl_LocalInvocationIndex.x < 16) {
kvalues_iq4nl[gl_LocalInvocationIndex.x] = FLOAT_TYPE(kvalues_iq4nl_const[gl_LocalInvocationIndex.x]);
for (uint i = gl_LocalInvocationIndex.x; i < kvalues_iq4nl.length(); i += wgsize.x) {
kvalues_iq4nl[i] = FLOAT_TYPE(kvalues_iq4nl_const[i]);
}
barrier();
}

View file

@ -17,13 +17,13 @@
#include <cstring>
#include <cstdlib>
#include <cassert>
#include <algorithm>
#include <sys/stat.h>
#include <sys/types.h>
#ifdef _WIN32
#include <windows.h>
#include <direct.h> // For _mkdir on Windows
#include <algorithm> // For std::replace on w64devkit
#else
#include <unistd.h>
#include <sys/wait.h>
@ -55,6 +55,11 @@ const std::vector<std::string> type_names = {
"q4_k",
"q5_k",
"q6_k",
"iq2_xxs",
"iq2_xs",
"iq2_s",
"iq3_xxs",
"iq3_s",
"iq4_nl"
};
@ -502,6 +507,7 @@ void write_output_files() {
fprintf(hdr, "#include <cstdint>\n\n");
fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str());
std::sort(shader_fnames.begin(), shader_fnames.end());
for (const auto& pair : shader_fnames) {
const std::string& name = pair.first;
#ifdef _WIN32

View file

@ -128,6 +128,10 @@ static void ggml_print_backtrace_symbols(void) {
#endif
static void ggml_print_backtrace(void) {
const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE");
if (GGML_NO_BACKTRACE) {
return;
}
char attach[32];
snprintf(attach, sizeof(attach), "attach %d", getpid());
int pid = fork();
@ -5339,7 +5343,7 @@ static void ggml_compute_backward(
} break;
case GGML_OP_MUL: {
if (src0_needs_grads) {
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad));
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1));
}
if (src1_needs_grads) {
struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
@ -5431,21 +5435,25 @@ static void ggml_compute_backward(
// src1.shape [n,p,qq,rr]
if (src0_needs_grads) {
struct ggml_tensor * s1_tg =
GGML_ASSERT(grad->ne[2] == src1->ne[2]);
GGML_ASSERT(grad->ne[3] == src1->ne[3]);
struct ggml_tensor * tmp =
ggml_out_prod(ctx, // [n,m,qq,rr]
src1, // [n,p,qq,rr]
grad); // [m,p,qq,rr]
const int64_t qq = s1_tg->ne[2];
const int64_t rr = s1_tg->ne[3];
const int64_t q1 = src0->ne[2];
const int64_t r1 = src0->ne[3];
const bool ne2_broadcasted = qq > q1;
const bool ne3_broadcasted = rr > r1;
if (ne2_broadcasted || ne3_broadcasted) {
// sum broadcast repetitions of s1_tg into shape of src0
s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
if (!ggml_are_same_shape(tmp, src0)) {
GGML_ASSERT(tmp->ne[0] == src0->ne[0]);
GGML_ASSERT(tmp->ne[1] == src0->ne[1]);
GGML_ASSERT(tmp->ne[3] == 1);
const int64_t nr2 = tmp->ne[2] / src0->ne[2];
const size_t nb2 = tmp->nb[2] * nr2;
const size_t nb3 = tmp->nb[2];
tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0);
tmp = ggml_repeat_back(ctx, tmp, src0);
}
ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/);
ggml_add_or_set(ctx, cgraph, isrc0, tmp);
}
if (src1_needs_grads) {
ggml_add_or_set(ctx, cgraph, isrc1,
@ -5514,7 +5522,9 @@ static void ggml_compute_backward(
if (src0_needs_grads) {
GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));
GGML_ASSERT(ggml_is_contiguous(grad));
ggml_add_or_set(ctx, cgraph, isrc0, grad);
GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0));
ggml_add_or_set(ctx, cgraph, isrc0,
ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0));
}
} break;
case GGML_OP_RESHAPE: {

View file

@ -1 +1 @@
d92321c0d151fe73a47d89738c7c3091ac904297
32f0b85987396945afea2291d5f4c5862434292b

View file

@ -819,7 +819,7 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps
for (const auto & file : files) {
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
std::unique_ptr<llama_mmap> mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, is_numa_fn()));
std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa_fn());
mmaps_used.emplace_back(mapping->size(), 0);
if (mlock_mmaps) {
std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock());

View file

@ -1303,10 +1303,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1);
auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev {
if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) {
LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s\n", il, ggml_backend_dev_name(cpu_dev));
return {cpu_dev, &pimpl->cpu_buft_list};
}
const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin();
auto * dev = devices.at(layer_gpu);
LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s\n", il, ggml_backend_dev_name(dev));
return {dev, &pimpl->gpu_buft_list.at(dev)};
};

View file

@ -1245,8 +1245,13 @@ struct llama_vocab::impl {
std::vector<llama_token> cache_special_tokens;
std::vector<std::string> cache_token_to_piece; // llama_token_to_piece(special = true);
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
struct pair_hash {
size_t operator()(const std::pair<std::string, std::string> & p) const {
return std::hash<std::string>{}(p.first) ^ //create some hash for pair
(std::hash<std::string>{}(p.second) << 1);
}
};
std::unordered_map<std::pair<std::string, std::string>, int, pair_hash> bpe_ranks;
// set of all tokens that cause "end of generation"
std::set<llama_token> special_eog_ids;

View file

@ -7700,17 +7700,13 @@ struct llm_build_context {
1
);
struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
ggml_build_forward_expand(
gf,
ggml_cpy(
ctx0,
wkv_states,
ggml_view_1d(
ctx0,
kv_self.v_l[il],
hparams.n_embd_v_s() * n_seqs,
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
)
ggml_view_1d(ctx0, last_norm_att, n_embd * n_seqs, 0),
ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
)
);
@ -8432,13 +8428,141 @@ static enum ggml_status llama_graph_compute(
return status;
}
static int llama_prepare_sbatch(
llama_context & lctx,
const llama_batch & batch,
uint32_t & n_outputs) {
const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
const uint32_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd;
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
if (batch.token) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
if (batch.token[i] < 0 || uint32_t(batch.token[i]) >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return -1;
}
}
}
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
lctx.n_queued_tokens += n_tokens_all;
lctx.embd_seq.clear();
// count outputs
if (batch.logits && !embd_pooled) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs += batch.logits[i] != 0;
}
} else if (lctx.logits_all || embd_pooled) {
n_outputs = n_tokens_all;
} else {
// keep last output only
n_outputs = 1;
}
lctx.sbatch.from_batch(batch, n_embd,
/* simple_split */ !lctx.kv_self.recurrent,
/* logits_all */ n_outputs == n_tokens_all);
// reserve output buffer
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
return -2;
};
return 0;
}
static int llama_prepare_ubatch(
llama_context & lctx,
llama_kv_slot_restorer & kv_slot_restorer,
llama_ubatch & ubatch,
const uint32_t n_outputs,
const uint32_t n_tokens_all) {
GGML_ASSERT(lctx.sbatch.n_tokens > 0);
auto & kv_self = lctx.kv_self;
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
if (lctx.kv_self.recurrent) {
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = lctx.sbatch.split_seq(cparams.n_ubatch);
} else {
// recurrent model architectures are easier to implement
// with equal-length sequences
ubatch = lctx.sbatch.split_equal(cparams.n_ubatch);
}
} else {
ubatch = lctx.sbatch.split_simple(cparams.n_ubatch);
}
// count the outputs in this u_batch
{
int32_t n_outputs_new = 0;
if (n_outputs == n_tokens_all) {
n_outputs_new = ubatch.n_tokens;
} else {
GGML_ASSERT(ubatch.output);
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
n_outputs_new += int32_t(ubatch.output[i] != 0);
}
}
// needs to happen before the graph is built
lctx.n_outputs = n_outputs_new;
}
// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
llama_kv_cache_update(&lctx);
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
kv_self.head = 0;
}
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
if (!slot) {
return 1;
}
kv_slot_restorer.save(slot);
if (!kv_self.recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
const uint32_t pad = llama_kv_cache_get_padding(cparams);
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
}
}
return 0;
}
// decode a batch of tokens by evaluating the transformer
// in case of unsuccessful decoding (error or warning),
// the kv_cache state will be returned to its original state
// (for non-recurrent models) or cleaned (for recurrent models)
//
// - lctx: llama context
// - batch: batch to evaluate
// - inp_batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
@ -8455,37 +8579,18 @@ static int llama_decode_impl(
return -1;
}
// temporary allocate memory for the input batch if needed
// temporarily allocate memory for the input batch if needed
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
const llama_batch & batch = batch_allocr.batch;
const uint32_t n_tokens_all = batch.n_tokens;
const auto & model = lctx.model;
const auto & vocab = model.vocab;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
if (batch.token) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return -1;
}
}
}
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
if (lctx.t_compute_start_us == 0) {
lctx.t_compute_start_us = ggml_time_us();
}
lctx.n_queued_tokens += n_tokens_all;
auto & kv_self = lctx.kv_self;
llama_kv_slot_restorer kv_slot_restorer(kv_self);
@ -8495,99 +8600,27 @@ static int llama_decode_impl(
uint32_t n_outputs = 0;
uint32_t n_outputs_prev = 0;
const auto n_ubatch = cparams.n_ubatch;
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
lctx.embd_seq.clear();
// count outputs
if (batch.logits && !embd_pooled) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs += batch.logits[i] != 0;
{
const int ret = llama_prepare_sbatch(lctx, batch, n_outputs);
if (ret != 0) {
return ret;
}
} else if (lctx.logits_all || embd_pooled) {
n_outputs = n_tokens_all;
} else {
// keep last output only
n_outputs = 1;
}
lctx.sbatch.from_batch(batch, n_embd,
/* simple_split */ !kv_self.recurrent,
/* logits_all */ n_outputs == n_tokens_all);
// reserve output buffer
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
return -2;
};
while (lctx.sbatch.n_tokens > 0) {
llama_ubatch ubatch;
if (kv_self.recurrent) {
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = lctx.sbatch.split_seq(n_ubatch);
} else {
// recurrent model architectures are easier to implement
// with equal-length sequences
ubatch = lctx.sbatch.split_equal(n_ubatch);
}
} else {
ubatch = lctx.sbatch.split_simple(n_ubatch);
}
const uint32_t n_tokens = ubatch.n_tokens;
// count the outputs in this u_batch
{
int32_t n_outputs_new = 0;
if (n_outputs == n_tokens_all) {
n_outputs_new = n_tokens;
} else {
GGML_ASSERT(ubatch.output);
for (uint32_t i = 0; i < n_tokens; i++) {
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
}
const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens);
if (ret != 0) {
return ret;
}
// needs to happen before the graph is built
lctx.n_outputs = n_outputs_new;
}
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
const int n_threads = ubatch.n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
ggml_threadpool_t threadpool = ubatch.n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
GGML_ASSERT(n_threads > 0);
// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
llama_kv_cache_update(&lctx);
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*n_tokens) {
kv_self.head = 0;
}
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
if (!slot) {
return 1;
}
kv_slot_restorer.save(slot);
if (!kv_self.recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
const uint32_t pad = llama_kv_cache_get_padding(cparams);
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
}
}
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
ggml_backend_sched_reset(lctx.sched.get());
@ -8640,7 +8673,7 @@ static int llama_decode_impl(
// update the kv ring buffer
{
kv_self.head += n_tokens;
kv_self.head += ubatch.n_tokens;
// Ensure kv cache head points to a valid index.
if (kv_self.head >= kv_self.size) {
@ -9405,6 +9438,7 @@ static struct llama_model * llama_model_load_from_file_impl(
model->devices.push_back(*dev);
}
} else {
std::vector<ggml_backend_dev_t> rpc_servers;
// use all available devices
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
@ -9415,10 +9449,19 @@ static struct llama_model * llama_model_load_from_file_impl(
break;
case GGML_BACKEND_DEVICE_TYPE_GPU:
model->devices.push_back(dev);
ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
if (ggml_backend_reg_name(reg) == std::string("RPC")) {
rpc_servers.push_back(dev);
} else {
model->devices.push_back(dev);
}
break;
}
}
// add RPC servers at the front of the list
if (!rpc_servers.empty()) {
model->devices.insert(model->devices.begin(), rpc_servers.begin(), rpc_servers.end());
}
}
// if using single GPU mode, remove all except the main GPU

View file

@ -1302,6 +1302,59 @@ struct test_repeat : public test_case {
}
};
// GGML_OP_REPEAT_BACK
struct test_repeat_back : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const std::array<int, 4> nr;
const bool v; // whether src is a noncontiguous view
std::string vars() override {
return VARS_TO_STR4(type, ne, nr, v);
}
size_t op_size(ggml_tensor * t) override {
return ggml_nbytes(t) * 2;
}
test_repeat_back(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {8, 6, 4, 2},
std::array<int, 4> nr = {2, 2, 2, 2},
bool v = false)
: type(type), ne(ne), nr(nr), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * src = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
ggml_set_name(src, "src");
if (v) {
GGML_ASSERT(ne[0] % 2 == 0);
GGML_ASSERT(ne[1] % 2 == 0);
GGML_ASSERT(ne[2] % 2 == 0);
GGML_ASSERT(ne[3] % 2 == 0);
GGML_ASSERT(nr[0] % 2 == 0 || nr[0] == 1);
GGML_ASSERT(nr[1] % 2 == 0 || nr[1] == 1);
GGML_ASSERT(nr[2] % 2 == 0 || nr[2] == 1);
GGML_ASSERT(nr[3] % 2 == 0 || nr[3] == 1);
const int64_t ne00 = nr[0] == 1 ? src->ne[0] : src->ne[0] / 2;
const int64_t ne01 = nr[1] == 1 ? src->ne[1] : src->ne[1] / 2;
const int64_t ne02 = nr[2] == 1 ? src->ne[2] : src->ne[2] / 2;
const int64_t ne03 = nr[3] == 1 ? src->ne[3] : src->ne[3] / 2;
src = ggml_view_4d(ctx, src, ne00, ne01, ne02, ne03, src->nb[1], src->nb[2], src->nb[3], 0);
}
ggml_tensor * target = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(target, "target");
ggml_tensor * out = ggml_repeat_back(ctx, src, target);
ggml_set_name(out, "out");
return out;
}
};
// GGML_OP_DUP
struct test_dup : public test_case {
const ggml_type type;
@ -1849,6 +1902,10 @@ struct test_mul_mat : public test_case {
return 5e-4;
}
int64_t grad_nmax() override {
return 20000;
}
uint64_t op_flops(ggml_tensor * t) override {
GGML_UNUSED(t);
return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1];
@ -1878,8 +1935,12 @@ struct test_mul_mat : public test_case {
a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]);
b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]);
ggml_set_param(ctx, a);
ggml_set_param(ctx, b);
if (!ggml_is_quantized(type_a)) {
if (bs[1] == 1 && nr[1] == 1) {
ggml_set_param(ctx, a);
}
ggml_set_param(ctx, b);
}
ggml_set_name(a, "a");
ggml_set_name(b, "b");
@ -1890,8 +1951,12 @@ struct test_mul_mat : public test_case {
} else {
a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
ggml_set_param(ctx, a);
ggml_set_param(ctx, b);
if (!ggml_is_quantized(type_a)) {
if (bs[1] == 1 && nr[1] == 1) {
ggml_set_param(ctx, a);
}
ggml_set_param(ctx, b);
}
ggml_set_name(a, "a");
ggml_set_name(b, "b");
}
@ -2282,11 +2347,12 @@ struct test_soft_max : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const bool mask;
const ggml_type m_prec;
const float scale;
const float max_bias;
std::string vars() override {
return VARS_TO_STR5(type, ne, mask, scale, max_bias);
return VARS_TO_STR6(type, ne, mask, m_prec, scale, max_bias);
}
// the 1024 test with bias occasionally fails:
@ -2298,9 +2364,10 @@ struct test_soft_max : public test_case {
test_soft_max(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 5, 4, 3},
bool mask = false,
ggml_type m_prec = GGML_TYPE_F32,
float scale = 1.0f,
float max_bias = 0.0f)
: type(type), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {}
: type(type), ne(ne), mask(mask), m_prec(m_prec), scale(scale), max_bias(max_bias) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
@ -2309,7 +2376,7 @@ struct test_soft_max : public test_case {
ggml_tensor * mask = nullptr;
if (this->mask) {
mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]);
mask = ggml_new_tensor_2d(ctx, m_prec, ne[0], ne[1]);
ggml_set_name(mask, "mask");
}
@ -3798,6 +3865,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, ne3}, {1, 1, 1, 2}));
}
for (bool view : {false, true}) {
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 1}, view));
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view));
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view));
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I16, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
}
test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
test_cases.emplace_back(new test_dup(GGML_TYPE_F16));
test_cases.emplace_back(new test_dup(GGML_TYPE_I32));
@ -3909,38 +3986,35 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
for (int i = 1; i < 9; ++i) {
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q6_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_IQ4_NL, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
for (ggml_type type_a : all_types) {
for (int i = 1; i < 10; ++i) {
test_cases.emplace_back(new test_mul_mat(type_a, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
}
}
#if 1
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
// test cases without permutation
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 2}));
// test cases with permutation
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
@ -4078,17 +4152,28 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (float scale : {1.0f, 0.1f}) {
for (int64_t ne0 : {16, 1024}) {
for (int64_t ne1 : {16, 1024}) {
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias));
if (mask) {
for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, scale, max_bias));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, scale, max_bias));
}
} else {
/* The precision of mask here doesn't matter as boolean mask is false */
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias));
}
}
}
}
}
}
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 8.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 8.0f));
for (float max_bias : {0.0f, 8.0f}) {
for (float scale : {1.0f, 0.1f}) {
@ -4224,13 +4309,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));